Source code for amptorch.data_parallel

"""
Adapted from https://github.com/Open-Catalyst-Project/ocp/blob/master/ocpmodels/common/data_parallel.py
"""
from itertools import chain
import torch


[docs]class DataParallel(torch.nn.DataParallel): """ Data Parallelization for GPU scheme. """ def __init__(self, module, output_device, num_gpus): if num_gpus < 0: raise ValueError("# GPUs must be positive.") if num_gpus > torch.cuda.device_count(): raise ValueError("# GPUs specified larger than available") if num_gpus == 1: device_ids = [output_device] else: if output_device >= num_gpus: raise ValueError("Main device must be less than # of GPUs") device_ids = list(range(num_gpus)) super(DataParallel, self).__init__( module=module, device_ids=device_ids, output_device=output_device ) self.src_device = torch.device(output_device)
[docs] def forward(self, batch_list): if len(self.device_ids) == 1: return self.module(batch_list[0].to(f"cuda:{self.device_ids[0]}")) for t in chain(self.module.parameters(), self.module.buffers()): if t.device != self.src_device: raise RuntimeError( ( "Module must have its parameters and buffers on device " "{} but found one of them on device {}." ).format(self.src_device, t.device) ) inputs = [ batch.to(f"cuda:{self.device_ids[i]}") for i, batch in enumerate(batch_list) ] replicas = self.replicate(self.module, self.device_ids[: len(inputs)]) outputs = self.parallel_apply(replicas, inputs, None) return self.gather(outputs, self.output_device)
[docs]class ParallelCollater: """ Data collater for multi-GPU training. """ def __init__(self, num_gpus, collater): self.num_gpus = num_gpus self.collater = collater def __call__(self, data_list): if self.num_gpus <= 1: batch = self.collater(data_list) batch_list = [batch[0]] target_list = [batch[1]] return [batch_list, target_list] else: num_devices = min(self.num_gpus, len(data_list)) count = torch.tensor([data.num_nodes for data in data_list]) cumsum = count.cumsum(0) cumsum = torch.cat([cumsum.new_zeros(1), cumsum], dim=0) device_id = num_devices * cumsum.to(torch.float) / cumsum[-1].item() device_id = (device_id[:-1] + device_id[1:]) / 2.0 device_id = device_id.to(torch.long) split = device_id.bincount().cumsum(0) split = torch.cat([split.new_zeros(1), split], dim=0) split = torch.unique(split, sorted=True) split = split.tolist() skorch_list = [ self.collater(data_list[split[i] : split[i + 1]]) for i in range(len(split) - 1) ] batch_list = [batch[0] for batch in skorch_list] target_list = [batch[1] for batch in skorch_list] return [batch_list, target_list]