From 9ca02e2090e578ab0a210ffdc04846bfb6c0bcd1 Mon Sep 17 00:00:00 2001 From: Dawith Lim Date: Tue, 30 Sep 2025 14:24:00 -0400 Subject: [PATCH] ETL loop almost complete: --- pipe/etl.py | 33 +++++++++++++++++++++++++++++---- pipe/load.py | 26 +++----------------------- pipe/transform.py | 42 ++++++++++++++++++++++++++++++++++-------- train.py | 15 +++++++++------ 4 files changed, 75 insertions(+), 41 deletions(-) diff --git a/pipe/etl.py b/pipe/etl.py index 413df67..83b4485 100644 --- a/pipe/etl.py +++ b/pipe/etl.py @@ -10,7 +10,7 @@ import matplotlib.pyplot as plt import numpy as np from pathlib import Path -from pyspark.sql import SparkSession +from pyspark.sql import DataFrame, SparkSession from sklearn.metrics import confusion_matrix import tensorflow as tf @@ -18,7 +18,7 @@ from pipe.transform import transform from pipe.load import load -def etl(spark: SparkSession) -> DataFrame: +def etl(spark: SparkSession, split: list=None) -> DataFrame: """ Performs the ETL process in series and returns the final DataFrame. @@ -31,7 +31,11 @@ def etl(spark: SparkSession) -> DataFrame: data = extract(spark) data = transform(spark, data, keys=["treatment", "target"]) load(data) - data = split(data) + match split: + case None: + data = split_sets(data) + case _: + data = split_sets(data, split=split) return data def read(spark: SparkSession) -> DataFrame: @@ -50,7 +54,7 @@ def read(spark: SparkSession) -> DataFrame: data = split(data) return data -def split(data: DataFrame, split=[0.99, 0.005, 0.005]) -> tuple: +def split_sets(data: DataFrame, split=[0.99, 0.005, 0.005]) -> tuple: """ Splits the DataFrame into training, validation, and test sets with random seed. @@ -77,3 +81,24 @@ def split(data: DataFrame, split=[0.99, 0.005, 0.005]) -> tuple: ) return ((trainx, trainy), (valx, valy), (testx, testy), category_dict) + +def trim(dataframe, column): + + ndarray = np.array(dataframe.select(column).collect()) \ + .reshape(-1, 32, 130) + + return ndarray + + +def build_dict(df, key): + """ + Takes a dataframe as input and returns a dictionary of unique values + in the column corresponding to the key. + """ + + df = df.select(key, f"{key}_str").distinct() + + return df.rdd.map( + lambda row: (str(np.argmax(row[key])), row[f"{key}_str"]) + ).collectAsMap() + diff --git a/pipe/load.py b/pipe/load.py index bc02e39..ae2ba24 100644 --- a/pipe/load.py +++ b/pipe/load.py @@ -6,30 +6,10 @@ import typing import numpy as np -from pyspark.sql import DataFrame +from pyspark.sql import DataFrame, SparkSession -def build_dict(df, key): - """ - Takes a dataframe as input and returns a dictionary of unique values - in the column corresponding to the key. - """ - - df = df.select(key, f"{key}_str").distinct() - - return df.rdd.map( - lambda row: (str(np.argmax(row[key])), row[f"{key}_str"]) - ).collectAsMap() - -def trim(dataframe, column): - - ndarray = np.array(dataframe.select(column).collect()) \ - .reshape(-1, 32, 130) - - return ndarray - -def load(spark:SparkSession, data: DataFrame): - df = df.write.mode("overwrite") \ - .parquet("/app/workdir/parquet/data.parquet") +def load(data: DataFrame): + data.write.mode("overwrite").parquet("/app/workdir/parquet/data.parquet") # EOF diff --git a/pipe/transform.py b/pipe/transform.py index 8604c60..3b1b2bd 100644 --- a/pipe/transform.py +++ b/pipe/transform.py @@ -5,14 +5,22 @@ import typing -from pyspark.sql import DataFrame, SparkSession +from pyspark.sql import DataFrame, functions, SparkSession, types from sklearn.preprocessing import OneHotEncoder -def merge_redundant_labels(dataframe: DataFrame) -> DataFrame: +def merge_redundant_treatment_labels(dataframe: DataFrame) -> DataFrame: """ Merge redundant labels in the 'treatment' column of the dataframe. This step is inefficient but necessary due to inconsistent naming conventions used in the MATLAB onekey processing. + + Args: + dataframe (DataFrame): Input Spark DataFrame. Must have a 'treatment' + column. + + Returns: + DataFrame: Modified DataFrame with merged labels in the 'treatment' + column. """ dataframe.select("treatment").replace("virus", "cpv") \ .replace("cont", "pbs") \ @@ -20,14 +28,19 @@ def merge_redundant_labels(dataframe: DataFrame) -> DataFrame: .replace("dld", "pbs").distinct() return dataframe -def transform(spark: SparkSession, dataframe: DataFrame, keys: list) \ - -> DataFrame: +def onehot(dataframe: DataFrame, keys: list) -> DataFrame: """ + One-hot encode the specified categorical columns in the dataframe. The + column names to be encoded this way are provided in the 'keys' list. + + Args: + dataframe (DataFrame): Input Spark DataFrame. + keys (list): List of column names to be one-hot encoded. + + Returns: + DataFrame: New DataFrame with one-hot encoded columns. """ - dataframe = merge_redundant_labels(dataframe) - dataframe = dataframe.withColumn( - "index", functions.monotonically_increasing_id() - ) + bundle = {key: [ arr.tolist() for arr in OneHotEncoder(sparse_output=False) \ @@ -41,6 +54,19 @@ def transform(spark: SparkSession, dataframe: DataFrame, keys: list) \ types.StructField(key, types.ArrayType(types.FloatType()), True) for key in keys ]) + + return bundle, schema + +def transform(spark: SparkSession, dataframe: DataFrame, keys: list) \ + -> DataFrame: + """ + """ + dataframe = merge_redundant_treatment_labels(dataframe) + dataframe = dataframe.withColumn( + "index", functions.monotonically_increasing_id() + ) + + bundle, schema = onehot(dataframe, keys) newframe = spark.createDataFrame(bundle, schema=schema).withColumn( "index", functions.monotonically_increasing_id() ) diff --git a/train.py b/train.py index 1562f3e..78ef7d2 100644 --- a/train.py +++ b/train.py @@ -24,7 +24,7 @@ from sklearn.preprocessing import OneHotEncoder from model.model import TimeSeriesTransformer as TSTF -from pipe.etl import extract, transform, load +from pipe.etl import etl, read with open("parameters.json", "r") as file: params = json.load(file) @@ -153,11 +153,14 @@ def main(): spark = SparkSession.builder.appName("train").getOrCreate() keys = ["treatment", "target"] - - data = extract(spark) - data = transform(spark, data, ["treatment", "target"]) - (train_set, validation_set, - test_set, categories) = load(data, split=SPLIT) + + load_from_scratch = True + if load_from_scratch: + data = etl(spark, split=SPLIT) + else: + data = read(spark) + + (train_set, validation_set, test_set, categories) = data n_classes = [dset.shape[1] for dset in train_set[1]] model = get_model(train_set[0].shape[1:], n_classes)