Skip to content

Commit

Permalink
ETL pipeline uses proper StringIndexer instead of jerryrigged onehot …
Browse files Browse the repository at this point in the history
…encoding scheme
  • Loading branch information
lim185 committed Oct 14, 2025
1 parent 9879577 commit 7ab49d6
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 17 deletions.
2 changes: 1 addition & 1 deletion pipe/etl.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def split_sets(data: DataFrame, split=[0.99, 0.005, 0.005]) -> tuple:
key: build_dict(data, key) for key in ["treatment", "target"]
}
splits = data.randomSplit(split, seed=42)
trainx, valx, testx = (trim(dset, "spectrogram") for dset in splits)
trainx, valx, testx = (trim(dset, "spectra") for dset in splits)
trainy, valy, testy = (
[
np.array(dset.select("treatment").collect()).squeeze(),
Expand Down
28 changes: 17 additions & 11 deletions pipe/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import json
import glob
import os
from pathlib import Path
import typing
Expand Down Expand Up @@ -33,9 +34,9 @@ def extract(spark: SparkSession) -> DataFrame:
for line in file:
labels.append(line.strip().split(",")[0])

reader = SpectrogramReader(spark, filetype="matfiles")
reader = FileReader(spark, filetype=FileType.MAT)

return reader.spectrogram_read(path, labels)
return reader.read_file(path, labels)

def image_pipe(spark: SparkSession, imagepath: Path, namepattern: str,
stacksize: int) -> np.ndarray:
Expand Down Expand Up @@ -72,11 +73,11 @@ def __init__(self, spark: SparkSession, filetype: FileType):
self.spark = spark
match filetype:
case FileType.HDF5:
self.spectrogram_read = self.spectrogram_read_hdf5
self.read_file = self.read_hdf5
case FileType.SHARD:
self.spectrogram_read = self.spectrogram_read_shards
self.read_file = self.read_shards
case FileType.MAT:
self.spectrogram_read = self.spectrogram_read_matfiles
self.read_file = self.read_matfiles
case _:
raise ValueError(Expected)

Expand Down Expand Up @@ -122,25 +123,30 @@ def read_matfiles(self, specpath: Path,
"""
data = []
row = {}
labels = glob.glob(str(specpath/"matfiles"/"*.mat"))

for label in labels:
matdata = sp.io.loadmat(specpath/"matfiles"/label)
if DataKind.TREATMENT in datakinds:
row["treatment"] = matdata["header"][0][0][4][0].lower()
if DataKind.TARGET in datakinds:
row["target"] = matdata["header"][0][0][2][0].lower()
try:
row["target"] = matdata["header"][0][0][2][0].lower()
except:
row["target"] = "unknown"
if DataKind.FPS in datakinds:
row["fps"] = 2*float(matdata["header"][0][0][15][0])
if DataKind.BB in datakinds:
row["bb"] = matdata["bb"]
row["bb"] = matdata["BB"].tolist()
if DataKind.NSD in datakinds:
row["nsd"] = matdata["nsd"]
row["nsd"] = matdata["NSD"].tolist()
if DataKind.NCNT in datakinds:
row["ncnt"] = matdata["ncnt"]
print(matdata["NCNT"])
row["ncnt"] = matdata["NCNT"].astype(float).tolist()
if DataKind.SPEC in datakinds:
spectra = np.array(matdata["SP"][0])
if len(spectra.shape) == 3:
spectra = spectrogram[0]
spectra = spectra[0]
if spectra.shape[0] > default_size[0]:
spectra = spectra[:default_size[0], :]
spectra = np.pad(
Expand All @@ -154,7 +160,7 @@ def read_matfiles(self, specpath: Path,
row["spectra"] = spectra.tolist()
data.append(Row(**row))

return self.spark.createDataFrame(spectrograms)
return self.spark.createDataFrame(data)

def read_hdf5(self, specpath: Path, labels: list,
namepattern:str="averaged_spectrogram{}.hdf5") \
Expand Down
54 changes: 49 additions & 5 deletions pipe/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import typing

from pyspark.sql import DataFrame, functions, SparkSession, types
from sklearn.preprocessing import OneHotEncoder
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, OneHotEncoder

def merge_redundant_treatment_labels(dataframe: DataFrame) -> DataFrame:
"""
Expand Down Expand Up @@ -34,13 +35,52 @@ def onehot(dataframe: DataFrame, keys: list) -> DataFrame:
column names to be encoded this way are provided in the 'keys' list.
Args:
dataframe (DataFrame): Input Spark DataFrame.
keys (list): List of column names to be one-hot encoded.
dataframe (pyspark.sql.DataFrame): Input Spark DataFrame that contains
the column of strings to encode.
keys (list(string)): List of column names to be one-hot encoded.
Returns:
DataFrame: New DataFrame with one-hot encoded columns.
pyspark.sql.DataFrame: New DataFrame with one-hot encoded column(s).
"""

indexers = []
encoders = []
indexed_cols = []

for key in keys:
if key not in dataframe.columns:
raise ValueError(f"Column \"{key}\" cannot be found in DataFrame.")

indexed_column = f"{key}_indexed"
encoded_column = f"{key}_encoded"
indexed_cols.append(indexed_column)

indexers.append(
StringIndexer(
inputCol=key,
outputCol=indexed_column,
handleInvalid="keep"
)
)

encoders.append(
OneHotEncoder(
inputCol=indexed_column,
outputCol=encoded_column,
dropLast=False
)
)

pipeline = Pipeline(stages=indexers + encoders)
model = pipeline.fit(dataframe)
result = model.transform(dataframe)

result = result.drop(*indexed_cols[:])
for column_name in keys:
result = result.withColumnRenamed(column_name, f"{column_name}_str")
result = result.withColumnRenamed(f"{column_name}_encoded", column_name)

"""
bundle = {key: [
arr.tolist()
for arr in OneHotEncoder(sparse_output=False) \
Expand All @@ -55,7 +95,8 @@ def onehot(dataframe: DataFrame, keys: list) -> DataFrame:
for key in keys
])
return bundle, schema
return bundle, schema"""
return result

def transform(spark: SparkSession, dataframe: DataFrame, keys: list) \
-> DataFrame:
Expand All @@ -66,6 +107,7 @@ def transform(spark: SparkSession, dataframe: DataFrame, keys: list) \
"index", functions.monotonically_increasing_id()
)

"""
bundle, schema = onehot(dataframe, keys)
newframe = spark.createDataFrame(bundle, schema=schema).withColumn(
"index", functions.monotonically_increasing_id()
Expand All @@ -74,5 +116,7 @@ def transform(spark: SparkSession, dataframe: DataFrame, keys: list) \
for key in keys:
dataframe = dataframe.withColumnRenamed(key, f"{key}_str")
dataframe = dataframe.join(newframe, on="index", how="inner")
"""
dataframe = onehot(dataframe, keys)

return dataframe

0 comments on commit 7ab49d6

Please sign in to comment.