From 12cfff61dd9e055053f72e5b5dd6d5b3ff1bcf71 Mon Sep 17 00:00:00 2001 From: Dawith Date: Mon, 29 Sep 2025 01:32:37 -0400 Subject: [PATCH] ETL split into separate steps instead of all being rolled into load --- pipe/etl.py | 49 +++++++++++++++++++++++++++++++------------------ train.py | 18 ++++++++++++------ 2 files changed, 43 insertions(+), 24 deletions(-) diff --git a/pipe/etl.py b/pipe/etl.py index d8e82cd..4dc3892 100644 --- a/pipe/etl.py +++ b/pipe/etl.py @@ -18,6 +18,11 @@ from pipe.pipe import SpectrogramPipe def transform(spark, dataframe, keys): + dataframe.select("treatment").replace("virus", "cpv") \ + .replace("cont", "pbs") \ + .replace("control", "pbs") \ + .replace("dld", "pbs").distinct() + dataframe = dataframe.withColumn( "index", functions.monotonically_increasing_id() ) @@ -44,6 +49,11 @@ def transform(spark, dataframe, keys): return dataframe def build_dict(df, key): + """ + Takes a dataframe as input and returns a dictionary of unique values + in the column corresponding to the key. + """ + df = df.select(key, f"{key}_str").distinct() return df.rdd.map( @@ -58,6 +68,12 @@ def trim(dataframe, column): return ndarray def extract(spark): + """ + First step of the ETL pipeline. It reads the list of .mat files from + a CSV list, opens and pulls the spectrogram from each respective file. + + """ + path = Path("/app/workdir") labels = [] with open(path / "train.csv", "r") as file: @@ -68,24 +84,7 @@ def extract(spark): return pipe.spectrogram_pipe(path, labels) -def load(spark, split=[0.99, 0.005, 0.005]): - data = extract(spark) - data.select("treatment").replace("virus", "cpv") \ - .replace("cont", "pbs") \ - .replace("control", "pbs") \ - .replace("dld", "pbs").distinct() - - for category in ["treatment", "target"]: - select = data.select(category).groupby(category).count() - plt.barh(np.array(select.select(category).collect()).squeeze(), - np.array(select.select("count").collect()).astype("float") \ - .squeeze()) - plt.xlabel("Count") - plt.ylabel(category) - plt.savefig(f"{category}_counts.png", bbox_inches="tight") - plt.close() - exit() - data = transform(spark, data, ["treatment", "target"]) +def load(data, split=[0.99, 0.005, 0.005]): category_dict = { key: build_dict(data, key) for key in ["treatment", "target"] } @@ -99,3 +98,17 @@ def load(spark, split=[0.99, 0.005, 0.005]): ) return ((trainx, trainy), (valx, valy), (testx, testy), category_dict) + +def visualize_data_distribution(data): + for category in ["treatment", "target"]: + select = data.select(category) \ + .groupby(category) \ + .count() + plt.barh( + np.array(select.select(category).collect()).squeeze(), + np.array(select.select("count").collect()).astype("float") \ + .squeeze()) + plt.xlabel("Count") + plt.ylabel(category) + plt.savefig(f"{category}_counts.png", bbox_inches="tight") + plt.close() diff --git a/train.py b/train.py index b028e9b..1562f3e 100644 --- a/train.py +++ b/train.py @@ -24,7 +24,7 @@ from sklearn.preprocessing import OneHotEncoder from model.model import TimeSeriesTransformer as TSTF -from pipe.etl import load +from pipe.etl import extract, transform, load with open("parameters.json", "r") as file: params = json.load(file) @@ -115,9 +115,11 @@ def get_data(spark, split=[0.99, 0.005, 0.005]): } splits = data.randomSplit(split, seed=42) trainx, valx, testx = (trim(dset, "spectrogram") for dset in splits) - trainy, valy, testy = ([np.array(dset.select("treatment").collect()).squeeze(), - np.array(dset.select("target").collect()).squeeze()] - for dset in splits) + trainy, valy, testy = ( + [np.array(dset.select("treatment").collect()).squeeze(), + np.array(dset.select("target").collect()).squeeze()] + for dset in splits + ) return ((trainx, trainy), (valx, valy), (testx, testy), category_dict) @@ -151,8 +153,11 @@ def main(): spark = SparkSession.builder.appName("train").getOrCreate() keys = ["treatment", "target"] + + data = extract(spark) + data = transform(spark, data, ["treatment", "target"]) (train_set, validation_set, - test_set, categories) = load(spark, split=SPLIT)#get_data(spark, split=SPLIT) + test_set, categories) = load(data, split=SPLIT) n_classes = [dset.shape[1] for dset in train_set[1]] model = get_model(train_set[0].shape[1:], n_classes) @@ -181,7 +186,7 @@ def main(): labels=range(len(categories[key].values())), normalize="pred" ) - plt.imshow(conf_matrix, origin="upper") + plt.pcolormesh(conf_matrix, edgecolors="black", linewidth=0.5)#origin="upper") plt.gca().set_aspect("equal") plt.colorbar() plt.xticks([int(num) for num in categories[key].keys()], @@ -198,6 +203,7 @@ def main(): "true": groundtruth.tolist(), "matrix": conf_matrix.tolist()} json.dump(confusion_dict, f) + print("Done") return