Skip to content

Train integration #4

Merged
merged 6 commits into from
Mar 2, 2026
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions model/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,8 @@
def sublinear(x):
return x / (1 + K.sqrt(K.sqrt(K.abs(x))))

def linear(x):
return x
#return x / (1 + K.sqrt(K.sqrt(K.abs(x))))

# EOF
19 changes: 14 additions & 5 deletions model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
GlobalAveragePooling1D, LayerNormalization, Masking, Conv2D, \
MultiHeadAttention, concatenate

from model.activation import sublinear
from model.activation import sublinear, linear
from model.transformer import TimeseriesTransformerBuilder as TSTFBuilder

class CompoundModel(Model):
Expand Down Expand Up @@ -77,6 +77,8 @@ def _modelstack(self, input_shape, head_size, num_heads, ff_dim,
#x = inputs
#inputs = Masking(mask_value=pad_value)(inputs)
x = BatchNormalization()(inputs)

# Transformer blocks
for _ in range(num_Transformer_blocks):
x = self.tstfbuilder.build_transformerblock(
x,
Expand All @@ -85,12 +87,18 @@ def _modelstack(self, input_shape, head_size, num_heads, ff_dim,
ff_dim,
dropout
)

# Pooling and simple DNN block
x = GlobalAveragePooling1D(data_format="channels_first")(x)
for dim in mlp_units:
x = Dense(dim, activation="relu")(x)
x = Dropout(mlp_dropout)(x)
y = Dense(n_classes[0], activation="softmax")(x)
z = Dense(n_classes[1], activation="softmax")(x)

# Two separate latent spaces supported
#y = Dense(n_classes[0], activation="softmax")(x)
#z = Dense(n_classes[1], activation="softmax")(x)
y = Dense(n_classes[0], activation="relu")(x)
z = Dense(n_classes[1], activation="relu")(x)

return Model(inputs, [y, z])

Expand Down Expand Up @@ -189,10 +197,11 @@ def _modelstack(self, input_shape, head_size, num_heads, ff_dim,
x = Dense(full_dimension, activation="relu")(x)
x = Reshape((input_shape[0], input_shape[1]))(x)

"""
for dim in mlp_units:
x = Dense(dim, activation="relu")(x)
x = Dropout(mlp_dropout)(x)

"""
for _ in range(num_Transformer_blocks):
x = self.tstfbuilder.build_transformerblock(
x,
Expand All @@ -206,7 +215,7 @@ def _modelstack(self, input_shape, head_size, num_heads, ff_dim,
x = Conv1D(filters=input_shape[1],
kernel_size=1,
padding="valid",
activation=sublinear)(x)
activation=linear)(x)

return Model(inputs, x)

Expand Down
2 changes: 1 addition & 1 deletion pipe/etl.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def split_sets(data: DataFrame, split=[0.99, 0.005, 0.005]) -> tuple:
def trim(dataframe, column):

ndarray = np.array(dataframe.select(column).collect()) \
.reshape(-1, 34, 133)
.reshape(-1, 36, 133)

return ndarray

Expand Down
42 changes: 29 additions & 13 deletions pipe/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import cv2 as cv
import h5py
import numpy as np
from numpy import ndarray
from pyspark.ml.linalg import Vectors
from pyspark.sql import SparkSession, Row, DataFrame
import scipy as sp

Expand All @@ -35,8 +37,10 @@ def extract(spark: SparkSession) -> DataFrame:
labels.append(line.strip().split(",")[0])

reader = FileReader(spark, filetype=FileType.MAT)

return reader.read_file(path, labels)
rdd = spark.sparkContext.parallelize(reader.read_file(path, labels),
numSlices=200)
#return reader.read_file(path, labels)
return spark.createDataFrame(rdd)

def image_pipe(spark: SparkSession, imagepath: Path, namepattern: str,
stacksize: int) -> np.ndarray:
Expand All @@ -58,6 +62,15 @@ def image_pipe(spark: SparkSession, imagepath: Path, namepattern: str,

return images

def strip_array(arr: ndarray) -> ndarray:
dtype = arr.dtype
while dtype == "object":
arr = arr[0]
dtype = arr.dtype
if len(arr) == 0:
return ["unknown"]
return arr

class FileReader:
"""
Class to read spectrograms and metadata from different file formats based
Expand Down Expand Up @@ -105,7 +118,7 @@ def metadata_read(self, metapath: Path, labels:list,

def read_matfiles(self, specpath: Path,
datakinds: List[DataKind],
default_size: tuple = (34, 130),
default_size: tuple = (36, 130),
pad_value: float = 0.) -> DataFrame:
"""
Loads data for each stack iteration from a set of mat files,
Expand All @@ -121,8 +134,7 @@ def read_matfiles(self, specpath: Path,
Returns:
DataFrame: Spark DataFrame containing the requested data.
"""
data = []
row = {}
#data = []
labels = glob.glob(str(specpath/"matfiles"/"*.mat"))
nloops = default_size[0]
nfreq = default_size[1]
Expand All @@ -137,6 +149,8 @@ def read_matfiles(self, specpath: Path,
ncnt_scale = 5.

for label in labels:
row = {}
print(label)
matdata = sp.io.loadmat(specpath/"matfiles"/label)
ncnt = np.log10(matdata["NCNT"][0])
if np.min(ncnt) < 2:
Expand Down Expand Up @@ -174,18 +188,20 @@ def read_matfiles(self, specpath: Path,
(matdata["NSD"][0] - nsd_meanshift)
timeseries_array[time_offset:, 132] = \
np.log10(matdata["NCNT"][0]) / ncnt_scale
row["timeseries"] = timeseries_array.tolist()
timeseries_array = Vectors.dense(timeseries_array.flatten())
row["timeseries"] = timeseries_array

if DataKind.TREATMENT in datakinds:
row["treatment"] = matdata["header"]["drug"][0][0][0].lower()
row["treatment"] = strip_array(
matdata["header"]["drug"])[0].lower()
if DataKind.TARGET in datakinds:
try:
row["target"] = matdata["header"]["cell"][0][0][0].lower()
except:
row["target"] = "unknown"
data.append(Row(**row))
row["target"] = strip_array(
matdata["header"]["cell"])[0].lower()
row["target"] = "unknown"
#data.append(Row(**row))
yield row

return self.spark.createDataFrame(data)
#return self.spark.createDataFrame(data)

def read_hdf5(self, specpath: Path, labels: list,
namepattern:str="averaged_spectrogram{}.hdf5") \
Expand Down
29 changes: 28 additions & 1 deletion pipe/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from pyspark.sql import DataFrame, functions, SparkSession, types
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, OneHotEncoder
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler

def merge_redundant_treatment_labels(dataframe: DataFrame) -> DataFrame:
"""
Expand Down Expand Up @@ -43,6 +43,7 @@ def onehot(dataframe: DataFrame, keys: list) -> DataFrame:
pyspark.sql.DataFrame: New DataFrame with one-hot encoded column(s).
"""

""" OLD BLOCK
indexers = []
encoders = []
indexed_cols = []
Expand Down Expand Up @@ -79,6 +80,32 @@ def onehot(dataframe: DataFrame, keys: list) -> DataFrame:
for column_name in keys:
result = result.withColumnRenamed(column_name, f"{column_name}_str")
result = result.withColumnRenamed(f"{column_name}_encoded", column_name)
"""

indexer = StringIndexer(
inputCols=keys,
outputCols=[f"{c}_idx" for c in keys],
handleInvalid="keep"
)

encoder = OneHotEncoder(
inputCols=[f"{c}_idx" for c in keys],
outputCols=[f"{c}_vec" for c in keys],
dropLast=False
)

assembler = VectorAssembler(
inputCols=[f"{c}_vec" for c in keys],
outputCol="features"
)
pipeline = Pipeline(stages=[indexer, encoder, assembler])
model = pipeline.fit(dataframe)
result = model.transform(dataframe)

for c in keys:
result = result.withColumnRenamed(c, f"{c}_str") \
.withColumnRenamed(f"{c}_vec", c)


return result

Expand Down
120 changes: 119 additions & 1 deletion train/autoencoder_train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,121 @@
# -*- coding: utf-8 -*-
#-*- coding: utf-8 -*-

from datetime import datetime
import time
import typing
from typing import List

import numpy as np
import os
import keras
from keras.metrics import MeanSquaredError
from keras import Model
from keras.callbacks import ModelCheckpoint, CSVLogger
import matplotlib.pyplot as plt

from model.model import CompoundModel
from model.metrics import MutualInformation, mutual_information
from visualize.visualize import confusion_matrix
from visualize.plot import roc_plot
from train.encoder_train import build_encoder
from train.decoder_train import build_decoder

def autoencoder_workflow(params, shape, n_classes,
train_set, validation_set, test_set,
categories, keys, path):

model = build_autoencoder(params, shape, n_classes)
model = train_autoencoder(params, model, train_set, validation_set, path)

m = {key: None for key in keys}
m, test_predict = test_autoencoder(
model,
test_set,
m
)
model_metrics = {metric: value for metric, value in m.items()}

evaluate_autoencoder(
params,
test_predict,
test_set[0],
categories,

keys,
path
)

save_autoencoder(params, model, path)

def build_autoencoder(params, shape, n_classes):
autoencoder_params = params["autoencoder_params"]
#mi = MutualInformation()
mse = MeanSquaredError()

encoder_model = build_encoder(params, shape, n_classes)
decoder_model = build_decoder(params, shape, n_classes)
model = keras.Sequential([encoder_model, decoder_model])
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=4e-4),
loss=autoencoder_params["loss"],
metrics=[mse]#, mutual_information]
)

return model

def train_autoencoder(params, model, train_set, validation_set, path):
log_level = params["log_level"]
timestamp = params["timestamp"]
params = params["autoencoder_params"]
callbacks = [
ModelCheckpoint(
filepath=path / timestamp / f"{timestamp}_checkpoint.keras",
monitor = "val_loss",
save_best_only=True,
save_weights_only=False,
verbose=1
),
CSVLogger(path / timestamp / f"{timestamp}_log.csv")
]

start = time.time()
model.fit(
x=train_set, y=train_set,
validation_data=(validation_set, validation_set),
batch_size=params["batch_size"],
epochs=params["epochs"],
verbose=log_level,
callbacks=callbacks
)
end = time.time()
print("Training time: ", end - start)
return model

def test_autoencoder(model: Model, test: List, metrics: dict):
"""
"""

test_eval = model.evaluate(test, test)
if len(metrics.keys()) == 1:
metrics[metrics.keys()[0]] = test_eval
else:
for i, key in enumerate(metrics.keys()):
metrics[key] = np.mean(test_eval[i])

test_predict = model.predict(test)[0]

return metrics, test_predict

def evaluate_autoencoder(params, test_predict, test_set, categories, keys, path):
plt.pcolor(test_set)
plt.savefig(path / params["timestamp"] / "original.png")
plt.close()
plt.pcolor(test_predict)
plt.savefig(path / params["timestamp"] / "reproduction.png")
plt.close()
return

def save_autoencoder(params, model, path):
model.save(path / params["timestamp"] / f"{params['timestamp']}_autoencoder.keras")

# EOF
17 changes: 13 additions & 4 deletions train/decoder_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from visualize.plot import spectra_plot

def decoder_workflow(params, train_set, validation_set, test_set,
n_classes, categories, keys):
decoder = load_decoder(params, train_set[0].shape[1:], n_classes)
n_classes, categories, keys, modelpath):
decoder = build_decoder(params, train_set[0].shape[1:], n_classes)

decoder = train_decoder(decoder, params, train_set, validation_set)
# Test model performance
Expand All @@ -26,7 +26,7 @@ def decoder_workflow(params, train_set, validation_set, test_set,
spectra_plot(test_predict[0], name=f"{target}-{treatment}-predict")
spectra_plot(test_set[0][0], name=f"{target}-{treatment}-true")

def load_decoder(params, input_shape, n_classes):
def build_decoder(params, input_shape, n_classes):
"""
"""

Expand Down Expand Up @@ -86,8 +86,17 @@ def test_decoder(decoder: Model, test: List, metrics: dict):

return metrics, test_predict

def save_decoder(decoder: Model):
def evaluate_decoder(params, test_predict, test_set, test_loss,
test_accuracy, categories, keys):
params = params["decoder_params"]

for predict, groundtruth, key in zip(test_predict, test_set[1], keys):
confusion_matrix(predict, groundtruth, categories[key], key)
roc_plot(predict, groundtruth, key)
save_metric(params, test_loss, test_accuracy)

def save_decoder(decoder: Model):
model.save(path + "decoder.keras")
return

# EOF
Loading
Loading