Skip to content

Commit

Permalink
Function description added
Browse files Browse the repository at this point in the history
  • Loading branch information
lim185 committed Sep 30, 2025
1 parent fe00fc6 commit 940d227
Showing 1 changed file with 9 additions and 16 deletions.
25 changes: 9 additions & 16 deletions pipe/etl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,26 @@
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from pyspark.sql import SparkSession, functions, types, Row
from pyspark.sql import SparkSession
from sklearn.metrics import confusion_matrix
import tensorflow as tf

from pipe.extract import extract
from pipe.transform import transform
from pipe.load import load

def etl(spark):
def etl(spark: SparkSession) -> types.DataFrame:
"""
Performs the ETL process in series.
Performs the ETL process in series and returns the final DataFrame.
Args:
spark (SparkSession): The Spark session to use for data processing.
Returns:
types.DataFrame: The final processed DataFrame after ETL.
"""
data = extract(spark)
data = transform(spark, data, keys=["treatment", "target"])
data = load(data)
return data

def visualize_data_distribution(data):
for category in ["treatment", "target"]:
select = data.select(category) \
.groupby(category) \
.count()
plt.barh(
np.array(select.select(category).collect()).squeeze(),
np.array(select.select("count").collect()).astype("float") \
.squeeze())
plt.xlabel("Count")
plt.ylabel(category)
plt.savefig(f"{category}_counts.png", bbox_inches="tight")
plt.close()

0 comments on commit 940d227

Please sign in to comment.