diff --git a/model_eval.py b/model_eval.py new file mode 100644 index 0000000..5db5572 --- /dev/null +++ b/model_eval.py @@ -0,0 +1,29 @@ +# -- coding: utf-8 -*- +""" +This script takes the model metric results stored as csv and generates plots to +visualize the performance based on the choice of hyperparameters. +""" + +import pandas as pd + +from visualize import plot + +def main(): + metric = pd.read_csv("/app/workdir/metrics.csv") + print(metric.head()) + hyperparam_keys = metric.columns.tolist() + metric_keys = ["test_accuracy", "test_loss"] + [hyperparam_keys.remove(key) for key in metric_keys] + + for key in hyperparam_keys: + for metric_key in metric_keys: + plot.lineplot( + data=metric, + x=key, + y=metric_key, + ) + +if __name__ == "__main__": + main() + +# EOF diff --git a/visualize/plot.py b/visualize/plot.py new file mode 100644 index 0000000..da1c722 --- /dev/null +++ b/visualize/plot.py @@ -0,0 +1,18 @@ +# --*- coding: utf-8 -*- + +import matplotlib.pyplot as plt +import seaborn as sns + +def lineplot(data=None, x=None, y=None): + if data is None or x is None or y is None: + raise ValueError("Data, x, and y parameters must be provided.") + + sns.lineplot(data=data, x=x, y=y) + plt.title(f"{y} by {x}") + plt.xlabel(x) + plt.ylabel(y) + plt.savefig(f"/app/workdir/figures/lineplot_{y}_by_{x}.png") + plt.close() + + +# EOF