diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..092dc51 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +*POSCAR* diff --git a/src/graphs.py b/src/graphs.py new file mode 100644 index 0000000..3ab8d5e --- /dev/null +++ b/src/graphs.py @@ -0,0 +1,119 @@ +"""Classes to create bulk, surface, and adsorbate graphs.""" + +import abc + +import networkx as nx +import numpy as np +from ase.neighborlist import build_neighbor_list + + +class AtomsGraph: + """Create graph representation of a collection of atoms.""" + + def __init__(self, atoms, select_idx, max_atoms=50): + """Initialize variables of the class. + + Parameters + ---------- + atoms: ase.Atoms object + Atoms object containing all the atoms in the slab + select_idx: list or np.ndarray + List of indices of atoms that are to be included in the graph + neighbor_list: ase.neighborlist.NeighborList object + Neighbor list that defines bonds between atoms + max_atoms: int (default = 50) + The maximum number of atoms in the graph. Graphs that have fewer + atoms are padded with 0s to reach this value. + """ + # Save parameters + self.atoms = atoms + self.select_idx = select_idx + self.max_atoms = max_atoms + + # Create graph + self.create_graph() + + def create_graph(self): + """Create a graph from an atoms object and neighbor_list.""" + # Create neighbor list of atoms + self.neighbor_list = build_neighbor_list( + self.atoms, bothways=True, self_interaction=False + ) + + # Create NetworkX Multigraph + graph = nx.MultiGraph() + + # Iterate over selected atoms and add them as nodes + for atom in self.atoms: + if atom.index in self.select_idx and atom.index not in list(graph.nodes()): + graph.add_node( + atom.index, + atomic_number=atom.number, + symbol=atom.symbol, + position=atom.position, + ) + + # Iterate over nodes, identify neighbors, and add edges between them + for n in graph.nodes(): + # Get neighbors from neighbor list + neighbor_idx, _ = self.neighbor_list.get_neighbors(n) + # Iterate over neighbors + for nn in neighbor_idx: + # If neighbor is not in graph, add it as a node + if not graph.has_node(nn): + graph.add_node( + nn, + atomic_number=self.atoms[nn].number, + symbol=self.atoms[nn].symbol, + position=self.atoms[nn].position, + ) + # Calculate bond distance + bond_dist = np.linalg.norm( + graph.nodes[n].position - graph.nodes[nn].position + ) + graph.add_edge(n, nn, bond_distance=bond_dist) + + # Assign graph object + self.graph = graph + + def featurize(self, node_featurizer, bond_featurizer): + """Featurize nodes and edges of the graph. + + Parameters + ---------- + node_featurizer: TODO + Object that featurizes atoms + bond_featurizer: TODO + Object that featurizes bonds + """ + pass + + def plot(self, filename=None): + """Plot the graph using NetworkX. + + Parameters + ---------- + filename: str (optional) + If provided, the plot is saved with the given filename. + """ + pass + + def get_node_tensor(self): + """Get the node matrix of the graph as a PyTorch Tensor. + + Returns + ------- + node_matrix: torch.Tensor + Node matrix + """ + pass + + def get_edge_tensor(self): + """Get the edge matrix of the graph as a PyTorch Tensor. + + Returns + ------- + edge_matrix: torch.Tensor + Edge tensor + """ + pass