Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions model/decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# -*- encoding: utf-8 -*-

# 3rd party module imports
from keras import Model
from keras.layers import Input, Dense, Reshape, Conv1D

# Local module imports
from model.activation import sublinear, linear
from model.transformer import TimeseriesTransformerBuilder as TSTFBuilder

class Decoder(Model):

def __init__(self, input_shape, head_size, num_heads, ff_dim,
num_Transformer_blocks, mlp_units,
dropout=0, mlp_dropout=0):
"""
Initializes the TimeSeriesTransformer class. This class is a
wrapper around a Keras model that consists of a series of
Transformer blocks followed by an MLP.
Args:
input_shape: tuple, shape of the input tensor.
head_size: int, the number of features in each attention head.
num_heads: int, the number of attention heads.
ff_dim: int, the number of neurons in the feedforward neural
network.
num_Transformer_blocks: int, the number of Transformer blocks.
mlp_units: list of ints, the number of neurons in each layer of
the MLP.
n_classes: int, the number of output classes.
dropout: float, dropout rate.
mlp_dropout: float, dropout rate in the MLP.
Attributes:
timeseriestransformer: Keras model, the TimeSeriesTransformer
model.
"""
self.tstfbuilder = TSTFBuilder()

super(Decoder, self).__init__()
self.decoder = self._modelstack(
input_shape, head_size, num_heads, ff_dim,
num_Transformer_blocks, mlp_units,
dropout, mlp_dropout)

def _modelstack(self, input_shape, head_size, num_heads, ff_dim,
num_Transformer_blocks, mlp_units,
dropout, mlp_dropout):
"""
Creates a Timeseries Transformer model. This consists of a series of
Transformer blocks followed by an MLP.
Args:
input_shape: tuple, shape of the input tensor.
head_size: int, the number of features in each attention head.
num_heads: int, the number of attention heads.
ff_dim: int, the number of neurons in the feedforward neural
network.
num_Transformer_blocks: int, the number of Transformer blocks.
mlp_units: list of ints, the number of neurons in each layer of
the MLP.
dropout: float, dropout rate.
mlp_dropout: float, dropout rate in the MLP.
Returns:
A Keras model.
"""

inputs = Input(shape=(mlp_units[-1],), name="decoder_input")
full_dimension = input_shape[0] * input_shape[1]
x = Dense(full_dimension, activation="relu", name="dec_dense1")(inputs)
x = Reshape((input_shape[0], input_shape[1]), name="dec_reshape")(x)

for i in range(num_Transformer_blocks):
x = self.tstfbuilder.build_transformerblock(
x,
head_size,
num_heads,
ff_dim,
dropout
)

# final layer with corrected shape
x = Conv1D(filters=input_shape[1],
kernel_size=1,
padding="valid",
activation=linear,
name="dec_conv1d")(x)

return Model(inputs, x, name="decoder")

def call(self, inputs):
"""
Calls the TimeSeriesTransformer model on a batch of inputs.
Args:
inputs (Tensor): batch of input data.
Returns:
(Tensor) Decoded reconstruction of the spectral data.
"""
return self.decoder(inputs)

def summary(self):
"""
Prints the Model summary.
"""
self.decoder.summary()

# EOF
119 changes: 119 additions & 0 deletions model/encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# -*- coding: utf-8 -*-
"""
Model Encoder block
"""

# Third party module imports
from keras import Input, Model
from keras.layers import BatchNormalization, Dense, Dropout, GlobalAveragePooling1D

# Local module imports
from model.transformer import TimeseriesTransformerBuilder as TSTFBuilder

class Encoder(Model):
"""
Encoder block that inherits keras Model class.
Args:
input_shape (tuple): Shape of the input tensor.
head_size (int): Number of features in each attention head.
num_heads (int) Number of attention heads.
ff_dim (int): Number of neurons in the feedforward neural network.
num_Transformer_blocks (int): Number of Transformer blocks.
mlp_units (List(int)): Number of neurons in each layer of the MLP.
n_classes (int): Number of output classes.
dropout (float): Dropout rate.
mlp_dropout (float): Dropout rate in the MLP.
Attributes:
timeseriestransformer: Keras model, the TimeSeriesTransformer model.
"""

def __init__(self, input_shape, head_size, num_heads, ff_dim,
num_Transformer_blocks, mlp_units,
dropout=0, mlp_dropout=0):
self.tstfbuilder = TSTFBuilder()

super(Encoder, self).__init__()
self.encoder = self._modelstack(
input_shape,
head_size,
num_heads,
ff_dim,
num_Transformer_blocks,
mlp_units,
dropout,
mlp_dropout
)

def _modelstack(self, input_shape, head_size, num_heads, ff_dim,
num_Transformer_blocks, mlp_units, dropout, mlp_dropout):
"""
Creates a Timeseries Transformer model. This consists of a series of
Transformer blocks followed by an MLP.
Args:
input_shape: tuple, shape of the input tensor.
head_size: int, the number of features in each attention head.
num_heads: int, the number of attention heads.
ff_dim: int, the number of neurons in the feedforward neural
network.
num_Transformer_blocks: int, the number of Transformer blocks.
mlp_units: list of ints, the number of neurons in each layer of
the MLP.
n_classes: list of ints, the number of output classes.
dropout: float, dropout rate.
mlp_dropout: float, dropout rate in the MLP.
Returns:
A Keras model.
"""

inputs = Input(shape=input_shape)
x = BatchNormalization()(inputs)

# Transformer blocks
for _ in range(num_Transformer_blocks):
x = self.tstfbuilder.build_transformerblock(
x,
head_size,
num_heads,
ff_dim,
dropout
)

# Pooling and simple DNN block
x = GlobalAveragePooling1D(data_format="channels_first")(x)
for dim in mlp_units:
x = Dense(dim, activation="relu")(x)
x = Dropout(mlp_dropout)(x)

# Two separate latent spaces supported

return Model(inputs, x)

def call(self, inputs):
"""
Calls the TimeSeriesTransformer model on a batch of inputs.
Args:
inputs: Tensor, batch of input data.
Returns:
Tensor, resulting output of the TimeSeriesTransformer model.
"""
return self.encoder(inputs)

def summary(self):
"""
Prints a summary of the TimeSeriesTransformer model.
Args:
None.
Returns:
None.
"""
self.encoder.summary()

# EOF
139 changes: 139 additions & 0 deletions model/latent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# -*- coding: utf-8 -*-

# 3rd party module imports
import keras
from keras import Input, Model
from keras.layers import Dense, Layer
import keras.ops as ops

# Local module imports
from model.losses import categorical_crossentropy, mean_absolute_error

class Latent(Model):
def __init__(self, mlp_dim, semantic_dims, var_dim):
super(Latent, self).__init__()
self.name = "Semantic Embedding Block"
self.semantic = Semantic(mlp_dim, semantic_dims)
self.variational = Variational(mlp_dim, var_dim)
self.latent = self._build(mlp_dim, semantic_dims, var_dim)

def _build(self, mlp_dim, semantic_dims, var_dim):
"""
"""

inputs = Input(shape=(mlp_dim,))
sem_inputs = [Input(shape=(dim,)) for dim in semantic_dims]
s, s_err = self.semantic([inputs] + sem_inputs)
v, kl_loss = self.variational(inputs)
x = Dense(mlp_dim, activation="relu")(ops.concatenate([s, v], axis=-1))

return Model([inputs, *sem_inputs], [x, s_err, kl_loss], name="latent")

def call(self, inputs):
return self.latent(inputs)

def summary(self):
return self.latent.summary()

class Semantic(Model):
def __init__(self, mlp_dim, semantic_dims):
self.name = "Semantic Embedding Block"
super(Semantic, self).__init__()
self.latent = self._build(mlp_dim, semantic_dims)

def _build(self, mlp_dim, semantic_dims):
"""
Embedding space for semantically meaningful variables. Everything is
laterally spaced out.
Args:
dims (list): List of dimensions for each variable. The final
dimension is reserved for regression variables.
Returns:
Model: Keras Model object that contains just a single layer deep
model, with a structured latent space that maps to
semantically meaningful variables.
"""

# Compute inverse log of dimensions to get weights
class_counts = ops.array(semantic_dims[:-1], dtype="float32")
weights = 1/ops.log(class_counts)

inputs = Input(shape=(mlp_dim,))
targets = [Input(shape=(dim,)) for dim in semantic_dims]

# One-hot encoding spaces
# Sample type, growth, treatment, dose value, dose unit
one_hots = [Dense(dim, activation="softmax")(inputs)
for dim in semantic_dims[:-1]]

# Regression spaces
# Min. frequency, Max. frequency, baseline time, treatment time,
# loop dt
reg = Dense(semantic_dims[-1], activation=None)(inputs)

# Compute categorical_crossentropy error against targets for
# categorical variables
errors = [
categorical_crossentropy(target, pred) * weight
for target, pred, weight in zip(targets[:-1], one_hots, weights)
]

# Compute MAE, with normalization by approximate range of values (~4)
errors.append(mean_absolute_error(targets[-1], reg) / 4.)
error = ops.sum(ops.stack(errors))

# Combine everything into single dense layer
concat = keras.layers.concatenate(one_hots + [reg], axis=-1)
output = Dense(mlp_dim, activation="relu")(concat)

return Model([inputs, *targets], [output, error], name="semantic")

def call(self, inputs):
return self.latent(inputs)

def summary(self):
return self.latent.summary()

class Variational(Model):
def __init__(self, dim, var_dim):
self.name = "Variational Embedding Block"
super(Variational, self).__init__()
self.latent = self._build(dim, var_dim)

def _build(self, dim, var_dim):
inputs = Input(shape=(dim,))
x = Dense(var_dim, activation="relu")(inputs)
z_mean = Dense(var_dim, activation=None)(x)
z_log_var = Dense(var_dim, activation=None)(x)
z = Sampling()([z_mean, z_log_var])
kl_loss = -0.5 * ops.sum(
1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var),
axis=1
)

return Model(inputs, [z, kl_loss], name="variational")

def call(self, inputs):
return self.latent(inputs)

def summary(self):
return self.latent.summary()

class Sampling(Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.seed_gen = keras.random.SeedGenerator(1337)

def call(self, inputs):
z_mean, z_log_var = inputs

# Reparameterization trick
eps = keras.random.normal(shape=ops.shape(z_mean), seed=self.seed_gen)
z = z_mean + ops.exp(0.5 * z_log_var) * eps

return z


# EOF
10 changes: 10 additions & 0 deletions model/losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# -*- encoding: utf-8 -*-

import keras.ops as ops

def categorical_crossentropy(y_true, y_pred):
y_pred = ops.clip(y_pred, 1e-7, 1.0)
return -ops.sum(y_true * ops.log(y_pred), axis=-1)

def mean_absolute_error(y_true, y_pred):
return ops.mean(ops.abs(y_pred - y_true), axis=-1)
Loading
Loading