From f8268e2a7661db26a28e7df133859432864e5bed Mon Sep 17 00:00:00 2001 From: maelstrom Date: Fri, 3 Jan 2025 00:19:12 -0500 Subject: [PATCH] CPU version that predicts with ~33% accuracy --- model/model.py | 8 ++++-- pipe/pipe.py | 5 +++- train_cpu.py | 78 +++++++++++++++++++++++++++++++++++--------------- 3 files changed, 64 insertions(+), 27 deletions(-) diff --git a/model/model.py b/model/model.py index 477e2cd..40b3270 100644 --- a/model/model.py +++ b/model/model.py @@ -4,8 +4,9 @@ """ from keras import Input, Model -from keras.layers import Conv1D, Dense, Dropout, GlobalAveragePooling1D, \ - LayerNormalization, Masking, MultiHeadAttention +from keras.layers import BatchNormalization, Conv1D, Dense, Dropout, \ + GlobalAveragePooling1D, LayerNormalization, Masking, \ + MultiHeadAttention class TimeSeriesTransformer(Model): @@ -65,7 +66,8 @@ def _modelstack(self, input_shape, head_size, num_heads, ff_dim, """ inputs = Input(shape=input_shape) - x = Masking(mask_value=32.)(inputs) + #x = inputs + x = BatchNormalization()(inputs) for _ in range(num_Transformer_blocks): x = self._transformerblocks(x, head_size, num_heads, ff_dim, dropout) diff --git a/pipe/pipe.py b/pipe/pipe.py index 4ac65ad..7282246 100644 --- a/pipe/pipe.py +++ b/pipe/pipe.py @@ -58,7 +58,7 @@ def metadata_pipe(self, metapath: Path, labels:list, def spectrogram_pipe_matfiles(self, specpath: Path, labels:list, default_size: tuple = (32, 130), - pad_value: float = 32.) \ + pad_value: float = 0.) \ -> DataFrame: """ Loads spectrograms for each stack iteration from a set of mat files, @@ -84,6 +84,9 @@ def spectrogram_pipe_matfiles(self, specpath: Path, labels:list, ((default_size[0] - spectrogram.shape[0], 0), (default_size[1] - spectrogram.shape[1], 0)), mode="constant", constant_values=pad_value) + spectrogram[np.isnan(spectrogram)] = 0. + spectrogram[np.abs(spectrogram) == np.inf] = 0. + spectrogram = spectrogram / np.sum(spectrogram) row["spectrogram"] = spectrogram.tolist() spectrograms.append(Row(**row)) diff --git a/train_cpu.py b/train_cpu.py index 439da4d..d0d5ec9 100644 --- a/train_cpu.py +++ b/train_cpu.py @@ -14,34 +14,31 @@ import jax import numpy as np from pipe.pipe import SpectrogramPipe -from pyspark.ml.feature import StringIndexer -from pyspark.sql import SparkSession +from pyspark.ml.feature import StringIndexer, IndexToString +from pyspark.sql import SparkSession, functions import tensorflow as tf import keras - -from jax.experimental import mesh_utils -from jax.sharding import Mesh -from jax.sharding import NamedSharding -from jax.sharding import PartitionSpec +import matplotlib.pyplot as plt +from sklearn.metrics import confusion_matrix from model.autoencoder_smol import Autoencoder from model.model import TimeSeriesTransformer as TSTF # data parameters -SPLIT = [0.98, 0.015, 0.05] +SPLIT = [0.9, 0.05, 0.05] # model parameters -HEAD_SIZE = 256 -NUM_HEADS = 4 -FF_DIM = 4 -NUM_TRANSFORMER_BLOCKS = 4 +HEAD_SIZE = 32 +NUM_HEADS = 12 +FF_DIM = 32 +NUM_TRANSFORMER_BLOCKS = 12 MLP_UNITS = [128] -DROPOUT = 0.2 +DROPOUT = 0.3 MLP_DROPOUT = 0.3 # training parameters BATCH_SIZE = 8 -EPOCHS = 25 +EPOCHS = 800 def trim(dataframe, column): @@ -63,19 +60,43 @@ def get_data(spark, split=[0.99, 0.005, 0.005]): indexer = StringIndexer(inputCol="treatment", outputCol="treatment_index") indexed = indexer.fit(data).transform(data) + + selected = indexed.select("treatment", "treatment_index").distinct() + selected = selected.sort("treatment_index") + index_max = selected.agg(functions.max("treatment_index")).collect()[0][0] train_df, validation_df, test_df = indexed.randomSplit(split, seed=42) trainx = trim(train_df, "spectrogram") - trainy = np.array(train_df.select("treatment_index").collect()) + trainy = np.array(train_df.select("treatment_index").collect()).astype(int) + _trainy = np.zeros((len(trainy), int(index_max+1))) + for index, value in enumerate(trainy): + _trainy[index, value] = 1. + #_trainy[np.arange(trainy.shape[0]), trainy] = 1. + trainy = _trainy + del _trainy valx = trim(validation_df, "spectrogram") - valy = np.array(validation_df.select("treatment_index").collect()) + valy = np.array(validation_df.select("treatment_index").collect()) \ + .astype(int) + _valy = np.zeros((len(valy), int(index_max+1))) + for index, value in enumerate(valy): + _valy[index, value] = 1. + #_valy[np.arange(valy.shape[0]), valy] = 1. + valy = _valy + del _valy testx = trim(test_df, "spectrogram") - testy = np.array(test_df.select("treatment_index").collect()) + testy = np.array(test_df.select("treatment_index").collect()).astype(int) + _testy = np.zeros((len(testy), int(index_max+1))) + for index, value in enumerate(testy): + _testy[index, value] = 1. + #_testy[np.arange(testy.shape[0]), testy] = 1. + testy = _testy + del _testy + - return ((trainx, trainy), (valx, valy), (testx, testy)) + return (selected, (trainx, trainy), (valx, valy), (testx, testy)) def get_model(input_shape, n_classes): model = TSTF(input_shape, HEAD_SIZE, NUM_HEADS, FF_DIM, @@ -88,13 +109,13 @@ def main(): spark = SparkSession.builder.appName("train").getOrCreate() - train_set, validation_set, test_set = get_data(spark, split=SPLIT) + indices, train_set, validation_set, test_set = get_data(spark, split=SPLIT) - n_classes = len(set(train_set[1].flatten())) + n_classes = indices.count() model = get_model(train_set[0].shape[1:], n_classes) - model.compile(optimizer=keras.optimizers.Adam(), - loss="sparse_categorical_crossentropy", - metrics=["sparse_categorical_accuracy"] + model.compile(optimizer=keras.optimizers.Adam(learning_rate=4e-4), + loss="categorical_crossentropy", + metrics=["categorical_accuracy"] ) model.summary() @@ -104,6 +125,17 @@ def main(): batch_size=BATCH_SIZE, epochs=EPOCHS) end = time.time() print("Training time: ", end - start) + + # Test model performance + test_loss, test_accuracy = model.evaluate(test_set[0], test_set[1]) + test_predict = model.predict(test_set[0]) + print(f"Test loss: {test_loss}, test accuracy: {test_accuracy}") + + conf_matrix = confusion_matrix(np.argmax(test_predict, axis=1), + np.argmax(test_set[1], axis=1)) + plt.imshow(conf_matrix, origin="upper") + plt.gca().set_aspect("equal") + plt.savefig("confusion_matrix.png") return