diff --git a/simulator.py b/simulator.py index cd2299d..c3f9efb 100644 --- a/simulator.py +++ b/simulator.py @@ -77,7 +77,7 @@ def forward(self, img): MVx = torch.ifft((self.S_half*torch.randn(self.img_size,self.img_size,device=self.device)).permute(1,2,0),2) MVy = torch.ifft((self.S_half*torch.randn(self.img_size,self.img_size,device=self.device)).permute(1,2,0),2) - pos = torch.stack((MVx[:,:,0],MVy[:,:,1]),2) * self.const + pos = torch.stack((MVx[:,:,0],MVy[:,:,0]),2) * self.const flow = self.grid+pos flow = 2.0*flow / (self.img_size-1) - 1.0 out = F.grid_sample(out.view((1,-1,self.img_size,self.img_size)), flow, 'bilinear', padding_mode='border', align_corners=False).squeeze()