Skip to content

Commit

Permalink
Decoder training code to test the decoder model build
Browse files Browse the repository at this point in the history
  • Loading branch information
Dawith committed Oct 24, 2025
1 parent c971c8a commit b29467a
Showing 1 changed file with 264 additions and 0 deletions.
264 changes: 264 additions & 0 deletions decodertest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
# --*-- coding: utf-8 --*--
"""
train.py
Launches the training process for the model.
"""

import os
from pathlib import Path
import time

os.environ["KERAS_BACKEND"] = "jax"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import jax
import json
import numpy as np
#from pyspark.ml.feature import OneHotEncoder, StringIndexer
from pyspark.sql import SparkSession, functions, types, Row
import pyspark as spark
import tensorflow as tf
import keras
import keras.metrics as metrics
import matplotlib.pyplot as plt
from sklearn.metrics import auc, confusion_matrix, roc_curve
from sklearn.preprocessing import OneHotEncoder, LabelBinarizer

from model.model import CompoundModel, DecoderModel
from pipe.etl import etl, read

with open("parameters.json", "r") as file:
params = json.load(file)

# data parameters
SPLIT = params["split"]
LOAD_FROM_SCRATCH = params["load_from_scratch"]

# model parameters
HEAD_SIZE = params["head_size"]
NUM_HEADS = params["num_heads"]
FF_DIM = params["ff_dim"]
NUM_TRANSFORMER_BLOCKS = params["num_transformer_blocks"]
MLP_UNITS = params["mlp_units"]
DROPOUT = params["dropout"]
MLP_DROPOUT = params["mlp_dropout"]

# training parameters
BATCH_SIZE = params["batch_size"]
EPOCHS = params["epochs"]
LOG_LEVEL = params["log_level"]
del params

def trim(dataframe, column):

ndarray = np.array(dataframe.select(column).collect()) \
.reshape(-1, 32, 130)

return ndarray

def get_model(input_shape, n_classes):
model = DecoderModel(input_shape, HEAD_SIZE, NUM_HEADS, FF_DIM,
NUM_TRANSFORMER_BLOCKS, MLP_UNITS, n_classes,
dropout=DROPOUT, mlp_dropout=MLP_DROPOUT)
return model

def transform(spark, dataframe, keys):
dataframe = dataframe.withColumn(
"index", functions.monotonically_increasing_id()
)
bundle = {key: [
arr.tolist()
for arr in OneHotEncoder(sparse_output=False) \
.fit_transform(dataframe.select(key).collect())
] for key in keys
}

bundle = [dict(zip(bundle.keys(), values))
for values in zip(*bundle.values())]
schema = types.StructType([
types.StructField(key, types.ArrayType(types.FloatType()), True)
for key in keys
])
newframe = spark.createDataFrame(bundle, schema=schema).withColumn(
"index", functions.monotonically_increasing_id()
)
for key in keys:
dataframe = dataframe.withColumnRenamed(key, f"{key}_str")
dataframe = dataframe.join(newframe, on="index", how="inner")

return dataframe

def build_dict(df, key):
df = df.select(key, f"{key}_str").distinct()

return df.rdd.map(
lambda row: (str(np.argmax(row[key])), row[f"{key}_str"])
).collectAsMap()


def get_data(spark, split=[0.99, 0.005, 0.005]):
path = Path("/app/workdir")

labels = []
with open(path / "train.csv", "r") as file:
for line in file:
labels.append(line.strip().split(",")[0])

pipe = SpectrogramPipe(spark, filetype="matfiles")
data = pipe.spectrogram_pipe(path, labels)
data.select("treatment").replace("virus", "cpv") \
.replace("cont", "pbs") \
.replace("control", "pbs") \
.replace("dld", "pbs").distinct()

data = transform(spark, data, ["treatment", "target"])
category_dict = {
key: build_dict(data, key) for key in ["treatment", "target"]
}
splits = data.randomSplit(split, seed=42)
trainx, valx, testx = (trim(dset, "spectra") for dset in splits)
trainy, valy, testy = (
[np.array(dset.select("treatment").collect()).squeeze(),
np.array(dset.select("target").collect()).squeeze()]
for dset in splits
)


return ((trainx, trainy), (valx, valy), (testx, testy), category_dict)


def main():
# jax mesh setup
"""
devices = jax.devices("gpu")
mesh = keras.distribution.DeviceMesh(
shape=(len(devices),), axis_names=["model"], devices=devices
)
layout_map = keras.distribution.LayoutMap(mesh)
layout_map["dense.*kernel"] = (None, "model")
layout_map["dense.*bias"] = ("model",)
layout_map["dense.*kernel_regularizer"] = (None, "model")
layout_map["dense.*bias_regularizer"] = ("model",)
layout_map["dense.*activity_regularizer"] = (None,)
layout_map["dense.*kernel_constraint"] = (None, "model")
layout_map["dense.*bias_constraint"] = ("model",)
layout_map["conv1d.*kernel"] = (None, None, None, "model")
layout_map["conv1d.*kernel_regularizer"] = (None, None, None, "model")
layout_map["conv1d.*bias_regularizer"] = ("model",)
model_parallel = keras.distribution.ModelParallel(
layout_map=layout_map
)
keras.distribution.set_distribution(model_parallel)
"""

spark = SparkSession.builder.appName("train").getOrCreate()

keys = ["treatment", "target"]

LOAD_FROM_SCRATCH = False
if LOAD_FROM_SCRATCH:
data = etl(spark, split=SPLIT)
else:
data = read(spark, split=SPLIT)

(train_set, validation_set, test_set, categories) = data

n_classes = [dset.shape[1] for dset in train_set[1]]
model = get_model(train_set[0].shape[1:], n_classes)
model.compile(optimizer=keras.optimizers.Adam(learning_rate=4e-4),
loss=["mse"],
)
if True: #LOG_LEVEL == 1:
model.summary()

start = time.time()
model.fit(x=train_set[1], y=train_set[0],
validation_data=(validation_set[1], validation_set[0]),
batch_size=BATCH_SIZE, epochs=EPOCHS, verbose=LOG_LEVEL)
end = time.time()
print("Training time: ", end - start)

# Test model performance
#test_loss, test_accuracy = model.evaluate(test_set[1], test_set[0])
test_predict = model.predict(test_set[1])
plt.pcolor(test_predict[0], cmap='bwr', clim=(-1, 1))
plt.savefig("sample_spectra.png")
plt.close()
plt.pcolor(test_predict[0] - np.mean(np.array(test_predict[0])[:4], axis=0),
cmap='bwr', clim=(-1, 1))
plt.savefig("sample_spectrogram.png")
plt.close()
exit()
print(f"Test loss: {test_loss}, test accuracy: {test_accuracy}")
for predict, groundtruth, key in zip(test_predict, test_set[1], keys):
conf_matrix = confusion_matrix(
np.argmax(predict, axis=1),
np.argmax(groundtruth, axis=1),
labels=range(len(categories[key].values())),
normalize="pred"
)
plt.pcolormesh(conf_matrix, edgecolors="black", linewidth=0.5)#origin="upper")
plt.gca().set_aspect("equal")
plt.colorbar()
plt.xticks([int(num) for num in categories[key].keys()],
categories[key].values(), rotation=270)
plt.yticks([int(num) for num in categories[key].keys()],
categories[key].values())
plt.xlabel("True label")
plt.ylabel("Predicted label")
plt.gcf().set_size_inches(len(categories[key])/10+4,
len(categories[key])/10+3)
plt.savefig(f"/app/workdir/figures/confusion_matrix_{key}.png",
bbox_inches="tight")
plt.close()
with open(f"confusion_matrix_{key}.json", 'w') as f:
confusion_dict = {"prediction": predict.tolist(),
"true": groundtruth.tolist(),
"matrix": conf_matrix.tolist()}
json.dump(confusion_dict, f)

label_binarizer = LabelBinarizer().fit(groundtruth)
y_onehot_test = label_binarizer.transform(groundtruth)
fpr, tpr, _ = roc_curve(
groundtruth.ravel(),
predict.ravel()
)
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, label=f"AUC = {roc_auc:.2f}")
plt.savefig(f"/app/workdir/figures/roc_curve_{key}.png",
bbox_inches="tight")
with open(f"roc_fpr_tpr_{key}.json", 'w') as f:
roc_dict = {"fpr": fpr.tolist(),
"tpr": tpr.tolist(),
"auc": roc_auc}
json.dump(roc_dict, f)
print("Done")

# Save the hyperparameters and metric to csv
metric = {
"head_size": HEAD_SIZE,
"num_heads": NUM_HEADS,
"ff_dim": FF_DIM,
"num_transformer_blocks": NUM_TRANSFORMER_BLOCKS,
"mlp_units": MLP_UNITS[0],
"dropout": DROPOUT,
"mlp_dropout": MLP_DROPOUT,
"batch_size": BATCH_SIZE,
"epochs": EPOCHS,
"test_loss": test_loss,
"test_accuracy": test_accuracy
}
if not os.path.exists("/app/workdir/metrics.csv"):
with open("/app/workdir/metrics.csv", "w") as f:
f.write(",".join(metric.keys()) + "\n")
with open("/app/workdir/metrics.csv", "a") as f:
f.write(",".join([str(value) for value in metric.values()]) + "\n")

return

if __name__ == "__main__":
main()

# EOF

0 comments on commit b29467a

Please sign in to comment.