from torch.utils.data import Dataset
from torch_geometric.data import Batch
from amptorch.descriptor.Gaussian import Gaussian
from amptorch.descriptor.GMP import GMP
from amptorch.descriptor.GMPOrderNorm import GMPOrderNorm
from amptorch.preprocessing import (
AtomsToData,
FeatureScaler,
TargetScaler,
sparse_block_diag,
)
[docs]class AtomsDataset(Dataset):
"""
Dataset class to hold information about the ase.Atoms including element, energy, fingerprint (and forces).
Args:
images (list): A list of ase.Atoms objects.
descriptor_setup (dict): A dictionary containing parameters for fingerprint generation.
forcetraining (bool): Whether to train with forces (default is True).
save_fps (bool): Whether to save the fingerprints (default is True).
scaling (dict): A dictionary on how to scale the fingerprints (default is {"type": "normalize", "range": (0, 1), "threshold": 1e-6}).
cores (int): The number of cores to use for parallel processing (default is 1).
process (bool): Whether to process the data during initialization (default is True).
"""
def __init__(
self,
images,
descriptor_setup,
forcetraining=True,
save_fps=True,
scaling={"type": "normalize", "range": (0, 1), "threshold": 1e-6},
cores=1,
process=True,
):
self.images = images
self.forcetraining = forcetraining
self.scaling = scaling
self.descriptor = construct_descriptor(descriptor_setup)
self.a2d = AtomsToData(
descriptor=self.descriptor,
r_energy=True,
r_forces=self.forcetraining,
save_fps=save_fps,
fprimes=forcetraining,
cores=cores,
)
self.data_list = self.process() if process else None
[docs] def process(self):
"""
Compute the fingerprints according to the defined fingerprinting scheme and parameters, scale the feature and targets.
"""
data_list = self.a2d.convert_all(self.images)
self.feature_scaler = FeatureScaler(data_list, self.forcetraining, self.scaling)
self.target_scaler = TargetScaler(data_list, self.forcetraining)
self.feature_scaler.norm(data_list)
self.target_scaler.norm(data_list)
return data_list
@property
def input_dim(self):
return self.data_list[0].fingerprint.shape[1]
def __len__(self):
return len(self.data_list)
def __getitem__(self, index):
return self.data_list[index]
[docs]class DataCollater:
"""
Helper function to batch the dataset.
"""
def __init__(self, train=True, forcetraining=True):
self.train = train
self.forcetraining = forcetraining
def __call__(self, data_list):
if hasattr(data_list[0], "fprimes"):
mtxs = []
for data in data_list:
mtxs.append(data.fprimes)
data.fprimes = None
batch = Batch.from_data_list(data_list)
for i, data in enumerate(data_list):
data.fprimes = mtxs[i]
block_matrix = sparse_block_diag(mtxs)
batch.fprimes = block_matrix
else:
batch = Batch.from_data_list(data_list)
if self.train:
if self.forcetraining:
return batch, [batch.energy, batch.forces]
else:
return batch, [
batch.energy,
]
else:
return batch
[docs]def construct_descriptor(descriptor_setup):
"""
Pass into different fingerprinting classes to obtain the corresponding atomic representations as fingerprints.
"""
fp_scheme, fp_params, cutoff_params, elements = descriptor_setup
if fp_scheme == "gaussian":
descriptor = Gaussian(Gs=fp_params, elements=elements, **cutoff_params)
elif fp_scheme == "gmp":
descriptor = GMP(MCSHs=fp_params, elements=elements)
elif fp_scheme == "gmpordernorm":
descriptor = GMPOrderNorm(MCSHs=fp_params, elements=elements)
else:
raise NotImplementedError
return descriptor