basicsr.archs.edvr_arch

class basicsr.archs.edvr_arch.EDVR(num_in_ch=3, num_out_ch=3, num_feat=64, num_frame=5, deformable_groups=8, num_extract_block=5, num_reconstruct_block=10, center_frame_idx=None, hr_in=False, with_predeblur=False, with_tsa=True)[source]

Bases: Module

EDVR network structure for video super-resolution.

Now only support X4 upsampling factor.

Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks

Parameters:
  • num_in_ch (int) – Channel number of input image. Default: 3.

  • num_out_ch (int) – Channel number of output image. Default: 3.

  • num_feat (int) – Channel number of intermediate features. Default: 64.

  • num_frame (int) – Number of input frames. Default: 5.

  • deformable_groups (int) – Deformable groups. Defaults: 8.

  • num_extract_block (int) – Number of blocks for feature extraction. Default: 5.

  • num_reconstruct_block (int) – Number of blocks for reconstruction. Default: 10.

  • center_frame_idx (int) – The index of center frame. Frame counting from 0. Default: Middle of input frames.

  • hr_in (bool) – Whether the input has high resolution. Default: False.

  • with_predeblur (bool) – Whether has predeblur module. Default: False.

  • with_tsa (bool) – Whether has TSA module. Default: True.

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class basicsr.archs.edvr_arch.PCDAlignment(num_feat=64, deformable_groups=8)[source]

Bases: Module

Alignment module using Pyramid, Cascading and Deformable convolution (PCD). It is used in EDVR.

Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks

Parameters:
  • num_feat (int) – Channel number of middle features. Default: 64.

  • deformable_groups (int) – Deformable groups. Defaults: 8.

forward(nbr_feat_l, ref_feat_l)[source]

Align neighboring frame features to the reference frame features.

Parameters:
  • nbr_feat_l (list[Tensor]) – Neighboring feature list. It contains three pyramid levels (L1, L2, L3), each with shape (b, c, h, w).

  • ref_feat_l (list[Tensor]) – Reference feature list. It contains three pyramid levels (L1, L2, L3), each with shape (b, c, h, w).

Returns:

Aligned features.

Return type:

Tensor

training: bool
class basicsr.archs.edvr_arch.PredeblurModule(num_in_ch=3, num_feat=64, hr_in=False)[source]

Bases: Module

Pre-dublur module.

Parameters:
  • num_in_ch (int) – Channel number of input image. Default: 3.

  • num_feat (int) – Channel number of intermediate features. Default: 64.

  • hr_in (bool) – Whether the input has high resolution. Default: False.

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class basicsr.archs.edvr_arch.TSAFusion(num_feat=64, num_frame=5, center_frame_idx=2)[source]

Bases: Module

Temporal Spatial Attention (TSA) fusion module.

Temporal: Calculate the correlation between center frame and

neighboring frames;

Spatial: It has 3 pyramid levels, the attention is similar to SFT.
(SFT: Recovering realistic texture in image super-resolution by deep

spatial feature transform.)

Parameters:
  • num_feat (int) – Channel number of middle features. Default: 64.

  • num_frame (int) – Number of frames. Default: 5.

  • center_frame_idx (int) – The index of center frame. Default: 2.

forward(aligned_feat)[source]
Parameters:

aligned_feat (Tensor) – Aligned features with shape (b, t, c, h, w).

Returns:

Features after TSA with the shape (b, c, h, w).

Return type:

Tensor

training: bool