Skip to content

Commit

Permalink
Training code now works
Browse files Browse the repository at this point in the history
  • Loading branch information
Dawith committed Dec 12, 2024
1 parent d3ebe35 commit 42aced7
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 19 deletions.
50 changes: 38 additions & 12 deletions model/autoencoder.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,61 @@
import keras

class Encoder(keras.Model):
def __init__(self, input_size=(130, 26), latent_dim=128):
def __init__(self, input_size=(26, 130, 1), latent_dim=128):
super(Encoder, self).__init__()
self.encoder = keras.Sequential([
keras.layers.InputLayer(input_shape=input_size),
keras.layers.Conv2D(8, (3, 3), activation="relu",
input_shape=input_size),
keras.layers.Conv2D(16, (3, 3), activation="relu",
input_shape=input_size),
keras.layers.Conv2D(32, (3, 3), activation="relu",
input_shape=input_size),
keras.layers.InputLayer(shape=input_size),
keras.layers.Conv2D(8, (3, 3), activation="relu"),
keras.layers.MaxPooling2D((2, 2)),
keras.layers.Conv2D(16, (3, 3), activation="relu"),
#keras.layers.MaxPooling2D((2, 2)),
keras.layers.Conv2D(32, (3, 3), activation="relu"),
keras.layers.Flatten(),
keras.layers.Dense(latent_dim*128, activation="relu"),
keras.layers.Dense(latent_dim*64, activation="relu"),
keras.layers.Dense(latent_dim*32, activation="relu"),
keras.layers.Dense(latent_dim*16, activation="relu"),
keras.layers.Dense(latent_dim*8, activation="relu"),
keras.layers.Dense(latent_dim*4, activation="relu"),
keras.layers.Dense(latent_dim*2, activation="relu"),
keras.layers.Dense(latent_dim, activation="relu")
])

def call(self, x):
return self.encoder(x)

def summary(self):
self.encoder.summary()

class Decoder(keras.Model):
def __init__(self, latent_dim=128):
super(Decoder, self).__init__()
self.decoder = keras.Sequential([
keras.layers.InputLayer(input_shape=(latent_dim,)),
keras.layers.Reshape((4, 4, 8)),
keras.layers.InputLayer(shape=(latent_dim,)),
keras.layers.Dense(latent_dim*2, activation="relu"),
keras.layers.Dense(latent_dim*4, activation="relu"),
keras.layers.Dense(latent_dim*8, activation="relu"),
keras.layers.Dense(latent_dim*16, activation="relu"),
keras.layers.Dense(latent_dim*32, activation="relu"),
keras.layers.Dense(latent_dim*64, activation="relu"),
keras.layers.Dense(latent_dim*128, activation="relu"),
keras.layers.Dense(15360, activation="relu"),
keras.layers.Reshape((8, 60, 32)),
keras.layers.Conv2DTranspose(32, (3, 3), activation="relu"),
#keras.layers.UpSampling2D((2, 2)),
keras.layers.Conv2DTranspose(16, (3, 3), activation="relu"),
keras.layers.Conv2DTranspose(8, (3, 3), activation="relu"),
keras.layers.UpSampling2D((2, 2)),
keras.layers.Conv2DTranspose(1, (3, 3), activation="sigmoid")
])

def call(self, x):
return self.decoder(x)

def summary(self):
self.decoder.summary()

class Autoencoder(keras.Model):
def __init__(self, input_size=(130, 26), latent_dim=128, **kwargs):
def __init__(self, input_size=(26, 130, 1), latent_dim=128, **kwargs):
super(Autoencoder, self).__init__()
self.encoder = Encoder(input_size=input_size, latent_dim=latent_dim)
self.decoder = Decoder(latent_dim=latent_dim)
Expand All @@ -43,3 +64,8 @@ def call(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded

def summary(self):
super().summary()
self.encoder.summary()
self.decoder.summary()
67 changes: 60 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,35 @@
Launches the training process for the model.
"""

import os
from pathlib import Path

os.environ["KERAS_BACKEND"] = "jax"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import jax
import numpy as np
from pipe.pipe import SpectrogramPipe
from pyspark.ml.feature import StringIndexer
from pyspark.sql import SparkSession
import tensorflow as tf
import keras

from jax.experimental import mesh_utils
from jax.sharding import Mesh
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec

from model.autoencoder import Autoencoder

def get_data(spark):
def trim(dataframe, column):

ndarray = np.array(dataframe.select(column).collect()) \
.reshape(-1, 26, 130, 1)

return ndarray

def get_data(spark, split=[0.99, 0.005, 0.005]):
path = Path("/app/datadump/train")

labels = []
Expand All @@ -22,21 +43,53 @@ def get_data(spark):
pipe = SpectrogramPipe(spark)
data = pipe.spectrogram_pipe(path, labels)

return data.randomSplit([0.8, 0.1, 0.1], seed=42)
indexer = StringIndexer(inputCol="treatment", outputCol="treatment_index")
indexed = indexer.fit(data).transform(data)

train_df, validation_df, test_df = indexed.randomSplit(split, seed=42)
print(train_df.count())
print(validation_df.count())
print(test_df.count())

trainx = trim(train_df, "spectrogram")
trainy = np.array(train_df.select("treatment_index").collect())

valx = trim(validation_df, "spectrogram")
valy = np.array(validation_df.select("treatment_index").collect())

testx = trim(test_df, "spectrogram")
testy = np.array(test_df.select("treatment_index").collect())

return ((trainx, trainy), (valx, valy), (testx, testy))
def get_model():
model = Autoencoder()
return
return model

def main():
# jax mesh setup
devices = jax.devices("gpu")
mesh = keras.distribution.DeviceMesh(
shape=(1, 2), axis_names=["data", "model"], devices=devices
)
layout_map = keras.distribution.LayoutMap(mesh)
layout_map["d1/kernel"] = (None, "model")
layout_map["d1/bias"] = ("model",)
layout_map["d2/output"] = ("model", None)
model_parallel = keras.distribution.ModelParallel(
layout_map=layout_map, batch_dim_name="data"
)
keras.distribution.set_distribution(model_parallel)


spark = SparkSession.builder.appName("train").getOrCreate()

train_df, validation_df, test_df = get_data(spark)
print(train_df.count())
print(validation_df.count())
print(test_df.count())
train_set, validation_set, test_set = get_data(spark)

model = get_model()
model.summary()

model.compile(optimizer="adam", loss="mean_squared_error")
model.fit(x=train_set[0], y=train_set[0], batch_size=1, epochs=100)

return

Expand Down

0 comments on commit 42aced7

Please sign in to comment.