From fe00fc6bea90a3df69c34f3d6a7bf86b067bd251 Mon Sep 17 00:00:00 2001 From: Dawith Lim Date: Tue, 30 Sep 2025 12:55:56 -0400 Subject: [PATCH] Function description added --- visualize/visualize.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/visualize/visualize.py b/visualize/visualize.py index 5c11e14..fb63894 100644 --- a/visualize/visualize.py +++ b/visualize/visualize.py @@ -8,6 +8,15 @@ import matplotlib.pyplot as plt def visualize_data_distribution(data: DataFrame) -> None: + """ + Visualize the distribution of treatment and target categories in the + dataset. + + Args: + data (DataFrame): The input DataFrame containing treatment and target + columns. + """ + for category in ["treatment", "target"]: select = data.select(category) \ .groupby(category) \ @@ -18,7 +27,8 @@ def visualize_data_distribution(data: DataFrame) -> None: .squeeze()) plt.xlabel("Count") plt.ylabel(category) - plt.savefig(f"{category}_counts.png", bbox_inches="tight") + plt.savefig(f"/app/workdir/figures/{category}_counts.png", + bbox_inches="tight") plt.close() # EOF