basicsr.archs.edvr_arch¶
- class basicsr.archs.edvr_arch.EDVR(num_in_ch=3, num_out_ch=3, num_feat=64, num_frame=5, deformable_groups=8, num_extract_block=5, num_reconstruct_block=10, center_frame_idx=None, hr_in=False, with_predeblur=False, with_tsa=True)[source]¶
Bases:
Module
EDVR network structure for video super-resolution.
Now only support X4 upsampling factor.
Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks
- Parameters:
num_in_ch (int) – Channel number of input image. Default: 3.
num_out_ch (int) – Channel number of output image. Default: 3.
num_feat (int) – Channel number of intermediate features. Default: 64.
num_frame (int) – Number of input frames. Default: 5.
deformable_groups (int) – Deformable groups. Defaults: 8.
num_extract_block (int) – Number of blocks for feature extraction. Default: 5.
num_reconstruct_block (int) – Number of blocks for reconstruction. Default: 10.
center_frame_idx (int) – The index of center frame. Frame counting from 0. Default: Middle of input frames.
hr_in (bool) – Whether the input has high resolution. Default: False.
with_predeblur (bool) – Whether has predeblur module. Default: False.
with_tsa (bool) – Whether has TSA module. Default: True.
- forward(x)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶
- class basicsr.archs.edvr_arch.PCDAlignment(num_feat=64, deformable_groups=8)[source]¶
Bases:
Module
Alignment module using Pyramid, Cascading and Deformable convolution (PCD). It is used in EDVR.
Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks
- Parameters:
num_feat (int) – Channel number of middle features. Default: 64.
deformable_groups (int) – Deformable groups. Defaults: 8.
- forward(nbr_feat_l, ref_feat_l)[source]¶
Align neighboring frame features to the reference frame features.
- Parameters:
nbr_feat_l (list[Tensor]) – Neighboring feature list. It contains three pyramid levels (L1, L2, L3), each with shape (b, c, h, w).
ref_feat_l (list[Tensor]) – Reference feature list. It contains three pyramid levels (L1, L2, L3), each with shape (b, c, h, w).
- Returns:
Aligned features.
- Return type:
Tensor
- training: bool¶
- class basicsr.archs.edvr_arch.PredeblurModule(num_in_ch=3, num_feat=64, hr_in=False)[source]¶
Bases:
Module
Pre-dublur module.
- Parameters:
num_in_ch (int) – Channel number of input image. Default: 3.
num_feat (int) – Channel number of intermediate features. Default: 64.
hr_in (bool) – Whether the input has high resolution. Default: False.
- forward(x)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶
- class basicsr.archs.edvr_arch.TSAFusion(num_feat=64, num_frame=5, center_frame_idx=2)[source]¶
Bases:
Module
Temporal Spatial Attention (TSA) fusion module.
- Temporal: Calculate the correlation between center frame and
neighboring frames;
- Spatial: It has 3 pyramid levels, the attention is similar to SFT.
- (SFT: Recovering realistic texture in image super-resolution by deep
spatial feature transform.)
- Parameters:
num_feat (int) – Channel number of middle features. Default: 64.
num_frame (int) – Number of frames. Default: 5.
center_frame_idx (int) – The index of center frame. Default: 2.
- forward(aligned_feat)[source]¶
- Parameters:
aligned_feat (Tensor) – Aligned features with shape (b, t, c, h, w).
- Returns:
Features after TSA with the shape (b, c, h, w).
- Return type:
Tensor
- training: bool¶