Skip to content

Commit

Permalink
LOAD_FROM_SCRATCH is now read from parameters file. Also added a hype…
Browse files Browse the repository at this point in the history
…rparameter and metric logging routine so that model performance can be tracked more systematically
  • Loading branch information
Dawith committed Oct 19, 2025
1 parent 52dde05 commit 894a9eb
Showing 1 changed file with 22 additions and 2 deletions.
24 changes: 22 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

# data parameters
SPLIT = params["split"]
LOAD_FROM_SCRATCH = params["load_from_scratch"]

# model parameters
HEAD_SIZE = params["head_size"]
Expand Down Expand Up @@ -154,8 +155,7 @@ def main():

keys = ["treatment", "target"]

load_from_scratch = True
if load_from_scratch:
if LOAD_FROM_SCRATCH:
data = etl(spark, split=SPLIT)
else:
data = read(spark)
Expand Down Expand Up @@ -207,6 +207,26 @@ def main():
"matrix": conf_matrix.tolist()}
json.dump(confusion_dict, f)
print("Done")

# Save the hyperparameters and metric to csv
metric = {
"head_size": HEAD_SIZE,
"num_heads": NUM_HEADS,
"ff_dim": FF_DIM,
"num_transformer_blocks": NUM_TRANSFORMER_BLOCKS,
"mlp_units": MLP_UNITS[0],
"dropout": DROPOUT,
"mlp_dropout": MLP_DROPOUT,
"batch_size": BATCH_SIZE,
"epochs": EPOCHS,
"test_loss": test_loss,
"test_accuracy": test_accuracy
}
if not os.path.exists("/app/workdir/metrics.csv"):
with open("/app/workdir/metrics.csv", "w") as f:
f.write(",".join(metric.keys()) + "\n")
with open("/app/workdir/metrics.csv", "a") as f:
f.write(",".join([str(value) for value in metric.values()]) + "\n")

return

Expand Down

0 comments on commit 894a9eb

Please sign in to comment.