Skip to content

Commit

Permalink
Fixed codestyle
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaurav S Deshmukh committed Sep 23, 2023
1 parent 2670973 commit 2e0aa86
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions src/samplers.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
"""Samplers for training, validation, and testing."""

import numpy as np
import pandas as pd
import abc

import numpy as np


class Sampler(abc.ABC):
"""Abstract base class for data samplers."""

def __init__(self):
"""Blank constructor."""
pass

@abc.abstractmethod
def create_samplers(self):
"""Create training, test, and validation samplers.
This should return a dictionary with "train", "val", "test" as keys and
indices of datapoints as values.
"""
Expand All @@ -24,8 +26,10 @@ def name():
"""Name of the sampling method."""
pass


class RandomSampler(Sampler):
"""Perform uniform random sampling on datapoints."""

def __init__(self, seed, dataset_size):
"""Initialize sampler.
Expand Down Expand Up @@ -66,21 +70,21 @@ def create_samplers(self, sample_config):
train_size = int(np.ceil(sample_config["train"] * self.dataset_size))
train_idx = idx_array[:train_size]
val_size = int(np.ceil(sample_config["val"] * self.dataset_size))
val_idx = idx_array[train_size: train_size + val_size]
test_idx = idx_array[train_size + val_size:]
val_idx = idx_array[train_size : train_size + val_size]
test_idx = idx_array[train_size + val_size :]

# Create samples
samples = {"train": train_idx, "val": val_idx, "test": test_idx}

return samples

@staticmethod
def name():
"""Name of the sampling method."""
return "random"


if __name__ == "__main__":
rs = RandomSampler(0, 100)
samples = rs.create_samplers({"train": 0.6, "val": 0.2, "test": 0.2})
print(samples)


0 comments on commit 2e0aa86

Please sign in to comment.