diff --git a/visualize/visualize.py b/visualize/visualize.py index 5c11e14..fb63894 100644 --- a/visualize/visualize.py +++ b/visualize/visualize.py @@ -8,6 +8,15 @@ import matplotlib.pyplot as plt def visualize_data_distribution(data: DataFrame) -> None: + """ + Visualize the distribution of treatment and target categories in the + dataset. + + Args: + data (DataFrame): The input DataFrame containing treatment and target + columns. + """ + for category in ["treatment", "target"]: select = data.select(category) \ .groupby(category) \ @@ -18,7 +27,8 @@ def visualize_data_distribution(data: DataFrame) -> None: .squeeze()) plt.xlabel("Count") plt.ylabel(category) - plt.savefig(f"{category}_counts.png", bbox_inches="tight") + plt.savefig(f"/app/workdir/figures/{category}_counts.png", + bbox_inches="tight") plt.close() # EOF