Skip to content

ENH: Add AtomsGraph class to create graph representation of selected atoms in a slab #4

Merged
merged 4 commits into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*POSCAR*
119 changes: 119 additions & 0 deletions src/graphs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""Classes to create bulk, surface, and adsorbate graphs."""

import abc

import networkx as nx
import numpy as np
from ase.neighborlist import build_neighbor_list


class AtomsGraph:
"""Create graph representation of a collection of atoms."""

def __init__(self, atoms, select_idx, max_atoms=50):
"""Initialize variables of the class.
Parameters
----------
atoms: ase.Atoms object
Atoms object containing all the atoms in the slab
select_idx: list or np.ndarray
List of indices of atoms that are to be included in the graph
neighbor_list: ase.neighborlist.NeighborList object
Neighbor list that defines bonds between atoms
max_atoms: int (default = 50)
The maximum number of atoms in the graph. Graphs that have fewer
atoms are padded with 0s to reach this value.
"""
# Save parameters
self.atoms = atoms
self.select_idx = select_idx
self.max_atoms = max_atoms

# Create graph
self.create_graph()

def create_graph(self):
"""Create a graph from an atoms object and neighbor_list."""
# Create neighbor list of atoms
self.neighbor_list = build_neighbor_list(
self.atoms, bothways=True, self_interaction=False
)

# Create NetworkX Multigraph
graph = nx.MultiGraph()

# Iterate over selected atoms and add them as nodes
for atom in self.atoms:
if atom.index in self.select_idx and atom.index not in list(graph.nodes()):
graph.add_node(
atom.index,
atomic_number=atom.number,
symbol=atom.symbol,
position=atom.position,
)

# Iterate over nodes, identify neighbors, and add edges between them
for n in graph.nodes():
# Get neighbors from neighbor list
neighbor_idx, _ = self.neighbor_list.get_neighbors(n)
# Iterate over neighbors
for nn in neighbor_idx:
# If neighbor is not in graph, add it as a node
if not graph.has_node(nn):
graph.add_node(
nn,
atomic_number=self.atoms[nn].number,
symbol=self.atoms[nn].symbol,
position=self.atoms[nn].position,
)
# Calculate bond distance
bond_dist = np.linalg.norm(
graph.nodes[n].position - graph.nodes[nn].position
)
graph.add_edge(n, nn, bond_distance=bond_dist)

# Assign graph object
self.graph = graph

def featurize(self, node_featurizer, bond_featurizer):
"""Featurize nodes and edges of the graph.
Parameters
----------
node_featurizer: TODO
Object that featurizes atoms
bond_featurizer: TODO
Object that featurizes bonds
"""
pass

def plot(self, filename=None):
"""Plot the graph using NetworkX.
Parameters
----------
filename: str (optional)
If provided, the plot is saved with the given filename.
"""
pass

def get_node_tensor(self):
"""Get the node matrix of the graph as a PyTorch Tensor.
Returns
-------
node_matrix: torch.Tensor
Node matrix
"""
pass

def get_edge_tensor(self):
"""Get the edge matrix of the graph as a PyTorch Tensor.
Returns
-------
edge_matrix: torch.Tensor
Edge tensor
"""
pass