From ed755bf8574783e15195db7fda10239ae10a7cd0 Mon Sep 17 00:00:00 2001 From: Dawith Date: Thu, 23 Oct 2025 20:40:18 -0400 Subject: [PATCH] Initial commit --- train/decoder_train.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 train/decoder_train.py diff --git a/train/decoder_train.py b/train/decoder_train.py new file mode 100644 index 0000000..19d22e8 --- /dev/null +++ b/train/decoder_train.py @@ -0,0 +1,42 @@ +#-*- coding: utf-8 -*- + +from keras import Model +import typing +from typing import List + +def load_decoder(input_shape, n_classes): + """ + """ + + model = DecoderModel(input_shape, HEAD_SIZE, NUM_HEADS, FF_DIM, + NUM_TRANSFORMER_BLOCKS, MLP_UNITS, n_classes, + dropout=DROPOUT, mlp_dropout=MLP_DROPOUT) + return model + +def train_decoder(decoder, train, validation): + """ + """ + + decoder = load_decoder(train_set[0].shape[1:], n_classes) + decoder.compile( + optimizer=keras.optimizers.Adam(learning_rate=4e-4), + loss=["mse"], + ) + +def test_decoder(decoder: Model, test: List, metrics: dict): + """ + """ + + test_eval = model.evaluate(test[1], test[0]) + if len(metrics.keys()) == 1: + metrics[metrics.keys()[0]] = test_eval + else: + for i, key in enumerate(metrics.keys()): + metrics[key] = test_eval[i] + + test_predict = model.predict(test[1]) + return metrics, test_predict + +def save_decoder(): + +# EOF