-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
333 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
''' | ||
Demo code for imaging through turbulence simulation | ||
Z. Mao, N. Chimitt, and S. H. Chan, "Accerlerating Atmospheric Turbulence | ||
Simulation via Learned Phase-to-Space Transform", ICCV 2021 | ||
Arxiv: https://arxiv.org/abs/2107.11627 | ||
Zhiyuan Mao, Nicholas Chimitt, and Stanley H. Chan | ||
Copyright 2021 | ||
Purdue University, West Lafayette, IN, USA | ||
''' | ||
|
||
from simulator import Simulator | ||
from turbStats import tilt_mat, corr_mat | ||
import matplotlib.pyplot as plt | ||
import torch | ||
|
||
# Select device. | ||
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('CPU') | ||
|
||
|
||
''' | ||
The corr_mat function is used to generate spatial-temporal correlation matrix | ||
for point spread functions. It may take over 10 minutes to finish. However, | ||
for each correlation value, it only needs to be computed once and can be | ||
used for all D/r0 values. You can also download the pre-generated correlation | ||
matrix from our website. | ||
https://engineering.purdue.edu/ChanGroup/project_turbulence.html | ||
''' | ||
|
||
# Uncomment the following line to generate correlation matrix | ||
# corr_mat(-0.1,'./data/') | ||
|
||
# Generate correlation matrix for tilt. Do this once for each different turbulence parameter. | ||
tilt_mat(x.shape[1], 0.1, 0.05, 3000) | ||
|
||
# Load image, permute axis if color | ||
x = plt.imread('./images/color.png') | ||
if len(x.shape) == 3: | ||
x = x.transpose((2,0,1)) | ||
x = torch.tensor(x, device = device, dtype=torch.float32) | ||
|
||
# Simulate | ||
simulator = Simulator(2, 512).to(device, dtype=torch.float32) | ||
|
||
out = simulator(x).detach().cpu().numpy() | ||
|
||
if len(out.shape) == 3: | ||
out = out.transpose((1,2,0)) | ||
|
||
# save image | ||
plt.imsave('./images/out.png',out) | ||
|
||
|
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
''' | ||
This file contains the codes for the P2S module and the simulator. | ||
Z. Mao, N. Chimitt, and S. H. Chan, "Accerlerating Atmospheric Turbulence | ||
Simulation via Learned Phase-to-Space Transform", ICCV 2021 | ||
Arxiv: https://arxiv.org/abs/2107.11627 | ||
Zhiyuan Mao, Nicholas Chimitt, and Stanley H. Chan | ||
Copyright 2021 | ||
Purdue University, West Lafayette, IN, USA | ||
''' | ||
|
||
import torch, os | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import numpy as np | ||
|
||
class Simulator(nn.Module): | ||
''' | ||
Class variables: | ||
Dr0 - D/r0, which characterizes turbulence strength. suggested range: (1 ~ 5) | ||
img_size - Size of the image used for simulation. suggested value: (128,256,512,1024) | ||
corr - Correlation strength for PSF. suggested range: (-5 ~ -0.01) | ||
data_path - Path where the model, PSF dictionary, and correlation matrices are stored | ||
device - 'cuda:0' for GPU or 'CPU' | ||
scale - Used to artificially increase or decrase turbulence strength | ||
use_temp - If true, the correlation matrix for tilt will be loaded from S_half-temp.npy. | ||
''' | ||
def __init__(self, Dr0, img_size, corr = -0.1, data_path = './data', device = 'cuda:0', scale=1.0, use_temp = False): | ||
super().__init__() | ||
self.img_size = img_size | ||
self.initial_grid = 16 | ||
self.Dr0 = torch.tensor(Dr0) | ||
self.device = torch.device(device) | ||
self.Dr0 = torch.tensor(Dr0).to(self.device,dtype=torch.float32) | ||
self.mapping = _P2S() | ||
self.mapping.load_state_dict(torch.load(os.path.join(data_path,'P2S_model.pt'))) | ||
self.dict_psf = np.load(os.path.join(data_path,'dictionary.npy'), allow_pickle = True) | ||
self.mu = torch.tensor(self.dict_psf.item()['mu']).reshape((1,1,33,33)).to(self.device,dtype=torch.float32) | ||
self.dict_psf = torch.tensor(self.dict_psf.item()['dictionary'][:100,:]).reshape((100,1,33,33)) | ||
self.dict_psf = self.dict_psf.to(self.device,dtype=torch.float32) | ||
|
||
self.R = np.load(os.path.join(data_path,'R-corr_{}.npy'.format(corr))) | ||
self.R = torch.tensor(self.R).to(self.device,dtype=torch.float32) | ||
self.offset = torch.tensor([31,31]).to(self.device,dtype=torch.float32) | ||
|
||
if use_temp: | ||
self.S_half = np.load(os.path.join(data_path,'S_half-temp.npy'.format(img_size,Dr0)), allow_pickle=True) | ||
else: | ||
self.S_half = np.load(os.path.join(data_path,'S_half-size_{}-D_r0_{:.4f}.npy'.format(img_size,Dr0)), allow_pickle=True) | ||
self.const = self.S_half.item()['const'] | ||
self.S_half = torch.tensor(self.S_half.item()['s_half']).to(self.device,dtype=torch.float32) | ||
|
||
xx = torch.arange(0, img_size).view(1,-1).repeat(img_size,1) | ||
yy = torch.arange(0, img_size).view(-1,1).repeat(1,img_size) | ||
xx = xx.view(1,1,img_size,img_size).repeat(1,1,1,1) | ||
yy = yy.view(1,1,img_size,img_size).repeat(1,1,1,1) | ||
self.grid = torch.cat((xx,yy),1).permute(0,2,3,1).to(self.device,dtype=torch.float32) | ||
|
||
self.scale=scale | ||
|
||
|
||
def forward(self, img): | ||
img_pad = F.pad(img.view((-1,1,self.img_size,self.img_size)), (16,16,16,16), mode = 'reflect') | ||
img_mean = F.conv2d(img_pad, self.mu).squeeze() | ||
dict_img = F.conv2d(img_pad, self.dict_psf) | ||
random_ = torch.sqrt(self.Dr0 ** (5 / 3))*torch.randn((self.initial_grid**2*36),1,device=self.device) | ||
|
||
zer = torch.matmul(self.R,random_).view(self.initial_grid,self.initial_grid,36).permute(2,0,1).unsqueeze(0) | ||
zer = F.interpolate(zer,size=(self.img_size,self.img_size),mode='bilinear', align_corners=False) | ||
zer = zer * self.scale | ||
weight = self.mapping(zer.squeeze().permute(1,2,0).view(self.img_size**2,-1)) | ||
|
||
weight = weight.view((self.img_size,self.img_size,100)).permute(2,0,1)# target: (100,512,512) | ||
out = torch.sum(weight.unsqueeze(0)*dict_img,1) + img_mean | ||
|
||
MVx = torch.ifft((self.S_half*torch.randn(self.img_size,self.img_size,device=self.device)).permute(1,2,0),2) | ||
MVy = torch.ifft((self.S_half*torch.randn(self.img_size,self.img_size,device=self.device)).permute(1,2,0),2) | ||
pos = torch.stack((MVx[:,:,0],MVy[:,:,1]),2) * self.const | ||
flow = self.grid+pos | ||
flow = 2.0*flow / (self.img_size-1) - 1.0 | ||
out = F.grid_sample(out.view((1,-1,self.img_size,self.img_size)), flow, 'bilinear', padding_mode='border', align_corners=False).squeeze() | ||
|
||
return out | ||
|
||
class _P2S(nn.Module): | ||
def __init__(self, input_dim = 36, output_dim = 100): | ||
super().__init__() | ||
self.fc1 = nn.Linear(input_dim, 100) | ||
self.fc2 = nn.Linear(100, 100) | ||
self.fc3 = nn.Linear(100, 100) | ||
self.out = nn.Linear(100, output_dim) | ||
|
||
def forward(self, x): | ||
y = F.relu(self.fc1(x)) | ||
y = F.relu(self.fc2(y)) | ||
y = F.relu(self.fc2(y)) | ||
out = self.out(y) | ||
|
||
return out | ||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
''' | ||
This file contains the codes for generating correlation matrices for tilt | ||
and higher-order abberations. | ||
Z. Mao, N. Chimitt, and S. H. Chan, "Accerlerating Atmospheric Turbulence | ||
Simulation via Learned Phase-to-Space Transform", ICCV 2021 | ||
Arxiv: https://arxiv.org/abs/2107.11627 | ||
Zhiyuan Mao, Nicholas Chimitt, and Stanley H. Chan | ||
Copyright 2021 | ||
Purdue University, West Lafayette, IN, USA | ||
''' | ||
|
||
import math, os | ||
import numpy as np | ||
from scipy.special import jv | ||
import scipy.integrate as integrate | ||
|
||
def corr_mat(corr,save_path = './data'): | ||
''' | ||
This function generates the correlation matrix for point spread functions (higher-order abberations) | ||
The correlation matrix will be stored in specified path with format 'R-corr_{}.npy'.format(corr) | ||
Input: | ||
corr - Correlation strength. suggested range: (-5 ~ -0.01), with -5 has the | ||
weakest correlation and -0.01 has the strongest. | ||
save_path - Path to save the correlation matrix | ||
''' | ||
# corr: [-0.01, -0.1, -1, -5] | ||
num_zern=36 | ||
N_rows=16 | ||
N_cols=16 | ||
|
||
subC = _nollCovMat(num_zern, 1, 1) | ||
C = np.zeros((int(N_rows*N_cols*num_zern), int(N_rows*N_cols*num_zern))) | ||
|
||
dist = np.zeros((N_rows,N_cols)) | ||
for i in range(N_rows): | ||
for j in range(N_cols): | ||
for ii in range(N_rows): | ||
for jj in range(N_cols): | ||
if not (i == ii and j == jj): | ||
dist[ii,jj] = np.exp(corr*(np.linalg.norm(i - ii) + np.linalg.norm(j - jj))) | ||
else: | ||
dist[ii,jj] = 1 | ||
C[num_zern * (N_cols * i + j): num_zern * (N_cols * i + j + 1) \ | ||
, num_zern * (N_cols * ii + jj): num_zern * (N_cols * ii + jj + 1)] = dist[ii, jj] * subC | ||
|
||
e_val, e_vec = np.linalg.eig(C) | ||
|
||
R = np.real(e_vec * np.sqrt(e_val)) | ||
np.save(os.path.join(save_path,'R-corr_{}.npy'.format(corr)),R) | ||
|
||
|
||
def tilt_mat(N, D, r0, L, save_path = './data', thre = 0.002, adj = 1, use_temp = False): | ||
''' | ||
This function generates the correlation matrix for tilt | ||
The correlation matrix will be stored in specified path | ||
with format 'S_half-size_{}-D_r0_{:.4f}.npy'.format(N,D_r0) | ||
Input: | ||
N - Image size. suggested value: (128,256,512,1024) | ||
D - Apeture diameter | ||
r0 - Fried parameter | ||
L - Propogation distance | ||
save_path - Path to save the correlation matrix | ||
thre - Used to suppress small valus in the correlation matrix. Increase | ||
this threshold if the pixel displacement appears to be scattering | ||
use_temp - If true, the correlation matrix will be stored in S_half-temp.npy. | ||
''' | ||
# N: image size | ||
# D: Apeture diameter | ||
# r0: Fried parameter | ||
# L: Propagation distance | ||
D_r0 = D/r0 | ||
wavelength = 0.500e-6 | ||
k = 2*np.pi/wavelength | ||
delta0 = L*wavelength/(2*D) | ||
delta0 *= adj# Adjusting factor | ||
c1 = 2*((24/5)*math.gamma(6/5))**(5/6); | ||
c2 = 4*c1/np.pi*(math.gamma(11/6))**2; | ||
smax = delta0/D*N | ||
spacing = delta0/D | ||
I0_arr, I2_arr = _calculate_integral(smax, spacing) | ||
|
||
i, j = np.int32(N/2), np.int32(N/2) | ||
[x,y] = np.meshgrid(np.arange(1,N+0.01,1),np.arange(1,N+0.01,1)) | ||
s = np.sqrt((x-i)**2 + (y-j)**2) | ||
s *= spacing | ||
|
||
C0 = (_In_m(s, spacing, I0_arr) + _In_m(s, spacing, I2_arr))/_I0(0) | ||
C0[i,j] = 1 | ||
C0_scaled = C0*_I0(0)*c2*((D_r0)**(5/3))/(2**(5/3))*((2*wavelength/(np.pi*D))**2)*2*np.pi | ||
Cfft = np.fft.fft2(C0_scaled) | ||
S_half = np.sqrt(Cfft) | ||
S_half_max = np.max(np.max(np.abs(S_half))) | ||
S_half[np.abs(S_half) < thre*S_half_max] = 0 | ||
S_half_new = np.zeros((2,N,N)) | ||
S_half_new[0] = np.real(S_half) | ||
S_half_new[1] = np.imag(S_half) | ||
data = {} | ||
data['s_half'] = S_half_new | ||
data['const'] = np.sqrt(2)*N*(L/delta0) | ||
|
||
if use_temp: | ||
np.save(os.path.join(save_path,'S_half-temp.npy'.format(N,D_r0)),data) | ||
else: | ||
np.save(os.path.join(save_path,'S_half-size_{}-D_r0_{:.4f}.npy'.format(N,D_r0)),data) | ||
|
||
def _nollToZernInd(j): | ||
""" | ||
Authors: Tim van Werkhoven, Jason Saredy | ||
See: https://github.com/tvwerkhoven/libtim-py/blob/master/libtim/zern.py | ||
""" | ||
if (j == 0): | ||
raise ValueError("Noll indices start at 1, 0 is invalid.") | ||
n = 0 | ||
j1 = j-1 | ||
while (j1 > n): | ||
n += 1 | ||
j1 -= n | ||
m = (-1)**j * ((n % 2) + 2 * int((j1+((n+1)%2)) / 2.0 )) | ||
|
||
return n, m | ||
|
||
def _nollCovMat(Z, D, fried): | ||
C = np.zeros((Z,Z)) | ||
# Z: Number of Zernike Coeff's | ||
for i in range(Z): | ||
for j in range(Z): | ||
ni, mi = _nollToZernInd(i+1) | ||
nj, mj = _nollToZernInd(j+1) | ||
if (abs(mi) == abs(mj)) and (np.mod(i - j, 2) == 0): | ||
num = math.gamma(14.0/3.0) * math.gamma((ni + nj - 5.0/3.0)/2.0) | ||
den = math.gamma((-ni + nj + 17.0/3.0)/2.0) * math.gamma((ni - nj + 17.0/3.0)/2.0) * \ | ||
math.gamma((ni + nj + 23.0/3.0)/2.0) | ||
coef1 = 0.0072 * (np.pi ** (8.0/3.0)) * ((D/fried) ** (5.0/3.0)) * np.sqrt((ni + 1) * (nj + 1)) * \ | ||
((-1) ** ((ni + nj - 2*abs(mi))/2.0)) | ||
C[i, j] = coef1*num/den | ||
else: | ||
C[i, j] = 0 | ||
C[0,0] = 1 | ||
return C | ||
|
||
def _I0(s): | ||
I0_s, _ = integrate.quad( lambda z: (z**(-14/3))*jv(0,2*s*z)*(jv(2,z)**2), 0, 1e3, limit = 100000) | ||
|
||
return I0_s | ||
|
||
def _I2(s): | ||
I2_s, _ = integrate.quad( lambda z: (z**(-14/3))*jv(2,2*s*z)*(jv(2,z)**2), 0, 1e3, limit = 100000) | ||
|
||
return I2_s | ||
|
||
def _calculate_integral(s_max, spacing): | ||
s_arr = np.arange(0,s_max,spacing) | ||
I0_arr = np.float32(s_arr*0) | ||
I2_arr = np.float32(s_arr*0) | ||
for i in range(len(s_arr)): | ||
I0_arr[i] = _I0(s_arr[i]) | ||
I2_arr[i] = _I2(s_arr[i]) | ||
|
||
return I0_arr, I2_arr | ||
|
||
def _In_m(s, spacing, In_arr): | ||
idx = np.int32(np.floor(s.flatten()/spacing)) | ||
M,N = np.shape(s)[0], np.shape(s)[1] | ||
In = np.reshape(np.take(In_arr, idx), [M,N]) | ||
|
||
return In |