Skip to content

Decoder training modularized #1

Merged
merged 2 commits into from
Oct 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 18 additions & 54 deletions decodertest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import os
from pathlib import Path
import time

os.environ["KERAS_BACKEND"] = "jax"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
Expand All @@ -27,28 +26,7 @@

from model.model import CompoundModel, DecoderModel
from pipe.etl import etl, read

with open("parameters.json", "r") as file:
params = json.load(file)

# data parameters
SPLIT = params["split"]
LOAD_FROM_SCRATCH = params["load_from_scratch"]

# model parameters
HEAD_SIZE = params["head_size"]
NUM_HEADS = params["num_heads"]
FF_DIM = params["ff_dim"]
NUM_TRANSFORMER_BLOCKS = params["num_transformer_blocks"]
MLP_UNITS = params["mlp_units"]
DROPOUT = params["dropout"]
MLP_DROPOUT = params["mlp_dropout"]

# training parameters
BATCH_SIZE = params["batch_size"]
EPOCHS = params["epochs"]
LOG_LEVEL = params["log_level"]
del params
import train.decoder_train as dt

def trim(dataframe, column):

Expand All @@ -57,12 +35,6 @@ def trim(dataframe, column):

return ndarray

def get_model(input_shape, n_classes):
model = DecoderModel(input_shape, HEAD_SIZE, NUM_HEADS, FF_DIM,
NUM_TRANSFORMER_BLOCKS, MLP_UNITS, n_classes,
dropout=DROPOUT, mlp_dropout=MLP_DROPOUT)
return model

def transform(spark, dataframe, keys):
dataframe = dataframe.withColumn(
"index", functions.monotonically_increasing_id()
Expand Down Expand Up @@ -127,10 +99,7 @@ def get_data(spark, split=[0.99, 0.005, 0.005]):

return ((trainx, trainy), (valx, valy), (testx, testy), category_dict)


def main():
# jax mesh setup
"""
def multi_gpu():
devices = jax.devices("gpu")
mesh = keras.distribution.DeviceMesh(
shape=(len(devices),), axis_names=["model"], devices=devices
Expand All @@ -151,38 +120,33 @@ def main():
layout_map=layout_map
)
keras.distribution.set_distribution(model_parallel)
"""

def main():

with open("parameters.json", "r") as file:
params = json.load(file)

spark = SparkSession.builder.appName("train").getOrCreate()

keys = ["treatment", "target"]

LOAD_FROM_SCRATCH = False
if LOAD_FROM_SCRATCH:
data = etl(spark, split=SPLIT)
if params["load_from_scratch"]:
(train_set, validation_set,
test_set, categories) = etl(spark, split=params["split"])
else:
data = read(spark, split=SPLIT)
(train_set, validation_set,
test_set, categories) = read(spark, split=params["split"])

(train_set, validation_set, test_set, categories) = data

n_classes = [dset.shape[1] for dset in train_set[1]]
model = get_model(train_set[0].shape[1:], n_classes)
model.compile(optimizer=keras.optimizers.Adam(learning_rate=4e-4),
loss=["mse"],
)
if True: #LOG_LEVEL == 1:
model.summary()

start = time.time()
model.fit(x=train_set[1], y=train_set[0],
validation_data=(validation_set[1], validation_set[0]),
batch_size=BATCH_SIZE, epochs=EPOCHS, verbose=LOG_LEVEL)
end = time.time()
print("Training time: ", end - start)
decoder = dt.load_decoder(params, train_set[0].shape[1:], n_classes)

decoder = dt.train_decoder(decoder, params, train_set, validation_set)
# Test model performance
#test_loss, test_accuracy = model.evaluate(test_set[1], test_set[0])
test_predict = model.predict(test_set[1])

metrics = {key: None for key in keys}
metrics, test_predict = dt.test_decoder(decoder, test_set, metrics)

plt.pcolor(test_predict[0], cmap='bwr', clim=(-1, 1))
plt.savefig("sample_spectra.png")
plt.close()
Expand Down
67 changes: 54 additions & 13 deletions train/decoder_train.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,83 @@
#-*- coding: utf-8 -*-

import keras
from keras import Model
import time
import typing
from typing import List

def load_decoder(input_shape, n_classes):
"""
"""
from model.model import DecoderModel
"""
# data parameters
SPLIT = params["split"]
LOAD_FROM_SCRATCH = params["load_from_scratch"]
model = DecoderModel(input_shape, HEAD_SIZE, NUM_HEADS, FF_DIM,
NUM_TRANSFORMER_BLOCKS, MLP_UNITS, n_classes,
dropout=DROPOUT, mlp_dropout=MLP_DROPOUT)
return model
def train_decoder(decoder, train, validation):
# training parameters
BATCH_SIZE = params["batch_size"]
EPOCHS = params["epochs"]
LOG_LEVEL = params["log_level"]
"""

def load_decoder(params, input_shape, n_classes):
"""
"""

decoder = load_decoder(train_set[0].shape[1:], n_classes)
decoder = DecoderModel(
input_shape,
params["head_size"],
params["num_heads"],
params["ff_dim"],
params["num_transformer_blocks"],
params["mlp_units"],
n_classes,
params["dropout"],
params["mlp_dropout"]
)

decoder.compile(
optimizer=keras.optimizers.Adam(learning_rate=4e-4),
loss=["mse"],
loss=params["loss"],
metrics=params["metrics"]
)

return decoder

def train_decoder(decoder, params, train, validation):
"""
"""

start = time.time()
decoder.fit(
x=train[1],
y=train[0],
validation_data=(validation[1], validation[0]),
batch_size=params["batch_size"],
epochs=params["epochs"],
verbose=params["log_level"]
)
end = time.time()
print("Training time: ", end - start)

return decoder

def test_decoder(decoder: Model, test: List, metrics: dict):
"""
"""

test_eval = model.evaluate(test[1], test[0])
test_eval = decoder.evaluate(test[1], test[0])
if len(metrics.keys()) == 1:
metrics[metrics.keys()[0]] = test_eval
else:
for i, key in enumerate(metrics.keys()):
metrics[key] = test_eval[i]

test_predict = model.predict(test[1])
test_predict = decoder.predict(test[1])

return metrics, test_predict

def save_decoder():
def save_decoder(decoder: Model):

return

# EOF
File renamed without changes.
Loading