From 894a9eb77b17c77ea8d5df06a8542961d60122da Mon Sep 17 00:00:00 2001 From: Dawith Date: Sun, 19 Oct 2025 12:38:44 -0400 Subject: [PATCH] LOAD_FROM_SCRATCH is now read from parameters file. Also added a hyperparameter and metric logging routine so that model performance can be tracked more systematically --- train.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 60ca8c8..694fa93 100644 --- a/train.py +++ b/train.py @@ -31,6 +31,7 @@ # data parameters SPLIT = params["split"] + LOAD_FROM_SCRATCH = params["load_from_scratch"] # model parameters HEAD_SIZE = params["head_size"] @@ -154,8 +155,7 @@ def main(): keys = ["treatment", "target"] - load_from_scratch = True - if load_from_scratch: + if LOAD_FROM_SCRATCH: data = etl(spark, split=SPLIT) else: data = read(spark) @@ -207,6 +207,26 @@ def main(): "matrix": conf_matrix.tolist()} json.dump(confusion_dict, f) print("Done") + + # Save the hyperparameters and metric to csv + metric = { + "head_size": HEAD_SIZE, + "num_heads": NUM_HEADS, + "ff_dim": FF_DIM, + "num_transformer_blocks": NUM_TRANSFORMER_BLOCKS, + "mlp_units": MLP_UNITS[0], + "dropout": DROPOUT, + "mlp_dropout": MLP_DROPOUT, + "batch_size": BATCH_SIZE, + "epochs": EPOCHS, + "test_loss": test_loss, + "test_accuracy": test_accuracy + } + if not os.path.exists("/app/workdir/metrics.csv"): + with open("/app/workdir/metrics.csv", "w") as f: + f.write(",".join(metric.keys()) + "\n") + with open("/app/workdir/metrics.csv", "a") as f: + f.write(",".join([str(value) for value in metric.values()]) + "\n") return