Skip to content

Commit

Permalink
asdf
Browse files Browse the repository at this point in the history
  • Loading branch information
Dawith committed Dec 17, 2024
1 parent 42aced7 commit d2fc802
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 9 deletions.
10 changes: 8 additions & 2 deletions model/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@ def __init__(self, input_size=(26, 130, 1), latent_dim=128):
#keras.layers.MaxPooling2D((2, 2)),
keras.layers.Conv2D(32, (3, 3), activation="relu"),
keras.layers.Flatten(),
keras.layers.Dense(latent_dim*128, activation="relu"),
keras.layers.Dense(latent_dim*64, activation="relu"),
keras.layers.Dense(latent_dim*64, activation="relu"),
keras.layers.Dense(latent_dim*64, activation="relu"),
keras.layers.Dense(latent_dim*32, activation="relu"),
keras.layers.Dense(latent_dim*32, activation="relu"),
keras.layers.Dense(latent_dim*32, activation="relu"),
keras.layers.Dense(latent_dim*16, activation="relu"),
keras.layers.Dense(latent_dim*8, activation="relu"),
Expand All @@ -37,8 +40,11 @@ def __init__(self, latent_dim=128):
keras.layers.Dense(latent_dim*8, activation="relu"),
keras.layers.Dense(latent_dim*16, activation="relu"),
keras.layers.Dense(latent_dim*32, activation="relu"),
keras.layers.Dense(latent_dim*32, activation="relu"),
keras.layers.Dense(latent_dim*32, activation="relu"),
keras.layers.Dense(latent_dim*64, activation="relu"),
keras.layers.Dense(latent_dim*64, activation="relu"),
keras.layers.Dense(latent_dim*64, activation="relu"),
keras.layers.Dense(latent_dim*128, activation="relu"),
keras.layers.Dense(15360, activation="relu"),
keras.layers.Reshape((8, 60, 32)),
keras.layers.Conv2DTranspose(32, (3, 3), activation="relu"),
Expand Down
25 changes: 18 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import os
from pathlib import Path
import time

os.environ["KERAS_BACKEND"] = "jax"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
Expand Down Expand Up @@ -69,17 +70,24 @@ def main():
# jax mesh setup
devices = jax.devices("gpu")
mesh = keras.distribution.DeviceMesh(
shape=(1, 2), axis_names=["data", "model"], devices=devices
shape=(2,), axis_names=["model"], devices=devices
)
layout_map = keras.distribution.LayoutMap(mesh)
layout_map["d1/kernel"] = (None, "model")
layout_map["d1/bias"] = ("model",)
layout_map["d2/output"] = ("model", None)
layout_map["dense.*kernel"] = (None, "model")
layout_map["dense.*bias"] = ("model",)
layout_map["dense.*kernel_regularizer"] = (None, "model")
layout_map["dense.*bias_regularizer"] = ("model",)
layout_map["dense.*activity_regularizer"] = (None,)
layout_map["dense.*kernel_constraint"] = (None, "model")
layout_map["dense.*bias_constraint"] = ("model",)
layout_map["conv2d.*kernel"] = (None, None, None, "model")
layout_map["conv2d.*kernel_regularizer"] = (None, None, None, "model")
layout_map["conv2d.*bias_regularizer"] = ("model",)

model_parallel = keras.distribution.ModelParallel(
layout_map=layout_map, batch_dim_name="data"
layout_map=layout_map
)
keras.distribution.set_distribution(model_parallel)


spark = SparkSession.builder.appName("train").getOrCreate()

Expand All @@ -89,7 +97,10 @@ def main():
model.summary()

model.compile(optimizer="adam", loss="mean_squared_error")
model.fit(x=train_set[0], y=train_set[0], batch_size=1, epochs=100)
start = time.time()
model.fit(x=train_set[0], y=train_set[0], batch_size=2, epochs=50)
end = time.time()
print("Training time: ", end - start)

return

Expand Down

0 comments on commit d2fc802

Please sign in to comment.