Skip to content

Commit

Permalink
Added featurizers
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaurav S Deshmukh committed Sep 11, 2023
1 parent cb3ba9c commit 916580b
Show file tree
Hide file tree
Showing 5 changed files with 320 additions and 31 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
*POSCAR*
*.csv
!data/dband_centers.csv
__pycache__
23 changes: 23 additions & 0 deletions data/dband_centers.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
22, 1.50
23, 1.06
24, 0.16
25, 0.07
26, -0.92
27, -1.17
28, -1.29
29, -2.67
40, 1.95
41, 1.41
42, 0.35
43, -0.60
44, -1.41
45, -1.73
46, -1.83
47, -4.30
72, 2.47
73, 2.00
74, 0.77
75, -0.51
77, -2.11
78, -2.25
79, -3.56
7 changes: 7 additions & 0 deletions src/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Constants are defined here."""

import pathlib

REPO_PATH = pathlib.Path(__file__).parents[1]

DBAND_FILE_PATH = REPO_PATH / "data" / "dband_centers.csv"
287 changes: 287 additions & 0 deletions src/featurizers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
"""Node and bond featurizers."""

import abc
import csv

import numpy as np
import networkx as nx
import torch

from torch.nn.functional import one_hot
from mendeleev import element

from constants import DBAND_FILE_PATH

class OneHotEncoder:
"""Featurize a property using a one-hot encoding scheme."""
def __init__(self):
"""Blank constructor."""
pass

def fit(self, min, max, n_intervals):
"""Fit encoder based on min, max, and number of intervals parameters.
Parameters
----------
min: int
Minimum possible value of the property.
max: int
Maximum possible value of the property.
n_intervals: int
Number of elements in the one-hot encoded array.
"""
self.min = min
self.max = max
self.n_intervals = n_intervals

def transform(self, property):
"""Transform a given property vector/matrix/tensor.
Parameters
----------
property: list or np.ndarray or torch.Tensor
Tensor containing value(s) of the property to be transformed. The
tensor must have a shape of N where N is the number of atoms.
"""
# Transform property to tensor
property = torch.Tensor(property)

# Scale between 0 and num_intervals
scaled_prop = ((property - self.min)
/ (self.max - self.min)) * self.n_intervals

# Apply floor function
floor_prop = torch.floor(scaled_prop)

# Create onehot encoding
onehot_prop = one_hot(floor_prop.to(torch.int64),
num_classes=self.n_intervals)

return onehot_prop

class Featurizer(abc.ABC):
"""Meta class for defining featurizers."""
@abc.abstractmethod
def __init__(self, encoder):
"""Initialize class variables and fit encoder.
Parameters
----------
encoder: OneHotEncoder
Initialized object of class OneHotEncoder.
"""
pass

@abc.abstractmethod
def featurize_graph(self, graph):
"""Featurize an AtomsGraph.
This class should create a feature tensor from the given graph. This
feature tensor should have a shape of (N, M) where N = number of atoms
and M = n_intervals provided to the encoder. The feature tensor should
be saved as self._feat_tensor.
Parameters
----------
graph: AtomsGraph
A graph of a collection of bulk, surface, or adsorbate atoms.
"""
pass

@abc.abstractproperty
def feat_tensor(self):
"""Return the featurized node tensor.
Returns
-------
feat_tensor: torch.Tensor
Featurized tensor having shape (N, M) where N = number of atoms and
M = n_intervals provided to the encoder
"""
pass

@abc.abstractstaticmethod
def name(self):
"""Return the name of the featurizer.
Returns
-------
_name = str
Name of the featurizer.
"""
return "abstract_featurizer"


class AtomNumFeaturizer(Featurizer):
"""Featurize nodes based on atomic number."""
def __init__(self, encoder):
"""Initialize featurizer with min = 1, max = 80, n_intervals = 10.
Parameters
----------
encoder: OneHotEncoder
Initialized object of class OneHotEncoder.
"""
# Initialize variables
self.min = 1
self.max = 80
self.n_intervals = 10

# Fit encoder
self.encoder = encoder
self.encoder.fit(self.min, self.max, self.n_intervals)

def featurize_graph(self, graph):
"""Featurize an AtomsGraph.
Parameters
----------
graph: AtomsGraph
A graph of a collection of bulk, surface, or adsorbate atoms.
"""
# Get atomic numbers
atom_num_dict = nx.get_node_attributes(graph, "atomic_number")
atom_num_arr = np.array(atom_num_dict.values())

# Create node feature matrix
self._feat_tensor = self.encoder.transform(atom_num_arr)

@property
def feat_tensor(self):
"""Return the featurized node tensor.
Returns
-------
feat_tensor: torch.Tensor
Featurized tensor having shape (N, M) where N = number of atoms and
M = n_intervals provided to the encoder
"""
return self._feat_tensor

@staticmethod
def name():
"""Return the name of the featurizer."""
return "atomic_number"

class DBandFeaturizer(Featurizer):
"""Featurize nodes based on close-packed d-band center."""
def __init__(self, encoder):
"""Initialize featurizer with min = -5, max = 3, n_intervals = 10.
Parameters
----------
encoder: OneHotEncoder
Initialized object of class OneHotEncoder.
"""
# Initialize variables
self.min = -5
self.max = 3
self.n_intervals = 10

# Fit encoder
self.encoder = encoder
self.encoder.fit(self.min, self.max, self.n_intervals)

# Get dband centers from csv
self.map_dict = {}
with open(DBAND_FILE_PATH, "r") as f:
csv_reader = csv.reader(f)
for row in csv_reader:
self.map_dict[row[0]] = row[1]

def featurize_graph(self, graph):
"""Featurize an AtomsGraph.
Parameters
----------
graph: AtomsGraph
A graph of a collection of bulk, surface, or adsorbate atoms.
"""
# Get atomic numbers
atom_num_dict = nx.get_node_attributes(graph, "atomic_number")
atom_num_arr = np.array(atom_num_dict.values())

# Map from atomic number to d-band center
dband_arr = np.vectorize(self.map_dict.__getitem__)(atom_num_arr)

# Create node feature matrix
self._feat_tensor = self.encoder.transform(dband_arr)

@property
def feat_tensor(self):
"""Return the featurized node tensor.
Returns
-------
feat_tensor: torch.Tensor
Featurized tensor having shape (N, M) where N = number of atoms and
M = n_intervals provided to the encoder
"""
return self._feat_tensor

@staticmethod
def name():
"""Return the name of the featurizer."""
return "dband_center"

class ValenceFeaturizer(Featurizer):
"""Featurize nodes based on number of valence electrons."""
def __init__(self, encoder):
"""Initialize featurizer with min = 1, max = 12, n_intervals = 12.
Parameters
----------
encoder: OneHotEncoder
Initialized object of class OneHotEncoder.
"""
# Initialize variables
self.min = 1
self.max = 12
self.n_intervals = 12

# Fit encoder
self.encoder = encoder
self.encoder.fit(self.min, self.max, self.n_intervals)

# Create a map between atomic number and number of valence electrons
self.map_dict = {1: 1, 2:0}
for i in range(3, 21, 1):
self.map_dict[i] = element(i).ec.get_valence().ne()

def featurize_graph(self, graph):
"""Featurize an AtomsGraph.
Parameters
----------
graph: AtomsGraph
A graph of a collection of bulk, surface, or adsorbate atoms.
"""
# Get atomic numbers
atom_num_dict = nx.get_node_attributes(graph, "atomic_number")
atom_num_arr = np.array(atom_num_dict.values())

# Create node feature matrix
self._feat_tensor = self.encoder.transform(atom_num_arr)

@property
def feat_tensor(self):
"""Return the featurized node tensor.
Returns
-------
feat_tensor: torch.Tensor
Featurized tensor having shape (N, M) where N = number of atoms and
M = n_intervals provided to the encoder
"""
return self._feat_tensor

@staticmethod
def name():
"""Return the name of the featurizer."""
return "valence"

if __name__ == "__main__":
prop = np.array([1.5, 2.5, 3.5, 4.5, 5.5])
ohf = OneHotEncoder()
ohf.fit(1, 6, 5)
print(ohf.transform(prop))
31 changes: 0 additions & 31 deletions src/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,6 @@ def create_graph(self):
# Assign graph object
self.graph = graph

def featurize(self, node_featurizer, bond_featurizer):
"""Featurize nodes and edges of the graph.
Parameters
----------
node_featurizer: TODO
Object that featurizes atoms
bond_featurizer: TODO
Object that featurizes bonds
"""
pass

def plot(self, filename=None):
"""Plot the graph using NetworkX.
Expand All @@ -98,22 +86,3 @@ def plot(self, filename=None):
"""
pass

def get_node_tensor(self):
"""Get the node matrix of the graph as a PyTorch Tensor.
Returns
-------
node_matrix: torch.Tensor
Node matrix
"""
pass

def get_edge_tensor(self):
"""Get the edge matrix of the graph as a PyTorch Tensor.
Returns
-------
edge_matrix: torch.Tensor
Edge tensor
"""
pass

0 comments on commit 916580b

Please sign in to comment.