diff --git a/pipe/etl.py b/pipe/etl.py index 854f242..069a020 100644 --- a/pipe/etl.py +++ b/pipe/etl.py @@ -38,7 +38,7 @@ def etl(spark: SparkSession, split: list=None) -> DataFrame: data = split_sets(data, split=split) return data -def read(spark: SparkSession) -> DataFrame: +def read(spark: SparkSession, split=None) -> DataFrame: """ Reads the processed data from a Parquet file and splits it into training, validation, and test sets. @@ -51,7 +51,7 @@ def read(spark: SparkSession) -> DataFrame: """ data = spark.read.parquet("/app/workdir/parquet/data.parquet") - data = split_sets(data) + data = split_sets(data, split=split) return data def split_sets(data: DataFrame, split=[0.99, 0.005, 0.005]) -> tuple: