Skip to content

Commit

Permalink
Combined model training code with improved visuals
Browse files Browse the repository at this point in the history
  • Loading branch information
Dawith committed Nov 3, 2025
1 parent 288de9d commit 4878ea5
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 38 deletions.
30 changes: 20 additions & 10 deletions train/decoder_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,35 @@

import keras
from keras import Model
import numpy as np
import time
import typing
from typing import List

from model.model import DecoderModel
"""
# data parameters
SPLIT = params["split"]
LOAD_FROM_SCRATCH = params["load_from_scratch"]
from visualize.plot import spectra_plot

def decoder_workflow(params, train_set, validation_set, test_set,
n_classes, categories, keys):
decoder = load_decoder(params, train_set[0].shape[1:], n_classes)

# training parameters
BATCH_SIZE = params["batch_size"]
EPOCHS = params["epochs"]
LOG_LEVEL = params["log_level"]
"""
decoder = train_decoder(decoder, params, train_set, validation_set)
# Test model performance

metrics = {key: None for key in keys}
metrics, test_predict = test_decoder(decoder, test_set, metrics)

target = categories["target"][str(np.argmax(test_set[1][1][0]))]
treatment = categories["treatment"][str(np.argmax(test_set[1][0][0]))]

spectra_plot(test_predict[0], name=f"{target}-{treatment}-predict")
spectra_plot(test_set[0][0], name=f"{target}-{treatment}-true")

def load_decoder(params, input_shape, n_classes):
"""
"""

params = params["decoder_params"]
decoder = DecoderModel(
input_shape,
params["head_size"],
Expand All @@ -47,14 +55,16 @@ def train_decoder(decoder, params, train, validation):
"""
"""

log_level = params["log_level"]
params = params["decoder_params"]
start = time.time()
decoder.fit(
x=train[1],
y=train[0],
validation_data=(validation[1], validation[0]),
batch_size=params["batch_size"],
epochs=params["epochs"],
verbose=params["log_level"]
verbose=log_level
)
end = time.time()
print("Training time: ", end - start)
Expand Down
38 changes: 34 additions & 4 deletions train/encoder_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,33 @@
from visualize.visualize import confusion_matrix
from visualize.plot import roc_plot

def encoder_workflow(params, shape, n_classes,
train_set, validation_set, test_set,
categories, keys):
model = build_encoder(params, shape, n_classes)
model = train_encoder(params, model, train_set, validation_set)
test_predict, test_loss, test_accuracy = test_encoder(
params,
model,
test_set,
categories,
keys
)

evaluate_encoder(
params,
test_predict,
test_set,
test_loss,
test_accuracy,
categories,
keys
)


def build_encoder(params, input_shape, n_classes):
log_level = params["log_level"]
params = params["encoder_params"]
model = CompoundModel(
input_shape,
params["head_size"],
Expand All @@ -24,28 +50,31 @@ def build_encoder(params, input_shape, n_classes):

model.compile(
optimizer=keras.optimizers.Adam(learning_rate=4e-4),
loss="categorical_crossentropy",
metrics=["categorical_accuracy", "categorical_accuracy"]
loss=params["loss"],
metrics=params["metrics"]
)
if params["log_level"] == 1:
if log_level == 1:
model.summary()

return model

def train_encoder(params, model, train_set, validation_set):
log_level = params["log_level"]
params = params["encoder_params"]
start = time.time()
model.fit(
x=train_set[0], y=train_set[1],
validation_data=(validation_set[0], validation_set[1]),
batch_size=params["batch_size"],
epochs=params["epochs"],
verbose=params["log_level"]
verbose=log_level
)
end = time.time()
print("Training time: ", end - start)
return model

def test_encoder(params, model, test_set, categories, keys):
params = params["encoder_params"]
# Test model performance
test_loss, test_accuracy, _, _, _, _ = model.evaluate(
test_set[0],
Expand All @@ -57,6 +86,7 @@ def test_encoder(params, model, test_set, categories, keys):
return test_predict, test_loss, test_accuracy

def evaluate_encoder(params, test_predict, test_set, test_loss, test_accuracy, categories, keys):
params = params["encoder_params"]
for predict, groundtruth, key in zip(test_predict, test_set[1], keys):
confusion_matrix(predict, groundtruth, categories[key], key)
roc_plot(predict, groundtruth, key)
Expand Down
49 changes: 26 additions & 23 deletions train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
import matplotlib.pyplot as plt

from pipe.etl import etl, read
from train.encoder_train import build_encoder, evaluate_encoder, \
train_encoder, test_encoder
from train.encoder_train import encoder_workflow
from train.decoder_train import decoder_workflow

def parameters():
with open("parameters.json", "r") as file:
Expand Down Expand Up @@ -55,34 +55,37 @@ def main():
keys = ["treatment", "target"]

if params["load_from_scratch"]:
data = etl(spark, split=params["split"])
data = etl(spark, split=params["encoder_params"]["split"])
else:
data = read(spark, split=params["split"])
data = read(spark, split=params["encoder_params"]["split"])

(train_set, validation_set, test_set, categories) = data

n_classes = [dset.shape[1] for dset in train_set[1]]
shape = train_set[0].shape[1:]

if params["train_encoder"]:
encoder_workflow(
params,
shape,
n_classes,
train_set,
validation_set,
test_set,
categories,
keys
)

model = build_encoder(params, shape, n_classes)
model = train_encoder(params, model, train_set, validation_set)
test_predict, test_loss, test_accuracy = test_encoder(
params,
model,
test_set,
categories,
keys
)

evaluate_encoder(
params,
test_predict,
test_set,
test_loss,
test_accuracy,
categories,
keys
)
if params["train_decoder"]:
decoder_workflow(
params,
train_set,
validation_set,
test_set,
n_classes,
categories,
keys
)

return

Expand Down
4 changes: 3 additions & 1 deletion visualize/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def roc_plot(predict, groundtruth, key):
plt.ylabel("True Positive Rate")
plt.savefig(f"/app/workdir/figures/roc_curve_{key}.png",
bbox_inches="tight")
plt.close()
with open(f"roc_fpr_tpr_{key}.json", 'w') as f:
roc_dict = {"fpr": fpr.tolist(),
"tpr": tpr.tolist(),
Expand All @@ -55,15 +56,16 @@ def spectra_plot(spec_array, name=None):
"""

spec_array = spec_array[:,:130]
print(spec_array.shape)
plt.pcolor(spec_array, cmap="bwr", vmin=-1, vmax=1)
plt.title(f"{name} spectra")
plt.colorbar()
plt.savefig(f"/app/workdir/figures/{name}_spectra_plot.png")
plt.close()

spec_array -= np.mean(spec_array[:6,:], axis=0)
plt.pcolor(spec_array, cmap="bwr", vmin=-1, vmax=1)
plt.title(f"{name} spectrogram")
plt.colorbar()
plt.savefig(f"/app/workdir/figures/{name}_spectrogram_plot.png")
plt.close()

Expand Down
14 changes: 14 additions & 0 deletions visualize/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,20 @@ def confusion_matrix(prediction, groundtruth, labels, key):
"matrix": conf_matrix.tolist()}
json.dump(confusion_dict, f)

def similarity_matrix(array1, array2, label1, label2):
array1 = np.array(array1)
array2 = np.array(array2)
sim_matrix = np.array([
[np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
for vec2 in array2]
for vec1 in array1
])

plt.pcolor(sim_matrix, edgecolors="black", vmin=-1, vmax=1, linewidth=0.5)
plt.gca().set_aspect("equal")
plt.colorbar()



def visualize_data_distribution(data: DataFrame) -> None:
"""
Expand Down

0 comments on commit 4878ea5

Please sign in to comment.