"""
Adapted from https://github.com/Open-Catalyst-Project/ocp/blob/master/ocpmodels/common/data_parallel.py
"""
from itertools import chain
import torch
[docs]class DataParallel(torch.nn.DataParallel):
"""
Data Parallelization for GPU scheme.
"""
def __init__(self, module, output_device, num_gpus):
if num_gpus < 0:
raise ValueError("# GPUs must be positive.")
if num_gpus > torch.cuda.device_count():
raise ValueError("# GPUs specified larger than available")
if num_gpus == 1:
device_ids = [output_device]
else:
if output_device >= num_gpus:
raise ValueError("Main device must be less than # of GPUs")
device_ids = list(range(num_gpus))
super(DataParallel, self).__init__(
module=module, device_ids=device_ids, output_device=output_device
)
self.src_device = torch.device(output_device)
[docs] def forward(self, batch_list):
if len(self.device_ids) == 1:
return self.module(batch_list[0].to(f"cuda:{self.device_ids[0]}"))
for t in chain(self.module.parameters(), self.module.buffers()):
if t.device != self.src_device:
raise RuntimeError(
(
"Module must have its parameters and buffers on device "
"{} but found one of them on device {}."
).format(self.src_device, t.device)
)
inputs = [
batch.to(f"cuda:{self.device_ids[i]}") for i, batch in enumerate(batch_list)
]
replicas = self.replicate(self.module, self.device_ids[: len(inputs)])
outputs = self.parallel_apply(replicas, inputs, None)
return self.gather(outputs, self.output_device)
[docs]class ParallelCollater:
"""
Data collater for multi-GPU training.
"""
def __init__(self, num_gpus, collater):
self.num_gpus = num_gpus
self.collater = collater
def __call__(self, data_list):
if self.num_gpus <= 1:
batch = self.collater(data_list)
batch_list = [batch[0]]
target_list = [batch[1]]
return [batch_list, target_list]
else:
num_devices = min(self.num_gpus, len(data_list))
count = torch.tensor([data.num_nodes for data in data_list])
cumsum = count.cumsum(0)
cumsum = torch.cat([cumsum.new_zeros(1), cumsum], dim=0)
device_id = num_devices * cumsum.to(torch.float) / cumsum[-1].item()
device_id = (device_id[:-1] + device_id[1:]) / 2.0
device_id = device_id.to(torch.long)
split = device_id.bincount().cumsum(0)
split = torch.cat([split.new_zeros(1), split], dim=0)
split = torch.unique(split, sorted=True)
split = split.tolist()
skorch_list = [
self.collater(data_list[split[i] : split[i + 1]])
for i in range(len(split) - 1)
]
batch_list = [batch[0] for batch in skorch_list]
target_list = [batch[1] for batch in skorch_list]
return [batch_list, target_list]