From dc8f977ecb6e2936bf513a2e2d418d67b530f7fe Mon Sep 17 00:00:00 2001 From: Gaurav S Deshmukh Date: Thu, 21 Sep 2023 19:59:32 -0400 Subject: [PATCH] Added name index dataframe --- src/data.py | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/src/data.py b/src/data.py index 423fd26..fb3ec17 100644 --- a/src/data.py +++ b/src/data.py @@ -3,6 +3,7 @@ import csv from pathlib import Path +import pandas as pd import torch import tqdm @@ -36,7 +37,6 @@ def __init__(self, root, prop_csv): # Create processed path if it doesn't exist self.processed_path = Path(self.processed_dir) self.processed_path.mkdir(exist_ok=True) - self.process_flag = False # Read csv self.prop_csv = prop_csv @@ -53,6 +53,11 @@ def __init__(self, root, prop_csv): name: prop for name, prop in zip(self.names, self.props) } + # Load index.csv if processed + self.index_path = self.processed_path / "index.csv" + if self.processed_status(): + self.df_name_idx = pd.read_csv(self.index_path) + def process_data(self, z_cutoffs, node_features, edge_features, pad=False, max_atoms=12, encoder=OneHotEncoder()): """Process raw data in the root directory into PyTorch Data and save. @@ -86,9 +91,18 @@ def process_data(self, z_cutoffs, node_features, edge_features, pad=False, encoder: OneHotEncoder object Encoder to convert properties to vectors """ + # Create empty dataframe to store index and name correspondence + self.df_name_idx = pd.DataFrame( + {"index": [0] * len(self.names), "name": [""] * len(self.names)} + ) + # Iterate over files and process them for i, name in tqdm.tqdm(enumerate(self.names), desc="Processing data", total=len(self.names)): + # Map index to name + self.df_name_idx.loc[i, "index"] = i + self.df_name_idx.loc[i, "name"] = name + # Set file path file_path = self.root_path / name @@ -120,14 +134,11 @@ def process_data(self, z_cutoffs, node_features, edge_features, pad=False, ) data_objects.append(data_obj) - # Add name of structure - data_objects.append(name) - # Save data objects torch.save(data_objects, self.processed_path / f"data_{i}.pt") - # Set process flag to true - self.process_flag = True + # Save name-index dataframe + self.df_name_idx.to_csv(self.index_path, index=None) def len(self): """Return size of the dataset.""" @@ -140,7 +151,10 @@ def get(self, i): def processed_status(self): """Check if the dataset is processed.""" - return self.process_flag + if Path(self.index_path).exists(): + return True + else: + return False if __name__ == "__main__": @@ -161,4 +175,5 @@ def processed_status(self): # ["surface_bond_distance"], # ["adsorbate_bond_distance"], # ]) - print(dataset[0][-2].x) \ No newline at end of file + print(dataset[0][-1].x) + print(dataset.df_name_idx.head()) \ No newline at end of file