Skip to content

Commit

Permalink
Decoder training steps organized into dedicated module
Browse files Browse the repository at this point in the history
  • Loading branch information
lim185 committed Oct 24, 2025
1 parent e74353e commit 829bee1
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 67 deletions.
72 changes: 18 additions & 54 deletions decodertest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import os
from pathlib import Path
import time

os.environ["KERAS_BACKEND"] = "jax"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
Expand All @@ -27,28 +26,7 @@

from model.model import CompoundModel, DecoderModel
from pipe.etl import etl, read

with open("parameters.json", "r") as file:
params = json.load(file)

# data parameters
SPLIT = params["split"]
LOAD_FROM_SCRATCH = params["load_from_scratch"]

# model parameters
HEAD_SIZE = params["head_size"]
NUM_HEADS = params["num_heads"]
FF_DIM = params["ff_dim"]
NUM_TRANSFORMER_BLOCKS = params["num_transformer_blocks"]
MLP_UNITS = params["mlp_units"]
DROPOUT = params["dropout"]
MLP_DROPOUT = params["mlp_dropout"]

# training parameters
BATCH_SIZE = params["batch_size"]
EPOCHS = params["epochs"]
LOG_LEVEL = params["log_level"]
del params
import train.decoder_train as dt

def trim(dataframe, column):

Expand All @@ -57,12 +35,6 @@ def trim(dataframe, column):

return ndarray

def get_model(input_shape, n_classes):
model = DecoderModel(input_shape, HEAD_SIZE, NUM_HEADS, FF_DIM,
NUM_TRANSFORMER_BLOCKS, MLP_UNITS, n_classes,
dropout=DROPOUT, mlp_dropout=MLP_DROPOUT)
return model

def transform(spark, dataframe, keys):
dataframe = dataframe.withColumn(
"index", functions.monotonically_increasing_id()
Expand Down Expand Up @@ -127,10 +99,7 @@ def get_data(spark, split=[0.99, 0.005, 0.005]):

return ((trainx, trainy), (valx, valy), (testx, testy), category_dict)


def main():
# jax mesh setup
"""
def multi_gpu():
devices = jax.devices("gpu")
mesh = keras.distribution.DeviceMesh(
shape=(len(devices),), axis_names=["model"], devices=devices
Expand All @@ -151,38 +120,33 @@ def main():
layout_map=layout_map
)
keras.distribution.set_distribution(model_parallel)
"""

def main():

with open("parameters.json", "r") as file:
params = json.load(file)

spark = SparkSession.builder.appName("train").getOrCreate()

keys = ["treatment", "target"]

LOAD_FROM_SCRATCH = False
if LOAD_FROM_SCRATCH:
data = etl(spark, split=SPLIT)
if params["load_from_scratch"]:
(train_set, validation_set,
test_set, categories) = etl(spark, split=params["split"])
else:
data = read(spark, split=SPLIT)
(train_set, validation_set,
test_set, categories) = read(spark, split=params["split"])

(train_set, validation_set, test_set, categories) = data

n_classes = [dset.shape[1] for dset in train_set[1]]
model = get_model(train_set[0].shape[1:], n_classes)
model.compile(optimizer=keras.optimizers.Adam(learning_rate=4e-4),
loss=["mse"],
)
if True: #LOG_LEVEL == 1:
model.summary()

start = time.time()
model.fit(x=train_set[1], y=train_set[0],
validation_data=(validation_set[1], validation_set[0]),
batch_size=BATCH_SIZE, epochs=EPOCHS, verbose=LOG_LEVEL)
end = time.time()
print("Training time: ", end - start)
decoder = dt.load_decoder(params, train_set[0].shape[1:], n_classes)

decoder = dt.train_decoder(decoder, params, train_set, validation_set)
# Test model performance
#test_loss, test_accuracy = model.evaluate(test_set[1], test_set[0])
test_predict = model.predict(test_set[1])

metrics = {key: None for key in keys}
metrics, test_predict = dt.test_decoder(decoder, test_set, metrics)

plt.pcolor(test_predict[0], cmap='bwr', clim=(-1, 1))
plt.savefig("sample_spectra.png")
plt.close()
Expand Down
67 changes: 54 additions & 13 deletions train/decoder_train.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,83 @@
#-*- coding: utf-8 -*-

import keras
from keras import Model
import time
import typing
from typing import List

def load_decoder(input_shape, n_classes):
"""
"""
from model.model import DecoderModel
"""
# data parameters
SPLIT = params["split"]
LOAD_FROM_SCRATCH = params["load_from_scratch"]
model = DecoderModel(input_shape, HEAD_SIZE, NUM_HEADS, FF_DIM,
NUM_TRANSFORMER_BLOCKS, MLP_UNITS, n_classes,
dropout=DROPOUT, mlp_dropout=MLP_DROPOUT)
return model
def train_decoder(decoder, train, validation):
# training parameters
BATCH_SIZE = params["batch_size"]
EPOCHS = params["epochs"]
LOG_LEVEL = params["log_level"]
"""

def load_decoder(params, input_shape, n_classes):
"""
"""

decoder = load_decoder(train_set[0].shape[1:], n_classes)
decoder = DecoderModel(
input_shape,
params["head_size"],
params["num_heads"],
params["ff_dim"],
params["num_transformer_blocks"],
params["mlp_units"],
n_classes,
params["dropout"],
params["mlp_dropout"]
)

decoder.compile(
optimizer=keras.optimizers.Adam(learning_rate=4e-4),
loss=["mse"],
loss=params["loss"],
metrics=params["metrics"]
)

return decoder

def train_decoder(decoder, params, train, validation):
"""
"""

start = time.time()
decoder.fit(
x=train[1],
y=train[0],
validation_data=(validation[1], validation[0]),
batch_size=params["batch_size"],
epochs=params["epochs"],
verbose=params["log_level"]
)
end = time.time()
print("Training time: ", end - start)

return decoder

def test_decoder(decoder: Model, test: List, metrics: dict):
"""
"""

test_eval = model.evaluate(test[1], test[0])
test_eval = decoder.evaluate(test[1], test[0])
if len(metrics.keys()) == 1:
metrics[metrics.keys()[0]] = test_eval
else:
for i, key in enumerate(metrics.keys()):
metrics[key] = test_eval[i]

test_predict = model.predict(test[1])
test_predict = decoder.predict(test[1])

return metrics, test_predict

def save_decoder():
def save_decoder(decoder: Model):

return

# EOF

0 comments on commit 829bee1

Please sign in to comment.