diff --git a/train/autoencoder_train.py b/train/autoencoder_train.py index 2703748..e9e7402 100644 --- a/train/autoencoder_train.py +++ b/train/autoencoder_train.py @@ -1,3 +1,121 @@ -# -*- coding: utf-8 -*- +#-*- coding: utf-8 -*- + +from datetime import datetime +import time +import typing +from typing import List + +import numpy as np +import os +import keras +from keras.metrics import MeanSquaredError +from keras import Model +from keras.callbacks import ModelCheckpoint, CSVLogger +import matplotlib.pyplot as plt + +from model.model import CompoundModel +from model.metrics import MutualInformation, mutual_information +from visualize.visualize import confusion_matrix +from visualize.plot import roc_plot +from train.encoder_train import build_encoder +from train.decoder_train import build_decoder + +def autoencoder_workflow(params, shape, n_classes, + train_set, validation_set, test_set, + categories, keys, path): + + model = build_autoencoder(params, shape, n_classes) + model = train_autoencoder(params, model, train_set, validation_set, path) + + m = {key: None for key in keys} + m, test_predict = test_autoencoder( + model, + test_set, + m + ) + model_metrics = {metric: value for metric, value in m.items()} + + evaluate_autoencoder( + params, + test_predict, + test_set[0], + categories, + + keys, + path + ) + + save_autoencoder(params, model, path) + +def build_autoencoder(params, shape, n_classes): + autoencoder_params = params["autoencoder_params"] + #mi = MutualInformation() + mse = MeanSquaredError() + + encoder_model = build_encoder(params, shape, n_classes) + decoder_model = build_decoder(params, shape, n_classes) + model = keras.Sequential([encoder_model, decoder_model]) + model.compile( + optimizer=keras.optimizers.Adam(learning_rate=4e-4), + loss=autoencoder_params["loss"], + metrics=[mse]#, mutual_information] + ) + + return model + +def train_autoencoder(params, model, train_set, validation_set, path): + log_level = params["log_level"] + timestamp = params["timestamp"] + params = params["autoencoder_params"] + callbacks = [ + ModelCheckpoint( + filepath=path / timestamp / f"{timestamp}_checkpoint.keras", + monitor = "val_loss", + save_best_only=True, + save_weights_only=False, + verbose=1 + ), + CSVLogger(path / timestamp / f"{timestamp}_log.csv") + ] + + start = time.time() + model.fit( + x=train_set, y=train_set, + validation_data=(validation_set, validation_set), + batch_size=params["batch_size"], + epochs=params["epochs"], + verbose=log_level, + callbacks=callbacks + ) + end = time.time() + print("Training time: ", end - start) + return model + +def test_autoencoder(model: Model, test: List, metrics: dict): + """ + """ + + test_eval = model.evaluate(test, test) + if len(metrics.keys()) == 1: + metrics[metrics.keys()[0]] = test_eval + else: + for i, key in enumerate(metrics.keys()): + metrics[key] = np.mean(test_eval[i]) + + test_predict = model.predict(test)[0] + + return metrics, test_predict + +def evaluate_autoencoder(params, test_predict, test_set, categories, keys, path): + plt.pcolor(test_set) + plt.savefig(path / params["timestamp"] / "original.png") + plt.close() + plt.pcolor(test_predict) + plt.savefig(path / params["timestamp"] / "reproduction.png") + plt.close() + return + +def save_autoencoder(params, model, path): + model.save(path / params["timestamp"] / f"{params['timestamp']}_autoencoder.keras") # EOF diff --git a/train/decoder_train.py b/train/decoder_train.py index 56bd9fa..2e4c6ff 100644 --- a/train/decoder_train.py +++ b/train/decoder_train.py @@ -11,8 +11,8 @@ from visualize.plot import spectra_plot def decoder_workflow(params, train_set, validation_set, test_set, - n_classes, categories, keys): - decoder = load_decoder(params, train_set[0].shape[1:], n_classes) + n_classes, categories, keys, modelpath): + decoder = build_decoder(params, train_set[0].shape[1:], n_classes) decoder = train_decoder(decoder, params, train_set, validation_set) # Test model performance @@ -26,7 +26,7 @@ def decoder_workflow(params, train_set, validation_set, test_set, spectra_plot(test_predict[0], name=f"{target}-{treatment}-predict") spectra_plot(test_set[0][0], name=f"{target}-{treatment}-true") -def load_decoder(params, input_shape, n_classes): +def build_decoder(params, input_shape, n_classes): """ """ @@ -86,8 +86,17 @@ def test_decoder(decoder: Model, test: List, metrics: dict): return metrics, test_predict -def save_decoder(decoder: Model): +def evaluate_decoder(params, test_predict, test_set, test_loss, + test_accuracy, categories, keys): + params = params["decoder_params"] + for predict, groundtruth, key in zip(test_predict, test_set[1], keys): + confusion_matrix(predict, groundtruth, categories[key], key) + roc_plot(predict, groundtruth, key) + save_metric(params, test_loss, test_accuracy) + +def save_decoder(decoder: Model): + model.save(path + "decoder.keras") return # EOF diff --git a/train_model.py b/train_model.py index 979b14f..d7b79e2 100644 --- a/train_model.py +++ b/train_model.py @@ -5,27 +5,35 @@ Launches the training process for the model. """ +# Built-in module imports +from datetime import datetime import os +from pathlib import Path +import shutil as sh +# Environment variables os.environ["KERAS_BACKEND"] = "jax" os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" +# 3rd party module imports import jax import json from pyspark.sql import SparkSession import keras import matplotlib.pyplot as plt +# Local module imports from pipe.etl import etl, read from train.encoder_train import encoder_workflow from train.decoder_train import decoder_workflow +from train.autoencoder_train import autoencoder_workflow def parameters(): with open("parameters.json", "r") as file: params = json.load(file) return params -def multi_gpu(): +def model_parallel(): devices = jax.devices("gpu") mesh = keras.distribution.DeviceMesh( shape=(len(devices),), axis_names=["model"], devices=devices @@ -45,14 +53,25 @@ def multi_gpu(): model_parallel = keras.distribution.ModelParallel( layout_map=layout_map ) - keras.distribution.set_distribution(model_parallel) + return model_parallel +def data_parallel(): + devices = jax.devices("gpu") + data_parallel = keras.distribution.DataParallel(devices=devices) + mesh = keras.distribution.DeviceMesh( + shape=(len(devices),), axis_names=["data"], devices=devices + ) + data_parallel = keras.distribution.DataParallel(mesh) + + return data_parallel def main(): # jax mesh setup params = parameters() spark = SparkSession.builder.appName("train").getOrCreate() keys = ["treatment", "target"] + #parallel = data_parallel() + #keras.distribution.set_distribution(parallel) if params["load_from_scratch"]: data = etl(spark, split=params["encoder_params"]["split"]) @@ -63,7 +82,14 @@ def main(): n_classes = [dset.shape[1] for dset in train_set[1]] shape = train_set[0].shape[1:] + path = Path("/app/workdir/model") + strfmt = "%Y%m%d_%H%M%S" + timestamp = datetime.now().strftime(strfmt) + params["timestamp"] = timestamp + os.mkdir(path/timestamp) + sh.copy("parameters.json", path/timestamp/"parameters.json") + if params["train_encoder"]: encoder_workflow( params, @@ -73,7 +99,8 @@ def main(): validation_set, test_set, categories, - keys + keys, + path ) if params["train_decoder"]: @@ -84,7 +111,21 @@ def main(): test_set, n_classes, categories, - keys + keys, + path + ) + + if params["train_autoencoder"]: + autoencoder_workflow( + params, + shape, + n_classes, + train_set[0], + validation_set[0], + test_set[0], + categories, + keys, + path ) return