Skip to content

Commit

Permalink
Pipe and train refactored to accommodate two categories classification
Browse files Browse the repository at this point in the history
  • Loading branch information
Dawith committed Apr 18, 2025
1 parent 044c3fa commit c70fa91
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 49 deletions.
8 changes: 7 additions & 1 deletion pipe/pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,16 @@ def spectrogram_pipe_matfiles(self, specpath: Path, labels:list,
for label in labels:
matdata = sp.io.loadmat(specpath/label)
row["treatment"] = matdata["header"][0][0][4][0].lower()
try:
row["target"] = matdata["header"][0][0][2][0].lower()
except IndexError:
row["target"] = "unknown"
row["label"] = label
spectrogram = np.array(matdata["SPF"][0])
spectrogram = np.array(matdata["SP"][0])
if len(spectrogram.shape) == 3:
spectrogram = spectrogram[0]
if spectrogram.shape[0] > default_size[0]:
spectrogram = spectrogram[:default_size[0], :]
spectrogram = np.pad(
spectrogram,
((default_size[0] - spectrogram.shape[0], 0),
Expand Down
135 changes: 87 additions & 48 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
import json
import numpy as np
from pipe.pipe import SpectrogramPipe
from pyspark.ml.feature import StringIndexer, IndexToString
from pyspark.sql import SparkSession, functions
import pyspark as spark
#from pyspark.ml.feature import OneHotEncoder, StringIndexer
from pyspark.sql import SparkSession, functions, types, Row
import tensorflow as tf
import keras
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import OneHotEncoder

from model.model import TimeSeriesTransformer as TSTF

Expand Down Expand Up @@ -52,6 +54,46 @@ def trim(dataframe, column):

return ndarray

def get_model(input_shape, n_classes):
model = TSTF(input_shape, HEAD_SIZE, NUM_HEADS, FF_DIM,
NUM_TRANSFORMER_BLOCKS, MLP_UNITS, n_classes,
dropout=DROPOUT, mlp_dropout=MLP_DROPOUT)
return model

def transform(spark, dataframe, keys):
dataframe = dataframe.withColumn(
"index", functions.monotonically_increasing_id()
)
bundle = {key: [
arr.tolist()
for arr in OneHotEncoder(sparse_output=False) \
.fit_transform(dataframe.select(key).collect())
] for key in keys
}

bundle = [dict(zip(bundle.keys(), values))
for values in zip(*bundle.values())]
schema = types.StructType([
types.StructField(key, types.ArrayType(types.FloatType()), True)
for key in keys
])
newframe = spark.createDataFrame(bundle, schema=schema).withColumn(
"index", functions.monotonically_increasing_id()
)
for key in keys:
dataframe = dataframe.withColumnRenamed(key, f"{key}_str")
dataframe = dataframe.join(newframe, on="index", how="inner")

return dataframe

def build_dict(df, key):
df = df.select(key, f"{key}_str").distinct()

return df.rdd.map(
lambda row: (str(np.argmax(row[key])), row[f"{key}_str"])
).collectAsMap()


def get_data(spark, split=[0.99, 0.005, 0.005]):
path = Path("/app/workdir")

Expand All @@ -62,48 +104,28 @@ def get_data(spark, split=[0.99, 0.005, 0.005]):

pipe = SpectrogramPipe(spark, filetype="matfiles")
data = pipe.spectrogram_pipe(path, labels)
data.select("treatment").replace("virus", "cpv") \
.replace("cont", "pbs") \
.replace("control", "pbs") \
.replace("dld", "pbs").distinct()

indexer = StringIndexer(inputCol="treatment", outputCol="treatment_index")
indexed = indexer.fit(data).transform(data)
data = transform(spark, data, ["treatment", "target"])
category_dict = {
key: build_dict(data, key) for key in ["treatment", "target"]
}
splits = data.randomSplit(split, seed=42)
trainx, valx, testx = (trim(dset, "spectrogram") for dset in splits)
trainy, valy, testy = ([np.array(dset.select("treatment").collect()).squeeze(),
np.array(dset.select("target").collect()).squeeze()]
for dset in splits)

selected = indexed.select("treatment", "treatment_index").distinct()
selected = selected.sort("treatment_index")
index_max = selected.agg(functions.max("treatment_index")).collect()[0][0]

train_df, validation_df, test_df = indexed.randomSplit(split, seed=42)

trainx = trim(train_df, "spectrogram")
trainy = np.array(train_df.select("treatment_index").collect()).astype(int)
_trainy = np.zeros((len(trainy), int(index_max+1)))
for index, value in enumerate(trainy):
_trainy[index, value] = 1.
#_trainy[np.arange(trainy.shape[0]), trainy] = 1.
trainy = _trainy
del _trainy

valx = trim(validation_df, "spectrogram")
valy = np.array(validation_df.select("treatment_index").collect()) \
.astype(int)
_valy = np.zeros((len(valy), int(index_max+1)))
for index, value in enumerate(valy):
_valy[index, value] = 1.
#_valy[np.arange(valy.shape[0]), valy] = 1.
valy = _valy
del _valy

testx = trim(test_df, "spectrogram")
testy = np.array(test_df.select("treatment_index").collect()).astype(int)
_testy = np.zeros((len(testy), int(index_max+1)))
for index, value in enumerate(testy):
_testy[index, value] = 1.
testy = _testy
del _testy

return (selected, (trainx, trainy), (valx, valy), (testx, testy))

return ((trainx, trainy), (valx, valy), (testx, testy), category_dict)


def main():
# jax mesh setup
"""
devices = jax.devices("gpu")
mesh = keras.distribution.DeviceMesh(
shape=(len(devices),), axis_names=["model"], devices=devices
Expand All @@ -124,16 +146,19 @@ def main():
layout_map=layout_map
)
keras.distribution.set_distribution(model_parallel)
"""

spark = SparkSession.builder.appName("train").getOrCreate()

indices, train_set, validation_set, test_set = get_data(spark, split=SPLIT)
keys = ["treatment", "target"]
(train_set, validation_set,
test_set, categories) = get_data(spark, split=SPLIT)

n_classes = indices.count()
n_classes = [dset.shape[1] for dset in train_set[1]]
model = get_model(train_set[0].shape[1:], n_classes)
model.compile(optimizer=keras.optimizers.Adam(learning_rate=4e-4),
loss="categorical_crossentropy",
metrics=["categorical_accuracy"]
metrics=["categorical_accuracy", "categorical_accuracy"]
)
model.summary()

Expand All @@ -145,15 +170,29 @@ def main():
print("Training time: ", end - start)

# Test model performance
test_loss, test_accuracy = model.evaluate(test_set[0], test_set[1])
test_loss, test_accuracy, _, _, _, _ = model.evaluate(test_set[0], test_set[1])
print(model.metrics_names)
test_predict = model.predict(test_set[0])
print(f"Test loss: {test_loss}, test accuracy: {test_accuracy}")

conf_matrix = confusion_matrix(np.argmax(test_predict, axis=1),
np.argmax(test_set[1], axis=1))
plt.imshow(conf_matrix, origin="upper")
plt.gca().set_aspect("equal")
plt.savefig("/app/workdir/confusion_matrix.png")
for predict, groundtruth, key in zip(test_predict, test_set[1], keys):
conf_matrix = confusion_matrix(
np.argmax(predict, axis=1),
np.argmax(groundtruth, axis=1),
labels=range(len(categories[key].values())),
normalize="pred"
)
plt.imshow(conf_matrix, origin="upper")
plt.gca().set_aspect("equal")
plt.colorbar()
plt.xticks([int(num) for num in categories[key].keys()],
categories[key].values(), rotation=270)
plt.yticks([int(num) for num in categories[key].keys()],
categories[key].values())
plt.gcf().set_size_inches(len(categories[key])/10+4,
len(categories[key])/10+3)
plt.savefig(f"/app/workdir/confusion_matrix_{key}.png",
bbox_inches="tight")
plt.close()

return

Expand Down

0 comments on commit c70fa91

Please sign in to comment.