Skip to content

Commit

Permalink
Added option to not pad graphs
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaurav S Deshmukh committed Sep 19, 2023
1 parent 4cced65 commit 77c5620
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
6 changes: 4 additions & 2 deletions src/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class AtomsGraph:
"""Create graph representation of a collection of atoms."""

def __init__(self, atoms, select_idx, max_atoms=50):
def __init__(self, atoms, select_idx, pad=True, max_atoms=50):
"""Initialize variables of the class.
Parameters
Expand All @@ -32,6 +32,7 @@ def __init__(self, atoms, select_idx, max_atoms=50):
# Save parameters
self.atoms = atoms
self.select_idx = select_idx
self.pad = pad
self.max_atoms = max_atoms

# Create graph
Expand Down Expand Up @@ -108,7 +109,8 @@ def create_graph(self):
graph.add_edge(n, self.map_idx_node[nn], bond_distance=bond_dist)

# Pad graph
graph = self.pad_graph(graph)
if self.pad:
graph = self.pad_graph(graph)

# Add coordination numbers
for n in graph.nodes():
Expand Down
10 changes: 8 additions & 2 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def featurize_atoms(
select_idx,
node_features,
edge_features,
pad=True,
max_atoms=50,
encoder=OneHotEncoder(),
):
Expand All @@ -74,6 +75,9 @@ def featurize_atoms(
Names of edge featurizers to use (current options: bulk_bond_distance,
surface_bond_distance, adsorbate_bond_distance). All of these encode
bond distance using a one-hot encoder, but the bounds for each vary.
pad: bool
If True, the graph is padded to ensure the number of nodes is equal to
max_atoms. In that case, the blank nodes have all 0s in their node tensors.
max_atoms: int (default = 50)
Maximum number of allowed atoms. If the number of atoms in the graph are
fewer than this number, the graph is padded with empty nodes. This is
Expand All @@ -89,7 +93,8 @@ def featurize_atoms(
corresponding tensors as values.
"""
# Create graph
atoms_graph = AtomsGraph(atoms=atoms, select_idx=select_idx, max_atoms=max_atoms)
atoms_graph = AtomsGraph(atoms=atoms, select_idx=select_idx, max_atoms=max_atoms,
pad=pad)

# Collect node featurizers
node_feats = []
Expand Down Expand Up @@ -140,7 +145,7 @@ def featurize_atoms(

atoms = read("CONTCAR")

part_atoms = partition_structure(atoms, 3, z_cutoffs=[15, 23.5])
part_atoms = partition_structure(atoms, z_cutoffs=[15, 23.5])
print(part_atoms)

feat_dict = featurize_atoms(
Expand All @@ -149,5 +154,6 @@ def featurize_atoms(
["atomic_number", "dband_center"],
["bulk_bond_distance"],
max_atoms=34,
pad=False,
)
print(feat_dict)

0 comments on commit 77c5620

Please sign in to comment.