basicsr.losses.gan_loss

class basicsr.losses.gan_loss.GANLoss(gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0)[source]

Bases: Module

Define GAN loss.

Parameters:
  • gan_type (str) – Support ‘vanilla’, ‘lsgan’, ‘wgan’, ‘hinge’.

  • real_label_val (float) – The value for real label. Default: 1.0.

  • fake_label_val (float) – The value for fake label. Default: 0.0.

  • loss_weight (float) – Loss weight. Default: 1.0. Note that loss_weight is only for generators; and it is always 1.0 for discriminators.

forward(input, target_is_real, is_disc=False)[source]
Parameters:
  • input (Tensor) – The input for the loss module, i.e., the network prediction.

  • target_is_real (bool) – Whether the targe is real or fake.

  • is_disc (bool) – Whether the loss for discriminators or not. Default: False.

Returns:

GAN loss value.

Return type:

Tensor

get_target_label(input, target_is_real)[source]

Get target label.

Parameters:
  • input (Tensor) – Input tensor.

  • target_is_real (bool) – Whether the target is real or fake.

Returns:

Target tensor. Return bool for wgan, otherwise,

return Tensor.

Return type:

(bool | Tensor)

training: bool
class basicsr.losses.gan_loss.MultiScaleGANLoss(gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0)[source]

Bases: GANLoss

MultiScaleGANLoss accepts a list of predictions

forward(input, target_is_real, is_disc=False)[source]

The input is a list of tensors, or a list of (a list of tensors)

training: bool
basicsr.losses.gan_loss.g_path_regularize(fake_img, latents, mean_path_length, decay=0.01)[source]
basicsr.losses.gan_loss.gradient_penalty_loss(discriminator, real_data, fake_data, weight=None)[source]

Calculate gradient penalty for wgan-gp.

Parameters:
  • discriminator (nn.Module) – Network for the discriminator.

  • real_data (Tensor) – Real input data.

  • fake_data (Tensor) – Fake input data.

  • weight (Tensor) – Weight tensor. Default: None.

Returns:

A tensor for gradient penalty.

Return type:

Tensor

basicsr.losses.gan_loss.r1_penalty(real_pred, real_img)[source]

R1 regularization for discriminator. The core idea is to penalize the gradient on real data alone: when the generator distribution produces the true data distribution and the discriminator is equal to 0 on the data manifold, the gradient penalty ensures that the discriminator cannot create a non-zero gradient orthogonal to the data manifold without suffering a loss in the GAN game.

Reference: Eq. 9 in Which training methods for GANs do actually converge.