Skip to content
Permalink
master
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
import keras
from keras import regularizers as regs
from keras import backend as K
from keras.layers import *
from keras.optimizers import Adam, SGD
from keras.models import Model, load_model
from keras.callbacks import ModelCheckpoint, LearningRateScheduler, ReduceLROnPlateau, LambdaCallback
from math import log
from utils import *
class model(object):
def __init__(self, args):
self.num_patch = args.num_patch
self.patch_sz = args.patch_sz
self.burst_sz = args.burst_sz
self.batch_sz = args.batch_sz
self.epochs = args.epochs
self.lr = args.lr
self.ckpt_dir = args.ckpt_dir
self.ckpt_file = args.ckpt_file
self.train_dir = args.train_dir
self.valid_dir = args.valid_dir
self.test_dir = args.test_dir
self.output_dir = args.output_dir
self.shift_ckpt = args.shift_ckpt
self.noisy_ckpt = args.noisy_ckpt
self.jit = args.jit
self.J = args.J
self.alpha = args.alpha
self.seed = args.seed
# Conv2D kernel size
self.kernel_size = 3
def learning_rate(self, epoch):
factor = [1, 1, 1, 1, 2, 2, 2, 2, 5, 5, 5, 5, 10, 10, 10, 10, 20, 20, 20, 20]
factor += [20, 30, 40, 50, 60, 70, 80, 80, 120, 120, 120, 120, 160, 160, 160, 160]
return self.lr / factor[epoch//5] if epoch//5 < len(factor) else self.lr / 400
def calculate_encoder_loss(self, y_true, y_pred):
def l2_distance(a, b):
return K.sum(K.pow((a-b),2))
return K.in_train_phase(l2_distance(y_pred, y_true), K.constant(0.0))
def model_encoder_loss(self):
loss_func = self.calculate_encoder_loss
def encoder_loss(y_true, y_pred):
return loss_func(y_true, y_pred)
return encoder_loss
def calculate_l2_loss(self, y_true, y_pred):
def l2_distance(a, b):
return K.sum(K.pow((a-b),2))
return l2_distance(y_pred, y_true)
def model_l2_loss(self):
loss_func = self.calculate_l2_loss
def l2_loss(y_true, y_pred):
return loss_func(y_true, y_pred)
return l2_loss
def calculate_accuracy(self, y_true, y_pred):
return -10.0 * (1.0/log(10)) * K.log(K.mean(K.square(y_pred - y_true)))
def calculate_ssim(self, y_true, y_pred):
import tensorflow as tf
return tf.reduce_mean(tf.image.ssim(y_true, y_pred, 1.0))
def model_metric(self):
metric_func = self.calculate_accuracy
def psnr_metric(y_true, y_pred):
return metric_func(y_true, y_pred)
return psnr_metric
def encoder_block(self, x, num_filters, name=None):
# Conv.
if name is None:
conv = Conv2D(num_filters, kernel_size=self.kernel_size, activation="relu", padding="same", kernel_initializer="he_normal")
else:
conv = Conv2D(num_filters, kernel_size=self.kernel_size, activation="relu", padding="same", kernel_initializer="he_normal", name=name)
y = conv(x)
return y
def decoder_block(self, x, num_filters, xSkip=None):
# Skip connection
if xSkip is not None:
y = Concatenate()(xSkip)
y = Concatenate()([y, x])
else:
y = x
# Deconv.
conv = Conv2DTranspose(num_filters, kernel_size=self.kernel_size, activation="relu", padding="same", kernel_initializer="he_normal")
y = conv(y)
return y
def final_decoder_block(self, x, num_filters):
# Deconv.
conv = Conv2DTranspose(num_filters, kernel_size=self.kernel_size, activation=None, padding="same", kernel_initializer="he_normal", name='outputs')
y = conv(x)
return y
def network(self, input_shape=(None, None, 8)):
inputs = Input(shape=input_shape)
# Denoising sub-network encoder
xN1Conv1 = self.encoder_block(inputs, 64)
xN1Conv2 = self.encoder_block(xN1Conv1, 64)
xN1Conv3 = self.encoder_block(xN1Conv2, 64)
xN1Conv4 = self.encoder_block(xN1Conv3, 64)
xN1Conv5 = self.encoder_block(xN1Conv4, 64)
xN1Conv6 = self.encoder_block(xN1Conv5, 64)
xN1Conv7 = self.encoder_block(xN1Conv6, 64)
xN1Conv8 = self.encoder_block(xN1Conv7, 64)
xN1Conv9 = self.encoder_block(xN1Conv8, 64)
xN1Conv10 = self.encoder_block(xN1Conv9, 64)
xN1Conv11 = self.encoder_block(xN1Conv10, 64)
xN1Conv12 = self.encoder_block(xN1Conv11, 64)
xN1Conv13 = self.encoder_block(xN1Conv12, 64)
xN1Conv14 = self.encoder_block(xN1Conv13, 64)
xN1Conv15 = self.encoder_block(xN1Conv14, 64, name="xNet1")
xNet1 = xN1Conv15
# Shift aligning sub-network encoder
xN2Conv1 = self.encoder_block(inputs, 16)
xN2Conv1 = self.encoder_block(xN2Conv1, 16)
xN2Conv1 = self.encoder_block(xN2Conv1, 16)
xN2Conv2 = self.encoder_block(xN2Conv1, 32)
xN2Conv2 = self.encoder_block(xN2Conv2, 32)
xN2Conv2 = self.encoder_block(xN2Conv2, 32)
xN2Conv3 = self.encoder_block(xN2Conv2, 64)
xN2Conv3 = self.encoder_block(xN2Conv3, 64)
xN2Conv3 = self.encoder_block(xN2Conv3, 64)
xN2Conv4 = self.encoder_block(xN2Conv3, 64)
xN2Conv4 = self.encoder_block(xN2Conv4, 64)
xN2Conv4 = self.encoder_block(xN2Conv4, 64)
xN2Conv5 = self.encoder_block(xN2Conv4, 64)
xN2Conv5 = self.encoder_block(xN2Conv5, 64)
xN2Conv5 = self.encoder_block(xN2Conv5, 64, name="xNet2")
xNet2 = xN2Conv5
# Concat 2 encoders
xBoth = Concatenate()([xNet1, xNet2])
# Decoder
xDeconv5 = self.decoder_block(xBoth, 128)
xDeconv5 = self.decoder_block(xDeconv5, 128)
xDeconv5 = self.decoder_block(xDeconv5, 128)
xDeconv4 = self.decoder_block(xDeconv5, 128, xSkip=xN2Conv4)
xDeconv4 = self.decoder_block(xDeconv4, 128)
xDeconv4 = self.decoder_block(xDeconv4, 128)
xDeconv3 = self.decoder_block(xDeconv4, 128, xSkip=xN2Conv3)
xDeconv3 = self.decoder_block(xDeconv3, 128)
xDeconv3 = self.decoder_block(xDeconv3, 128)
xDeconv2 = self.decoder_block(xDeconv3, 128, xSkip=xN2Conv2)
xDeconv2 = self.decoder_block(xDeconv2, 128)
xDeconv2 = self.decoder_block(xDeconv2, 128)
xDeconv1 = self.decoder_block(xDeconv2, 128, xSkip=xN2Conv1)
xDeconv1 = self.decoder_block(xDeconv1, 128)
outputs = self.final_decoder_block(xDeconv1, 1)
# Return model
model = Model(inputs=inputs, outputs=[xNet1, xNet2, outputs])
return model
def train(self):
# Prepare training and validation data
print("[*] Loading training data ...")
random_init(self.seed)
train_gen = BurstSequence2E(self.train_dir, self.patch_sz, self.num_patch, self.burst_sz, self.batch_sz,
jit=self.jit, J=self.J, rd_crop=False, noise=True, alpha=self.alpha, read_noise=0.25,
regen_after=5, renoi_after=5, is_train=True, shift_ckpt=self.shift_ckpt, noisy_ckpt=self.noisy_ckpt)
valid_gen = BurstSequence2E(self.valid_dir, self.patch_sz, self.num_patch, self.burst_sz, self.batch_sz,
jit=self.jit, J=self.J, rd_crop=False, noise=True, alpha=self.alpha, read_noise=0.25,
regen_after=10, renoi_after=10)
# Prepare callback functions
print("[*] Preparing callbacks ...")
filepath_w = os.path.join(self.ckpt_dir, "weighted_"+self.ckpt_file)
filepath = os.path.join(self.ckpt_dir, self.ckpt_file)
checkpoint_w = ModelCheckpoint(filepath=filepath_w, monitor='val_outputs_psnr_metric', verbose=1, mode='max', save_best_only=True)
checkpoint = ModelCheckpoint(filepath=filepath, monitor='val_outputs_psnr_metric', verbose=1, mode='max', save_best_only=True)
lr_scheduler = LearningRateScheduler(self.learning_rate)
lr_reducer = ReduceLROnPlateau(factor=0.5, cooldown=0, patience=8, verbose=1, min_lr=1e-8)
callbacks_w = [checkpoint_w, lr_scheduler, lr_reducer]
callbacks = [checkpoint, lr_scheduler, lr_reducer]
# Compile network and train
print("[*] Start training ...")
model = self.network(input_shape=(self.patch_sz, self.patch_sz, self.burst_sz))
# Stage 1: Train with weighted loss
loss_weights = [10.0, 5.0, 2.0, 1.0]
for i in range(4):
model.compile(loss={"xNet1": self.model_encoder_loss(), "xNet2": self.model_encoder_loss(), "outputs": self.model_l2_loss()},
loss_weights={"xNet1": loss_weights[i]/2, "xNet2": loss_weights[i]/2, "outputs": 1.0},
optimizer=Adam(lr=self.learning_rate(0)), metrics={"outputs": self.model_metric()})
model.outputs[0]._uses_learning_phase = True # xNet1
model.outputs[1]._uses_learning_phase = True # xNet2
model.summary()
history = model.fit_generator(train_gen,
epochs=self.epochs//4,
initial_epoch=i*(self.epochs//4),
validation_data=valid_gen,
shuffle=True,
callbacks=callbacks_w)
# Stage 2: Train with constant loss
model.compile(loss={"outputs": self.model_l2_loss()}, optimizer=Adam(lr=self.learning_rate(0)), metrics={"outputs": self.model_metric()})
history = model.fit_generator(train_gen,
epochs=self.epochs,
initial_epoch=0,
validation_data=valid_gen,
shuffle=True,
callbacks=callbacks)
# Score trained model
scores = model.evaluate_generator(valid_gen, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[2])
def test(self):
# Prepare test data
print("[*] Loading test data ...")
random_init(self.seed)
test_gen = BurstSequence2E(self.test_dir, self.patch_sz, self.num_patch, self.burst_sz, self.batch_sz,
jit=self.jit, J=self.J, rd_crop=False, noise=True, alpha=self.alpha, read_noise=0.25)
# Load model
filepath = os.path.join(self.ckpt_dir, self.ckpt_file)
model = load_model(filepath, custom_objects={'l2_loss': self.model_l2_loss(), 'psnr_metric': self.model_metric()})
# Test for PSNR
score = model.evaluate_generator(test_gen, verbose=0)
with open(self.output_dir+"/psnr.txt", "w") as logfile:
logfile.write("l2_loss = %.2f \n" % score[0])
logfile.write("PSNR = %.4f \n" % score[2])
# Save sample denoised images
test_gen.batch_sz = 16
images = test_gen.get_images(0)
y_test = images["y"] # clean, static
x_shift = images["x_shift"] # clean, dynamic
x_noisy = images["x_noisy"] # noisy, static
x_burst = images["x_burst"] # noisy, dynamic
y_pred = model.predict(x_burst, batch_size=1)
y_pred = y_pred[2] # outputs
y_hat = K.sum(y_pred, axis=-1, keepdims=True)
y_hat = K.eval(y_hat)
for i in range(16):
save_image(y_hat[i], self.output_dir, "img_%d_out.png" % i)
save_image(y_test[i], self.output_dir, "img_%d_gt.png" % i)
for j in range(8):
#save_image(np.reshape(x_shift[i,:,:,j], (64,64,1)), self.output_dir, "img_%d_shift_%d.png" % (i, j))
#save_image(np.reshape(x_noisy[i,:,:,j], (64,64,1)), self.output_dir, "img_%d_noisy_%d.png" % (i, j))
save_image(np.reshape(x_burst[i,:,:,j], (self.patch_sz, self.patch_sz, 1)), self.output_dir, "img_%d_burst_%d.png" % (i, j))
save_image(np.sum(x_burst[i], axis=2, keepdims=True)/self.burst_sz, self.output_dir, "img_%d_burstsum.png" % i)