From 940d227379fc40f814f221721d9216af7391a3db Mon Sep 17 00:00:00 2001 From: Dawith Lim Date: Tue, 30 Sep 2025 12:57:12 -0400 Subject: [PATCH] Function description added --- pipe/etl.py | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/pipe/etl.py b/pipe/etl.py index 807d427..f3f9277 100644 --- a/pipe/etl.py +++ b/pipe/etl.py @@ -10,7 +10,7 @@ import matplotlib.pyplot as plt import numpy as np from pathlib import Path -from pyspark.sql import SparkSession, functions, types, Row +from pyspark.sql import SparkSession from sklearn.metrics import confusion_matrix import tensorflow as tf @@ -18,25 +18,18 @@ from pipe.transform import transform from pipe.load import load -def etl(spark): +def etl(spark: SparkSession) -> types.DataFrame: """ - Performs the ETL process in series. + Performs the ETL process in series and returns the final DataFrame. + + Args: + spark (SparkSession): The Spark session to use for data processing. + + Returns: + types.DataFrame: The final processed DataFrame after ETL. """ data = extract(spark) data = transform(spark, data, keys=["treatment", "target"]) data = load(data) return data -def visualize_data_distribution(data): - for category in ["treatment", "target"]: - select = data.select(category) \ - .groupby(category) \ - .count() - plt.barh( - np.array(select.select(category).collect()).squeeze(), - np.array(select.select("count").collect()).astype("float") \ - .squeeze()) - plt.xlabel("Count") - plt.ylabel(category) - plt.savefig(f"{category}_counts.png", bbox_inches="tight") - plt.close()