Skip to content

Commit

Permalink
Merge pull request #2 from Nolte-Group/20251024
Browse files Browse the repository at this point in the history
Cleanup on decodertest code
  • Loading branch information
lim185 authored Oct 27, 2025
2 parents f88b647 + ea74dc7 commit 97ebb66
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 138 deletions.
141 changes: 3 additions & 138 deletions decodertest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,39 +27,7 @@
from model.model import CompoundModel, DecoderModel
from pipe.etl import etl, read
import train.decoder_train as dt

def trim(dataframe, column):

ndarray = np.array(dataframe.select(column).collect()) \
.reshape(-1, 32, 130)

return ndarray

def transform(spark, dataframe, keys):
dataframe = dataframe.withColumn(
"index", functions.monotonically_increasing_id()
)
bundle = {key: [
arr.tolist()
for arr in OneHotEncoder(sparse_output=False) \
.fit_transform(dataframe.select(key).collect())
] for key in keys
}

bundle = [dict(zip(bundle.keys(), values))
for values in zip(*bundle.values())]
schema = types.StructType([
types.StructField(key, types.ArrayType(types.FloatType()), True)
for key in keys
])
newframe = spark.createDataFrame(bundle, schema=schema).withColumn(
"index", functions.monotonically_increasing_id()
)
for key in keys:
dataframe = dataframe.withColumnRenamed(key, f"{key}_str")
dataframe = dataframe.join(newframe, on="index", how="inner")

return dataframe
from visualize.plot import spectra_plot

def build_dict(df, key):
df = df.select(key, f"{key}_str").distinct()
Expand All @@ -68,37 +36,6 @@ def build_dict(df, key):
lambda row: (str(np.argmax(row[key])), row[f"{key}_str"])
).collectAsMap()


def get_data(spark, split=[0.99, 0.005, 0.005]):
path = Path("/app/workdir")

labels = []
with open(path / "train.csv", "r") as file:
for line in file:
labels.append(line.strip().split(",")[0])

pipe = SpectrogramPipe(spark, filetype="matfiles")
data = pipe.spectrogram_pipe(path, labels)
data.select("treatment").replace("virus", "cpv") \
.replace("cont", "pbs") \
.replace("control", "pbs") \
.replace("dld", "pbs").distinct()

data = transform(spark, data, ["treatment", "target"])
category_dict = {
key: build_dict(data, key) for key in ["treatment", "target"]
}
splits = data.randomSplit(split, seed=42)
trainx, valx, testx = (trim(dset, "spectra") for dset in splits)
trainy, valy, testy = (
[np.array(dset.select("treatment").collect()).squeeze(),
np.array(dset.select("target").collect()).squeeze()]
for dset in splits
)


return ((trainx, trainy), (valx, valy), (testx, testy), category_dict)

def multi_gpu():
devices = jax.devices("gpu")
mesh = keras.distribution.DeviceMesh(
Expand Down Expand Up @@ -147,80 +84,8 @@ def main():
metrics = {key: None for key in keys}
metrics, test_predict = dt.test_decoder(decoder, test_set, metrics)

plt.pcolor(test_predict[0], cmap='bwr', clim=(-1, 1))
plt.savefig("sample_spectra.png")
plt.close()
plt.pcolor(test_predict[0] - np.mean(np.array(test_predict[0])[:4], axis=0),
cmap='bwr', clim=(-1, 1))
plt.savefig("sample_spectrogram.png")
plt.close()
exit()
print(f"Test loss: {test_loss}, test accuracy: {test_accuracy}")
for predict, groundtruth, key in zip(test_predict, test_set[1], keys):
conf_matrix = confusion_matrix(
np.argmax(predict, axis=1),
np.argmax(groundtruth, axis=1),
labels=range(len(categories[key].values())),
normalize="pred"
)
plt.pcolormesh(conf_matrix, edgecolors="black", linewidth=0.5)#origin="upper")
plt.gca().set_aspect("equal")
plt.colorbar()
plt.xticks([int(num) for num in categories[key].keys()],
categories[key].values(), rotation=270)
plt.yticks([int(num) for num in categories[key].keys()],
categories[key].values())
plt.xlabel("True label")
plt.ylabel("Predicted label")
plt.gcf().set_size_inches(len(categories[key])/10+4,
len(categories[key])/10+3)
plt.savefig(f"/app/workdir/figures/confusion_matrix_{key}.png",
bbox_inches="tight")
plt.close()
with open(f"confusion_matrix_{key}.json", 'w') as f:
confusion_dict = {"prediction": predict.tolist(),
"true": groundtruth.tolist(),
"matrix": conf_matrix.tolist()}
json.dump(confusion_dict, f)

label_binarizer = LabelBinarizer().fit(groundtruth)
y_onehot_test = label_binarizer.transform(groundtruth)
fpr, tpr, _ = roc_curve(
groundtruth.ravel(),
predict.ravel()
)
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, label=f"AUC = {roc_auc:.2f}")
plt.savefig(f"/app/workdir/figures/roc_curve_{key}.png",
bbox_inches="tight")
with open(f"roc_fpr_tpr_{key}.json", 'w') as f:
roc_dict = {"fpr": fpr.tolist(),
"tpr": tpr.tolist(),
"auc": roc_auc}
json.dump(roc_dict, f)
print("Done")

# Save the hyperparameters and metric to csv
metric = {
"head_size": HEAD_SIZE,
"num_heads": NUM_HEADS,
"ff_dim": FF_DIM,
"num_transformer_blocks": NUM_TRANSFORMER_BLOCKS,
"mlp_units": MLP_UNITS[0],
"dropout": DROPOUT,
"mlp_dropout": MLP_DROPOUT,
"batch_size": BATCH_SIZE,
"epochs": EPOCHS,
"test_loss": test_loss,
"test_accuracy": test_accuracy
}
if not os.path.exists("/app/workdir/metrics.csv"):
with open("/app/workdir/metrics.csv", "w") as f:
f.write(",".join(metric.keys()) + "\n")
with open("/app/workdir/metrics.csv", "a") as f:
f.write(",".join([str(value) for value in metric.values()]) + "\n")

return
spectra_plot(test_predict[0], name=None)


if __name__ == "__main__":
main()
Expand Down
20 changes: 20 additions & 0 deletions visualize/plot.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# --*- coding: utf-8 -*-

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

def lineplot(data=None, x=None, y=None, hue=None):
"""
"""

if data is None or x is None or y is None:
raise ValueError("Data, x, and y parameters must be provided.")

Expand All @@ -19,5 +23,21 @@ def lineplot(data=None, x=None, y=None, hue=None):
plt.savefig(f"/app/workdir/figures/lineplot_{y}_by_{x}.png")
plt.close()

def spectra_plot(spec_array, name=None):
"""
"""

spec_array = spec_array[:,:130]
print(spec_array.shape)
plt.pcolor(spec_array, cmap="bwr", vmin=-1, vmax=1)
plt.title(f"{name} spectra")
plt.savefig(f"/app/workdir/figures/{name}_spectra_plot.png")
plt.close()

spec_array -= np.mean(spec_array[:6,:], axis=0)
plt.pcolor(spec_array, cmap="bwr", vmin=-1, vmax=1)
plt.title(f"{name} spectrogram")
plt.savefig(f"/app/workdir/figures/{name}_spectrogram_plot.png")
plt.close()

# EOF

0 comments on commit 97ebb66

Please sign in to comment.