Skip to content

Commit

Permalink
CPU version that predicts with ~33% accuracy
Browse files Browse the repository at this point in the history
  • Loading branch information
lim185 committed Jan 3, 2025
1 parent 17dcf1b commit f8268e2
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 27 deletions.
8 changes: 5 additions & 3 deletions model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
"""

from keras import Input, Model
from keras.layers import Conv1D, Dense, Dropout, GlobalAveragePooling1D, \
LayerNormalization, Masking, MultiHeadAttention
from keras.layers import BatchNormalization, Conv1D, Dense, Dropout, \
GlobalAveragePooling1D, LayerNormalization, Masking, \
MultiHeadAttention

class TimeSeriesTransformer(Model):

Expand Down Expand Up @@ -65,7 +66,8 @@ def _modelstack(self, input_shape, head_size, num_heads, ff_dim,
"""

inputs = Input(shape=input_shape)
x = Masking(mask_value=32.)(inputs)
#x = inputs
x = BatchNormalization()(inputs)
for _ in range(num_Transformer_blocks):
x = self._transformerblocks(x, head_size, num_heads, ff_dim,
dropout)
Expand Down
5 changes: 4 additions & 1 deletion pipe/pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def metadata_pipe(self, metapath: Path, labels:list,

def spectrogram_pipe_matfiles(self, specpath: Path, labels:list,
default_size: tuple = (32, 130),
pad_value: float = 32.) \
pad_value: float = 0.) \
-> DataFrame:
"""
Loads spectrograms for each stack iteration from a set of mat files,
Expand All @@ -84,6 +84,9 @@ def spectrogram_pipe_matfiles(self, specpath: Path, labels:list,
((default_size[0] - spectrogram.shape[0], 0),
(default_size[1] - spectrogram.shape[1], 0)),
mode="constant", constant_values=pad_value)
spectrogram[np.isnan(spectrogram)] = 0.
spectrogram[np.abs(spectrogram) == np.inf] = 0.
spectrogram = spectrogram / np.sum(spectrogram)
row["spectrogram"] = spectrogram.tolist()
spectrograms.append(Row(**row))

Expand Down
78 changes: 55 additions & 23 deletions train_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,31 @@
import jax
import numpy as np
from pipe.pipe import SpectrogramPipe
from pyspark.ml.feature import StringIndexer
from pyspark.sql import SparkSession
from pyspark.ml.feature import StringIndexer, IndexToString
from pyspark.sql import SparkSession, functions
import tensorflow as tf
import keras

from jax.experimental import mesh_utils
from jax.sharding import Mesh
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

from model.autoencoder_smol import Autoencoder
from model.model import TimeSeriesTransformer as TSTF

# data parameters
SPLIT = [0.98, 0.015, 0.05]
SPLIT = [0.9, 0.05, 0.05]

# model parameters
HEAD_SIZE = 256
NUM_HEADS = 4
FF_DIM = 4
NUM_TRANSFORMER_BLOCKS = 4
HEAD_SIZE = 32
NUM_HEADS = 12
FF_DIM = 32
NUM_TRANSFORMER_BLOCKS = 12
MLP_UNITS = [128]
DROPOUT = 0.2
DROPOUT = 0.3
MLP_DROPOUT = 0.3

# training parameters
BATCH_SIZE = 8
EPOCHS = 25
EPOCHS = 800

def trim(dataframe, column):

Expand All @@ -63,19 +60,43 @@ def get_data(spark, split=[0.99, 0.005, 0.005]):

indexer = StringIndexer(inputCol="treatment", outputCol="treatment_index")
indexed = indexer.fit(data).transform(data)

selected = indexed.select("treatment", "treatment_index").distinct()
selected = selected.sort("treatment_index")
index_max = selected.agg(functions.max("treatment_index")).collect()[0][0]

train_df, validation_df, test_df = indexed.randomSplit(split, seed=42)

trainx = trim(train_df, "spectrogram")
trainy = np.array(train_df.select("treatment_index").collect())
trainy = np.array(train_df.select("treatment_index").collect()).astype(int)
_trainy = np.zeros((len(trainy), int(index_max+1)))
for index, value in enumerate(trainy):
_trainy[index, value] = 1.
#_trainy[np.arange(trainy.shape[0]), trainy] = 1.
trainy = _trainy
del _trainy

valx = trim(validation_df, "spectrogram")
valy = np.array(validation_df.select("treatment_index").collect())
valy = np.array(validation_df.select("treatment_index").collect()) \
.astype(int)
_valy = np.zeros((len(valy), int(index_max+1)))
for index, value in enumerate(valy):
_valy[index, value] = 1.
#_valy[np.arange(valy.shape[0]), valy] = 1.
valy = _valy
del _valy

testx = trim(test_df, "spectrogram")
testy = np.array(test_df.select("treatment_index").collect())
testy = np.array(test_df.select("treatment_index").collect()).astype(int)
_testy = np.zeros((len(testy), int(index_max+1)))
for index, value in enumerate(testy):
_testy[index, value] = 1.
#_testy[np.arange(testy.shape[0]), testy] = 1.
testy = _testy
del _testy


return ((trainx, trainy), (valx, valy), (testx, testy))
return (selected, (trainx, trainy), (valx, valy), (testx, testy))

def get_model(input_shape, n_classes):
model = TSTF(input_shape, HEAD_SIZE, NUM_HEADS, FF_DIM,
Expand All @@ -88,13 +109,13 @@ def main():

spark = SparkSession.builder.appName("train").getOrCreate()

train_set, validation_set, test_set = get_data(spark, split=SPLIT)
indices, train_set, validation_set, test_set = get_data(spark, split=SPLIT)

n_classes = len(set(train_set[1].flatten()))
n_classes = indices.count()
model = get_model(train_set[0].shape[1:], n_classes)
model.compile(optimizer=keras.optimizers.Adam(),
loss="sparse_categorical_crossentropy",
metrics=["sparse_categorical_accuracy"]
model.compile(optimizer=keras.optimizers.Adam(learning_rate=4e-4),
loss="categorical_crossentropy",
metrics=["categorical_accuracy"]
)
model.summary()

Expand All @@ -104,6 +125,17 @@ def main():
batch_size=BATCH_SIZE, epochs=EPOCHS)
end = time.time()
print("Training time: ", end - start)

# Test model performance
test_loss, test_accuracy = model.evaluate(test_set[0], test_set[1])
test_predict = model.predict(test_set[0])
print(f"Test loss: {test_loss}, test accuracy: {test_accuracy}")

conf_matrix = confusion_matrix(np.argmax(test_predict, axis=1),
np.argmax(test_set[1], axis=1))
plt.imshow(conf_matrix, origin="upper")
plt.gca().set_aspect("equal")
plt.savefig("confusion_matrix.png")

return

Expand Down

0 comments on commit f8268e2

Please sign in to comment.