basicsr.models.base_model¶
- class basicsr.models.base_model.BaseModel(opt)[source]¶
Bases:
object
Base model.
- get_bare_model(net)[source]¶
Get bare model, especially under wrapping with DistributedDataParallel or DataParallel.
- load_network(net, load_path, strict=True, param_key='params')[source]¶
Load network.
- Parameters:
load_path (str) – The path of networks to be loaded.
net (nn.Module) – Network.
strict (bool) – Whether strictly loaded.
param_key (str) – The parameter key of loaded network. If set to None, use the root ‘path’. Default: ‘params’.
- model_to_device(net)[source]¶
Model to device. It also warps models with DistributedDataParallel or DataParallel.
- Parameters:
net (nn.Module) –
- print_network(net)[source]¶
Print the str and parameter number of a network.
- Parameters:
net (nn.Module) –
- reduce_loss_dict(loss_dict)[source]¶
reduce loss dict.
In distributed training, it averages the losses among different GPUs .
- Parameters:
loss_dict (OrderedDict) – Loss dict.
- resume_training(resume_state)[source]¶
Reload the optimizers and schedulers for resumed training.
- Parameters:
resume_state (dict) – Resume state.
- save_network(net, net_label, current_iter, param_key='params')[source]¶
Save networks.
- Parameters:
net (nn.Module | list[nn.Module]) – Network(s) to be saved.
net_label (str) – Network label.
current_iter (int) – Current iter number.
param_key (str | list[str]) – The parameter key(s) to save network. Default: ‘params’.
- save_training_state(epoch, current_iter)[source]¶
Save training states during training, which will be used for resuming.
- Parameters:
epoch (int) – Current epoch.
current_iter (int) – Current iteration.
- update_learning_rate(current_iter, warmup_iter=-1)[source]¶
Update learning rate.
- Parameters:
current_iter (int) – Current iteration.
warmup_iter (int) – Default: -1.
- validation(dataloader, current_iter, tb_logger, save_img=False)[source]¶
Validation function.
- Parameters:
dataloader (torch.utils.data.DataLoader) – Validation dataloader.
current_iter (int) – Current iteration.
tb_logger (tensorboard logger) – Tensorboard logger.
save_img (bool) – Whether to save images. Default: False.