diff --git a/decodertest.py b/decodertest.py index 544df15..bb34818 100644 --- a/decodertest.py +++ b/decodertest.py @@ -7,7 +7,6 @@ import os from pathlib import Path -import time os.environ["KERAS_BACKEND"] = "jax" os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" @@ -27,28 +26,7 @@ from model.model import CompoundModel, DecoderModel from pipe.etl import etl, read - -with open("parameters.json", "r") as file: - params = json.load(file) - - # data parameters - SPLIT = params["split"] - LOAD_FROM_SCRATCH = params["load_from_scratch"] - - # model parameters - HEAD_SIZE = params["head_size"] - NUM_HEADS = params["num_heads"] - FF_DIM = params["ff_dim"] - NUM_TRANSFORMER_BLOCKS = params["num_transformer_blocks"] - MLP_UNITS = params["mlp_units"] - DROPOUT = params["dropout"] - MLP_DROPOUT = params["mlp_dropout"] - - # training parameters - BATCH_SIZE = params["batch_size"] - EPOCHS = params["epochs"] - LOG_LEVEL = params["log_level"] - del params +import train.decoder_train as dt def trim(dataframe, column): @@ -57,12 +35,6 @@ def trim(dataframe, column): return ndarray -def get_model(input_shape, n_classes): - model = DecoderModel(input_shape, HEAD_SIZE, NUM_HEADS, FF_DIM, - NUM_TRANSFORMER_BLOCKS, MLP_UNITS, n_classes, - dropout=DROPOUT, mlp_dropout=MLP_DROPOUT) - return model - def transform(spark, dataframe, keys): dataframe = dataframe.withColumn( "index", functions.monotonically_increasing_id() @@ -127,10 +99,7 @@ def get_data(spark, split=[0.99, 0.005, 0.005]): return ((trainx, trainy), (valx, valy), (testx, testy), category_dict) - -def main(): - # jax mesh setup - """ +def multi_gpu(): devices = jax.devices("gpu") mesh = keras.distribution.DeviceMesh( shape=(len(devices),), axis_names=["model"], devices=devices @@ -151,38 +120,33 @@ def main(): layout_map=layout_map ) keras.distribution.set_distribution(model_parallel) - """ + +def main(): + + with open("parameters.json", "r") as file: + params = json.load(file) spark = SparkSession.builder.appName("train").getOrCreate() keys = ["treatment", "target"] - LOAD_FROM_SCRATCH = False - if LOAD_FROM_SCRATCH: - data = etl(spark, split=SPLIT) + if params["load_from_scratch"]: + (train_set, validation_set, + test_set, categories) = etl(spark, split=params["split"]) else: - data = read(spark, split=SPLIT) + (train_set, validation_set, + test_set, categories) = read(spark, split=params["split"]) - (train_set, validation_set, test_set, categories) = data n_classes = [dset.shape[1] for dset in train_set[1]] - model = get_model(train_set[0].shape[1:], n_classes) - model.compile(optimizer=keras.optimizers.Adam(learning_rate=4e-4), - loss=["mse"], - ) - if True: #LOG_LEVEL == 1: - model.summary() - - start = time.time() - model.fit(x=train_set[1], y=train_set[0], - validation_data=(validation_set[1], validation_set[0]), - batch_size=BATCH_SIZE, epochs=EPOCHS, verbose=LOG_LEVEL) - end = time.time() - print("Training time: ", end - start) + decoder = dt.load_decoder(params, train_set[0].shape[1:], n_classes) + decoder = dt.train_decoder(decoder, params, train_set, validation_set) # Test model performance - #test_loss, test_accuracy = model.evaluate(test_set[1], test_set[0]) - test_predict = model.predict(test_set[1]) + + metrics = {key: None for key in keys} + metrics, test_predict = dt.test_decoder(decoder, test_set, metrics) + plt.pcolor(test_predict[0], cmap='bwr', clim=(-1, 1)) plt.savefig("sample_spectra.png") plt.close() diff --git a/train/decoder_train.py b/train/decoder_train.py index 19d22e8..182e443 100644 --- a/train/decoder_train.py +++ b/train/decoder_train.py @@ -1,42 +1,83 @@ #-*- coding: utf-8 -*- +import keras from keras import Model +import time import typing from typing import List -def load_decoder(input_shape, n_classes): - """ - """ +from model.model import DecoderModel +""" +# data parameters +SPLIT = params["split"] +LOAD_FROM_SCRATCH = params["load_from_scratch"] - model = DecoderModel(input_shape, HEAD_SIZE, NUM_HEADS, FF_DIM, - NUM_TRANSFORMER_BLOCKS, MLP_UNITS, n_classes, - dropout=DROPOUT, mlp_dropout=MLP_DROPOUT) - return model -def train_decoder(decoder, train, validation): +# training parameters +BATCH_SIZE = params["batch_size"] +EPOCHS = params["epochs"] +LOG_LEVEL = params["log_level"] +""" + +def load_decoder(params, input_shape, n_classes): """ """ - decoder = load_decoder(train_set[0].shape[1:], n_classes) + decoder = DecoderModel( + input_shape, + params["head_size"], + params["num_heads"], + params["ff_dim"], + params["num_transformer_blocks"], + params["mlp_units"], + n_classes, + params["dropout"], + params["mlp_dropout"] + ) + decoder.compile( optimizer=keras.optimizers.Adam(learning_rate=4e-4), - loss=["mse"], + loss=params["loss"], + metrics=params["metrics"] + ) + + return decoder + +def train_decoder(decoder, params, train, validation): + """ + """ + + start = time.time() + decoder.fit( + x=train[1], + y=train[0], + validation_data=(validation[1], validation[0]), + batch_size=params["batch_size"], + epochs=params["epochs"], + verbose=params["log_level"] ) + end = time.time() + print("Training time: ", end - start) + + return decoder def test_decoder(decoder: Model, test: List, metrics: dict): """ """ - test_eval = model.evaluate(test[1], test[0]) + test_eval = decoder.evaluate(test[1], test[0]) if len(metrics.keys()) == 1: metrics[metrics.keys()[0]] = test_eval else: for i, key in enumerate(metrics.keys()): metrics[key] = test_eval[i] - test_predict = model.predict(test[1]) + test_predict = decoder.predict(test[1]) + return metrics, test_predict -def save_decoder(): +def save_decoder(decoder: Model): + + return # EOF