diff --git a/decodertest.py b/decodertest.py index bb34818..6c5b63e 100644 --- a/decodertest.py +++ b/decodertest.py @@ -155,72 +155,7 @@ def main(): plt.savefig("sample_spectrogram.png") plt.close() exit() - print(f"Test loss: {test_loss}, test accuracy: {test_accuracy}") - for predict, groundtruth, key in zip(test_predict, test_set[1], keys): - conf_matrix = confusion_matrix( - np.argmax(predict, axis=1), - np.argmax(groundtruth, axis=1), - labels=range(len(categories[key].values())), - normalize="pred" - ) - plt.pcolormesh(conf_matrix, edgecolors="black", linewidth=0.5)#origin="upper") - plt.gca().set_aspect("equal") - plt.colorbar() - plt.xticks([int(num) for num in categories[key].keys()], - categories[key].values(), rotation=270) - plt.yticks([int(num) for num in categories[key].keys()], - categories[key].values()) - plt.xlabel("True label") - plt.ylabel("Predicted label") - plt.gcf().set_size_inches(len(categories[key])/10+4, - len(categories[key])/10+3) - plt.savefig(f"/app/workdir/figures/confusion_matrix_{key}.png", - bbox_inches="tight") - plt.close() - with open(f"confusion_matrix_{key}.json", 'w') as f: - confusion_dict = {"prediction": predict.tolist(), - "true": groundtruth.tolist(), - "matrix": conf_matrix.tolist()} - json.dump(confusion_dict, f) - - label_binarizer = LabelBinarizer().fit(groundtruth) - y_onehot_test = label_binarizer.transform(groundtruth) - fpr, tpr, _ = roc_curve( - groundtruth.ravel(), - predict.ravel() - ) - roc_auc = auc(fpr, tpr) - plt.plot(fpr, tpr, label=f"AUC = {roc_auc:.2f}") - plt.savefig(f"/app/workdir/figures/roc_curve_{key}.png", - bbox_inches="tight") - with open(f"roc_fpr_tpr_{key}.json", 'w') as f: - roc_dict = {"fpr": fpr.tolist(), - "tpr": tpr.tolist(), - "auc": roc_auc} - json.dump(roc_dict, f) - print("Done") - - # Save the hyperparameters and metric to csv - metric = { - "head_size": HEAD_SIZE, - "num_heads": NUM_HEADS, - "ff_dim": FF_DIM, - "num_transformer_blocks": NUM_TRANSFORMER_BLOCKS, - "mlp_units": MLP_UNITS[0], - "dropout": DROPOUT, - "mlp_dropout": MLP_DROPOUT, - "batch_size": BATCH_SIZE, - "epochs": EPOCHS, - "test_loss": test_loss, - "test_accuracy": test_accuracy - } - if not os.path.exists("/app/workdir/metrics.csv"): - with open("/app/workdir/metrics.csv", "w") as f: - f.write(",".join(metric.keys()) + "\n") - with open("/app/workdir/metrics.csv", "a") as f: - f.write(",".join([str(value) for value in metric.values()]) + "\n") - - return + if __name__ == "__main__": main()