Source code for rl4lms.data_pools.text_generation_pool

import random
from abc import abstractclassmethod
from dataclasses import dataclass
from typing import Any, List, Dict


[docs]@dataclass(init=True) class Sample: id: str prompt_or_input_text: str references: List[str] meta_data: Dict[str, Any] = None
[docs]class TextGenPool:
[docs] def __init__(self, samples: List[Sample]): self._samples = samples
def __len__(self): return len(self._samples) def __getitem__(self, ix: int) -> Sample: if ix >= len(self): raise StopIteration sample = self._samples[ix] return sample, 1.0
[docs] def sample(self) -> Sample: random_sample = random.choice(self._samples) return random_sample
[docs] @abstractclassmethod def prepare(cls, **args) -> 'TextGenPool': """ A factory method to instantiate data pool """ raise NotImplementedError
[docs] def split(self, split_ratios: List[float]) -> List['TextGenPool']: start_ix = 0 pools = [] for ratio in split_ratios: count = int(len(self) * ratio) end_ix = start_ix + count pools.append(type(self)(self._samples[start_ix: end_ix])) start_ix = end_ix return pools