Skip to content

Commit

Permalink
Fixed minor issues with names in the etl pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
Dawith committed Sep 30, 2025
1 parent 08f4d25 commit 07aebcf
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 20 deletions.
19 changes: 0 additions & 19 deletions pipe/etl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,25 +27,6 @@ def etl(spark):
data = load(data)
return data

def build_dict(df, key):
"""
Takes a dataframe as input and returns a dictionary of unique values
in the column corresponding to the key.
"""

df = df.select(key, f"{key}_str").distinct()

return df.rdd.map(
lambda row: (str(np.argmax(row[key])), row[f"{key}_str"])
).collectAsMap()

def trim(dataframe, column):

ndarray = np.array(dataframe.select(column).collect()) \
.reshape(-1, 32, 130)

return ndarray

def visualize_data_distribution(data):
for category in ["treatment", "target"]:
select = data.select(category) \
Expand Down
2 changes: 1 addition & 1 deletion pipe/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def extract(spark):

reader = SpectrogramReader(spark, filetype="matfiles")

return spectrogram_read(path, labels)
return reader.spectrogram_read(path, labels)

def image_pipe(spark: SparkSession, imagepath: Path, namepattern: str,
stacksize: int) -> np.ndarray:
Expand Down
21 changes: 21 additions & 0 deletions pipe/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,26 @@
import numpy as np
from pyspark.sql import DataFrame


def build_dict(df, key):
"""
Takes a dataframe as input and returns a dictionary of unique values
in the column corresponding to the key.
"""

df = df.select(key, f"{key}_str").distinct()

return df.rdd.map(
lambda row: (str(np.argmax(row[key])), row[f"{key}_str"])
).collectAsMap()

def trim(dataframe, column):

ndarray = np.array(dataframe.select(column).collect()) \
.reshape(-1, 32, 130)

return ndarray

def load(data: DataFrame, split=[0.99, 0.005, 0.005]):
category_dict = {
key: build_dict(data, key) for key in ["treatment", "target"]
Expand All @@ -23,3 +43,4 @@ def load(data: DataFrame, split=[0.99, 0.005, 0.005]):

return ((trainx, trainy), (valx, valy), (testx, testy), category_dict)

# EOF

0 comments on commit 07aebcf

Please sign in to comment.