Skip to content

ENH: Add workflow to train model #20

Merged
merged 5 commits into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@ data/*
!data/dband_centers.csv
__pycache__
*.cif
trained_models
*.pt
*__init__.py
41 changes: 12 additions & 29 deletions src/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from ase.io import read
from torch_geometric.data import Data, Dataset

from constants import REPO_PATH
from featurizers import OneHotEncoder
from utils import featurize_atoms, partition_structure
from .constants import REPO_PATH
from .featurizers import OneHotEncoder
from .utils import featurize_atoms, partition_structure_by_layers


class AtomsDataset(Dataset):
Expand Down Expand Up @@ -60,7 +60,7 @@ def __init__(self, root, prop_csv):

def process_data(
self,
z_cutoffs,
layer_cutoffs,
node_features,
edge_features,
max_atoms=None,
Expand All @@ -75,8 +75,8 @@ def process_data(

Parameters
----------
z_cutoffs: list or np.ndarray
List of z-coordinates based on which atomic structures are
layer_cutoffs: list or np.ndarray
List of layer cutoffs based on which atomic structures are
partitioned. The number of partitions is equal to one more than the
length of z_cutoffs.
node_features: list[list]
Expand Down Expand Up @@ -115,7 +115,7 @@ def process_data(
atoms = read(str(file_path))

# Partition structure
part_atoms = partition_structure(atoms, z_cutoffs)
part_atoms = partition_structure_by_layers(atoms, layer_cutoffs)

# Featurize partitions
data_objects = []
Expand Down Expand Up @@ -191,7 +191,7 @@ def __init__(self, atoms):

def process_data(
self,
z_cutoffs,
layer_cutoffs,
node_features,
edge_features,
max_atoms=None,
Expand All @@ -206,8 +206,8 @@ def process_data(

Parameters
----------
z_cutoffs: list or np.ndarray
List of z-coordinates based on which atomic structures are
layer_cutoffs: list or np.ndarray
List of layer cutoffs based on which atomic structures are
partitioned. The number of partitions is equal to one more than the
length of z_cutoffs.
node_features: list[list]
Expand All @@ -229,7 +229,7 @@ def process_data(
# Iterate over files and process them
for atoms_obj in self.atoms:
# Partition structure
part_atoms = partition_structure(atoms_obj, z_cutoffs)
part_atoms = partition_structure_by_layers(atoms_obj, layer_cutoffs)

# Featurize partitions
data_objects = []
Expand Down Expand Up @@ -337,28 +337,11 @@ def load_datapoints(atoms, process_dict):
data_root_path = Path(REPO_PATH) / "data" / "S_calcs"
prop_csv_path = data_root_path / "name_prop.csv"

# Create dataset
dataset = AtomsDataset(data_root_path, prop_csv_path)
# dataset.process_data(z_cutoffs=[13., 20.],
# node_features=[
# ["atomic_number", "dband_center"],
# ["atomic_number", "reactivity"],
# ["atomic_number", "reactivity"],
# ],
# edge_features=[
# ["bulk_bond_distance"],
# ["surface_bond_distance"],
# ["adsorbate_bond_distance"],
# ])
print(dataset[0][-1].x)
print(dataset.df_name_idx.head())
print(dataset[0][-1].name)

# Create datapoint
atoms = read(data_root_path / "Pt_3_Rh_9_-7-7-S.cif")
datapoint = AtomsDatapoints(atoms)
datapoint.process_data(
z_cutoffs=[13.0, 20.0],
layer_cutoffs=[3, 6],
node_features=[
["atomic_number", "dband_center"],
["atomic_number", "reactivity"],
Expand Down
4 changes: 2 additions & 2 deletions src/featurizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from mendeleev import element
from torch.nn.functional import one_hot

from constants import DBAND_FILE_PATH
from graphs import AtomsGraph
from .constants import DBAND_FILE_PATH
from .graphs import AtomsGraph


class OneHotEncoder:
Expand Down
22 changes: 12 additions & 10 deletions src/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, partition_configs):

# Initialize layers
# Initial transform
self.init_transform = []
self.init_transform = nn.ModuleList()
for i in range(self.n_partitions):
self.init_transform.append(
nn.Sequential(
Expand All @@ -58,12 +58,12 @@ def __init__(self, partition_configs):
self.init_conv_layers()

# Pooling layers
self.pool_layers = []
self.pool_layers = nn.ModuleList()
for i in range(self.n_partitions):
self.pool_layers.append(gnn.pool.global_add_pool)
self.pool_layers.append(gnn.aggr.SumAggregation())

# Pool transform
self.pool_transform = []
self.pool_transform = nn.ModuleList()
for i in range(self.n_partitions):
self.pool_transform.append(
nn.Sequential(
Expand All @@ -73,7 +73,7 @@ def __init__(self, partition_configs):
)

# Hidden layers
self.hidden_layers = []
self.hidden_layers = nn.ModuleList()
for i in range(self.n_partitions):
self.hidden_layers.append(
nn.Sequential(
Expand All @@ -98,10 +98,12 @@ def __init__(self, partition_configs):
self.final_lin_transform = nn.Linear(self.n_partitions, 1, bias=False)
with torch.no_grad():
self.final_lin_transform.weight.copy_(torch.ones(self.n_partitions))
for p in self.final_lin_transform.parameters():
p.requires_grad = False

def init_conv_layers(self):
"""Initialize convolutional layers."""
self.conv_layers = []
self.conv_layers = nn.ModuleList()
for i in range(self.n_partitions):
part_conv_layers = []
for j in range(self.n_conv[i]):
Expand All @@ -110,7 +112,7 @@ def init_conv_layers(self):
gnn.CGConv(
channels=self.conv_size[i],
dim=self.num_edge_features[i],
batch_norm=False,
batch_norm=True,
),
nn.LeakyReLU(inplace=True),
]
Expand Down Expand Up @@ -151,7 +153,7 @@ def forward(self, data_objects):
conv_data = layer(conv_data)

# Apply pooling layer
pooled_data = self.pool_layers[i](x=conv_data, batch=None)
pooled_data = self.pool_layers[i](x=conv_data, index=data.batch)

# Apply pool-to-hidden transform
hidden_data = self.pool_transform[i](pooled_data)
Expand All @@ -163,8 +165,8 @@ def forward(self, data_objects):
contributions.append(hidden_data)

# Apply final transformation
contributions = torch.cat(contributions, dim=-1)
output = self.final_lin_transform(contributions)
contributions = torch.cat(contributions)
output = self.final_lin_transform(contributions.view(-1, 3))

return {"output": output, "contributions": contributions}

Expand Down
81 changes: 65 additions & 16 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,24 @@
import torch
from sklearn.metrics import mean_absolute_error, mean_squared_error

from models import MultiGCN
from .models import MultiGCN


class Standardizer:
"""Class to standardize targets."""
def __init__(self, X):

def __init__(self):
"""
Class to standardize outputs.

Initialize with dummy values.
"""
self.mean = 0
self.std = 0.1

def initialize(self, X):
"""Initialize mean and std based on the given tensor.

Parameters
----------
X: torch.Tensor
Expand Down Expand Up @@ -60,6 +69,24 @@ def restore(self, Z):
X = self.mean + Z * self.std
return X

def restore_cont(self, Z):
"""
Restore a standardized contribution to the non-standardized contribution.

Parameters
----------
Z: torch.Tensor
Tensor of standardized contributions

Returns
-------
X: torch.Tensor
Tensor of non-standardized contributions

"""
X = (self.mean / Z.shape[0]) + Z * self.std
return X

def get_state(self):
"""
Return dictionary of the state of the Standardizer.
Expand Down Expand Up @@ -122,7 +149,8 @@ def __init__(self, global_config, partition_configs, model_path):

# Set GPU status
self.use_gpu = global_config["gpu"]

if self.use_gpu:
self.model.cuda()
# Set loss function
if global_config["loss_function"] == "mse":
self.loss_fn = torch.nn.MSELoss()
Expand All @@ -145,18 +173,21 @@ def __init__(self, global_config, partition_configs, model_path):
)
elif global_config["optimizer"].lower().strip() == "sgd":
self.optimizer = torch.optim.SGD(
self.model_parameters(),
self.model.parameters(),
lr=global_config["learning_rate"],
)

# Set scheduler
if "lr_milestones" in global_config.keys():
self.scheduler = torch.optim.MultiStepLR(
self.scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer=self.optimizer, milestones=global_config["lr_milestones"]
)
else:
self.scheduler = None

# Set standardizer
self.standardizer = Standardizer()

def make_directory_structure(self, model_path):
"""Make directory structure to store models and results."""
self.model_path = Path(model_path)
Expand All @@ -173,7 +204,7 @@ def init_standardizer(self, targets):
targets: np.ndarray or torch.Tensor
Array of training outputs
"""
self.standardizer = Standardizer(targets)
self.standardizer.initialize(torch.Tensor(targets))

def train_epoch(self, dataloader):
"""Train the model for a single epoch.
Expand Down Expand Up @@ -209,11 +240,11 @@ def train_epoch(self, dataloader):
pred_dict = self.model(nn_input)

# Calculate loss
loss = self.loss_fn(nn_output, pred_dict["output"])
loss = self.loss_fn(pred_dict["output"], nn_output.unsqueeze(1))
avg_loss += loss

# Calculate metric
y_pred = self.standardizer.restore(pred_dict["output"].cpu())
y_pred = self.standardizer.restore(pred_dict["output"].cpu().detach())
metric = self.metric_fn(y, y_pred)
avg_metric += metric

Expand Down Expand Up @@ -273,11 +304,11 @@ def validate(self, dataloader):
pred_dict = self.model(nn_input)

# Calculate loss
loss = self.loss_fn(nn_output, pred_dict["output"])
loss = self.loss_fn(pred_dict["output"], nn_output.unsqueeze(1))
avg_loss += loss

# Calculate metric
y_pred = self.standardizer.restore(pred_dict["output"].cpu())
y_pred = self.standardizer.restore(pred_dict["output"].cpu().detach())
metric = self.metric_fn(y, y_pred)
avg_metric += metric

Expand Down Expand Up @@ -305,20 +336,22 @@ def predict(self, dataset, indices, return_targets=False):
Returns
-------
prediction_dict: dict
Dictionary containing "targets", "predictions", and "indices" (copy of
predict_idx).
Dictionary containing "targets", "predictions", "contributions" and
"indices" (copy of predict_idx).
"""
# Create arrays
n_partitions = len(dataset.get(indices[0]))
targets = np.zeros(len(indices))
predictions = np.zeros(len(indices))
contributions = np.zeros((len(indices), n_partitions))

# Enable eval mode of model
self.model.eval()

# Go over each batch in the dataloader
for i, idx in enumerate(indices):
# Get data objects
data_objects = dataset.get(i)
data_objects = dataset.get(idx)

# Standardize output
if return_targets:
Expand All @@ -332,12 +365,19 @@ def predict(self, dataset, indices, return_targets=False):

# Compute prediction
pred_dict = self.model(nn_input)
predictions[i] = self.standardizer.restore(pred_dict["output"].cpu())
predictions[i] = self.standardizer.restore(
pred_dict["output"].cpu().detach()
)
conts_std = pred_dict["contributions"].cpu().detach()
contributions[i, :] = (
self.standardizer.restore_cont(conts_std).numpy().flatten()
)

predictions_dict = {
"targets": targets,
"predictions": predictions,
"indices": indices,
"contributions": contributions,
}

return predictions_dict
Expand Down Expand Up @@ -444,11 +484,20 @@ def train(self, epochs, dataloader_dict, verbose=False):
self.save(i, best_status)

# Save losses and metrics
train_losses.append(train_loss)
val_losses.append(val_loss)
train_losses.append(train_loss.cpu().detach().numpy())
val_losses.append(val_loss.cpu().detach().numpy())
train_metrics.append(train_metric)
val_metrics.append(val_metric)

# Print, if verbose
if verbose:
print(
f"Epoch: [{i}] Training loss: [{train_loss:.3f}] "
+ f"Training metric: [{train_metric:.3f}] "
+ f"Validation loss: [{val_loss:.3f}] "
+ f"Validation metric: [{val_metric:.3f}]"
)

# Load the best model
self.load(best_status=True)

Expand Down
Loading