diff --git a/src/data.py b/src/data.py index 708640b..a09e111 100644 --- a/src/data.py +++ b/src/data.py @@ -109,6 +109,9 @@ def process(self, z_cutoffs, node_features, edge_features, max_atoms=12, ) data_objects.append(data_obj) + # Add name of structure + data_objects.append(name) + # Save data objects torch.save(data_objects, self.processed_path / f"data_{i}.pt")