Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Dawith committed Oct 24, 2025
1 parent d7356ed commit ed755bf
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions train/decoder_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#-*- coding: utf-8 -*-

from keras import Model
import typing
from typing import List

def load_decoder(input_shape, n_classes):
"""
"""

model = DecoderModel(input_shape, HEAD_SIZE, NUM_HEADS, FF_DIM,
NUM_TRANSFORMER_BLOCKS, MLP_UNITS, n_classes,
dropout=DROPOUT, mlp_dropout=MLP_DROPOUT)
return model

def train_decoder(decoder, train, validation):
"""
"""

decoder = load_decoder(train_set[0].shape[1:], n_classes)
decoder.compile(
optimizer=keras.optimizers.Adam(learning_rate=4e-4),
loss=["mse"],
)

def test_decoder(decoder: Model, test: List, metrics: dict):
"""
"""

test_eval = model.evaluate(test[1], test[0])
if len(metrics.keys()) == 1:
metrics[metrics.keys()[0]] = test_eval
else:
for i, key in enumerate(metrics.keys()):
metrics[key] = test_eval[i]

test_predict = model.predict(test[1])
return metrics, test_predict

def save_decoder():

# EOF

0 comments on commit ed755bf

Please sign in to comment.