Skip to content

Commit

Permalink
training scripts fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
lim185 committed Jan 3, 2025
1 parent 5a26e4e commit a48f78b
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
8 changes: 7 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ def get_data(spark, split=[0.99, 0.005, 0.005]):
testy = np.array(test_df.select("treatment_index").collect()).astype(int)
_testy = np.zeros((len(testy), int(index_max+1)))
for index, value in enumerate(testy):
_testy[index, value] = 1.
testy = _testy
del _testy

return (selected, (trainx, trainy), (valx, valy), (testx, testy))


def main():
# jax mesh setup
Expand Down Expand Up @@ -146,7 +152,7 @@ def main():
np.argmax(test_set[1], axis=1))
plt.imshow(conf_matrix, origin="upper")
plt.gca().set_aspect("equal")
plt.savefig("confusion_matrix.png")
plt.savefig("/app/workdir/confusion_matrix.png")

return

Expand Down
2 changes: 1 addition & 1 deletion train_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def main():
np.argmax(test_set[1], axis=1))
plt.imshow(conf_matrix, origin="upper")
plt.gca().set_aspect("equal")
plt.savefig("confusion_matrix.png")
plt.savefig("/app/workdir/confusion_matrix.png")

return

Expand Down

0 comments on commit a48f78b

Please sign in to comment.