Skip to content

Commit

Permalink
ETL split into separate steps instead of all being rolled into load
Browse files Browse the repository at this point in the history
  • Loading branch information
Dawith committed Sep 29, 2025
1 parent 79ee5da commit 12cfff6
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 24 deletions.
49 changes: 31 additions & 18 deletions pipe/etl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
from pipe.pipe import SpectrogramPipe

def transform(spark, dataframe, keys):
dataframe.select("treatment").replace("virus", "cpv") \
.replace("cont", "pbs") \
.replace("control", "pbs") \
.replace("dld", "pbs").distinct()

dataframe = dataframe.withColumn(
"index", functions.monotonically_increasing_id()
)
Expand All @@ -44,6 +49,11 @@ def transform(spark, dataframe, keys):
return dataframe

def build_dict(df, key):
"""
Takes a dataframe as input and returns a dictionary of unique values
in the column corresponding to the key.
"""

df = df.select(key, f"{key}_str").distinct()

return df.rdd.map(
Expand All @@ -58,6 +68,12 @@ def trim(dataframe, column):
return ndarray

def extract(spark):
"""
First step of the ETL pipeline. It reads the list of .mat files from
a CSV list, opens and pulls the spectrogram from each respective file.
"""

path = Path("/app/workdir")
labels = []
with open(path / "train.csv", "r") as file:
Expand All @@ -68,24 +84,7 @@ def extract(spark):

return pipe.spectrogram_pipe(path, labels)

def load(spark, split=[0.99, 0.005, 0.005]):
data = extract(spark)
data.select("treatment").replace("virus", "cpv") \
.replace("cont", "pbs") \
.replace("control", "pbs") \
.replace("dld", "pbs").distinct()

for category in ["treatment", "target"]:
select = data.select(category).groupby(category).count()
plt.barh(np.array(select.select(category).collect()).squeeze(),
np.array(select.select("count").collect()).astype("float") \
.squeeze())
plt.xlabel("Count")
plt.ylabel(category)
plt.savefig(f"{category}_counts.png", bbox_inches="tight")
plt.close()
exit()
data = transform(spark, data, ["treatment", "target"])
def load(data, split=[0.99, 0.005, 0.005]):
category_dict = {
key: build_dict(data, key) for key in ["treatment", "target"]
}
Expand All @@ -99,3 +98,17 @@ def load(spark, split=[0.99, 0.005, 0.005]):
)

return ((trainx, trainy), (valx, valy), (testx, testy), category_dict)

def visualize_data_distribution(data):
for category in ["treatment", "target"]:
select = data.select(category) \
.groupby(category) \
.count()
plt.barh(
np.array(select.select(category).collect()).squeeze(),
np.array(select.select("count").collect()).astype("float") \
.squeeze())
plt.xlabel("Count")
plt.ylabel(category)
plt.savefig(f"{category}_counts.png", bbox_inches="tight")
plt.close()
18 changes: 12 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from sklearn.preprocessing import OneHotEncoder

from model.model import TimeSeriesTransformer as TSTF
from pipe.etl import load
from pipe.etl import extract, transform, load

with open("parameters.json", "r") as file:
params = json.load(file)
Expand Down Expand Up @@ -115,9 +115,11 @@ def get_data(spark, split=[0.99, 0.005, 0.005]):
}
splits = data.randomSplit(split, seed=42)
trainx, valx, testx = (trim(dset, "spectrogram") for dset in splits)
trainy, valy, testy = ([np.array(dset.select("treatment").collect()).squeeze(),
np.array(dset.select("target").collect()).squeeze()]
for dset in splits)
trainy, valy, testy = (
[np.array(dset.select("treatment").collect()).squeeze(),
np.array(dset.select("target").collect()).squeeze()]
for dset in splits
)


return ((trainx, trainy), (valx, valy), (testx, testy), category_dict)
Expand Down Expand Up @@ -151,8 +153,11 @@ def main():
spark = SparkSession.builder.appName("train").getOrCreate()

keys = ["treatment", "target"]

data = extract(spark)
data = transform(spark, data, ["treatment", "target"])
(train_set, validation_set,
test_set, categories) = load(spark, split=SPLIT)#get_data(spark, split=SPLIT)
test_set, categories) = load(data, split=SPLIT)

n_classes = [dset.shape[1] for dset in train_set[1]]
model = get_model(train_set[0].shape[1:], n_classes)
Expand Down Expand Up @@ -181,7 +186,7 @@ def main():
labels=range(len(categories[key].values())),
normalize="pred"
)
plt.imshow(conf_matrix, origin="upper")
plt.pcolormesh(conf_matrix, edgecolors="black", linewidth=0.5)#origin="upper")
plt.gca().set_aspect("equal")
plt.colorbar()
plt.xticks([int(num) for num in categories[key].keys()],
Expand All @@ -198,6 +203,7 @@ def main():
"true": groundtruth.tolist(),
"matrix": conf_matrix.tolist()}
json.dump(confusion_dict, f)
print("Done")

return

Expand Down

0 comments on commit 12cfff6

Please sign in to comment.