import torch
[docs]def convert_edvr():
ori_net = torch.load('experiments/pretrained_models/EDVR_REDS_SR_M.pth')
crt_net = torch.load('xxx/net_g_8.pth')
save_path = './edvr_medium_x4_reds_sr_official.pth'
# for k, v in ori_net.items():
# print(k)
# print('*****')
# for k, v in crt_net.items():
# print(k)
for crt_k, _ in crt_net.items():
# deblur hr in
if 'predeblur.stride_conv_hr1' in crt_k:
ori_k = crt_k.replace('predeblur.stride_conv_hr1', 'pre_deblur.conv_first_2')
elif 'predeblur.stride_conv_hr2' in crt_k:
ori_k = crt_k.replace('predeblur.stride_conv_hr2', 'pre_deblur.conv_first_3')
elif 'predeblur.conv_first' in crt_k:
ori_k = crt_k.replace('predeblur.conv_first', 'pre_deblur.conv_first_1')
# predeblur module
# elif 'predeblur.conv_first' in crt_k:
# ori_k = crt_k.replace('predeblur.conv_first',
# 'pre_deblur.conv_first')
elif 'predeblur.stride_conv_l2' in crt_k:
ori_k = crt_k.replace('predeblur.stride_conv_l2', 'pre_deblur.deblur_L2_conv')
elif 'predeblur.stride_conv_l3' in crt_k:
ori_k = crt_k.replace('predeblur.stride_conv_l3', 'pre_deblur.deblur_L3_conv')
elif 'predeblur.resblock_l3' in crt_k:
ori_k = crt_k.replace('predeblur.resblock_l3', 'pre_deblur.RB_L3_1')
elif 'predeblur.resblock_l2' in crt_k:
ori_k = crt_k.replace('predeblur.resblock_l', 'pre_deblur.RB_L')
elif 'predeblur.resblock_l1' in crt_k:
a, b, c, d, e = crt_k.split('.')
ori_k = f'pre_deblur.RB_L1_{int(c)+1}.{d}.{e}'
elif 'conv_l2' in crt_k:
ori_k = crt_k.replace('conv_l2_', 'fea_L2_conv')
elif 'conv_l3' in crt_k:
ori_k = crt_k.replace('conv_l3_', 'fea_L3_conv')
elif 'pcd_align.dcn_pack' in crt_k:
idx = crt_k.split('.l')[1].split('.')[0]
name = crt_k.split('.l')[1].split('.')[1]
ori_k = f'pcd_align.L{idx}_dcnpack.{name}'
if 'conv_offset' in crt_k:
name = name.replace('conv_offset', 'conv_offset_mask')
weight_bias = crt_k.split('.l')[1].split('.')[2]
ori_k = f'pcd_align.L{idx}_dcnpack.{name}.{weight_bias}'
elif 'pcd_align.offset_conv' in crt_k:
_, b, c, d = crt_k.split('.')
idx = b.split('conv')[1]
level = c.split('l')[1]
ori_k = f'pcd_align.L{level}_offset_conv{idx}.{d}'
elif 'pcd_align.feat_conv' in crt_k:
a, b, c, d = crt_k.split('.')
level = c.split('l')[1]
ori_k = f'pcd_align.L{level}_fea_conv.{d}'
elif 'pcd_align.cas_dcnpack' in crt_k:
ori_k = crt_k.replace('conv_offset', 'conv_offset_mask')
elif ('conv_first' in crt_k or 'feature_extraction' in crt_k or 'pcd_align.cas_offset' in crt_k
or 'upconv' in crt_k or 'conv_last' in crt_k or 'conv_1x1' in crt_k):
ori_k = crt_k
elif 'temporal_attn1' in crt_k:
ori_k = crt_k.replace('fusion.temporal_attn1', 'tsa_fusion.tAtt_2')
elif 'temporal_attn2' in crt_k:
ori_k = crt_k.replace('fusion.temporal_attn2', 'tsa_fusion.tAtt_1')
elif 'fusion.feat_fusion' in crt_k:
ori_k = crt_k.replace('fusion.feat_fusion', 'tsa_fusion.fea_fusion')
elif 'fusion.spatial_attn_add' in crt_k:
ori_k = crt_k.replace('fusion.spatial_attn_add', 'tsa_fusion.sAtt_add_')
elif 'fusion.spatial_attn_l' in crt_k:
ori_k = crt_k.replace('fusion.spatial_attn_l', 'tsa_fusion.sAtt_L')
elif 'fusion.spatial_attn' in crt_k:
ori_k = crt_k.replace('fusion.spatial_attn', 'tsa_fusion.sAtt_')
elif 'reconstruction' in crt_k:
ori_k = crt_k.replace('reconstruction', 'recon_trunk')
elif 'conv_hr' in crt_k:
ori_k = crt_k.replace('conv_hr', 'HRconv')
# for model woTSA
elif 'fusion' in crt_k:
ori_k = crt_k.replace('fusion', 'tsa_fusion')
else:
print('unprocess key', crt_k)
# print(ori_k)
crt_net[crt_k] = ori_net[ori_k]
ori_k = None
torch.save(crt_net, save_path)
[docs]def convert_edsr(ori_net_path, crt_net_path, save_path, num_block=32):
"""Convert EDSR models in https://github.com/thstkdgus35/EDSR-PyTorch.
It supports converting x2, x3 and x4 models.
Args:
ori_net_path (str): Original network path.
crt_net_path (str): Current network path.
save_path (str): The path to save the converted model.
num_block (int): Number of blocks. Default: 16.
"""
ori_net = torch.load(ori_net_path)
crt_net = torch.load(crt_net_path)
for crt_k, _ in crt_net.items():
if 'conv_first' in crt_k:
ori_k = crt_k.replace('conv_first', 'head.0')
crt_net[crt_k] = ori_net[ori_k]
elif 'conv_after_body' in crt_k:
ori_k = crt_k.replace('conv_after_body', f'body.{num_block}')
elif 'body' in crt_k:
ori_k = crt_k.replace('conv1', 'body.0').replace('conv2', 'body.2')
elif 'upsample.0' in crt_k:
ori_k = crt_k.replace('upsample.0', 'tail.0.0')
elif 'upsample.2' in crt_k:
ori_k = crt_k.replace('upsample.2', 'tail.0.2')
elif 'conv_last' in crt_k:
ori_k = crt_k.replace('conv_last', 'tail.1')
else:
print('unprocess key', crt_k)
crt_net[crt_k] = ori_net[ori_k]
torch.save(crt_net, save_path)
[docs]def convert_rcan_model():
ori_net = torch.load('RCAN_model_best.pt')
crt_net = torch.load('experiments/201_RCANx4_scratch_DIV2K_rand0/models/net_g_5000.pth')
# for ori_k, ori_v in ori_net.items():
# print(ori_k)
for crt_k, _ in crt_net.items():
# print(crt_k)
if 'conv_first' in crt_k:
ori_k = crt_k.replace('conv_first', 'head.0')
crt_net[crt_k] = ori_net[ori_k]
elif 'conv_after_body' in crt_k:
ori_k = crt_k.replace('conv_after_body', 'body.10')
elif 'upsample.0' in crt_k:
ori_k = crt_k.replace('upsample.0', 'tail.0.0')
elif 'upsample.2' in crt_k:
ori_k = crt_k.replace('upsample.2', 'tail.0.2')
elif 'conv_last' in crt_k:
ori_k = crt_k.replace('conv_last', 'tail.1')
elif 'attention' in crt_k:
_, ai, _, bi, _, ci, d, di, e = crt_k.split('.')
ori_k = f'body.{ai}.body.{bi}.body.{ci}.conv_du.{int(di)-1}.{e}'
elif 'rcab' in crt_k:
a, ai, b, bi, c, ci, d = crt_k.split('.')
ori_k = f'body.{ai}.body.{bi}.body.{ci}.{d}'
elif 'body' in crt_k:
ori_k = crt_k.replace('conv.', 'body.20.')
else:
print('unprocess key', crt_k)
crt_net[crt_k] = ori_net[ori_k]
torch.save(crt_net, 'RCAN_model_best.pth')
[docs]def convert_esrgan_model():
from basicsr.archs.rrdbnet_arch import RRDBNet
rrdb = RRDBNet(3, 3, num_feat=64, num_block=23, num_grow_ch=32)
crt_net = rrdb.state_dict()
# for k, v in crt_net.items():
# print(k)
ori_net = torch.load('experiments/pretrained_models/RRDB_ESRGAN_x4.pth')
# for k, v in ori_net.items():
# print(k)
for crt_k, _ in crt_net.items():
if 'rdb' in crt_k:
ori_k = crt_k.replace('rdb', 'RDB').replace('body', 'RRDB_trunk')
elif 'conv_body' in crt_k:
ori_k = crt_k.replace('conv_body', 'trunk_conv')
elif 'conv_up' in crt_k:
ori_k = crt_k.replace('conv_up', 'upconv')
elif 'conv_hr' in crt_k:
ori_k = crt_k.replace('conv_hr', 'HRconv')
else:
ori_k = crt_k
print(crt_k)
crt_net[crt_k] = ori_net[ori_k]
torch.save(crt_net, 'experiments/pretrained_models/ESRGAN_x4_SR_DF2KOST_official.pth')
[docs]def convert_duf_model():
from basicsr.archs.duf_arch import DUF
scale = 2
duf = DUF(scale=scale, num_layer=16, adapt_official_weights=True)
crt_net = duf.state_dict()
# for k, v in crt_net.items():
# print(k)
ori_net = torch.load('experiments/pretrained_models/old_DUF_x2_16L_official.pth')
# print('******')
# for k, v in ori_net.items():
# print(k)
'''
for crt_k, crt_v in crt_net.items():
if 'conv3d1' in crt_k:
ori_k = crt_k.replace('conv3d1', 'conv3d_1')
elif 'conv3d2' in crt_k:
ori_k = crt_k.replace('conv3d2', 'conv3d_2')
elif 'dense_block1.dense_blocks' in crt_k:
# dense_block1.dense_blocks.0.0.weight
a, b, c, d, e = crt_k.split('.')
# dense_block_1.dense_blocks.0.weight
ori_k = f'dense_block_1.dense_blocks.{int(c) * 6 + int(d)}.{e}'
elif 'dense_block2.temporal_reduce1.0' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce1.0',
'dense_block_2.bn3d_1')
elif 'dense_block2.temporal_reduce1.2' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce1.2',
'dense_block_2.conv3d_1')
elif 'dense_block2.temporal_reduce1.3' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce1.3',
'dense_block_2.bn3d_2')
elif 'dense_block2.temporal_reduce1.5' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce1.5',
'dense_block_2.conv3d_2')
elif 'dense_block2.temporal_reduce2.0' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce2.0',
'dense_block_2.bn3d_3')
elif 'dense_block2.temporal_reduce2.2' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce2.2',
'dense_block_2.conv3d_3')
elif 'dense_block2.temporal_reduce2.3' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce2.3',
'dense_block_2.bn3d_4')
elif 'dense_block2.temporal_reduce2.5' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce2.5',
'dense_block_2.conv3d_4')
elif 'dense_block2.temporal_reduce3.0' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce3.0',
'dense_block_2.bn3d_5')
elif 'dense_block2.temporal_reduce3.2' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce3.2',
'dense_block_2.conv3d_5')
elif 'dense_block2.temporal_reduce3.3' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce3.3',
'dense_block_2.bn3d_6')
elif 'dense_block2.temporal_reduce3.5' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce3.5',
'dense_block_2.conv3d_6')
elif 'bn3d2' in crt_k:
ori_k = crt_k.replace('bn3d2', 'bn3d_2')
else:
ori_k = crt_k
print(crt_k)
crt_net[crt_k] = ori_net[ori_k]
'''
# for 16 layers
for crt_k, _ in crt_net.items():
if 'conv3d1' in crt_k:
ori_k = crt_k.replace('conv3d1', 'conv3d_1')
elif 'conv3d2' in crt_k:
ori_k = crt_k.replace('conv3d2', 'conv3d_2')
elif 'dense_block1.dense_blocks.0.0' in crt_k:
ori_k = crt_k.replace('dense_block1.dense_blocks.0.0', 'dense_block_1.bn3d_1')
elif 'dense_block1.dense_blocks.0.2' in crt_k:
ori_k = crt_k.replace('dense_block1.dense_blocks.0.2', 'dense_block_1.conv3d_1')
elif 'dense_block1.dense_blocks.0.3' in crt_k:
ori_k = crt_k.replace('dense_block1.dense_blocks.0.3', 'dense_block_1.bn3d_2')
elif 'dense_block1.dense_blocks.0.5' in crt_k:
ori_k = crt_k.replace('dense_block1.dense_blocks.0.5', 'dense_block_1.conv3d_2')
elif 'dense_block1.dense_blocks.1.0' in crt_k:
ori_k = crt_k.replace('dense_block1.dense_blocks.1.0', 'dense_block_1.bn3d_3')
elif 'dense_block1.dense_blocks.1.2' in crt_k:
ori_k = crt_k.replace('dense_block1.dense_blocks.1.2', 'dense_block_1.conv3d_3')
elif 'dense_block1.dense_blocks.1.3' in crt_k:
ori_k = crt_k.replace('dense_block1.dense_blocks.1.3', 'dense_block_1.bn3d_4')
elif 'dense_block1.dense_blocks.1.5' in crt_k:
ori_k = crt_k.replace('dense_block1.dense_blocks.1.5', 'dense_block_1.conv3d_4')
elif 'dense_block1.dense_blocks.2.0' in crt_k:
ori_k = crt_k.replace('dense_block1.dense_blocks.2.0', 'dense_block_1.bn3d_5')
elif 'dense_block1.dense_blocks.2.2' in crt_k:
ori_k = crt_k.replace('dense_block1.dense_blocks.2.2', 'dense_block_1.conv3d_5')
elif 'dense_block1.dense_blocks.2.3' in crt_k:
ori_k = crt_k.replace('dense_block1.dense_blocks.2.3', 'dense_block_1.bn3d_6')
elif 'dense_block1.dense_blocks.2.5' in crt_k:
ori_k = crt_k.replace('dense_block1.dense_blocks.2.5', 'dense_block_1.conv3d_6')
elif 'dense_block2.temporal_reduce1.0' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce1.0', 'dense_block_2.bn3d_1')
elif 'dense_block2.temporal_reduce1.2' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce1.2', 'dense_block_2.conv3d_1')
elif 'dense_block2.temporal_reduce1.3' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce1.3', 'dense_block_2.bn3d_2')
elif 'dense_block2.temporal_reduce1.5' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce1.5', 'dense_block_2.conv3d_2')
elif 'dense_block2.temporal_reduce2.0' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce2.0', 'dense_block_2.bn3d_3')
elif 'dense_block2.temporal_reduce2.2' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce2.2', 'dense_block_2.conv3d_3')
elif 'dense_block2.temporal_reduce2.3' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce2.3', 'dense_block_2.bn3d_4')
elif 'dense_block2.temporal_reduce2.5' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce2.5', 'dense_block_2.conv3d_4')
elif 'dense_block2.temporal_reduce3.0' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce3.0', 'dense_block_2.bn3d_5')
elif 'dense_block2.temporal_reduce3.2' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce3.2', 'dense_block_2.conv3d_5')
elif 'dense_block2.temporal_reduce3.3' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce3.3', 'dense_block_2.bn3d_6')
elif 'dense_block2.temporal_reduce3.5' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce3.5', 'dense_block_2.conv3d_6')
elif 'bn3d2' in crt_k:
ori_k = crt_k.replace('bn3d2', 'bn3d_2')
else:
ori_k = crt_k
print(crt_k)
crt_net[crt_k] = ori_net[ori_k]
x = crt_net['conv3d_r2.weight'].clone()
x1 = x[::3, ...]
x2 = x[1::3, ...]
x3 = x[2::3, ...]
crt_net['conv3d_r2.weight'][:scale**2, ...] = x1
crt_net['conv3d_r2.weight'][scale**2:2 * (scale**2), ...] = x2
crt_net['conv3d_r2.weight'][2 * (scale**2):, ...] = x3
x = crt_net['conv3d_r2.bias'].clone()
x1 = x[::3, ...]
x2 = x[1::3, ...]
x3 = x[2::3, ...]
crt_net['conv3d_r2.bias'][:scale**2, ...] = x1
crt_net['conv3d_r2.bias'][scale**2:2 * (scale**2), ...] = x2
crt_net['conv3d_r2.bias'][2 * (scale**2):, ...] = x3
torch.save(crt_net, 'experiments/pretrained_models/DUF_x2_16L_official.pth')
if __name__ == '__main__':
# convert EDSR models
# ori_net_path = 'path to original model'
# crt_net_path = 'path to current model'
# save_path = 'save path'
# convert_edsr(ori_net_path, crt_net_path, save_path, num_block=32)
convert_duf_model()