basicsr.archs.vgg_arch

class basicsr.archs.vgg_arch.VGGFeatureExtractor(layer_name_list, vgg_type='vgg19', use_input_norm=True, range_norm=False, requires_grad=False, remove_pooling=False, pooling_stride=2)[source]

Bases: Module

VGG network for feature extraction.

In this implementation, we allow users to choose whether use normalization in the input feature and the type of vgg network. Note that the pretrained path must fit the vgg type.

Parameters:
  • layer_name_list (list[str]) – Forward function returns the corresponding features according to the layer_name_list. Example: {‘relu1_1’, ‘relu2_1’, ‘relu3_1’}.

  • vgg_type (str) – Set the type of vgg network. Default: ‘vgg19’.

  • use_input_norm (bool) – If True, normalize the input image. Importantly, the input feature must in the range [0, 1]. Default: True.

  • range_norm (bool) – If True, norm images with range [-1, 1] to [0, 1]. Default: False.

  • requires_grad (bool) – If true, the parameters of VGG network will be optimized. Default: False.

  • remove_pooling (bool) – If true, the max pooling operations in VGG net will be removed. Default: False.

  • pooling_stride (int) – The stride of max pooling operation. Default: 2.

forward(x)[source]

Forward function.

Parameters:

x (Tensor) – Input tensor with shape (n, c, h, w).

Returns:

Forward results.

Return type:

Tensor

training: bool
basicsr.archs.vgg_arch.insert_bn(names)[source]

Insert bn layer after each conv.

Parameters:

names (list) – The list of layer names.

Returns:

The list of layer names with bn layers.

Return type:

list