-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ETL moved out to its own file so that it can be shared
- Loading branch information
Showing
3 changed files
with
133 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| # -*- coding: utf-8 -*- | ||
| """ | ||
| analysis.py | ||
| This script is used to perform classical machine learning analysis methods on | ||
| the data. | ||
| """ | ||
|
|
||
| import matplotlib.pyplot as plt | ||
| from pyspark.ml.feature import PCA | ||
| from pyspark.ml.linalg import Vectors | ||
|
|
||
| from etl import load | ||
|
|
||
| def category_distribution(data): | ||
| """ | ||
| Plot the distribution of the categories in the data to visualize the skew | ||
| in the distribution of data. | ||
| """ | ||
| for category in ["treatment", "target"]: | ||
| select = data.select(category).groupby(category).count() | ||
| plt.bar(select[category], select["count"]) | ||
| plt.close() | ||
|
|
||
| def pca(data): | ||
| """ | ||
| Perform PCA on the data. | ||
| :param data: The data to perform PCA on. | ||
| :return: The PCA model. | ||
| """ | ||
|
|
||
| # Create a DataFrame with the features | ||
| features = data.select("features") \ | ||
| .rdd.map(lambda x: Vectors \ | ||
| .dense(x[0])).toDF(["features"]) | ||
|
|
||
| # Create a PCA model | ||
| pca = PCA(k=2, inputCol="features", outputCol="pca_features") | ||
| model = pca.fit(features) | ||
|
|
||
| # Transform the data | ||
| pca_data = model.transform(features) | ||
|
|
||
| return pca_data | ||
|
|
||
| if __name__ == "__main__": | ||
| data = load(spark, split=[0.9, 0.5, 0.5]) | ||
|
|
||
| pca(data) | ||
| caegory_distribution(data) | ||
|
|
||
| # EOF |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,80 @@ | ||
| # -*- coding: utf-8 -*- | ||
| """ | ||
| etl.py | ||
| This module contains the ETL (Extract, Transform, Load) pipeline for processing | ||
| the spectrogram data and the labels. | ||
| """ | ||
|
|
||
| from pyspark.sql import SparkSession, functions, types, Row | ||
| import tensorflow as tf | ||
| import keras | ||
| import matplotlib.pyplot as plt | ||
| from sklearn.metrics import confusion_matrix | ||
| from sklearn.preprocessing import OneHotEncoder | ||
|
|
||
| def transform(spark, dataframe, keys): | ||
| dataframe = dataframe.withColumn( | ||
| "index", functions.monotonically_increasing_id() | ||
| ) | ||
| bundle = {key: [ | ||
| arr.tolist() | ||
| for arr in OneHotEncoder(sparse_output=False) \ | ||
| .fit_transform(dataframe.select(key).collect()) | ||
| ] for key in keys | ||
| } | ||
|
|
||
| bundle = [dict(zip(bundle.keys(), values)) | ||
| for values in zip(*bundle.values())] | ||
| schema = types.StructType([ | ||
| types.StructField(key, types.ArrayType(types.FloatType()), True) | ||
| for key in keys | ||
| ]) | ||
| newframe = spark.createDataFrame(bundle, schema=schema).withColumn( | ||
| "index", functions.monotonically_increasing_id() | ||
| ) | ||
| for key in keys: | ||
| dataframe = dataframe.withColumnRenamed(key, f"{key}_str") | ||
| dataframe = dataframe.join(newframe, on="index", how="inner") | ||
|
|
||
| return dataframe | ||
|
|
||
| def build_dict(df, key): | ||
| df = df.select(key, f"{key}_str").distinct() | ||
|
|
||
| return df.rdd.map( | ||
| lambda row: (str(np.argmax(row[key])), row[f"{key}_str"]) | ||
| ).collectAsMap() | ||
|
|
||
| def extract(spark): | ||
| path = Path("/app/workdir") | ||
| labels = [] | ||
| with open(path / "train.csv", "r") as file: | ||
| for line in file: | ||
| labels.append(line.strip().split(",")[0]) | ||
|
|
||
| pipe = SpectrogramPipe(spark, filetype="matfiles") | ||
|
|
||
| return pipe.spectrogram_pipe(path, labels) | ||
|
|
||
| def load(spark, split=[0.99, 0.005, 0.005]): | ||
| data = extract(spark) | ||
| data.select("treatment").replace("virus", "cpv") \ | ||
| .replace("cont", "pbs") \ | ||
| .replace("control", "pbs") \ | ||
| .replace("dld", "pbs").distinct() | ||
|
|
||
| data = transform(spark, data, ["treatment", "target"]) | ||
| category_dict = { | ||
| key: build_dict(data, key) for key in ["treatment", "target"] | ||
| } | ||
| splits = data.randomSplit(split, seed=42) | ||
| trainx, valx, testx = (trim(dset, "spectrogram") for dset in splits) | ||
| trainy, valy, testy = ( | ||
| [ | ||
| np.array(dset.select("treatment").collect()).squeeze(), | ||
| np.array(dset.select("target").collect()).squeeze() | ||
| ] for dset in splits | ||
| ) | ||
|
|
||
| return ((trainx, trainy), (valx, valy), (testx, testy), category_dict) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters