From 6b3eb45830ef81d83b0c12810263fca23c5192a4 Mon Sep 17 00:00:00 2001 From: Dawith Date: Tue, 10 Jun 2025 21:08:01 -0400 Subject: [PATCH] updates --- analysis.py | 6 ++++-- pipe/etl.py | 18 ++++++++++++++++++ train.py | 7 ++++++- 3 files changed, 28 insertions(+), 3 deletions(-) diff --git a/analysis.py b/analysis.py index 19b4fac..f8b7cfa 100644 --- a/analysis.py +++ b/analysis.py @@ -8,6 +8,7 @@ import matplotlib.pyplot as plt from pyspark.ml.feature import PCA from pyspark.ml.linalg import Vectors +from pyspark.sql import SparkSession, functions, types, Row from pipe.etl import load @@ -21,7 +22,7 @@ def category_distribution(data): plt.bar(select[category], select["count"]) plt.close() -def pca(data): +def pca(data, features): """ Perform PCA on the data. :param data: The data to perform PCA on. @@ -43,9 +44,10 @@ def pca(data): return pca_data if __name__ == "__main__": + spark = SparkSession.builder.appName("train").getOrCreate() data = load(spark, split=[0.9, 0.5, 0.5]) pca(data) - caegory_distribution(data) + category_distribution(data) # EOF diff --git a/pipe/etl.py b/pipe/etl.py index 433e78e..d8e82cd 100644 --- a/pipe/etl.py +++ b/pipe/etl.py @@ -8,6 +8,7 @@ import keras import matplotlib.pyplot as plt +import numpy as np from pathlib import Path from pyspark.sql import SparkSession, functions, types, Row from sklearn.metrics import confusion_matrix @@ -49,6 +50,13 @@ def build_dict(df, key): lambda row: (str(np.argmax(row[key])), row[f"{key}_str"]) ).collectAsMap() +def trim(dataframe, column): + + ndarray = np.array(dataframe.select(column).collect()) \ + .reshape(-1, 32, 130) + + return ndarray + def extract(spark): path = Path("/app/workdir") labels = [] @@ -67,6 +75,16 @@ def load(spark, split=[0.99, 0.005, 0.005]): .replace("control", "pbs") \ .replace("dld", "pbs").distinct() + for category in ["treatment", "target"]: + select = data.select(category).groupby(category).count() + plt.barh(np.array(select.select(category).collect()).squeeze(), + np.array(select.select("count").collect()).astype("float") \ + .squeeze()) + plt.xlabel("Count") + plt.ylabel(category) + plt.savefig(f"{category}_counts.png", bbox_inches="tight") + plt.close() + exit() data = transform(spark, data, ["treatment", "target"]) category_dict = { key: build_dict(data, key) for key in ["treatment", "target"] diff --git a/train.py b/train.py index 3ef4f49..b028e9b 100644 --- a/train.py +++ b/train.py @@ -14,9 +14,9 @@ import jax import json import numpy as np -import pyspark as spark #from pyspark.ml.feature import OneHotEncoder, StringIndexer from pyspark.sql import SparkSession, functions, types, Row +import pyspark as spark import tensorflow as tf import keras import matplotlib.pyplot as plt @@ -193,6 +193,11 @@ def main(): plt.savefig(f"/app/workdir/confusion_matrix_{key}.png", bbox_inches="tight") plt.close() + with open(f"confusion_matrix_{key}.json", 'w') as f: + confusion_dict = {"prediction": predict.tolist(), + "true": groundtruth.tolist(), + "matrix": conf_matrix.tolist()} + json.dump(confusion_dict, f) return