diff --git a/visualize/plot.py b/visualize/plot.py index da1c722..065c90c 100644 --- a/visualize/plot.py +++ b/visualize/plot.py @@ -3,15 +3,20 @@ import matplotlib.pyplot as plt import seaborn as sns -def lineplot(data=None, x=None, y=None): +def lineplot(data=None, x=None, y=None, hue=None): if data is None or x is None or y is None: raise ValueError("Data, x, and y parameters must be provided.") - sns.lineplot(data=data, x=x, y=y) - plt.title(f"{y} by {x}") + sns.lineplot(data=data, x=x, y=y, hue=hue) plt.xlabel(x) plt.ylabel(y) - plt.savefig(f"/app/workdir/figures/lineplot_{y}_by_{x}.png") + plt.legend() + if hue is not None: + plt.title(f"{y} by {x} and {hue}") + plt.savefig(f"/app/workdir/figures/lineplot_{y}_by_{x}_{hue}.png") + else: + plt.title(f"{y} by {x}") + plt.savefig(f"/app/workdir/figures/lineplot_{y}_by_{x}.png") plt.close()