diff --git a/src/graphs.py b/src/graphs.py index fdead15..3ab8d5e 100644 --- a/src/graphs.py +++ b/src/graphs.py @@ -1,14 +1,15 @@ """Classes to create bulk, surface, and adsorbate graphs.""" import abc + import networkx as nx import numpy as np - from ase.neighborlist import build_neighbor_list class AtomsGraph: """Create graph representation of a collection of atoms.""" + def __init__(self, atoms, select_idx, max_atoms=50): """Initialize variables of the class. @@ -64,7 +65,7 @@ def create_graph(self): nn, atomic_number=self.atoms[nn].number, symbol=self.atoms[nn].symbol, - position=self.atoms[nn].position + position=self.atoms[nn].position, ) # Calculate bond distance bond_dist = np.linalg.norm( @@ -74,7 +75,7 @@ def create_graph(self): # Assign graph object self.graph = graph - + def featurize(self, node_featurizer, bond_featurizer): """Featurize nodes and edges of the graph. @@ -89,7 +90,7 @@ def featurize(self, node_featurizer, bond_featurizer): def plot(self, filename=None): """Plot the graph using NetworkX. - + Parameters ---------- filename: str (optional) @@ -99,21 +100,20 @@ def plot(self, filename=None): def get_node_tensor(self): """Get the node matrix of the graph as a PyTorch Tensor. - + Returns ------- node_matrix: torch.Tensor Node matrix """ pass - + def get_edge_tensor(self): """Get the edge matrix of the graph as a PyTorch Tensor. - + Returns ------- edge_matrix: torch.Tensor Edge tensor """ pass -