diff --git a/src/train.py b/src/train.py index 32539f9..2e9ab02 100644 --- a/src/train.py +++ b/src/train.py @@ -5,31 +5,32 @@ import numpy as np import torch - from sklearn.metrics import mean_absolute_error, mean_squared_error from models import MultiGCN + class Standardizer: + """Class to standardize targets.""" def __init__(self, X): """ Class to standardize outputs. - + Parameters ---------- - X: torch.Tensor + X: torch.Tensor Tensor of outputs """ self.mean = torch.mean(X) self.std = torch.std(X) - + def standardize(self, X): """ Convert a non-standardized output to a standardized output. Parameters ---------- - X: torch.Tensor + X: torch.Tensor Tensor of non-standardized outputs Returns @@ -40,7 +41,7 @@ def standardize(self, X): """ Z = (X - self.mean) / (self.std) return Z - + def restore(self, Z): """ Restore a standardized output to the non-standardized output. @@ -52,13 +53,13 @@ def restore(self, Z): Returns ------- - X: torch.Tensor + X: torch.Tensor Tensor of non-standardized outputs """ X = self.mean + Z * self.std return X - + def get_state(self): """ Return dictionary of the state of the Standardizer. @@ -69,8 +70,8 @@ def get_state(self): Dictionary with the mean and std of the outputs """ - return {"mean" : self.mean, "std" : self.std} - + return {"mean": self.mean, "std": self.std} + def set_state(self, state): """ Load a dictionary containing the state of the Standardizer. @@ -78,13 +79,15 @@ def set_state(self, state): Parameters ---------- state : dict - Dictionary containing mean and std + Dictionary containing mean and std """ self.mean = state["mean"] self.std = state["std"] + class Model: """Wrapper class for a MultiGCN model that allows training and prediction.""" + def __init__(self, global_config, partition_configs, model_path): """Initialize a MultiGCN model. @@ -116,7 +119,7 @@ def __init__(self, global_config, partition_configs, model_path): # Create model path self.make_directory_structure(model_path) - + # Set GPU status self.use_gpu = global_config["gpu"] @@ -127,7 +130,7 @@ def __init__(self, global_config, partition_configs, model_path): raise ValueError( "Incorrect loss function. Currently only 'mse' is supported" ) - + # Set metric function if global_config["metric_function"] == "mae": self.metric_fn = mean_absolute_error @@ -149,8 +152,7 @@ def __init__(self, global_config, partition_configs, model_path): # Set scheduler if "lr_milestones" in global_config.keys(): self.scheduler = torch.optim.MultiStepLR( - optimizer=self.optimizer, - milestones=global_config["lr_milestones"] + optimizer=self.optimizer, milestones=global_config["lr_milestones"] ) else: self.scheduler = None @@ -236,7 +238,7 @@ def train_epoch(self, dataloader): avg_metric = avg_metric / count return avg_loss, avg_metric - + def validate(self, dataloader): """Validate/test the model. @@ -332,11 +334,14 @@ def predict(self, dataset, indices, return_targets=False): pred_dict = self.model(nn_input) predictions[i] = self.standardizer.restore(pred_dict["output"].cpu()) - predictions_dict = {"targets": targets, "predictions": predictions, - "indices": indices} - + predictions_dict = { + "targets": targets, + "predictions": predictions, + "indices": indices, + } + return predictions_dict - + def save(self, epoch, best_status=None): """Save the current state of the model as a dictionary. @@ -364,7 +369,7 @@ def save(self, epoch, best_status=None): def load(self, epoch=None, best_status=None): """Load a model saved at a particular epoch or the best model. - + If best_status is True, epoch is ignored and the best model is loaded. Parameters @@ -379,7 +384,7 @@ def load(self, epoch=None, best_status=None): load_path = self.model_save_path / "best.pt" else: load_path = self.model_save_path / f"model_{epoch}.pt" - + # Load the dictionary load_dict = torch.load(load_path) @@ -392,7 +397,7 @@ def train(self, epochs, dataloader_dict, verbose=False): The training is performed with early stopping, i.e., the metric function is evaluated at every epoch and the model with the best value for this - metric is loaded after training for testing. + metric is loaded after training for testing. Parameters ---------- @@ -422,22 +427,22 @@ def train(self, epochs, dataloader_dict, verbose=False): # Train and validate model for i in range(epochs): - # Train + # Train train_loss, train_metric = self.train_epoch(dataloader_dict["train"]) - + # Validate val_loss, val_metric = self.validate(dataloader_dict["val"]) - + # Check if model is best if val_metric < prev_val_metric: best_status = True prev_val_metric = deepcopy(val_metric) else: best_status = False - + # Save model self.save(i, best_status) - + # Save losses and metrics train_losses.append(train_loss) val_losses.append(val_loss) @@ -450,14 +455,9 @@ def train(self, epochs, dataloader_dict, verbose=False): # Test the model test_loss, test_metric = self.validate(dataloader_dict["test"]) - loss_dict = { - "train": train_losses, "val": val_losses, "test": test_loss - } - metric_dict = { - "train": train_metrics, "val": val_metrics, "test": test_metric - } + loss_dict = {"train": train_losses, "val": val_losses, "test": test_loss} + metric_dict = {"train": train_metrics, "val": val_metrics, "test": test_metric} results_dict = {"loss": loss_dict, "metric": metric_dict} - return results_dict - \ No newline at end of file + return results_dict