diff --git a/src/data.py b/src/data.py index c1920c1..e9561af 100644 --- a/src/data.py +++ b/src/data.py @@ -1,4 +1,4 @@ -"Store graph data using PyTorch Geometric abstractions." +"""Store graph data using PyTorch Geometric abstractions.""" import csv from pathlib import Path @@ -6,17 +6,18 @@ import pandas as pd import torch import tqdm - -from ase.io import read from ase import Atoms +from ase.io import read from torch_geometric.data import Data, Dataset -from utils import partition_structure, featurize_atoms -from featurizers import OneHotEncoder from constants import REPO_PATH +from featurizers import OneHotEncoder +from utils import featurize_atoms, partition_structure + class AtomsDataset(Dataset): """Class to hold a dataset containing graphs of atomic_structures.""" + def __init__(self, root, prop_csv): """Initialize an AtomsDataset. @@ -50,17 +51,21 @@ def __init__(self, root, prop_csv): self.props.append(float(row[1])) # Create name to property map - self.map_name_prop = { - name: prop for name, prop in zip(self.names, self.props) - } + self.map_name_prop = {name: prop for name, prop in zip(self.names, self.props)} # Load index.csv if processed self.index_path = self.processed_path / "index.csv" if self.processed_status(): self.df_name_idx = pd.read_csv(self.index_path) - def process_data(self, z_cutoffs, node_features, edge_features, max_atoms=None, - encoder=OneHotEncoder()): + def process_data( + self, + z_cutoffs, + node_features, + edge_features, + max_atoms=None, + encoder=OneHotEncoder(), + ): """Process raw data in the root directory into PyTorch Data and save. Each atomic structure in the root directory is partitioned based on the @@ -76,12 +81,12 @@ def process_data(self, z_cutoffs, node_features, edge_features, max_atoms=None, length of z_cutoffs. node_features: list[list] List of lists of node featurization methods to be used for each - partition. For e.g., specify [["atomic_number", "dband_center"], + partition. For e.g., specify [["atomic_number", "dband_center"], ["atomic_number", "reactivity"], ["atomic_number", "reactivity"]] for a typical bulk + surface + adsorbate partition. edge_features: list[list] List of lists of edge featurization methods to be used for each - partition. For e.g., specify [["bulk_bond_distance"], + partition. For e.g., specify [["bulk_bond_distance"], ["surface_bond_distance"], ["adsorbate_bond_distance"]] for a typical bulk + surface + adsorbate partition. max_atoms: int (default = None) @@ -96,8 +101,9 @@ def process_data(self, z_cutoffs, node_features, edge_features, max_atoms=None, ) # Iterate over files and process them - for i, name in tqdm.tqdm(enumerate(self.names), desc="Processing data", - total=len(self.names)): + for i, name in tqdm.tqdm( + enumerate(self.names), desc="Processing data", total=len(self.names) + ): # Map index to name self.df_name_idx.loc[i, "index"] = i self.df_name_idx.loc[i, "name"] = name @@ -120,7 +126,7 @@ def process_data(self, z_cutoffs, node_features, edge_features, max_atoms=None, node_features=node_features[j], edge_features=edge_features[j], max_atoms=max_atoms, - encoder=encoder + encoder=encoder, ) # Convert to Data object @@ -128,7 +134,7 @@ def process_data(self, z_cutoffs, node_features, edge_features, max_atoms=None, x=feat_dict["node_tensor"], edge_index=feat_dict["edge_indices"], edge_attr=feat_dict["edge_tensor"], - y=torch.Tensor([self.map_name_prop[name]]) + y=torch.Tensor([self.map_name_prop[name]]), ) data_objects.append(data_obj) @@ -137,16 +143,16 @@ def process_data(self, z_cutoffs, node_features, edge_features, max_atoms=None, # Save name-index dataframe self.df_name_idx.to_csv(self.index_path, index=None) - + def len(self): """Return size of the dataset.""" return len(self.names) - + def get(self, i): """Fetch the processed graph(s) at the i-th index.""" data_objects = torch.load(self.processed_path / f"data_{i}.pt") return data_objects - + def processed_status(self): """Check if the dataset is processed.""" if Path(self.index_path).exists(): @@ -154,15 +160,17 @@ def processed_status(self): else: return False + class AtomsDatapoints: """Class to hold atomic structures as a datapoints (without targets). - + This main difference between this class and AtomsDataset is that this is initialized with a list of atoms objects (as opposed to a directory with files containing atomic structures) without any targets specified. This is useful to make predictions on atomic structures for which true target values are not known, i.e., previously unseen structures. """ + def __init__(self, atoms): """Initialize an AtomsDatapoint. @@ -181,8 +189,14 @@ def __init__(self, atoms): self.atoms = atoms self.data = [] - def process_data(self, z_cutoffs, node_features, edge_features, max_atoms=None, - encoder=OneHotEncoder()): + def process_data( + self, + z_cutoffs, + node_features, + edge_features, + max_atoms=None, + encoder=OneHotEncoder(), + ): """Process list of Atoms objects into PyTorch Data and save. Each atomic structure in the root directory is partitioned based on the @@ -198,12 +212,12 @@ def process_data(self, z_cutoffs, node_features, edge_features, max_atoms=None, length of z_cutoffs. node_features: list[list] List of lists of node featurization methods to be used for each - partition. For e.g., specify [["atomic_number", "dband_center"], + partition. For e.g., specify [["atomic_number", "dband_center"], ["atomic_number", "reactivity"], ["atomic_number", "reactivity"]] for a typical bulk + surface + adsorbate partition. edge_features: list[list] List of lists of edge featurization methods to be used for each - partition. For e.g., specify [["bulk_bond_distance"], + partition. For e.g., specify [["bulk_bond_distance"], ["surface_bond_distance"], ["adsorbate_bond_distance"]] for a typical bulk + surface + adsorbate partition. max_atoms: int (default is None) @@ -226,7 +240,7 @@ def process_data(self, z_cutoffs, node_features, edge_features, max_atoms=None, node_features=node_features[j], edge_features=edge_features[j], max_atoms=max_atoms, - encoder=encoder + encoder=encoder, ) # Convert to Data object @@ -243,12 +257,13 @@ def process_data(self, z_cutoffs, node_features, edge_features, max_atoms=None, def len(self): """Return size of the dataset.""" return len(self.data) - + def get(self, i): """Fetch the processed graph(s) at the i-th index.""" data_objects = self.data[i] return data_objects + def load_dataset(root, prop_csv, process_dict=None): """Load an AtomsDataset at the path given by root. @@ -273,7 +288,7 @@ def load_dataset(root, prop_csv, process_dict=None): node_features, edge_features, max_atoms (optional), encoder (optional). Refer to the documentation of process_atoms for more information regarding these parameters. - + Returns ------- dataset: AtomsDataset @@ -285,6 +300,7 @@ def load_dataset(root, prop_csv, process_dict=None): return dataset + def load_datapoints(atoms, process_dict): """Load AtomsDatapoints for the provided ase.Atoms or list of ase.Atoms. @@ -297,13 +313,13 @@ def load_datapoints(atoms, process_dict): ---------- atoms: ase.Atoms object or a list of ase.Atoms objects Structures for which predictions are to be made. - process_dict: dict + process_dict: dict Parameters to process the provided Atoms objects into graphs. - This should contain the following keys: z_cutoffs, node_features, - edge_features, max_atoms (optional), encoder (optional). Refer to the + This should contain the following keys: z_cutoffs, node_features, + edge_features, max_atoms (optional), encoder (optional). Refer to the documentation of process_atoms for more information regarding these parameters. - + Returns ------- datapoints: AtomsDatapoints @@ -312,7 +328,7 @@ def load_datapoints(atoms, process_dict): datapoints = AtomsDatapoints(atoms) if process_dict is not None: datapoints.process_data(**process_dict) - + return datapoints @@ -323,7 +339,7 @@ def load_datapoints(atoms, process_dict): # Create dataset dataset = AtomsDataset(data_root_path, prop_csv_path) - # dataset.process_data(z_cutoffs=[13., 20.], + # dataset.process_data(z_cutoffs=[13., 20.], # node_features=[ # ["atomic_number", "dband_center"], # ["atomic_number", "reactivity"], @@ -340,15 +356,17 @@ def load_datapoints(atoms, process_dict): # Create datapoint atoms = read(data_root_path / "Pt_3_Rh_9_-7-7-S.cif") datapoint = AtomsDatapoints(atoms) - datapoint.process_data(z_cutoffs=[13., 20.], - node_features=[ - ["atomic_number", "dband_center"], - ["atomic_number", "reactivity"], - ["atomic_number", "reactivity"], - ], - edge_features=[ - ["bulk_bond_distance"], - ["surface_bond_distance"], - ["adsorbate_bond_distance"], - ]) - print(datapoint.get(0)) \ No newline at end of file + datapoint.process_data( + z_cutoffs=[13.0, 20.0], + node_features=[ + ["atomic_number", "dband_center"], + ["atomic_number", "reactivity"], + ["atomic_number", "reactivity"], + ], + edge_features=[ + ["bulk_bond_distance"], + ["surface_bond_distance"], + ["adsorbate_bond_distance"], + ], + ) + print(datapoint.get(0))