Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
mao114 authored Jul 28, 2021
1 parent 51e01a8 commit 80d4928
Show file tree
Hide file tree
Showing 7 changed files with 333 additions and 0 deletions.
Binary file added data/P2S_model.pt
Binary file not shown.
Binary file added data/dictionary.npy
Binary file not shown.
55 changes: 55 additions & 0 deletions demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
'''
Demo code for imaging through turbulence simulation
Z. Mao, N. Chimitt, and S. H. Chan, "Accerlerating Atmospheric Turbulence
Simulation via Learned Phase-to-Space Transform", ICCV 2021
Arxiv: https://arxiv.org/abs/2107.11627
Zhiyuan Mao, Nicholas Chimitt, and Stanley H. Chan
Copyright 2021
Purdue University, West Lafayette, IN, USA
'''

from simulator import Simulator
from turbStats import tilt_mat, corr_mat
import matplotlib.pyplot as plt
import torch

# Select device.
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('CPU')


'''
The corr_mat function is used to generate spatial-temporal correlation matrix
for point spread functions. It may take over 10 minutes to finish. However,
for each correlation value, it only needs to be computed once and can be
used for all D/r0 values. You can also download the pre-generated correlation
matrix from our website.
https://engineering.purdue.edu/ChanGroup/project_turbulence.html
'''

# Uncomment the following line to generate correlation matrix
# corr_mat(-0.1,'./data/')

# Generate correlation matrix for tilt. Do this once for each different turbulence parameter.
tilt_mat(x.shape[1], 0.1, 0.05, 3000)

# Load image, permute axis if color
x = plt.imread('./images/color.png')
if len(x.shape) == 3:
x = x.transpose((2,0,1))
x = torch.tensor(x, device = device, dtype=torch.float32)

# Simulate
simulator = Simulator(2, 512).to(device, dtype=torch.float32)

out = simulator(x).detach().cpu().numpy()

if len(out.shape) == 3:
out = out.transpose((1,2,0))

# save image
plt.imsave('./images/out.png',out)


Binary file added images/color.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/out.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
107 changes: 107 additions & 0 deletions simulator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
'''
This file contains the codes for the P2S module and the simulator.
Z. Mao, N. Chimitt, and S. H. Chan, "Accerlerating Atmospheric Turbulence
Simulation via Learned Phase-to-Space Transform", ICCV 2021
Arxiv: https://arxiv.org/abs/2107.11627
Zhiyuan Mao, Nicholas Chimitt, and Stanley H. Chan
Copyright 2021
Purdue University, West Lafayette, IN, USA
'''

import torch, os
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class Simulator(nn.Module):
'''
Class variables:
Dr0 - D/r0, which characterizes turbulence strength. suggested range: (1 ~ 5)
img_size - Size of the image used for simulation. suggested value: (128,256,512,1024)
corr - Correlation strength for PSF. suggested range: (-5 ~ -0.01)
data_path - Path where the model, PSF dictionary, and correlation matrices are stored
device - 'cuda:0' for GPU or 'CPU'
scale - Used to artificially increase or decrase turbulence strength
use_temp - If true, the correlation matrix for tilt will be loaded from S_half-temp.npy.
'''
def __init__(self, Dr0, img_size, corr = -0.1, data_path = './data', device = 'cuda:0', scale=1.0, use_temp = False):
super().__init__()
self.img_size = img_size
self.initial_grid = 16
self.Dr0 = torch.tensor(Dr0)
self.device = torch.device(device)
self.Dr0 = torch.tensor(Dr0).to(self.device,dtype=torch.float32)
self.mapping = _P2S()
self.mapping.load_state_dict(torch.load(os.path.join(data_path,'P2S_model.pt')))
self.dict_psf = np.load(os.path.join(data_path,'dictionary.npy'), allow_pickle = True)
self.mu = torch.tensor(self.dict_psf.item()['mu']).reshape((1,1,33,33)).to(self.device,dtype=torch.float32)
self.dict_psf = torch.tensor(self.dict_psf.item()['dictionary'][:100,:]).reshape((100,1,33,33))
self.dict_psf = self.dict_psf.to(self.device,dtype=torch.float32)

self.R = np.load(os.path.join(data_path,'R-corr_{}.npy'.format(corr)))
self.R = torch.tensor(self.R).to(self.device,dtype=torch.float32)
self.offset = torch.tensor([31,31]).to(self.device,dtype=torch.float32)

if use_temp:
self.S_half = np.load(os.path.join(data_path,'S_half-temp.npy'.format(img_size,Dr0)), allow_pickle=True)
else:
self.S_half = np.load(os.path.join(data_path,'S_half-size_{}-D_r0_{:.4f}.npy'.format(img_size,Dr0)), allow_pickle=True)
self.const = self.S_half.item()['const']
self.S_half = torch.tensor(self.S_half.item()['s_half']).to(self.device,dtype=torch.float32)

xx = torch.arange(0, img_size).view(1,-1).repeat(img_size,1)
yy = torch.arange(0, img_size).view(-1,1).repeat(1,img_size)
xx = xx.view(1,1,img_size,img_size).repeat(1,1,1,1)
yy = yy.view(1,1,img_size,img_size).repeat(1,1,1,1)
self.grid = torch.cat((xx,yy),1).permute(0,2,3,1).to(self.device,dtype=torch.float32)

self.scale=scale


def forward(self, img):
img_pad = F.pad(img.view((-1,1,self.img_size,self.img_size)), (16,16,16,16), mode = 'reflect')
img_mean = F.conv2d(img_pad, self.mu).squeeze()
dict_img = F.conv2d(img_pad, self.dict_psf)
random_ = torch.sqrt(self.Dr0 ** (5 / 3))*torch.randn((self.initial_grid**2*36),1,device=self.device)

zer = torch.matmul(self.R,random_).view(self.initial_grid,self.initial_grid,36).permute(2,0,1).unsqueeze(0)
zer = F.interpolate(zer,size=(self.img_size,self.img_size),mode='bilinear', align_corners=False)
zer = zer * self.scale
weight = self.mapping(zer.squeeze().permute(1,2,0).view(self.img_size**2,-1))

weight = weight.view((self.img_size,self.img_size,100)).permute(2,0,1)# target: (100,512,512)
out = torch.sum(weight.unsqueeze(0)*dict_img,1) + img_mean

MVx = torch.ifft((self.S_half*torch.randn(self.img_size,self.img_size,device=self.device)).permute(1,2,0),2)
MVy = torch.ifft((self.S_half*torch.randn(self.img_size,self.img_size,device=self.device)).permute(1,2,0),2)
pos = torch.stack((MVx[:,:,0],MVy[:,:,1]),2) * self.const
flow = self.grid+pos
flow = 2.0*flow / (self.img_size-1) - 1.0
out = F.grid_sample(out.view((1,-1,self.img_size,self.img_size)), flow, 'bilinear', padding_mode='border', align_corners=False).squeeze()

return out

class _P2S(nn.Module):
def __init__(self, input_dim = 36, output_dim = 100):
super().__init__()
self.fc1 = nn.Linear(input_dim, 100)
self.fc2 = nn.Linear(100, 100)
self.fc3 = nn.Linear(100, 100)
self.out = nn.Linear(100, output_dim)

def forward(self, x):
y = F.relu(self.fc1(x))
y = F.relu(self.fc2(y))
y = F.relu(self.fc2(y))
out = self.out(y)

return out






171 changes: 171 additions & 0 deletions turbStats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
'''
This file contains the codes for generating correlation matrices for tilt
and higher-order abberations.
Z. Mao, N. Chimitt, and S. H. Chan, "Accerlerating Atmospheric Turbulence
Simulation via Learned Phase-to-Space Transform", ICCV 2021
Arxiv: https://arxiv.org/abs/2107.11627
Zhiyuan Mao, Nicholas Chimitt, and Stanley H. Chan
Copyright 2021
Purdue University, West Lafayette, IN, USA
'''

import math, os
import numpy as np
from scipy.special import jv
import scipy.integrate as integrate

def corr_mat(corr,save_path = './data'):
'''
This function generates the correlation matrix for point spread functions (higher-order abberations)
The correlation matrix will be stored in specified path with format 'R-corr_{}.npy'.format(corr)
Input:
corr - Correlation strength. suggested range: (-5 ~ -0.01), with -5 has the
weakest correlation and -0.01 has the strongest.
save_path - Path to save the correlation matrix
'''
# corr: [-0.01, -0.1, -1, -5]
num_zern=36
N_rows=16
N_cols=16

subC = _nollCovMat(num_zern, 1, 1)
C = np.zeros((int(N_rows*N_cols*num_zern), int(N_rows*N_cols*num_zern)))

dist = np.zeros((N_rows,N_cols))
for i in range(N_rows):
for j in range(N_cols):
for ii in range(N_rows):
for jj in range(N_cols):
if not (i == ii and j == jj):
dist[ii,jj] = np.exp(corr*(np.linalg.norm(i - ii) + np.linalg.norm(j - jj)))
else:
dist[ii,jj] = 1
C[num_zern * (N_cols * i + j): num_zern * (N_cols * i + j + 1) \
, num_zern * (N_cols * ii + jj): num_zern * (N_cols * ii + jj + 1)] = dist[ii, jj] * subC

e_val, e_vec = np.linalg.eig(C)

R = np.real(e_vec * np.sqrt(e_val))
np.save(os.path.join(save_path,'R-corr_{}.npy'.format(corr)),R)


def tilt_mat(N, D, r0, L, save_path = './data', thre = 0.002, adj = 1, use_temp = False):
'''
This function generates the correlation matrix for tilt
The correlation matrix will be stored in specified path
with format 'S_half-size_{}-D_r0_{:.4f}.npy'.format(N,D_r0)
Input:
N - Image size. suggested value: (128,256,512,1024)
D - Apeture diameter
r0 - Fried parameter
L - Propogation distance
save_path - Path to save the correlation matrix
thre - Used to suppress small valus in the correlation matrix. Increase
this threshold if the pixel displacement appears to be scattering
use_temp - If true, the correlation matrix will be stored in S_half-temp.npy.
'''
# N: image size
# D: Apeture diameter
# r0: Fried parameter
# L: Propagation distance
D_r0 = D/r0
wavelength = 0.500e-6
k = 2*np.pi/wavelength
delta0 = L*wavelength/(2*D)
delta0 *= adj# Adjusting factor
c1 = 2*((24/5)*math.gamma(6/5))**(5/6);
c2 = 4*c1/np.pi*(math.gamma(11/6))**2;
smax = delta0/D*N
spacing = delta0/D
I0_arr, I2_arr = _calculate_integral(smax, spacing)

i, j = np.int32(N/2), np.int32(N/2)
[x,y] = np.meshgrid(np.arange(1,N+0.01,1),np.arange(1,N+0.01,1))
s = np.sqrt((x-i)**2 + (y-j)**2)
s *= spacing

C0 = (_In_m(s, spacing, I0_arr) + _In_m(s, spacing, I2_arr))/_I0(0)
C0[i,j] = 1
C0_scaled = C0*_I0(0)*c2*((D_r0)**(5/3))/(2**(5/3))*((2*wavelength/(np.pi*D))**2)*2*np.pi
Cfft = np.fft.fft2(C0_scaled)
S_half = np.sqrt(Cfft)
S_half_max = np.max(np.max(np.abs(S_half)))
S_half[np.abs(S_half) < thre*S_half_max] = 0
S_half_new = np.zeros((2,N,N))
S_half_new[0] = np.real(S_half)
S_half_new[1] = np.imag(S_half)
data = {}
data['s_half'] = S_half_new
data['const'] = np.sqrt(2)*N*(L/delta0)

if use_temp:
np.save(os.path.join(save_path,'S_half-temp.npy'.format(N,D_r0)),data)
else:
np.save(os.path.join(save_path,'S_half-size_{}-D_r0_{:.4f}.npy'.format(N,D_r0)),data)

def _nollToZernInd(j):
"""
Authors: Tim van Werkhoven, Jason Saredy
See: https://github.com/tvwerkhoven/libtim-py/blob/master/libtim/zern.py
"""
if (j == 0):
raise ValueError("Noll indices start at 1, 0 is invalid.")
n = 0
j1 = j-1
while (j1 > n):
n += 1
j1 -= n
m = (-1)**j * ((n % 2) + 2 * int((j1+((n+1)%2)) / 2.0 ))

return n, m

def _nollCovMat(Z, D, fried):
C = np.zeros((Z,Z))
# Z: Number of Zernike Coeff's
for i in range(Z):
for j in range(Z):
ni, mi = _nollToZernInd(i+1)
nj, mj = _nollToZernInd(j+1)
if (abs(mi) == abs(mj)) and (np.mod(i - j, 2) == 0):
num = math.gamma(14.0/3.0) * math.gamma((ni + nj - 5.0/3.0)/2.0)
den = math.gamma((-ni + nj + 17.0/3.0)/2.0) * math.gamma((ni - nj + 17.0/3.0)/2.0) * \
math.gamma((ni + nj + 23.0/3.0)/2.0)
coef1 = 0.0072 * (np.pi ** (8.0/3.0)) * ((D/fried) ** (5.0/3.0)) * np.sqrt((ni + 1) * (nj + 1)) * \
((-1) ** ((ni + nj - 2*abs(mi))/2.0))
C[i, j] = coef1*num/den
else:
C[i, j] = 0
C[0,0] = 1
return C

def _I0(s):
I0_s, _ = integrate.quad( lambda z: (z**(-14/3))*jv(0,2*s*z)*(jv(2,z)**2), 0, 1e3, limit = 100000)

return I0_s

def _I2(s):
I2_s, _ = integrate.quad( lambda z: (z**(-14/3))*jv(2,2*s*z)*(jv(2,z)**2), 0, 1e3, limit = 100000)

return I2_s

def _calculate_integral(s_max, spacing):
s_arr = np.arange(0,s_max,spacing)
I0_arr = np.float32(s_arr*0)
I2_arr = np.float32(s_arr*0)
for i in range(len(s_arr)):
I0_arr[i] = _I0(s_arr[i])
I2_arr[i] = _I2(s_arr[i])

return I0_arr, I2_arr

def _In_m(s, spacing, In_arr):
idx = np.int32(np.floor(s.flatten()/spacing))
M,N = np.shape(s)[0], np.shape(s)[1]
In = np.reshape(np.take(In_arr, idx), [M,N])

return In

0 comments on commit 80d4928

Please sign in to comment.