From a48f78b1b1fede2d9feebea141abe6dc723b7d05 Mon Sep 17 00:00:00 2001 From: maelstrom Date: Fri, 3 Jan 2025 13:56:46 -0500 Subject: [PATCH] training scripts fixed --- train.py | 8 +++++++- train_cpu.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 25e7743..11b0b8c 100644 --- a/train.py +++ b/train.py @@ -94,6 +94,12 @@ def get_data(spark, split=[0.99, 0.005, 0.005]): testy = np.array(test_df.select("treatment_index").collect()).astype(int) _testy = np.zeros((len(testy), int(index_max+1))) for index, value in enumerate(testy): + _testy[index, value] = 1. + testy = _testy + del _testy + + return (selected, (trainx, trainy), (valx, valy), (testx, testy)) + def main(): # jax mesh setup @@ -146,7 +152,7 @@ def main(): np.argmax(test_set[1], axis=1)) plt.imshow(conf_matrix, origin="upper") plt.gca().set_aspect("equal") - plt.savefig("confusion_matrix.png") + plt.savefig("/app/workdir/confusion_matrix.png") return diff --git a/train_cpu.py b/train_cpu.py index de18b67..c1dbe24 100644 --- a/train_cpu.py +++ b/train_cpu.py @@ -140,7 +140,7 @@ def main(): np.argmax(test_set[1], axis=1)) plt.imshow(conf_matrix, origin="upper") plt.gca().set_aspect("equal") - plt.savefig("confusion_matrix.png") + plt.savefig("/app/workdir/confusion_matrix.png") return