diff --git a/train.py b/train.py index 762ee1d..b5085c2 100644 --- a/train.py +++ b/train.py @@ -23,21 +23,25 @@ from model.model import TimeSeriesTransformer as TSTF -# data parameters -SPLIT = [0.9, 0.05, 0.05] - -# model parameters -HEAD_SIZE = 32 -NUM_HEADS = 12 -FF_DIM = 32 -NUM_TRANSFORMER_BLOCKS = 12 -MLP_UNITS = [128] -DROPOUT = 0.3 -MLP_DROPOUT = 0.3 - -# training parameters -BATCH_SIZE = 8 -EPOCHS = 800 +with open("parameters.json", "r") as file: + params = json.load(file) + + # data parameters + SPLIT = params["split"] + + # model parameters + HEAD_SIZE = params["head_size"] + NUM_HEADS = params["num_heads"] + FF_DIM = params["ff_dim"] + NUM_TRANSFORMER_BLOCKS = params["num_transformer_blocks"] + MLP_UNITS = params["mlp_units"] + DROPOUT = params["dropout"] + MLP_DROPOUT = params["mlp_dropout"] + + # training parameters + BATCH_SIZE = params["batch_size"] + EPOCHS = params["epochs"] + del params def trim(dataframe, column): diff --git a/train_cpu.py b/train_cpu.py index d0d5ec9..de18b67 100644 --- a/train_cpu.py +++ b/train_cpu.py @@ -12,6 +12,7 @@ #os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" import jax +import json import numpy as np from pipe.pipe import SpectrogramPipe from pyspark.ml.feature import StringIndexer, IndexToString @@ -24,21 +25,25 @@ from model.autoencoder_smol import Autoencoder from model.model import TimeSeriesTransformer as TSTF -# data parameters -SPLIT = [0.9, 0.05, 0.05] - -# model parameters -HEAD_SIZE = 32 -NUM_HEADS = 12 -FF_DIM = 32 -NUM_TRANSFORMER_BLOCKS = 12 -MLP_UNITS = [128] -DROPOUT = 0.3 -MLP_DROPOUT = 0.3 - -# training parameters -BATCH_SIZE = 8 -EPOCHS = 800 +with open("parameters.json", "r") as file: + params = json.load(file) + + # data parameters + SPLIT = params["split"] + + # model parameters + HEAD_SIZE = params["head_size"] + NUM_HEADS = params["num_heads"] + FF_DIM = params["ff_dim"] + NUM_TRANSFORMER_BLOCKS = params["num_transformer_blocks"] + MLP_UNITS = params["mlp_units"] + DROPOUT = params["dropout"] + MLP_DROPOUT = params["mlp_dropout"] + + # training parameters + BATCH_SIZE = params["batch_size"] + EPOCHS = params["epochs"] + del params def trim(dataframe, column):