-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Gaurav S Deshmukh
committed
Sep 19, 2023
1 parent
c5e779d
commit 4cced65
Showing
3 changed files
with
120 additions
and
10 deletions.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,116 @@ | ||
| "Store graph data using PyTorch Geometric abstractions." | ||
|
|
||
| import csv | ||
| from pathlib import Path | ||
|
|
||
| import torch | ||
|
|
||
| from ase.io import read | ||
| from torch_geometric.data import Data, Dataset | ||
|
|
||
| from utils import partition_structure, featurize_atoms | ||
| from featurizers import OneHotEncoder | ||
|
|
||
| class AtomsDataset(Dataset): | ||
| """Class to hold a dataset containing graphs of atomic_structures.""" | ||
| def __init__(self, root, prop_csv): | ||
| """Initialize an AtomsDataset. | ||
| Atomic structures stored as .cif files in the root directory are loaded. | ||
| Paramters | ||
| --------- | ||
| root: str | ||
| Path to the directory in which atomic structures are stored | ||
| pro_csv: str | ||
| Path to the file mapping atomic structure filename and property. | ||
| This filename will typically have two columns, the first with the | ||
| names of the cif files (without .cif) and the second with the | ||
| corresponding target property values. | ||
| """ | ||
| super().__init__(root) | ||
|
|
||
| # Read csv | ||
| self.prop_csv = prop_csv | ||
| self.names = [] | ||
| self.props = [] | ||
| with open(self.prop_csv, "r") as f: | ||
| csv_reader = csv.reader(f) | ||
| for row in csv_reader: | ||
| self.names.append(str(row[0])) | ||
| self.props.append(float(row[1])) | ||
|
|
||
| # Create name to property map | ||
| self.map_name_prop = { | ||
| name: prop for name, prop in zip(self.names, self.props) | ||
| } | ||
|
|
||
| def process(self, z_cutoffs, node_features, edge_features, max_atoms=12, | ||
| encoder=OneHotEncoder()): | ||
| """Process raw data in the root directory into PyTorch Data and save. | ||
| Each atomic structure in the root directory is partitioned based on the | ||
| given z_cutoffs and each partition is featurized according to the given | ||
| node_features and edge_features. The featurized graphs are converted | ||
| into Data objects and stored in the "processed" directory under root. | ||
| Parameters | ||
| ---------- | ||
| z_cutoffs: list or np.ndarray | ||
| List of z-coordinates based on which atomic structures are | ||
| partitioned. The number of partitions is equal to one more than the | ||
| length of z_cutoffs. | ||
| node_features: list[list] | ||
| List of lists of node featurization methods to be used for each | ||
| partition. For e.g., specify [["atomic_number", "dband_center"], | ||
| ["atomic_number", "reactivity"], ["atomic_number", "reactivity"]] for | ||
| a typical bulk + surface + adsorbate partition. | ||
| edge_features: list[list] | ||
| List of lists of edge featurization methods to be used for each | ||
| partition. For e.g., specify [["bulk_bond_distance"], | ||
| ["surface_bond_distance"], ["adsorbate_bond_distance"]] for | ||
| a typical bulk + surface + adsorbate partition. | ||
| encoder: OneHotEncoder object | ||
| Encoder to convert properties to vectors | ||
| """ | ||
| # Root path | ||
| root_path = Path(self.root) | ||
|
|
||
| # Create processed path if it doesn't exist | ||
| processed_path = Path(self.processed_dir).mkdir(exist_ok=True) | ||
|
|
||
| # Iterate over files and process them | ||
| for name in self.names: | ||
| # Set file path | ||
| file_path = root_path / name + ".cif" | ||
|
|
||
| # Read structure | ||
| atoms = read(str(file_path)) | ||
|
|
||
| # Partition structure | ||
| part_atoms = partition_structure(atoms, z_cutoffs) | ||
|
|
||
| # Featurize partitions | ||
| data_objects = [] | ||
| for i, part_idx in enumerate(part_atoms): | ||
| feat_dict = featurize_atoms( | ||
| atoms, | ||
| part_idx, | ||
| node_features=node_features[i], | ||
| edge_features=edge_features[i],\ | ||
| max_atoms=max_atoms, | ||
| encoder=encoder | ||
| ) | ||
|
|
||
| # Convert to Data object | ||
| data_obj = Data( | ||
| x=feat_dict["node_tensor"], | ||
| edge_index=feat_dict["edge_indices"], | ||
| edge_attr=feat_dict["edge_tensor"], | ||
| y=torch.Tensor([self.map_name_prop[name]]) | ||
| ) | ||
| data_objects.append(data_obj) | ||
|
|
||
| # Save data objects | ||
| torch.save(data_objects, processed_path / name + ".pt") | ||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters