basicsr.losses.basic_loss

class basicsr.losses.basic_loss.CharbonnierLoss(loss_weight=1.0, reduction='mean', eps=1e-12)[source]

Bases: Module

Charbonnier loss (one variant of Robust L1Loss, a differentiable variant of L1Loss).

Described in “Deep Laplacian Pyramid Networks for Fast and Accurate

Super-Resolution”.

Parameters:
  • loss_weight (float) – Loss weight for L1 loss. Default: 1.0.

  • reduction (str) – Specifies the reduction to apply to the output. Supported choices are ‘none’ | ‘mean’ | ‘sum’. Default: ‘mean’.

  • eps (float) – A value used to control the curvature near zero. Default: 1e-12.

forward(pred, target, weight=None, **kwargs)[source]
Parameters:
  • pred (Tensor) – of shape (N, C, H, W). Predicted tensor.

  • target (Tensor) – of shape (N, C, H, W). Ground truth tensor.

  • weight (Tensor, optional) – of shape (N, C, H, W). Element-wise weights. Default: None.

training: bool
class basicsr.losses.basic_loss.L1Loss(loss_weight=1.0, reduction='mean')[source]

Bases: Module

L1 (mean absolute error, MAE) loss.

Parameters:
  • loss_weight (float) – Loss weight for L1 loss. Default: 1.0.

  • reduction (str) – Specifies the reduction to apply to the output. Supported choices are ‘none’ | ‘mean’ | ‘sum’. Default: ‘mean’.

forward(pred, target, weight=None, **kwargs)[source]
Parameters:
  • pred (Tensor) – of shape (N, C, H, W). Predicted tensor.

  • target (Tensor) – of shape (N, C, H, W). Ground truth tensor.

  • weight (Tensor, optional) – of shape (N, C, H, W). Element-wise weights. Default: None.

training: bool
class basicsr.losses.basic_loss.MSELoss(loss_weight=1.0, reduction='mean')[source]

Bases: Module

MSE (L2) loss.

Parameters:
  • loss_weight (float) – Loss weight for MSE loss. Default: 1.0.

  • reduction (str) – Specifies the reduction to apply to the output. Supported choices are ‘none’ | ‘mean’ | ‘sum’. Default: ‘mean’.

forward(pred, target, weight=None, **kwargs)[source]
Parameters:
  • pred (Tensor) – of shape (N, C, H, W). Predicted tensor.

  • target (Tensor) – of shape (N, C, H, W). Ground truth tensor.

  • weight (Tensor, optional) – of shape (N, C, H, W). Element-wise weights. Default: None.

training: bool
class basicsr.losses.basic_loss.PerceptualLoss(layer_weights, vgg_type='vgg19', use_input_norm=True, range_norm=False, perceptual_weight=1.0, style_weight=0.0, criterion='l1')[source]

Bases: Module

Perceptual loss with commonly used style loss.

Parameters:
  • layer_weights (dict) – The weight for each layer of vgg feature. Here is an example: {‘conv5_4’: 1.}, which means the conv5_4 feature layer (before relu5_4) will be extracted with weight 1.0 in calculating losses.

  • vgg_type (str) – The type of vgg network used as feature extractor. Default: ‘vgg19’.

  • use_input_norm (bool) – If True, normalize the input image in vgg. Default: True.

  • range_norm (bool) – If True, norm images with range [-1, 1] to [0, 1]. Default: False.

  • perceptual_weight (float) – If perceptual_weight > 0, the perceptual loss will be calculated and the loss will multiplied by the weight. Default: 1.0.

  • style_weight (float) – If style_weight > 0, the style loss will be calculated and the loss will multiplied by the weight. Default: 0.

  • criterion (str) – Criterion used for perceptual loss. Default: ‘l1’.

forward(x, gt)[source]

Forward function.

Parameters:
  • x (Tensor) – Input tensor with shape (n, c, h, w).

  • gt (Tensor) – Ground-truth tensor with shape (n, c, h, w).

Returns:

Forward results.

Return type:

Tensor

training: bool
class basicsr.losses.basic_loss.WeightedTVLoss(loss_weight=1.0, reduction='mean')[source]

Bases: L1Loss

Weighted TV loss.

Parameters:

loss_weight (float) – Loss weight. Default: 1.0.

forward(pred, weight=None)[source]
Parameters:
  • pred (Tensor) – of shape (N, C, H, W). Predicted tensor.

  • target (Tensor) – of shape (N, C, H, W). Ground truth tensor.

  • weight (Tensor, optional) – of shape (N, C, H, W). Element-wise weights. Default: None.

training: bool
basicsr.losses.basic_loss.charbonnier_loss(pred, target, eps=1e-12)[source]
basicsr.losses.basic_loss.l1_loss(pred, target)[source]
basicsr.losses.basic_loss.mse_loss(pred, target)[source]