From b3bfa1b1787d61da69b6b84efba7443f2a27d9ea Mon Sep 17 00:00:00 2001 From: Dawith Date: Tue, 28 Oct 2025 12:47:29 -0400 Subject: [PATCH] encoder workflow moved to modularized files for better organization and accessibility --- train/encoder_train.py | 86 +++++++++++++++ train_encoder.py | 233 +++++++---------------------------------- visualize/plot.py | 27 +++++ visualize/visualize.py | 35 ++++++- 4 files changed, 182 insertions(+), 199 deletions(-) diff --git a/train/encoder_train.py b/train/encoder_train.py index 6b3b110..230c6f4 100644 --- a/train/encoder_train.py +++ b/train/encoder_train.py @@ -1,3 +1,89 @@ #-*- coding: utf-8 -*- +import time + +import os +import keras + +from model.model import CompoundModel +from visualize.visualize import confusion_matrix +from visualize.plot import roc_plot + +def build_encoder(params, input_shape, n_classes): + model = CompoundModel( + input_shape, + params["head_size"], + params["num_heads"], + params["ff_dim"], + params["num_transformer_blocks"], + params["mlp_units"], + n_classes, + dropout=params["dropout"], + mlp_dropout=params["mlp_dropout"] + ) + + model.compile( + optimizer=keras.optimizers.Adam(learning_rate=4e-4), + loss="categorical_crossentropy", + metrics=["categorical_accuracy", "categorical_accuracy"] + ) + if params["log_level"] == 1: + model.summary() + + return model + +def train_encoder(params, model, train_set, validation_set): + start = time.time() + model.fit( + x=train_set[0], y=train_set[1], + validation_data=(validation_set[0], validation_set[1]), + batch_size=params["batch_size"], + epochs=params["epochs"], + verbose=params["log_level"] + ) + end = time.time() + print("Training time: ", end - start) + return model + +def test_encoder(params, model, test_set, categories, keys): + # Test model performance + test_loss, test_accuracy, _, _, _, _ = model.evaluate( + test_set[0], + test_set[1] + ) + + test_predict = model.predict(test_set[0]) + print(f"Test loss: {test_loss}, test accuracy: {test_accuracy}") + return test_predict, test_loss, test_accuracy + +def evaluate_encoder(params, test_predict, test_set, test_loss, test_accuracy, categories, keys): + for predict, groundtruth, key in zip(test_predict, test_set[1], keys): + confusion_matrix(predict, groundtruth, categories[key], key) + roc_plot(predict, groundtruth, key) + save_metric(params, test_loss, test_accuracy) + +def save_metric(params, test_loss, test_accuracy): + """ + Save the hyperparameters and metric to csv + """ + + metric = { + "head_size": params["head_size"], + "num_heads": params["num_heads"], + "ff_dim": params["ff_dim"], + "num_transformer_blocks": params["num_transformer_blocks"], + "mlp_units": params["mlp_units"][0], + "dropout": params["dropout"], + "mlp_dropout": params["mlp_dropout"], + "batch_size": params["batch_size"], + "epochs": params["epochs"], + "test_loss": test_loss, + "test_accuracy": test_accuracy + } + if not os.path.exists("/app/workdir/metrics.csv"): + with open("/app/workdir/metrics.csv", "w") as f: + f.write(",".join(metric.keys()) + "\n") + with open("/app/workdir/metrics.csv", "a") as f: + f.write(",".join([str(value) for value in metric.values()]) + "\n") + # EOF diff --git a/train_encoder.py b/train_encoder.py index 1adfdce..5607cb0 100644 --- a/train_encoder.py +++ b/train_encoder.py @@ -6,130 +6,26 @@ """ import os -from pathlib import Path -import time os.environ["KERAS_BACKEND"] = "jax" os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" import jax import json -import numpy as np -#from pyspark.ml.feature import OneHotEncoder, StringIndexer -from pyspark.sql import SparkSession, functions, types, Row -import pyspark as spark -import tensorflow as tf +from pyspark.sql import SparkSession import keras import matplotlib.pyplot as plt -from sklearn.metrics import auc, confusion_matrix, roc_curve -from sklearn.preprocessing import OneHotEncoder, LabelBinarizer -from model.model import CompoundModel from pipe.etl import etl, read +from train.encoder_train import build_encoder, evaluate_encoder, \ + train_encoder, test_encoder -with open("parameters.json", "r") as file: - params = json.load(file) +def parameters(): + with open("parameters.json", "r") as file: + params = json.load(file) + return params - # data parameters - SPLIT = params["split"] - LOAD_FROM_SCRATCH = params["load_from_scratch"] - - # model parameters - HEAD_SIZE = params["head_size"] - NUM_HEADS = params["num_heads"] - FF_DIM = params["ff_dim"] - NUM_TRANSFORMER_BLOCKS = params["num_transformer_blocks"] - MLP_UNITS = params["mlp_units"] - DROPOUT = params["dropout"] - MLP_DROPOUT = params["mlp_dropout"] - - # training parameters - BATCH_SIZE = params["batch_size"] - EPOCHS = params["epochs"] - LOG_LEVEL = params["log_level"] - del params - -def trim(dataframe, column): - - ndarray = np.array(dataframe.select(column).collect()) \ - .reshape(-1, 32, 130) - - return ndarray - -def get_model(input_shape, n_classes): - model = CompoundModel(input_shape, HEAD_SIZE, NUM_HEADS, FF_DIM, - NUM_TRANSFORMER_BLOCKS, MLP_UNITS, n_classes, - dropout=DROPOUT, mlp_dropout=MLP_DROPOUT) - return model - -def transform(spark, dataframe, keys): - dataframe = dataframe.withColumn( - "index", functions.monotonically_increasing_id() - ) - bundle = {key: [ - arr.tolist() - for arr in OneHotEncoder(sparse_output=False) \ - .fit_transform(dataframe.select(key).collect()) - ] for key in keys - } - - bundle = [dict(zip(bundle.keys(), values)) - for values in zip(*bundle.values())] - schema = types.StructType([ - types.StructField(key, types.ArrayType(types.FloatType()), True) - for key in keys - ]) - newframe = spark.createDataFrame(bundle, schema=schema).withColumn( - "index", functions.monotonically_increasing_id() - ) - for key in keys: - dataframe = dataframe.withColumnRenamed(key, f"{key}_str") - dataframe = dataframe.join(newframe, on="index", how="inner") - - return dataframe - -def build_dict(df, key): - df = df.select(key, f"{key}_str").distinct() - - return df.rdd.map( - lambda row: (str(np.argmax(row[key])), row[f"{key}_str"]) - ).collectAsMap() - - -def get_data(spark, split=[0.99, 0.005, 0.005]): - path = Path("/app/workdir") - - labels = [] - with open(path / "train.csv", "r") as file: - for line in file: - labels.append(line.strip().split(",")[0]) - - pipe = SpectrogramPipe(spark, filetype="matfiles") - data = pipe.spectrogram_pipe(path, labels) - data.select("treatment").replace("virus", "cpv") \ - .replace("cont", "pbs") \ - .replace("control", "pbs") \ - .replace("dld", "pbs").distinct() - - data = transform(spark, data, ["treatment", "target"]) - category_dict = { - key: build_dict(data, key) for key in ["treatment", "target"] - } - splits = data.randomSplit(split, seed=42) - trainx, valx, testx = (trim(dset, "spectra") for dset in splits) - trainy, valy, testy = ( - [np.array(dset.select("treatment").collect()).squeeze(), - np.array(dset.select("target").collect()).squeeze()] - for dset in splits - ) - - - return ((trainx, trainy), (valx, valy), (testx, testy), category_dict) - - -def main(): - # jax mesh setup - """ +def multi_gpu(): devices = jax.devices("gpu") mesh = keras.distribution.DeviceMesh( shape=(len(devices),), axis_names=["model"], devices=devices @@ -150,102 +46,43 @@ def main(): layout_map=layout_map ) keras.distribution.set_distribution(model_parallel) - """ - spark = SparkSession.builder.appName("train").getOrCreate() +def main(): + # jax mesh setup + params = parameters() + spark = SparkSession.builder.appName("train").getOrCreate() keys = ["treatment", "target"] - if LOAD_FROM_SCRATCH: - data = etl(spark, split=SPLIT) + if params["load_from_scratch"]: + data = etl(spark, split=params["split"]) else: - data = read(spark, split=SPLIT) + data = read(spark, split=params["split"]) (train_set, validation_set, test_set, categories) = data n_classes = [dset.shape[1] for dset in train_set[1]] - model = get_model(train_set[0].shape[1:], n_classes) - model.compile(optimizer=keras.optimizers.Adam(learning_rate=4e-4), - loss="categorical_crossentropy", - metrics=["categorical_accuracy", "categorical_accuracy"] - ) - if LOG_LEVEL == 1: - model.summary() - - start = time.time() - model.fit(x=train_set[0], y=train_set[1], - validation_data=(validation_set[0], validation_set[1]), - batch_size=BATCH_SIZE, epochs=EPOCHS, verbose=LOG_LEVEL) - end = time.time() - print("Training time: ", end - start) - - # Test model performance - test_loss, test_accuracy, _, _, _, _ = model.evaluate(test_set[0], test_set[1]) - test_predict = model.predict(test_set[0]) - print(f"Test loss: {test_loss}, test accuracy: {test_accuracy}") - for predict, groundtruth, key in zip(test_predict, test_set[1], keys): - conf_matrix = confusion_matrix( - np.argmax(predict, axis=1), - np.argmax(groundtruth, axis=1), - labels=range(len(categories[key].values())), - normalize="pred" - ) - plt.pcolormesh(conf_matrix, edgecolors="black", linewidth=0.5)#origin="upper") - plt.gca().set_aspect("equal") - plt.colorbar() - plt.xticks([int(num) for num in categories[key].keys()], - categories[key].values(), rotation=270) - plt.yticks([int(num) for num in categories[key].keys()], - categories[key].values()) - plt.xlabel("True label") - plt.ylabel("Predicted label") - plt.gcf().set_size_inches(len(categories[key])/10+4, - len(categories[key])/10+3) - plt.savefig(f"/app/workdir/figures/confusion_matrix_{key}.png", - bbox_inches="tight") - plt.close() - with open(f"confusion_matrix_{key}.json", 'w') as f: - confusion_dict = {"prediction": predict.tolist(), - "true": groundtruth.tolist(), - "matrix": conf_matrix.tolist()} - json.dump(confusion_dict, f) - - label_binarizer = LabelBinarizer().fit(groundtruth) - y_onehot_test = label_binarizer.transform(groundtruth) - fpr, tpr, _ = roc_curve( - groundtruth.ravel(), - predict.ravel() - ) - roc_auc = auc(fpr, tpr) - plt.plot(fpr, tpr, label=f"AUC = {roc_auc:.2f}") - plt.savefig(f"/app/workdir/figures/roc_curve_{key}.png", - bbox_inches="tight") - with open(f"roc_fpr_tpr_{key}.json", 'w') as f: - roc_dict = {"fpr": fpr.tolist(), - "tpr": tpr.tolist(), - "auc": roc_auc} - json.dump(roc_dict, f) - print("Done") + shape = train_set[0].shape[1:] + + model = build_encoder(params, shape, n_classes) + model = train_encoder(params, model, train_set, validation_set) + test_predict, test_loss, test_accuracy = test_encoder( + params, + model, + test_set, + categories, + keys + ) - # Save the hyperparameters and metric to csv - metric = { - "head_size": HEAD_SIZE, - "num_heads": NUM_HEADS, - "ff_dim": FF_DIM, - "num_transformer_blocks": NUM_TRANSFORMER_BLOCKS, - "mlp_units": MLP_UNITS[0], - "dropout": DROPOUT, - "mlp_dropout": MLP_DROPOUT, - "batch_size": BATCH_SIZE, - "epochs": EPOCHS, - "test_loss": test_loss, - "test_accuracy": test_accuracy - } - if not os.path.exists("/app/workdir/metrics.csv"): - with open("/app/workdir/metrics.csv", "w") as f: - f.write(",".join(metric.keys()) + "\n") - with open("/app/workdir/metrics.csv", "a") as f: - f.write(",".join([str(value) for value in metric.values()]) + "\n") + evaluate_encoder( + params, + test_predict, + test_set, + test_loss, + test_accuracy, + categories, + keys + ) return diff --git a/visualize/plot.py b/visualize/plot.py index 1fcf3a6..a64da97 100644 --- a/visualize/plot.py +++ b/visualize/plot.py @@ -1,8 +1,11 @@ # --*- coding: utf-8 -*- +import json import matplotlib.pyplot as plt import numpy as np import seaborn as sns +from sklearn.metrics import auc, roc_curve +from sklearn.preprocessing import LabelBinarizer def lineplot(data=None, x=None, y=None, hue=None): """ @@ -23,6 +26,30 @@ def lineplot(data=None, x=None, y=None, hue=None): plt.savefig(f"/app/workdir/figures/lineplot_{y}_by_{x}.png") plt.close() +def roc_plot(predict, groundtruth, key): + + label_binarizer = LabelBinarizer().fit(groundtruth) + y_onehot_test = label_binarizer.transform(groundtruth) + fpr, tpr, _ = roc_curve( + groundtruth.ravel(), + predict.ravel() + ) + roc_auc = auc(fpr, tpr) + plt.plot(fpr, tpr, label=f"AUC = {roc_auc:.2f}") + plt.plot([0, 1], [0, 1], '--', color='r', label="Random") + plt.legend() + plt.title(f"ROC Curve for {key}") + plt.xlabel("False Positive Rate") + plt.ylabel("True Positive Rate") + plt.savefig(f"/app/workdir/figures/roc_curve_{key}.png", + bbox_inches="tight") + with open(f"roc_fpr_tpr_{key}.json", 'w') as f: + roc_dict = {"fpr": fpr.tolist(), + "tpr": tpr.tolist(), + "auc": roc_auc} + json.dump(roc_dict, f) + print("Done") + def spectra_plot(spec_array, name=None): """ """ diff --git a/visualize/visualize.py b/visualize/visualize.py index fb63894..c56630b 100644 --- a/visualize/visualize.py +++ b/visualize/visualize.py @@ -3,9 +3,42 @@ visualize.py """ +import json +import matplotlib.pyplot as plt +import numpy as np +from pyspark.sql import DataFrame +from sklearn.metrics import confusion_matrix as sk_confusion_matrix import typing -import matplotlib.pyplot as plt +def confusion_matrix(prediction, groundtruth, labels, key): + conf_matrix = sk_confusion_matrix( + np.argmax(prediction, axis=1), + np.argmax(groundtruth, axis=1), + labels=range(len(prediction[0])), + normalize="pred" + ) + plt.pcolormesh(conf_matrix, edgecolors="black", linewidth=0.5) + plt.gca().set_aspect("equal") + plt.colorbar() + plt.xticks([int(num) for num in labels.keys()], + labels.values(), rotation=270) + plt.yticks([int(num) for num in labels.keys()], + labels.values()) + plt.xlabel("True label") + plt.ylabel("Predicted label") + plt.gcf().set_size_inches(len(labels)/10+4, + len(labels)/10+3) + plt.savefig(f"/app/workdir/figures/confusion_matrix_{key}.png", + bbox_inches="tight") + plt.close() + + # Save the confusion matrix data as JSON + with open(f"confusion_matrix_{key}.json", 'w') as f: + confusion_dict = {"prediction": prediction.tolist(), + "true": groundtruth.tolist(), + "matrix": conf_matrix.tolist()} + json.dump(confusion_dict, f) + def visualize_data_distribution(data: DataFrame) -> None: """