From fa0f4768a077da8e598aca213ad2d409e2542e99 Mon Sep 17 00:00:00 2001 From: Gaurav S Deshmukh Date: Thu, 7 Sep 2023 17:10:06 -0400 Subject: [PATCH 1/4] Added graphs.py --- src/graphs.py | 119 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 src/graphs.py diff --git a/src/graphs.py b/src/graphs.py new file mode 100644 index 0000000..76c8048 --- /dev/null +++ b/src/graphs.py @@ -0,0 +1,119 @@ +"""Classes to create bulk, surface, and adsorbate graphs.""" + +import abc +import networkx as nx +import numpy as np + +from ase.neighborlist import build_neighbor_list + + +class AtomsGraph: + """Create graph representation of a collection of atoms.""" + def __init__(self, atoms, select_idx, max_atoms=50): + """Initialize variables of the class. + + Parameters + ---------- + atoms: ase.Atoms object + Atoms object containing all the atoms in the slab + select_idx: list or np.ndarray + List of indices of atoms that are to be included in the graph + neighbor_list: ase.neighborlist.NeighborList object + Neighbor list that defines bonds between atoms + max_atoms: int (default = 50) + The maximum number of atoms in the graph. Graphs that have fewer + atoms are padded with 0s to reach this value. + """ + # Save parameters + self.atoms = atoms + self.select_idx = select_idx + self.max_atoms = max_atoms + + # Create graph + self.create_graph() + + def create_graph(self): + """Create a graph from an atoms object and neighbor_list.""" + # Create neighbor list of atoms + self.neighbor_list = build_neighbor_list( + self.atoms, bothways=True, self_interaction=False + ) + + # Create NetworkX Multigraph + graph = nx.MultiGraph() + + # Iterate over selected atoms and add them as nodes + for atom in self.atoms: + if atom.index in self.select_idx and atom.index not in list(graph.nodes()): + graph.add_node( + atom.index, + atomic_number=atom.number, + symbol=atom.symbol, + position=atom.position, + ) + + # Iterate over nodes, identify neighbors, and add edges between them + for n in graph.nodes(): + # Get neighbors from neighbor list + neighbor_idx, _ = self.neighbor_list.get_neighbors(n) + # Iterate over neighbors + for nn in neighbor_idx: + # If neighbor is not in graph, add it as a node + if not graph.has_node(nn): + graph.add_node( + nn, + atomic_number=self.atoms[nn].number, + symbol=self.atoms[nn].symbol, + position=self.atoms[nn].position + ) + # Calculate bond distance + bond_dist = np.linalg.norm( + graph.nodes[n].position - graph.nodes[nn].position + ) + graph.add_edge(n, nn, bond_distance=bond_dist) + + # Assign graph object + self.graph = graph + + def featurize(self, node_featurizer, bond_featurizer): + """Featurize nodes and edges of the graph. + + Parameters + ---------- + node_featurizer: TODO + Object that featurizes atoms + bond_featurizer: TODO + Object that featurizes bonds + """ + pass + + def plot(self, filename=None): + """Plot the graph using NetworkX. + + Parameters + ---------- + filename: str (optional) + If provided, the plot is saved with the given filename. + """ + pass + + def get_node_tensor(self): + """Get the node matrix of the graph as a PyTorch Tensor. + + Returns + ------- + node_matrix: torch.Tensor + Node matrix + """ + pass + + def get_adjacency_tensor(self): + """Get the adjacency matrix of the graph as a PyTorch Tensor. + + Returns + ------- + adj_matrix: torch.Tensor + Adjacency tensor + """ + pass + From 6f5e207ce2378d31a6da6fae59d1658b0c6c9eda Mon Sep 17 00:00:00 2001 From: Gaurav S Deshmukh Date: Thu, 7 Sep 2023 17:13:40 -0400 Subject: [PATCH 2/4] Edited name of function --- src/graphs.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/graphs.py b/src/graphs.py index 76c8048..fdead15 100644 --- a/src/graphs.py +++ b/src/graphs.py @@ -107,13 +107,13 @@ def get_node_tensor(self): """ pass - def get_adjacency_tensor(self): - """Get the adjacency matrix of the graph as a PyTorch Tensor. + def get_edge_tensor(self): + """Get the edge matrix of the graph as a PyTorch Tensor. Returns ------- - adj_matrix: torch.Tensor - Adjacency tensor + edge_matrix: torch.Tensor + Edge tensor """ pass From 16decbc582ecdda9c44ae72c45d2467698421033 Mon Sep 17 00:00:00 2001 From: Gaurav S Deshmukh Date: Thu, 7 Sep 2023 17:20:24 -0400 Subject: [PATCH 3/4] Added gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..092dc51 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +*POSCAR* From 72e986b1c30c0e143cafad204cf076e0d37aa6fb Mon Sep 17 00:00:00 2001 From: Gaurav S Deshmukh Date: Thu, 7 Sep 2023 17:26:32 -0400 Subject: [PATCH 4/4] Fixed codestyle --- src/graphs.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/graphs.py b/src/graphs.py index fdead15..3ab8d5e 100644 --- a/src/graphs.py +++ b/src/graphs.py @@ -1,14 +1,15 @@ """Classes to create bulk, surface, and adsorbate graphs.""" import abc + import networkx as nx import numpy as np - from ase.neighborlist import build_neighbor_list class AtomsGraph: """Create graph representation of a collection of atoms.""" + def __init__(self, atoms, select_idx, max_atoms=50): """Initialize variables of the class. @@ -64,7 +65,7 @@ def create_graph(self): nn, atomic_number=self.atoms[nn].number, symbol=self.atoms[nn].symbol, - position=self.atoms[nn].position + position=self.atoms[nn].position, ) # Calculate bond distance bond_dist = np.linalg.norm( @@ -74,7 +75,7 @@ def create_graph(self): # Assign graph object self.graph = graph - + def featurize(self, node_featurizer, bond_featurizer): """Featurize nodes and edges of the graph. @@ -89,7 +90,7 @@ def featurize(self, node_featurizer, bond_featurizer): def plot(self, filename=None): """Plot the graph using NetworkX. - + Parameters ---------- filename: str (optional) @@ -99,21 +100,20 @@ def plot(self, filename=None): def get_node_tensor(self): """Get the node matrix of the graph as a PyTorch Tensor. - + Returns ------- node_matrix: torch.Tensor Node matrix """ pass - + def get_edge_tensor(self): """Get the edge matrix of the graph as a PyTorch Tensor. - + Returns ------- edge_matrix: torch.Tensor Edge tensor """ pass -