Skip to content

Commit

Permalink
Added AtomsDatapoint
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaurav S Deshmukh committed Sep 22, 2023
1 parent dc8f977 commit 33293ea
Showing 1 changed file with 98 additions and 0 deletions.
98 changes: 98 additions & 0 deletions src/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import tqdm

from ase.io import read
from ase import Atoms
from torch_geometric.data import Data, Dataset

from utils import partition_structure, featurize_atoms
Expand Down Expand Up @@ -156,6 +157,103 @@ def processed_status(self):
else:
return False

class AtomsDatapoints:
"""Class to hold atomic structures as a datapoints (without targets).
This main difference between this class and AtomsDataset is that this is
initialized with a list of atoms objects (as opposed to a directory with
files containing atomic structures) without any targets specified. This is
useful to make predictions on atomic structures for which true target values
are not known, i.e., previously unseen structures.
"""
def __init__(self, atoms):
"""Initialize an AtomsDatapoint.
Atomic structures provided in the list are initialized.
Paramters
---------
atoms: ase.Atoms object or a list of ase.Atoms objects
Structures for which predictions are to be made.
"""
# If single object, convert to list
if isinstance(atoms, Atoms):
atoms = [atoms]

# Save object
self.atoms = atoms
self.data = []

def process_data(self, z_cutoffs, node_features, edge_features, pad=False,
max_atoms=12, encoder=OneHotEncoder()):
"""Process list of Atoms objects into PyTorch Data and save.
Each atomic structure in the root directory is partitioned based on the
given z_cutoffs and each partition is featurized according to the given
node_features and edge_features. The featurized graphs are converted
into Data objects and stored in the "processed" directory under root.
Parameters
----------
z_cutoffs: list or np.ndarray
List of z-coordinates based on which atomic structures are
partitioned. The number of partitions is equal to one more than the
length of z_cutoffs.
node_features: list[list]
List of lists of node featurization methods to be used for each
partition. For e.g., specify [["atomic_number", "dband_center"],
["atomic_number", "reactivity"], ["atomic_number", "reactivity"]] for
a typical bulk + surface + adsorbate partition.
edge_features: list[list]
List of lists of edge featurization methods to be used for each
partition. For e.g., specify [["bulk_bond_distance"],
["surface_bond_distance"], ["adsorbate_bond_distance"]] for
a typical bulk + surface + adsorbate partition.
pad: bool
Whether to pad the graph with empty nodes to make total nodes add
up to max_atoms
max_atoms: int
Maximum number of nodes in graph. Only used if pad is True.
encoder: OneHotEncoder object
Encoder to convert properties to vectors
"""
# Iterate over files and process them
for atoms_obj in self.atoms:
# Partition structure
part_atoms = partition_structure(atoms_obj, z_cutoffs)

# Featurize partitions
data_objects = []
for j, part_idx in enumerate(part_atoms):
feat_dict = featurize_atoms(
atoms_obj,
part_idx,
node_features=node_features[j],
edge_features=edge_features[j],
pad=pad,
max_atoms=max_atoms,
encoder=encoder
)

# Convert to Data object
data_obj = Data(
x=feat_dict["node_tensor"],
edge_index=feat_dict["edge_indices"],
edge_attr=feat_dict["edge_tensor"],
)
data_objects.append(data_obj)

# Save data objects
self.data.append(data_objects)

def len(self):
"""Return size of the dataset."""
return len(self.data)

def get(self, i):
"""Fetch the processed graph(s) at the i-th index."""
data_objects = self.data[i]
return data_objects

if __name__ == "__main__":
# Get path to root directory
Expand Down

0 comments on commit 33293ea

Please sign in to comment.