diff --git a/demo.py b/demo.py index dbb5912..3105f51 100644 --- a/demo.py +++ b/demo.py @@ -32,15 +32,16 @@ # Uncomment the following line to generate correlation matrix # corr_mat(-0.1,'./data/') -# Generate correlation matrix for tilt. Do this once for each different turbulence parameter. -tilt_mat(x.shape[1], 0.1, 0.05, 3000) - # Load image, permute axis if color x = plt.imread('./images/color.png') + if len(x.shape) == 3: x = x.transpose((2,0,1)) x = torch.tensor(x, device = device, dtype=torch.float32) +# Generate correlation matrix for tilt. Do this once for each different turbulence parameter. +tilt_mat(x.shape[1], 0.1, 0.05, 3000) + # Simulate simulator = Simulator(2, 512).to(device, dtype=torch.float32)