Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
ECCV2020_Dynamic/e2_model.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
287 lines (227 sloc)
10.5 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import keras | |
from keras import regularizers as regs | |
from keras import backend as K | |
from keras.layers import * | |
from keras.optimizers import Adam, SGD | |
from keras.models import Model, load_model | |
from keras.callbacks import ModelCheckpoint, LearningRateScheduler, ReduceLROnPlateau, LambdaCallback | |
from math import log | |
from utils import * | |
class model(object): | |
def __init__(self, args): | |
self.num_patch = args.num_patch | |
self.patch_sz = args.patch_sz | |
self.burst_sz = args.burst_sz | |
self.batch_sz = args.batch_sz | |
self.epochs = args.epochs | |
self.lr = args.lr | |
self.ckpt_dir = args.ckpt_dir | |
self.ckpt_file = args.ckpt_file | |
self.train_dir = args.train_dir | |
self.valid_dir = args.valid_dir | |
self.test_dir = args.test_dir | |
self.output_dir = args.output_dir | |
self.shift_ckpt = args.shift_ckpt | |
self.noisy_ckpt = args.noisy_ckpt | |
self.jit = args.jit | |
self.J = args.J | |
self.alpha = args.alpha | |
self.seed = args.seed | |
# Conv2D kernel size | |
self.kernel_size = 3 | |
def learning_rate(self, epoch): | |
factor = [1, 1, 1, 1, 2, 2, 2, 2, 5, 5, 5, 5, 10, 10, 10, 10, 20, 20, 20, 20] | |
factor += [20, 30, 40, 50, 60, 70, 80, 80, 120, 120, 120, 120, 160, 160, 160, 160] | |
return self.lr / factor[epoch//5] if epoch//5 < len(factor) else self.lr / 400 | |
def calculate_encoder_loss(self, y_true, y_pred): | |
def l2_distance(a, b): | |
return K.sum(K.pow((a-b),2)) | |
return K.in_train_phase(l2_distance(y_pred, y_true), K.constant(0.0)) | |
def model_encoder_loss(self): | |
loss_func = self.calculate_encoder_loss | |
def encoder_loss(y_true, y_pred): | |
return loss_func(y_true, y_pred) | |
return encoder_loss | |
def calculate_l2_loss(self, y_true, y_pred): | |
def l2_distance(a, b): | |
return K.sum(K.pow((a-b),2)) | |
return l2_distance(y_pred, y_true) | |
def model_l2_loss(self): | |
loss_func = self.calculate_l2_loss | |
def l2_loss(y_true, y_pred): | |
return loss_func(y_true, y_pred) | |
return l2_loss | |
def calculate_accuracy(self, y_true, y_pred): | |
return -10.0 * (1.0/log(10)) * K.log(K.mean(K.square(y_pred - y_true))) | |
def calculate_ssim(self, y_true, y_pred): | |
import tensorflow as tf | |
return tf.reduce_mean(tf.image.ssim(y_true, y_pred, 1.0)) | |
def model_metric(self): | |
metric_func = self.calculate_accuracy | |
def psnr_metric(y_true, y_pred): | |
return metric_func(y_true, y_pred) | |
return psnr_metric | |
def encoder_block(self, x, num_filters, name=None): | |
# Conv. | |
if name is None: | |
conv = Conv2D(num_filters, kernel_size=self.kernel_size, activation="relu", padding="same", kernel_initializer="he_normal") | |
else: | |
conv = Conv2D(num_filters, kernel_size=self.kernel_size, activation="relu", padding="same", kernel_initializer="he_normal", name=name) | |
y = conv(x) | |
return y | |
def decoder_block(self, x, num_filters, xSkip=None): | |
# Skip connection | |
if xSkip is not None: | |
y = Concatenate()(xSkip) | |
y = Concatenate()([y, x]) | |
else: | |
y = x | |
# Deconv. | |
conv = Conv2DTranspose(num_filters, kernel_size=self.kernel_size, activation="relu", padding="same", kernel_initializer="he_normal") | |
y = conv(y) | |
return y | |
def final_decoder_block(self, x, num_filters): | |
# Deconv. | |
conv = Conv2DTranspose(num_filters, kernel_size=self.kernel_size, activation=None, padding="same", kernel_initializer="he_normal", name='outputs') | |
y = conv(x) | |
return y | |
def network(self, input_shape=(None, None, 8)): | |
inputs = Input(shape=input_shape) | |
# Denoising sub-network encoder | |
xN1Conv1 = self.encoder_block(inputs, 64) | |
xN1Conv2 = self.encoder_block(xN1Conv1, 64) | |
xN1Conv3 = self.encoder_block(xN1Conv2, 64) | |
xN1Conv4 = self.encoder_block(xN1Conv3, 64) | |
xN1Conv5 = self.encoder_block(xN1Conv4, 64) | |
xN1Conv6 = self.encoder_block(xN1Conv5, 64) | |
xN1Conv7 = self.encoder_block(xN1Conv6, 64) | |
xN1Conv8 = self.encoder_block(xN1Conv7, 64) | |
xN1Conv9 = self.encoder_block(xN1Conv8, 64) | |
xN1Conv10 = self.encoder_block(xN1Conv9, 64) | |
xN1Conv11 = self.encoder_block(xN1Conv10, 64) | |
xN1Conv12 = self.encoder_block(xN1Conv11, 64) | |
xN1Conv13 = self.encoder_block(xN1Conv12, 64) | |
xN1Conv14 = self.encoder_block(xN1Conv13, 64) | |
xN1Conv15 = self.encoder_block(xN1Conv14, 64, name="xNet1") | |
xNet1 = xN1Conv15 | |
# Shift aligning sub-network encoder | |
xN2Conv1 = self.encoder_block(inputs, 16) | |
xN2Conv1 = self.encoder_block(xN2Conv1, 16) | |
xN2Conv1 = self.encoder_block(xN2Conv1, 16) | |
xN2Conv2 = self.encoder_block(xN2Conv1, 32) | |
xN2Conv2 = self.encoder_block(xN2Conv2, 32) | |
xN2Conv2 = self.encoder_block(xN2Conv2, 32) | |
xN2Conv3 = self.encoder_block(xN2Conv2, 64) | |
xN2Conv3 = self.encoder_block(xN2Conv3, 64) | |
xN2Conv3 = self.encoder_block(xN2Conv3, 64) | |
xN2Conv4 = self.encoder_block(xN2Conv3, 64) | |
xN2Conv4 = self.encoder_block(xN2Conv4, 64) | |
xN2Conv4 = self.encoder_block(xN2Conv4, 64) | |
xN2Conv5 = self.encoder_block(xN2Conv4, 64) | |
xN2Conv5 = self.encoder_block(xN2Conv5, 64) | |
xN2Conv5 = self.encoder_block(xN2Conv5, 64, name="xNet2") | |
xNet2 = xN2Conv5 | |
# Concat 2 encoders | |
xBoth = Concatenate()([xNet1, xNet2]) | |
# Decoder | |
xDeconv5 = self.decoder_block(xBoth, 128) | |
xDeconv5 = self.decoder_block(xDeconv5, 128) | |
xDeconv5 = self.decoder_block(xDeconv5, 128) | |
xDeconv4 = self.decoder_block(xDeconv5, 128, xSkip=xN2Conv4) | |
xDeconv4 = self.decoder_block(xDeconv4, 128) | |
xDeconv4 = self.decoder_block(xDeconv4, 128) | |
xDeconv3 = self.decoder_block(xDeconv4, 128, xSkip=xN2Conv3) | |
xDeconv3 = self.decoder_block(xDeconv3, 128) | |
xDeconv3 = self.decoder_block(xDeconv3, 128) | |
xDeconv2 = self.decoder_block(xDeconv3, 128, xSkip=xN2Conv2) | |
xDeconv2 = self.decoder_block(xDeconv2, 128) | |
xDeconv2 = self.decoder_block(xDeconv2, 128) | |
xDeconv1 = self.decoder_block(xDeconv2, 128, xSkip=xN2Conv1) | |
xDeconv1 = self.decoder_block(xDeconv1, 128) | |
outputs = self.final_decoder_block(xDeconv1, 1) | |
# Return model | |
model = Model(inputs=inputs, outputs=[xNet1, xNet2, outputs]) | |
return model | |
def train(self): | |
# Prepare training and validation data | |
print("[*] Loading training data ...") | |
random_init(self.seed) | |
train_gen = BurstSequence2E(self.train_dir, self.patch_sz, self.num_patch, self.burst_sz, self.batch_sz, | |
jit=self.jit, J=self.J, rd_crop=False, noise=True, alpha=self.alpha, read_noise=0.25, | |
regen_after=5, renoi_after=5, is_train=True, shift_ckpt=self.shift_ckpt, noisy_ckpt=self.noisy_ckpt) | |
valid_gen = BurstSequence2E(self.valid_dir, self.patch_sz, self.num_patch, self.burst_sz, self.batch_sz, | |
jit=self.jit, J=self.J, rd_crop=False, noise=True, alpha=self.alpha, read_noise=0.25, | |
regen_after=10, renoi_after=10) | |
# Prepare callback functions | |
print("[*] Preparing callbacks ...") | |
filepath_w = os.path.join(self.ckpt_dir, "weighted_"+self.ckpt_file) | |
filepath = os.path.join(self.ckpt_dir, self.ckpt_file) | |
checkpoint_w = ModelCheckpoint(filepath=filepath_w, monitor='val_outputs_psnr_metric', verbose=1, mode='max', save_best_only=True) | |
checkpoint = ModelCheckpoint(filepath=filepath, monitor='val_outputs_psnr_metric', verbose=1, mode='max', save_best_only=True) | |
lr_scheduler = LearningRateScheduler(self.learning_rate) | |
lr_reducer = ReduceLROnPlateau(factor=0.5, cooldown=0, patience=8, verbose=1, min_lr=1e-8) | |
callbacks_w = [checkpoint_w, lr_scheduler, lr_reducer] | |
callbacks = [checkpoint, lr_scheduler, lr_reducer] | |
# Compile network and train | |
print("[*] Start training ...") | |
model = self.network(input_shape=(self.patch_sz, self.patch_sz, self.burst_sz)) | |
# Stage 1: Train with weighted loss | |
loss_weights = [10.0, 5.0, 2.0, 1.0] | |
for i in range(4): | |
model.compile(loss={"xNet1": self.model_encoder_loss(), "xNet2": self.model_encoder_loss(), "outputs": self.model_l2_loss()}, | |
loss_weights={"xNet1": loss_weights[i]/2, "xNet2": loss_weights[i]/2, "outputs": 1.0}, | |
optimizer=Adam(lr=self.learning_rate(0)), metrics={"outputs": self.model_metric()}) | |
model.outputs[0]._uses_learning_phase = True # xNet1 | |
model.outputs[1]._uses_learning_phase = True # xNet2 | |
model.summary() | |
history = model.fit_generator(train_gen, | |
epochs=self.epochs//4, | |
initial_epoch=i*(self.epochs//4), | |
validation_data=valid_gen, | |
shuffle=True, | |
callbacks=callbacks_w) | |
# Stage 2: Train with constant loss | |
model.compile(loss={"outputs": self.model_l2_loss()}, optimizer=Adam(lr=self.learning_rate(0)), metrics={"outputs": self.model_metric()}) | |
history = model.fit_generator(train_gen, | |
epochs=self.epochs, | |
initial_epoch=0, | |
validation_data=valid_gen, | |
shuffle=True, | |
callbacks=callbacks) | |
# Score trained model | |
scores = model.evaluate_generator(valid_gen, verbose=1) | |
print('Test loss:', scores[0]) | |
print('Test accuracy:', scores[2]) | |
def test(self): | |
# Prepare test data | |
print("[*] Loading test data ...") | |
random_init(self.seed) | |
test_gen = BurstSequence2E(self.test_dir, self.patch_sz, self.num_patch, self.burst_sz, self.batch_sz, | |
jit=self.jit, J=self.J, rd_crop=False, noise=True, alpha=self.alpha, read_noise=0.25) | |
# Load model | |
filepath = os.path.join(self.ckpt_dir, self.ckpt_file) | |
model = load_model(filepath, custom_objects={'l2_loss': self.model_l2_loss(), 'psnr_metric': self.model_metric()}) | |
# Test for PSNR | |
score = model.evaluate_generator(test_gen, verbose=0) | |
with open(self.output_dir+"/psnr.txt", "w") as logfile: | |
logfile.write("l2_loss = %.2f \n" % score[0]) | |
logfile.write("PSNR = %.4f \n" % score[2]) | |
# Save sample denoised images | |
test_gen.batch_sz = 16 | |
images = test_gen.get_images(0) | |
y_test = images["y"] # clean, static | |
x_shift = images["x_shift"] # clean, dynamic | |
x_noisy = images["x_noisy"] # noisy, static | |
x_burst = images["x_burst"] # noisy, dynamic | |
y_pred = model.predict(x_burst, batch_size=1) | |
y_pred = y_pred[2] # outputs | |
y_hat = K.sum(y_pred, axis=-1, keepdims=True) | |
y_hat = K.eval(y_hat) | |
for i in range(16): | |
save_image(y_hat[i], self.output_dir, "img_%d_out.png" % i) | |
save_image(y_test[i], self.output_dir, "img_%d_gt.png" % i) | |
for j in range(8): | |
#save_image(np.reshape(x_shift[i,:,:,j], (64,64,1)), self.output_dir, "img_%d_shift_%d.png" % (i, j)) | |
#save_image(np.reshape(x_noisy[i,:,:,j], (64,64,1)), self.output_dir, "img_%d_noisy_%d.png" % (i, j)) | |
save_image(np.reshape(x_burst[i,:,:,j], (self.patch_sz, self.patch_sz, 1)), self.output_dir, "img_%d_burst_%d.png" % (i, j)) | |
save_image(np.sum(x_burst[i], axis=2, keepdims=True)/self.burst_sz, self.output_dir, "img_%d_burstsum.png" % i) | |