From ebf8fed185e97be7c776d7fc5b552acd5fd5c4ae Mon Sep 17 00:00:00 2001 From: Dawith Date: Tue, 21 Oct 2025 12:44:02 -0400 Subject: [PATCH] roc plotting work in progress --- train.py | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/train.py b/train.py index 694fa93..1adfdce 100644 --- a/train.py +++ b/train.py @@ -1,3 +1,4 @@ +# --*-- coding: utf-8 --*-- """ train.py @@ -20,8 +21,8 @@ import tensorflow as tf import keras import matplotlib.pyplot as plt -from sklearn.metrics import confusion_matrix -from sklearn.preprocessing import OneHotEncoder +from sklearn.metrics import auc, confusion_matrix, roc_curve +from sklearn.preprocessing import OneHotEncoder, LabelBinarizer from model.model import CompoundModel from pipe.etl import etl, read @@ -158,7 +159,7 @@ def main(): if LOAD_FROM_SCRATCH: data = etl(spark, split=SPLIT) else: - data = read(spark) + data = read(spark, split=SPLIT) (train_set, validation_set, test_set, categories) = data @@ -168,7 +169,8 @@ def main(): loss="categorical_crossentropy", metrics=["categorical_accuracy", "categorical_accuracy"] ) - model.summary() + if LOG_LEVEL == 1: + model.summary() start = time.time() model.fit(x=train_set[0], y=train_set[1], @@ -179,7 +181,6 @@ def main(): # Test model performance test_loss, test_accuracy, _, _, _, _ = model.evaluate(test_set[0], test_set[1]) - print(model.metrics_names) test_predict = model.predict(test_set[0]) print(f"Test loss: {test_loss}, test accuracy: {test_accuracy}") for predict, groundtruth, key in zip(test_predict, test_set[1], keys): @@ -196,6 +197,8 @@ def main(): categories[key].values(), rotation=270) plt.yticks([int(num) for num in categories[key].keys()], categories[key].values()) + plt.xlabel("True label") + plt.ylabel("Predicted label") plt.gcf().set_size_inches(len(categories[key])/10+4, len(categories[key])/10+3) plt.savefig(f"/app/workdir/figures/confusion_matrix_{key}.png", @@ -206,6 +209,22 @@ def main(): "true": groundtruth.tolist(), "matrix": conf_matrix.tolist()} json.dump(confusion_dict, f) + + label_binarizer = LabelBinarizer().fit(groundtruth) + y_onehot_test = label_binarizer.transform(groundtruth) + fpr, tpr, _ = roc_curve( + groundtruth.ravel(), + predict.ravel() + ) + roc_auc = auc(fpr, tpr) + plt.plot(fpr, tpr, label=f"AUC = {roc_auc:.2f}") + plt.savefig(f"/app/workdir/figures/roc_curve_{key}.png", + bbox_inches="tight") + with open(f"roc_fpr_tpr_{key}.json", 'w') as f: + roc_dict = {"fpr": fpr.tolist(), + "tpr": tpr.tolist(), + "auc": roc_auc} + json.dump(roc_dict, f) print("Done") # Save the hyperparameters and metric to csv