From 003f6698dfdffebd8d38433d68cf76fc6c145765 Mon Sep 17 00:00:00 2001 From: Gaurav S Deshmukh Date: Wed, 20 Sep 2023 11:15:22 -0400 Subject: [PATCH] Added len and get methods --- src/data.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/src/data.py b/src/data.py index 528b64c..708640b 100644 --- a/src/data.py +++ b/src/data.py @@ -29,6 +29,7 @@ def __init__(self, root, prop_csv): corresponding target property values. """ super().__init__(root) + self.root_path = Path(self.root) # Read csv self.prop_csv = prop_csv @@ -73,16 +74,13 @@ def process(self, z_cutoffs, node_features, edge_features, max_atoms=12, encoder: OneHotEncoder object Encoder to convert properties to vectors """ - # Root path - root_path = Path(self.root) - # Create processed path if it doesn't exist - processed_path = Path(self.processed_dir).mkdir(exist_ok=True) + self.processed_path = Path(self.processed_dir).mkdir(exist_ok=True) # Iterate over files and process them - for name in self.names: + for i, name in enumerate(self.names): # Set file path - file_path = root_path / name + ".cif" + file_path = self.root_path / name + ".cif" # Read structure atoms = read(str(file_path)) @@ -92,12 +90,12 @@ def process(self, z_cutoffs, node_features, edge_features, max_atoms=12, # Featurize partitions data_objects = [] - for i, part_idx in enumerate(part_atoms): + for j, part_idx in enumerate(part_atoms): feat_dict = featurize_atoms( atoms, part_idx, - node_features=node_features[i], - edge_features=edge_features[i],\ + node_features=node_features[j], + edge_features=edge_features[j],\ max_atoms=max_atoms, encoder=encoder ) @@ -112,5 +110,13 @@ def process(self, z_cutoffs, node_features, edge_features, max_atoms=12, data_objects.append(data_obj) # Save data objects - torch.save(data_objects, processed_path / name + ".pt") - \ No newline at end of file + torch.save(data_objects, self.processed_path / f"data_{i}.pt") + + def len(self): + """Return size of the dataset.""" + return len(self.names) + + def get(self, i): + """Fetch the processed graph(s) at the i-th index.""" + data_objects = torch.load(self.processed_path / f"data_{i}.pt") + return data_objects \ No newline at end of file