basicsr.archs.swinir_arch

class basicsr.archs.swinir_arch.BasicLayer(dim, input_resolution, depth, num_heads, window_size, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, downsample=None, use_checkpoint=False)[source]

Bases: Module

A basic Swin Transformer layer for one stage.

Parameters:
  • dim (int) – Number of input channels.

  • input_resolution (tuple[int]) – Input resolution.

  • depth (int) – Number of blocks.

  • num_heads (int) – Number of attention heads.

  • window_size (int) – Local window size.

  • mlp_ratio (float) – Ratio of mlp hidden dim to embedding dim.

  • qkv_bias (bool, optional) – If True, add a learnable bias to query, key, value. Default: True

  • qk_scale (float | None, optional) – Override default qk scale of head_dim ** -0.5 if set.

  • drop (float, optional) – Dropout rate. Default: 0.0

  • attn_drop (float, optional) – Attention dropout rate. Default: 0.0

  • drop_path (float | tuple[float], optional) – Stochastic depth rate. Default: 0.0

  • norm_layer (nn.Module, optional) – Normalization layer. Default: nn.LayerNorm

  • downsample (nn.Module | None, optional) – Downsample layer at the end of the layer. Default: None

  • use_checkpoint (bool) – Whether to use checkpointing to save memory. Default: False.

extra_repr() str[source]

Set the extra representation of the module

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

flops()[source]
forward(x, x_size)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class basicsr.archs.swinir_arch.DropPath(drop_prob=None)[source]

Bases: Module

Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class basicsr.archs.swinir_arch.Mlp(in_features, hidden_features=None, out_features=None, act_layer=<class 'torch.nn.modules.activation.GELU'>, drop=0.0)[source]

Bases: Module

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class basicsr.archs.swinir_arch.PatchEmbed(img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None)[source]

Bases: Module

Image to Patch Embedding

Parameters:
  • img_size (int) – Image size. Default: 224.

  • patch_size (int) – Patch token size. Default: 4.

  • in_chans (int) – Number of input image channels. Default: 3.

  • embed_dim (int) – Number of linear projection output channels. Default: 96.

  • norm_layer (nn.Module, optional) – Normalization layer. Default: None

flops()[source]
forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class basicsr.archs.swinir_arch.PatchMerging(input_resolution, dim, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>)[source]

Bases: Module

Patch Merging Layer.

Parameters:
  • input_resolution (tuple[int]) – Resolution of input feature.

  • dim (int) – Number of input channels.

  • norm_layer (nn.Module, optional) – Normalization layer. Default: nn.LayerNorm

extra_repr() str[source]

Set the extra representation of the module

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

flops()[source]
forward(x)[source]

x: b, h*w, c

training: bool
class basicsr.archs.swinir_arch.PatchUnEmbed(img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None)[source]

Bases: Module

Image to Patch Unembedding

Parameters:
  • img_size (int) – Image size. Default: 224.

  • patch_size (int) – Patch token size. Default: 4.

  • in_chans (int) – Number of input image channels. Default: 3.

  • embed_dim (int) – Number of linear projection output channels. Default: 96.

  • norm_layer (nn.Module, optional) – Normalization layer. Default: None

flops()[source]
forward(x, x_size)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class basicsr.archs.swinir_arch.RSTB(dim, input_resolution, depth, num_heads, window_size, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, downsample=None, use_checkpoint=False, img_size=224, patch_size=4, resi_connection='1conv')[source]

Bases: Module

Residual Swin Transformer Block (RSTB).

Parameters:
  • dim (int) – Number of input channels.

  • input_resolution (tuple[int]) – Input resolution.

  • depth (int) – Number of blocks.

  • num_heads (int) – Number of attention heads.

  • window_size (int) – Local window size.

  • mlp_ratio (float) – Ratio of mlp hidden dim to embedding dim.

  • qkv_bias (bool, optional) – If True, add a learnable bias to query, key, value. Default: True

  • qk_scale (float | None, optional) – Override default qk scale of head_dim ** -0.5 if set.

  • drop (float, optional) – Dropout rate. Default: 0.0

  • attn_drop (float, optional) – Attention dropout rate. Default: 0.0

  • drop_path (float | tuple[float], optional) – Stochastic depth rate. Default: 0.0

  • norm_layer (nn.Module, optional) – Normalization layer. Default: nn.LayerNorm

  • downsample (nn.Module | None, optional) – Downsample layer at the end of the layer. Default: None

  • use_checkpoint (bool) – Whether to use checkpointing to save memory. Default: False.

  • img_size – Input image size.

  • patch_size – Patch size.

  • resi_connection – The convolutional block before residual connection.

flops()[source]
forward(x, x_size)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class basicsr.archs.swinir_arch.SwinIR(img_size=64, patch_size=1, in_chans=3, embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6), window_size=7, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.1, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, ape=False, patch_norm=True, use_checkpoint=False, upscale=2, img_range=1.0, upsampler='', resi_connection='1conv', **kwargs)[source]

Bases: Module

A PyTorch impl of : SwinIR: Image Restoration Using Swin Transformer, based on Swin Transformer.

Parameters:
  • img_size (int | tuple(int)) – Input image size. Default 64

  • patch_size (int | tuple(int)) – Patch size. Default: 1

  • in_chans (int) – Number of input image channels. Default: 3

  • embed_dim (int) – Patch embedding dimension. Default: 96

  • depths (tuple(int)) – Depth of each Swin Transformer layer.

  • num_heads (tuple(int)) – Number of attention heads in different layers.

  • window_size (int) – Window size. Default: 7

  • mlp_ratio (float) – Ratio of mlp hidden dim to embedding dim. Default: 4

  • qkv_bias (bool) – If True, add a learnable bias to query, key, value. Default: True

  • qk_scale (float) – Override default qk scale of head_dim ** -0.5 if set. Default: None

  • drop_rate (float) – Dropout rate. Default: 0

  • attn_drop_rate (float) – Attention dropout rate. Default: 0

  • drop_path_rate (float) – Stochastic depth rate. Default: 0.1

  • norm_layer (nn.Module) – Normalization layer. Default: nn.LayerNorm.

  • ape (bool) – If True, add absolute position embedding to the patch embedding. Default: False

  • patch_norm (bool) – If True, add normalization after patch embedding. Default: True

  • use_checkpoint (bool) – Whether to use checkpointing to save memory. Default: False

  • upscale – Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction

  • img_range – Image range. 1. or 255.

  • upsampler – The reconstruction reconstruction module. ‘pixelshuffle’/’pixelshuffledirect’/’nearest+conv’/None

  • resi_connection – The convolutional block before residual connection. ‘1conv’/’3conv’

flops()[source]
forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_features(x)[source]
no_weight_decay()[source]
no_weight_decay_keywords()[source]
training: bool
class basicsr.archs.swinir_arch.SwinTransformerBlock(dim, input_resolution, num_heads, window_size=7, shift_size=0, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=<class 'torch.nn.modules.activation.GELU'>, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>)[source]

Bases: Module

Swin Transformer Block.

Parameters:
  • dim (int) – Number of input channels.

  • input_resolution (tuple[int]) – Input resolution.

  • num_heads (int) – Number of attention heads.

  • window_size (int) – Window size.

  • shift_size (int) – Shift size for SW-MSA.

  • mlp_ratio (float) – Ratio of mlp hidden dim to embedding dim.

  • qkv_bias (bool, optional) – If True, add a learnable bias to query, key, value. Default: True

  • qk_scale (float | None, optional) – Override default qk scale of head_dim ** -0.5 if set.

  • drop (float, optional) – Dropout rate. Default: 0.0

  • attn_drop (float, optional) – Attention dropout rate. Default: 0.0

  • drop_path (float, optional) – Stochastic depth rate. Default: 0.0

  • act_layer (nn.Module, optional) – Activation layer. Default: nn.GELU

  • norm_layer (nn.Module, optional) – Normalization layer. Default: nn.LayerNorm

calculate_mask(x_size)[source]
extra_repr() str[source]

Set the extra representation of the module

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

flops()[source]
forward(x, x_size)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class basicsr.archs.swinir_arch.Upsample(scale, num_feat)[source]

Bases: Sequential

Upsample module.

Parameters:
  • scale (int) – Scale factor. Supported scales: 2^n and 3.

  • num_feat (int) – Channel number of intermediate features.

class basicsr.archs.swinir_arch.UpsampleOneStep(scale, num_feat, num_out_ch, input_resolution=None)[source]

Bases: Sequential

UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)

Used in lightweight SR to save parameters.

Parameters:
  • scale (int) – Scale factor. Supported scales: 2^n and 3.

  • num_feat (int) – Channel number of intermediate features.

flops()[source]
class basicsr.archs.swinir_arch.WindowAttention(dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0)[source]

Bases: Module

Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window.

Parameters:
  • dim (int) – Number of input channels.

  • window_size (tuple[int]) – The height and width of the window.

  • num_heads (int) – Number of attention heads.

  • qkv_bias (bool, optional) – If True, add a learnable bias to query, key, value. Default: True

  • qk_scale (float | None, optional) – Override default qk scale of head_dim ** -0.5 if set

  • attn_drop (float, optional) – Dropout ratio of attention weight. Default: 0.0

  • proj_drop (float, optional) – Dropout ratio of output. Default: 0.0

extra_repr() str[source]

Set the extra representation of the module

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

flops(n)[source]
forward(x, mask=None)[source]
Parameters:
  • x – input features with shape of (num_windows*b, n, c)

  • mask – (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None

training: bool
basicsr.archs.swinir_arch.drop_path(x, drop_prob: float = 0.0, training: bool = False)[source]

Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py

basicsr.archs.swinir_arch.window_partition(x, window_size)[source]
Parameters:
  • x – (b, h, w, c)

  • window_size (int) – window size

Returns:

(num_windows*b, window_size, window_size, c)

Return type:

windows

basicsr.archs.swinir_arch.window_reverse(windows, window_size, h, w)[source]
Parameters:
  • windows – (num_windows*b, window_size, window_size, c)

  • window_size (int) – Window size

  • h (int) – Height of image

  • w (int) – Width of image

Returns:

(b, h, w, c)

Return type:

x