import skorch
import json
import os
from skorch.utils import to_numpy
from torch_geometric.data import Batch
from torch.nn.parallel.scatter_gather import gather
import torch
[docs]class InOrderSplit:
def __init__(self, val_frac):
self.val_frac = val_frac
def __call__(self, dataset):
len_dataset = len(dataset)
len_val = int(len_dataset * self.val_frac)
len_train = len_dataset - len_val
train_idx = list(range(0, len_train))
val_idx = list(range(len_train, len_dataset))
train_dataset = torch.utils.data.Subset(dataset, train_idx)
val_dataset = torch.utils.data.Subset(dataset, val_idx)
return train_dataset, val_dataset
[docs]def to_tensor(X, device, accept_sparse=False):
if isinstance(X[0], Batch):
return X
else:
for i, batch in enumerate(X):
for j, targets in enumerate(batch):
X[i][j] = targets.to(device)
if device != "cpu":
outputs = gather(X, device)
else:
outputs = X[0]
return outputs
[docs]def save_normalizers(normalizers, path):
tosave = {}
tosave["feature"] = {
"type": normalizers["feature"].transform,
"scales": normalizers["feature"].scales.numpy(),
}
tosave["target"] = {
"mean": normalizers["target"].target_mean.numpy(),
"stddev": normalizers["target"].target_std.numpy(),
}
with open(path, "w", encoding="utf8") as json_file:
json.dump(tosave, json_file, indent=4)
return
[docs]class train_end_load_best_loss(skorch.callbacks.base.Callback):
def __init__(self, filename):
self.filename = filename
[docs] def on_train_end(self, net, X, y):
net.load_params("./checkpoints/{}/params.pt".format(self.filename))
[docs]class check_memory(skorch.callbacks.base.Callback):
[docs] def on_batch_end(self, net, **kwargs):
print(
f"Allocated {torch.cuda.memory_allocated() / 1e6} Mb, cached {torch.cuda.memory_cached() / 1e6} Mb"
)