- 
                Notifications
    You must be signed in to change notification settings 
- Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
  
- Loading branch information
            JAX Toolbox
    
  committed
  Apr 20, 2025 
        
          
    
    
    
  
 
        
      
    
      
        1 parent
          
          2e5ef1f
      
      commit 55161f7
    
  
      Showing
      1 changed file
      with
      83 additions
      and
      80 deletions.
    
  
  There are no files selected for viewing
  
    
      This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -1,80 +1,83 @@ | ||
| # -*- coding: utf-8 -*- | ||
| """ | ||
| etl.py | ||
| This module contains the ETL (Extract, Transform, Load) pipeline for processing | ||
| the spectrogram data and the labels. | ||
| """ | ||
|  | ||
| from pyspark.sql import SparkSession, functions, types, Row | ||
| import tensorflow as tf | ||
| import keras | ||
| import matplotlib.pyplot as plt | ||
| from sklearn.metrics import confusion_matrix | ||
| from sklearn.preprocessing import OneHotEncoder | ||
|  | ||
| def transform(spark, dataframe, keys): | ||
| dataframe = dataframe.withColumn( | ||
| "index", functions.monotonically_increasing_id() | ||
| ) | ||
| bundle = {key: [ | ||
| arr.tolist() | ||
| for arr in OneHotEncoder(sparse_output=False) \ | ||
| .fit_transform(dataframe.select(key).collect()) | ||
| ] for key in keys | ||
| } | ||
|  | ||
| bundle = [dict(zip(bundle.keys(), values)) | ||
| for values in zip(*bundle.values())] | ||
| schema = types.StructType([ | ||
| types.StructField(key, types.ArrayType(types.FloatType()), True) | ||
| for key in keys | ||
| ]) | ||
| newframe = spark.createDataFrame(bundle, schema=schema).withColumn( | ||
| "index", functions.monotonically_increasing_id() | ||
| ) | ||
| for key in keys: | ||
| dataframe = dataframe.withColumnRenamed(key, f"{key}_str") | ||
| dataframe = dataframe.join(newframe, on="index", how="inner") | ||
|  | ||
| return dataframe | ||
|  | ||
| def build_dict(df, key): | ||
| df = df.select(key, f"{key}_str").distinct() | ||
|  | ||
| return df.rdd.map( | ||
| lambda row: (str(np.argmax(row[key])), row[f"{key}_str"]) | ||
| ).collectAsMap() | ||
|  | ||
| def extract(spark): | ||
| path = Path("/app/workdir") | ||
| labels = [] | ||
| with open(path / "train.csv", "r") as file: | ||
| for line in file: | ||
| labels.append(line.strip().split(",")[0]) | ||
|  | ||
| pipe = SpectrogramPipe(spark, filetype="matfiles") | ||
|  | ||
| return pipe.spectrogram_pipe(path, labels) | ||
|  | ||
| def load(spark, split=[0.99, 0.005, 0.005]): | ||
| data = extract(spark) | ||
| data.select("treatment").replace("virus", "cpv") \ | ||
| .replace("cont", "pbs") \ | ||
| .replace("control", "pbs") \ | ||
| .replace("dld", "pbs").distinct() | ||
|  | ||
| data = transform(spark, data, ["treatment", "target"]) | ||
| category_dict = { | ||
| key: build_dict(data, key) for key in ["treatment", "target"] | ||
| } | ||
| splits = data.randomSplit(split, seed=42) | ||
| trainx, valx, testx = (trim(dset, "spectrogram") for dset in splits) | ||
| trainy, valy, testy = ( | ||
| [ | ||
| np.array(dset.select("treatment").collect()).squeeze(), | ||
| np.array(dset.select("target").collect()).squeeze() | ||
| ] for dset in splits | ||
| ) | ||
|  | ||
| return ((trainx, trainy), (valx, valy), (testx, testy), category_dict) | ||
| # -*- coding: utf-8 -*- | ||
| """ | ||
| etl.py | ||
| This module contains the ETL (Extract, Transform, Load) pipeline for processing | ||
| the spectrogram data and the labels. | ||
| """ | ||
|  | ||
| import keras | ||
| import matplotlib.pyplot as plt | ||
| from pathlib import Path | ||
| from pyspark.sql import SparkSession, functions, types, Row | ||
| from sklearn.metrics import confusion_matrix | ||
| from sklearn.preprocessing import OneHotEncoder | ||
| import tensorflow as tf | ||
|  | ||
| from pipe.pipe import SpectrogramPipe | ||
|  | ||
| def transform(spark, dataframe, keys): | ||
| dataframe = dataframe.withColumn( | ||
| "index", functions.monotonically_increasing_id() | ||
| ) | ||
| bundle = {key: [ | ||
| arr.tolist() | ||
| for arr in OneHotEncoder(sparse_output=False) \ | ||
| .fit_transform(dataframe.select(key).collect()) | ||
| ] for key in keys | ||
| } | ||
|  | ||
| bundle = [dict(zip(bundle.keys(), values)) | ||
| for values in zip(*bundle.values())] | ||
| schema = types.StructType([ | ||
| types.StructField(key, types.ArrayType(types.FloatType()), True) | ||
| for key in keys | ||
| ]) | ||
| newframe = spark.createDataFrame(bundle, schema=schema).withColumn( | ||
| "index", functions.monotonically_increasing_id() | ||
| ) | ||
| for key in keys: | ||
| dataframe = dataframe.withColumnRenamed(key, f"{key}_str") | ||
| dataframe = dataframe.join(newframe, on="index", how="inner") | ||
|  | ||
| return dataframe | ||
|  | ||
| def build_dict(df, key): | ||
| df = df.select(key, f"{key}_str").distinct() | ||
|  | ||
| return df.rdd.map( | ||
| lambda row: (str(np.argmax(row[key])), row[f"{key}_str"]) | ||
| ).collectAsMap() | ||
|  | ||
| def extract(spark): | ||
| path = Path("/app/workdir") | ||
| labels = [] | ||
| with open(path / "train.csv", "r") as file: | ||
| for line in file: | ||
| labels.append(line.strip().split(",")[0]) | ||
|  | ||
| pipe = SpectrogramPipe(spark, filetype="matfiles") | ||
|  | ||
| return pipe.spectrogram_pipe(path, labels) | ||
|  | ||
| def load(spark, split=[0.99, 0.005, 0.005]): | ||
| data = extract(spark) | ||
| data.select("treatment").replace("virus", "cpv") \ | ||
| .replace("cont", "pbs") \ | ||
| .replace("control", "pbs") \ | ||
| .replace("dld", "pbs").distinct() | ||
|  | ||
| data = transform(spark, data, ["treatment", "target"]) | ||
| category_dict = { | ||
| key: build_dict(data, key) for key in ["treatment", "target"] | ||
| } | ||
| splits = data.randomSplit(split, seed=42) | ||
| trainx, valx, testx = (trim(dset, "spectrogram") for dset in splits) | ||
| trainy, valy, testy = ( | ||
| [ | ||
| np.array(dset.select("treatment").collect()).squeeze(), | ||
| np.array(dset.select("target").collect()).squeeze() | ||
| ] for dset in splits | ||
| ) | ||
|  | ||
| return ((trainx, trainy), (valx, valy), (testx, testy), category_dict) |