diff --git a/decodertest.py b/decodertest.py new file mode 100644 index 0000000..544df15 --- /dev/null +++ b/decodertest.py @@ -0,0 +1,264 @@ +# --*-- coding: utf-8 --*-- +""" +train.py + +Launches the training process for the model. +""" + +import os +from pathlib import Path +import time + +os.environ["KERAS_BACKEND"] = "jax" +os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" + +import jax +import json +import numpy as np +#from pyspark.ml.feature import OneHotEncoder, StringIndexer +from pyspark.sql import SparkSession, functions, types, Row +import pyspark as spark +import tensorflow as tf +import keras +import keras.metrics as metrics +import matplotlib.pyplot as plt +from sklearn.metrics import auc, confusion_matrix, roc_curve +from sklearn.preprocessing import OneHotEncoder, LabelBinarizer + +from model.model import CompoundModel, DecoderModel +from pipe.etl import etl, read + +with open("parameters.json", "r") as file: + params = json.load(file) + + # data parameters + SPLIT = params["split"] + LOAD_FROM_SCRATCH = params["load_from_scratch"] + + # model parameters + HEAD_SIZE = params["head_size"] + NUM_HEADS = params["num_heads"] + FF_DIM = params["ff_dim"] + NUM_TRANSFORMER_BLOCKS = params["num_transformer_blocks"] + MLP_UNITS = params["mlp_units"] + DROPOUT = params["dropout"] + MLP_DROPOUT = params["mlp_dropout"] + + # training parameters + BATCH_SIZE = params["batch_size"] + EPOCHS = params["epochs"] + LOG_LEVEL = params["log_level"] + del params + +def trim(dataframe, column): + + ndarray = np.array(dataframe.select(column).collect()) \ + .reshape(-1, 32, 130) + + return ndarray + +def get_model(input_shape, n_classes): + model = DecoderModel(input_shape, HEAD_SIZE, NUM_HEADS, FF_DIM, + NUM_TRANSFORMER_BLOCKS, MLP_UNITS, n_classes, + dropout=DROPOUT, mlp_dropout=MLP_DROPOUT) + return model + +def transform(spark, dataframe, keys): + dataframe = dataframe.withColumn( + "index", functions.monotonically_increasing_id() + ) + bundle = {key: [ + arr.tolist() + for arr in OneHotEncoder(sparse_output=False) \ + .fit_transform(dataframe.select(key).collect()) + ] for key in keys + } + + bundle = [dict(zip(bundle.keys(), values)) + for values in zip(*bundle.values())] + schema = types.StructType([ + types.StructField(key, types.ArrayType(types.FloatType()), True) + for key in keys + ]) + newframe = spark.createDataFrame(bundle, schema=schema).withColumn( + "index", functions.monotonically_increasing_id() + ) + for key in keys: + dataframe = dataframe.withColumnRenamed(key, f"{key}_str") + dataframe = dataframe.join(newframe, on="index", how="inner") + + return dataframe + +def build_dict(df, key): + df = df.select(key, f"{key}_str").distinct() + + return df.rdd.map( + lambda row: (str(np.argmax(row[key])), row[f"{key}_str"]) + ).collectAsMap() + + +def get_data(spark, split=[0.99, 0.005, 0.005]): + path = Path("/app/workdir") + + labels = [] + with open(path / "train.csv", "r") as file: + for line in file: + labels.append(line.strip().split(",")[0]) + + pipe = SpectrogramPipe(spark, filetype="matfiles") + data = pipe.spectrogram_pipe(path, labels) + data.select("treatment").replace("virus", "cpv") \ + .replace("cont", "pbs") \ + .replace("control", "pbs") \ + .replace("dld", "pbs").distinct() + + data = transform(spark, data, ["treatment", "target"]) + category_dict = { + key: build_dict(data, key) for key in ["treatment", "target"] + } + splits = data.randomSplit(split, seed=42) + trainx, valx, testx = (trim(dset, "spectra") for dset in splits) + trainy, valy, testy = ( + [np.array(dset.select("treatment").collect()).squeeze(), + np.array(dset.select("target").collect()).squeeze()] + for dset in splits + ) + + + return ((trainx, trainy), (valx, valy), (testx, testy), category_dict) + + +def main(): + # jax mesh setup + """ + devices = jax.devices("gpu") + mesh = keras.distribution.DeviceMesh( + shape=(len(devices),), axis_names=["model"], devices=devices + ) + layout_map = keras.distribution.LayoutMap(mesh) + layout_map["dense.*kernel"] = (None, "model") + layout_map["dense.*bias"] = ("model",) + layout_map["dense.*kernel_regularizer"] = (None, "model") + layout_map["dense.*bias_regularizer"] = ("model",) + layout_map["dense.*activity_regularizer"] = (None,) + layout_map["dense.*kernel_constraint"] = (None, "model") + layout_map["dense.*bias_constraint"] = ("model",) + layout_map["conv1d.*kernel"] = (None, None, None, "model") + layout_map["conv1d.*kernel_regularizer"] = (None, None, None, "model") + layout_map["conv1d.*bias_regularizer"] = ("model",) + + model_parallel = keras.distribution.ModelParallel( + layout_map=layout_map + ) + keras.distribution.set_distribution(model_parallel) + """ + + spark = SparkSession.builder.appName("train").getOrCreate() + + keys = ["treatment", "target"] + + LOAD_FROM_SCRATCH = False + if LOAD_FROM_SCRATCH: + data = etl(spark, split=SPLIT) + else: + data = read(spark, split=SPLIT) + + (train_set, validation_set, test_set, categories) = data + + n_classes = [dset.shape[1] for dset in train_set[1]] + model = get_model(train_set[0].shape[1:], n_classes) + model.compile(optimizer=keras.optimizers.Adam(learning_rate=4e-4), + loss=["mse"], + ) + if True: #LOG_LEVEL == 1: + model.summary() + + start = time.time() + model.fit(x=train_set[1], y=train_set[0], + validation_data=(validation_set[1], validation_set[0]), + batch_size=BATCH_SIZE, epochs=EPOCHS, verbose=LOG_LEVEL) + end = time.time() + print("Training time: ", end - start) + + # Test model performance + #test_loss, test_accuracy = model.evaluate(test_set[1], test_set[0]) + test_predict = model.predict(test_set[1]) + plt.pcolor(test_predict[0], cmap='bwr', clim=(-1, 1)) + plt.savefig("sample_spectra.png") + plt.close() + plt.pcolor(test_predict[0] - np.mean(np.array(test_predict[0])[:4], axis=0), + cmap='bwr', clim=(-1, 1)) + plt.savefig("sample_spectrogram.png") + plt.close() + exit() + print(f"Test loss: {test_loss}, test accuracy: {test_accuracy}") + for predict, groundtruth, key in zip(test_predict, test_set[1], keys): + conf_matrix = confusion_matrix( + np.argmax(predict, axis=1), + np.argmax(groundtruth, axis=1), + labels=range(len(categories[key].values())), + normalize="pred" + ) + plt.pcolormesh(conf_matrix, edgecolors="black", linewidth=0.5)#origin="upper") + plt.gca().set_aspect("equal") + plt.colorbar() + plt.xticks([int(num) for num in categories[key].keys()], + categories[key].values(), rotation=270) + plt.yticks([int(num) for num in categories[key].keys()], + categories[key].values()) + plt.xlabel("True label") + plt.ylabel("Predicted label") + plt.gcf().set_size_inches(len(categories[key])/10+4, + len(categories[key])/10+3) + plt.savefig(f"/app/workdir/figures/confusion_matrix_{key}.png", + bbox_inches="tight") + plt.close() + with open(f"confusion_matrix_{key}.json", 'w') as f: + confusion_dict = {"prediction": predict.tolist(), + "true": groundtruth.tolist(), + "matrix": conf_matrix.tolist()} + json.dump(confusion_dict, f) + + label_binarizer = LabelBinarizer().fit(groundtruth) + y_onehot_test = label_binarizer.transform(groundtruth) + fpr, tpr, _ = roc_curve( + groundtruth.ravel(), + predict.ravel() + ) + roc_auc = auc(fpr, tpr) + plt.plot(fpr, tpr, label=f"AUC = {roc_auc:.2f}") + plt.savefig(f"/app/workdir/figures/roc_curve_{key}.png", + bbox_inches="tight") + with open(f"roc_fpr_tpr_{key}.json", 'w') as f: + roc_dict = {"fpr": fpr.tolist(), + "tpr": tpr.tolist(), + "auc": roc_auc} + json.dump(roc_dict, f) + print("Done") + + # Save the hyperparameters and metric to csv + metric = { + "head_size": HEAD_SIZE, + "num_heads": NUM_HEADS, + "ff_dim": FF_DIM, + "num_transformer_blocks": NUM_TRANSFORMER_BLOCKS, + "mlp_units": MLP_UNITS[0], + "dropout": DROPOUT, + "mlp_dropout": MLP_DROPOUT, + "batch_size": BATCH_SIZE, + "epochs": EPOCHS, + "test_loss": test_loss, + "test_accuracy": test_accuracy + } + if not os.path.exists("/app/workdir/metrics.csv"): + with open("/app/workdir/metrics.csv", "w") as f: + f.write(",".join(metric.keys()) + "\n") + with open("/app/workdir/metrics.csv", "a") as f: + f.write(",".join([str(value) for value in metric.values()]) + "\n") + + return + +if __name__ == "__main__": + main() + +# EOF