Source code for amptorch.descriptor.base_descriptor

import os
from abc import ABC, abstractmethod

import h5py
import numpy as np
from tqdm import tqdm

from .util import get_hash, list_symbols_to_indices, validate_image


[docs]class BaseDescriptor(ABC): def __init__(self): super().__init__() self.fp_database = "processed/descriptors/" # To Be specified/calculated self.descriptor_type = "default" self.descriptor_setup_hash = "default" self.elements = []
[docs] @abstractmethod def calculate_fingerprints(self, image, params_set, calculate_derivatives=True): # image is a single snapshot pass
[docs] @abstractmethod def get_descriptor_setup_hash(self): # set self.descriptor_setup_hash pass
[docs] @abstractmethod def save_descriptor_setup(self, filename): pass
[docs] @abstractmethod def prepare_descriptor_parameters(self): # prepare self.params_set pass
[docs] def prepare_fingerprints( self, images, calc_derivatives, save_fps, verbose, cores, log ): images_descriptor_list = [] # if save is true, create directories if not exist self._setup_fingerprint_database(save_fps=save_fps) for image in tqdm( images, total=len(images), desc="Computing fingerprints", disable=not verbose, ): validate_image(image) image_hash = get_hash(image) image_db_filename = "{}/{}.h5".format(self.desc_fp_database_dir, image_hash) # if save, then read/write from db as needed if save_fps: try: temp_descriptor_list = self._compute_fingerprints( image, image_db_filename, calc_derivatives=calc_derivatives, save_fps=save_fps, cores=cores, log=log, ) except Exception: print( "File {} not loaded properly\nProceed to compute in run-time".format( image_db_filename ) ) temp_descriptor_list = self._compute_fingerprints_nodb( image, image_db_filename, calc_derivatives=calc_derivatives, save_fps=save_fps, cores=cores, log=log, ) # if not save, compute fps on-the-fly else: temp_descriptor_list = self._compute_fingerprints_nodb( image, image_db_filename, calc_derivatives=calc_derivatives, save_fps=save_fps, cores=cores, log=log, ) images_descriptor_list += temp_descriptor_list return images_descriptor_list
def _compute_fingerprints( self, image, image_db_filename, calc_derivatives, save_fps, cores, log ): descriptor_list = [] with h5py.File(image_db_filename, "a") as db: image_dict = {} symbol_arr = np.array(image.get_chemical_symbols()) image_dict["atomic_numbers"] = list_symbols_to_indices(symbol_arr) num_atoms = len(symbol_arr) image_dict["num_atoms"] = num_atoms try: current_snapshot_grp = db[str(0)] except Exception: current_snapshot_grp = db.create_group(str(0)) num_desc_list = [] index_arr_dict = {} num_desc_dict = {} fp_dict = {} fp_prime_val_dict = {} fp_prime_row_dict = {} fp_prime_col_dict = {} fp_prime_size_dict = {} for element in self.elements: if element in image.get_chemical_symbols(): index_arr = np.arange(num_atoms)[symbol_arr == element] index_arr_dict[element] = index_arr try: current_element_grp = current_snapshot_grp[element] except Exception: current_element_grp = current_snapshot_grp.create_group(element) if calc_derivatives: try: size_info = np.array(current_element_grp["size_info"]) fps = np.array(current_element_grp["fps"]) fp_primes_val = np.array( current_element_grp["fp_primes_val"] ) fp_primes_row = np.array( current_element_grp["fp_primes_row"] ) fp_primes_col = np.array( current_element_grp["fp_primes_col"] ) fp_primes_size = np.array( current_element_grp["fp_primes_size"] ) except Exception: ( size_info, fps, fp_primes_val, fp_primes_row, fp_primes_col, fp_primes_size, ) = self.calculate_fingerprints( image, element, calc_derivatives=calc_derivatives, log=log, ) if save_fps: current_element_grp.create_dataset( "size_info", data=size_info ) current_element_grp.create_dataset("fps", data=fps) current_element_grp.create_dataset( "fp_primes_val", data=fp_primes_val ) current_element_grp.create_dataset( "fp_primes_row", data=fp_primes_row ) current_element_grp.create_dataset( "fp_primes_col", data=fp_primes_col ) current_element_grp.create_dataset( "fp_primes_size", data=fp_primes_size ) num_desc_list.append(size_info[2]) num_desc_dict[element] = size_info[2] fp_dict[element] = fps fp_prime_val_dict[element] = fp_primes_val fp_prime_row_dict[element] = fp_primes_row fp_prime_col_dict[element] = fp_primes_col fp_prime_size_dict[element] = fp_primes_size else: try: size_info = np.array(current_element_grp["size_info"]) fps = np.array(current_element_grp["fps"]) except Exception: size_info, fps, _, _, _, _ = self.calculate_fingerprints( image, element, calc_derivatives=calc_derivatives, log=log, ) if save_fps: current_element_grp.create_dataset( "size_info", data=size_info ) current_element_grp.create_dataset("fps", data=fps) num_desc_list.append(size_info[2]) num_desc_dict[element] = size_info[2] fp_dict[element] = fps else: pass # print("element not in current image: {}".format(element)) num_desc_max = np.max(num_desc_list) image_fp_array = np.zeros((num_atoms, num_desc_max)) for element in fp_dict.keys(): image_fp_array[ index_arr_dict[element], : num_desc_dict[element] ] = fp_dict[element] image_dict["descriptors"] = image_fp_array image_dict["num_descriptors"] = num_desc_dict if calc_derivatives: descriptor_prime_dict = {} descriptor_prime_dict["size"] = np.array( [num_desc_max * num_atoms, 3 * num_atoms] ) descriptor_prime_row_list = [] descriptor_prime_col_list = [] descriptor_prime_val_list = [] for element in fp_prime_val_dict.keys(): descriptor_prime_row_list.append( self._fp_prime_element_row_index_to_image_row_index( fp_prime_row_dict[element], index_arr_dict[element], num_desc_dict[element], num_desc_max, ) ) descriptor_prime_col_list.append(fp_prime_col_dict[element]) descriptor_prime_val_list.append(fp_prime_val_dict[element]) descriptor_prime_dict["row"] = np.concatenate(descriptor_prime_row_list) descriptor_prime_dict["col"] = np.concatenate(descriptor_prime_col_list) descriptor_prime_dict["val"] = np.concatenate(descriptor_prime_val_list) image_dict["descriptor_primes"] = descriptor_prime_dict descriptor_list.append(image_dict) return descriptor_list def _compute_fingerprints_nodb( self, image, image_db_filename, calc_derivatives, save_fps, cores, log ): descriptor_list = [] image_dict = {} symbol_arr = np.array(image.get_chemical_symbols()) image_dict["atomic_numbers"] = list_symbols_to_indices(symbol_arr) num_atoms = len(symbol_arr) image_dict["num_atoms"] = num_atoms num_desc_list = [] index_arr_dict = {} num_desc_dict = {} fp_dict = {} fp_prime_val_dict = {} fp_prime_row_dict = {} fp_prime_col_dict = {} fp_prime_size_dict = {} for element in self.elements: if element in image.get_chemical_symbols(): index_arr = np.arange(num_atoms)[symbol_arr == element] index_arr_dict[element] = index_arr if calc_derivatives: ( size_info, fps, fp_primes_val, fp_primes_row, fp_primes_col, fp_primes_size, ) = self.calculate_fingerprints( image, element, calc_derivatives=calc_derivatives, log=log, ) num_desc_list.append(size_info[2]) num_desc_dict[element] = size_info[2] fp_dict[element] = fps fp_prime_val_dict[element] = fp_primes_val fp_prime_row_dict[element] = fp_primes_row fp_prime_col_dict[element] = fp_primes_col fp_prime_size_dict[element] = fp_primes_size else: size_info, fps, _, _, _, _ = self.calculate_fingerprints( image, element, calc_derivatives=calc_derivatives, log=log ) num_desc_list.append(size_info[2]) num_desc_dict[element] = size_info[2] fp_dict[element] = fps else: pass # print("element not in current image: {}".format(element)) num_desc_max = np.max(num_desc_list) image_fp_array = np.zeros((num_atoms, num_desc_max)) for element in fp_dict.keys(): image_fp_array[index_arr_dict[element], : num_desc_dict[element]] = fp_dict[ element ] image_dict["descriptors"] = image_fp_array image_dict["num_descriptors"] = num_desc_dict if calc_derivatives: descriptor_prime_dict = {} descriptor_prime_dict["size"] = np.array( [num_desc_max * num_atoms, 3 * num_atoms] ) descriptor_prime_row_list = [] descriptor_prime_col_list = [] descriptor_prime_val_list = [] for element in fp_prime_val_dict.keys(): descriptor_prime_row_list.append( self._fp_prime_element_row_index_to_image_row_index( fp_prime_row_dict[element], index_arr_dict[element], num_desc_dict[element], num_desc_max, ) ) descriptor_prime_col_list.append(fp_prime_col_dict[element]) descriptor_prime_val_list.append(fp_prime_val_dict[element]) descriptor_prime_dict["row"] = np.concatenate(descriptor_prime_row_list) descriptor_prime_dict["col"] = np.concatenate(descriptor_prime_col_list) descriptor_prime_dict["val"] = np.concatenate(descriptor_prime_val_list) image_dict["descriptor_primes"] = descriptor_prime_dict descriptor_list.append(image_dict) return descriptor_list def _fp_prime_element_row_index_to_image_row_index( self, original_rows, index_arr, num_desc, num_desc_max ): atom_indices_for_specific_element, desc_indices = np.divmod( original_rows, num_desc ) atom_indices_in_image = index_arr[atom_indices_for_specific_element] new_row = atom_indices_in_image * num_desc_max + desc_indices return new_row def _setup_fingerprint_database(self, save_fps): self.get_descriptor_setup_hash() self.desc_type_database_dir = "{}/{}".format( self.fp_database, self.descriptor_type ) self.desc_fp_database_dir = "{}/{}".format( self.desc_type_database_dir, self.descriptor_setup_hash ) if save_fps: os.makedirs(self.fp_database, exist_ok=True) os.makedirs(self.desc_type_database_dir, exist_ok=True) os.makedirs(self.desc_fp_database_dir, exist_ok=True) descriptor_setup_filename = "descriptor_log.txt" descriptor_setup_path = "{}/{}".format( self.desc_fp_database_dir, descriptor_setup_filename ) self.save_descriptor_setup(descriptor_setup_path) def _get_element_list(self): return self.elements