diff --git a/model_eval.py b/model_eval.py index 5db5572..6ff0a57 100644 --- a/model_eval.py +++ b/model_eval.py @@ -15,13 +15,18 @@ def main(): metric_keys = ["test_accuracy", "test_loss"] [hyperparam_keys.remove(key) for key in metric_keys] - for key in hyperparam_keys: - for metric_key in metric_keys: - plot.lineplot( - data=metric, - x=key, - y=metric_key, - ) + for x_key in hyperparam_keys: + for hue_key in hyperparam_keys: + if x_key == hue_key: + continue + else: + for metric_key in metric_keys: + plot.lineplot( + data=metric, + x=x_key, + y=metric_key, + hue=hue_key, + ) if __name__ == "__main__": main()