Skip to content

Commit

Permalink
constants moved out to parameter file
Browse files Browse the repository at this point in the history
  • Loading branch information
lim185 committed Jan 3, 2025
1 parent 721f2f8 commit 04558db
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 30 deletions.
34 changes: 19 additions & 15 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,25 @@

from model.model import TimeSeriesTransformer as TSTF

# data parameters
SPLIT = [0.9, 0.05, 0.05]

# model parameters
HEAD_SIZE = 32
NUM_HEADS = 12
FF_DIM = 32
NUM_TRANSFORMER_BLOCKS = 12
MLP_UNITS = [128]
DROPOUT = 0.3
MLP_DROPOUT = 0.3

# training parameters
BATCH_SIZE = 8
EPOCHS = 800
with open("parameters.json", "r") as file:
params = json.load(file)

# data parameters
SPLIT = params["split"]

# model parameters
HEAD_SIZE = params["head_size"]
NUM_HEADS = params["num_heads"]
FF_DIM = params["ff_dim"]
NUM_TRANSFORMER_BLOCKS = params["num_transformer_blocks"]
MLP_UNITS = params["mlp_units"]
DROPOUT = params["dropout"]
MLP_DROPOUT = params["mlp_dropout"]

# training parameters
BATCH_SIZE = params["batch_size"]
EPOCHS = params["epochs"]
del params

def trim(dataframe, column):

Expand Down
35 changes: 20 additions & 15 deletions train_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import jax
import json
import numpy as np
from pipe.pipe import SpectrogramPipe
from pyspark.ml.feature import StringIndexer, IndexToString
Expand All @@ -24,21 +25,25 @@
from model.autoencoder_smol import Autoencoder
from model.model import TimeSeriesTransformer as TSTF

# data parameters
SPLIT = [0.9, 0.05, 0.05]

# model parameters
HEAD_SIZE = 32
NUM_HEADS = 12
FF_DIM = 32
NUM_TRANSFORMER_BLOCKS = 12
MLP_UNITS = [128]
DROPOUT = 0.3
MLP_DROPOUT = 0.3

# training parameters
BATCH_SIZE = 8
EPOCHS = 800
with open("parameters.json", "r") as file:
params = json.load(file)

# data parameters
SPLIT = params["split"]

# model parameters
HEAD_SIZE = params["head_size"]
NUM_HEADS = params["num_heads"]
FF_DIM = params["ff_dim"]
NUM_TRANSFORMER_BLOCKS = params["num_transformer_blocks"]
MLP_UNITS = params["mlp_units"]
DROPOUT = params["dropout"]
MLP_DROPOUT = params["mlp_dropout"]

# training parameters
BATCH_SIZE = params["batch_size"]
EPOCHS = params["epochs"]
del params

def trim(dataframe, column):

Expand Down

0 comments on commit 04558db

Please sign in to comment.