Skip to content

Commit

Permalink
Merge pull request #4 from GreeleyGroup/enh/graphs
Browse files Browse the repository at this point in the history
ENH: Add AtomsGraph class to create graph representation of selected atoms in a slab
  • Loading branch information
deshmukg authored Sep 7, 2023
2 parents c94eb86 + 72e986b commit cb3ba9c
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*POSCAR*
119 changes: 119 additions & 0 deletions src/graphs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""Classes to create bulk, surface, and adsorbate graphs."""

import abc

import networkx as nx
import numpy as np
from ase.neighborlist import build_neighbor_list


class AtomsGraph:
"""Create graph representation of a collection of atoms."""

def __init__(self, atoms, select_idx, max_atoms=50):
"""Initialize variables of the class.
Parameters
----------
atoms: ase.Atoms object
Atoms object containing all the atoms in the slab
select_idx: list or np.ndarray
List of indices of atoms that are to be included in the graph
neighbor_list: ase.neighborlist.NeighborList object
Neighbor list that defines bonds between atoms
max_atoms: int (default = 50)
The maximum number of atoms in the graph. Graphs that have fewer
atoms are padded with 0s to reach this value.
"""
# Save parameters
self.atoms = atoms
self.select_idx = select_idx
self.max_atoms = max_atoms

# Create graph
self.create_graph()

def create_graph(self):
"""Create a graph from an atoms object and neighbor_list."""
# Create neighbor list of atoms
self.neighbor_list = build_neighbor_list(
self.atoms, bothways=True, self_interaction=False
)

# Create NetworkX Multigraph
graph = nx.MultiGraph()

# Iterate over selected atoms and add them as nodes
for atom in self.atoms:
if atom.index in self.select_idx and atom.index not in list(graph.nodes()):
graph.add_node(
atom.index,
atomic_number=atom.number,
symbol=atom.symbol,
position=atom.position,
)

# Iterate over nodes, identify neighbors, and add edges between them
for n in graph.nodes():
# Get neighbors from neighbor list
neighbor_idx, _ = self.neighbor_list.get_neighbors(n)
# Iterate over neighbors
for nn in neighbor_idx:
# If neighbor is not in graph, add it as a node
if not graph.has_node(nn):
graph.add_node(
nn,
atomic_number=self.atoms[nn].number,
symbol=self.atoms[nn].symbol,
position=self.atoms[nn].position,
)
# Calculate bond distance
bond_dist = np.linalg.norm(
graph.nodes[n].position - graph.nodes[nn].position
)
graph.add_edge(n, nn, bond_distance=bond_dist)

# Assign graph object
self.graph = graph

def featurize(self, node_featurizer, bond_featurizer):
"""Featurize nodes and edges of the graph.
Parameters
----------
node_featurizer: TODO
Object that featurizes atoms
bond_featurizer: TODO
Object that featurizes bonds
"""
pass

def plot(self, filename=None):
"""Plot the graph using NetworkX.
Parameters
----------
filename: str (optional)
If provided, the plot is saved with the given filename.
"""
pass

def get_node_tensor(self):
"""Get the node matrix of the graph as a PyTorch Tensor.
Returns
-------
node_matrix: torch.Tensor
Node matrix
"""
pass

def get_edge_tensor(self):
"""Get the edge matrix of the graph as a PyTorch Tensor.
Returns
-------
edge_matrix: torch.Tensor
Edge tensor
"""
pass

0 comments on commit cb3ba9c

Please sign in to comment.