diff --git a/src/samplers.py b/src/samplers.py new file mode 100644 index 0000000..699bbf1 --- /dev/null +++ b/src/samplers.py @@ -0,0 +1,86 @@ +"""Samplers for training, validation, and testing.""" + +import numpy as np +import pandas as pd +import abc + +class Sampler(abc.ABC): + """Abstract base class for data samplers.""" + def __init__(self): + """Blank constructor.""" + pass + + @abc.abstractmethod + def create_samplers(self): + """Create training, test, and validation samplers. + + This should return a dictionary with "train", "val", "test" as keys and + indices of datapoints as values. + """ + pass + + @abc.abstractstaticmethod + def name(): + """Name of the sampling method.""" + pass + +class RandomSampler(Sampler): + """Perform uniform random sampling on datapoints.""" + def __init__(self, seed, dataset_size): + """Initialize sampler. + + Parameters + ---------- + seed: int + Seed for random sampling. + dataset_size: int + Number of points in dataset + """ + self.seed = seed + self.dataset_size = dataset_size + + def create_samplers(self, sample_config): + """Randomly sample training, validation, and test datapoints. + + Parameters + ---------- + sample_config: dict + Dictionary with "train", "val", "test" as values and corresponding + fractions as values (must sum up to 1). + + Returns + ------- + samples: dict + Dictionary with indices for train, val, and test points. + """ + # Create randomizer + randomizer = np.random.default_rng(self.seed) + + # Create array of indices + idx_array = np.arange(self.dataset_size) + + # Shuffle array + np.random.shuffle(idx_array) + + # Get indices + train_size = int(np.ceil(sample_config["train"] * self.dataset_size)) + train_idx = idx_array[:train_size] + val_size = int(np.ceil(sample_config["val"] * self.dataset_size)) + val_idx = idx_array[train_size: train_size + val_size] + test_idx = idx_array[train_size + val_size:] + + # Create samples + samples = {"train": train_idx, "val": val_idx, "test": test_idx} + + return samples + + @staticmethod + def name(): + return "random" + +if __name__ == "__main__": + rs = RandomSampler(0, 100) + samples = rs.create_samplers({"train": 0.6, "val": 0.2, "test": 0.2}) + print(samples) + +