-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Gaurav S Deshmukh
committed
Sep 7, 2023
1 parent
c94eb86
commit fa0f476
Showing
1 changed file
with
119 additions
and
0 deletions.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,119 @@ | ||
| """Classes to create bulk, surface, and adsorbate graphs.""" | ||
|
|
||
| import abc | ||
| import networkx as nx | ||
| import numpy as np | ||
|
|
||
| from ase.neighborlist import build_neighbor_list | ||
|
|
||
|
|
||
| class AtomsGraph: | ||
| """Create graph representation of a collection of atoms.""" | ||
| def __init__(self, atoms, select_idx, max_atoms=50): | ||
| """Initialize variables of the class. | ||
| Parameters | ||
| ---------- | ||
| atoms: ase.Atoms object | ||
| Atoms object containing all the atoms in the slab | ||
| select_idx: list or np.ndarray | ||
| List of indices of atoms that are to be included in the graph | ||
| neighbor_list: ase.neighborlist.NeighborList object | ||
| Neighbor list that defines bonds between atoms | ||
| max_atoms: int (default = 50) | ||
| The maximum number of atoms in the graph. Graphs that have fewer | ||
| atoms are padded with 0s to reach this value. | ||
| """ | ||
| # Save parameters | ||
| self.atoms = atoms | ||
| self.select_idx = select_idx | ||
| self.max_atoms = max_atoms | ||
|
|
||
| # Create graph | ||
| self.create_graph() | ||
|
|
||
| def create_graph(self): | ||
| """Create a graph from an atoms object and neighbor_list.""" | ||
| # Create neighbor list of atoms | ||
| self.neighbor_list = build_neighbor_list( | ||
| self.atoms, bothways=True, self_interaction=False | ||
| ) | ||
|
|
||
| # Create NetworkX Multigraph | ||
| graph = nx.MultiGraph() | ||
|
|
||
| # Iterate over selected atoms and add them as nodes | ||
| for atom in self.atoms: | ||
| if atom.index in self.select_idx and atom.index not in list(graph.nodes()): | ||
| graph.add_node( | ||
| atom.index, | ||
| atomic_number=atom.number, | ||
| symbol=atom.symbol, | ||
| position=atom.position, | ||
| ) | ||
|
|
||
| # Iterate over nodes, identify neighbors, and add edges between them | ||
| for n in graph.nodes(): | ||
| # Get neighbors from neighbor list | ||
| neighbor_idx, _ = self.neighbor_list.get_neighbors(n) | ||
| # Iterate over neighbors | ||
| for nn in neighbor_idx: | ||
| # If neighbor is not in graph, add it as a node | ||
| if not graph.has_node(nn): | ||
| graph.add_node( | ||
| nn, | ||
| atomic_number=self.atoms[nn].number, | ||
| symbol=self.atoms[nn].symbol, | ||
| position=self.atoms[nn].position | ||
| ) | ||
| # Calculate bond distance | ||
| bond_dist = np.linalg.norm( | ||
| graph.nodes[n].position - graph.nodes[nn].position | ||
| ) | ||
| graph.add_edge(n, nn, bond_distance=bond_dist) | ||
|
|
||
| # Assign graph object | ||
| self.graph = graph | ||
|
|
||
| def featurize(self, node_featurizer, bond_featurizer): | ||
| """Featurize nodes and edges of the graph. | ||
| Parameters | ||
| ---------- | ||
| node_featurizer: TODO | ||
| Object that featurizes atoms | ||
| bond_featurizer: TODO | ||
| Object that featurizes bonds | ||
| """ | ||
| pass | ||
|
|
||
| def plot(self, filename=None): | ||
| """Plot the graph using NetworkX. | ||
| Parameters | ||
| ---------- | ||
| filename: str (optional) | ||
| If provided, the plot is saved with the given filename. | ||
| """ | ||
| pass | ||
|
|
||
| def get_node_tensor(self): | ||
| """Get the node matrix of the graph as a PyTorch Tensor. | ||
| Returns | ||
| ------- | ||
| node_matrix: torch.Tensor | ||
| Node matrix | ||
| """ | ||
| pass | ||
|
|
||
| def get_adjacency_tensor(self): | ||
| """Get the adjacency matrix of the graph as a PyTorch Tensor. | ||
| Returns | ||
| ------- | ||
| adj_matrix: torch.Tensor | ||
| Adjacency tensor | ||
| """ | ||
| pass | ||
|
|