From d2fc802f6b58bb26d6e40a2250e2ad9008d0f66e Mon Sep 17 00:00:00 2001 From: Dawith Date: Tue, 17 Dec 2024 09:02:09 -0500 Subject: [PATCH] asdf --- model/autoencoder.py | 10 ++++++++-- train.py | 25 ++++++++++++++++++------- 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/model/autoencoder.py b/model/autoencoder.py index 9d51eb8..7bea8ac 100644 --- a/model/autoencoder.py +++ b/model/autoencoder.py @@ -11,8 +11,11 @@ def __init__(self, input_size=(26, 130, 1), latent_dim=128): #keras.layers.MaxPooling2D((2, 2)), keras.layers.Conv2D(32, (3, 3), activation="relu"), keras.layers.Flatten(), - keras.layers.Dense(latent_dim*128, activation="relu"), keras.layers.Dense(latent_dim*64, activation="relu"), + keras.layers.Dense(latent_dim*64, activation="relu"), + keras.layers.Dense(latent_dim*64, activation="relu"), + keras.layers.Dense(latent_dim*32, activation="relu"), + keras.layers.Dense(latent_dim*32, activation="relu"), keras.layers.Dense(latent_dim*32, activation="relu"), keras.layers.Dense(latent_dim*16, activation="relu"), keras.layers.Dense(latent_dim*8, activation="relu"), @@ -37,8 +40,11 @@ def __init__(self, latent_dim=128): keras.layers.Dense(latent_dim*8, activation="relu"), keras.layers.Dense(latent_dim*16, activation="relu"), keras.layers.Dense(latent_dim*32, activation="relu"), + keras.layers.Dense(latent_dim*32, activation="relu"), + keras.layers.Dense(latent_dim*32, activation="relu"), + keras.layers.Dense(latent_dim*64, activation="relu"), + keras.layers.Dense(latent_dim*64, activation="relu"), keras.layers.Dense(latent_dim*64, activation="relu"), - keras.layers.Dense(latent_dim*128, activation="relu"), keras.layers.Dense(15360, activation="relu"), keras.layers.Reshape((8, 60, 32)), keras.layers.Conv2DTranspose(32, (3, 3), activation="relu"), diff --git a/train.py b/train.py index 246d083..2cc2a90 100644 --- a/train.py +++ b/train.py @@ -6,6 +6,7 @@ import os from pathlib import Path +import time os.environ["KERAS_BACKEND"] = "jax" os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" @@ -69,17 +70,24 @@ def main(): # jax mesh setup devices = jax.devices("gpu") mesh = keras.distribution.DeviceMesh( - shape=(1, 2), axis_names=["data", "model"], devices=devices + shape=(2,), axis_names=["model"], devices=devices ) layout_map = keras.distribution.LayoutMap(mesh) - layout_map["d1/kernel"] = (None, "model") - layout_map["d1/bias"] = ("model",) - layout_map["d2/output"] = ("model", None) + layout_map["dense.*kernel"] = (None, "model") + layout_map["dense.*bias"] = ("model",) + layout_map["dense.*kernel_regularizer"] = (None, "model") + layout_map["dense.*bias_regularizer"] = ("model",) + layout_map["dense.*activity_regularizer"] = (None,) + layout_map["dense.*kernel_constraint"] = (None, "model") + layout_map["dense.*bias_constraint"] = ("model",) + layout_map["conv2d.*kernel"] = (None, None, None, "model") + layout_map["conv2d.*kernel_regularizer"] = (None, None, None, "model") + layout_map["conv2d.*bias_regularizer"] = ("model",) + model_parallel = keras.distribution.ModelParallel( - layout_map=layout_map, batch_dim_name="data" + layout_map=layout_map ) keras.distribution.set_distribution(model_parallel) - spark = SparkSession.builder.appName("train").getOrCreate() @@ -89,7 +97,10 @@ def main(): model.summary() model.compile(optimizer="adam", loss="mean_squared_error") - model.fit(x=train_set[0], y=train_set[0], batch_size=1, epochs=100) + start = time.time() + model.fit(x=train_set[0], y=train_set[0], batch_size=2, epochs=50) + end = time.time() + print("Training time: ", end - start) return