Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Dawith committed Jun 11, 2025
1 parent 55161f7 commit 6b3eb45
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 3 deletions.
6 changes: 4 additions & 2 deletions analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import matplotlib.pyplot as plt
from pyspark.ml.feature import PCA
from pyspark.ml.linalg import Vectors
from pyspark.sql import SparkSession, functions, types, Row

from pipe.etl import load

Expand All @@ -21,7 +22,7 @@ def category_distribution(data):
plt.bar(select[category], select["count"])
plt.close()

def pca(data):
def pca(data, features):
"""
Perform PCA on the data.
:param data: The data to perform PCA on.
Expand All @@ -43,9 +44,10 @@ def pca(data):
return pca_data

if __name__ == "__main__":
spark = SparkSession.builder.appName("train").getOrCreate()
data = load(spark, split=[0.9, 0.5, 0.5])

pca(data)
caegory_distribution(data)
category_distribution(data)

# EOF
18 changes: 18 additions & 0 deletions pipe/etl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import keras
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from pyspark.sql import SparkSession, functions, types, Row
from sklearn.metrics import confusion_matrix
Expand Down Expand Up @@ -49,6 +50,13 @@ def build_dict(df, key):
lambda row: (str(np.argmax(row[key])), row[f"{key}_str"])
).collectAsMap()

def trim(dataframe, column):

ndarray = np.array(dataframe.select(column).collect()) \
.reshape(-1, 32, 130)

return ndarray

def extract(spark):
path = Path("/app/workdir")
labels = []
Expand All @@ -67,6 +75,16 @@ def load(spark, split=[0.99, 0.005, 0.005]):
.replace("control", "pbs") \
.replace("dld", "pbs").distinct()

for category in ["treatment", "target"]:
select = data.select(category).groupby(category).count()
plt.barh(np.array(select.select(category).collect()).squeeze(),
np.array(select.select("count").collect()).astype("float") \
.squeeze())
plt.xlabel("Count")
plt.ylabel(category)
plt.savefig(f"{category}_counts.png", bbox_inches="tight")
plt.close()
exit()
data = transform(spark, data, ["treatment", "target"])
category_dict = {
key: build_dict(data, key) for key in ["treatment", "target"]
Expand Down
7 changes: 6 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
import jax
import json
import numpy as np
import pyspark as spark
#from pyspark.ml.feature import OneHotEncoder, StringIndexer
from pyspark.sql import SparkSession, functions, types, Row
import pyspark as spark
import tensorflow as tf
import keras
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -193,6 +193,11 @@ def main():
plt.savefig(f"/app/workdir/confusion_matrix_{key}.png",
bbox_inches="tight")
plt.close()
with open(f"confusion_matrix_{key}.json", 'w') as f:
confusion_dict = {"prediction": predict.tolist(),
"true": groundtruth.tolist(),
"matrix": conf_matrix.tolist()}
json.dump(confusion_dict, f)

return

Expand Down

0 comments on commit 6b3eb45

Please sign in to comment.