Skip to content

Commit

Permalink
Model structure slightly modified
Browse files Browse the repository at this point in the history
  • Loading branch information
Dawith committed Mar 2, 2026
1 parent 4878ea5 commit 2f9298c
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
GlobalAveragePooling1D, LayerNormalization, Masking, Conv2D, \
MultiHeadAttention, concatenate

from model.activation import sublinear
from model.activation import sublinear, linear
from model.transformer import TimeseriesTransformerBuilder as TSTFBuilder

class CompoundModel(Model):
Expand Down Expand Up @@ -77,6 +77,8 @@ def _modelstack(self, input_shape, head_size, num_heads, ff_dim,
#x = inputs
#inputs = Masking(mask_value=pad_value)(inputs)
x = BatchNormalization()(inputs)

# Transformer blocks
for _ in range(num_Transformer_blocks):
x = self.tstfbuilder.build_transformerblock(
x,
Expand All @@ -85,12 +87,18 @@ def _modelstack(self, input_shape, head_size, num_heads, ff_dim,
ff_dim,
dropout
)

# Pooling and simple DNN block
x = GlobalAveragePooling1D(data_format="channels_first")(x)
for dim in mlp_units:
x = Dense(dim, activation="relu")(x)
x = Dropout(mlp_dropout)(x)
y = Dense(n_classes[0], activation="softmax")(x)
z = Dense(n_classes[1], activation="softmax")(x)

# Two separate latent spaces supported
#y = Dense(n_classes[0], activation="softmax")(x)
#z = Dense(n_classes[1], activation="softmax")(x)
y = Dense(n_classes[0], activation="relu")(x)
z = Dense(n_classes[1], activation="relu")(x)

return Model(inputs, [y, z])

Expand Down Expand Up @@ -189,10 +197,11 @@ def _modelstack(self, input_shape, head_size, num_heads, ff_dim,
x = Dense(full_dimension, activation="relu")(x)
x = Reshape((input_shape[0], input_shape[1]))(x)

"""
for dim in mlp_units:
x = Dense(dim, activation="relu")(x)
x = Dropout(mlp_dropout)(x)

"""
for _ in range(num_Transformer_blocks):
x = self.tstfbuilder.build_transformerblock(
x,
Expand All @@ -206,7 +215,7 @@ def _modelstack(self, input_shape, head_size, num_heads, ff_dim,
x = Conv1D(filters=input_shape[1],
kernel_size=1,
padding="valid",
activation=sublinear)(x)
activation=linear)(x)

return Model(inputs, x)

Expand Down

0 comments on commit 2f9298c

Please sign in to comment.