-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
JAX Toolbox
committed
Apr 20, 2025
1 parent
2e5ef1f
commit 55161f7
Showing
1 changed file
with
83 additions
and
80 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,80 +1,83 @@ | ||
| # -*- coding: utf-8 -*- | ||
| """ | ||
| etl.py | ||
| This module contains the ETL (Extract, Transform, Load) pipeline for processing | ||
| the spectrogram data and the labels. | ||
| """ | ||
|
|
||
| from pyspark.sql import SparkSession, functions, types, Row | ||
| import tensorflow as tf | ||
| import keras | ||
| import matplotlib.pyplot as plt | ||
| from sklearn.metrics import confusion_matrix | ||
| from sklearn.preprocessing import OneHotEncoder | ||
|
|
||
| def transform(spark, dataframe, keys): | ||
| dataframe = dataframe.withColumn( | ||
| "index", functions.monotonically_increasing_id() | ||
| ) | ||
| bundle = {key: [ | ||
| arr.tolist() | ||
| for arr in OneHotEncoder(sparse_output=False) \ | ||
| .fit_transform(dataframe.select(key).collect()) | ||
| ] for key in keys | ||
| } | ||
|
|
||
| bundle = [dict(zip(bundle.keys(), values)) | ||
| for values in zip(*bundle.values())] | ||
| schema = types.StructType([ | ||
| types.StructField(key, types.ArrayType(types.FloatType()), True) | ||
| for key in keys | ||
| ]) | ||
| newframe = spark.createDataFrame(bundle, schema=schema).withColumn( | ||
| "index", functions.monotonically_increasing_id() | ||
| ) | ||
| for key in keys: | ||
| dataframe = dataframe.withColumnRenamed(key, f"{key}_str") | ||
| dataframe = dataframe.join(newframe, on="index", how="inner") | ||
|
|
||
| return dataframe | ||
|
|
||
| def build_dict(df, key): | ||
| df = df.select(key, f"{key}_str").distinct() | ||
|
|
||
| return df.rdd.map( | ||
| lambda row: (str(np.argmax(row[key])), row[f"{key}_str"]) | ||
| ).collectAsMap() | ||
|
|
||
| def extract(spark): | ||
| path = Path("/app/workdir") | ||
| labels = [] | ||
| with open(path / "train.csv", "r") as file: | ||
| for line in file: | ||
| labels.append(line.strip().split(",")[0]) | ||
|
|
||
| pipe = SpectrogramPipe(spark, filetype="matfiles") | ||
|
|
||
| return pipe.spectrogram_pipe(path, labels) | ||
|
|
||
| def load(spark, split=[0.99, 0.005, 0.005]): | ||
| data = extract(spark) | ||
| data.select("treatment").replace("virus", "cpv") \ | ||
| .replace("cont", "pbs") \ | ||
| .replace("control", "pbs") \ | ||
| .replace("dld", "pbs").distinct() | ||
|
|
||
| data = transform(spark, data, ["treatment", "target"]) | ||
| category_dict = { | ||
| key: build_dict(data, key) for key in ["treatment", "target"] | ||
| } | ||
| splits = data.randomSplit(split, seed=42) | ||
| trainx, valx, testx = (trim(dset, "spectrogram") for dset in splits) | ||
| trainy, valy, testy = ( | ||
| [ | ||
| np.array(dset.select("treatment").collect()).squeeze(), | ||
| np.array(dset.select("target").collect()).squeeze() | ||
| ] for dset in splits | ||
| ) | ||
|
|
||
| return ((trainx, trainy), (valx, valy), (testx, testy), category_dict) | ||
| # -*- coding: utf-8 -*- | ||
| """ | ||
| etl.py | ||
| This module contains the ETL (Extract, Transform, Load) pipeline for processing | ||
| the spectrogram data and the labels. | ||
| """ | ||
|
|
||
| import keras | ||
| import matplotlib.pyplot as plt | ||
| from pathlib import Path | ||
| from pyspark.sql import SparkSession, functions, types, Row | ||
| from sklearn.metrics import confusion_matrix | ||
| from sklearn.preprocessing import OneHotEncoder | ||
| import tensorflow as tf | ||
|
|
||
| from pipe.pipe import SpectrogramPipe | ||
|
|
||
| def transform(spark, dataframe, keys): | ||
| dataframe = dataframe.withColumn( | ||
| "index", functions.monotonically_increasing_id() | ||
| ) | ||
| bundle = {key: [ | ||
| arr.tolist() | ||
| for arr in OneHotEncoder(sparse_output=False) \ | ||
| .fit_transform(dataframe.select(key).collect()) | ||
| ] for key in keys | ||
| } | ||
|
|
||
| bundle = [dict(zip(bundle.keys(), values)) | ||
| for values in zip(*bundle.values())] | ||
| schema = types.StructType([ | ||
| types.StructField(key, types.ArrayType(types.FloatType()), True) | ||
| for key in keys | ||
| ]) | ||
| newframe = spark.createDataFrame(bundle, schema=schema).withColumn( | ||
| "index", functions.monotonically_increasing_id() | ||
| ) | ||
| for key in keys: | ||
| dataframe = dataframe.withColumnRenamed(key, f"{key}_str") | ||
| dataframe = dataframe.join(newframe, on="index", how="inner") | ||
|
|
||
| return dataframe | ||
|
|
||
| def build_dict(df, key): | ||
| df = df.select(key, f"{key}_str").distinct() | ||
|
|
||
| return df.rdd.map( | ||
| lambda row: (str(np.argmax(row[key])), row[f"{key}_str"]) | ||
| ).collectAsMap() | ||
|
|
||
| def extract(spark): | ||
| path = Path("/app/workdir") | ||
| labels = [] | ||
| with open(path / "train.csv", "r") as file: | ||
| for line in file: | ||
| labels.append(line.strip().split(",")[0]) | ||
|
|
||
| pipe = SpectrogramPipe(spark, filetype="matfiles") | ||
|
|
||
| return pipe.spectrogram_pipe(path, labels) | ||
|
|
||
| def load(spark, split=[0.99, 0.005, 0.005]): | ||
| data = extract(spark) | ||
| data.select("treatment").replace("virus", "cpv") \ | ||
| .replace("cont", "pbs") \ | ||
| .replace("control", "pbs") \ | ||
| .replace("dld", "pbs").distinct() | ||
|
|
||
| data = transform(spark, data, ["treatment", "target"]) | ||
| category_dict = { | ||
| key: build_dict(data, key) for key in ["treatment", "target"] | ||
| } | ||
| splits = data.randomSplit(split, seed=42) | ||
| trainx, valx, testx = (trim(dset, "spectrogram") for dset in splits) | ||
| trainy, valy, testy = ( | ||
| [ | ||
| np.array(dset.select("treatment").collect()).squeeze(), | ||
| np.array(dset.select("target").collect()).squeeze() | ||
| ] for dset in splits | ||
| ) | ||
|
|
||
| return ((trainx, trainy), (valx, valy), (testx, testy), category_dict) |