Skip to content

Commit

Permalink
Added padding and fixed bond distances
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaurav S Deshmukh committed Sep 12, 2023
1 parent d80ef8b commit 2347f03
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 38 deletions.
54 changes: 37 additions & 17 deletions src/featurizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self, encoder):
pass

@abc.abstractmethod
def featurize_graph(self, graph):
def featurize_graph(self, atoms_graph):
"""Featurize an AtomsGraph.
This class should create a feature tensor from the given graph. This
Expand Down Expand Up @@ -115,8 +115,8 @@ def name(self):

class AtomNumFeaturizer(Featurizer):
"""Featurize nodes based on atomic number."""
def __init__(self, encoder, min=1, max=80, n_intervals=10):
"""Initialize featurizer with min = 1, max = 80, n_intervals = 10.
def __init__(self, encoder, min=0, max=80, n_intervals=10):
"""Initialize featurizer with min = 0, max = 80, n_intervals = 10.
Parameters
----------
Expand All @@ -138,7 +138,7 @@ def __init__(self, encoder, min=1, max=80, n_intervals=10):
self.encoder = encoder
self.encoder.fit(self.min, self.max, self.n_intervals)

def featurize_graph(self, graph):
def featurize_graph(self, atoms_graph):
"""Featurize an AtomsGraph.
Parameters
Expand All @@ -147,7 +147,7 @@ def featurize_graph(self, graph):
A graph of a collection of bulk, surface, or adsorbate atoms.
"""
# Get atomic numbers
atom_num_dict = nx.get_node_attributes(graph, "atomic_number")
atom_num_dict = nx.get_node_attributes(atoms_graph.graph, "atomic_number")
atom_num_arr = np.array(list(atom_num_dict.values()))

# Create node feature matrix
Expand Down Expand Up @@ -202,7 +202,7 @@ def __init__(self, encoder, min=-5, max=3, n_intervals=10):
for row in csv_reader:
self.map_dict[int(row[0])] = float(row[1])

def featurize_graph(self, graph):
def featurize_graph(self, atoms_graph):
"""Featurize an AtomsGraph.
Parameters
Expand All @@ -211,7 +211,7 @@ def featurize_graph(self, graph):
A graph of a collection of bulk, surface, or adsorbate atoms.
"""
# Get atomic numbers
atom_num_dict = nx.get_node_attributes(graph, "atomic_number")
atom_num_dict = nx.get_node_attributes(atoms_graph.graph, "atomic_number")
atom_num_arr = np.array(list(atom_num_dict.values()))

# Map from atomic number to d-band center
Expand Down Expand Up @@ -263,11 +263,11 @@ def __init__(self, encoder, min=1, max=12, n_intervals=12):
self.encoder.fit(self.min, self.max, self.n_intervals)

# Create a map between atomic number and number of valence electrons
self.map_dict = {1: 1, 2:0}
self.map_dict = {0: 0, 1: 1, 2:0}
for i in range(3, 21, 1):
self.map_dict[i] = min(element(i).ec.get_valence().ne(), 12)

def featurize_graph(self, graph):
def featurize_graph(self, atoms_graph):
"""Featurize an AtomsGraph.
Parameters
Expand All @@ -276,7 +276,7 @@ def featurize_graph(self, graph):
A graph of a collection of bulk, surface, or adsorbate atoms.
"""
# Get atomic numbers
atom_num_dict = nx.get_node_attributes(graph, "atomic_number")
atom_num_dict = nx.get_node_attributes(atoms_graph.graph, "atomic_number")
atom_num_arr = np.array(list(atom_num_dict.values()))

# Create node feature matrix
Expand Down Expand Up @@ -324,7 +324,7 @@ def __init__(self, encoder, min=1, max=15, n_intervals=15):
self.encoder = encoder
self.encoder.fit(self.min, self.max, self.n_intervals)

def featurize_graph(self, graph):
def featurize_graph(self, atoms_graph):
"""Featurize an AtomsGraph.
Parameters
Expand All @@ -333,7 +333,7 @@ def featurize_graph(self, graph):
A graph of a collection of bulk, surface, or adsorbate atoms.
"""
# Get atomic numbers
cn_dict = nx.get_node_attributes(graph, "coordination")
cn_dict = nx.get_node_attributes(atoms_graph.graph, "coordination")
cn_arr = np.array(list(cn_dict.values()))

# Create node feature matrix
Expand Down Expand Up @@ -381,7 +381,7 @@ def __init__(self, encoder, min, max, n_intervals):
self.encoder = encoder
self.encoder.fit(self.min, self.max, self.n_intervals)

def featurize_graph(self, graph):
def featurize_graph(self, atoms_graph):
"""Featurize an AtomsGraph.
Parameters
Expand All @@ -390,11 +390,14 @@ def featurize_graph(self, graph):
A graph of a collection of bulk, surface, or adsorbate atoms.
"""
# Get atomic numbers
bond_dist_dict = nx.get_edge_attributes(graph, "bond_distance")
bond_dist_dict = nx.get_edge_attributes(atoms_graph.graph, "bond_distance")
bond_dist_arr = np.array(list(bond_dist_dict.values()))

# Create node feature matrix
self._feat_tensor = self.encoder.transform(bond_dist_arr)

# Create list of edge indices
self._edge_indices = torch.Tensor(list(atoms_graph.graph.edges()))

@property
def feat_tensor(self):
Expand All @@ -407,6 +410,17 @@ def feat_tensor(self):
M = n_intervals provided to the encoder
"""
return self._feat_tensor

@property
def edge_indices(self):
"""Return list of edge indices.
Returns
-------
edge_indices: torch.Tensor
Tensor with edge indices
"""
return self._edge_indices

@staticmethod
def name():
Expand Down Expand Up @@ -448,6 +462,12 @@ def __init__(self, encoder, min=0, max=4, n_intervals=16):

atoms = read("CONTCAR")
g = AtomsGraph(atoms, select_idx=[1, 10, 11, 12])
dbf = DBandFeaturizer(OneHotEncoder())
dbf.featurize_graph(g.graph)
print(dbf.feat_tensor)

anf = AtomNumFeaturizer(OneHotEncoder())
anf.featurize_graph(g)
print(anf.feat_tensor.shape)

bdf = BulkBondDistanceFeaturizer(OneHotEncoder())
bdf.featurize_graph(g)
print(bdf.feat_tensor)
print(bdf.edge_indices)
127 changes: 106 additions & 21 deletions src/graphs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""Classes to create bulk, surface, and adsorbate graphs."""

from copy import deepcopy

import networkx as nx
import numpy as np
from ase.neighborlist import build_neighbor_list, natural_cutoffs
from ase.neighborlist import (build_neighbor_list, natural_cutoffs,
NewPrimitiveNeighborList)


class AtomsGraph:
Expand Down Expand Up @@ -33,56 +36,77 @@ def __init__(self, atoms, select_idx, max_atoms=50):

def create_graph(self):
"""Create a graph from an atoms object and neighbor_list."""
# Create neighbor list of atoms
self.neighbor_list = build_neighbor_list(
self.atoms,
natural_cutoffs(self.atoms),
bothways=True,
self_interaction=False
)

# Create NetworkX Multigraph
graph = nx.MultiGraph()

# Iterate over selected atoms and add them as nodes
self.node_count = 0
self.map_idx_node = {}
for atom in self.atoms:
if atom.index in self.select_idx and atom.index not in list(graph.nodes()):
if (atom.index in self.select_idx
and atom.index not in list(graph.nodes(data="index"))):
graph.add_node(
atom.index,
self.node_count,
index=atom.index,
atomic_number=atom.number,
symbol=atom.symbol,
position=atom.position,
)
self.map_idx_node[atom.index] = self.node_count
self.node_count += 1

# Create neighbor list of atoms
self.neighbor_list = build_neighbor_list(
self.atoms,
natural_cutoffs(self.atoms),
bothways=True,
self_interaction=False,
primitive=NewPrimitiveNeighborList
)

# Iterate over nodes, identify neighbors, and add edges between them
node_list = list(graph.nodes())
bond_tuples = []
for n in node_list:
# Get neighbors from neighbor list
neighbor_idx, _ = self.neighbor_list.get_neighbors(n)
neighbor_idx, neighbor_offsets = self.neighbor_list.get_neighbors(
graph.nodes[n]["index"]
)
# Iterate over neighbors
for nn in neighbor_idx:
for nn, offset in zip(neighbor_idx, neighbor_offsets):
# Skip if self atom
if nn == graph.nodes[n]["index"]:
continue
# Save bond
bond = (n, nn)
bond = (graph.nodes[n]["index"], nn)
rev_bond = tuple(reversed(bond))
# Check if bond has already been added
if rev_bond in bond_tuples:
continue
else:
bond_tuples.append(bond)
# If neighbor is not in graph, add it as a node
if not graph.has_node(nn):
node_indices = nx.get_node_attributes(graph, "index")
if nn not in list(node_indices.values()):
graph.add_node(
nn,
self.node_count,
index=nn,
atomic_number=self.atoms[nn].number,
symbol=self.atoms[nn].symbol,
position=self.atoms[nn].position,
)
self.map_idx_node[nn] = self.node_count
self.node_count += 1
# Calculate bond distance
bond_dist = np.linalg.norm(
graph.nodes[n]["position"] - graph.nodes[nn]["position"]
bond_dist = self.calc_minimum_distance(
self.atoms[graph.nodes[n]["index"]].position,
self.atoms[nn].position,
offset,
)
graph.add_edge(
n, self.map_idx_node[nn], bond_distance=bond_dist
)
graph.add_edge(n, nn, bond_distance=bond_dist)

# Pad graph
graph = self.pad_graph(graph)

# Add coordination numbers
for n in graph.nodes():
Expand All @@ -91,6 +115,59 @@ def create_graph(self):
# Assign graph object
self.graph = graph

def pad_graph(self, graph):
"""Pad graph with empty nodes.
This can be used to make sure that the number of nodes in each graph is
equal to max_atoms
Parameters
----------
graph: Networkx.Graph
A Networkx graph
Returns
-------
padded_graph: Networkx.Graph
Padded graph
"""
padded_graph = deepcopy(graph)

for i in range(self.node_count, self.max_atoms, 1):
padded_graph.add_node(
i,
index=-1,
atomic_number=0,
symbol="",
position=np.zeros(3)
)

return padded_graph

def calc_minimum_distance(self, pos_1, pos_2, offset):
"""
Calculate minimum distance between two atoms.
Parameters
----------
pos_1: np.ndarray
Position of first atom in x, y, z coordinates
pos_2: np.ndarray
Position of second atom in x, y, z coordinates
offset: np.ndarray
Offset returned by the neighbor list in ASE
"""
# First, calculate the distance without offset
dist_1 = np.linalg.norm(pos_1 - pos_2)

# Next calculate the distance by applying offset to second position
dist_2 = np.linalg.norm(pos_1 - (pos_2 + offset @ self.atoms.get_cell()))

# Get minimum distance
min_dist = min(dist_1, dist_2)

return min_dist

def plot(self, filename=None):
"""Plot the graph using NetworkX.
Expand All @@ -100,3 +177,11 @@ def plot(self, filename=None):
If provided, the plot is saved with the given filename.
"""
pass

if __name__ == "__main__":
from ase.io import read

atoms = read("CONTCAR")
g = AtomsGraph(atoms, select_idx=[1, 10, 11, 12])
print(g.map_idx_node)
print(g.graph.edges(data="bond_distance"))

0 comments on commit 2347f03

Please sign in to comment.