-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #10 from GreeleyGroup/enh/data
ENH: Add PyTorch data handling methods
- Loading branch information
Showing
5 changed files
with
391 additions
and
22 deletions.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,7 @@ | ||
| *POSCAR* | ||
| *CONTCAR* | ||
| *.csv | ||
| data/* | ||
| !data/dband_centers.csv | ||
| __pycache__ | ||
| *.cif |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,372 @@ | ||
| """Store graph data using PyTorch Geometric abstractions.""" | ||
|
|
||
| import csv | ||
| from pathlib import Path | ||
|
|
||
| import pandas as pd | ||
| import torch | ||
| import tqdm | ||
| from ase import Atoms | ||
| from ase.io import read | ||
| from torch_geometric.data import Data, Dataset | ||
|
|
||
| from constants import REPO_PATH | ||
| from featurizers import OneHotEncoder | ||
| from utils import featurize_atoms, partition_structure | ||
|
|
||
|
|
||
| class AtomsDataset(Dataset): | ||
| """Class to hold a dataset containing graphs of atomic_structures.""" | ||
|
|
||
| def __init__(self, root, prop_csv): | ||
| """Initialize an AtomsDataset. | ||
| Atomic structures stored as .cif files in the root directory are loaded. | ||
| Paramters | ||
| --------- | ||
| root: str | ||
| Path to the directory in which atomic structures are stored | ||
| prop_csv: str | ||
| Path to the file mapping atomic structure filename and property. | ||
| This filename will typically have two columns, the first with the | ||
| names of the cif files and the second with the | ||
| corresponding target property values. | ||
| """ | ||
| super().__init__(root) | ||
| self.root_path = Path(self.root) | ||
|
|
||
| # Create processed path if it doesn't exist | ||
| self.processed_path = Path(self.processed_dir) | ||
| self.processed_path.mkdir(exist_ok=True) | ||
|
|
||
| # Read csv | ||
| self.prop_csv = prop_csv | ||
| self.names = [] | ||
| self.props = [] | ||
| with open(self.prop_csv, "r") as f: | ||
| csv_reader = csv.reader(f) | ||
| for row in csv_reader: | ||
| self.names.append(str(row[0])) | ||
| self.props.append(float(row[1])) | ||
|
|
||
| # Create name to property map | ||
| self.map_name_prop = {name: prop for name, prop in zip(self.names, self.props)} | ||
|
|
||
| # Load index.csv if processed | ||
| self.index_path = self.processed_path / "index.csv" | ||
| if self.processed_status(): | ||
| self.df_name_idx = pd.read_csv(self.index_path) | ||
|
|
||
| def process_data( | ||
| self, | ||
| z_cutoffs, | ||
| node_features, | ||
| edge_features, | ||
| max_atoms=None, | ||
| encoder=OneHotEncoder(), | ||
| ): | ||
| """Process raw data in the root directory into PyTorch Data and save. | ||
| Each atomic structure in the root directory is partitioned based on the | ||
| given z_cutoffs and each partition is featurized according to the given | ||
| node_features and edge_features. The featurized graphs are converted | ||
| into Data objects and stored in the "processed" directory under root. | ||
| Parameters | ||
| ---------- | ||
| z_cutoffs: list or np.ndarray | ||
| List of z-coordinates based on which atomic structures are | ||
| partitioned. The number of partitions is equal to one more than the | ||
| length of z_cutoffs. | ||
| node_features: list[list] | ||
| List of lists of node featurization methods to be used for each | ||
| partition. For e.g., specify [["atomic_number", "dband_center"], | ||
| ["atomic_number", "reactivity"], ["atomic_number", "reactivity"]] for | ||
| a typical bulk + surface + adsorbate partition. | ||
| edge_features: list[list] | ||
| List of lists of edge featurization methods to be used for each | ||
| partition. For e.g., specify [["bulk_bond_distance"], | ||
| ["surface_bond_distance"], ["adsorbate_bond_distance"]] for | ||
| a typical bulk + surface + adsorbate partition. | ||
| max_atoms: int (default = None) | ||
| Maximum number of nodes in graph. If a value is provided, graphs are | ||
| padded to make sure the total number of nodes matches max_atoms. | ||
| encoder: OneHotEncoder object | ||
| Encoder to convert properties to vectors | ||
| """ | ||
| # Create empty dataframe to store index and name correspondence | ||
| self.df_name_idx = pd.DataFrame( | ||
| {"index": [0] * len(self.names), "name": [""] * len(self.names)} | ||
| ) | ||
|
|
||
| # Iterate over files and process them | ||
| for i, name in tqdm.tqdm( | ||
| enumerate(self.names), desc="Processing data", total=len(self.names) | ||
| ): | ||
| # Map index to name | ||
| self.df_name_idx.loc[i, "index"] = i | ||
| self.df_name_idx.loc[i, "name"] = name | ||
|
|
||
| # Set file path | ||
| file_path = self.root_path / name | ||
|
|
||
| # Read structure | ||
| atoms = read(str(file_path)) | ||
|
|
||
| # Partition structure | ||
| part_atoms = partition_structure(atoms, z_cutoffs) | ||
|
|
||
| # Featurize partitions | ||
| data_objects = [] | ||
| for j, part_idx in enumerate(part_atoms): | ||
| feat_dict = featurize_atoms( | ||
| atoms, | ||
| part_idx, | ||
| node_features=node_features[j], | ||
| edge_features=edge_features[j], | ||
| max_atoms=max_atoms, | ||
| encoder=encoder, | ||
| ) | ||
|
|
||
| # Convert to Data object | ||
| data_obj = Data( | ||
| x=feat_dict["node_tensor"], | ||
| edge_index=feat_dict["edge_indices"], | ||
| edge_attr=feat_dict["edge_tensor"], | ||
| y=torch.Tensor([self.map_name_prop[name]]), | ||
| ) | ||
| data_objects.append(data_obj) | ||
|
|
||
| # Save data objects | ||
| torch.save(data_objects, self.processed_path / f"data_{i}.pt") | ||
|
|
||
| # Save name-index dataframe | ||
| self.df_name_idx.to_csv(self.index_path, index=None) | ||
|
|
||
| def len(self): | ||
| """Return size of the dataset.""" | ||
| return len(self.names) | ||
|
|
||
| def get(self, i): | ||
| """Fetch the processed graph(s) at the i-th index.""" | ||
| data_objects = torch.load(self.processed_path / f"data_{i}.pt") | ||
| return data_objects | ||
|
|
||
| def processed_status(self): | ||
| """Check if the dataset is processed.""" | ||
| if Path(self.index_path).exists(): | ||
| return True | ||
| else: | ||
| return False | ||
|
|
||
|
|
||
| class AtomsDatapoints: | ||
| """Class to hold atomic structures as a datapoints (without targets). | ||
| This main difference between this class and AtomsDataset is that this is | ||
| initialized with a list of atoms objects (as opposed to a directory with | ||
| files containing atomic structures) without any targets specified. This is | ||
| useful to make predictions on atomic structures for which true target values | ||
| are not known, i.e., previously unseen structures. | ||
| """ | ||
|
|
||
| def __init__(self, atoms): | ||
| """Initialize an AtomsDatapoint. | ||
| Atomic structures provided in the list are initialized. | ||
| Paramters | ||
| --------- | ||
| atoms: ase.Atoms object or a list of ase.Atoms objects | ||
| Structures for which predictions are to be made. | ||
| """ | ||
| # If single object, convert to list | ||
| if isinstance(atoms, Atoms): | ||
| atoms = [atoms] | ||
|
|
||
| # Save object | ||
| self.atoms = atoms | ||
| self.data = [] | ||
|
|
||
| def process_data( | ||
| self, | ||
| z_cutoffs, | ||
| node_features, | ||
| edge_features, | ||
| max_atoms=None, | ||
| encoder=OneHotEncoder(), | ||
| ): | ||
| """Process list of Atoms objects into PyTorch Data and save. | ||
| Each atomic structure in the root directory is partitioned based on the | ||
| given z_cutoffs and each partition is featurized according to the given | ||
| node_features and edge_features. The featurized graphs are converted | ||
| into Data objects and stored in the "processed" directory under root. | ||
| Parameters | ||
| ---------- | ||
| z_cutoffs: list or np.ndarray | ||
| List of z-coordinates based on which atomic structures are | ||
| partitioned. The number of partitions is equal to one more than the | ||
| length of z_cutoffs. | ||
| node_features: list[list] | ||
| List of lists of node featurization methods to be used for each | ||
| partition. For e.g., specify [["atomic_number", "dband_center"], | ||
| ["atomic_number", "reactivity"], ["atomic_number", "reactivity"]] for | ||
| a typical bulk + surface + adsorbate partition. | ||
| edge_features: list[list] | ||
| List of lists of edge featurization methods to be used for each | ||
| partition. For e.g., specify [["bulk_bond_distance"], | ||
| ["surface_bond_distance"], ["adsorbate_bond_distance"]] for | ||
| a typical bulk + surface + adsorbate partition. | ||
| max_atoms: int (default is None) | ||
| Maximum number of nodes in graph. If a value is provided, graphs are | ||
| padded to make sure the total number of nodes matches max_atoms. | ||
| encoder: OneHotEncoder object | ||
| Encoder to convert properties to vectors | ||
| """ | ||
| # Iterate over files and process them | ||
| for atoms_obj in self.atoms: | ||
| # Partition structure | ||
| part_atoms = partition_structure(atoms_obj, z_cutoffs) | ||
|
|
||
| # Featurize partitions | ||
| data_objects = [] | ||
| for j, part_idx in enumerate(part_atoms): | ||
| feat_dict = featurize_atoms( | ||
| atoms_obj, | ||
| part_idx, | ||
| node_features=node_features[j], | ||
| edge_features=edge_features[j], | ||
| max_atoms=max_atoms, | ||
| encoder=encoder, | ||
| ) | ||
|
|
||
| # Convert to Data object | ||
| data_obj = Data( | ||
| x=feat_dict["node_tensor"], | ||
| edge_index=feat_dict["edge_indices"], | ||
| edge_attr=feat_dict["edge_tensor"], | ||
| ) | ||
| data_objects.append(data_obj) | ||
|
|
||
| # Save data objects | ||
| self.data.append(data_objects) | ||
|
|
||
| def len(self): | ||
| """Return size of the dataset.""" | ||
| return len(self.data) | ||
|
|
||
| def get(self, i): | ||
| """Fetch the processed graph(s) at the i-th index.""" | ||
| data_objects = self.data[i] | ||
| return data_objects | ||
|
|
||
|
|
||
| def load_dataset(root, prop_csv, process_dict=None): | ||
| """Load an AtomsDataset at the path given by root. | ||
| If process_dict is provided, the process_data method of AtomsDataset is called | ||
| to convert the atomic structures to graphs based on the given parameters in | ||
| process_dict. This should be used when the dataset is created for the first | ||
| time. | ||
| Parameters | ||
| ---------- | ||
| root: str | ||
| Path to the dataset | ||
| prop_csv: str | ||
| Path to the file mapping atomic structure filename and property. | ||
| This filename will typically have two columns, the first with the | ||
| names of the cif files and the second with the | ||
| corresponding target property values. | ||
| process_dict: dict (default = None) | ||
| If this is provided, atomic structures at root will be processed into | ||
| graphs and stored under a "processed" subdirectory. Only use this when | ||
| creating a new dataset. This should contain the following keys: z_cutoffs, | ||
| node_features, edge_features, max_atoms (optional), encoder (optional). | ||
| Refer to the documentation of process_atoms for more information regarding | ||
| these parameters. | ||
| Returns | ||
| ------- | ||
| dataset: AtomsDataset | ||
| Initialized AtomsDataset object | ||
| """ | ||
| dataset = AtomsDataset(root, prop_csv) | ||
| if process_dict is not None: | ||
| dataset.process_data(**process_dict) | ||
|
|
||
| return dataset | ||
|
|
||
|
|
||
| def load_datapoints(atoms, process_dict): | ||
| """Load AtomsDatapoints for the provided ase.Atoms or list of ase.Atoms. | ||
| If process_dict is provided, the process_data method of AtomsDatapoints is called | ||
| to convert the atomic structures to graphs based on the given parameters in | ||
| process_dict. This should be used when the dataset is created for the first | ||
| time. | ||
| Parameters | ||
| ---------- | ||
| atoms: ase.Atoms object or a list of ase.Atoms objects | ||
| Structures for which predictions are to be made. | ||
| process_dict: dict | ||
| Parameters to process the provided Atoms objects into graphs. | ||
| This should contain the following keys: z_cutoffs, node_features, | ||
| edge_features, max_atoms (optional), encoder (optional). Refer to the | ||
| documentation of process_atoms for more information regarding these | ||
| parameters. | ||
| Returns | ||
| ------- | ||
| datapoints: AtomsDatapoints | ||
| Initialized AtomsDatapoints object | ||
| """ | ||
| datapoints = AtomsDatapoints(atoms) | ||
| if process_dict is not None: | ||
| datapoints.process_data(**process_dict) | ||
|
|
||
| return datapoints | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| # Get path to root directory | ||
| data_root_path = Path(REPO_PATH) / "data" / "S_calcs" | ||
| prop_csv_path = data_root_path / "name_prop.csv" | ||
|
|
||
| # Create dataset | ||
| dataset = AtomsDataset(data_root_path, prop_csv_path) | ||
| # dataset.process_data(z_cutoffs=[13., 20.], | ||
| # node_features=[ | ||
| # ["atomic_number", "dband_center"], | ||
| # ["atomic_number", "reactivity"], | ||
| # ["atomic_number", "reactivity"], | ||
| # ], | ||
| # edge_features=[ | ||
| # ["bulk_bond_distance"], | ||
| # ["surface_bond_distance"], | ||
| # ["adsorbate_bond_distance"], | ||
| # ]) | ||
| print(dataset[0][-1].x) | ||
| print(dataset.df_name_idx.head()) | ||
|
|
||
| # Create datapoint | ||
| atoms = read(data_root_path / "Pt_3_Rh_9_-7-7-S.cif") | ||
| datapoint = AtomsDatapoints(atoms) | ||
| datapoint.process_data( | ||
| z_cutoffs=[13.0, 20.0], | ||
| node_features=[ | ||
| ["atomic_number", "dband_center"], | ||
| ["atomic_number", "reactivity"], | ||
| ["atomic_number", "reactivity"], | ||
| ], | ||
| edge_features=[ | ||
| ["bulk_bond_distance"], | ||
| ["surface_bond_distance"], | ||
| ["adsorbate_bond_distance"], | ||
| ], | ||
| ) | ||
| print(datapoint.get(0)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.