From 91125c9746e0890c93c3ea721238e97a680095da Mon Sep 17 00:00:00 2001 From: maelstrom Date: Sat, 7 Dec 2024 22:22:08 -0500 Subject: [PATCH] autoencoder added as a test run --- model/autoencoder.py | 44 ++++++++++++++++++++++++++++++++++++++++++++ train.py | 20 ++++++++++++++++---- 2 files changed, 60 insertions(+), 4 deletions(-) create mode 100644 model/autoencoder.py diff --git a/model/autoencoder.py b/model/autoencoder.py new file mode 100644 index 0000000..51597ce --- /dev/null +++ b/model/autoencoder.py @@ -0,0 +1,44 @@ +import keras + +class Encoder(keras.Model): + def __init__(self, input_size=(130, 26)): + super(Encoder, self).__init__() + self.encoder = keras.Sequential([ + keras.layers.Conv2D(8, (3, 3), activation="relu", + input_shape=input_size), + keras.layers.Conv2D(16, (3, 3), activation="relu", + input_shape=input_size), + keras.layers.Conv2D(32, (3, 3), activation="relu", + input_shape=input_size), + keras.layers.Flatten(), + keras.layers.Dense(128, activation="relu") + ]) + + def call(self, x): + return self.encoder(x) + +class Decoder(keras.Model): + def __init__(self): + super(Decoder, self).__init__() + self.decoder = keras.Sequential([ + keras.layers.Dense(128, activation="relu"), + keras.layers.Reshape((4, 4, 8)), + keras.layers.Conv2DTranspose(32, (3, 3), activation="relu"), + keras.layers.Conv2DTranspose(16, (3, 3), activation="relu"), + keras.layers.Conv2DTranspose(8, (3, 3), activation="relu"), + keras.layers.Conv2DTranspose(1, (3, 3), activation="sigmoid") + ]) + + def call(self, x): + return self.decoder(x) + +class Autoencoder(keras.Model): + def __init__(self, input_size=(130, 26): + super(Autoencoder, self).__init__() + self.encoder = Encoder(input_size=input_size) + self.decoder = Decoder() + + def call(self, x): + encoded = self.encoder(x) + decoded = self.decoder(encoded) + return decoded diff --git a/train.py b/train.py index d550750..7fba5f4 100644 --- a/train.py +++ b/train.py @@ -9,7 +9,9 @@ from pipe.pipe import SpectrogramPipe from pyspark.sql import SparkSession -def main(): +from model.autoencoder import Autoencoder + +def get_data(spark): path = Path("/app/datadump/train") labels = [] @@ -17,14 +19,24 @@ def main(): for line in file: labels.append(line.strip().split(",")[0]) - spark = SparkSession.builder.appName("train").getOrCreate() pipe = SpectrogramPipe(spark) data = pipe.spectrogram_pipe(path, labels) - train_df, validation_df, test_df = data.randomSplit([0.8, 0.1, 0.1], - seed=42) + + return data.randomSplit([0.8, 0.1, 0.1], seed=42) + +def get_model(): + model = Autoencoder() + return + +def main(): + spark = SparkSession.builder.appName("train").getOrCreate() + + train_df, validation_df, test_df = get_data(spark) print(train_df.count()) print(validation_df.count()) print(test_df.count()) + + model = get_model() return