From 42aced767ac15297c685309dc381fd73032a407b Mon Sep 17 00:00:00 2001 From: Dawith Date: Thu, 12 Dec 2024 00:04:29 -0500 Subject: [PATCH] Training code now works --- model/autoencoder.py | 50 +++++++++++++++++++++++++-------- train.py | 67 +++++++++++++++++++++++++++++++++++++++----- 2 files changed, 98 insertions(+), 19 deletions(-) diff --git a/model/autoencoder.py b/model/autoencoder.py index 8a6fe77..9d51eb8 100644 --- a/model/autoencoder.py +++ b/model/autoencoder.py @@ -1,40 +1,61 @@ import keras class Encoder(keras.Model): - def __init__(self, input_size=(130, 26), latent_dim=128): + def __init__(self, input_size=(26, 130, 1), latent_dim=128): super(Encoder, self).__init__() self.encoder = keras.Sequential([ - keras.layers.InputLayer(input_shape=input_size), - keras.layers.Conv2D(8, (3, 3), activation="relu", - input_shape=input_size), - keras.layers.Conv2D(16, (3, 3), activation="relu", - input_shape=input_size), - keras.layers.Conv2D(32, (3, 3), activation="relu", - input_shape=input_size), + keras.layers.InputLayer(shape=input_size), + keras.layers.Conv2D(8, (3, 3), activation="relu"), + keras.layers.MaxPooling2D((2, 2)), + keras.layers.Conv2D(16, (3, 3), activation="relu"), + #keras.layers.MaxPooling2D((2, 2)), + keras.layers.Conv2D(32, (3, 3), activation="relu"), keras.layers.Flatten(), + keras.layers.Dense(latent_dim*128, activation="relu"), + keras.layers.Dense(latent_dim*64, activation="relu"), + keras.layers.Dense(latent_dim*32, activation="relu"), + keras.layers.Dense(latent_dim*16, activation="relu"), + keras.layers.Dense(latent_dim*8, activation="relu"), + keras.layers.Dense(latent_dim*4, activation="relu"), + keras.layers.Dense(latent_dim*2, activation="relu"), keras.layers.Dense(latent_dim, activation="relu") ]) def call(self, x): return self.encoder(x) + def summary(self): + self.encoder.summary() + class Decoder(keras.Model): def __init__(self, latent_dim=128): super(Decoder, self).__init__() self.decoder = keras.Sequential([ - keras.layers.InputLayer(input_shape=(latent_dim,)), - keras.layers.Reshape((4, 4, 8)), + keras.layers.InputLayer(shape=(latent_dim,)), + keras.layers.Dense(latent_dim*2, activation="relu"), + keras.layers.Dense(latent_dim*4, activation="relu"), + keras.layers.Dense(latent_dim*8, activation="relu"), + keras.layers.Dense(latent_dim*16, activation="relu"), + keras.layers.Dense(latent_dim*32, activation="relu"), + keras.layers.Dense(latent_dim*64, activation="relu"), + keras.layers.Dense(latent_dim*128, activation="relu"), + keras.layers.Dense(15360, activation="relu"), + keras.layers.Reshape((8, 60, 32)), keras.layers.Conv2DTranspose(32, (3, 3), activation="relu"), + #keras.layers.UpSampling2D((2, 2)), keras.layers.Conv2DTranspose(16, (3, 3), activation="relu"), - keras.layers.Conv2DTranspose(8, (3, 3), activation="relu"), + keras.layers.UpSampling2D((2, 2)), keras.layers.Conv2DTranspose(1, (3, 3), activation="sigmoid") ]) def call(self, x): return self.decoder(x) + def summary(self): + self.decoder.summary() + class Autoencoder(keras.Model): - def __init__(self, input_size=(130, 26), latent_dim=128, **kwargs): + def __init__(self, input_size=(26, 130, 1), latent_dim=128, **kwargs): super(Autoencoder, self).__init__() self.encoder = Encoder(input_size=input_size, latent_dim=latent_dim) self.decoder = Decoder(latent_dim=latent_dim) @@ -43,3 +64,8 @@ def call(self, x): encoded = self.encoder(x) decoded = self.decoder(encoded) return decoded + + def summary(self): + super().summary() + self.encoder.summary() + self.decoder.summary() diff --git a/train.py b/train.py index 7fba5f4..246d083 100644 --- a/train.py +++ b/train.py @@ -4,14 +4,35 @@ Launches the training process for the model. """ +import os from pathlib import Path +os.environ["KERAS_BACKEND"] = "jax" +os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" + +import jax +import numpy as np from pipe.pipe import SpectrogramPipe +from pyspark.ml.feature import StringIndexer from pyspark.sql import SparkSession +import tensorflow as tf +import keras + +from jax.experimental import mesh_utils +from jax.sharding import Mesh +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec from model.autoencoder import Autoencoder -def get_data(spark): +def trim(dataframe, column): + + ndarray = np.array(dataframe.select(column).collect()) \ + .reshape(-1, 26, 130, 1) + + return ndarray + +def get_data(spark, split=[0.99, 0.005, 0.005]): path = Path("/app/datadump/train") labels = [] @@ -22,21 +43,53 @@ def get_data(spark): pipe = SpectrogramPipe(spark) data = pipe.spectrogram_pipe(path, labels) - return data.randomSplit([0.8, 0.1, 0.1], seed=42) + indexer = StringIndexer(inputCol="treatment", outputCol="treatment_index") + indexed = indexer.fit(data).transform(data) + + train_df, validation_df, test_df = indexed.randomSplit(split, seed=42) + print(train_df.count()) + print(validation_df.count()) + print(test_df.count()) + + trainx = trim(train_df, "spectrogram") + trainy = np.array(train_df.select("treatment_index").collect()) + valx = trim(validation_df, "spectrogram") + valy = np.array(validation_df.select("treatment_index").collect()) + + testx = trim(test_df, "spectrogram") + testy = np.array(test_df.select("treatment_index").collect()) + + return ((trainx, trainy), (valx, valy), (testx, testy)) def get_model(): model = Autoencoder() - return + return model def main(): + # jax mesh setup + devices = jax.devices("gpu") + mesh = keras.distribution.DeviceMesh( + shape=(1, 2), axis_names=["data", "model"], devices=devices + ) + layout_map = keras.distribution.LayoutMap(mesh) + layout_map["d1/kernel"] = (None, "model") + layout_map["d1/bias"] = ("model",) + layout_map["d2/output"] = ("model", None) + model_parallel = keras.distribution.ModelParallel( + layout_map=layout_map, batch_dim_name="data" + ) + keras.distribution.set_distribution(model_parallel) + + spark = SparkSession.builder.appName("train").getOrCreate() - train_df, validation_df, test_df = get_data(spark) - print(train_df.count()) - print(validation_df.count()) - print(test_df.count()) + train_set, validation_set, test_set = get_data(spark) model = get_model() + model.summary() + + model.compile(optimizer="adam", loss="mean_squared_error") + model.fit(x=train_set[0], y=train_set[0], batch_size=1, epochs=100) return