Skip to content

Commit

Permalink
Added data functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaurav S Deshmukh committed Sep 22, 2023
1 parent 33293ea commit 986380f
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 36 deletions.
111 changes: 94 additions & 17 deletions src/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def __init__(self, root, prop_csv):
if self.processed_status():
self.df_name_idx = pd.read_csv(self.index_path)

def process_data(self, z_cutoffs, node_features, edge_features, pad=False,
max_atoms=12, encoder=OneHotEncoder()):
def process_data(self, z_cutoffs, node_features, edge_features, max_atoms=None,
encoder=OneHotEncoder()):
"""Process raw data in the root directory into PyTorch Data and save.
Each atomic structure in the root directory is partitioned based on the
Expand All @@ -84,11 +84,9 @@ def process_data(self, z_cutoffs, node_features, edge_features, pad=False,
partition. For e.g., specify [["bulk_bond_distance"],
["surface_bond_distance"], ["adsorbate_bond_distance"]] for
a typical bulk + surface + adsorbate partition.
pad: bool
Whether to pad the graph with empty nodes to make total nodes add
up to max_atoms
max_atoms: int
Maximum number of nodes in graph. Only used if pad is True.
max_atoms: int (default = None)
Maximum number of nodes in graph. If a value is provided, graphs are
padded to make sure the total number of nodes matches max_atoms.
encoder: OneHotEncoder object
Encoder to convert properties to vectors
"""
Expand Down Expand Up @@ -121,7 +119,6 @@ def process_data(self, z_cutoffs, node_features, edge_features, pad=False,
part_idx,
node_features=node_features[j],
edge_features=edge_features[j],
pad=pad,
max_atoms=max_atoms,
encoder=encoder
)
Expand Down Expand Up @@ -184,8 +181,8 @@ def __init__(self, atoms):
self.atoms = atoms
self.data = []

def process_data(self, z_cutoffs, node_features, edge_features, pad=False,
max_atoms=12, encoder=OneHotEncoder()):
def process_data(self, z_cutoffs, node_features, edge_features, max_atoms=None,
encoder=OneHotEncoder()):
"""Process list of Atoms objects into PyTorch Data and save.
Each atomic structure in the root directory is partitioned based on the
Expand All @@ -209,11 +206,9 @@ def process_data(self, z_cutoffs, node_features, edge_features, pad=False,
partition. For e.g., specify [["bulk_bond_distance"],
["surface_bond_distance"], ["adsorbate_bond_distance"]] for
a typical bulk + surface + adsorbate partition.
pad: bool
Whether to pad the graph with empty nodes to make total nodes add
up to max_atoms
max_atoms: int
Maximum number of nodes in graph. Only used if pad is True.
max_atoms: int (default is None)
Maximum number of nodes in graph. If a value is provided, graphs are
padded to make sure the total number of nodes matches max_atoms.
encoder: OneHotEncoder object
Encoder to convert properties to vectors
"""
Expand All @@ -230,7 +225,6 @@ def process_data(self, z_cutoffs, node_features, edge_features, pad=False,
part_idx,
node_features=node_features[j],
edge_features=edge_features[j],
pad=pad,
max_atoms=max_atoms,
encoder=encoder
)
Expand All @@ -255,6 +249,73 @@ def get(self, i):
data_objects = self.data[i]
return data_objects

def load_dataset(root, prop_csv, process_dict=None):
"""Load an AtomsDataset at the path given by root.
If process_dict is provided, the process_data method of AtomsDataset is called
to convert the atomic structures to graphs based on the given parameters in
process_dict. This should be used when the dataset is created for the first
time.
Parameters
----------
root: str
Path to the dataset
prop_csv: str
Path to the file mapping atomic structure filename and property.
This filename will typically have two columns, the first with the
names of the cif files and the second with the
corresponding target property values.
process_dict: dict (default = None)
If this is provided, atomic structures at root will be processed into
graphs and stored under a "processed" subdirectory. Only use this when
creating a new dataset. This should contain the following keys: z_cutoffs,
node_features, edge_features, max_atoms (optional), encoder (optional).
Refer to the documentation of process_atoms for more information regarding
these parameters.
Returns
-------
dataset: AtomsDataset
Initialized AtomsDataset object
"""
dataset = AtomsDataset(root, prop_csv)
if process_dict is not None:
dataset.process_data(**process_dict)

return dataset

def load_datapoints(atoms, process_dict):
"""Load AtomsDatapoints for the provided ase.Atoms or list of ase.Atoms.
If process_dict is provided, the process_data method of AtomsDatapoints is called
to convert the atomic structures to graphs based on the given parameters in
process_dict. This should be used when the dataset is created for the first
time.
Parameters
----------
atoms: ase.Atoms object or a list of ase.Atoms objects
Structures for which predictions are to be made.
process_dict: dict
Parameters to process the provided Atoms objects into graphs.
This should contain the following keys: z_cutoffs, node_features,
edge_features, max_atoms (optional), encoder (optional). Refer to the
documentation of process_atoms for more information regarding these
parameters.
Returns
-------
datapoints: AtomsDatapoints
Initialized AtomsDatapoints object
"""
datapoints = AtomsDatapoints(atoms)
if process_dict is not None:
datapoints.process_data(**process_dict)

return datapoints


if __name__ == "__main__":
# Get path to root directory
data_root_path = Path(REPO_PATH) / "data" / "S_calcs"
Expand All @@ -274,4 +335,20 @@ def get(self, i):
# ["adsorbate_bond_distance"],
# ])
print(dataset[0][-1].x)
print(dataset.df_name_idx.head())
print(dataset.df_name_idx.head())

# Create datapoint
atoms = read(data_root_path / "Pt_3_Rh_9_-7-7-S.cif")
datapoint = AtomsDatapoints(atoms)
datapoint.process_data(z_cutoffs=[13., 20.],
node_features=[
["atomic_number", "dband_center"],
["atomic_number", "reactivity"],
["atomic_number", "reactivity"],
],
edge_features=[
["bulk_bond_distance"],
["surface_bond_distance"],
["adsorbate_bond_distance"],
])
print(datapoint.get(0))
12 changes: 6 additions & 6 deletions src/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class AtomsGraph:
"""Create graph representation of a collection of atoms."""

def __init__(self, atoms, select_idx, pad=True, max_atoms=50):
def __init__(self, atoms, select_idx, max_atoms=None):
"""Initialize variables of the class.
Parameters
Expand All @@ -25,14 +25,14 @@ def __init__(self, atoms, select_idx, pad=True, max_atoms=50):
List of indices of atoms that are to be included in the graph
neighbor_list: ase.neighborlist.NeighborList object
Neighbor list that defines bonds between atoms
max_atoms: int (default = 50)
The maximum number of atoms in the graph. Graphs that have fewer
atoms are padded with 0s to reach this value.
max_atoms: int (default = None)
The maximum number of atoms in the graph. If it is not None, graphs
that have fewer nodes than max_atoms are padded with 0s to ensure
that the total number of nodes is equal to max_atoms.
"""
# Save parameters
self.atoms = atoms
self.select_idx = select_idx
self.pad = pad
self.max_atoms = max_atoms

# Create graph
Expand Down Expand Up @@ -109,7 +109,7 @@ def create_graph(self):
graph.add_edge(n, self.map_idx_node[nn], bond_distance=bond_dist)

# Pad graph
if self.pad:
if self.max_atoms is not None:
graph = self.pad_graph(graph)

# Add coordination numbers
Expand Down
19 changes: 6 additions & 13 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ def featurize_atoms(
select_idx,
node_features,
edge_features,
pad=True,
max_atoms=50,
max_atoms=None,
encoder=OneHotEncoder(),
):
"""Featurize atoms and bonds with the chosen featurizers.
Expand All @@ -75,14 +74,10 @@ def featurize_atoms(
Names of edge featurizers to use (current options: bulk_bond_distance,
surface_bond_distance, adsorbate_bond_distance). All of these encode
bond distance using a one-hot encoder, but the bounds for each vary.
pad: bool
If True, the graph is padded to ensure the number of nodes is equal to
max_atoms. In that case, the blank nodes have all 0s in their node tensors.
max_atoms: int (default = 50)
Maximum number of allowed atoms. If the number of atoms in the graph are
fewer than this number, the graph is padded with empty nodes. This is
required to make the sizes of the node feature tensors consistent across
structures.
max_atoms: int (default = None)
Maximum number of allowed atoms. If it is not None, graphs
that have fewer nodes than max_atoms are padded with 0s to ensure
that the total number of nodes is equal to max_atoms.
encoder: encoder object from featurizers.py
Currently only the OneHotEncoder is supported
Expand All @@ -93,8 +88,7 @@ def featurize_atoms(
corresponding tensors as values.
"""
# Create graph
atoms_graph = AtomsGraph(atoms=atoms, select_idx=select_idx, max_atoms=max_atoms,
pad=pad)
atoms_graph = AtomsGraph(atoms=atoms, select_idx=select_idx, max_atoms=max_atoms)

# Collect node featurizers
node_feats = []
Expand Down Expand Up @@ -154,6 +148,5 @@ def featurize_atoms(
["atomic_number", "dband_center"],
["bulk_bond_distance"],
max_atoms=34,
pad=False,
)
print(feat_dict)

0 comments on commit 986380f

Please sign in to comment.