diff --git a/pipe/pipe.py b/pipe/pipe.py index 7282246..1c76f83 100644 --- a/pipe/pipe.py +++ b/pipe/pipe.py @@ -75,10 +75,16 @@ def spectrogram_pipe_matfiles(self, specpath: Path, labels:list, for label in labels: matdata = sp.io.loadmat(specpath/label) row["treatment"] = matdata["header"][0][0][4][0].lower() + try: + row["target"] = matdata["header"][0][0][2][0].lower() + except IndexError: + row["target"] = "unknown" row["label"] = label - spectrogram = np.array(matdata["SPF"][0]) + spectrogram = np.array(matdata["SP"][0]) if len(spectrogram.shape) == 3: spectrogram = spectrogram[0] + if spectrogram.shape[0] > default_size[0]: + spectrogram = spectrogram[:default_size[0], :] spectrogram = np.pad( spectrogram, ((default_size[0] - spectrogram.shape[0], 0), diff --git a/train.py b/train.py index a447ddd..4274774 100644 --- a/train.py +++ b/train.py @@ -15,12 +15,14 @@ import json import numpy as np from pipe.pipe import SpectrogramPipe -from pyspark.ml.feature import StringIndexer, IndexToString -from pyspark.sql import SparkSession, functions +import pyspark as spark +#from pyspark.ml.feature import OneHotEncoder, StringIndexer +from pyspark.sql import SparkSession, functions, types, Row import tensorflow as tf import keras import matplotlib.pyplot as plt from sklearn.metrics import confusion_matrix +from sklearn.preprocessing import OneHotEncoder from model.model import TimeSeriesTransformer as TSTF @@ -52,6 +54,46 @@ def trim(dataframe, column): return ndarray +def get_model(input_shape, n_classes): + model = TSTF(input_shape, HEAD_SIZE, NUM_HEADS, FF_DIM, + NUM_TRANSFORMER_BLOCKS, MLP_UNITS, n_classes, + dropout=DROPOUT, mlp_dropout=MLP_DROPOUT) + return model + +def transform(spark, dataframe, keys): + dataframe = dataframe.withColumn( + "index", functions.monotonically_increasing_id() + ) + bundle = {key: [ + arr.tolist() + for arr in OneHotEncoder(sparse_output=False) \ + .fit_transform(dataframe.select(key).collect()) + ] for key in keys + } + + bundle = [dict(zip(bundle.keys(), values)) + for values in zip(*bundle.values())] + schema = types.StructType([ + types.StructField(key, types.ArrayType(types.FloatType()), True) + for key in keys + ]) + newframe = spark.createDataFrame(bundle, schema=schema).withColumn( + "index", functions.monotonically_increasing_id() + ) + for key in keys: + dataframe = dataframe.withColumnRenamed(key, f"{key}_str") + dataframe = dataframe.join(newframe, on="index", how="inner") + + return dataframe + +def build_dict(df, key): + df = df.select(key, f"{key}_str").distinct() + + return df.rdd.map( + lambda row: (str(np.argmax(row[key])), row[f"{key}_str"]) + ).collectAsMap() + + def get_data(spark, split=[0.99, 0.005, 0.005]): path = Path("/app/workdir") @@ -62,48 +104,28 @@ def get_data(spark, split=[0.99, 0.005, 0.005]): pipe = SpectrogramPipe(spark, filetype="matfiles") data = pipe.spectrogram_pipe(path, labels) + data.select("treatment").replace("virus", "cpv") \ + .replace("cont", "pbs") \ + .replace("control", "pbs") \ + .replace("dld", "pbs").distinct() - indexer = StringIndexer(inputCol="treatment", outputCol="treatment_index") - indexed = indexer.fit(data).transform(data) + data = transform(spark, data, ["treatment", "target"]) + category_dict = { + key: build_dict(data, key) for key in ["treatment", "target"] + } + splits = data.randomSplit(split, seed=42) + trainx, valx, testx = (trim(dset, "spectrogram") for dset in splits) + trainy, valy, testy = ([np.array(dset.select("treatment").collect()).squeeze(), + np.array(dset.select("target").collect()).squeeze()] + for dset in splits) - selected = indexed.select("treatment", "treatment_index").distinct() - selected = selected.sort("treatment_index") - index_max = selected.agg(functions.max("treatment_index")).collect()[0][0] - - train_df, validation_df, test_df = indexed.randomSplit(split, seed=42) - - trainx = trim(train_df, "spectrogram") - trainy = np.array(train_df.select("treatment_index").collect()).astype(int) - _trainy = np.zeros((len(trainy), int(index_max+1))) - for index, value in enumerate(trainy): - _trainy[index, value] = 1. - #_trainy[np.arange(trainy.shape[0]), trainy] = 1. - trainy = _trainy - del _trainy - - valx = trim(validation_df, "spectrogram") - valy = np.array(validation_df.select("treatment_index").collect()) \ - .astype(int) - _valy = np.zeros((len(valy), int(index_max+1))) - for index, value in enumerate(valy): - _valy[index, value] = 1. - #_valy[np.arange(valy.shape[0]), valy] = 1. - valy = _valy - del _valy - - testx = trim(test_df, "spectrogram") - testy = np.array(test_df.select("treatment_index").collect()).astype(int) - _testy = np.zeros((len(testy), int(index_max+1))) - for index, value in enumerate(testy): - _testy[index, value] = 1. - testy = _testy - del _testy - - return (selected, (trainx, trainy), (valx, valy), (testx, testy)) + + return ((trainx, trainy), (valx, valy), (testx, testy), category_dict) def main(): # jax mesh setup + """ devices = jax.devices("gpu") mesh = keras.distribution.DeviceMesh( shape=(len(devices),), axis_names=["model"], devices=devices @@ -124,16 +146,19 @@ def main(): layout_map=layout_map ) keras.distribution.set_distribution(model_parallel) + """ spark = SparkSession.builder.appName("train").getOrCreate() - indices, train_set, validation_set, test_set = get_data(spark, split=SPLIT) + keys = ["treatment", "target"] + (train_set, validation_set, + test_set, categories) = get_data(spark, split=SPLIT) - n_classes = indices.count() + n_classes = [dset.shape[1] for dset in train_set[1]] model = get_model(train_set[0].shape[1:], n_classes) model.compile(optimizer=keras.optimizers.Adam(learning_rate=4e-4), loss="categorical_crossentropy", - metrics=["categorical_accuracy"] + metrics=["categorical_accuracy", "categorical_accuracy"] ) model.summary() @@ -145,15 +170,29 @@ def main(): print("Training time: ", end - start) # Test model performance - test_loss, test_accuracy = model.evaluate(test_set[0], test_set[1]) + test_loss, test_accuracy, _, _, _, _ = model.evaluate(test_set[0], test_set[1]) + print(model.metrics_names) test_predict = model.predict(test_set[0]) print(f"Test loss: {test_loss}, test accuracy: {test_accuracy}") - - conf_matrix = confusion_matrix(np.argmax(test_predict, axis=1), - np.argmax(test_set[1], axis=1)) - plt.imshow(conf_matrix, origin="upper") - plt.gca().set_aspect("equal") - plt.savefig("/app/workdir/confusion_matrix.png") + for predict, groundtruth, key in zip(test_predict, test_set[1], keys): + conf_matrix = confusion_matrix( + np.argmax(predict, axis=1), + np.argmax(groundtruth, axis=1), + labels=range(len(categories[key].values())), + normalize="pred" + ) + plt.imshow(conf_matrix, origin="upper") + plt.gca().set_aspect("equal") + plt.colorbar() + plt.xticks([int(num) for num in categories[key].keys()], + categories[key].values(), rotation=270) + plt.yticks([int(num) for num in categories[key].keys()], + categories[key].values()) + plt.gcf().set_size_inches(len(categories[key])/10+4, + len(categories[key])/10+3) + plt.savefig(f"/app/workdir/confusion_matrix_{key}.png", + bbox_inches="tight") + plt.close() return