Skip to content

Commit

Permalink
missing import
Browse files Browse the repository at this point in the history
  • Loading branch information
JAX Toolbox committed Apr 20, 2025
1 parent 2e5ef1f commit 55161f7
Showing 1 changed file with 83 additions and 80 deletions.
163 changes: 83 additions & 80 deletions pipe/etl.py
Original file line number Diff line number Diff line change
@@ -1,80 +1,83 @@
# -*- coding: utf-8 -*-
"""
etl.py
This module contains the ETL (Extract, Transform, Load) pipeline for processing
the spectrogram data and the labels.
"""

from pyspark.sql import SparkSession, functions, types, Row
import tensorflow as tf
import keras
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import OneHotEncoder

def transform(spark, dataframe, keys):
dataframe = dataframe.withColumn(
"index", functions.monotonically_increasing_id()
)
bundle = {key: [
arr.tolist()
for arr in OneHotEncoder(sparse_output=False) \
.fit_transform(dataframe.select(key).collect())
] for key in keys
}

bundle = [dict(zip(bundle.keys(), values))
for values in zip(*bundle.values())]
schema = types.StructType([
types.StructField(key, types.ArrayType(types.FloatType()), True)
for key in keys
])
newframe = spark.createDataFrame(bundle, schema=schema).withColumn(
"index", functions.monotonically_increasing_id()
)
for key in keys:
dataframe = dataframe.withColumnRenamed(key, f"{key}_str")
dataframe = dataframe.join(newframe, on="index", how="inner")

return dataframe

def build_dict(df, key):
df = df.select(key, f"{key}_str").distinct()

return df.rdd.map(
lambda row: (str(np.argmax(row[key])), row[f"{key}_str"])
).collectAsMap()

def extract(spark):
path = Path("/app/workdir")
labels = []
with open(path / "train.csv", "r") as file:
for line in file:
labels.append(line.strip().split(",")[0])

pipe = SpectrogramPipe(spark, filetype="matfiles")

return pipe.spectrogram_pipe(path, labels)

def load(spark, split=[0.99, 0.005, 0.005]):
data = extract(spark)
data.select("treatment").replace("virus", "cpv") \
.replace("cont", "pbs") \
.replace("control", "pbs") \
.replace("dld", "pbs").distinct()

data = transform(spark, data, ["treatment", "target"])
category_dict = {
key: build_dict(data, key) for key in ["treatment", "target"]
}
splits = data.randomSplit(split, seed=42)
trainx, valx, testx = (trim(dset, "spectrogram") for dset in splits)
trainy, valy, testy = (
[
np.array(dset.select("treatment").collect()).squeeze(),
np.array(dset.select("target").collect()).squeeze()
] for dset in splits
)

return ((trainx, trainy), (valx, valy), (testx, testy), category_dict)
# -*- coding: utf-8 -*-
"""
etl.py
This module contains the ETL (Extract, Transform, Load) pipeline for processing
the spectrogram data and the labels.
"""

import keras
import matplotlib.pyplot as plt
from pathlib import Path
from pyspark.sql import SparkSession, functions, types, Row
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import OneHotEncoder
import tensorflow as tf

from pipe.pipe import SpectrogramPipe

def transform(spark, dataframe, keys):
dataframe = dataframe.withColumn(
"index", functions.monotonically_increasing_id()
)
bundle = {key: [
arr.tolist()
for arr in OneHotEncoder(sparse_output=False) \
.fit_transform(dataframe.select(key).collect())
] for key in keys
}

bundle = [dict(zip(bundle.keys(), values))
for values in zip(*bundle.values())]
schema = types.StructType([
types.StructField(key, types.ArrayType(types.FloatType()), True)
for key in keys
])
newframe = spark.createDataFrame(bundle, schema=schema).withColumn(
"index", functions.monotonically_increasing_id()
)
for key in keys:
dataframe = dataframe.withColumnRenamed(key, f"{key}_str")
dataframe = dataframe.join(newframe, on="index", how="inner")

return dataframe

def build_dict(df, key):
df = df.select(key, f"{key}_str").distinct()

return df.rdd.map(
lambda row: (str(np.argmax(row[key])), row[f"{key}_str"])
).collectAsMap()

def extract(spark):
path = Path("/app/workdir")
labels = []
with open(path / "train.csv", "r") as file:
for line in file:
labels.append(line.strip().split(",")[0])

pipe = SpectrogramPipe(spark, filetype="matfiles")

return pipe.spectrogram_pipe(path, labels)

def load(spark, split=[0.99, 0.005, 0.005]):
data = extract(spark)
data.select("treatment").replace("virus", "cpv") \
.replace("cont", "pbs") \
.replace("control", "pbs") \
.replace("dld", "pbs").distinct()

data = transform(spark, data, ["treatment", "target"])
category_dict = {
key: build_dict(data, key) for key in ["treatment", "target"]
}
splits = data.randomSplit(split, seed=42)
trainx, valx, testx = (trim(dset, "spectrogram") for dset in splits)
trainy, valy, testy = (
[
np.array(dset.select("treatment").collect()).squeeze(),
np.array(dset.select("target").collect()).squeeze()
] for dset in splits
)

return ((trainx, trainy), (valx, valy), (testx, testy), category_dict)

0 comments on commit 55161f7

Please sign in to comment.