Skip to content

Commit

Permalink
Added RandomSampler
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaurav S Deshmukh committed Sep 23, 2023
1 parent 945c494 commit 2670973
Showing 1 changed file with 86 additions and 0 deletions.
86 changes: 86 additions & 0 deletions src/samplers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Samplers for training, validation, and testing."""

import numpy as np
import pandas as pd
import abc

class Sampler(abc.ABC):
"""Abstract base class for data samplers."""
def __init__(self):
"""Blank constructor."""
pass

@abc.abstractmethod
def create_samplers(self):
"""Create training, test, and validation samplers.
This should return a dictionary with "train", "val", "test" as keys and
indices of datapoints as values.
"""
pass

@abc.abstractstaticmethod
def name():
"""Name of the sampling method."""
pass

class RandomSampler(Sampler):
"""Perform uniform random sampling on datapoints."""
def __init__(self, seed, dataset_size):
"""Initialize sampler.
Parameters
----------
seed: int
Seed for random sampling.
dataset_size: int
Number of points in dataset
"""
self.seed = seed
self.dataset_size = dataset_size

def create_samplers(self, sample_config):
"""Randomly sample training, validation, and test datapoints.
Parameters
----------
sample_config: dict
Dictionary with "train", "val", "test" as values and corresponding
fractions as values (must sum up to 1).
Returns
-------
samples: dict
Dictionary with indices for train, val, and test points.
"""
# Create randomizer
randomizer = np.random.default_rng(self.seed)

# Create array of indices
idx_array = np.arange(self.dataset_size)

# Shuffle array
np.random.shuffle(idx_array)

# Get indices
train_size = int(np.ceil(sample_config["train"] * self.dataset_size))
train_idx = idx_array[:train_size]
val_size = int(np.ceil(sample_config["val"] * self.dataset_size))
val_idx = idx_array[train_size: train_size + val_size]
test_idx = idx_array[train_size + val_size:]

# Create samples
samples = {"train": train_idx, "val": val_idx, "test": test_idx}

return samples

@staticmethod
def name():
return "random"

if __name__ == "__main__":
rs = RandomSampler(0, 100)
samples = rs.create_samplers({"train": 0.6, "val": 0.2, "test": 0.2})
print(samples)


0 comments on commit 2670973

Please sign in to comment.