Source code for scripts.model_conversion.convert_dfdnet

import torch

from basicsr.archs.dfdnet_arch import DFDNet
from basicsr.archs.vgg_arch import NAMES


[docs]def convert_net(ori_net, crt_net): for crt_k, _ in crt_net.items(): # vgg feature extractor if 'vgg_extractor' in crt_k: ori_k = crt_k.replace('vgg_extractor', 'VggExtract').replace('vgg_net', 'model') if 'mean' in crt_k: ori_k = ori_k.replace('mean', 'RGB_mean') elif 'std' in crt_k: ori_k = ori_k.replace('std', 'RGB_std') else: idx = NAMES['vgg19'].index(crt_k.split('.')[2]) if 'weight' in crt_k: ori_k = f'VggExtract.model.features.{idx}.weight' else: ori_k = f'VggExtract.model.features.{idx}.bias' elif 'attn_blocks' in crt_k: if 'left_eye' in crt_k: ori_k = crt_k.replace('attn_blocks.left_eye', 'le') elif 'right_eye' in crt_k: ori_k = crt_k.replace('attn_blocks.right_eye', 're') elif 'mouth' in crt_k: ori_k = crt_k.replace('attn_blocks.mouth', 'mo') elif 'nose' in crt_k: ori_k = crt_k.replace('attn_blocks.nose', 'no') else: raise ValueError('Wrong!') elif 'multi_scale_dilation' in crt_k: if 'conv_blocks' in crt_k: _, _, c, d, e = crt_k.split('.') ori_k = f'MSDilate.conv{int(c)+1}.{d}.{e}' else: ori_k = crt_k.replace('multi_scale_dilation.conv_fusion', 'MSDilate.convi') elif crt_k.startswith('upsample'): ori_k = crt_k.replace('upsample', 'up') if 'scale_block' in crt_k: ori_k = ori_k.replace('scale_block', 'ScaleModel1') elif 'shift_block' in crt_k: ori_k = ori_k.replace('shift_block', 'ShiftModel1') elif 'upsample4' in crt_k and 'body' in crt_k: ori_k = ori_k.replace('body', 'Model') else: print('unprocess key: ', crt_k) # replace if crt_net[crt_k].size() != ori_net[ori_k].size(): raise ValueError('Wrong tensor size: \n' f'crt_net: {crt_net[crt_k].size()}\n' f'ori_net: {ori_net[ori_k].size()}') else: crt_net[crt_k] = ori_net[ori_k] return crt_net
if __name__ == '__main__': ori_net = torch.load('experiments/pretrained_models/DFDNet/DFDNet_official_original.pth') dfd_net = DFDNet(64, dict_path='experiments/pretrained_models/DFDNet/DFDNet_dict_512.pth') crt_net = dfd_net.state_dict() crt_net_params = convert_net(ori_net, crt_net) torch.save( dict(params=crt_net_params), 'experiments/pretrained_models/DFDNet/DFDNet_official.pth', _use_new_zipfile_serialization=False)