From 52395c1255fcc31498215826b0f6ad7f9010cce2 Mon Sep 17 00:00:00 2001 From: Dawith Lim Date: Tue, 30 Sep 2025 14:05:47 -0400 Subject: [PATCH] ETL now loads data to persistent form of parquet files, which is then read when needed --- pipe/etl.py | 48 ++++++++++++++++++++++++++++++++++++++++++++++-- pipe/load.py | 17 +++-------------- 2 files changed, 49 insertions(+), 16 deletions(-) diff --git a/pipe/etl.py b/pipe/etl.py index f3f9277..413df67 100644 --- a/pipe/etl.py +++ b/pipe/etl.py @@ -18,7 +18,7 @@ from pipe.transform import transform from pipe.load import load -def etl(spark: SparkSession) -> types.DataFrame: +def etl(spark: SparkSession) -> DataFrame: """ Performs the ETL process in series and returns the final DataFrame. @@ -30,6 +30,50 @@ def etl(spark: SparkSession) -> types.DataFrame: """ data = extract(spark) data = transform(spark, data, keys=["treatment", "target"]) - data = load(data) + load(data) + data = split(data) return data +def read(spark: SparkSession) -> DataFrame: + """ + Reads the processed data from a Parquet file and splits it into training, + validation, and test sets. + + Args: + spark (SparkSession): The Spark session to use for data processing. + + Returns: + DataFrame: The split datasets and category dictionary. + """ + + data = spark.read.parquet("/app/workdir/parquet/data.parquet") + data = split(data) + return data + +def split(data: DataFrame, split=[0.99, 0.005, 0.005]) -> tuple: + """ + Splits the DataFrame into training, validation, and test sets with random + seed. + + Args: + data (DataFrame): The DataFrame to split. + split (list, optional): The split ratios for train, val, and test sets. + Defaults to [0.99, 0.005, 0.005]. + + Returns: + tuple: A tuple containing the split datasets and category dictionary. + """ + + category_dict = { + key: build_dict(data, key) for key in ["treatment", "target"] + } + splits = data.randomSplit(split, seed=42) + trainx, valx, testx = (trim(dset, "spectrogram") for dset in splits) + trainy, valy, testy = ( + [ + np.array(dset.select("treatment").collect()).squeeze(), + np.array(dset.select("target").collect()).squeeze() + ] for dset in splits + ) + + return ((trainx, trainy), (valx, valy), (testx, testy), category_dict) diff --git a/pipe/load.py b/pipe/load.py index 5d12919..bc02e39 100644 --- a/pipe/load.py +++ b/pipe/load.py @@ -28,19 +28,8 @@ def trim(dataframe, column): return ndarray -def load(data: DataFrame, split=[0.99, 0.005, 0.005]): - category_dict = { - key: build_dict(data, key) for key in ["treatment", "target"] - } - splits = data.randomSplit(split, seed=42) - trainx, valx, testx = (trim(dset, "spectrogram") for dset in splits) - trainy, valy, testy = ( - [ - np.array(dset.select("treatment").collect()).squeeze(), - np.array(dset.select("target").collect()).squeeze() - ] for dset in splits - ) - - return ((trainx, trainy), (valx, valy), (testx, testy), category_dict) +def load(spark:SparkSession, data: DataFrame): + df = df.write.mode("overwrite") \ + .parquet("/app/workdir/parquet/data.parquet") # EOF