Skip to content

Cleanup on decodertest code #2

Merged
merged 3 commits into from
Oct 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 3 additions & 138 deletions decodertest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,39 +27,7 @@
from model.model import CompoundModel, DecoderModel
from pipe.etl import etl, read
import train.decoder_train as dt

def trim(dataframe, column):

ndarray = np.array(dataframe.select(column).collect()) \
.reshape(-1, 32, 130)

return ndarray

def transform(spark, dataframe, keys):
dataframe = dataframe.withColumn(
"index", functions.monotonically_increasing_id()
)
bundle = {key: [
arr.tolist()
for arr in OneHotEncoder(sparse_output=False) \
.fit_transform(dataframe.select(key).collect())
] for key in keys
}

bundle = [dict(zip(bundle.keys(), values))
for values in zip(*bundle.values())]
schema = types.StructType([
types.StructField(key, types.ArrayType(types.FloatType()), True)
for key in keys
])
newframe = spark.createDataFrame(bundle, schema=schema).withColumn(
"index", functions.monotonically_increasing_id()
)
for key in keys:
dataframe = dataframe.withColumnRenamed(key, f"{key}_str")
dataframe = dataframe.join(newframe, on="index", how="inner")

return dataframe
from visualize.plot import spectra_plot

def build_dict(df, key):
df = df.select(key, f"{key}_str").distinct()
Expand All @@ -68,37 +36,6 @@ def build_dict(df, key):
lambda row: (str(np.argmax(row[key])), row[f"{key}_str"])
).collectAsMap()


def get_data(spark, split=[0.99, 0.005, 0.005]):
path = Path("/app/workdir")

labels = []
with open(path / "train.csv", "r") as file:
for line in file:
labels.append(line.strip().split(",")[0])

pipe = SpectrogramPipe(spark, filetype="matfiles")
data = pipe.spectrogram_pipe(path, labels)
data.select("treatment").replace("virus", "cpv") \
.replace("cont", "pbs") \
.replace("control", "pbs") \
.replace("dld", "pbs").distinct()

data = transform(spark, data, ["treatment", "target"])
category_dict = {
key: build_dict(data, key) for key in ["treatment", "target"]
}
splits = data.randomSplit(split, seed=42)
trainx, valx, testx = (trim(dset, "spectra") for dset in splits)
trainy, valy, testy = (
[np.array(dset.select("treatment").collect()).squeeze(),
np.array(dset.select("target").collect()).squeeze()]
for dset in splits
)


return ((trainx, trainy), (valx, valy), (testx, testy), category_dict)

def multi_gpu():
devices = jax.devices("gpu")
mesh = keras.distribution.DeviceMesh(
Expand Down Expand Up @@ -147,80 +84,8 @@ def main():
metrics = {key: None for key in keys}
metrics, test_predict = dt.test_decoder(decoder, test_set, metrics)

plt.pcolor(test_predict[0], cmap='bwr', clim=(-1, 1))
plt.savefig("sample_spectra.png")
plt.close()
plt.pcolor(test_predict[0] - np.mean(np.array(test_predict[0])[:4], axis=0),
cmap='bwr', clim=(-1, 1))
plt.savefig("sample_spectrogram.png")
plt.close()
exit()
print(f"Test loss: {test_loss}, test accuracy: {test_accuracy}")
for predict, groundtruth, key in zip(test_predict, test_set[1], keys):
conf_matrix = confusion_matrix(
np.argmax(predict, axis=1),
np.argmax(groundtruth, axis=1),
labels=range(len(categories[key].values())),
normalize="pred"
)
plt.pcolormesh(conf_matrix, edgecolors="black", linewidth=0.5)#origin="upper")
plt.gca().set_aspect("equal")
plt.colorbar()
plt.xticks([int(num) for num in categories[key].keys()],
categories[key].values(), rotation=270)
plt.yticks([int(num) for num in categories[key].keys()],
categories[key].values())
plt.xlabel("True label")
plt.ylabel("Predicted label")
plt.gcf().set_size_inches(len(categories[key])/10+4,
len(categories[key])/10+3)
plt.savefig(f"/app/workdir/figures/confusion_matrix_{key}.png",
bbox_inches="tight")
plt.close()
with open(f"confusion_matrix_{key}.json", 'w') as f:
confusion_dict = {"prediction": predict.tolist(),
"true": groundtruth.tolist(),
"matrix": conf_matrix.tolist()}
json.dump(confusion_dict, f)

label_binarizer = LabelBinarizer().fit(groundtruth)
y_onehot_test = label_binarizer.transform(groundtruth)
fpr, tpr, _ = roc_curve(
groundtruth.ravel(),
predict.ravel()
)
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, label=f"AUC = {roc_auc:.2f}")
plt.savefig(f"/app/workdir/figures/roc_curve_{key}.png",
bbox_inches="tight")
with open(f"roc_fpr_tpr_{key}.json", 'w') as f:
roc_dict = {"fpr": fpr.tolist(),
"tpr": tpr.tolist(),
"auc": roc_auc}
json.dump(roc_dict, f)
print("Done")

# Save the hyperparameters and metric to csv
metric = {
"head_size": HEAD_SIZE,
"num_heads": NUM_HEADS,
"ff_dim": FF_DIM,
"num_transformer_blocks": NUM_TRANSFORMER_BLOCKS,
"mlp_units": MLP_UNITS[0],
"dropout": DROPOUT,
"mlp_dropout": MLP_DROPOUT,
"batch_size": BATCH_SIZE,
"epochs": EPOCHS,
"test_loss": test_loss,
"test_accuracy": test_accuracy
}
if not os.path.exists("/app/workdir/metrics.csv"):
with open("/app/workdir/metrics.csv", "w") as f:
f.write(",".join(metric.keys()) + "\n")
with open("/app/workdir/metrics.csv", "a") as f:
f.write(",".join([str(value) for value in metric.values()]) + "\n")

return
spectra_plot(test_predict[0], name=None)


if __name__ == "__main__":
main()
Expand Down
20 changes: 20 additions & 0 deletions visualize/plot.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# --*- coding: utf-8 -*-

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

def lineplot(data=None, x=None, y=None, hue=None):
"""
"""

if data is None or x is None or y is None:
raise ValueError("Data, x, and y parameters must be provided.")

Expand All @@ -19,5 +23,21 @@ def lineplot(data=None, x=None, y=None, hue=None):
plt.savefig(f"/app/workdir/figures/lineplot_{y}_by_{x}.png")
plt.close()

def spectra_plot(spec_array, name=None):
"""
"""

spec_array = spec_array[:,:130]
print(spec_array.shape)
plt.pcolor(spec_array, cmap="bwr", vmin=-1, vmax=1)
plt.title(f"{name} spectra")
plt.savefig(f"/app/workdir/figures/{name}_spectra_plot.png")
plt.close()

spec_array -= np.mean(spec_array[:6,:], axis=0)
plt.pcolor(spec_array, cmap="bwr", vmin=-1, vmax=1)
plt.title(f"{name} spectrogram")
plt.savefig(f"/app/workdir/figures/{name}_spectrogram_plot.png")
plt.close()

# EOF
Loading