Source code for amptorch.metrics

import numpy as np
import torch
from skorch.callbacks import Checkpoint, EpochScoring
from torch.nn import L1Loss, MSELoss

from amptorch.utils import target_extractor


[docs]def mae_energy_score(net, X, y): """ Compute the energy MAE of the model. """ mae_loss = L1Loss() energy_pred, _ = net.forward(X) if isinstance(X, torch.utils.data.Subset): X = X.dataset # TODO: investigate the np.concatenate part in skorch 0.10 # site-packages/skorch/callbacks/scoring.py Line 625 # as shown by the following test, it's concatenate the y_target # oiginal value # it is probably skorch's fault here, nothing much we can do # so might worth trying new versions of it # print(np.concatenate([i[0] for i in y])) energy_pred = X.target_scaler.denorm(energy_pred, pred="energy") energy_target = X.target_scaler.denorm( torch.FloatTensor(np.concatenate([i[0] for i in y])), pred="energy" ) energy_loss = mae_loss(energy_pred, energy_target) return energy_loss.item()
[docs]def mae_forces_score(net, X, y): """ Compute the force MAE of the model. """ mae_loss = L1Loss() _, force_pred = net.forward(X) if isinstance(X, torch.utils.data.Subset): X = X.dataset force_pred = X.target_scaler.denorm(force_pred, pred="forces") force_target = X.target_scaler.denorm( torch.FloatTensor(np.concatenate([i[1] for i in y])), pred="forces" ) force_loss = mae_loss(force_pred, force_target) return force_loss.item()
[docs]def mse_energy_score(net, X, y): """ Compute the energy MSE of the model. """ mse_loss = MSELoss() energy_pred, _ = net.forward(X) if isinstance(X, torch.utils.data.Subset): X = X.dataset energy_pred = X.target_scaler.denorm(energy_pred, pred="energy") energy_target = X.target_scaler.denorm( torch.FloatTensor(np.concatenate([i[0] for i in y])), pred="energy" ) energy_loss = mse_loss(energy_pred, energy_target) return energy_loss.item()
[docs]def mse_forces_score(net, X, y): """ Compute the force MSE of the model. """ mse_loss = MSELoss() _, force_pred = net.forward(X) if isinstance(X, torch.utils.data.Subset): X = X.dataset force_pred = X.target_scaler.denorm(force_pred, pred="forces") force_target = X.target_scaler.denorm( torch.FloatTensor(np.concatenate([i[1] for i in y])), pred="forces" ) force_loss = mse_loss(force_pred, force_target) return force_loss.item()
[docs]def evaluator( val_split, metric, identifier, forcetraining, cp_metric, ): """ For metric display with callbacks. """ # print("evaluator") callbacks = [] isval = val_split != 0 if isval: cp_on = "val" else: cp_on = "train" if metric == "mae": energy_score = mae_energy_score forces_score = mae_forces_score elif metric == "mse": energy_score = mse_energy_score forces_score = mse_forces_score else: raise NotImplementedError(f"{metric} metric not available!") if cp_metric not in ["energy", "forces"]: raise NotImplementedError(f"{cp_metric} value not valid!") callbacks.append( MemEffEpochScoring( energy_score, on_train=True, use_caching=True, name="train_energy_{}".format(metric), target_extractor=target_extractor, ) ) if isval: callbacks.append( MemEffEpochScoring( energy_score, on_train=False, use_caching=True, name="val_energy_{}".format(metric), target_extractor=target_extractor, ) ) if forcetraining: callbacks.append( MemEffEpochScoring( forces_score, on_train=True, use_caching=True, name="train_forces_{}".format(metric), target_extractor=target_extractor, ) ) if isval: callbacks.append( MemEffEpochScoring( forces_score, on_train=False, use_caching=True, name="val_forces_{}".format(metric), target_extractor=target_extractor, ) ) callbacks.append( Checkpoint( monitor="{}_{}_{}_best".format(cp_on, cp_metric, metric), fn_prefix="checkpoints/{}/".format(identifier), ) ) return callbacks
[docs]def to_cpu(X): """ Detach to cpu. """ if isinstance(X, (tuple, list)): return type(X)(to_cpu(x) for x in X) return X.detach().to("cpu")
[docs]class MemEffEpochScoring(EpochScoring): """ Memory-efficient epoch scorer that caches the predictions for all batches during the epoch. """ def __init__( self, scoring, lower_is_better=True, on_train=False, name=None, target_extractor=None, use_caching=True, ): super().__init__( scoring, lower_is_better, on_train, name, target_extractor, use_caching )
[docs] def on_batch_end(self, net, y, y_pred, training, **kwargs): if not self.use_caching or training != self.on_train: return if y is not None: self.y_trues_.append(to_cpu(y)) self.y_preds_.append(to_cpu(y_pred))