diff --git a/decodertest.py b/decodertest.py deleted file mode 100644 index e150970..0000000 --- a/decodertest.py +++ /dev/null @@ -1,86 +0,0 @@ -# --*-- coding: utf-8 --*-- -""" -train.py - -Launches the training process for the model. -""" - -import os -from pathlib import Path - -os.environ["KERAS_BACKEND"] = "jax" -os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" - -import jax -import json -import numpy as np -#from pyspark.ml.feature import OneHotEncoder, StringIndexer -from pyspark.sql import SparkSession, functions, types, Row -import pyspark as spark -import tensorflow as tf -import keras -import keras.metrics as metrics -import matplotlib.pyplot as plt -from sklearn.metrics import auc, confusion_matrix, roc_curve -from sklearn.preprocessing import OneHotEncoder, LabelBinarizer - -from model.model import CompoundModel, DecoderModel -from pipe.etl import etl, read -import train.decoder_train as dt -from visualize.plot import spectra_plot - -def multi_gpu(): - devices = jax.devices("gpu") - mesh = keras.distribution.DeviceMesh( - shape=(len(devices),), axis_names=["model"], devices=devices - ) - layout_map = keras.distribution.LayoutMap(mesh) - layout_map["dense.*kernel"] = (None, "model") - layout_map["dense.*bias"] = ("model",) - layout_map["dense.*kernel_regularizer"] = (None, "model") - layout_map["dense.*bias_regularizer"] = ("model",) - layout_map["dense.*activity_regularizer"] = (None,) - layout_map["dense.*kernel_constraint"] = (None, "model") - layout_map["dense.*bias_constraint"] = ("model",) - layout_map["conv1d.*kernel"] = (None, None, None, "model") - layout_map["conv1d.*kernel_regularizer"] = (None, None, None, "model") - layout_map["conv1d.*bias_regularizer"] = ("model",) - - model_parallel = keras.distribution.ModelParallel( - layout_map=layout_map - ) - keras.distribution.set_distribution(model_parallel) - -def main(): - - with open("parameters.json", "r") as file: - params = json.load(file) - - spark = SparkSession.builder.appName("train").getOrCreate() - - keys = ["treatment", "target"] - - if params["load_from_scratch"]: - (train_set, validation_set, - test_set, categories) = etl(spark, split=params["split"]) - else: - (train_set, validation_set, - test_set, categories) = read(spark, split=params["split"]) - - - n_classes = [dset.shape[1] for dset in train_set[1]] - decoder = dt.load_decoder(params, train_set[0].shape[1:], n_classes) - - decoder = dt.train_decoder(decoder, params, train_set, validation_set) - # Test model performance - - metrics = {key: None for key in keys} - metrics, test_predict = dt.test_decoder(decoder, test_set, metrics) - - spectra_plot(test_predict[0], name=None) - - -if __name__ == "__main__": - main() - -# EOF diff --git a/model/model.py b/model/model.py index be7b97b..d2b5251 100644 --- a/model/model.py +++ b/model/model.py @@ -11,6 +11,7 @@ GlobalAveragePooling1D, LayerNormalization, Masking, Conv2D, \ MultiHeadAttention, concatenate +from model.activation import sublinear from model.transformer import TimeseriesTransformerBuilder as TSTFBuilder class CompoundModel(Model): @@ -205,7 +206,7 @@ def _modelstack(self, input_shape, head_size, num_heads, ff_dim, x = Conv1D(filters=input_shape[1], kernel_size=1, padding="valid", - activation="linear")(x) + activation=sublinear)(x) return Model(inputs, x) diff --git a/train_encoder.py b/train_model.py similarity index 100% rename from train_encoder.py rename to train_model.py