From 964dfcd9e2ea56b1a0ed0c7b328334408c010b7a Mon Sep 17 00:00:00 2001 From: Gaurav S Deshmukh Date: Thu, 21 Sep 2023 19:30:42 -0400 Subject: [PATCH] Add processed flag --- src/data.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/data.py b/src/data.py index 28a9049..423fd26 100644 --- a/src/data.py +++ b/src/data.py @@ -36,6 +36,7 @@ def __init__(self, root, prop_csv): # Create processed path if it doesn't exist self.processed_path = Path(self.processed_dir) self.processed_path.mkdir(exist_ok=True) + self.process_flag = False # Read csv self.prop_csv = prop_csv @@ -124,6 +125,9 @@ def process_data(self, z_cutoffs, node_features, edge_features, pad=False, # Save data objects torch.save(data_objects, self.processed_path / f"data_{i}.pt") + + # Set process flag to true + self.process_flag = True def len(self): """Return size of the dataset.""" @@ -133,6 +137,11 @@ def get(self, i): """Fetch the processed graph(s) at the i-th index.""" data_objects = torch.load(self.processed_path / f"data_{i}.pt") return data_objects + + def processed_status(self): + """Check if the dataset is processed.""" + return self.process_flag + if __name__ == "__main__": # Get path to root directory