-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
106 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
''' | ||
This file contains the codes for the P2S module and the simulator. | ||
Z. Mao, N. Chimitt, and S. H. Chan, "Accerlerating Atmospheric Turbulence | ||
Simulation via Learned Phase-to-Space Transform", ICCV 2021 | ||
Arxiv: https://arxiv.org/abs/2107.11627 | ||
Zhiyuan Mao, Nicholas Chimitt, and Stanley H. Chan | ||
Copyright 2021 | ||
Purdue University, West Lafayette, IN, USA | ||
''' | ||
|
||
import torch, os | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import numpy as np | ||
|
||
class Simulator(nn.Module): | ||
''' | ||
Class variables: | ||
Dr0 - D/r0, which characterizes turbulence strength. suggested range: (1 ~ 5) | ||
img_size - Size of the image used for simulation. suggested value: (128,256,512,1024) | ||
corr - Correlation strength for PSF. suggested range: (-5 ~ -0.01) | ||
data_path - Path where the model, PSF dictionary, and correlation matrices are stored | ||
device - 'cuda:0' for GPU or 'CPU' | ||
scale - Used to artificially increase or decrase turbulence strength | ||
use_temp - If true, the correlation matrix for tilt will be loaded from S_half-temp.npy. | ||
''' | ||
def __init__(self, Dr0, img_size, corr = -0.1, data_path = './data', device = 'cuda:0', scale=1.0, use_temp = False): | ||
super().__init__() | ||
self.img_size = img_size | ||
self.initial_grid = 16 | ||
self.Dr0 = torch.tensor(Dr0) | ||
self.device = torch.device(device) | ||
self.Dr0 = torch.tensor(Dr0).to(self.device,dtype=torch.float32) | ||
self.mapping = _P2S() | ||
self.mapping.load_state_dict(torch.load(os.path.join(data_path,'P2S_model.pt'))) | ||
self.dict_psf = np.load(os.path.join(data_path,'dictionary.npy'), allow_pickle = True) | ||
self.mu = torch.tensor(self.dict_psf.item()['mu']).reshape((1,1,33,33)).to(self.device,dtype=torch.float32) | ||
self.dict_psf = torch.tensor(self.dict_psf.item()['dictionary'][:100,:]).reshape((100,1,33,33)) | ||
self.dict_psf = self.dict_psf.to(self.device,dtype=torch.float32) | ||
|
||
self.R = np.load(os.path.join(data_path,'R-corr_{}.npy'.format(corr))) | ||
self.R = torch.tensor(self.R).to(self.device,dtype=torch.float32) | ||
self.offset = torch.tensor([31,31]).to(self.device,dtype=torch.float32) | ||
|
||
if use_temp: | ||
self.S_half = np.load(os.path.join(data_path,'S_half-temp.npy'.format(img_size,Dr0)), allow_pickle=True) | ||
else: | ||
self.S_half = np.load(os.path.join(data_path,'S_half-size_{}-D_r0_{:.4f}.npy'.format(img_size,Dr0)), allow_pickle=True) | ||
self.const = self.S_half.item()['const'] | ||
self.S_half = torch.tensor(self.S_half.item()['s_half']).to(self.device,dtype=torch.float32) | ||
|
||
xx = torch.arange(0, img_size).view(1,-1).repeat(img_size,1) | ||
yy = torch.arange(0, img_size).view(-1,1).repeat(1,img_size) | ||
xx = xx.view(1,1,img_size,img_size).repeat(1,1,1,1) | ||
yy = yy.view(1,1,img_size,img_size).repeat(1,1,1,1) | ||
self.grid = torch.cat((xx,yy),1).permute(0,2,3,1).to(self.device,dtype=torch.float32) | ||
|
||
self.scale=scale | ||
|
||
|
||
def forward(self, img): | ||
img_pad = F.pad(img.view((-1,1,self.img_size,self.img_size)), (16,16,16,16), mode = 'reflect') | ||
img_mean = F.conv2d(img_pad, self.mu).squeeze() | ||
dict_img = F.conv2d(img_pad, self.dict_psf) | ||
random_ = torch.sqrt(self.Dr0 ** (5 / 3))*torch.randn((self.initial_grid**2*36),1,device=self.device) | ||
|
||
zer = torch.matmul(self.R,random_).view(self.initial_grid,self.initial_grid,36).permute(2,0,1).unsqueeze(0) | ||
zer = F.interpolate(zer,size=(self.img_size,self.img_size),mode='bilinear', align_corners=False) | ||
zer = zer * self.scale | ||
weight = self.mapping(zer.squeeze().permute(1,2,0).view(self.img_size**2,-1)) | ||
|
||
weight = weight.view((self.img_size,self.img_size,100)).permute(2,0,1)# target: (100,512,512) | ||
out = torch.sum(weight.unsqueeze(0)*dict_img,1) + img_mean | ||
|
||
pos = torch.fft.irfft2((self.S_half.permute(1, 2, 0).unsqueeze(0) * torch.randn(1, self.img_size, | ||
self.img_size, 2, device=self.device)), s=(self.img_size,self.img_size), dim=(1,2)) * self.const | ||
|
||
flow = 2.0*(self.grid+pos) / (self.img_size-1) - 1.0 | ||
out = F.grid_sample(out.view((1,-1,self.img_size,self.img_size)), flow, 'bilinear', padding_mode='border', align_corners=False).squeeze() | ||
|
||
return out | ||
|
||
class _P2S(nn.Module): | ||
def __init__(self, input_dim = 36, output_dim = 100): | ||
super().__init__() | ||
self.fc1 = nn.Linear(input_dim, 100) | ||
self.fc2 = nn.Linear(100, 100) | ||
self.fc3 = nn.Linear(100, 100) | ||
self.out = nn.Linear(100, output_dim) | ||
|
||
def forward(self, x): | ||
y = F.relu(self.fc1(x)) | ||
y = F.relu(self.fc2(y)) | ||
y = F.relu(self.fc2(y)) | ||
out = self.out(y) | ||
|
||
return out | ||
|
||
|
||
|
||
|
||
|
||
|