Skip to content

Commit

Permalink
Fix codestyle
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaurav S Deshmukh committed Sep 24, 2023
1 parent 8875133 commit b212fae
Showing 1 changed file with 36 additions and 36 deletions.
72 changes: 36 additions & 36 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,32 @@

import numpy as np
import torch

from sklearn.metrics import mean_absolute_error, mean_squared_error

from models import MultiGCN


class Standardizer:
"""Class to standardize targets."""
def __init__(self, X):
"""
Class to standardize outputs.
Parameters
----------
X: torch.Tensor
X: torch.Tensor
Tensor of outputs
"""
self.mean = torch.mean(X)
self.std = torch.std(X)

def standardize(self, X):
"""
Convert a non-standardized output to a standardized output.
Parameters
----------
X: torch.Tensor
X: torch.Tensor
Tensor of non-standardized outputs
Returns
Expand All @@ -40,7 +41,7 @@ def standardize(self, X):
"""
Z = (X - self.mean) / (self.std)
return Z

def restore(self, Z):
"""
Restore a standardized output to the non-standardized output.
Expand All @@ -52,13 +53,13 @@ def restore(self, Z):
Returns
-------
X: torch.Tensor
X: torch.Tensor
Tensor of non-standardized outputs
"""
X = self.mean + Z * self.std
return X

def get_state(self):
"""
Return dictionary of the state of the Standardizer.
Expand All @@ -69,22 +70,24 @@ def get_state(self):
Dictionary with the mean and std of the outputs
"""
return {"mean" : self.mean, "std" : self.std}
return {"mean": self.mean, "std": self.std}

def set_state(self, state):
"""
Load a dictionary containing the state of the Standardizer.
Parameters
----------
state : dict
Dictionary containing mean and std
Dictionary containing mean and std
"""
self.mean = state["mean"]
self.std = state["std"]


class Model:
"""Wrapper class for a MultiGCN model that allows training and prediction."""

def __init__(self, global_config, partition_configs, model_path):
"""Initialize a MultiGCN model.
Expand Down Expand Up @@ -116,7 +119,7 @@ def __init__(self, global_config, partition_configs, model_path):

# Create model path
self.make_directory_structure(model_path)

# Set GPU status
self.use_gpu = global_config["gpu"]

Expand All @@ -127,7 +130,7 @@ def __init__(self, global_config, partition_configs, model_path):
raise ValueError(
"Incorrect loss function. Currently only 'mse' is supported"
)

# Set metric function
if global_config["metric_function"] == "mae":
self.metric_fn = mean_absolute_error
Expand All @@ -149,8 +152,7 @@ def __init__(self, global_config, partition_configs, model_path):
# Set scheduler
if "lr_milestones" in global_config.keys():
self.scheduler = torch.optim.MultiStepLR(
optimizer=self.optimizer,
milestones=global_config["lr_milestones"]
optimizer=self.optimizer, milestones=global_config["lr_milestones"]
)
else:
self.scheduler = None
Expand Down Expand Up @@ -236,7 +238,7 @@ def train_epoch(self, dataloader):
avg_metric = avg_metric / count

return avg_loss, avg_metric

def validate(self, dataloader):
"""Validate/test the model.
Expand Down Expand Up @@ -332,11 +334,14 @@ def predict(self, dataset, indices, return_targets=False):
pred_dict = self.model(nn_input)
predictions[i] = self.standardizer.restore(pred_dict["output"].cpu())

predictions_dict = {"targets": targets, "predictions": predictions,
"indices": indices}

predictions_dict = {
"targets": targets,
"predictions": predictions,
"indices": indices,
}

return predictions_dict

def save(self, epoch, best_status=None):
"""Save the current state of the model as a dictionary.
Expand Down Expand Up @@ -364,7 +369,7 @@ def save(self, epoch, best_status=None):

def load(self, epoch=None, best_status=None):
"""Load a model saved at a particular epoch or the best model.
If best_status is True, epoch is ignored and the best model is loaded.
Parameters
Expand All @@ -379,7 +384,7 @@ def load(self, epoch=None, best_status=None):
load_path = self.model_save_path / "best.pt"
else:
load_path = self.model_save_path / f"model_{epoch}.pt"

# Load the dictionary
load_dict = torch.load(load_path)

Expand All @@ -392,7 +397,7 @@ def train(self, epochs, dataloader_dict, verbose=False):
The training is performed with early stopping, i.e., the metric function
is evaluated at every epoch and the model with the best value for this
metric is loaded after training for testing.
metric is loaded after training for testing.
Parameters
----------
Expand Down Expand Up @@ -422,22 +427,22 @@ def train(self, epochs, dataloader_dict, verbose=False):

# Train and validate model
for i in range(epochs):
# Train
# Train
train_loss, train_metric = self.train_epoch(dataloader_dict["train"])

# Validate
val_loss, val_metric = self.validate(dataloader_dict["val"])

# Check if model is best
if val_metric < prev_val_metric:
best_status = True
prev_val_metric = deepcopy(val_metric)
else:
best_status = False

# Save model
self.save(i, best_status)

# Save losses and metrics
train_losses.append(train_loss)
val_losses.append(val_loss)
Expand All @@ -450,14 +455,9 @@ def train(self, epochs, dataloader_dict, verbose=False):
# Test the model
test_loss, test_metric = self.validate(dataloader_dict["test"])

loss_dict = {
"train": train_losses, "val": val_losses, "test": test_loss
}
metric_dict = {
"train": train_metrics, "val": val_metrics, "test": test_metric
}
loss_dict = {"train": train_losses, "val": val_losses, "test": test_loss}
metric_dict = {"train": train_metrics, "val": val_metrics, "test": test_metric}

results_dict = {"loss": loss_dict, "metric": metric_dict}

return results_dict

return results_dict

0 comments on commit b212fae

Please sign in to comment.