Skip to content

Commit

Permalink
ETL split into extract, transform, load so it's easier to follow the …
Browse files Browse the repository at this point in the history
…scope of each portion in the code
  • Loading branch information
lim185 committed Sep 30, 2025
1 parent 4f59193 commit 0e2f81f
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 73 deletions.
75 changes: 11 additions & 64 deletions pipe/etl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,41 +12,20 @@
from pathlib import Path
from pyspark.sql import SparkSession, functions, types, Row
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import OneHotEncoder
import tensorflow as tf

from pipe.pipe import SpectrogramPipe
from pipe.extract import extract
from pipe.transform import transform
from pipe.load import load

def transform(spark, dataframe, keys):
dataframe.select("treatment").replace("virus", "cpv") \
.replace("cont", "pbs") \
.replace("control", "pbs") \
.replace("dld", "pbs").distinct()

dataframe = dataframe.withColumn(
"index", functions.monotonically_increasing_id()
)
bundle = {key: [
arr.tolist()
for arr in OneHotEncoder(sparse_output=False) \
.fit_transform(dataframe.select(key).collect())
] for key in keys
}

bundle = [dict(zip(bundle.keys(), values))
for values in zip(*bundle.values())]
schema = types.StructType([
types.StructField(key, types.ArrayType(types.FloatType()), True)
for key in keys
])
newframe = spark.createDataFrame(bundle, schema=schema).withColumn(
"index", functions.monotonically_increasing_id()
)
for key in keys:
dataframe = dataframe.withColumnRenamed(key, f"{key}_str")
dataframe = dataframe.join(newframe, on="index", how="inner")

return dataframe
def etl(spark):
"""
Performs the ETL process in series.
"""
data = extract(spark)
data = transform(spark, data, keys=["treatment", "target"])
data = load(data)
return data

def build_dict(df, key):
"""
Expand All @@ -67,38 +46,6 @@ def trim(dataframe, column):

return ndarray

def extract(spark):
"""
First step of the ETL pipeline. It reads the list of .mat files from
a CSV list, opens and pulls the spectrogram from each respective file.
"""

path = Path("/app/workdir")
labels = []
with open(path / "train.csv", "r") as file:
for line in file:
labels.append(line.strip().split(",")[0])

pipe = SpectrogramPipe(spark, filetype="matfiles")

return pipe.spectrogram_pipe(path, labels)

def load(data, split=[0.99, 0.005, 0.005]):
category_dict = {
key: build_dict(data, key) for key in ["treatment", "target"]
}
splits = data.randomSplit(split, seed=42)
trainx, valx, testx = (trim(dset, "spectrogram") for dset in splits)
trainy, valy, testy = (
[
np.array(dset.select("treatment").collect()).squeeze(),
np.array(dset.select("target").collect()).squeeze()
] for dset in splits
)

return ((trainx, trainy), (valx, valy), (testx, testy), category_dict)

def visualize_data_distribution(data):
for category in ["treatment", "target"]:
select = data.select(category) \
Expand Down
36 changes: 27 additions & 9 deletions pipe/pipe.py → pipe/extract.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#-*- coding: utf-8 -*-
"""
pipe.py
extract.py
"""

import json
Expand All @@ -13,6 +14,23 @@
from pyspark.sql import SparkSession, Row, DataFrame
import scipy as sp

def extract(spark):
"""
First step of the ETL pipeline. It reads the list of .mat files from
a CSV list, opens and pulls the spectrogram from each respective file.
"""

path = Path("/app/workdir")
labels = []
with open(path / "train.csv", "r") as file:
for line in file:
labels.append(line.strip().split(",")[0])

reader = SpectrogramReader(spark, filetype="matfiles")

return spectrogram_read(path, labels)

def image_pipe(spark: SparkSession, imagepath: Path, namepattern: str,
stacksize: int) -> np.ndarray:
images = np.zeros((stacksize, 800,800))
Expand All @@ -21,20 +39,20 @@ def image_pipe(spark: SparkSession, imagepath: Path, namepattern: str,

return images

class SpectrogramPipe:
class SpectrogramReader:

def __init__(self, spark: SparkSession, filetype: str = "hdf5"):
self.spark = spark
if filetype == "hdf5":
self.spectrogram_pipe = self.spectrogram_pipe_hdf5
self.spectrogram_read = self.spectrogram_read_hdf5
elif filetype == "shards":
self.spectrogram_pipe = self.spectrogram_pipe_shards
self.spectrogram_read = self.spectrogram_read_shards
elif filetype == "matfiles":
self.spectrogram_pipe = self.spectrogram_pipe_matfiles
self.spectrogram_read = self.spectrogram_read_matfiles
else:
raise ValueError

def metadata_pipe(self, metapath: Path, labels:list,
def metadata_read(self, metapath: Path, labels:list,
namepattern: str="metadata{}.json") -> dict:
"""
Loads metadata for each target label from a set of json files and
Expand All @@ -56,7 +74,7 @@ def metadata_pipe(self, metapath: Path, labels:list,

return metadata

def spectrogram_pipe_matfiles(self, specpath: Path, labels:list,
def spectrogram_read_matfiles(self, specpath: Path, labels:list,
default_size: tuple = (32, 130),
pad_value: float = 0.) \
-> DataFrame:
Expand Down Expand Up @@ -98,7 +116,7 @@ def spectrogram_pipe_matfiles(self, specpath: Path, labels:list,

return self.spark.createDataFrame(spectrograms)

def spectrogram_pipe_hdf5(self, specpath: Path, labels: list,
def spectrogram_read_hdf5(self, specpath: Path, labels: list,
namepattern:str="averaged_spectrogram{}.hdf5") \
-> DataFrame:
"""
Expand All @@ -123,7 +141,7 @@ def spectrogram_pipe_hdf5(self, specpath: Path, labels: list,

return self.spark.createDataFrame(spectrograms)

def spectrogram_pipe_shards(self, specpath: Path, namepattern: str,
def spectrogram_read_shards(self, specpath: Path, namepattern: str,
stacksize: int, freq_samples: int) \
-> DataFrame:
"""
Expand Down
25 changes: 25 additions & 0 deletions pipe/load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#-*- coding: utf-8 -*-
"""
load.py
"""

import typing

import numpy as np
from pyspark.sql import DataFrame

def load(data: DataFrame, split=[0.99, 0.005, 0.005]):
category_dict = {
key: build_dict(data, key) for key in ["treatment", "target"]
}
splits = data.randomSplit(split, seed=42)
trainx, valx, testx = (trim(dset, "spectrogram") for dset in splits)
trainy, valy, testy = (
[
np.array(dset.select("treatment").collect()).squeeze(),
np.array(dset.select("target").collect()).squeeze()
] for dset in splits
)

return ((trainx, trainy), (valx, valy), (testx, testy), category_dict)

52 changes: 52 additions & 0 deletions pipe/transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#-*- coding: utf-8 -*-
"""
transform.py
"""

import typing

from pyspark.sql import DataFrame, SparkSession
from sklearn.preprocessing import OneHotEncoder

def merge_redundant_labels(dataframe: DataFrame) -> DataFrame:
"""
Merge redundant labels in the 'treatment' column of the dataframe. This
step is inefficient but necessary due to inconsistent naming conventions
used in the MATLAB onekey processing.
"""
dataframe.select("treatment").replace("virus", "cpv") \
.replace("cont", "pbs") \
.replace("control", "pbs") \
.replace("dld", "pbs").distinct()
return dataframe

def transform(spark: SparkSession, dataframe: DataFrame, keys: list) \
-> DataFrame:
"""
"""
dataframe = merge_redundant_labels(dataframe)
dataframe = dataframe.withColumn(
"index", functions.monotonically_increasing_id()
)
bundle = {key: [
arr.tolist()
for arr in OneHotEncoder(sparse_output=False) \
.fit_transform(dataframe.select(key).collect())
] for key in keys
}

bundle = [dict(zip(bundle.keys(), values))
for values in zip(*bundle.values())]
schema = types.StructType([
types.StructField(key, types.ArrayType(types.FloatType()), True)
for key in keys
])
newframe = spark.createDataFrame(bundle, schema=schema).withColumn(
"index", functions.monotonically_increasing_id()
)

for key in keys:
dataframe = dataframe.withColumnRenamed(key, f"{key}_str")
dataframe = dataframe.join(newframe, on="index", how="inner")

return dataframe

0 comments on commit 0e2f81f

Please sign in to comment.