diff --git a/src/models.py b/src/models.py index 56321d2..015d29e 100644 --- a/src/models.py +++ b/src/models.py @@ -129,7 +129,7 @@ def forward(self, data_objects): # For each data object for i, data in enumerate(data_objects): # Apply initial transform - conv_data = self.init_transform[i](data.x.to(torch.float32)) + conv_data = self.init_transform[i](data.x) # Apply convolutional layers for layer in self.conv_layers[i]: