Skip to content
Permalink
master
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
@mao114
Latest commit dfd421d Jan 15, 2023 History
1 contributor

Users who have contributed to this file

'''
This file contains the codes for the P2S module and the simulator.
Z. Mao, N. Chimitt, and S. H. Chan, "Accerlerating Atmospheric Turbulence
Simulation via Learned Phase-to-Space Transform", ICCV 2021
Arxiv: https://arxiv.org/abs/2107.11627
Zhiyuan Mao, Nicholas Chimitt, and Stanley H. Chan
Copyright 2021
Purdue University, West Lafayette, IN, USA
'''
import torch, os
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class Simulator(nn.Module):
'''
Class variables:
Dr0 - D/r0, which characterizes turbulence strength. suggested range: (1 ~ 5)
img_size - Size of the image used for simulation. suggested value: (128,256,512,1024)
corr - Correlation strength for PSF. suggested range: (-5 ~ -0.01)
data_path - Path where the model, PSF dictionary, and correlation matrices are stored
device - 'cuda:0' for GPU or 'CPU'
scale - Used to artificially increase or decrase turbulence strength
use_temp - If true, the correlation matrix for tilt will be loaded from S_half-temp.npy.
'''
def __init__(self, Dr0, img_size, corr = -0.1, data_path = './data', device = 'cuda:0', scale=1.0, use_temp = False):
super().__init__()
self.img_size = img_size
self.initial_grid = 16
self.Dr0 = torch.tensor(Dr0)
self.device = torch.device(device)
self.Dr0 = torch.tensor(Dr0).to(self.device,dtype=torch.float32)
self.mapping = _P2S()
self.mapping.load_state_dict(torch.load(os.path.join(data_path,'P2S_model.pt')))
self.dict_psf = np.load(os.path.join(data_path,'dictionary.npy'), allow_pickle = True)
self.mu = torch.tensor(self.dict_psf.item()['mu']).reshape((1,1,33,33)).to(self.device,dtype=torch.float32)
self.dict_psf = torch.tensor(self.dict_psf.item()['dictionary'][:100,:]).reshape((100,1,33,33))
self.dict_psf = self.dict_psf.to(self.device,dtype=torch.float32)
self.R = np.load(os.path.join(data_path,'R-corr_{}.npy'.format(corr)))
self.R = torch.tensor(self.R).to(self.device,dtype=torch.float32)
self.offset = torch.tensor([31,31]).to(self.device,dtype=torch.float32)
if use_temp:
self.S_half = np.load(os.path.join(data_path,'S_half-temp.npy'.format(img_size,Dr0)), allow_pickle=True)
else:
self.S_half = np.load(os.path.join(data_path,'S_half-size_{}-D_r0_{:.4f}.npy'.format(img_size,Dr0)), allow_pickle=True)
self.const = self.S_half.item()['const']
self.S_half = torch.tensor(self.S_half.item()['s_half']).to(self.device,dtype=torch.float32)
xx = torch.arange(0, img_size).view(1,-1).repeat(img_size,1)
yy = torch.arange(0, img_size).view(-1,1).repeat(1,img_size)
xx = xx.view(1,1,img_size,img_size).repeat(1,1,1,1)
yy = yy.view(1,1,img_size,img_size).repeat(1,1,1,1)
self.grid = torch.cat((xx,yy),1).permute(0,2,3,1).to(self.device,dtype=torch.float32)
self.scale=scale
def forward(self, img):
img_pad = F.pad(img.view((-1,1,self.img_size,self.img_size)), (16,16,16,16), mode = 'reflect')
img_mean = F.conv2d(img_pad, self.mu).squeeze()
dict_img = F.conv2d(img_pad, self.dict_psf)
random_ = torch.sqrt(self.Dr0 ** (5 / 3))*torch.randn((self.initial_grid**2*36),1,device=self.device)
zer = torch.matmul(self.R,random_).view(self.initial_grid,self.initial_grid,36).permute(2,0,1).unsqueeze(0)
zer = F.interpolate(zer,size=(self.img_size,self.img_size),mode='bilinear', align_corners=False)
zer = zer * self.scale
weight = self.mapping(zer.squeeze().permute(1,2,0).view(self.img_size**2,-1))
weight = weight.view((self.img_size,self.img_size,100)).permute(2,0,1)# target: (100,512,512)
out = torch.sum(weight.unsqueeze(0)*dict_img,1) + img_mean
pos = torch.fft.irfft2((self.S_half.permute(1, 2, 0).unsqueeze(0) * torch.randn(1, self.img_size,
self.img_size, 2, device=self.device)), s=(self.img_size,self.img_size), dim=(1,2)) * self.const
flow = 2.0*(self.grid+pos) / (self.img_size-1) - 1.0
out = F.grid_sample(out.view((1,-1,self.img_size,self.img_size)), flow, 'bilinear', padding_mode='border', align_corners=False).squeeze()
return out
class _P2S(nn.Module):
def __init__(self, input_dim = 36, output_dim = 100):
super().__init__()
self.fc1 = nn.Linear(input_dim, 100)
self.fc2 = nn.Linear(100, 100)
self.fc3 = nn.Linear(100, 100)
self.out = nn.Linear(100, output_dim)
def forward(self, x):
y = F.relu(self.fc1(x))
y = F.relu(self.fc2(y))
y = F.relu(self.fc2(y))
out = self.out(y)
return out