From e98d9988c15b78f2184118462e8c1623e0e78555 Mon Sep 17 00:00:00 2001 From: Gaurav S Deshmukh Date: Thu, 21 Sep 2023 19:07:29 -0400 Subject: [PATCH] AtomsDataset works --- .gitignore | 1 + src/data.py | 48 +++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 40 insertions(+), 9 deletions(-) diff --git a/.gitignore b/.gitignore index bc0c661..513abe4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ *POSCAR* *CONTCAR* *.csv +data/* !data/dband_centers.csv __pycache__ *.cif diff --git a/src/data.py b/src/data.py index c322fb5..671537d 100644 --- a/src/data.py +++ b/src/data.py @@ -4,12 +4,14 @@ from pathlib import Path import torch +import tqdm from ase.io import read from torch_geometric.data import Data, Dataset from utils import partition_structure, featurize_atoms from featurizers import OneHotEncoder +from constants import REPO_PATH class AtomsDataset(Dataset): """Class to hold a dataset containing graphs of atomic_structures.""" @@ -22,7 +24,7 @@ def __init__(self, root, prop_csv): --------- root: str Path to the directory in which atomic structures are stored - pro_csv: str + prop_csv: str Path to the file mapping atomic structure filename and property. This filename will typically have two columns, the first with the names of the cif files and the second with the @@ -31,6 +33,10 @@ def __init__(self, root, prop_csv): super().__init__(root) self.root_path = Path(self.root) + # Create processed path if it doesn't exist + self.processed_path = Path(self.processed_dir) + self.processed_path.mkdir(exist_ok=True) + # Read csv self.prop_csv = prop_csv self.names = [] @@ -46,8 +52,8 @@ def __init__(self, root, prop_csv): name: prop for name, prop in zip(self.names, self.props) } - def process(self, z_cutoffs, node_features, edge_features, max_atoms=12, - encoder=OneHotEncoder()): + def process_data(self, z_cutoffs, node_features, edge_features, pad=False, + max_atoms=12, encoder=OneHotEncoder()): """Process raw data in the root directory into PyTorch Data and save. Each atomic structure in the root directory is partitioned based on the @@ -71,14 +77,17 @@ def process(self, z_cutoffs, node_features, edge_features, max_atoms=12, partition. For e.g., specify [["bulk_bond_distance"], ["surface_bond_distance"], ["adsorbate_bond_distance"]] for a typical bulk + surface + adsorbate partition. + pad: bool + Whether to pad the graph with empty nodes to make total nodes add + up to max_atoms + max_atoms: int + Maximum number of nodes in graph. Only used if pad is True. encoder: OneHotEncoder object Encoder to convert properties to vectors """ - # Create processed path if it doesn't exist - self.processed_path = Path(self.processed_dir).mkdir(exist_ok=True) - # Iterate over files and process them - for i, name in enumerate(self.names): + for i, name in tqdm.tqdm(enumerate(self.names), desc="Processing data", + total=len(self.names)): # Set file path file_path = self.root_path / name @@ -95,7 +104,8 @@ def process(self, z_cutoffs, node_features, edge_features, max_atoms=12, atoms, part_idx, node_features=node_features[j], - edge_features=edge_features[j],\ + edge_features=edge_features[j], + pad=pad, max_atoms=max_atoms, encoder=encoder ) @@ -122,4 +132,24 @@ def len(self): def get(self, i): """Fetch the processed graph(s) at the i-th index.""" data_objects = torch.load(self.processed_path / f"data_{i}.pt") - return data_objects \ No newline at end of file + return data_objects + +if __name__ == "__main__": + # Get path to root directory + data_root_path = Path(REPO_PATH) / "data" / "S_calcs" + prop_csv_path = data_root_path / "name_prop.csv" + + # Create dataset + dataset = AtomsDataset(data_root_path, prop_csv_path) + dataset.process_data(z_cutoffs=[13., 20.], + node_features=[ + ["atomic_number", "dband_center"], + ["atomic_number", "reactivity"], + ["atomic_number", "reactivity"], + ], + edge_features=[ + ["bulk_bond_distance"], + ["surface_bond_distance"], + ["adsorbate_bond_distance"], + ]) + print(dataset[0][-2].x) \ No newline at end of file