Source code for basicsr.data.prefetch_dataloader

import queue as Queue
import threading
import torch
from torch.utils.data import DataLoader


[docs]class PrefetchGenerator(threading.Thread): """A general prefetch generator. Reference: https://stackoverflow.com/questions/7323664/python-generator-pre-fetch Args: generator: Python generator. num_prefetch_queue (int): Number of prefetch queue. """ def __init__(self, generator, num_prefetch_queue): threading.Thread.__init__(self) self.queue = Queue.Queue(num_prefetch_queue) self.generator = generator self.daemon = True self.start()
[docs] def run(self): for item in self.generator: self.queue.put(item) self.queue.put(None)
def __next__(self): next_item = self.queue.get() if next_item is None: raise StopIteration return next_item def __iter__(self): return self
[docs]class PrefetchDataLoader(DataLoader): """Prefetch version of dataloader. Reference: https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# TODO: Need to test on single gpu and ddp (multi-gpu). There is a known issue in ddp. Args: num_prefetch_queue (int): Number of prefetch queue. kwargs (dict): Other arguments for dataloader. """ def __init__(self, num_prefetch_queue, **kwargs): self.num_prefetch_queue = num_prefetch_queue super(PrefetchDataLoader, self).__init__(**kwargs) def __iter__(self): return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
[docs]class CPUPrefetcher(): """CPU prefetcher. Args: loader: Dataloader. """ def __init__(self, loader): self.ori_loader = loader self.loader = iter(loader)
[docs] def next(self): try: return next(self.loader) except StopIteration: return None
[docs] def reset(self): self.loader = iter(self.ori_loader)
[docs]class CUDAPrefetcher(): """CUDA prefetcher. Reference: https://github.com/NVIDIA/apex/issues/304# It may consume more GPU memory. Args: loader: Dataloader. opt (dict): Options. """ def __init__(self, loader, opt): self.ori_loader = loader self.loader = iter(loader) self.opt = opt self.stream = torch.cuda.Stream() self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') self.preload()
[docs] def preload(self): try: self.batch = next(self.loader) # self.batch is a dict except StopIteration: self.batch = None return None # put tensors to gpu with torch.cuda.stream(self.stream): for k, v in self.batch.items(): if torch.is_tensor(v): self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
[docs] def next(self): torch.cuda.current_stream().wait_stream(self.stream) batch = self.batch self.preload() return batch
[docs] def reset(self): self.loader = iter(self.ori_loader) self.preload()