diff --git a/src/graphs.py b/src/graphs.py index 40a09b3..a3ac282 100644 --- a/src/graphs.py +++ b/src/graphs.py @@ -14,7 +14,7 @@ class AtomsGraph: """Create graph representation of a collection of atoms.""" - def __init__(self, atoms, select_idx, max_atoms=50): + def __init__(self, atoms, select_idx, pad=True, max_atoms=50): """Initialize variables of the class. Parameters @@ -32,6 +32,7 @@ def __init__(self, atoms, select_idx, max_atoms=50): # Save parameters self.atoms = atoms self.select_idx = select_idx + self.pad = pad self.max_atoms = max_atoms # Create graph @@ -108,7 +109,8 @@ def create_graph(self): graph.add_edge(n, self.map_idx_node[nn], bond_distance=bond_dist) # Pad graph - graph = self.pad_graph(graph) + if self.pad: + graph = self.pad_graph(graph) # Add coordination numbers for n in graph.nodes(): diff --git a/src/utils.py b/src/utils.py index 0d3592a..8436778 100644 --- a/src/utils.py +++ b/src/utils.py @@ -53,6 +53,7 @@ def featurize_atoms( select_idx, node_features, edge_features, + pad=True, max_atoms=50, encoder=OneHotEncoder(), ): @@ -74,6 +75,9 @@ def featurize_atoms( Names of edge featurizers to use (current options: bulk_bond_distance, surface_bond_distance, adsorbate_bond_distance). All of these encode bond distance using a one-hot encoder, but the bounds for each vary. + pad: bool + If True, the graph is padded to ensure the number of nodes is equal to + max_atoms. In that case, the blank nodes have all 0s in their node tensors. max_atoms: int (default = 50) Maximum number of allowed atoms. If the number of atoms in the graph are fewer than this number, the graph is padded with empty nodes. This is @@ -89,7 +93,8 @@ def featurize_atoms( corresponding tensors as values. """ # Create graph - atoms_graph = AtomsGraph(atoms=atoms, select_idx=select_idx, max_atoms=max_atoms) + atoms_graph = AtomsGraph(atoms=atoms, select_idx=select_idx, max_atoms=max_atoms, + pad=pad) # Collect node featurizers node_feats = [] @@ -140,7 +145,7 @@ def featurize_atoms( atoms = read("CONTCAR") - part_atoms = partition_structure(atoms, 3, z_cutoffs=[15, 23.5]) + part_atoms = partition_structure(atoms, z_cutoffs=[15, 23.5]) print(part_atoms) feat_dict = featurize_atoms( @@ -149,5 +154,6 @@ def featurize_atoms( ["atomic_number", "dband_center"], ["bulk_bond_distance"], max_atoms=34, + pad=False, ) print(feat_dict)