From 4878ea54bbcfc02365217b5e5788aae36016d88d Mon Sep 17 00:00:00 2001 From: Dawith Date: Mon, 3 Nov 2025 16:09:28 -0500 Subject: [PATCH] Combined model training code with improved visuals --- train/decoder_train.py | 30 +++++++++++++++++--------- train/encoder_train.py | 38 ++++++++++++++++++++++++++++---- train_model.py | 49 ++++++++++++++++++++++-------------------- visualize/plot.py | 4 +++- visualize/visualize.py | 14 ++++++++++++ 5 files changed, 97 insertions(+), 38 deletions(-) diff --git a/train/decoder_train.py b/train/decoder_train.py index 182e443..56bd9fa 100644 --- a/train/decoder_train.py +++ b/train/decoder_train.py @@ -2,27 +2,35 @@ import keras from keras import Model +import numpy as np import time import typing from typing import List from model.model import DecoderModel -""" -# data parameters -SPLIT = params["split"] -LOAD_FROM_SCRATCH = params["load_from_scratch"] +from visualize.plot import spectra_plot +def decoder_workflow(params, train_set, validation_set, test_set, + n_classes, categories, keys): + decoder = load_decoder(params, train_set[0].shape[1:], n_classes) -# training parameters -BATCH_SIZE = params["batch_size"] -EPOCHS = params["epochs"] -LOG_LEVEL = params["log_level"] -""" + decoder = train_decoder(decoder, params, train_set, validation_set) + # Test model performance + + metrics = {key: None for key in keys} + metrics, test_predict = test_decoder(decoder, test_set, metrics) + + target = categories["target"][str(np.argmax(test_set[1][1][0]))] + treatment = categories["treatment"][str(np.argmax(test_set[1][0][0]))] + + spectra_plot(test_predict[0], name=f"{target}-{treatment}-predict") + spectra_plot(test_set[0][0], name=f"{target}-{treatment}-true") def load_decoder(params, input_shape, n_classes): """ """ + params = params["decoder_params"] decoder = DecoderModel( input_shape, params["head_size"], @@ -47,6 +55,8 @@ def train_decoder(decoder, params, train, validation): """ """ + log_level = params["log_level"] + params = params["decoder_params"] start = time.time() decoder.fit( x=train[1], @@ -54,7 +64,7 @@ def train_decoder(decoder, params, train, validation): validation_data=(validation[1], validation[0]), batch_size=params["batch_size"], epochs=params["epochs"], - verbose=params["log_level"] + verbose=log_level ) end = time.time() print("Training time: ", end - start) diff --git a/train/encoder_train.py b/train/encoder_train.py index 230c6f4..fae134b 100644 --- a/train/encoder_train.py +++ b/train/encoder_train.py @@ -9,7 +9,33 @@ from visualize.visualize import confusion_matrix from visualize.plot import roc_plot +def encoder_workflow(params, shape, n_classes, + train_set, validation_set, test_set, + categories, keys): + model = build_encoder(params, shape, n_classes) + model = train_encoder(params, model, train_set, validation_set) + test_predict, test_loss, test_accuracy = test_encoder( + params, + model, + test_set, + categories, + keys + ) + + evaluate_encoder( + params, + test_predict, + test_set, + test_loss, + test_accuracy, + categories, + keys + ) + + def build_encoder(params, input_shape, n_classes): + log_level = params["log_level"] + params = params["encoder_params"] model = CompoundModel( input_shape, params["head_size"], @@ -24,28 +50,31 @@ def build_encoder(params, input_shape, n_classes): model.compile( optimizer=keras.optimizers.Adam(learning_rate=4e-4), - loss="categorical_crossentropy", - metrics=["categorical_accuracy", "categorical_accuracy"] + loss=params["loss"], + metrics=params["metrics"] ) - if params["log_level"] == 1: + if log_level == 1: model.summary() return model def train_encoder(params, model, train_set, validation_set): + log_level = params["log_level"] + params = params["encoder_params"] start = time.time() model.fit( x=train_set[0], y=train_set[1], validation_data=(validation_set[0], validation_set[1]), batch_size=params["batch_size"], epochs=params["epochs"], - verbose=params["log_level"] + verbose=log_level ) end = time.time() print("Training time: ", end - start) return model def test_encoder(params, model, test_set, categories, keys): + params = params["encoder_params"] # Test model performance test_loss, test_accuracy, _, _, _, _ = model.evaluate( test_set[0], @@ -57,6 +86,7 @@ def test_encoder(params, model, test_set, categories, keys): return test_predict, test_loss, test_accuracy def evaluate_encoder(params, test_predict, test_set, test_loss, test_accuracy, categories, keys): + params = params["encoder_params"] for predict, groundtruth, key in zip(test_predict, test_set[1], keys): confusion_matrix(predict, groundtruth, categories[key], key) roc_plot(predict, groundtruth, key) diff --git a/train_model.py b/train_model.py index 5607cb0..979b14f 100644 --- a/train_model.py +++ b/train_model.py @@ -17,8 +17,8 @@ import matplotlib.pyplot as plt from pipe.etl import etl, read -from train.encoder_train import build_encoder, evaluate_encoder, \ - train_encoder, test_encoder +from train.encoder_train import encoder_workflow +from train.decoder_train import decoder_workflow def parameters(): with open("parameters.json", "r") as file: @@ -55,34 +55,37 @@ def main(): keys = ["treatment", "target"] if params["load_from_scratch"]: - data = etl(spark, split=params["split"]) + data = etl(spark, split=params["encoder_params"]["split"]) else: - data = read(spark, split=params["split"]) + data = read(spark, split=params["encoder_params"]["split"]) (train_set, validation_set, test_set, categories) = data n_classes = [dset.shape[1] for dset in train_set[1]] shape = train_set[0].shape[1:] + + if params["train_encoder"]: + encoder_workflow( + params, + shape, + n_classes, + train_set, + validation_set, + test_set, + categories, + keys + ) - model = build_encoder(params, shape, n_classes) - model = train_encoder(params, model, train_set, validation_set) - test_predict, test_loss, test_accuracy = test_encoder( - params, - model, - test_set, - categories, - keys - ) - - evaluate_encoder( - params, - test_predict, - test_set, - test_loss, - test_accuracy, - categories, - keys - ) + if params["train_decoder"]: + decoder_workflow( + params, + train_set, + validation_set, + test_set, + n_classes, + categories, + keys + ) return diff --git a/visualize/plot.py b/visualize/plot.py index a64da97..f3db37c 100644 --- a/visualize/plot.py +++ b/visualize/plot.py @@ -43,6 +43,7 @@ def roc_plot(predict, groundtruth, key): plt.ylabel("True Positive Rate") plt.savefig(f"/app/workdir/figures/roc_curve_{key}.png", bbox_inches="tight") + plt.close() with open(f"roc_fpr_tpr_{key}.json", 'w') as f: roc_dict = {"fpr": fpr.tolist(), "tpr": tpr.tolist(), @@ -55,15 +56,16 @@ def spectra_plot(spec_array, name=None): """ spec_array = spec_array[:,:130] - print(spec_array.shape) plt.pcolor(spec_array, cmap="bwr", vmin=-1, vmax=1) plt.title(f"{name} spectra") + plt.colorbar() plt.savefig(f"/app/workdir/figures/{name}_spectra_plot.png") plt.close() spec_array -= np.mean(spec_array[:6,:], axis=0) plt.pcolor(spec_array, cmap="bwr", vmin=-1, vmax=1) plt.title(f"{name} spectrogram") + plt.colorbar() plt.savefig(f"/app/workdir/figures/{name}_spectrogram_plot.png") plt.close() diff --git a/visualize/visualize.py b/visualize/visualize.py index c56630b..743957b 100644 --- a/visualize/visualize.py +++ b/visualize/visualize.py @@ -39,6 +39,20 @@ def confusion_matrix(prediction, groundtruth, labels, key): "matrix": conf_matrix.tolist()} json.dump(confusion_dict, f) +def similarity_matrix(array1, array2, label1, label2): + array1 = np.array(array1) + array2 = np.array(array2) + sim_matrix = np.array([ + [np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)) + for vec2 in array2] + for vec1 in array1 + ]) + + plt.pcolor(sim_matrix, edgecolors="black", vmin=-1, vmax=1, linewidth=0.5) + plt.gca().set_aspect("equal") + plt.colorbar() + + def visualize_data_distribution(data: DataFrame) -> None: """