From 065e8b3ac1fd84ea08ad0a2b452b2b5cb449f1e3 Mon Sep 17 00:00:00 2001 From: Dawith Date: Fri, 24 Oct 2025 21:15:17 -0400 Subject: [PATCH 1/3] Removed irrelvant plotting code for clarity while working --- decodertest.py | 67 +------------------------------------------------- 1 file changed, 1 insertion(+), 66 deletions(-) diff --git a/decodertest.py b/decodertest.py index bb34818..6c5b63e 100644 --- a/decodertest.py +++ b/decodertest.py @@ -155,72 +155,7 @@ def main(): plt.savefig("sample_spectrogram.png") plt.close() exit() - print(f"Test loss: {test_loss}, test accuracy: {test_accuracy}") - for predict, groundtruth, key in zip(test_predict, test_set[1], keys): - conf_matrix = confusion_matrix( - np.argmax(predict, axis=1), - np.argmax(groundtruth, axis=1), - labels=range(len(categories[key].values())), - normalize="pred" - ) - plt.pcolormesh(conf_matrix, edgecolors="black", linewidth=0.5)#origin="upper") - plt.gca().set_aspect("equal") - plt.colorbar() - plt.xticks([int(num) for num in categories[key].keys()], - categories[key].values(), rotation=270) - plt.yticks([int(num) for num in categories[key].keys()], - categories[key].values()) - plt.xlabel("True label") - plt.ylabel("Predicted label") - plt.gcf().set_size_inches(len(categories[key])/10+4, - len(categories[key])/10+3) - plt.savefig(f"/app/workdir/figures/confusion_matrix_{key}.png", - bbox_inches="tight") - plt.close() - with open(f"confusion_matrix_{key}.json", 'w') as f: - confusion_dict = {"prediction": predict.tolist(), - "true": groundtruth.tolist(), - "matrix": conf_matrix.tolist()} - json.dump(confusion_dict, f) - - label_binarizer = LabelBinarizer().fit(groundtruth) - y_onehot_test = label_binarizer.transform(groundtruth) - fpr, tpr, _ = roc_curve( - groundtruth.ravel(), - predict.ravel() - ) - roc_auc = auc(fpr, tpr) - plt.plot(fpr, tpr, label=f"AUC = {roc_auc:.2f}") - plt.savefig(f"/app/workdir/figures/roc_curve_{key}.png", - bbox_inches="tight") - with open(f"roc_fpr_tpr_{key}.json", 'w') as f: - roc_dict = {"fpr": fpr.tolist(), - "tpr": tpr.tolist(), - "auc": roc_auc} - json.dump(roc_dict, f) - print("Done") - - # Save the hyperparameters and metric to csv - metric = { - "head_size": HEAD_SIZE, - "num_heads": NUM_HEADS, - "ff_dim": FF_DIM, - "num_transformer_blocks": NUM_TRANSFORMER_BLOCKS, - "mlp_units": MLP_UNITS[0], - "dropout": DROPOUT, - "mlp_dropout": MLP_DROPOUT, - "batch_size": BATCH_SIZE, - "epochs": EPOCHS, - "test_loss": test_loss, - "test_accuracy": test_accuracy - } - if not os.path.exists("/app/workdir/metrics.csv"): - with open("/app/workdir/metrics.csv", "w") as f: - f.write(",".join(metric.keys()) + "\n") - with open("/app/workdir/metrics.csv", "a") as f: - f.write(",".join([str(value) for value in metric.values()]) + "\n") - - return + if __name__ == "__main__": main() From 4f443cfbcf3c43646164632cfb65801b44541497 Mon Sep 17 00:00:00 2001 From: Dawith Date: Mon, 27 Oct 2025 13:44:27 -0400 Subject: [PATCH 2/3] Unnecessary aux functions removed --- decodertest.py | 74 ++------------------------------------------------ 1 file changed, 2 insertions(+), 72 deletions(-) diff --git a/decodertest.py b/decodertest.py index 6c5b63e..70188cc 100644 --- a/decodertest.py +++ b/decodertest.py @@ -27,39 +27,7 @@ from model.model import CompoundModel, DecoderModel from pipe.etl import etl, read import train.decoder_train as dt - -def trim(dataframe, column): - - ndarray = np.array(dataframe.select(column).collect()) \ - .reshape(-1, 32, 130) - - return ndarray - -def transform(spark, dataframe, keys): - dataframe = dataframe.withColumn( - "index", functions.monotonically_increasing_id() - ) - bundle = {key: [ - arr.tolist() - for arr in OneHotEncoder(sparse_output=False) \ - .fit_transform(dataframe.select(key).collect()) - ] for key in keys - } - - bundle = [dict(zip(bundle.keys(), values)) - for values in zip(*bundle.values())] - schema = types.StructType([ - types.StructField(key, types.ArrayType(types.FloatType()), True) - for key in keys - ]) - newframe = spark.createDataFrame(bundle, schema=schema).withColumn( - "index", functions.monotonically_increasing_id() - ) - for key in keys: - dataframe = dataframe.withColumnRenamed(key, f"{key}_str") - dataframe = dataframe.join(newframe, on="index", how="inner") - - return dataframe +from visualize.plot import spectra_plot def build_dict(df, key): df = df.select(key, f"{key}_str").distinct() @@ -68,37 +36,6 @@ def build_dict(df, key): lambda row: (str(np.argmax(row[key])), row[f"{key}_str"]) ).collectAsMap() - -def get_data(spark, split=[0.99, 0.005, 0.005]): - path = Path("/app/workdir") - - labels = [] - with open(path / "train.csv", "r") as file: - for line in file: - labels.append(line.strip().split(",")[0]) - - pipe = SpectrogramPipe(spark, filetype="matfiles") - data = pipe.spectrogram_pipe(path, labels) - data.select("treatment").replace("virus", "cpv") \ - .replace("cont", "pbs") \ - .replace("control", "pbs") \ - .replace("dld", "pbs").distinct() - - data = transform(spark, data, ["treatment", "target"]) - category_dict = { - key: build_dict(data, key) for key in ["treatment", "target"] - } - splits = data.randomSplit(split, seed=42) - trainx, valx, testx = (trim(dset, "spectra") for dset in splits) - trainy, valy, testy = ( - [np.array(dset.select("treatment").collect()).squeeze(), - np.array(dset.select("target").collect()).squeeze()] - for dset in splits - ) - - - return ((trainx, trainy), (valx, valy), (testx, testy), category_dict) - def multi_gpu(): devices = jax.devices("gpu") mesh = keras.distribution.DeviceMesh( @@ -147,14 +84,7 @@ def main(): metrics = {key: None for key in keys} metrics, test_predict = dt.test_decoder(decoder, test_set, metrics) - plt.pcolor(test_predict[0], cmap='bwr', clim=(-1, 1)) - plt.savefig("sample_spectra.png") - plt.close() - plt.pcolor(test_predict[0] - np.mean(np.array(test_predict[0])[:4], axis=0), - cmap='bwr', clim=(-1, 1)) - plt.savefig("sample_spectrogram.png") - plt.close() - exit() + spectra_plot(test_predict[0], name=None) if __name__ == "__main__": From ea74dc7f882aacd2ef0af5d33a0fb40a64bc25b9 Mon Sep 17 00:00:00 2001 From: Dawith Date: Mon, 27 Oct 2025 13:45:03 -0400 Subject: [PATCH 3/3] Spectra/Spectrogram plotting basic code added to plot.py --- visualize/plot.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/visualize/plot.py b/visualize/plot.py index 065c90c..1fcf3a6 100644 --- a/visualize/plot.py +++ b/visualize/plot.py @@ -1,9 +1,13 @@ # --*- coding: utf-8 -*- import matplotlib.pyplot as plt +import numpy as np import seaborn as sns def lineplot(data=None, x=None, y=None, hue=None): + """ + """ + if data is None or x is None or y is None: raise ValueError("Data, x, and y parameters must be provided.") @@ -19,5 +23,21 @@ def lineplot(data=None, x=None, y=None, hue=None): plt.savefig(f"/app/workdir/figures/lineplot_{y}_by_{x}.png") plt.close() +def spectra_plot(spec_array, name=None): + """ + """ + + spec_array = spec_array[:,:130] + print(spec_array.shape) + plt.pcolor(spec_array, cmap="bwr", vmin=-1, vmax=1) + plt.title(f"{name} spectra") + plt.savefig(f"/app/workdir/figures/{name}_spectra_plot.png") + plt.close() + + spec_array -= np.mean(spec_array[:6,:], axis=0) + plt.pcolor(spec_array, cmap="bwr", vmin=-1, vmax=1) + plt.title(f"{name} spectrogram") + plt.savefig(f"/app/workdir/figures/{name}_spectrogram_plot.png") + plt.close() # EOF