diff --git a/visualize/visualize.py b/visualize/visualize.py new file mode 100644 index 0000000..5c11e14 --- /dev/null +++ b/visualize/visualize.py @@ -0,0 +1,24 @@ +#-*- coding: utf-8 -*- +""" +visualize.py +""" + +import typing + +import matplotlib.pyplot as plt + +def visualize_data_distribution(data: DataFrame) -> None: + for category in ["treatment", "target"]: + select = data.select(category) \ + .groupby(category) \ + .count() + plt.barh( + np.array(select.select(category).collect()).squeeze(), + np.array(select.select("count").collect()).astype("float") \ + .squeeze()) + plt.xlabel("Count") + plt.ylabel(category) + plt.savefig(f"{category}_counts.png", bbox_inches="tight") + plt.close() + +# EOF