Skip to content

Commit

Permalink
ETL loop almost complete:
Browse files Browse the repository at this point in the history
  • Loading branch information
lim185 committed Sep 30, 2025
1 parent 52395c1 commit 9ca02e2
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 41 deletions.
33 changes: 29 additions & 4 deletions pipe/etl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from pyspark.sql import SparkSession
from pyspark.sql import DataFrame, SparkSession
from sklearn.metrics import confusion_matrix
import tensorflow as tf

from pipe.extract import extract
from pipe.transform import transform
from pipe.load import load

def etl(spark: SparkSession) -> DataFrame:
def etl(spark: SparkSession, split: list=None) -> DataFrame:
"""
Performs the ETL process in series and returns the final DataFrame.
Expand All @@ -31,7 +31,11 @@ def etl(spark: SparkSession) -> DataFrame:
data = extract(spark)
data = transform(spark, data, keys=["treatment", "target"])
load(data)
data = split(data)
match split:
case None:
data = split_sets(data)
case _:
data = split_sets(data, split=split)
return data

def read(spark: SparkSession) -> DataFrame:
Expand All @@ -50,7 +54,7 @@ def read(spark: SparkSession) -> DataFrame:
data = split(data)
return data

def split(data: DataFrame, split=[0.99, 0.005, 0.005]) -> tuple:
def split_sets(data: DataFrame, split=[0.99, 0.005, 0.005]) -> tuple:
"""
Splits the DataFrame into training, validation, and test sets with random
seed.
Expand All @@ -77,3 +81,24 @@ def split(data: DataFrame, split=[0.99, 0.005, 0.005]) -> tuple:
)

return ((trainx, trainy), (valx, valy), (testx, testy), category_dict)

def trim(dataframe, column):

ndarray = np.array(dataframe.select(column).collect()) \
.reshape(-1, 32, 130)

return ndarray


def build_dict(df, key):
"""
Takes a dataframe as input and returns a dictionary of unique values
in the column corresponding to the key.
"""

df = df.select(key, f"{key}_str").distinct()

return df.rdd.map(
lambda row: (str(np.argmax(row[key])), row[f"{key}_str"])
).collectAsMap()

26 changes: 3 additions & 23 deletions pipe/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,10 @@
import typing

import numpy as np
from pyspark.sql import DataFrame
from pyspark.sql import DataFrame, SparkSession


def build_dict(df, key):
"""
Takes a dataframe as input and returns a dictionary of unique values
in the column corresponding to the key.
"""

df = df.select(key, f"{key}_str").distinct()

return df.rdd.map(
lambda row: (str(np.argmax(row[key])), row[f"{key}_str"])
).collectAsMap()

def trim(dataframe, column):

ndarray = np.array(dataframe.select(column).collect()) \
.reshape(-1, 32, 130)

return ndarray

def load(spark:SparkSession, data: DataFrame):
df = df.write.mode("overwrite") \
.parquet("/app/workdir/parquet/data.parquet")
def load(data: DataFrame):
data.write.mode("overwrite").parquet("/app/workdir/parquet/data.parquet")

# EOF
42 changes: 34 additions & 8 deletions pipe/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,42 @@

import typing

from pyspark.sql import DataFrame, SparkSession
from pyspark.sql import DataFrame, functions, SparkSession, types
from sklearn.preprocessing import OneHotEncoder

def merge_redundant_labels(dataframe: DataFrame) -> DataFrame:
def merge_redundant_treatment_labels(dataframe: DataFrame) -> DataFrame:
"""
Merge redundant labels in the 'treatment' column of the dataframe. This
step is inefficient but necessary due to inconsistent naming conventions
used in the MATLAB onekey processing.
Args:
dataframe (DataFrame): Input Spark DataFrame. Must have a 'treatment'
column.
Returns:
DataFrame: Modified DataFrame with merged labels in the 'treatment'
column.
"""
dataframe.select("treatment").replace("virus", "cpv") \
.replace("cont", "pbs") \
.replace("control", "pbs") \
.replace("dld", "pbs").distinct()
return dataframe

def transform(spark: SparkSession, dataframe: DataFrame, keys: list) \
-> DataFrame:
def onehot(dataframe: DataFrame, keys: list) -> DataFrame:
"""
One-hot encode the specified categorical columns in the dataframe. The
column names to be encoded this way are provided in the 'keys' list.
Args:
dataframe (DataFrame): Input Spark DataFrame.
keys (list): List of column names to be one-hot encoded.
Returns:
DataFrame: New DataFrame with one-hot encoded columns.
"""
dataframe = merge_redundant_labels(dataframe)
dataframe = dataframe.withColumn(
"index", functions.monotonically_increasing_id()
)

bundle = {key: [
arr.tolist()
for arr in OneHotEncoder(sparse_output=False) \
Expand All @@ -41,6 +54,19 @@ def transform(spark: SparkSession, dataframe: DataFrame, keys: list) \
types.StructField(key, types.ArrayType(types.FloatType()), True)
for key in keys
])

return bundle, schema

def transform(spark: SparkSession, dataframe: DataFrame, keys: list) \
-> DataFrame:
"""
"""
dataframe = merge_redundant_treatment_labels(dataframe)
dataframe = dataframe.withColumn(
"index", functions.monotonically_increasing_id()
)

bundle, schema = onehot(dataframe, keys)
newframe = spark.createDataFrame(bundle, schema=schema).withColumn(
"index", functions.monotonically_increasing_id()
)
Expand Down
15 changes: 9 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from sklearn.preprocessing import OneHotEncoder

from model.model import TimeSeriesTransformer as TSTF
from pipe.etl import extract, transform, load
from pipe.etl import etl, read

with open("parameters.json", "r") as file:
params = json.load(file)
Expand Down Expand Up @@ -153,11 +153,14 @@ def main():
spark = SparkSession.builder.appName("train").getOrCreate()

keys = ["treatment", "target"]

data = extract(spark)
data = transform(spark, data, ["treatment", "target"])
(train_set, validation_set,
test_set, categories) = load(data, split=SPLIT)

load_from_scratch = True
if load_from_scratch:
data = etl(spark, split=SPLIT)
else:
data = read(spark)

(train_set, validation_set, test_set, categories) = data

n_classes = [dset.shape[1] for dset in train_set[1]]
model = get_model(train_set[0].shape[1:], n_classes)
Expand Down

0 comments on commit 9ca02e2

Please sign in to comment.