Skip to content

Commit

Permalink
Fix codestyle
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaurav S Deshmukh committed Sep 13, 2023
1 parent 2347f03 commit 119e5c7
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 54 deletions.
2 changes: 1 addition & 1 deletion src/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@

REPO_PATH = pathlib.Path(__file__).parents[1]

DBAND_FILE_PATH = REPO_PATH / "data" / "dband_centers.csv"
DBAND_FILE_PATH = REPO_PATH / "data" / "dband_centers.csv"
85 changes: 51 additions & 34 deletions src/featurizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@
import abc
import csv

import numpy as np
import networkx as nx
import numpy as np
import torch

from torch.nn.functional import one_hot
from mendeleev import element
from torch.nn.functional import one_hot

from constants import DBAND_FILE_PATH
from graphs import AtomsGraph


class OneHotEncoder:
"""Featurize a property using a one-hot encoding scheme."""

def __init__(self):
"""Blank constructor."""
pass
Expand All @@ -37,7 +38,7 @@ def fit(self, min, max, n_intervals):

def transform(self, property):
"""Transform a given property vector/matrix/tensor.
Parameters
----------
property: list or np.ndarray or torch.Tensor
Expand All @@ -48,24 +49,24 @@ def transform(self, property):
property = torch.Tensor(property)

# Scale between 0 and num_intervals
scaled_prop = ((property - self.min)
/ (self.max - self.min)) * self.n_intervals

scaled_prop = ((property - self.min) / (self.max - self.min)) * self.n_intervals

# Apply floor function
floor_prop = torch.floor(scaled_prop)

# Create onehot encoding
onehot_prop = one_hot(floor_prop.to(torch.int64),
num_classes=self.n_intervals)

onehot_prop = one_hot(floor_prop.to(torch.int64), num_classes=self.n_intervals)

return onehot_prop


class Featurizer(abc.ABC):
"""Meta class for defining featurizers."""

@abc.abstractmethod
def __init__(self, encoder):
"""Initialize class variables and fit encoder.
Parameters
----------
encoder: OneHotEncoder
Expand All @@ -92,19 +93,19 @@ def featurize_graph(self, atoms_graph):
@abc.abstractproperty
def feat_tensor(self):
"""Return the featurized node tensor.
Returns
-------
feat_tensor: torch.Tensor
Featurized tensor having shape (N, M) where N = number of atoms and
M = n_intervals provided to the encoder
"""
pass

@abc.abstractstaticmethod
def name(self):
"""Return the name of the featurizer.
Returns
-------
_name = str
Expand All @@ -115,6 +116,7 @@ def name(self):

class AtomNumFeaturizer(Featurizer):
"""Featurize nodes based on atomic number."""

def __init__(self, encoder, min=0, max=80, n_intervals=10):
"""Initialize featurizer with min = 0, max = 80, n_intervals = 10.
Expand Down Expand Up @@ -156,22 +158,24 @@ def featurize_graph(self, atoms_graph):
@property
def feat_tensor(self):
"""Return the featurized node tensor.
Returns
-------
feat_tensor: torch.Tensor
Featurized tensor having shape (N, M) where N = number of atoms and
M = n_intervals provided to the encoder
"""
return self._feat_tensor

@staticmethod
def name():
"""Return the name of the featurizer."""
return "atomic_number"


class DBandFeaturizer(Featurizer):
"""Featurize nodes based on close-packed d-band center."""

def __init__(self, encoder, min=-5, max=3, n_intervals=10):
"""Initialize featurizer with min = -5, max = 3, n_intervals = 10.
Expand Down Expand Up @@ -223,7 +227,7 @@ def featurize_graph(self, atoms_graph):
@property
def feat_tensor(self):
"""Return the featurized node tensor.
Returns
-------
feat_tensor: torch.Tensor
Expand All @@ -236,9 +240,11 @@ def feat_tensor(self):
def name():
"""Return the name of the featurizer."""
return "dband_center"



class ValenceFeaturizer(Featurizer):
"""Featurize nodes based on number of valence electrons."""

def __init__(self, encoder, min=1, max=12, n_intervals=12):
"""Initialize featurizer with min = 1, max = 12, n_intervals = 12.
Expand All @@ -263,7 +269,7 @@ def __init__(self, encoder, min=1, max=12, n_intervals=12):
self.encoder.fit(self.min, self.max, self.n_intervals)

# Create a map between atomic number and number of valence electrons
self.map_dict = {0: 0, 1: 1, 2:0}
self.map_dict = {0: 0, 1: 1, 2: 0}
for i in range(3, 21, 1):
self.map_dict[i] = min(element(i).ec.get_valence().ne(), 12)

Expand All @@ -285,7 +291,7 @@ def featurize_graph(self, atoms_graph):
@property
def feat_tensor(self):
"""Return the featurized node tensor.
Returns
-------
feat_tensor: torch.Tensor
Expand All @@ -298,9 +304,11 @@ def feat_tensor(self):
def name():
"""Return the name of the featurizer."""
return "valence"



class CoordinationFeaturizer(Featurizer):
"""Featurize nodes based on coordination number."""

def __init__(self, encoder, min=1, max=15, n_intervals=15):
"""Initialize featurizer with min = 1, max = 15, n_intervals = 15.
Expand Down Expand Up @@ -342,7 +350,7 @@ def featurize_graph(self, atoms_graph):
@property
def feat_tensor(self):
"""Return the featurized node tensor.
Returns
-------
feat_tensor: torch.Tensor
Expand All @@ -355,9 +363,11 @@ def feat_tensor(self):
def name():
"""Return the name of the featurizer."""
return "coordination"



class BondDistanceFeaturizer(Featurizer):
"""Featurize edges based on bond distance."""

def __init__(self, encoder, min, max, n_intervals):
"""Initialize bond distance featurizer.
Expand Down Expand Up @@ -395,22 +405,22 @@ def featurize_graph(self, atoms_graph):

# Create node feature matrix
self._feat_tensor = self.encoder.transform(bond_dist_arr)

# Create list of edge indices
self._edge_indices = torch.Tensor(list(atoms_graph.graph.edges()))

@property
def feat_tensor(self):
"""Return the featurized node tensor.
Returns
-------
feat_tensor: torch.Tensor
Featurized tensor having shape (N, M) where N = number of atoms and
M = n_intervals provided to the encoder
"""
return self._feat_tensor

@property
def edge_indices(self):
"""Return list of edge indices.
Expand All @@ -426,48 +436,55 @@ def edge_indices(self):
def name():
"""Return the name of the featurizer."""
return "valence"



class BulkBondDistanceFeaturizer(BondDistanceFeaturizer):
"""Featurize bulk bond distances.
Child class of BondDistanceFeaturizer with suitable min, max, and n_interval
values initialized for bulk atoms. The values are: min = 0, max = 8,
n_intervals = 8.
"""

def __init__(self, encoder, min=0, max=8, n_intervals=8):
super().__init__(encoder, min=min, max=max, n_intervals=n_intervals)


class SurfaceBondDistanceFeaturizer(BondDistanceFeaturizer):
"""Featurize bulk bond distances.
Child class of BondDistanceFeaturizer with suitable min, max, and n_interval
values initialized for surface atoms. The values are: min = 0, max = 5,
n_intervals = 10.
"""

def __init__(self, encoder, min=0, max=5, n_intervals=10):
super().__init__(encoder, min=min, max=max, n_intervals=n_intervals)


class AdsorbateBondDistanceFeaturizer(BondDistanceFeaturizer):
"""Featurize bulk bond distances.
Child class of BondDistanceFeaturizer with suitable min, max, and n_interval
values initialized for adsorbate atoms. The values are: min = 0, max = 4,
n_intervals = 16.
"""

def __init__(self, encoder, min=0, max=4, n_intervals=16):
super().__init__(encoder, min=min, max=max, n_intervals=n_intervals)


if __name__ == "__main__":
from ase.io import read

atoms = read("CONTCAR")
g = AtomsGraph(atoms, select_idx=[1, 10, 11, 12])

anf = AtomNumFeaturizer(OneHotEncoder())
anf.featurize_graph(g)
print(anf.feat_tensor.shape)

bdf = BulkBondDistanceFeaturizer(OneHotEncoder())
bdf.featurize_graph(g)
print(bdf.feat_tensor)
print(bdf.edge_indices)
print(bdf.edge_indices)
34 changes: 15 additions & 19 deletions src/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import networkx as nx
import numpy as np
from ase.neighborlist import (build_neighbor_list, natural_cutoffs,
NewPrimitiveNeighborList)
from ase.neighborlist import (NewPrimitiveNeighborList, build_neighbor_list,
natural_cutoffs)


class AtomsGraph:
Expand Down Expand Up @@ -43,8 +43,9 @@ def create_graph(self):
self.node_count = 0
self.map_idx_node = {}
for atom in self.atoms:
if (atom.index in self.select_idx
and atom.index not in list(graph.nodes(data="index"))):
if atom.index in self.select_idx and atom.index not in list(
graph.nodes(data="index")
):
graph.add_node(
self.node_count,
index=atom.index,
Expand All @@ -60,7 +61,7 @@ def create_graph(self):
natural_cutoffs(self.atoms),
bothways=True,
self_interaction=False,
primitive=NewPrimitiveNeighborList
primitive=NewPrimitiveNeighborList,
)

# Iterate over nodes, identify neighbors, and add edges between them
Expand Down Expand Up @@ -101,9 +102,7 @@ def create_graph(self):
self.atoms[nn].position,
offset,
)
graph.add_edge(
n, self.map_idx_node[nn], bond_distance=bond_dist
)
graph.add_edge(n, self.map_idx_node[nn], bond_distance=bond_dist)

# Pad graph
graph = self.pad_graph(graph)
Expand All @@ -120,12 +119,12 @@ def pad_graph(self, graph):
This can be used to make sure that the number of nodes in each graph is
equal to max_atoms
Parameters
----------
graph: Networkx.Graph
A Networkx graph
Returns
-------
padded_graph: Networkx.Graph
Expand All @@ -135,11 +134,7 @@ def pad_graph(self, graph):

for i in range(self.node_count, self.max_atoms, 1):
padded_graph.add_node(
i,
index=-1,
atomic_number=0,
symbol="",
position=np.zeros(3)
i, index=-1, atomic_number=0, symbol="", position=np.zeros(3)
)

return padded_graph
Expand All @@ -159,13 +154,13 @@ def calc_minimum_distance(self, pos_1, pos_2, offset):
"""
# First, calculate the distance without offset
dist_1 = np.linalg.norm(pos_1 - pos_2)

# Next calculate the distance by applying offset to second position
dist_2 = np.linalg.norm(pos_1 - (pos_2 + offset @ self.atoms.get_cell()))

# Get minimum distance
min_dist = min(dist_1, dist_2)

return min_dist

def plot(self, filename=None):
Expand All @@ -178,10 +173,11 @@ def plot(self, filename=None):
"""
pass


if __name__ == "__main__":
from ase.io import read

atoms = read("CONTCAR")
g = AtomsGraph(atoms, select_idx=[1, 10, 11, 12])
print(g.map_idx_node)
print(g.graph.edges(data="bond_distance"))
print(g.graph.edges(data="bond_distance"))

0 comments on commit 119e5c7

Please sign in to comment.