diff --git a/src/featurizers.py b/src/featurizers.py index 9a37b73..9597295 100644 --- a/src/featurizers.py +++ b/src/featurizers.py @@ -505,7 +505,7 @@ def featurize_graph(self, atoms_graph): self._feat_tensor = self.encoder.transform(bond_dist_arr) # Create list of edge indices - self._edge_indices = torch.Tensor(list(atoms_graph.graph.edges())).view(2, -1) + self._edge_indices = torch.LongTensor(list(atoms_graph.graph.edges())).view(2, -1) @property def feat_tensor(self):