Skip to content

Commit

Permalink
Fixed encoder training
Browse files Browse the repository at this point in the history
  • Loading branch information
Dawith committed Mar 2, 2026
1 parent 9156af5 commit e4092b2
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions train/encoder_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

def encoder_workflow(params, shape, n_classes,
train_set, validation_set, test_set,
categories, keys):
categories, keys, modelpath):
model = build_encoder(params, shape, n_classes)
model = train_encoder(params, model, train_set, validation_set)
test_predict, test_loss, test_accuracy = test_encoder(
Expand All @@ -31,7 +31,8 @@ def encoder_workflow(params, shape, n_classes,
categories,
keys
)


save_encoder(model, path)

def build_encoder(params, input_shape, n_classes):
log_level = params["log_level"]
Expand Down Expand Up @@ -85,7 +86,8 @@ def test_encoder(params, model, test_set, categories, keys):
print(f"Test loss: {test_loss}, test accuracy: {test_accuracy}")
return test_predict, test_loss, test_accuracy

def evaluate_encoder(params, test_predict, test_set, test_loss, test_accuracy, categories, keys):
def evaluate_encoder(params, test_predict, test_set, test_loss, test_accuracy,
categories, keys):
params = params["encoder_params"]
for predict, groundtruth, key in zip(test_predict, test_set[1], keys):
confusion_matrix(predict, groundtruth, categories[key], key)
Expand Down Expand Up @@ -116,4 +118,7 @@ def save_metric(params, test_loss, test_accuracy):
with open("/app/workdir/metrics.csv", "a") as f:
f.write(",".join([str(value) for value in metric.values()]) + "\n")

def save_encoder(model, path):
model.save(path + "encoder.keras")

# EOF

0 comments on commit e4092b2

Please sign in to comment.