Skip to content

Commit

Permalink
roc plotting work in progress
Browse files Browse the repository at this point in the history
  • Loading branch information
Dawith committed Oct 21, 2025
1 parent 3fee715 commit ebf8fed
Showing 1 changed file with 24 additions and 5 deletions.
29 changes: 24 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# --*-- coding: utf-8 --*--
"""
train.py
Expand All @@ -20,8 +21,8 @@
import tensorflow as tf
import keras
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import OneHotEncoder
from sklearn.metrics import auc, confusion_matrix, roc_curve
from sklearn.preprocessing import OneHotEncoder, LabelBinarizer

from model.model import CompoundModel
from pipe.etl import etl, read
Expand Down Expand Up @@ -158,7 +159,7 @@ def main():
if LOAD_FROM_SCRATCH:
data = etl(spark, split=SPLIT)
else:
data = read(spark)
data = read(spark, split=SPLIT)

(train_set, validation_set, test_set, categories) = data

Expand All @@ -168,7 +169,8 @@ def main():
loss="categorical_crossentropy",
metrics=["categorical_accuracy", "categorical_accuracy"]
)
model.summary()
if LOG_LEVEL == 1:
model.summary()

start = time.time()
model.fit(x=train_set[0], y=train_set[1],
Expand All @@ -179,7 +181,6 @@ def main():

# Test model performance
test_loss, test_accuracy, _, _, _, _ = model.evaluate(test_set[0], test_set[1])
print(model.metrics_names)
test_predict = model.predict(test_set[0])
print(f"Test loss: {test_loss}, test accuracy: {test_accuracy}")
for predict, groundtruth, key in zip(test_predict, test_set[1], keys):
Expand All @@ -196,6 +197,8 @@ def main():
categories[key].values(), rotation=270)
plt.yticks([int(num) for num in categories[key].keys()],
categories[key].values())
plt.xlabel("True label")
plt.ylabel("Predicted label")
plt.gcf().set_size_inches(len(categories[key])/10+4,
len(categories[key])/10+3)
plt.savefig(f"/app/workdir/figures/confusion_matrix_{key}.png",
Expand All @@ -206,6 +209,22 @@ def main():
"true": groundtruth.tolist(),
"matrix": conf_matrix.tolist()}
json.dump(confusion_dict, f)

label_binarizer = LabelBinarizer().fit(groundtruth)
y_onehot_test = label_binarizer.transform(groundtruth)
fpr, tpr, _ = roc_curve(
groundtruth.ravel(),
predict.ravel()
)
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, label=f"AUC = {roc_auc:.2f}")
plt.savefig(f"/app/workdir/figures/roc_curve_{key}.png",
bbox_inches="tight")
with open(f"roc_fpr_tpr_{key}.json", 'w') as f:
roc_dict = {"fpr": fpr.tolist(),
"tpr": tpr.tolist(),
"auc": roc_auc}
json.dump(roc_dict, f)
print("Done")

# Save the hyperparameters and metric to csv
Expand Down

0 comments on commit ebf8fed

Please sign in to comment.