Skip to content

Commit

Permalink
AtomsDataset works
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaurav S Deshmukh committed Sep 21, 2023
1 parent c8adcb0 commit e98d998
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 9 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
*POSCAR*
*CONTCAR*
*.csv
data/*
!data/dband_centers.csv
__pycache__
*.cif
48 changes: 39 additions & 9 deletions src/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
from pathlib import Path

import torch
import tqdm

from ase.io import read
from torch_geometric.data import Data, Dataset

from utils import partition_structure, featurize_atoms
from featurizers import OneHotEncoder
from constants import REPO_PATH

class AtomsDataset(Dataset):
"""Class to hold a dataset containing graphs of atomic_structures."""
Expand All @@ -22,7 +24,7 @@ def __init__(self, root, prop_csv):
---------
root: str
Path to the directory in which atomic structures are stored
pro_csv: str
prop_csv: str
Path to the file mapping atomic structure filename and property.
This filename will typically have two columns, the first with the
names of the cif files and the second with the
Expand All @@ -31,6 +33,10 @@ def __init__(self, root, prop_csv):
super().__init__(root)
self.root_path = Path(self.root)

# Create processed path if it doesn't exist
self.processed_path = Path(self.processed_dir)
self.processed_path.mkdir(exist_ok=True)

# Read csv
self.prop_csv = prop_csv
self.names = []
Expand All @@ -46,8 +52,8 @@ def __init__(self, root, prop_csv):
name: prop for name, prop in zip(self.names, self.props)
}

def process(self, z_cutoffs, node_features, edge_features, max_atoms=12,
encoder=OneHotEncoder()):
def process_data(self, z_cutoffs, node_features, edge_features, pad=False,
max_atoms=12, encoder=OneHotEncoder()):
"""Process raw data in the root directory into PyTorch Data and save.
Each atomic structure in the root directory is partitioned based on the
Expand All @@ -71,14 +77,17 @@ def process(self, z_cutoffs, node_features, edge_features, max_atoms=12,
partition. For e.g., specify [["bulk_bond_distance"],
["surface_bond_distance"], ["adsorbate_bond_distance"]] for
a typical bulk + surface + adsorbate partition.
pad: bool
Whether to pad the graph with empty nodes to make total nodes add
up to max_atoms
max_atoms: int
Maximum number of nodes in graph. Only used if pad is True.
encoder: OneHotEncoder object
Encoder to convert properties to vectors
"""
# Create processed path if it doesn't exist
self.processed_path = Path(self.processed_dir).mkdir(exist_ok=True)

# Iterate over files and process them
for i, name in enumerate(self.names):
for i, name in tqdm.tqdm(enumerate(self.names), desc="Processing data",
total=len(self.names)):
# Set file path
file_path = self.root_path / name

Expand All @@ -95,7 +104,8 @@ def process(self, z_cutoffs, node_features, edge_features, max_atoms=12,
atoms,
part_idx,
node_features=node_features[j],
edge_features=edge_features[j],\
edge_features=edge_features[j],
pad=pad,
max_atoms=max_atoms,
encoder=encoder
)
Expand All @@ -122,4 +132,24 @@ def len(self):
def get(self, i):
"""Fetch the processed graph(s) at the i-th index."""
data_objects = torch.load(self.processed_path / f"data_{i}.pt")
return data_objects
return data_objects

if __name__ == "__main__":
# Get path to root directory
data_root_path = Path(REPO_PATH) / "data" / "S_calcs"
prop_csv_path = data_root_path / "name_prop.csv"

# Create dataset
dataset = AtomsDataset(data_root_path, prop_csv_path)
dataset.process_data(z_cutoffs=[13., 20.],
node_features=[
["atomic_number", "dband_center"],
["atomic_number", "reactivity"],
["atomic_number", "reactivity"],
],
edge_features=[
["bulk_bond_distance"],
["surface_bond_distance"],
["adsorbate_bond_distance"],
])
print(dataset[0][-2].x)

0 comments on commit e98d998

Please sign in to comment.