basicsr.models.base_model

class basicsr.models.base_model.BaseModel(opt)[source]

Bases: object

Base model.

feed_data(data)[source]
get_bare_model(net)[source]

Get bare model, especially under wrapping with DistributedDataParallel or DataParallel.

get_current_learning_rate()[source]
get_current_log()[source]
get_current_visuals()[source]
get_optimizer(optim_type, params, lr, **kwargs)[source]
load_network(net, load_path, strict=True, param_key='params')[source]

Load network.

Parameters:
  • load_path (str) – The path of networks to be loaded.

  • net (nn.Module) – Network.

  • strict (bool) – Whether strictly loaded.

  • param_key (str) – The parameter key of loaded network. If set to None, use the root ‘path’. Default: ‘params’.

model_ema(decay=0.999)[source]
model_to_device(net)[source]

Model to device. It also warps models with DistributedDataParallel or DataParallel.

Parameters:

net (nn.Module) –

optimize_parameters()[source]
print_network(net)[source]

Print the str and parameter number of a network.

Parameters:

net (nn.Module) –

reduce_loss_dict(loss_dict)[source]

reduce loss dict.

In distributed training, it averages the losses among different GPUs .

Parameters:

loss_dict (OrderedDict) – Loss dict.

resume_training(resume_state)[source]

Reload the optimizers and schedulers for resumed training.

Parameters:

resume_state (dict) – Resume state.

save(epoch, current_iter)[source]

Save networks and training state.

save_network(net, net_label, current_iter, param_key='params')[source]

Save networks.

Parameters:
  • net (nn.Module | list[nn.Module]) – Network(s) to be saved.

  • net_label (str) – Network label.

  • current_iter (int) – Current iter number.

  • param_key (str | list[str]) – The parameter key(s) to save network. Default: ‘params’.

save_training_state(epoch, current_iter)[source]

Save training states during training, which will be used for resuming.

Parameters:
  • epoch (int) – Current epoch.

  • current_iter (int) – Current iteration.

setup_schedulers()[source]

Set up schedulers.

update_learning_rate(current_iter, warmup_iter=-1)[source]

Update learning rate.

Parameters:
  • current_iter (int) – Current iteration.

  • warmup_iter (int) – Default: -1.

validation(dataloader, current_iter, tb_logger, save_img=False)[source]

Validation function.

Parameters:
  • dataloader (torch.utils.data.DataLoader) – Validation dataloader.

  • current_iter (int) – Current iteration.

  • tb_logger (tensorboard logger) – Tensorboard logger.

  • save_img (bool) – Whether to save images. Default: False.