Skip to content

Commit

Permalink
ETL moved out to its own file so that it can be shared
Browse files Browse the repository at this point in the history
  • Loading branch information
lim185 committed Apr 20, 2025
1 parent c70fa91 commit a8cae58
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 1 deletion.
51 changes: 51 additions & 0 deletions analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# -*- coding: utf-8 -*-
"""
analysis.py
This script is used to perform classical machine learning analysis methods on
the data.
"""

import matplotlib.pyplot as plt
from pyspark.ml.feature import PCA
from pyspark.ml.linalg import Vectors

from etl import load

def category_distribution(data):
"""
Plot the distribution of the categories in the data to visualize the skew
in the distribution of data.
"""
for category in ["treatment", "target"]:
select = data.select(category).groupby(category).count()
plt.bar(select[category], select["count"])
plt.close()

def pca(data):
"""
Perform PCA on the data.
:param data: The data to perform PCA on.
:return: The PCA model.
"""

# Create a DataFrame with the features
features = data.select("features") \
.rdd.map(lambda x: Vectors \
.dense(x[0])).toDF(["features"])

# Create a PCA model
pca = PCA(k=2, inputCol="features", outputCol="pca_features")
model = pca.fit(features)

# Transform the data
pca_data = model.transform(features)

return pca_data

if __name__ == "__main__":
data = load(spark, split=[0.9, 0.5, 0.5])

pca(data)
caegory_distribution(data)

# EOF
80 changes: 80 additions & 0 deletions etl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# -*- coding: utf-8 -*-
"""
etl.py
This module contains the ETL (Extract, Transform, Load) pipeline for processing
the spectrogram data and the labels.
"""

from pyspark.sql import SparkSession, functions, types, Row
import tensorflow as tf
import keras
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import OneHotEncoder

def transform(spark, dataframe, keys):
dataframe = dataframe.withColumn(
"index", functions.monotonically_increasing_id()
)
bundle = {key: [
arr.tolist()
for arr in OneHotEncoder(sparse_output=False) \
.fit_transform(dataframe.select(key).collect())
] for key in keys
}

bundle = [dict(zip(bundle.keys(), values))
for values in zip(*bundle.values())]
schema = types.StructType([
types.StructField(key, types.ArrayType(types.FloatType()), True)
for key in keys
])
newframe = spark.createDataFrame(bundle, schema=schema).withColumn(
"index", functions.monotonically_increasing_id()
)
for key in keys:
dataframe = dataframe.withColumnRenamed(key, f"{key}_str")
dataframe = dataframe.join(newframe, on="index", how="inner")

return dataframe

def build_dict(df, key):
df = df.select(key, f"{key}_str").distinct()

return df.rdd.map(
lambda row: (str(np.argmax(row[key])), row[f"{key}_str"])
).collectAsMap()

def extract(spark):
path = Path("/app/workdir")
labels = []
with open(path / "train.csv", "r") as file:
for line in file:
labels.append(line.strip().split(",")[0])

pipe = SpectrogramPipe(spark, filetype="matfiles")

return pipe.spectrogram_pipe(path, labels)

def load(spark, split=[0.99, 0.005, 0.005]):
data = extract(spark)
data.select("treatment").replace("virus", "cpv") \
.replace("cont", "pbs") \
.replace("control", "pbs") \
.replace("dld", "pbs").distinct()

data = transform(spark, data, ["treatment", "target"])
category_dict = {
key: build_dict(data, key) for key in ["treatment", "target"]
}
splits = data.randomSplit(split, seed=42)
trainx, valx, testx = (trim(dset, "spectrogram") for dset in splits)
trainy, valy, testy = (
[
np.array(dset.select("treatment").collect()).squeeze(),
np.array(dset.select("target").collect()).squeeze()
] for dset in splits
)

return ((trainx, trainy), (valx, valy), (testx, testy), category_dict)
3 changes: 2 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from sklearn.preprocessing import OneHotEncoder

from model.model import TimeSeriesTransformer as TSTF
from etl import load

with open("parameters.json", "r") as file:
params = json.load(file)
Expand Down Expand Up @@ -152,7 +153,7 @@ def main():

keys = ["treatment", "target"]
(train_set, validation_set,
test_set, categories) = get_data(spark, split=SPLIT)
test_set, categories) = load(spark, split=SPLIT)#get_data(spark, split=SPLIT)

n_classes = [dset.shape[1] for dset in train_set[1]]
model = get_model(train_set[0].shape[1:], n_classes)
Expand Down

0 comments on commit a8cae58

Please sign in to comment.