diff --git a/train.py b/train.py index 6315016..22e3608 100644 --- a/train.py +++ b/train.py @@ -20,7 +20,11 @@ def main(): spark = SparkSession.builder.appName("train").getOrCreate() pipe = SpectrogramPipe(spark) data = pipe.spectrogram_pipe(path, labels) - print(data.head()) + train_df, validation_df, test_df = data.randomsplit([0.8, 0.1, 0.1], + seed=42) + print(train_df.count()) + print(validation_df.count()) + print(test_df.count()) return