Skip to content

Commit

Permalink
trying to fix input dimensions:
Browse files Browse the repository at this point in the history
  • Loading branch information
lim185 committed Dec 8, 2024
1 parent 0f17e2d commit d3ebe35
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions model/autoencoder.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,28 @@
import keras

class Encoder(keras.Model):
def __init__(self, input_size=(130, 26)):
def __init__(self, input_size=(130, 26), latent_dim=128):
super(Encoder, self).__init__()
self.encoder = keras.Sequential([
keras.layers.InputLayer(input_shape=input_size),
keras.layers.Conv2D(8, (3, 3), activation="relu",
input_shape=input_size),
keras.layers.Conv2D(16, (3, 3), activation="relu",
input_shape=input_size),
keras.layers.Conv2D(32, (3, 3), activation="relu",
input_shape=input_size),
keras.layers.Flatten(),
keras.layers.Dense(128, activation="relu")
keras.layers.Dense(latent_dim, activation="relu")
])

def call(self, x):
return self.encoder(x)

class Decoder(keras.Model):
def __init__(self):
def __init__(self, latent_dim=128):
super(Decoder, self).__init__()
self.decoder = keras.Sequential([
keras.layers.Dense(128, activation="relu"),
keras.layers.InputLayer(input_shape=(latent_dim,)),
keras.layers.Reshape((4, 4, 8)),
keras.layers.Conv2DTranspose(32, (3, 3), activation="relu"),
keras.layers.Conv2DTranspose(16, (3, 3), activation="relu"),
Expand All @@ -33,10 +34,10 @@ def call(self, x):
return self.decoder(x)

class Autoencoder(keras.Model):
def __init__(self, input_size=(130, 26)):
def __init__(self, input_size=(130, 26), latent_dim=128, **kwargs):
super(Autoencoder, self).__init__()
self.encoder = Encoder(input_size=input_size)
self.decoder = Decoder()
self.encoder = Encoder(input_size=input_size, latent_dim=latent_dim)
self.decoder = Decoder(latent_dim=latent_dim)

def call(self, x):
encoded = self.encoder(x)
Expand Down

0 comments on commit d3ebe35

Please sign in to comment.