Skip to content

Commit

Permalink
Added name index dataframe
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaurav S Deshmukh committed Sep 21, 2023
1 parent 964dfcd commit dc8f977
Showing 1 changed file with 23 additions and 8 deletions.
31 changes: 23 additions & 8 deletions src/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import csv
from pathlib import Path

import pandas as pd
import torch
import tqdm

Expand Down Expand Up @@ -36,7 +37,6 @@ def __init__(self, root, prop_csv):
# Create processed path if it doesn't exist
self.processed_path = Path(self.processed_dir)
self.processed_path.mkdir(exist_ok=True)
self.process_flag = False

# Read csv
self.prop_csv = prop_csv
Expand All @@ -53,6 +53,11 @@ def __init__(self, root, prop_csv):
name: prop for name, prop in zip(self.names, self.props)
}

# Load index.csv if processed
self.index_path = self.processed_path / "index.csv"
if self.processed_status():
self.df_name_idx = pd.read_csv(self.index_path)

def process_data(self, z_cutoffs, node_features, edge_features, pad=False,
max_atoms=12, encoder=OneHotEncoder()):
"""Process raw data in the root directory into PyTorch Data and save.
Expand Down Expand Up @@ -86,9 +91,18 @@ def process_data(self, z_cutoffs, node_features, edge_features, pad=False,
encoder: OneHotEncoder object
Encoder to convert properties to vectors
"""
# Create empty dataframe to store index and name correspondence
self.df_name_idx = pd.DataFrame(
{"index": [0] * len(self.names), "name": [""] * len(self.names)}
)

# Iterate over files and process them
for i, name in tqdm.tqdm(enumerate(self.names), desc="Processing data",
total=len(self.names)):
# Map index to name
self.df_name_idx.loc[i, "index"] = i
self.df_name_idx.loc[i, "name"] = name

# Set file path
file_path = self.root_path / name

Expand Down Expand Up @@ -120,14 +134,11 @@ def process_data(self, z_cutoffs, node_features, edge_features, pad=False,
)
data_objects.append(data_obj)

# Add name of structure
data_objects.append(name)

# Save data objects
torch.save(data_objects, self.processed_path / f"data_{i}.pt")

# Set process flag to true
self.process_flag = True
# Save name-index dataframe
self.df_name_idx.to_csv(self.index_path, index=None)

def len(self):
"""Return size of the dataset."""
Expand All @@ -140,7 +151,10 @@ def get(self, i):

def processed_status(self):
"""Check if the dataset is processed."""
return self.process_flag
if Path(self.index_path).exists():
return True
else:
return False


if __name__ == "__main__":
Expand All @@ -161,4 +175,5 @@ def processed_status(self):
# ["surface_bond_distance"],
# ["adsorbate_bond_distance"],
# ])
print(dataset[0][-2].x)
print(dataset[0][-1].x)
print(dataset.df_name_idx.head())

0 comments on commit dc8f977

Please sign in to comment.