-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Decoder training code to test the decoder model build
- Loading branch information
Dawith
committed
Oct 24, 2025
1 parent
c971c8a
commit b29467a
Showing
1 changed file
with
264 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,264 @@ | ||
| # --*-- coding: utf-8 --*-- | ||
| """ | ||
| train.py | ||
| Launches the training process for the model. | ||
| """ | ||
|
|
||
| import os | ||
| from pathlib import Path | ||
| import time | ||
|
|
||
| os.environ["KERAS_BACKEND"] = "jax" | ||
| os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" | ||
|
|
||
| import jax | ||
| import json | ||
| import numpy as np | ||
| #from pyspark.ml.feature import OneHotEncoder, StringIndexer | ||
| from pyspark.sql import SparkSession, functions, types, Row | ||
| import pyspark as spark | ||
| import tensorflow as tf | ||
| import keras | ||
| import keras.metrics as metrics | ||
| import matplotlib.pyplot as plt | ||
| from sklearn.metrics import auc, confusion_matrix, roc_curve | ||
| from sklearn.preprocessing import OneHotEncoder, LabelBinarizer | ||
|
|
||
| from model.model import CompoundModel, DecoderModel | ||
| from pipe.etl import etl, read | ||
|
|
||
| with open("parameters.json", "r") as file: | ||
| params = json.load(file) | ||
|
|
||
| # data parameters | ||
| SPLIT = params["split"] | ||
| LOAD_FROM_SCRATCH = params["load_from_scratch"] | ||
|
|
||
| # model parameters | ||
| HEAD_SIZE = params["head_size"] | ||
| NUM_HEADS = params["num_heads"] | ||
| FF_DIM = params["ff_dim"] | ||
| NUM_TRANSFORMER_BLOCKS = params["num_transformer_blocks"] | ||
| MLP_UNITS = params["mlp_units"] | ||
| DROPOUT = params["dropout"] | ||
| MLP_DROPOUT = params["mlp_dropout"] | ||
|
|
||
| # training parameters | ||
| BATCH_SIZE = params["batch_size"] | ||
| EPOCHS = params["epochs"] | ||
| LOG_LEVEL = params["log_level"] | ||
| del params | ||
|
|
||
| def trim(dataframe, column): | ||
|
|
||
| ndarray = np.array(dataframe.select(column).collect()) \ | ||
| .reshape(-1, 32, 130) | ||
|
|
||
| return ndarray | ||
|
|
||
| def get_model(input_shape, n_classes): | ||
| model = DecoderModel(input_shape, HEAD_SIZE, NUM_HEADS, FF_DIM, | ||
| NUM_TRANSFORMER_BLOCKS, MLP_UNITS, n_classes, | ||
| dropout=DROPOUT, mlp_dropout=MLP_DROPOUT) | ||
| return model | ||
|
|
||
| def transform(spark, dataframe, keys): | ||
| dataframe = dataframe.withColumn( | ||
| "index", functions.monotonically_increasing_id() | ||
| ) | ||
| bundle = {key: [ | ||
| arr.tolist() | ||
| for arr in OneHotEncoder(sparse_output=False) \ | ||
| .fit_transform(dataframe.select(key).collect()) | ||
| ] for key in keys | ||
| } | ||
|
|
||
| bundle = [dict(zip(bundle.keys(), values)) | ||
| for values in zip(*bundle.values())] | ||
| schema = types.StructType([ | ||
| types.StructField(key, types.ArrayType(types.FloatType()), True) | ||
| for key in keys | ||
| ]) | ||
| newframe = spark.createDataFrame(bundle, schema=schema).withColumn( | ||
| "index", functions.monotonically_increasing_id() | ||
| ) | ||
| for key in keys: | ||
| dataframe = dataframe.withColumnRenamed(key, f"{key}_str") | ||
| dataframe = dataframe.join(newframe, on="index", how="inner") | ||
|
|
||
| return dataframe | ||
|
|
||
| def build_dict(df, key): | ||
| df = df.select(key, f"{key}_str").distinct() | ||
|
|
||
| return df.rdd.map( | ||
| lambda row: (str(np.argmax(row[key])), row[f"{key}_str"]) | ||
| ).collectAsMap() | ||
|
|
||
|
|
||
| def get_data(spark, split=[0.99, 0.005, 0.005]): | ||
| path = Path("/app/workdir") | ||
|
|
||
| labels = [] | ||
| with open(path / "train.csv", "r") as file: | ||
| for line in file: | ||
| labels.append(line.strip().split(",")[0]) | ||
|
|
||
| pipe = SpectrogramPipe(spark, filetype="matfiles") | ||
| data = pipe.spectrogram_pipe(path, labels) | ||
| data.select("treatment").replace("virus", "cpv") \ | ||
| .replace("cont", "pbs") \ | ||
| .replace("control", "pbs") \ | ||
| .replace("dld", "pbs").distinct() | ||
|
|
||
| data = transform(spark, data, ["treatment", "target"]) | ||
| category_dict = { | ||
| key: build_dict(data, key) for key in ["treatment", "target"] | ||
| } | ||
| splits = data.randomSplit(split, seed=42) | ||
| trainx, valx, testx = (trim(dset, "spectra") for dset in splits) | ||
| trainy, valy, testy = ( | ||
| [np.array(dset.select("treatment").collect()).squeeze(), | ||
| np.array(dset.select("target").collect()).squeeze()] | ||
| for dset in splits | ||
| ) | ||
|
|
||
|
|
||
| return ((trainx, trainy), (valx, valy), (testx, testy), category_dict) | ||
|
|
||
|
|
||
| def main(): | ||
| # jax mesh setup | ||
| """ | ||
| devices = jax.devices("gpu") | ||
| mesh = keras.distribution.DeviceMesh( | ||
| shape=(len(devices),), axis_names=["model"], devices=devices | ||
| ) | ||
| layout_map = keras.distribution.LayoutMap(mesh) | ||
| layout_map["dense.*kernel"] = (None, "model") | ||
| layout_map["dense.*bias"] = ("model",) | ||
| layout_map["dense.*kernel_regularizer"] = (None, "model") | ||
| layout_map["dense.*bias_regularizer"] = ("model",) | ||
| layout_map["dense.*activity_regularizer"] = (None,) | ||
| layout_map["dense.*kernel_constraint"] = (None, "model") | ||
| layout_map["dense.*bias_constraint"] = ("model",) | ||
| layout_map["conv1d.*kernel"] = (None, None, None, "model") | ||
| layout_map["conv1d.*kernel_regularizer"] = (None, None, None, "model") | ||
| layout_map["conv1d.*bias_regularizer"] = ("model",) | ||
| model_parallel = keras.distribution.ModelParallel( | ||
| layout_map=layout_map | ||
| ) | ||
| keras.distribution.set_distribution(model_parallel) | ||
| """ | ||
|
|
||
| spark = SparkSession.builder.appName("train").getOrCreate() | ||
|
|
||
| keys = ["treatment", "target"] | ||
|
|
||
| LOAD_FROM_SCRATCH = False | ||
| if LOAD_FROM_SCRATCH: | ||
| data = etl(spark, split=SPLIT) | ||
| else: | ||
| data = read(spark, split=SPLIT) | ||
|
|
||
| (train_set, validation_set, test_set, categories) = data | ||
|
|
||
| n_classes = [dset.shape[1] for dset in train_set[1]] | ||
| model = get_model(train_set[0].shape[1:], n_classes) | ||
| model.compile(optimizer=keras.optimizers.Adam(learning_rate=4e-4), | ||
| loss=["mse"], | ||
| ) | ||
| if True: #LOG_LEVEL == 1: | ||
| model.summary() | ||
|
|
||
| start = time.time() | ||
| model.fit(x=train_set[1], y=train_set[0], | ||
| validation_data=(validation_set[1], validation_set[0]), | ||
| batch_size=BATCH_SIZE, epochs=EPOCHS, verbose=LOG_LEVEL) | ||
| end = time.time() | ||
| print("Training time: ", end - start) | ||
|
|
||
| # Test model performance | ||
| #test_loss, test_accuracy = model.evaluate(test_set[1], test_set[0]) | ||
| test_predict = model.predict(test_set[1]) | ||
| plt.pcolor(test_predict[0], cmap='bwr', clim=(-1, 1)) | ||
| plt.savefig("sample_spectra.png") | ||
| plt.close() | ||
| plt.pcolor(test_predict[0] - np.mean(np.array(test_predict[0])[:4], axis=0), | ||
| cmap='bwr', clim=(-1, 1)) | ||
| plt.savefig("sample_spectrogram.png") | ||
| plt.close() | ||
| exit() | ||
| print(f"Test loss: {test_loss}, test accuracy: {test_accuracy}") | ||
| for predict, groundtruth, key in zip(test_predict, test_set[1], keys): | ||
| conf_matrix = confusion_matrix( | ||
| np.argmax(predict, axis=1), | ||
| np.argmax(groundtruth, axis=1), | ||
| labels=range(len(categories[key].values())), | ||
| normalize="pred" | ||
| ) | ||
| plt.pcolormesh(conf_matrix, edgecolors="black", linewidth=0.5)#origin="upper") | ||
| plt.gca().set_aspect("equal") | ||
| plt.colorbar() | ||
| plt.xticks([int(num) for num in categories[key].keys()], | ||
| categories[key].values(), rotation=270) | ||
| plt.yticks([int(num) for num in categories[key].keys()], | ||
| categories[key].values()) | ||
| plt.xlabel("True label") | ||
| plt.ylabel("Predicted label") | ||
| plt.gcf().set_size_inches(len(categories[key])/10+4, | ||
| len(categories[key])/10+3) | ||
| plt.savefig(f"/app/workdir/figures/confusion_matrix_{key}.png", | ||
| bbox_inches="tight") | ||
| plt.close() | ||
| with open(f"confusion_matrix_{key}.json", 'w') as f: | ||
| confusion_dict = {"prediction": predict.tolist(), | ||
| "true": groundtruth.tolist(), | ||
| "matrix": conf_matrix.tolist()} | ||
| json.dump(confusion_dict, f) | ||
|
|
||
| label_binarizer = LabelBinarizer().fit(groundtruth) | ||
| y_onehot_test = label_binarizer.transform(groundtruth) | ||
| fpr, tpr, _ = roc_curve( | ||
| groundtruth.ravel(), | ||
| predict.ravel() | ||
| ) | ||
| roc_auc = auc(fpr, tpr) | ||
| plt.plot(fpr, tpr, label=f"AUC = {roc_auc:.2f}") | ||
| plt.savefig(f"/app/workdir/figures/roc_curve_{key}.png", | ||
| bbox_inches="tight") | ||
| with open(f"roc_fpr_tpr_{key}.json", 'w') as f: | ||
| roc_dict = {"fpr": fpr.tolist(), | ||
| "tpr": tpr.tolist(), | ||
| "auc": roc_auc} | ||
| json.dump(roc_dict, f) | ||
| print("Done") | ||
|
|
||
| # Save the hyperparameters and metric to csv | ||
| metric = { | ||
| "head_size": HEAD_SIZE, | ||
| "num_heads": NUM_HEADS, | ||
| "ff_dim": FF_DIM, | ||
| "num_transformer_blocks": NUM_TRANSFORMER_BLOCKS, | ||
| "mlp_units": MLP_UNITS[0], | ||
| "dropout": DROPOUT, | ||
| "mlp_dropout": MLP_DROPOUT, | ||
| "batch_size": BATCH_SIZE, | ||
| "epochs": EPOCHS, | ||
| "test_loss": test_loss, | ||
| "test_accuracy": test_accuracy | ||
| } | ||
| if not os.path.exists("/app/workdir/metrics.csv"): | ||
| with open("/app/workdir/metrics.csv", "w") as f: | ||
| f.write(",".join(metric.keys()) + "\n") | ||
| with open("/app/workdir/metrics.csv", "a") as f: | ||
| f.write(",".join([str(value) for value in metric.values()]) + "\n") | ||
|
|
||
| return | ||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
|
|
||
| # EOF |