Skip to content

Commit

Permalink
Merge pull request #20 from GreeleyGroup/enh/basic_workflow
Browse files Browse the repository at this point in the history
ENH: Add workflow to train model
  • Loading branch information
deshmukg authored Sep 25, 2023
2 parents b6b8292 + c58ba06 commit b86dcea
Show file tree
Hide file tree
Showing 7 changed files with 242 additions and 66 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@ data/*
!data/dband_centers.csv
__pycache__
*.cif
trained_models
*.pt
*__init__.py
41 changes: 12 additions & 29 deletions src/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from ase.io import read
from torch_geometric.data import Data, Dataset

from constants import REPO_PATH
from featurizers import OneHotEncoder
from utils import featurize_atoms, partition_structure
from .constants import REPO_PATH
from .featurizers import OneHotEncoder
from .utils import featurize_atoms, partition_structure_by_layers


class AtomsDataset(Dataset):
Expand Down Expand Up @@ -60,7 +60,7 @@ def __init__(self, root, prop_csv):

def process_data(
self,
z_cutoffs,
layer_cutoffs,
node_features,
edge_features,
max_atoms=None,
Expand All @@ -75,8 +75,8 @@ def process_data(
Parameters
----------
z_cutoffs: list or np.ndarray
List of z-coordinates based on which atomic structures are
layer_cutoffs: list or np.ndarray
List of layer cutoffs based on which atomic structures are
partitioned. The number of partitions is equal to one more than the
length of z_cutoffs.
node_features: list[list]
Expand Down Expand Up @@ -115,7 +115,7 @@ def process_data(
atoms = read(str(file_path))

# Partition structure
part_atoms = partition_structure(atoms, z_cutoffs)
part_atoms = partition_structure_by_layers(atoms, layer_cutoffs)

# Featurize partitions
data_objects = []
Expand Down Expand Up @@ -191,7 +191,7 @@ def __init__(self, atoms):

def process_data(
self,
z_cutoffs,
layer_cutoffs,
node_features,
edge_features,
max_atoms=None,
Expand All @@ -206,8 +206,8 @@ def process_data(
Parameters
----------
z_cutoffs: list or np.ndarray
List of z-coordinates based on which atomic structures are
layer_cutoffs: list or np.ndarray
List of layer cutoffs based on which atomic structures are
partitioned. The number of partitions is equal to one more than the
length of z_cutoffs.
node_features: list[list]
Expand All @@ -229,7 +229,7 @@ def process_data(
# Iterate over files and process them
for atoms_obj in self.atoms:
# Partition structure
part_atoms = partition_structure(atoms_obj, z_cutoffs)
part_atoms = partition_structure_by_layers(atoms_obj, layer_cutoffs)

# Featurize partitions
data_objects = []
Expand Down Expand Up @@ -337,28 +337,11 @@ def load_datapoints(atoms, process_dict):
data_root_path = Path(REPO_PATH) / "data" / "S_calcs"
prop_csv_path = data_root_path / "name_prop.csv"

# Create dataset
dataset = AtomsDataset(data_root_path, prop_csv_path)
# dataset.process_data(z_cutoffs=[13., 20.],
# node_features=[
# ["atomic_number", "dband_center"],
# ["atomic_number", "reactivity"],
# ["atomic_number", "reactivity"],
# ],
# edge_features=[
# ["bulk_bond_distance"],
# ["surface_bond_distance"],
# ["adsorbate_bond_distance"],
# ])
print(dataset[0][-1].x)
print(dataset.df_name_idx.head())
print(dataset[0][-1].name)

# Create datapoint
atoms = read(data_root_path / "Pt_3_Rh_9_-7-7-S.cif")
datapoint = AtomsDatapoints(atoms)
datapoint.process_data(
z_cutoffs=[13.0, 20.0],
layer_cutoffs=[3, 6],
node_features=[
["atomic_number", "dband_center"],
["atomic_number", "reactivity"],
Expand Down
4 changes: 2 additions & 2 deletions src/featurizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from mendeleev import element
from torch.nn.functional import one_hot

from constants import DBAND_FILE_PATH
from graphs import AtomsGraph
from .constants import DBAND_FILE_PATH
from .graphs import AtomsGraph


class OneHotEncoder:
Expand Down
22 changes: 12 additions & 10 deletions src/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, partition_configs):

# Initialize layers
# Initial transform
self.init_transform = []
self.init_transform = nn.ModuleList()
for i in range(self.n_partitions):
self.init_transform.append(
nn.Sequential(
Expand All @@ -58,12 +58,12 @@ def __init__(self, partition_configs):
self.init_conv_layers()

# Pooling layers
self.pool_layers = []
self.pool_layers = nn.ModuleList()
for i in range(self.n_partitions):
self.pool_layers.append(gnn.pool.global_add_pool)
self.pool_layers.append(gnn.aggr.SumAggregation())

# Pool transform
self.pool_transform = []
self.pool_transform = nn.ModuleList()
for i in range(self.n_partitions):
self.pool_transform.append(
nn.Sequential(
Expand All @@ -73,7 +73,7 @@ def __init__(self, partition_configs):
)

# Hidden layers
self.hidden_layers = []
self.hidden_layers = nn.ModuleList()
for i in range(self.n_partitions):
self.hidden_layers.append(
nn.Sequential(
Expand All @@ -98,10 +98,12 @@ def __init__(self, partition_configs):
self.final_lin_transform = nn.Linear(self.n_partitions, 1, bias=False)
with torch.no_grad():
self.final_lin_transform.weight.copy_(torch.ones(self.n_partitions))
for p in self.final_lin_transform.parameters():
p.requires_grad = False

def init_conv_layers(self):
"""Initialize convolutional layers."""
self.conv_layers = []
self.conv_layers = nn.ModuleList()
for i in range(self.n_partitions):
part_conv_layers = []
for j in range(self.n_conv[i]):
Expand All @@ -110,7 +112,7 @@ def init_conv_layers(self):
gnn.CGConv(
channels=self.conv_size[i],
dim=self.num_edge_features[i],
batch_norm=False,
batch_norm=True,
),
nn.LeakyReLU(inplace=True),
]
Expand Down Expand Up @@ -151,7 +153,7 @@ def forward(self, data_objects):
conv_data = layer(conv_data)

# Apply pooling layer
pooled_data = self.pool_layers[i](x=conv_data, batch=None)
pooled_data = self.pool_layers[i](x=conv_data, index=data.batch)

# Apply pool-to-hidden transform
hidden_data = self.pool_transform[i](pooled_data)
Expand All @@ -163,8 +165,8 @@ def forward(self, data_objects):
contributions.append(hidden_data)

# Apply final transformation
contributions = torch.cat(contributions, dim=-1)
output = self.final_lin_transform(contributions)
contributions = torch.cat(contributions)
output = self.final_lin_transform(contributions.view(-1, 3))

return {"output": output, "contributions": contributions}

Expand Down
81 changes: 65 additions & 16 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,24 @@
import torch
from sklearn.metrics import mean_absolute_error, mean_squared_error

from models import MultiGCN
from .models import MultiGCN


class Standardizer:
"""Class to standardize targets."""
def __init__(self, X):

def __init__(self):
"""
Class to standardize outputs.
Initialize with dummy values.
"""
self.mean = 0
self.std = 0.1

def initialize(self, X):
"""Initialize mean and std based on the given tensor.
Parameters
----------
X: torch.Tensor
Expand Down Expand Up @@ -60,6 +69,24 @@ def restore(self, Z):
X = self.mean + Z * self.std
return X

def restore_cont(self, Z):
"""
Restore a standardized contribution to the non-standardized contribution.
Parameters
----------
Z: torch.Tensor
Tensor of standardized contributions
Returns
-------
X: torch.Tensor
Tensor of non-standardized contributions
"""
X = (self.mean / Z.shape[0]) + Z * self.std
return X

def get_state(self):
"""
Return dictionary of the state of the Standardizer.
Expand Down Expand Up @@ -122,7 +149,8 @@ def __init__(self, global_config, partition_configs, model_path):

# Set GPU status
self.use_gpu = global_config["gpu"]

if self.use_gpu:
self.model.cuda()
# Set loss function
if global_config["loss_function"] == "mse":
self.loss_fn = torch.nn.MSELoss()
Expand All @@ -145,18 +173,21 @@ def __init__(self, global_config, partition_configs, model_path):
)
elif global_config["optimizer"].lower().strip() == "sgd":
self.optimizer = torch.optim.SGD(
self.model_parameters(),
self.model.parameters(),
lr=global_config["learning_rate"],
)

# Set scheduler
if "lr_milestones" in global_config.keys():
self.scheduler = torch.optim.MultiStepLR(
self.scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer=self.optimizer, milestones=global_config["lr_milestones"]
)
else:
self.scheduler = None

# Set standardizer
self.standardizer = Standardizer()

def make_directory_structure(self, model_path):
"""Make directory structure to store models and results."""
self.model_path = Path(model_path)
Expand All @@ -173,7 +204,7 @@ def init_standardizer(self, targets):
targets: np.ndarray or torch.Tensor
Array of training outputs
"""
self.standardizer = Standardizer(targets)
self.standardizer.initialize(torch.Tensor(targets))

def train_epoch(self, dataloader):
"""Train the model for a single epoch.
Expand Down Expand Up @@ -209,11 +240,11 @@ def train_epoch(self, dataloader):
pred_dict = self.model(nn_input)

# Calculate loss
loss = self.loss_fn(nn_output, pred_dict["output"])
loss = self.loss_fn(pred_dict["output"], nn_output.unsqueeze(1))
avg_loss += loss

# Calculate metric
y_pred = self.standardizer.restore(pred_dict["output"].cpu())
y_pred = self.standardizer.restore(pred_dict["output"].cpu().detach())
metric = self.metric_fn(y, y_pred)
avg_metric += metric

Expand Down Expand Up @@ -273,11 +304,11 @@ def validate(self, dataloader):
pred_dict = self.model(nn_input)

# Calculate loss
loss = self.loss_fn(nn_output, pred_dict["output"])
loss = self.loss_fn(pred_dict["output"], nn_output.unsqueeze(1))
avg_loss += loss

# Calculate metric
y_pred = self.standardizer.restore(pred_dict["output"].cpu())
y_pred = self.standardizer.restore(pred_dict["output"].cpu().detach())
metric = self.metric_fn(y, y_pred)
avg_metric += metric

Expand Down Expand Up @@ -305,20 +336,22 @@ def predict(self, dataset, indices, return_targets=False):
Returns
-------
prediction_dict: dict
Dictionary containing "targets", "predictions", and "indices" (copy of
predict_idx).
Dictionary containing "targets", "predictions", "contributions" and
"indices" (copy of predict_idx).
"""
# Create arrays
n_partitions = len(dataset.get(indices[0]))
targets = np.zeros(len(indices))
predictions = np.zeros(len(indices))
contributions = np.zeros((len(indices), n_partitions))

# Enable eval mode of model
self.model.eval()

# Go over each batch in the dataloader
for i, idx in enumerate(indices):
# Get data objects
data_objects = dataset.get(i)
data_objects = dataset.get(idx)

# Standardize output
if return_targets:
Expand All @@ -332,12 +365,19 @@ def predict(self, dataset, indices, return_targets=False):

# Compute prediction
pred_dict = self.model(nn_input)
predictions[i] = self.standardizer.restore(pred_dict["output"].cpu())
predictions[i] = self.standardizer.restore(
pred_dict["output"].cpu().detach()
)
conts_std = pred_dict["contributions"].cpu().detach()
contributions[i, :] = (
self.standardizer.restore_cont(conts_std).numpy().flatten()
)

predictions_dict = {
"targets": targets,
"predictions": predictions,
"indices": indices,
"contributions": contributions,
}

return predictions_dict
Expand Down Expand Up @@ -444,11 +484,20 @@ def train(self, epochs, dataloader_dict, verbose=False):
self.save(i, best_status)

# Save losses and metrics
train_losses.append(train_loss)
val_losses.append(val_loss)
train_losses.append(train_loss.cpu().detach().numpy())
val_losses.append(val_loss.cpu().detach().numpy())
train_metrics.append(train_metric)
val_metrics.append(val_metric)

# Print, if verbose
if verbose:
print(
f"Epoch: [{i}] Training loss: [{train_loss:.3f}] "
+ f"Training metric: [{train_metric:.3f}] "
+ f"Validation loss: [{val_loss:.3f}] "
+ f"Validation metric: [{val_metric:.3f}]"
)

# Load the best model
self.load(best_status=True)

Expand Down
Loading

0 comments on commit b86dcea

Please sign in to comment.