basicsr.losses.gan_loss¶
- class basicsr.losses.gan_loss.GANLoss(gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0)[source]¶
Bases:
Module
Define GAN loss.
- Parameters:
gan_type (str) – Support ‘vanilla’, ‘lsgan’, ‘wgan’, ‘hinge’.
real_label_val (float) – The value for real label. Default: 1.0.
fake_label_val (float) – The value for fake label. Default: 0.0.
loss_weight (float) – Loss weight. Default: 1.0. Note that loss_weight is only for generators; and it is always 1.0 for discriminators.
- forward(input, target_is_real, is_disc=False)[source]¶
- Parameters:
input (Tensor) – The input for the loss module, i.e., the network prediction.
target_is_real (bool) – Whether the targe is real or fake.
is_disc (bool) – Whether the loss for discriminators or not. Default: False.
- Returns:
GAN loss value.
- Return type:
Tensor
- get_target_label(input, target_is_real)[source]¶
Get target label.
- Parameters:
input (Tensor) – Input tensor.
target_is_real (bool) – Whether the target is real or fake.
- Returns:
- Target tensor. Return bool for wgan, otherwise,
return Tensor.
- Return type:
(bool | Tensor)
- training: bool¶
- class basicsr.losses.gan_loss.MultiScaleGANLoss(gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0)[source]¶
Bases:
GANLoss
MultiScaleGANLoss accepts a list of predictions
- forward(input, target_is_real, is_disc=False)[source]¶
The input is a list of tensors, or a list of (a list of tensors)
- training: bool¶
- basicsr.losses.gan_loss.gradient_penalty_loss(discriminator, real_data, fake_data, weight=None)[source]¶
Calculate gradient penalty for wgan-gp.
- Parameters:
discriminator (nn.Module) – Network for the discriminator.
real_data (Tensor) – Real input data.
fake_data (Tensor) – Fake input data.
weight (Tensor) – Weight tensor. Default: None.
- Returns:
A tensor for gradient penalty.
- Return type:
Tensor
- basicsr.losses.gan_loss.r1_penalty(real_pred, real_img)[source]¶
R1 regularization for discriminator. The core idea is to penalize the gradient on real data alone: when the generator distribution produces the true data distribution and the discriminator is equal to 0 on the data manifold, the gradient penalty ensures that the discriminator cannot create a non-zero gradient orthogonal to the data manifold without suffering a loss in the GAN game.
Reference: Eq. 9 in Which training methods for GANs do actually converge.