diff --git a/train.py b/train.py index ddf37f1..55f0b69 100644 --- a/train.py +++ b/train.py @@ -114,7 +114,7 @@ def get_data(spark, split=[0.99, 0.005, 0.005]): key: build_dict(data, key) for key in ["treatment", "target"] } splits = data.randomSplit(split, seed=42) - trainx, valx, testx = (trim(dset, "spectrogram") for dset in splits) + trainx, valx, testx = (trim(dset, "spectra") for dset in splits) trainy, valy, testy = ( [np.array(dset.select("treatment").collect()).squeeze(), np.array(dset.select("target").collect()).squeeze()] @@ -154,7 +154,7 @@ def main(): keys = ["treatment", "target"] - load_from_scratch = False + load_from_scratch = True if load_from_scratch: data = etl(spark, split=SPLIT) else: