Skip to content

Commit

Permalink
Revised data pipeline to handle large throughput
Browse files Browse the repository at this point in the history
  • Loading branch information
Dawith committed Mar 2, 2026
1 parent 2f9298c commit 520fbe3
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 15 deletions.
2 changes: 1 addition & 1 deletion pipe/etl.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def split_sets(data: DataFrame, split=[0.99, 0.005, 0.005]) -> tuple:
def trim(dataframe, column):

ndarray = np.array(dataframe.select(column).collect()) \
.reshape(-1, 34, 133)
.reshape(-1, 36, 133)

return ndarray

Expand Down
42 changes: 29 additions & 13 deletions pipe/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import cv2 as cv
import h5py
import numpy as np
from numpy import ndarray
from pyspark.ml.linalg import Vectors
from pyspark.sql import SparkSession, Row, DataFrame
import scipy as sp

Expand All @@ -35,8 +37,10 @@ def extract(spark: SparkSession) -> DataFrame:
labels.append(line.strip().split(",")[0])

reader = FileReader(spark, filetype=FileType.MAT)

return reader.read_file(path, labels)
rdd = spark.sparkContext.parallelize(reader.read_file(path, labels),
numSlices=200)
#return reader.read_file(path, labels)
return spark.createDataFrame(rdd)

def image_pipe(spark: SparkSession, imagepath: Path, namepattern: str,
stacksize: int) -> np.ndarray:
Expand All @@ -58,6 +62,15 @@ def image_pipe(spark: SparkSession, imagepath: Path, namepattern: str,

return images

def strip_array(arr: ndarray) -> ndarray:
dtype = arr.dtype
while dtype == "object":
arr = arr[0]
dtype = arr.dtype
if len(arr) == 0:
return ["unknown"]
return arr

class FileReader:
"""
Class to read spectrograms and metadata from different file formats based
Expand Down Expand Up @@ -105,7 +118,7 @@ def metadata_read(self, metapath: Path, labels:list,

def read_matfiles(self, specpath: Path,
datakinds: List[DataKind],
default_size: tuple = (34, 130),
default_size: tuple = (36, 130),
pad_value: float = 0.) -> DataFrame:
"""
Loads data for each stack iteration from a set of mat files,
Expand All @@ -121,8 +134,7 @@ def read_matfiles(self, specpath: Path,
Returns:
DataFrame: Spark DataFrame containing the requested data.
"""
data = []
row = {}
#data = []
labels = glob.glob(str(specpath/"matfiles"/"*.mat"))
nloops = default_size[0]
nfreq = default_size[1]
Expand All @@ -137,6 +149,8 @@ def read_matfiles(self, specpath: Path,
ncnt_scale = 5.

for label in labels:
row = {}
print(label)
matdata = sp.io.loadmat(specpath/"matfiles"/label)
ncnt = np.log10(matdata["NCNT"][0])
if np.min(ncnt) < 2:
Expand Down Expand Up @@ -174,18 +188,20 @@ def read_matfiles(self, specpath: Path,
(matdata["NSD"][0] - nsd_meanshift)
timeseries_array[time_offset:, 132] = \
np.log10(matdata["NCNT"][0]) / ncnt_scale
row["timeseries"] = timeseries_array.tolist()
timeseries_array = Vectors.dense(timeseries_array.flatten())
row["timeseries"] = timeseries_array

if DataKind.TREATMENT in datakinds:
row["treatment"] = matdata["header"]["drug"][0][0][0].lower()
row["treatment"] = strip_array(
matdata["header"]["drug"])[0].lower()
if DataKind.TARGET in datakinds:
try:
row["target"] = matdata["header"]["cell"][0][0][0].lower()
except:
row["target"] = "unknown"
data.append(Row(**row))
row["target"] = strip_array(
matdata["header"]["cell"])[0].lower()
row["target"] = "unknown"
#data.append(Row(**row))
yield row

return self.spark.createDataFrame(data)
#return self.spark.createDataFrame(data)

def read_hdf5(self, specpath: Path, labels: list,
namepattern:str="averaged_spectrogram{}.hdf5") \
Expand Down
29 changes: 28 additions & 1 deletion pipe/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from pyspark.sql import DataFrame, functions, SparkSession, types
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, OneHotEncoder
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler

def merge_redundant_treatment_labels(dataframe: DataFrame) -> DataFrame:
"""
Expand Down Expand Up @@ -43,6 +43,7 @@ def onehot(dataframe: DataFrame, keys: list) -> DataFrame:
pyspark.sql.DataFrame: New DataFrame with one-hot encoded column(s).
"""

""" OLD BLOCK
indexers = []
encoders = []
indexed_cols = []
Expand Down Expand Up @@ -79,6 +80,32 @@ def onehot(dataframe: DataFrame, keys: list) -> DataFrame:
for column_name in keys:
result = result.withColumnRenamed(column_name, f"{column_name}_str")
result = result.withColumnRenamed(f"{column_name}_encoded", column_name)
"""

indexer = StringIndexer(
inputCols=keys,
outputCols=[f"{c}_idx" for c in keys],
handleInvalid="keep"
)

encoder = OneHotEncoder(
inputCols=[f"{c}_idx" for c in keys],
outputCols=[f"{c}_vec" for c in keys],
dropLast=False
)

assembler = VectorAssembler(
inputCols=[f"{c}_vec" for c in keys],
outputCol="features"
)
pipeline = Pipeline(stages=[indexer, encoder, assembler])
model = pipeline.fit(dataframe)
result = model.transform(dataframe)

for c in keys:
result = result.withColumnRenamed(c, f"{c}_str") \
.withColumnRenamed(f"{c}_vec", c)


return result

Expand Down

0 comments on commit 520fbe3

Please sign in to comment.