Skip to content

Commit

Permalink
encoder workflow moved to modularized files for better organization a…
Browse files Browse the repository at this point in the history
…nd accessibility
  • Loading branch information
Dawith committed Oct 28, 2025
1 parent 6403973 commit b3bfa1b
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 199 deletions.
86 changes: 86 additions & 0 deletions train/encoder_train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,89 @@
#-*- coding: utf-8 -*-

import time

import os
import keras

from model.model import CompoundModel
from visualize.visualize import confusion_matrix
from visualize.plot import roc_plot

def build_encoder(params, input_shape, n_classes):
model = CompoundModel(
input_shape,
params["head_size"],
params["num_heads"],
params["ff_dim"],
params["num_transformer_blocks"],
params["mlp_units"],
n_classes,
dropout=params["dropout"],
mlp_dropout=params["mlp_dropout"]
)

model.compile(
optimizer=keras.optimizers.Adam(learning_rate=4e-4),
loss="categorical_crossentropy",
metrics=["categorical_accuracy", "categorical_accuracy"]
)
if params["log_level"] == 1:
model.summary()

return model

def train_encoder(params, model, train_set, validation_set):
start = time.time()
model.fit(
x=train_set[0], y=train_set[1],
validation_data=(validation_set[0], validation_set[1]),
batch_size=params["batch_size"],
epochs=params["epochs"],
verbose=params["log_level"]
)
end = time.time()
print("Training time: ", end - start)
return model

def test_encoder(params, model, test_set, categories, keys):
# Test model performance
test_loss, test_accuracy, _, _, _, _ = model.evaluate(
test_set[0],
test_set[1]
)

test_predict = model.predict(test_set[0])
print(f"Test loss: {test_loss}, test accuracy: {test_accuracy}")
return test_predict, test_loss, test_accuracy

def evaluate_encoder(params, test_predict, test_set, test_loss, test_accuracy, categories, keys):
for predict, groundtruth, key in zip(test_predict, test_set[1], keys):
confusion_matrix(predict, groundtruth, categories[key], key)
roc_plot(predict, groundtruth, key)
save_metric(params, test_loss, test_accuracy)

def save_metric(params, test_loss, test_accuracy):
"""
Save the hyperparameters and metric to csv
"""

metric = {
"head_size": params["head_size"],
"num_heads": params["num_heads"],
"ff_dim": params["ff_dim"],
"num_transformer_blocks": params["num_transformer_blocks"],
"mlp_units": params["mlp_units"][0],
"dropout": params["dropout"],
"mlp_dropout": params["mlp_dropout"],
"batch_size": params["batch_size"],
"epochs": params["epochs"],
"test_loss": test_loss,
"test_accuracy": test_accuracy
}
if not os.path.exists("/app/workdir/metrics.csv"):
with open("/app/workdir/metrics.csv", "w") as f:
f.write(",".join(metric.keys()) + "\n")
with open("/app/workdir/metrics.csv", "a") as f:
f.write(",".join([str(value) for value in metric.values()]) + "\n")

# EOF
233 changes: 35 additions & 198 deletions train_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,130 +6,26 @@
"""

import os
from pathlib import Path
import time

os.environ["KERAS_BACKEND"] = "jax"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import jax
import json
import numpy as np
#from pyspark.ml.feature import OneHotEncoder, StringIndexer
from pyspark.sql import SparkSession, functions, types, Row
import pyspark as spark
import tensorflow as tf
from pyspark.sql import SparkSession
import keras
import matplotlib.pyplot as plt
from sklearn.metrics import auc, confusion_matrix, roc_curve
from sklearn.preprocessing import OneHotEncoder, LabelBinarizer

from model.model import CompoundModel
from pipe.etl import etl, read
from train.encoder_train import build_encoder, evaluate_encoder, \
train_encoder, test_encoder

with open("parameters.json", "r") as file:
params = json.load(file)
def parameters():
with open("parameters.json", "r") as file:
params = json.load(file)
return params

# data parameters
SPLIT = params["split"]
LOAD_FROM_SCRATCH = params["load_from_scratch"]

# model parameters
HEAD_SIZE = params["head_size"]
NUM_HEADS = params["num_heads"]
FF_DIM = params["ff_dim"]
NUM_TRANSFORMER_BLOCKS = params["num_transformer_blocks"]
MLP_UNITS = params["mlp_units"]
DROPOUT = params["dropout"]
MLP_DROPOUT = params["mlp_dropout"]

# training parameters
BATCH_SIZE = params["batch_size"]
EPOCHS = params["epochs"]
LOG_LEVEL = params["log_level"]
del params

def trim(dataframe, column):

ndarray = np.array(dataframe.select(column).collect()) \
.reshape(-1, 32, 130)

return ndarray

def get_model(input_shape, n_classes):
model = CompoundModel(input_shape, HEAD_SIZE, NUM_HEADS, FF_DIM,
NUM_TRANSFORMER_BLOCKS, MLP_UNITS, n_classes,
dropout=DROPOUT, mlp_dropout=MLP_DROPOUT)
return model

def transform(spark, dataframe, keys):
dataframe = dataframe.withColumn(
"index", functions.monotonically_increasing_id()
)
bundle = {key: [
arr.tolist()
for arr in OneHotEncoder(sparse_output=False) \
.fit_transform(dataframe.select(key).collect())
] for key in keys
}

bundle = [dict(zip(bundle.keys(), values))
for values in zip(*bundle.values())]
schema = types.StructType([
types.StructField(key, types.ArrayType(types.FloatType()), True)
for key in keys
])
newframe = spark.createDataFrame(bundle, schema=schema).withColumn(
"index", functions.monotonically_increasing_id()
)
for key in keys:
dataframe = dataframe.withColumnRenamed(key, f"{key}_str")
dataframe = dataframe.join(newframe, on="index", how="inner")

return dataframe

def build_dict(df, key):
df = df.select(key, f"{key}_str").distinct()

return df.rdd.map(
lambda row: (str(np.argmax(row[key])), row[f"{key}_str"])
).collectAsMap()


def get_data(spark, split=[0.99, 0.005, 0.005]):
path = Path("/app/workdir")

labels = []
with open(path / "train.csv", "r") as file:
for line in file:
labels.append(line.strip().split(",")[0])

pipe = SpectrogramPipe(spark, filetype="matfiles")
data = pipe.spectrogram_pipe(path, labels)
data.select("treatment").replace("virus", "cpv") \
.replace("cont", "pbs") \
.replace("control", "pbs") \
.replace("dld", "pbs").distinct()

data = transform(spark, data, ["treatment", "target"])
category_dict = {
key: build_dict(data, key) for key in ["treatment", "target"]
}
splits = data.randomSplit(split, seed=42)
trainx, valx, testx = (trim(dset, "spectra") for dset in splits)
trainy, valy, testy = (
[np.array(dset.select("treatment").collect()).squeeze(),
np.array(dset.select("target").collect()).squeeze()]
for dset in splits
)


return ((trainx, trainy), (valx, valy), (testx, testy), category_dict)


def main():
# jax mesh setup
"""
def multi_gpu():
devices = jax.devices("gpu")
mesh = keras.distribution.DeviceMesh(
shape=(len(devices),), axis_names=["model"], devices=devices
Expand All @@ -150,102 +46,43 @@ def main():
layout_map=layout_map
)
keras.distribution.set_distribution(model_parallel)
"""

spark = SparkSession.builder.appName("train").getOrCreate()

def main():
# jax mesh setup
params = parameters()
spark = SparkSession.builder.appName("train").getOrCreate()
keys = ["treatment", "target"]

if LOAD_FROM_SCRATCH:
data = etl(spark, split=SPLIT)
if params["load_from_scratch"]:
data = etl(spark, split=params["split"])
else:
data = read(spark, split=SPLIT)
data = read(spark, split=params["split"])

(train_set, validation_set, test_set, categories) = data

n_classes = [dset.shape[1] for dset in train_set[1]]
model = get_model(train_set[0].shape[1:], n_classes)
model.compile(optimizer=keras.optimizers.Adam(learning_rate=4e-4),
loss="categorical_crossentropy",
metrics=["categorical_accuracy", "categorical_accuracy"]
)
if LOG_LEVEL == 1:
model.summary()

start = time.time()
model.fit(x=train_set[0], y=train_set[1],
validation_data=(validation_set[0], validation_set[1]),
batch_size=BATCH_SIZE, epochs=EPOCHS, verbose=LOG_LEVEL)
end = time.time()
print("Training time: ", end - start)

# Test model performance
test_loss, test_accuracy, _, _, _, _ = model.evaluate(test_set[0], test_set[1])
test_predict = model.predict(test_set[0])
print(f"Test loss: {test_loss}, test accuracy: {test_accuracy}")
for predict, groundtruth, key in zip(test_predict, test_set[1], keys):
conf_matrix = confusion_matrix(
np.argmax(predict, axis=1),
np.argmax(groundtruth, axis=1),
labels=range(len(categories[key].values())),
normalize="pred"
)
plt.pcolormesh(conf_matrix, edgecolors="black", linewidth=0.5)#origin="upper")
plt.gca().set_aspect("equal")
plt.colorbar()
plt.xticks([int(num) for num in categories[key].keys()],
categories[key].values(), rotation=270)
plt.yticks([int(num) for num in categories[key].keys()],
categories[key].values())
plt.xlabel("True label")
plt.ylabel("Predicted label")
plt.gcf().set_size_inches(len(categories[key])/10+4,
len(categories[key])/10+3)
plt.savefig(f"/app/workdir/figures/confusion_matrix_{key}.png",
bbox_inches="tight")
plt.close()
with open(f"confusion_matrix_{key}.json", 'w') as f:
confusion_dict = {"prediction": predict.tolist(),
"true": groundtruth.tolist(),
"matrix": conf_matrix.tolist()}
json.dump(confusion_dict, f)

label_binarizer = LabelBinarizer().fit(groundtruth)
y_onehot_test = label_binarizer.transform(groundtruth)
fpr, tpr, _ = roc_curve(
groundtruth.ravel(),
predict.ravel()
)
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, label=f"AUC = {roc_auc:.2f}")
plt.savefig(f"/app/workdir/figures/roc_curve_{key}.png",
bbox_inches="tight")
with open(f"roc_fpr_tpr_{key}.json", 'w') as f:
roc_dict = {"fpr": fpr.tolist(),
"tpr": tpr.tolist(),
"auc": roc_auc}
json.dump(roc_dict, f)
print("Done")
shape = train_set[0].shape[1:]

model = build_encoder(params, shape, n_classes)
model = train_encoder(params, model, train_set, validation_set)
test_predict, test_loss, test_accuracy = test_encoder(
params,
model,
test_set,
categories,
keys
)

# Save the hyperparameters and metric to csv
metric = {
"head_size": HEAD_SIZE,
"num_heads": NUM_HEADS,
"ff_dim": FF_DIM,
"num_transformer_blocks": NUM_TRANSFORMER_BLOCKS,
"mlp_units": MLP_UNITS[0],
"dropout": DROPOUT,
"mlp_dropout": MLP_DROPOUT,
"batch_size": BATCH_SIZE,
"epochs": EPOCHS,
"test_loss": test_loss,
"test_accuracy": test_accuracy
}
if not os.path.exists("/app/workdir/metrics.csv"):
with open("/app/workdir/metrics.csv", "w") as f:
f.write(",".join(metric.keys()) + "\n")
with open("/app/workdir/metrics.csv", "a") as f:
f.write(",".join([str(value) for value in metric.values()]) + "\n")
evaluate_encoder(
params,
test_predict,
test_set,
test_loss,
test_accuracy,
categories,
keys
)

return

Expand Down
Loading

0 comments on commit b3bfa1b

Please sign in to comment.