Skip to content

Commit

Permalink
Added len and get methods
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaurav S Deshmukh committed Sep 20, 2023
1 parent 77c5620 commit 003f669
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions src/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(self, root, prop_csv):
corresponding target property values.
"""
super().__init__(root)
self.root_path = Path(self.root)

# Read csv
self.prop_csv = prop_csv
Expand Down Expand Up @@ -73,16 +74,13 @@ def process(self, z_cutoffs, node_features, edge_features, max_atoms=12,
encoder: OneHotEncoder object
Encoder to convert properties to vectors
"""
# Root path
root_path = Path(self.root)

# Create processed path if it doesn't exist
processed_path = Path(self.processed_dir).mkdir(exist_ok=True)
self.processed_path = Path(self.processed_dir).mkdir(exist_ok=True)

# Iterate over files and process them
for name in self.names:
for i, name in enumerate(self.names):
# Set file path
file_path = root_path / name + ".cif"
file_path = self.root_path / name + ".cif"

# Read structure
atoms = read(str(file_path))
Expand All @@ -92,12 +90,12 @@ def process(self, z_cutoffs, node_features, edge_features, max_atoms=12,

# Featurize partitions
data_objects = []
for i, part_idx in enumerate(part_atoms):
for j, part_idx in enumerate(part_atoms):
feat_dict = featurize_atoms(
atoms,
part_idx,
node_features=node_features[i],
edge_features=edge_features[i],\
node_features=node_features[j],
edge_features=edge_features[j],\
max_atoms=max_atoms,
encoder=encoder
)
Expand All @@ -112,5 +110,13 @@ def process(self, z_cutoffs, node_features, edge_features, max_atoms=12,
data_objects.append(data_obj)

# Save data objects
torch.save(data_objects, processed_path / name + ".pt")

torch.save(data_objects, self.processed_path / f"data_{i}.pt")

def len(self):
"""Return size of the dataset."""
return len(self.names)

def get(self, i):
"""Fetch the processed graph(s) at the i-th index."""
data_objects = torch.load(self.processed_path / f"data_{i}.pt")
return data_objects

0 comments on commit 003f669

Please sign in to comment.