From c6d8a75735532c2eb1ea0315672f089a98459f1a Mon Sep 17 00:00:00 2001 From: maelstrom Date: Sat, 7 Dec 2024 22:07:02 -0500 Subject: [PATCH] Splitting dataset into train, validation and test sets --- train.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index 6315016..22e3608 100644 --- a/train.py +++ b/train.py @@ -20,7 +20,11 @@ def main(): spark = SparkSession.builder.appName("train").getOrCreate() pipe = SpectrogramPipe(spark) data = pipe.spectrogram_pipe(path, labels) - print(data.head()) + train_df, validation_df, test_df = data.randomsplit([0.8, 0.1, 0.1], + seed=42) + print(train_df.count()) + print(validation_df.count()) + print(test_df.count()) return