Skip to content
Permalink
master
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
import torch
import torch.nn as nn
from networks import Generator, Discriminator, weights_init_normal
from torch.utils.data import DataLoader
from dataloader import ImageFolder
import torchvision.transforms as transforms
import os
from torch.optim import lr_scheduler, Adam
from torch import compile
import time
from PIL import Image
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cuda.matmul.allow_tf32 = True
NUM_WORKERS = 1
img_size = 256
# sets the size of the receptive field i.e output dim of discriminator
if img_size == 128:
recp = 14
elif img_size == 256:
recp = 30
# define transforms
transform = transforms.Compose([
transforms.Resize(int(img_size*1.12), Image.BICUBIC),
transforms.RandomCrop(img_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
root_dir = 'data'
sub_dir = 'maps'
batch_size = 1
n_block = 9
n_epochs = n_epochs_decay = 75 #change as per required
total_epochs = n_epochs + n_epochs_decay
lambda_cyc = 10.0
lr = 0.0002
# set dataloaders
data_X = ImageFolder(os.path.join(root_dir, sub_dir, 'trainA'), transform)
data_Y = ImageFolder(os.path.join(root_dir, sub_dir, 'trainB'), transform)
dl_X = DataLoader(data_X, batch_size, shuffle=True,
num_workers=NUM_WORKERS, drop_last=True)
dl_Y = DataLoader(data_Y, batch_size, shuffle=True,
num_workers=NUM_WORKERS, drop_last=True)
# initialize networks
G = compile(Generator(n_block=n_block).to(device))
F = compile(Generator(n_block=n_block).to(device))
D_X = compile(Discriminator().to(device))
D_Y = compile(Discriminator().to(device))
# initilialize weights to random normal
G.apply(weights_init_normal)
F.apply(weights_init_normal)
D_X.apply(weights_init_normal)
D_Y.apply(weights_init_normal)
# criterion for losses
crit_GAN = nn.MSELoss()
crit_cyc = nn.L1Loss()
# composite optimizers
optim_GENS = Adam(
params=list(G.parameters()) + list(F.parameters()),
lr=lr,
betas=(0.5, 0.999)
)
optim_DISC_X = Adam(
params=D_X.parameters(),
lr=lr,
betas=(0.5, 0.999)
)
optim_DISC_Y = Adam(
params=D_Y.parameters(),
lr=lr,
betas=(0.5, 0.999)
)
# set learning rate scheduler, decays lr linearly to 0 after n epochs
def lambda_rule(epoch):
lr_l = 1.0 - max(0, 1 + epoch - n_epochs) / \
float(n_epochs_decay + 1)
return lr_l
scheduler_GENS = lr_scheduler.LambdaLR(optim_GENS, lr_lambda=lambda_rule)
scheduler_DISC_X = lr_scheduler.LambdaLR(
optim_DISC_X, lr_lambda=lambda_rule)
scheduler_DISC_Y = lr_scheduler.LambdaLR(
optim_DISC_Y, lr_lambda=lambda_rule)
# initialize true and false labels for discriminator
true_label = torch.ones((batch_size, 1, recp, recp)).to(device)
false_label = torch.zeros((batch_size, 1, recp, recp)).to(device)
# start training
for epoch in range(n_epochs+n_epochs_decay):
start = time.time()
for i, (imgX, imgY) in enumerate(zip(dl_X, dl_Y)):
r_X = imgX.to(device)
r_Y = imgY.to(device)
# G(X) -> Y' ; F(Y)-> X'
f_Y = G(r_X)
f_X = F(r_Y)
# Generator Training
optim_GENS.zero_grad()
# adversarial losses for GAN
loss_G = crit_GAN(D_Y(f_Y), true_label)
loss_F = crit_GAN(D_X(f_X), true_label)
loss_gens = (loss_G + loss_F)
# cycle consistency loss
rec_X = F(f_Y)
rec_Y = G(f_X)
loss_cyc_X = crit_cyc(rec_X, r_X)
loss_cyc_Y = crit_cyc(rec_Y, r_Y)
loss_cyc = (loss_cyc_X + loss_cyc_Y)
loss_GAN = loss_gens + lambda_cyc*loss_cyc
loss_GAN.backward()
optim_GENS.step()
# Discriminator training
# Discriminator X
D_X.requires_grad_()
optim_DISC_X.zero_grad()
pred_r_X = D_X(r_X)
loss_D_X_real = crit_GAN(pred_r_X, true_label)
pred_f_X = D_X(f_X.detach())
loss_D_X_fake = crit_GAN(pred_f_X, false_label)
loss_D_X = (loss_D_X_real + loss_D_X_fake)*0.5
loss_D_X.backward()
optim_DISC_X.step()
# Discriminator Y
D_Y.requires_grad_()
optim_DISC_Y.zero_grad()
pred_r_Y = D_Y(r_Y)
loss_D_Y_real = crit_GAN(pred_r_Y, true_label)
pred_f_Y = D_Y(f_Y.detach())
loss_D_Y_fake = crit_GAN(pred_f_Y, false_label)
loss_D_Y = (loss_D_Y_real + loss_D_Y_fake)*0.5
loss_D_Y.backward()
optim_DISC_Y.step()
loss_D = (loss_D_X + loss_D_Y)
# print losses
if i % 500 == 0:
print(f'[Epoch {epoch+1}/{total_epochs}] [Batch {i+1}/{len(dl_X)}] [D loss : {loss_D.item():.6f}] [G loss : {loss_GAN.item():.6f} - (adv : {loss_gens.item():.6f}, cycle : {loss_cyc.item():.6f})]')
# update lr scheduler after every epoch
scheduler_GENS.step()
scheduler_DISC_X.step()
scheduler_DISC_Y.step()
end = time.time()
print(
f'[Time taken for epoch {epoch+1}/{total_epochs}: {int(end-start)}s]')
# Save the trained models
torch.save(G.state_dict(),
f'models/{sub_dir}/G_{total_epochs}_{batch_size}.pth')
torch.save(F.state_dict(),
f'models/{sub_dir}/F_{total_epochs}_{batch_size}.pth')