Skip to content

Commit

Permalink
ETL now loads data to persistent form of parquet files, which is then…
Browse files Browse the repository at this point in the history
… read when needed
  • Loading branch information
lim185 committed Sep 30, 2025
1 parent 43dc233 commit 52395c1
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 16 deletions.
48 changes: 46 additions & 2 deletions pipe/etl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from pipe.transform import transform
from pipe.load import load

def etl(spark: SparkSession) -> types.DataFrame:
def etl(spark: SparkSession) -> DataFrame:
"""
Performs the ETL process in series and returns the final DataFrame.
Expand All @@ -30,6 +30,50 @@ def etl(spark: SparkSession) -> types.DataFrame:
"""
data = extract(spark)
data = transform(spark, data, keys=["treatment", "target"])
data = load(data)
load(data)
data = split(data)
return data

def read(spark: SparkSession) -> DataFrame:
"""
Reads the processed data from a Parquet file and splits it into training,
validation, and test sets.
Args:
spark (SparkSession): The Spark session to use for data processing.
Returns:
DataFrame: The split datasets and category dictionary.
"""

data = spark.read.parquet("/app/workdir/parquet/data.parquet")
data = split(data)
return data

def split(data: DataFrame, split=[0.99, 0.005, 0.005]) -> tuple:
"""
Splits the DataFrame into training, validation, and test sets with random
seed.
Args:
data (DataFrame): The DataFrame to split.
split (list, optional): The split ratios for train, val, and test sets.
Defaults to [0.99, 0.005, 0.005].
Returns:
tuple: A tuple containing the split datasets and category dictionary.
"""

category_dict = {
key: build_dict(data, key) for key in ["treatment", "target"]
}
splits = data.randomSplit(split, seed=42)
trainx, valx, testx = (trim(dset, "spectrogram") for dset in splits)
trainy, valy, testy = (
[
np.array(dset.select("treatment").collect()).squeeze(),
np.array(dset.select("target").collect()).squeeze()
] for dset in splits
)

return ((trainx, trainy), (valx, valy), (testx, testy), category_dict)
17 changes: 3 additions & 14 deletions pipe/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,8 @@ def trim(dataframe, column):

return ndarray

def load(data: DataFrame, split=[0.99, 0.005, 0.005]):
category_dict = {
key: build_dict(data, key) for key in ["treatment", "target"]
}
splits = data.randomSplit(split, seed=42)
trainx, valx, testx = (trim(dset, "spectrogram") for dset in splits)
trainy, valy, testy = (
[
np.array(dset.select("treatment").collect()).squeeze(),
np.array(dset.select("target").collect()).squeeze()
] for dset in splits
)

return ((trainx, trainy), (valx, valy), (testx, testy), category_dict)
def load(spark:SparkSession, data: DataFrame):
df = df.write.mode("overwrite") \
.parquet("/app/workdir/parquet/data.parquet")

# EOF

0 comments on commit 52395c1

Please sign in to comment.