Skip to content

Commit

Permalink
Improved workflow for model training, including proper monitoring for…
Browse files Browse the repository at this point in the history
… autoencoder training
  • Loading branch information
Dawith committed Mar 2, 2026
1 parent 520fbe3 commit e891126
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 9 deletions.
120 changes: 119 additions & 1 deletion train/autoencoder_train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,121 @@
# -*- coding: utf-8 -*-
#-*- coding: utf-8 -*-

from datetime import datetime
import time
import typing
from typing import List

import numpy as np
import os
import keras
from keras.metrics import MeanSquaredError
from keras import Model
from keras.callbacks import ModelCheckpoint, CSVLogger
import matplotlib.pyplot as plt

from model.model import CompoundModel
from model.metrics import MutualInformation, mutual_information
from visualize.visualize import confusion_matrix
from visualize.plot import roc_plot
from train.encoder_train import build_encoder
from train.decoder_train import build_decoder

def autoencoder_workflow(params, shape, n_classes,
train_set, validation_set, test_set,
categories, keys, path):

model = build_autoencoder(params, shape, n_classes)
model = train_autoencoder(params, model, train_set, validation_set, path)

m = {key: None for key in keys}
m, test_predict = test_autoencoder(
model,
test_set,
m
)
model_metrics = {metric: value for metric, value in m.items()}

evaluate_autoencoder(
params,
test_predict,
test_set[0],
categories,

keys,
path
)

save_autoencoder(params, model, path)

def build_autoencoder(params, shape, n_classes):
autoencoder_params = params["autoencoder_params"]
#mi = MutualInformation()
mse = MeanSquaredError()

encoder_model = build_encoder(params, shape, n_classes)
decoder_model = build_decoder(params, shape, n_classes)
model = keras.Sequential([encoder_model, decoder_model])
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=4e-4),
loss=autoencoder_params["loss"],
metrics=[mse]#, mutual_information]
)

return model

def train_autoencoder(params, model, train_set, validation_set, path):
log_level = params["log_level"]
timestamp = params["timestamp"]
params = params["autoencoder_params"]
callbacks = [
ModelCheckpoint(
filepath=path / timestamp / f"{timestamp}_checkpoint.keras",
monitor = "val_loss",
save_best_only=True,
save_weights_only=False,
verbose=1
),
CSVLogger(path / timestamp / f"{timestamp}_log.csv")
]

start = time.time()
model.fit(
x=train_set, y=train_set,
validation_data=(validation_set, validation_set),
batch_size=params["batch_size"],
epochs=params["epochs"],
verbose=log_level,
callbacks=callbacks
)
end = time.time()
print("Training time: ", end - start)
return model

def test_autoencoder(model: Model, test: List, metrics: dict):
"""
"""

test_eval = model.evaluate(test, test)
if len(metrics.keys()) == 1:
metrics[metrics.keys()[0]] = test_eval
else:
for i, key in enumerate(metrics.keys()):
metrics[key] = np.mean(test_eval[i])

test_predict = model.predict(test)[0]

return metrics, test_predict

def evaluate_autoencoder(params, test_predict, test_set, categories, keys, path):
plt.pcolor(test_set)
plt.savefig(path / params["timestamp"] / "original.png")
plt.close()
plt.pcolor(test_predict)
plt.savefig(path / params["timestamp"] / "reproduction.png")
plt.close()
return

def save_autoencoder(params, model, path):
model.save(path / params["timestamp"] / f"{params['timestamp']}_autoencoder.keras")

# EOF
17 changes: 13 additions & 4 deletions train/decoder_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from visualize.plot import spectra_plot

def decoder_workflow(params, train_set, validation_set, test_set,
n_classes, categories, keys):
decoder = load_decoder(params, train_set[0].shape[1:], n_classes)
n_classes, categories, keys, modelpath):
decoder = build_decoder(params, train_set[0].shape[1:], n_classes)

decoder = train_decoder(decoder, params, train_set, validation_set)
# Test model performance
Expand All @@ -26,7 +26,7 @@ def decoder_workflow(params, train_set, validation_set, test_set,
spectra_plot(test_predict[0], name=f"{target}-{treatment}-predict")
spectra_plot(test_set[0][0], name=f"{target}-{treatment}-true")

def load_decoder(params, input_shape, n_classes):
def build_decoder(params, input_shape, n_classes):
"""
"""

Expand Down Expand Up @@ -86,8 +86,17 @@ def test_decoder(decoder: Model, test: List, metrics: dict):

return metrics, test_predict

def save_decoder(decoder: Model):
def evaluate_decoder(params, test_predict, test_set, test_loss,
test_accuracy, categories, keys):
params = params["decoder_params"]

for predict, groundtruth, key in zip(test_predict, test_set[1], keys):
confusion_matrix(predict, groundtruth, categories[key], key)
roc_plot(predict, groundtruth, key)
save_metric(params, test_loss, test_accuracy)

def save_decoder(decoder: Model):
model.save(path + "decoder.keras")
return

# EOF
49 changes: 45 additions & 4 deletions train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,35 @@
Launches the training process for the model.
"""

# Built-in module imports
from datetime import datetime
import os
from pathlib import Path
import shutil as sh

# Environment variables
os.environ["KERAS_BACKEND"] = "jax"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

# 3rd party module imports
import jax
import json
from pyspark.sql import SparkSession
import keras
import matplotlib.pyplot as plt

# Local module imports
from pipe.etl import etl, read
from train.encoder_train import encoder_workflow
from train.decoder_train import decoder_workflow
from train.autoencoder_train import autoencoder_workflow

def parameters():
with open("parameters.json", "r") as file:
params = json.load(file)
return params

def multi_gpu():
def model_parallel():
devices = jax.devices("gpu")
mesh = keras.distribution.DeviceMesh(
shape=(len(devices),), axis_names=["model"], devices=devices
Expand All @@ -45,14 +53,25 @@ def multi_gpu():
model_parallel = keras.distribution.ModelParallel(
layout_map=layout_map
)
keras.distribution.set_distribution(model_parallel)
return model_parallel

def data_parallel():
devices = jax.devices("gpu")
data_parallel = keras.distribution.DataParallel(devices=devices)
mesh = keras.distribution.DeviceMesh(
shape=(len(devices),), axis_names=["data"], devices=devices
)
data_parallel = keras.distribution.DataParallel(mesh)

return data_parallel

def main():
# jax mesh setup
params = parameters()
spark = SparkSession.builder.appName("train").getOrCreate()
keys = ["treatment", "target"]
#parallel = data_parallel()
#keras.distribution.set_distribution(parallel)

if params["load_from_scratch"]:
data = etl(spark, split=params["encoder_params"]["split"])
Expand All @@ -63,7 +82,14 @@ def main():

n_classes = [dset.shape[1] for dset in train_set[1]]
shape = train_set[0].shape[1:]
path = Path("/app/workdir/model")

strfmt = "%Y%m%d_%H%M%S"
timestamp = datetime.now().strftime(strfmt)
params["timestamp"] = timestamp
os.mkdir(path/timestamp)
sh.copy("parameters.json", path/timestamp/"parameters.json")

if params["train_encoder"]:
encoder_workflow(
params,
Expand All @@ -73,7 +99,8 @@ def main():
validation_set,
test_set,
categories,
keys
keys,
path
)

if params["train_decoder"]:
Expand All @@ -84,7 +111,21 @@ def main():
test_set,
n_classes,
categories,
keys
keys,
path
)

if params["train_autoencoder"]:
autoencoder_workflow(
params,
shape,
n_classes,
train_set[0],
validation_set[0],
test_set[0],
categories,
keys,
path
)

return
Expand Down

0 comments on commit e891126

Please sign in to comment.