Skip to content

Commit

Permalink
autoencoder added as a test run
Browse files Browse the repository at this point in the history
  • Loading branch information
lim185 committed Dec 8, 2024
1 parent ffe9db0 commit 91125c9
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 4 deletions.
44 changes: 44 additions & 0 deletions model/autoencoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import keras

class Encoder(keras.Model):
def __init__(self, input_size=(130, 26)):
super(Encoder, self).__init__()
self.encoder = keras.Sequential([
keras.layers.Conv2D(8, (3, 3), activation="relu",
input_shape=input_size),
keras.layers.Conv2D(16, (3, 3), activation="relu",
input_shape=input_size),
keras.layers.Conv2D(32, (3, 3), activation="relu",
input_shape=input_size),
keras.layers.Flatten(),
keras.layers.Dense(128, activation="relu")
])

def call(self, x):
return self.encoder(x)

class Decoder(keras.Model):
def __init__(self):
super(Decoder, self).__init__()
self.decoder = keras.Sequential([
keras.layers.Dense(128, activation="relu"),
keras.layers.Reshape((4, 4, 8)),
keras.layers.Conv2DTranspose(32, (3, 3), activation="relu"),
keras.layers.Conv2DTranspose(16, (3, 3), activation="relu"),
keras.layers.Conv2DTranspose(8, (3, 3), activation="relu"),
keras.layers.Conv2DTranspose(1, (3, 3), activation="sigmoid")
])

def call(self, x):
return self.decoder(x)

class Autoencoder(keras.Model):
def __init__(self, input_size=(130, 26):
super(Autoencoder, self).__init__()
self.encoder = Encoder(input_size=input_size)
self.decoder = Decoder()

def call(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded
20 changes: 16 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,34 @@
from pipe.pipe import SpectrogramPipe
from pyspark.sql import SparkSession

def main():
from model.autoencoder import Autoencoder

def get_data(spark):
path = Path("/app/datadump/train")

labels = []
with open(path / "train.csv", "r") as file:
for line in file:
labels.append(line.strip().split(",")[0])

spark = SparkSession.builder.appName("train").getOrCreate()
pipe = SpectrogramPipe(spark)
data = pipe.spectrogram_pipe(path, labels)
train_df, validation_df, test_df = data.randomSplit([0.8, 0.1, 0.1],
seed=42)

return data.randomSplit([0.8, 0.1, 0.1], seed=42)

def get_model():
model = Autoencoder()
return

def main():
spark = SparkSession.builder.appName("train").getOrCreate()

train_df, validation_df, test_df = get_data(spark)
print(train_df.count())
print(validation_df.count())
print(test_df.count())

model = get_model()

return

Expand Down

0 comments on commit 91125c9

Please sign in to comment.