diff --git a/src/data.py b/src/data.py index c482e22..c1920c1 100644 --- a/src/data.py +++ b/src/data.py @@ -59,8 +59,8 @@ def __init__(self, root, prop_csv): if self.processed_status(): self.df_name_idx = pd.read_csv(self.index_path) - def process_data(self, z_cutoffs, node_features, edge_features, pad=False, - max_atoms=12, encoder=OneHotEncoder()): + def process_data(self, z_cutoffs, node_features, edge_features, max_atoms=None, + encoder=OneHotEncoder()): """Process raw data in the root directory into PyTorch Data and save. Each atomic structure in the root directory is partitioned based on the @@ -84,11 +84,9 @@ def process_data(self, z_cutoffs, node_features, edge_features, pad=False, partition. For e.g., specify [["bulk_bond_distance"], ["surface_bond_distance"], ["adsorbate_bond_distance"]] for a typical bulk + surface + adsorbate partition. - pad: bool - Whether to pad the graph with empty nodes to make total nodes add - up to max_atoms - max_atoms: int - Maximum number of nodes in graph. Only used if pad is True. + max_atoms: int (default = None) + Maximum number of nodes in graph. If a value is provided, graphs are + padded to make sure the total number of nodes matches max_atoms. encoder: OneHotEncoder object Encoder to convert properties to vectors """ @@ -121,7 +119,6 @@ def process_data(self, z_cutoffs, node_features, edge_features, pad=False, part_idx, node_features=node_features[j], edge_features=edge_features[j], - pad=pad, max_atoms=max_atoms, encoder=encoder ) @@ -184,8 +181,8 @@ def __init__(self, atoms): self.atoms = atoms self.data = [] - def process_data(self, z_cutoffs, node_features, edge_features, pad=False, - max_atoms=12, encoder=OneHotEncoder()): + def process_data(self, z_cutoffs, node_features, edge_features, max_atoms=None, + encoder=OneHotEncoder()): """Process list of Atoms objects into PyTorch Data and save. Each atomic structure in the root directory is partitioned based on the @@ -209,11 +206,9 @@ def process_data(self, z_cutoffs, node_features, edge_features, pad=False, partition. For e.g., specify [["bulk_bond_distance"], ["surface_bond_distance"], ["adsorbate_bond_distance"]] for a typical bulk + surface + adsorbate partition. - pad: bool - Whether to pad the graph with empty nodes to make total nodes add - up to max_atoms - max_atoms: int - Maximum number of nodes in graph. Only used if pad is True. + max_atoms: int (default is None) + Maximum number of nodes in graph. If a value is provided, graphs are + padded to make sure the total number of nodes matches max_atoms. encoder: OneHotEncoder object Encoder to convert properties to vectors """ @@ -230,7 +225,6 @@ def process_data(self, z_cutoffs, node_features, edge_features, pad=False, part_idx, node_features=node_features[j], edge_features=edge_features[j], - pad=pad, max_atoms=max_atoms, encoder=encoder ) @@ -255,6 +249,73 @@ def get(self, i): data_objects = self.data[i] return data_objects +def load_dataset(root, prop_csv, process_dict=None): + """Load an AtomsDataset at the path given by root. + + If process_dict is provided, the process_data method of AtomsDataset is called + to convert the atomic structures to graphs based on the given parameters in + process_dict. This should be used when the dataset is created for the first + time. + + Parameters + ---------- + root: str + Path to the dataset + prop_csv: str + Path to the file mapping atomic structure filename and property. + This filename will typically have two columns, the first with the + names of the cif files and the second with the + corresponding target property values. + process_dict: dict (default = None) + If this is provided, atomic structures at root will be processed into + graphs and stored under a "processed" subdirectory. Only use this when + creating a new dataset. This should contain the following keys: z_cutoffs, + node_features, edge_features, max_atoms (optional), encoder (optional). + Refer to the documentation of process_atoms for more information regarding + these parameters. + + Returns + ------- + dataset: AtomsDataset + Initialized AtomsDataset object + """ + dataset = AtomsDataset(root, prop_csv) + if process_dict is not None: + dataset.process_data(**process_dict) + + return dataset + +def load_datapoints(atoms, process_dict): + """Load AtomsDatapoints for the provided ase.Atoms or list of ase.Atoms. + + If process_dict is provided, the process_data method of AtomsDatapoints is called + to convert the atomic structures to graphs based on the given parameters in + process_dict. This should be used when the dataset is created for the first + time. + + Parameters + ---------- + atoms: ase.Atoms object or a list of ase.Atoms objects + Structures for which predictions are to be made. + process_dict: dict + Parameters to process the provided Atoms objects into graphs. + This should contain the following keys: z_cutoffs, node_features, + edge_features, max_atoms (optional), encoder (optional). Refer to the + documentation of process_atoms for more information regarding these + parameters. + + Returns + ------- + datapoints: AtomsDatapoints + Initialized AtomsDatapoints object + """ + datapoints = AtomsDatapoints(atoms) + if process_dict is not None: + datapoints.process_data(**process_dict) + + return datapoints + + if __name__ == "__main__": # Get path to root directory data_root_path = Path(REPO_PATH) / "data" / "S_calcs" @@ -274,4 +335,20 @@ def get(self, i): # ["adsorbate_bond_distance"], # ]) print(dataset[0][-1].x) - print(dataset.df_name_idx.head()) \ No newline at end of file + print(dataset.df_name_idx.head()) + + # Create datapoint + atoms = read(data_root_path / "Pt_3_Rh_9_-7-7-S.cif") + datapoint = AtomsDatapoints(atoms) + datapoint.process_data(z_cutoffs=[13., 20.], + node_features=[ + ["atomic_number", "dband_center"], + ["atomic_number", "reactivity"], + ["atomic_number", "reactivity"], + ], + edge_features=[ + ["bulk_bond_distance"], + ["surface_bond_distance"], + ["adsorbate_bond_distance"], + ]) + print(datapoint.get(0)) \ No newline at end of file diff --git a/src/graphs.py b/src/graphs.py index a3ac282..7e0f747 100644 --- a/src/graphs.py +++ b/src/graphs.py @@ -14,7 +14,7 @@ class AtomsGraph: """Create graph representation of a collection of atoms.""" - def __init__(self, atoms, select_idx, pad=True, max_atoms=50): + def __init__(self, atoms, select_idx, max_atoms=None): """Initialize variables of the class. Parameters @@ -25,14 +25,14 @@ def __init__(self, atoms, select_idx, pad=True, max_atoms=50): List of indices of atoms that are to be included in the graph neighbor_list: ase.neighborlist.NeighborList object Neighbor list that defines bonds between atoms - max_atoms: int (default = 50) - The maximum number of atoms in the graph. Graphs that have fewer - atoms are padded with 0s to reach this value. + max_atoms: int (default = None) + The maximum number of atoms in the graph. If it is not None, graphs + that have fewer nodes than max_atoms are padded with 0s to ensure + that the total number of nodes is equal to max_atoms. """ # Save parameters self.atoms = atoms self.select_idx = select_idx - self.pad = pad self.max_atoms = max_atoms # Create graph @@ -109,7 +109,7 @@ def create_graph(self): graph.add_edge(n, self.map_idx_node[nn], bond_distance=bond_dist) # Pad graph - if self.pad: + if self.max_atoms is not None: graph = self.pad_graph(graph) # Add coordination numbers diff --git a/src/utils.py b/src/utils.py index 8436778..2e73f2f 100644 --- a/src/utils.py +++ b/src/utils.py @@ -53,8 +53,7 @@ def featurize_atoms( select_idx, node_features, edge_features, - pad=True, - max_atoms=50, + max_atoms=None, encoder=OneHotEncoder(), ): """Featurize atoms and bonds with the chosen featurizers. @@ -75,14 +74,10 @@ def featurize_atoms( Names of edge featurizers to use (current options: bulk_bond_distance, surface_bond_distance, adsorbate_bond_distance). All of these encode bond distance using a one-hot encoder, but the bounds for each vary. - pad: bool - If True, the graph is padded to ensure the number of nodes is equal to - max_atoms. In that case, the blank nodes have all 0s in their node tensors. - max_atoms: int (default = 50) - Maximum number of allowed atoms. If the number of atoms in the graph are - fewer than this number, the graph is padded with empty nodes. This is - required to make the sizes of the node feature tensors consistent across - structures. + max_atoms: int (default = None) + Maximum number of allowed atoms. If it is not None, graphs + that have fewer nodes than max_atoms are padded with 0s to ensure + that the total number of nodes is equal to max_atoms. encoder: encoder object from featurizers.py Currently only the OneHotEncoder is supported @@ -93,8 +88,7 @@ def featurize_atoms( corresponding tensors as values. """ # Create graph - atoms_graph = AtomsGraph(atoms=atoms, select_idx=select_idx, max_atoms=max_atoms, - pad=pad) + atoms_graph = AtomsGraph(atoms=atoms, select_idx=select_idx, max_atoms=max_atoms) # Collect node featurizers node_feats = [] @@ -154,6 +148,5 @@ def featurize_atoms( ["atomic_number", "dband_center"], ["bulk_bond_distance"], max_atoms=34, - pad=False, ) print(feat_dict)