import torch
from basicsr.archs.dfdnet_arch import DFDNet
from basicsr.archs.vgg_arch import NAMES
[docs]def convert_net(ori_net, crt_net):
for crt_k, _ in crt_net.items():
# vgg feature extractor
if 'vgg_extractor' in crt_k:
ori_k = crt_k.replace('vgg_extractor', 'VggExtract').replace('vgg_net', 'model')
if 'mean' in crt_k:
ori_k = ori_k.replace('mean', 'RGB_mean')
elif 'std' in crt_k:
ori_k = ori_k.replace('std', 'RGB_std')
else:
idx = NAMES['vgg19'].index(crt_k.split('.')[2])
if 'weight' in crt_k:
ori_k = f'VggExtract.model.features.{idx}.weight'
else:
ori_k = f'VggExtract.model.features.{idx}.bias'
elif 'attn_blocks' in crt_k:
if 'left_eye' in crt_k:
ori_k = crt_k.replace('attn_blocks.left_eye', 'le')
elif 'right_eye' in crt_k:
ori_k = crt_k.replace('attn_blocks.right_eye', 're')
elif 'mouth' in crt_k:
ori_k = crt_k.replace('attn_blocks.mouth', 'mo')
elif 'nose' in crt_k:
ori_k = crt_k.replace('attn_blocks.nose', 'no')
else:
raise ValueError('Wrong!')
elif 'multi_scale_dilation' in crt_k:
if 'conv_blocks' in crt_k:
_, _, c, d, e = crt_k.split('.')
ori_k = f'MSDilate.conv{int(c)+1}.{d}.{e}'
else:
ori_k = crt_k.replace('multi_scale_dilation.conv_fusion', 'MSDilate.convi')
elif crt_k.startswith('upsample'):
ori_k = crt_k.replace('upsample', 'up')
if 'scale_block' in crt_k:
ori_k = ori_k.replace('scale_block', 'ScaleModel1')
elif 'shift_block' in crt_k:
ori_k = ori_k.replace('shift_block', 'ShiftModel1')
elif 'upsample4' in crt_k and 'body' in crt_k:
ori_k = ori_k.replace('body', 'Model')
else:
print('unprocess key: ', crt_k)
# replace
if crt_net[crt_k].size() != ori_net[ori_k].size():
raise ValueError('Wrong tensor size: \n'
f'crt_net: {crt_net[crt_k].size()}\n'
f'ori_net: {ori_net[ori_k].size()}')
else:
crt_net[crt_k] = ori_net[ori_k]
return crt_net
if __name__ == '__main__':
ori_net = torch.load('experiments/pretrained_models/DFDNet/DFDNet_official_original.pth')
dfd_net = DFDNet(64, dict_path='experiments/pretrained_models/DFDNet/DFDNet_dict_512.pth')
crt_net = dfd_net.state_dict()
crt_net_params = convert_net(ori_net, crt_net)
torch.save(
dict(params=crt_net_params),
'experiments/pretrained_models/DFDNet/DFDNet_official.pth',
_use_new_zipfile_serialization=False)