diff --git a/data/P2S_model.pt b/data/P2S_model.pt new file mode 100644 index 0000000..b4618ff Binary files /dev/null and b/data/P2S_model.pt differ diff --git a/data/dictionary.npy b/data/dictionary.npy new file mode 100644 index 0000000..312c352 Binary files /dev/null and b/data/dictionary.npy differ diff --git a/demo.py b/demo.py new file mode 100644 index 0000000..dbb5912 --- /dev/null +++ b/demo.py @@ -0,0 +1,55 @@ +''' +Demo code for imaging through turbulence simulation + +Z. Mao, N. Chimitt, and S. H. Chan, "Accerlerating Atmospheric Turbulence +Simulation via Learned Phase-to-Space Transform", ICCV 2021 + +Arxiv: https://arxiv.org/abs/2107.11627 + +Zhiyuan Mao, Nicholas Chimitt, and Stanley H. Chan +Copyright 2021 +Purdue University, West Lafayette, IN, USA +''' + +from simulator import Simulator +from turbStats import tilt_mat, corr_mat +import matplotlib.pyplot as plt +import torch + +# Select device. +device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('CPU') + + +''' +The corr_mat function is used to generate spatial-temporal correlation matrix +for point spread functions. It may take over 10 minutes to finish. However, +for each correlation value, it only needs to be computed once and can be +used for all D/r0 values. You can also download the pre-generated correlation +matrix from our website. +https://engineering.purdue.edu/ChanGroup/project_turbulence.html +''' + +# Uncomment the following line to generate correlation matrix +# corr_mat(-0.1,'./data/') + +# Generate correlation matrix for tilt. Do this once for each different turbulence parameter. +tilt_mat(x.shape[1], 0.1, 0.05, 3000) + +# Load image, permute axis if color +x = plt.imread('./images/color.png') +if len(x.shape) == 3: + x = x.transpose((2,0,1)) +x = torch.tensor(x, device = device, dtype=torch.float32) + +# Simulate +simulator = Simulator(2, 512).to(device, dtype=torch.float32) + +out = simulator(x).detach().cpu().numpy() + +if len(out.shape) == 3: + out = out.transpose((1,2,0)) + +# save image +plt.imsave('./images/out.png',out) + + diff --git a/images/color.png b/images/color.png new file mode 100644 index 0000000..3dab66b Binary files /dev/null and b/images/color.png differ diff --git a/images/out.png b/images/out.png new file mode 100644 index 0000000..d21f10a Binary files /dev/null and b/images/out.png differ diff --git a/simulator.py b/simulator.py new file mode 100644 index 0000000..cd2299d --- /dev/null +++ b/simulator.py @@ -0,0 +1,107 @@ +''' +This file contains the codes for the P2S module and the simulator. + +Z. Mao, N. Chimitt, and S. H. Chan, "Accerlerating Atmospheric Turbulence +Simulation via Learned Phase-to-Space Transform", ICCV 2021 + +Arxiv: https://arxiv.org/abs/2107.11627 + +Zhiyuan Mao, Nicholas Chimitt, and Stanley H. Chan +Copyright 2021 +Purdue University, West Lafayette, IN, USA +''' + +import torch, os +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +class Simulator(nn.Module): + ''' + Class variables: + Dr0 - D/r0, which characterizes turbulence strength. suggested range: (1 ~ 5) + img_size - Size of the image used for simulation. suggested value: (128,256,512,1024) + corr - Correlation strength for PSF. suggested range: (-5 ~ -0.01) + data_path - Path where the model, PSF dictionary, and correlation matrices are stored + device - 'cuda:0' for GPU or 'CPU' + scale - Used to artificially increase or decrase turbulence strength + use_temp - If true, the correlation matrix for tilt will be loaded from S_half-temp.npy. + ''' + def __init__(self, Dr0, img_size, corr = -0.1, data_path = './data', device = 'cuda:0', scale=1.0, use_temp = False): + super().__init__() + self.img_size = img_size + self.initial_grid = 16 + self.Dr0 = torch.tensor(Dr0) + self.device = torch.device(device) + self.Dr0 = torch.tensor(Dr0).to(self.device,dtype=torch.float32) + self.mapping = _P2S() + self.mapping.load_state_dict(torch.load(os.path.join(data_path,'P2S_model.pt'))) + self.dict_psf = np.load(os.path.join(data_path,'dictionary.npy'), allow_pickle = True) + self.mu = torch.tensor(self.dict_psf.item()['mu']).reshape((1,1,33,33)).to(self.device,dtype=torch.float32) + self.dict_psf = torch.tensor(self.dict_psf.item()['dictionary'][:100,:]).reshape((100,1,33,33)) + self.dict_psf = self.dict_psf.to(self.device,dtype=torch.float32) + + self.R = np.load(os.path.join(data_path,'R-corr_{}.npy'.format(corr))) + self.R = torch.tensor(self.R).to(self.device,dtype=torch.float32) + self.offset = torch.tensor([31,31]).to(self.device,dtype=torch.float32) + + if use_temp: + self.S_half = np.load(os.path.join(data_path,'S_half-temp.npy'.format(img_size,Dr0)), allow_pickle=True) + else: + self.S_half = np.load(os.path.join(data_path,'S_half-size_{}-D_r0_{:.4f}.npy'.format(img_size,Dr0)), allow_pickle=True) + self.const = self.S_half.item()['const'] + self.S_half = torch.tensor(self.S_half.item()['s_half']).to(self.device,dtype=torch.float32) + + xx = torch.arange(0, img_size).view(1,-1).repeat(img_size,1) + yy = torch.arange(0, img_size).view(-1,1).repeat(1,img_size) + xx = xx.view(1,1,img_size,img_size).repeat(1,1,1,1) + yy = yy.view(1,1,img_size,img_size).repeat(1,1,1,1) + self.grid = torch.cat((xx,yy),1).permute(0,2,3,1).to(self.device,dtype=torch.float32) + + self.scale=scale + + + def forward(self, img): + img_pad = F.pad(img.view((-1,1,self.img_size,self.img_size)), (16,16,16,16), mode = 'reflect') + img_mean = F.conv2d(img_pad, self.mu).squeeze() + dict_img = F.conv2d(img_pad, self.dict_psf) + random_ = torch.sqrt(self.Dr0 ** (5 / 3))*torch.randn((self.initial_grid**2*36),1,device=self.device) + + zer = torch.matmul(self.R,random_).view(self.initial_grid,self.initial_grid,36).permute(2,0,1).unsqueeze(0) + zer = F.interpolate(zer,size=(self.img_size,self.img_size),mode='bilinear', align_corners=False) + zer = zer * self.scale + weight = self.mapping(zer.squeeze().permute(1,2,0).view(self.img_size**2,-1)) + + weight = weight.view((self.img_size,self.img_size,100)).permute(2,0,1)# target: (100,512,512) + out = torch.sum(weight.unsqueeze(0)*dict_img,1) + img_mean + + MVx = torch.ifft((self.S_half*torch.randn(self.img_size,self.img_size,device=self.device)).permute(1,2,0),2) + MVy = torch.ifft((self.S_half*torch.randn(self.img_size,self.img_size,device=self.device)).permute(1,2,0),2) + pos = torch.stack((MVx[:,:,0],MVy[:,:,1]),2) * self.const + flow = self.grid+pos + flow = 2.0*flow / (self.img_size-1) - 1.0 + out = F.grid_sample(out.view((1,-1,self.img_size,self.img_size)), flow, 'bilinear', padding_mode='border', align_corners=False).squeeze() + + return out + +class _P2S(nn.Module): + def __init__(self, input_dim = 36, output_dim = 100): + super().__init__() + self.fc1 = nn.Linear(input_dim, 100) + self.fc2 = nn.Linear(100, 100) + self.fc3 = nn.Linear(100, 100) + self.out = nn.Linear(100, output_dim) + + def forward(self, x): + y = F.relu(self.fc1(x)) + y = F.relu(self.fc2(y)) + y = F.relu(self.fc2(y)) + out = self.out(y) + + return out + + + + + + diff --git a/turbStats.py b/turbStats.py new file mode 100644 index 0000000..b96e114 --- /dev/null +++ b/turbStats.py @@ -0,0 +1,171 @@ +''' +This file contains the codes for generating correlation matrices for tilt +and higher-order abberations. + +Z. Mao, N. Chimitt, and S. H. Chan, "Accerlerating Atmospheric Turbulence +Simulation via Learned Phase-to-Space Transform", ICCV 2021 + +Arxiv: https://arxiv.org/abs/2107.11627 + +Zhiyuan Mao, Nicholas Chimitt, and Stanley H. Chan +Copyright 2021 +Purdue University, West Lafayette, IN, USA +''' + +import math, os +import numpy as np +from scipy.special import jv +import scipy.integrate as integrate + +def corr_mat(corr,save_path = './data'): + ''' + This function generates the correlation matrix for point spread functions (higher-order abberations) + The correlation matrix will be stored in specified path with format 'R-corr_{}.npy'.format(corr) + + Input: + corr - Correlation strength. suggested range: (-5 ~ -0.01), with -5 has the + weakest correlation and -0.01 has the strongest. + save_path - Path to save the correlation matrix + ''' + # corr: [-0.01, -0.1, -1, -5] + num_zern=36 + N_rows=16 + N_cols=16 + + subC = _nollCovMat(num_zern, 1, 1) + C = np.zeros((int(N_rows*N_cols*num_zern), int(N_rows*N_cols*num_zern))) + + dist = np.zeros((N_rows,N_cols)) + for i in range(N_rows): + for j in range(N_cols): + for ii in range(N_rows): + for jj in range(N_cols): + if not (i == ii and j == jj): + dist[ii,jj] = np.exp(corr*(np.linalg.norm(i - ii) + np.linalg.norm(j - jj))) + else: + dist[ii,jj] = 1 + C[num_zern * (N_cols * i + j): num_zern * (N_cols * i + j + 1) \ + , num_zern * (N_cols * ii + jj): num_zern * (N_cols * ii + jj + 1)] = dist[ii, jj] * subC + + e_val, e_vec = np.linalg.eig(C) + + R = np.real(e_vec * np.sqrt(e_val)) + np.save(os.path.join(save_path,'R-corr_{}.npy'.format(corr)),R) + + +def tilt_mat(N, D, r0, L, save_path = './data', thre = 0.002, adj = 1, use_temp = False): + ''' + This function generates the correlation matrix for tilt + The correlation matrix will be stored in specified path + with format 'S_half-size_{}-D_r0_{:.4f}.npy'.format(N,D_r0) + + Input: + N - Image size. suggested value: (128,256,512,1024) + D - Apeture diameter + r0 - Fried parameter + L - Propogation distance + save_path - Path to save the correlation matrix + thre - Used to suppress small valus in the correlation matrix. Increase + this threshold if the pixel displacement appears to be scattering + use_temp - If true, the correlation matrix will be stored in S_half-temp.npy. + ''' + # N: image size + # D: Apeture diameter + # r0: Fried parameter + # L: Propagation distance + D_r0 = D/r0 + wavelength = 0.500e-6 + k = 2*np.pi/wavelength + delta0 = L*wavelength/(2*D) + delta0 *= adj# Adjusting factor + c1 = 2*((24/5)*math.gamma(6/5))**(5/6); + c2 = 4*c1/np.pi*(math.gamma(11/6))**2; + smax = delta0/D*N + spacing = delta0/D + I0_arr, I2_arr = _calculate_integral(smax, spacing) + + i, j = np.int32(N/2), np.int32(N/2) + [x,y] = np.meshgrid(np.arange(1,N+0.01,1),np.arange(1,N+0.01,1)) + s = np.sqrt((x-i)**2 + (y-j)**2) + s *= spacing + + C0 = (_In_m(s, spacing, I0_arr) + _In_m(s, spacing, I2_arr))/_I0(0) + C0[i,j] = 1 + C0_scaled = C0*_I0(0)*c2*((D_r0)**(5/3))/(2**(5/3))*((2*wavelength/(np.pi*D))**2)*2*np.pi + Cfft = np.fft.fft2(C0_scaled) + S_half = np.sqrt(Cfft) + S_half_max = np.max(np.max(np.abs(S_half))) + S_half[np.abs(S_half) < thre*S_half_max] = 0 + S_half_new = np.zeros((2,N,N)) + S_half_new[0] = np.real(S_half) + S_half_new[1] = np.imag(S_half) + data = {} + data['s_half'] = S_half_new + data['const'] = np.sqrt(2)*N*(L/delta0) + + if use_temp: + np.save(os.path.join(save_path,'S_half-temp.npy'.format(N,D_r0)),data) + else: + np.save(os.path.join(save_path,'S_half-size_{}-D_r0_{:.4f}.npy'.format(N,D_r0)),data) + +def _nollToZernInd(j): + """ + Authors: Tim van Werkhoven, Jason Saredy + See: https://github.com/tvwerkhoven/libtim-py/blob/master/libtim/zern.py + """ + if (j == 0): + raise ValueError("Noll indices start at 1, 0 is invalid.") + n = 0 + j1 = j-1 + while (j1 > n): + n += 1 + j1 -= n + m = (-1)**j * ((n % 2) + 2 * int((j1+((n+1)%2)) / 2.0 )) + + return n, m + +def _nollCovMat(Z, D, fried): + C = np.zeros((Z,Z)) + # Z: Number of Zernike Coeff's + for i in range(Z): + for j in range(Z): + ni, mi = _nollToZernInd(i+1) + nj, mj = _nollToZernInd(j+1) + if (abs(mi) == abs(mj)) and (np.mod(i - j, 2) == 0): + num = math.gamma(14.0/3.0) * math.gamma((ni + nj - 5.0/3.0)/2.0) + den = math.gamma((-ni + nj + 17.0/3.0)/2.0) * math.gamma((ni - nj + 17.0/3.0)/2.0) * \ + math.gamma((ni + nj + 23.0/3.0)/2.0) + coef1 = 0.0072 * (np.pi ** (8.0/3.0)) * ((D/fried) ** (5.0/3.0)) * np.sqrt((ni + 1) * (nj + 1)) * \ + ((-1) ** ((ni + nj - 2*abs(mi))/2.0)) + C[i, j] = coef1*num/den + else: + C[i, j] = 0 + C[0,0] = 1 + return C + +def _I0(s): + I0_s, _ = integrate.quad( lambda z: (z**(-14/3))*jv(0,2*s*z)*(jv(2,z)**2), 0, 1e3, limit = 100000) + + return I0_s + +def _I2(s): + I2_s, _ = integrate.quad( lambda z: (z**(-14/3))*jv(2,2*s*z)*(jv(2,z)**2), 0, 1e3, limit = 100000) + + return I2_s + +def _calculate_integral(s_max, spacing): + s_arr = np.arange(0,s_max,spacing) + I0_arr = np.float32(s_arr*0) + I2_arr = np.float32(s_arr*0) + for i in range(len(s_arr)): + I0_arr[i] = _I0(s_arr[i]) + I2_arr[i] = _I2(s_arr[i]) + + return I0_arr, I2_arr + +def _In_m(s, spacing, In_arr): + idx = np.int32(np.floor(s.flatten()/spacing)) + M,N = np.shape(s)[0], np.shape(s)[1] + In = np.reshape(np.take(In_arr, idx), [M,N]) + + return In