From 66016884f8d5f7fe016fcef4c893aaf27b9c3550 Mon Sep 17 00:00:00 2001 From: Dawith Lim Date: Mon, 13 Oct 2025 20:27:13 -0400 Subject: [PATCH] Minor change to train.py --- train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index ddf37f1..55f0b69 100644 --- a/train.py +++ b/train.py @@ -114,7 +114,7 @@ def get_data(spark, split=[0.99, 0.005, 0.005]): key: build_dict(data, key) for key in ["treatment", "target"] } splits = data.randomSplit(split, seed=42) - trainx, valx, testx = (trim(dset, "spectrogram") for dset in splits) + trainx, valx, testx = (trim(dset, "spectra") for dset in splits) trainy, valy, testy = ( [np.array(dset.select("treatment").collect()).squeeze(), np.array(dset.select("target").collect()).squeeze()] @@ -154,7 +154,7 @@ def main(): keys = ["treatment", "target"] - load_from_scratch = False + load_from_scratch = True if load_from_scratch: data = etl(spark, split=SPLIT) else: