Skip to content

Commit

Permalink
Merge pull request #17 from GreeleyGroup/enh/samplers
Browse files Browse the repository at this point in the history
ENH: Added RandomSampler
  • Loading branch information
deshmukg authored Sep 23, 2023
2 parents 945c494 + 2e0aa86 commit a12e497
Showing 1 changed file with 90 additions and 0 deletions.
90 changes: 90 additions & 0 deletions src/samplers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""Samplers for training, validation, and testing."""

import abc

import numpy as np


class Sampler(abc.ABC):
"""Abstract base class for data samplers."""

def __init__(self):
"""Blank constructor."""
pass

@abc.abstractmethod
def create_samplers(self):
"""Create training, test, and validation samplers.
This should return a dictionary with "train", "val", "test" as keys and
indices of datapoints as values.
"""
pass

@abc.abstractstaticmethod
def name():
"""Name of the sampling method."""
pass


class RandomSampler(Sampler):
"""Perform uniform random sampling on datapoints."""

def __init__(self, seed, dataset_size):
"""Initialize sampler.
Parameters
----------
seed: int
Seed for random sampling.
dataset_size: int
Number of points in dataset
"""
self.seed = seed
self.dataset_size = dataset_size

def create_samplers(self, sample_config):
"""Randomly sample training, validation, and test datapoints.
Parameters
----------
sample_config: dict
Dictionary with "train", "val", "test" as values and corresponding
fractions as values (must sum up to 1).
Returns
-------
samples: dict
Dictionary with indices for train, val, and test points.
"""
# Create randomizer
randomizer = np.random.default_rng(self.seed)

# Create array of indices
idx_array = np.arange(self.dataset_size)

# Shuffle array
np.random.shuffle(idx_array)

# Get indices
train_size = int(np.ceil(sample_config["train"] * self.dataset_size))
train_idx = idx_array[:train_size]
val_size = int(np.ceil(sample_config["val"] * self.dataset_size))
val_idx = idx_array[train_size : train_size + val_size]
test_idx = idx_array[train_size + val_size :]

# Create samples
samples = {"train": train_idx, "val": val_idx, "test": test_idx}

return samples

@staticmethod
def name():
"""Name of the sampling method."""
return "random"


if __name__ == "__main__":
rs = RandomSampler(0, 100)
samples = rs.create_samplers({"train": 0.6, "val": 0.2, "test": 0.2})
print(samples)

0 comments on commit a12e497

Please sign in to comment.