diff --git a/pipe/etl.py b/pipe/etl.py index 83b4485..617b901 100644 --- a/pipe/etl.py +++ b/pipe/etl.py @@ -51,7 +51,7 @@ def read(spark: SparkSession) -> DataFrame: """ data = spark.read.parquet("/app/workdir/parquet/data.parquet") - data = split(data) + data = split_sets(data) return data def split_sets(data: DataFrame, split=[0.99, 0.005, 0.005]) -> tuple: diff --git a/train.py b/train.py index 78ef7d2..9474007 100644 --- a/train.py +++ b/train.py @@ -154,7 +154,7 @@ def main(): keys = ["treatment", "target"] - load_from_scratch = True + load_from_scratch = False if load_from_scratch: data = etl(spark, split=SPLIT) else: