From a9d3d7afe769445bad84acdf42563f5c7774ab98 Mon Sep 17 00:00:00 2001 From: Gaurav S Deshmukh Date: Fri, 22 Sep 2023 21:47:47 -0400 Subject: [PATCH] Removed float32 --- src/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/models.py b/src/models.py index 56321d2..015d29e 100644 --- a/src/models.py +++ b/src/models.py @@ -129,7 +129,7 @@ def forward(self, data_objects): # For each data object for i, data in enumerate(data_objects): # Apply initial transform - conv_data = self.init_transform[i](data.x.to(torch.float32)) + conv_data = self.init_transform[i](data.x) # Apply convolutional layers for layer in self.conv_layers[i]: