From 7ab49d640ea2fa96563d6f437e61c3b7638bdf96 Mon Sep 17 00:00:00 2001 From: Dawith Lim Date: Mon, 13 Oct 2025 20:26:40 -0400 Subject: [PATCH] ETL pipeline uses proper StringIndexer instead of jerryrigged onehot encoding scheme --- pipe/etl.py | 2 +- pipe/extract.py | 28 ++++++++++++++---------- pipe/transform.py | 54 ++++++++++++++++++++++++++++++++++++++++++----- 3 files changed, 67 insertions(+), 17 deletions(-) diff --git a/pipe/etl.py b/pipe/etl.py index 617b901..4525775 100644 --- a/pipe/etl.py +++ b/pipe/etl.py @@ -72,7 +72,7 @@ def split_sets(data: DataFrame, split=[0.99, 0.005, 0.005]) -> tuple: key: build_dict(data, key) for key in ["treatment", "target"] } splits = data.randomSplit(split, seed=42) - trainx, valx, testx = (trim(dset, "spectrogram") for dset in splits) + trainx, valx, testx = (trim(dset, "spectra") for dset in splits) trainy, valy, testy = ( [ np.array(dset.select("treatment").collect()).squeeze(), diff --git a/pipe/extract.py b/pipe/extract.py index 8396d0c..663a1cd 100644 --- a/pipe/extract.py +++ b/pipe/extract.py @@ -4,6 +4,7 @@ """ import json +import glob import os from pathlib import Path import typing @@ -33,9 +34,9 @@ def extract(spark: SparkSession) -> DataFrame: for line in file: labels.append(line.strip().split(",")[0]) - reader = SpectrogramReader(spark, filetype="matfiles") + reader = FileReader(spark, filetype=FileType.MAT) - return reader.spectrogram_read(path, labels) + return reader.read_file(path, labels) def image_pipe(spark: SparkSession, imagepath: Path, namepattern: str, stacksize: int) -> np.ndarray: @@ -72,11 +73,11 @@ def __init__(self, spark: SparkSession, filetype: FileType): self.spark = spark match filetype: case FileType.HDF5: - self.spectrogram_read = self.spectrogram_read_hdf5 + self.read_file = self.read_hdf5 case FileType.SHARD: - self.spectrogram_read = self.spectrogram_read_shards + self.read_file = self.read_shards case FileType.MAT: - self.spectrogram_read = self.spectrogram_read_matfiles + self.read_file = self.read_matfiles case _: raise ValueError(Expected) @@ -122,25 +123,30 @@ def read_matfiles(self, specpath: Path, """ data = [] row = {} + labels = glob.glob(str(specpath/"matfiles"/"*.mat")) for label in labels: matdata = sp.io.loadmat(specpath/"matfiles"/label) if DataKind.TREATMENT in datakinds: row["treatment"] = matdata["header"][0][0][4][0].lower() if DataKind.TARGET in datakinds: - row["target"] = matdata["header"][0][0][2][0].lower() + try: + row["target"] = matdata["header"][0][0][2][0].lower() + except: + row["target"] = "unknown" if DataKind.FPS in datakinds: row["fps"] = 2*float(matdata["header"][0][0][15][0]) if DataKind.BB in datakinds: - row["bb"] = matdata["bb"] + row["bb"] = matdata["BB"].tolist() if DataKind.NSD in datakinds: - row["nsd"] = matdata["nsd"] + row["nsd"] = matdata["NSD"].tolist() if DataKind.NCNT in datakinds: - row["ncnt"] = matdata["ncnt"] + print(matdata["NCNT"]) + row["ncnt"] = matdata["NCNT"].astype(float).tolist() if DataKind.SPEC in datakinds: spectra = np.array(matdata["SP"][0]) if len(spectra.shape) == 3: - spectra = spectrogram[0] + spectra = spectra[0] if spectra.shape[0] > default_size[0]: spectra = spectra[:default_size[0], :] spectra = np.pad( @@ -154,7 +160,7 @@ def read_matfiles(self, specpath: Path, row["spectra"] = spectra.tolist() data.append(Row(**row)) - return self.spark.createDataFrame(spectrograms) + return self.spark.createDataFrame(data) def read_hdf5(self, specpath: Path, labels: list, namepattern:str="averaged_spectrogram{}.hdf5") \ diff --git a/pipe/transform.py b/pipe/transform.py index 3b1b2bd..8bcb616 100644 --- a/pipe/transform.py +++ b/pipe/transform.py @@ -6,7 +6,8 @@ import typing from pyspark.sql import DataFrame, functions, SparkSession, types -from sklearn.preprocessing import OneHotEncoder +from pyspark.ml import Pipeline +from pyspark.ml.feature import StringIndexer, OneHotEncoder def merge_redundant_treatment_labels(dataframe: DataFrame) -> DataFrame: """ @@ -34,13 +35,52 @@ def onehot(dataframe: DataFrame, keys: list) -> DataFrame: column names to be encoded this way are provided in the 'keys' list. Args: - dataframe (DataFrame): Input Spark DataFrame. - keys (list): List of column names to be one-hot encoded. + dataframe (pyspark.sql.DataFrame): Input Spark DataFrame that contains + the column of strings to encode. + keys (list(string)): List of column names to be one-hot encoded. Returns: - DataFrame: New DataFrame with one-hot encoded columns. + pyspark.sql.DataFrame: New DataFrame with one-hot encoded column(s). """ + indexers = [] + encoders = [] + indexed_cols = [] + + for key in keys: + if key not in dataframe.columns: + raise ValueError(f"Column \"{key}\" cannot be found in DataFrame.") + + indexed_column = f"{key}_indexed" + encoded_column = f"{key}_encoded" + indexed_cols.append(indexed_column) + + indexers.append( + StringIndexer( + inputCol=key, + outputCol=indexed_column, + handleInvalid="keep" + ) + ) + + encoders.append( + OneHotEncoder( + inputCol=indexed_column, + outputCol=encoded_column, + dropLast=False + ) + ) + + pipeline = Pipeline(stages=indexers + encoders) + model = pipeline.fit(dataframe) + result = model.transform(dataframe) + + result = result.drop(*indexed_cols[:]) + for column_name in keys: + result = result.withColumnRenamed(column_name, f"{column_name}_str") + result = result.withColumnRenamed(f"{column_name}_encoded", column_name) + + """ bundle = {key: [ arr.tolist() for arr in OneHotEncoder(sparse_output=False) \ @@ -55,7 +95,8 @@ def onehot(dataframe: DataFrame, keys: list) -> DataFrame: for key in keys ]) - return bundle, schema + return bundle, schema""" + return result def transform(spark: SparkSession, dataframe: DataFrame, keys: list) \ -> DataFrame: @@ -66,6 +107,7 @@ def transform(spark: SparkSession, dataframe: DataFrame, keys: list) \ "index", functions.monotonically_increasing_id() ) + """ bundle, schema = onehot(dataframe, keys) newframe = spark.createDataFrame(bundle, schema=schema).withColumn( "index", functions.monotonically_increasing_id() @@ -74,5 +116,7 @@ def transform(spark: SparkSession, dataframe: DataFrame, keys: list) \ for key in keys: dataframe = dataframe.withColumnRenamed(key, f"{key}_str") dataframe = dataframe.join(newframe, on="index", how="inner") + """ + dataframe = onehot(dataframe, keys) return dataframe