Skip to content

Commit

Permalink
working on decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
Dawith committed Oct 21, 2025
1 parent 3f43970 commit 465dd25
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 48 deletions.
137 changes: 92 additions & 45 deletions model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,37 +93,109 @@ def _modelstack(self, input_shape, head_size, num_heads, ff_dim,

return Model(inputs, [y, z])

def _transformerblocks(self, inputs, head_size, num_heads,
ff_dim, dropout):
def call(self, inputs):
"""
Constructs the transformer block. This consists of multi-head
attention, dropout, layer normalization, a residual connection,
a feedforward neural network, and another residual connection.
Calls the TimeSeriesTransformer model on a batch of inputs.
Args:
inputs: Tensor, batch of input data.
Returns:
Tensor, resulting output of the TimeSeriesTransformer model.
"""
return self.timeseriestransformer(inputs)

def summary(self):
"""
Prints a summary of the TimeSeriesTransformer model.
Args:
None.
Returns:
None.
"""
self.timeseriestransformer.summary()

class DecoderModel(Model):

def __init__(self, input_shape, head_size, num_heads, ff_dim,
num_Transformer_blocks, mlp_units, n_classes,
dropout=0, mlp_dropout=0):
"""
Initializes the TimeSeriesTransformer class. This class is a
wrapper around a Keras model that consists of a series of
Transformer blocks followed by an MLP.
Args:
input_shape: tuple, shape of the input tensor.
head_size: int, the number of features in each attention head.
num_heads: int, the number of attention heads.
ff_dim: int, the number of neurons in the feedforward neural
network.
num_Transformer_blocks: int, the number of Transformer blocks.
mlp_units: list of ints, the number of neurons in each layer of
the MLP.
n_classes: int, the number of output classes.
dropout: float, dropout rate.
mlp_dropout: float, dropout rate in the MLP.
Attributes:
timeseriestransformer: Keras model, the TimeSeriesTransformer
model.
"""
self.tstfbuilder = TSTFBuilder()

super(CompoundModel, self).__init__()
self.timeseriestransdecoder = self._modelstack(
input_shape, head_size, num_heads, ff_dim,
num_Transformer_blocks, mlp_units, n_classes,
dropout, mlp_dropout)

def _modelstack(self, input_shape, head_size, num_heads, ff_dim,
num_Transformer_blocks, mlp_units, n_classes,
dropout, mlp_dropout):
"""
Creates a Timeseries Transformer model. This consists of a series of
Transformer blocks followed by an MLP.
Args:
input_shape: tuple, shape of the input tensor.
head_size: int, the number of features in each attention head.
num_heads: int, the number of attention heads.
ff_dim: int, the number of neurons in the feedforward neural
network.
num_Transformer_blocks: int, the number of Transformer blocks.
mlp_units: list of ints, the number of neurons in each layer of
the MLP.
n_classes: list of ints, the number of output classes.
dropout: float, dropout rate.
mlp_dropout: float, dropout rate in the MLP.
Returns:
A model layer.
"""
x = MultiHeadAttention(
key_dim=head_size, num_heads=num_heads,
dropout=dropout)(inputs, inputs)
x = Dropout(dropout)(x)
x = LayerNormalization(epsilon=1e-6)(x)
res = x + inputs

x = Conv1D(filters=ff_dim, kernel_size=1, activation="relu")(res)
x = Dropout(dropout)(x)
x = Conv1D(filters=inputs.shape[-1], kernel_size=1)(x)
outputs = Dropout(dropout)(x) + res

return outputs
A Keras model.
"""

x1, x2 = Input(shape=input_shape)
x1 = Dense(n_classes[0], activation="relu")(x1)
x2 = Dense(n_classes[1], activation="relu")(x2)
x = (x1 + x2)
x = GlobalAveragePooling1D(data_format="channels_first")(x)

for dim in mlp_units:
x = Dense(dim, activation="relu")(x)
x = Dropout(mlp_dropout)(x)

for _ in range(num_Transformer_blocks):
x = self.tstfbuilder.build_decoderblock(
x,
head_size,
num_heads,
ff_dim,
dropout
)

return Model(inputs, z)

def call(self, inputs):
"""
Expand All @@ -149,29 +221,4 @@ def summary(self):
"""
self.timeseriestransformer.summary()

'''
def compile(self, loss="sparse_categorical_crossentropy",
optimizer="adam",
metrics=["sparse_categorical_accuracy"]):
"""
Compiles the TimeSeriesTransformer model.
Args:
loss: str, loss function.
optimizer: str, optimizer.
metrics: list of str, evaluation metrics.
Returns:
None.
"""
super()
self.timeseriestransformer.compile(
loss="sparse_categorical_crossentropy",
optimizer="adam",
metrics=["sparse_categorical_accuracy"])
return
'''

# EOF
32 changes: 29 additions & 3 deletions model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,14 @@ def __init__(self):
def build_transformerblock(self, inputs, head_size, num_heads,
ff_dim, dropout):
"""
Constructs the transformer block. This consists of multi-head
attention, dropout, layer normalization, a residual connection,
a feedforward neural network, and another residual connection.
Constructs the transformer block. A transformer block consists of the
following steps:
1. multi-head attention
2. dropout
3. layer normalization
4. residual connection
5. feedforward neural network
6. residual connection
Args:
inputs: Tensor, batch of input data.
Expand All @@ -69,6 +74,27 @@ def build_transformerblock(self, inputs, head_size, num_heads,

return outputs

def build_decoderblock(self, inputs, head_size, num_heads, ff_dim,
dropout):
"""
Constructs the decoder block. This consists of masked multi-head
attention, dropout, layer normalization, a residual connection,
a feedforward neural network, and another residual connection, but in
the reverse order as the encoder block.
"""

x = LayerNormalization(epsilon=1e-6)(inputs)
x = Conv1D(filters=ff_dim, kernel_size=1, activation="relu")(inputs)
x = Dropout(dropout)(x)
x = Conv1D(filters=inputs.shape[-1], kernel_size=1)(x)
x = Dropout(dropout)(x)
res = x + inputs
outputs = MultiHeadAttention(
key_dim=head_size, num_heads=num_heads,
dropout=dropout)(res, res, use_causal_mask=True)

return outputs

def call(self, inputs):
"""
Calls the TimeSeriesTransformer model on a batch of inputs.
Expand Down

0 comments on commit 465dd25

Please sign in to comment.