diff --git a/train.py b/train.py index 22e3608..d550750 100644 --- a/train.py +++ b/train.py @@ -20,7 +20,7 @@ def main(): spark = SparkSession.builder.appName("train").getOrCreate() pipe = SpectrogramPipe(spark) data = pipe.spectrogram_pipe(path, labels) - train_df, validation_df, test_df = data.randomsplit([0.8, 0.1, 0.1], + train_df, validation_df, test_df = data.randomSplit([0.8, 0.1, 0.1], seed=42) print(train_df.count()) print(validation_df.count())