diff --git a/src/data.py b/src/data.py index 28a9049..423fd26 100644 --- a/src/data.py +++ b/src/data.py @@ -36,6 +36,7 @@ def __init__(self, root, prop_csv): # Create processed path if it doesn't exist self.processed_path = Path(self.processed_dir) self.processed_path.mkdir(exist_ok=True) + self.process_flag = False # Read csv self.prop_csv = prop_csv @@ -124,6 +125,9 @@ def process_data(self, z_cutoffs, node_features, edge_features, pad=False, # Save data objects torch.save(data_objects, self.processed_path / f"data_{i}.pt") + + # Set process flag to true + self.process_flag = True def len(self): """Return size of the dataset.""" @@ -133,6 +137,11 @@ def get(self, i): """Fetch the processed graph(s) at the i-th index.""" data_objects = torch.load(self.processed_path / f"data_{i}.pt") return data_objects + + def processed_status(self): + """Check if the dataset is processed.""" + return self.process_flag + if __name__ == "__main__": # Get path to root directory