Skip to content

Commit

Permalink
Minor change to train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
lim185 committed Oct 14, 2025
1 parent 4a0dc38 commit 6601688
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def get_data(spark, split=[0.99, 0.005, 0.005]):
key: build_dict(data, key) for key in ["treatment", "target"]
}
splits = data.randomSplit(split, seed=42)
trainx, valx, testx = (trim(dset, "spectrogram") for dset in splits)
trainx, valx, testx = (trim(dset, "spectra") for dset in splits)
trainy, valy, testy = (
[np.array(dset.select("treatment").collect()).squeeze(),
np.array(dset.select("target").collect()).squeeze()]
Expand Down Expand Up @@ -154,7 +154,7 @@ def main():

keys = ["treatment", "target"]

load_from_scratch = False
load_from_scratch = True
if load_from_scratch:
data = etl(spark, split=SPLIT)
else:
Expand Down

0 comments on commit 6601688

Please sign in to comment.