diff --git a/train/encoder_train.py b/train/encoder_train.py index fae134b..ceab077 100644 --- a/train/encoder_train.py +++ b/train/encoder_train.py @@ -11,7 +11,7 @@ def encoder_workflow(params, shape, n_classes, train_set, validation_set, test_set, - categories, keys): + categories, keys, modelpath): model = build_encoder(params, shape, n_classes) model = train_encoder(params, model, train_set, validation_set) test_predict, test_loss, test_accuracy = test_encoder( @@ -31,7 +31,8 @@ def encoder_workflow(params, shape, n_classes, categories, keys ) - + + save_encoder(model, path) def build_encoder(params, input_shape, n_classes): log_level = params["log_level"] @@ -85,7 +86,8 @@ def test_encoder(params, model, test_set, categories, keys): print(f"Test loss: {test_loss}, test accuracy: {test_accuracy}") return test_predict, test_loss, test_accuracy -def evaluate_encoder(params, test_predict, test_set, test_loss, test_accuracy, categories, keys): +def evaluate_encoder(params, test_predict, test_set, test_loss, test_accuracy, + categories, keys): params = params["encoder_params"] for predict, groundtruth, key in zip(test_predict, test_set[1], keys): confusion_matrix(predict, groundtruth, categories[key], key) @@ -116,4 +118,7 @@ def save_metric(params, test_loss, test_accuracy): with open("/app/workdir/metrics.csv", "a") as f: f.write(",".join([str(value) for value in metric.values()]) + "\n") +def save_encoder(model, path): + model.save(path + "encoder.keras") + # EOF