-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Multiple data streams integrated together in an intermediate fusion p…
…ipeline
- Loading branch information
Showing
6 changed files
with
163 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
model.py | ||
Contains the core model architecture for geonosis. | ||
""" | ||
|
||
from keras import Input, Model | ||
from keras.layers import BatchNormalization, Conv1D, Dense, Dropout, Reshape, \ | ||
GlobalAveragePooling1D, LayerNormalization, Masking, \ | ||
MultiHeadAttention | ||
|
||
class TimeseriesTransformerBuilder: | ||
|
||
def __init__(self): | ||
""" | ||
Initializes the TimeSeriesTransformer class. This class is a | ||
wrapper around a Keras model that consists of a series of | ||
Transformer blocks followed by an MLP. | ||
Args: | ||
input_shape: tuple, shape of the input tensor. | ||
head_size: int, the number of features in each attention head. | ||
num_heads: int, the number of attention heads. | ||
ff_dim: int, the number of neurons in the feedforward neural | ||
network. | ||
num_Transformer_blocks: int, the number of Transformer blocks. | ||
mlp_units: list of ints, the number of neurons in each layer of | ||
the MLP. | ||
n_classes: int, the number of output classes. | ||
dropout: float, dropout rate. | ||
mlp_dropout: float, dropout rate in the MLP. | ||
Attributes: | ||
timeseriestransformer: Keras model, the TimeSeriesTransformer | ||
model. | ||
""" | ||
|
||
def build_transformerblock(self, inputs, head_size, num_heads, | ||
ff_dim, dropout): | ||
""" | ||
Constructs the transformer block. This consists of multi-head | ||
attention, dropout, layer normalization, a residual connection, | ||
a feedforward neural network, and another residual connection. | ||
Args: | ||
inputs: Tensor, batch of input data. | ||
head_size: int, the number of features in each attention head. | ||
num_heads: int, the number of attention heads. | ||
ff_dim: int, the number of neurons in the feedforward neural | ||
network. | ||
dropout: float, dropout rate. | ||
Returns: | ||
A model layer. | ||
""" | ||
x = MultiHeadAttention( | ||
key_dim=head_size, num_heads=num_heads, | ||
dropout=dropout)(inputs, inputs) | ||
x = Dropout(dropout)(x) | ||
x = LayerNormalization(epsilon=1e-6)(x) | ||
res = x + inputs | ||
|
||
x = Conv1D(filters=ff_dim, kernel_size=1, activation="relu")(res) | ||
x = Dropout(dropout)(x) | ||
x = Conv1D(filters=inputs.shape[-1], kernel_size=1)(x) | ||
outputs = Dropout(dropout)(x) + res | ||
|
||
return outputs | ||
|
||
def call(self, inputs): | ||
""" | ||
Calls the TimeSeriesTransformer model on a batch of inputs. | ||
Args: | ||
inputs: Tensor, batch of input data. | ||
Returns: | ||
Tensor, resulting output of the TimeSeriesTransformer model. | ||
""" | ||
return self.timeseriestransformer(inputs) | ||
|
||
def summary(self): | ||
""" | ||
Prints a summary of the TimeSeriesTransformer model. | ||
Args: | ||
None. | ||
Returns: | ||
None. | ||
""" | ||
self.timeseriestransformer.summary() | ||
|
||
# EOF |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters