Skip to content

ENH: Added RandomSampler #17

Merged
merged 2 commits into from
Sep 23, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions src/samplers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""Samplers for training, validation, and testing."""

import abc

import numpy as np


class Sampler(abc.ABC):
"""Abstract base class for data samplers."""

def __init__(self):
"""Blank constructor."""
pass

@abc.abstractmethod
def create_samplers(self):
"""Create training, test, and validation samplers.

This should return a dictionary with "train", "val", "test" as keys and
indices of datapoints as values.
"""
pass

@abc.abstractstaticmethod
def name():
"""Name of the sampling method."""
pass


class RandomSampler(Sampler):
"""Perform uniform random sampling on datapoints."""

def __init__(self, seed, dataset_size):
"""Initialize sampler.

Parameters
----------
seed: int
Seed for random sampling.
dataset_size: int
Number of points in dataset
"""
self.seed = seed
self.dataset_size = dataset_size

def create_samplers(self, sample_config):
"""Randomly sample training, validation, and test datapoints.

Parameters
----------
sample_config: dict
Dictionary with "train", "val", "test" as values and corresponding
fractions as values (must sum up to 1).

Returns
-------
samples: dict
Dictionary with indices for train, val, and test points.
"""
# Create randomizer
randomizer = np.random.default_rng(self.seed)

# Create array of indices
idx_array = np.arange(self.dataset_size)

# Shuffle array
np.random.shuffle(idx_array)

# Get indices
train_size = int(np.ceil(sample_config["train"] * self.dataset_size))
train_idx = idx_array[:train_size]
val_size = int(np.ceil(sample_config["val"] * self.dataset_size))
val_idx = idx_array[train_size : train_size + val_size]
test_idx = idx_array[train_size + val_size :]

# Create samples
samples = {"train": train_idx, "val": val_idx, "test": test_idx}

return samples

@staticmethod
def name():
"""Name of the sampling method."""
return "random"


if __name__ == "__main__":
rs = RandomSampler(0, 100)
samples = rs.create_samplers({"train": 0.6, "val": 0.2, "test": 0.2})
print(samples)