rl4lms.envs.text_generation package

Subpackages

Submodules

rl4lms.envs.text_generation.alg_wrappers module

rl4lms.envs.text_generation.env module

rl4lms.envs.text_generation.evaluation_utils module

rl4lms.envs.text_generation.hf_generation_utils module

rl4lms.envs.text_generation.kl_controllers module

class rl4lms.envs.text_generation.kl_controllers.KLController(kl_coeff: float, target_kl: float | None = None)[source]

Bases: object

__init__(kl_coeff: float, target_kl: float | None = None) None[source]
step(kl_div: <MagicMock id='139934427555024'>)[source]

Adapts the KL coeff

property kl_coeff
get_state_dict() Dict[str, Any][source]
load_from_state_dict(state_dict: Dict[str, Any])[source]

rl4lms.envs.text_generation.logging_utils module

rl4lms.envs.text_generation.metric module

rl4lms.envs.text_generation.observation module

class rl4lms.envs.text_generation.observation.Observation(prompt_or_input_encoded_pt: <MagicMock id='139934431535568'>, prompt_or_input_attention_mask_pt: <MagicMock id='139934431340160'>, prompt_or_input_text: str, context_encoded_pt: <MagicMock id='139934431536384'>, context_attention_mask_pt: <MagicMock id='139934431937872'>, context_text: str, target_or_reference_texts: List[str], input_encoded_pt: <MagicMock id='139934431947472'>, input_attention_mask_pt: <MagicMock id='139934446522224'>, action_history: List[str], meta_info: Dict[str, Any])[source]

Bases: object

prompt_or_input_encoded_pt: <MagicMock id='139934431535568'>
prompt_or_input_attention_mask_pt: <MagicMock id='139934431340160'>
prompt_or_input_text: str
context_encoded_pt: <MagicMock id='139934431536384'>
context_attention_mask_pt: <MagicMock id='139934431937872'>
context_text: str
target_or_reference_texts: List[str]
input_encoded_pt: <MagicMock id='139934431947472'>
input_attention_mask_pt: <MagicMock id='139934446522224'>
action_history: List[str]
meta_info: Dict[str, Any]
to_dict() Dict[str, <MagicMock id='139934446555424'>][source]

For stable baselines (only return tensor items)

update(action: int, tokenizer: AutoTokenizer) Observation[source]

Updates the observation using the given action

classmethod init_from_sample(sample: Sample, tokenizer: AutoTokenizer, max_input_length: int, max_context_length: int, prompt_truncation_side: str, context_start_token: int | None = None, meta_info: Dict[str, Any] | None = None)[source]
__init__(prompt_or_input_encoded_pt: <MagicMock id='139934431535568'>, prompt_or_input_attention_mask_pt: <MagicMock id='139934431340160'>, prompt_or_input_text: str, context_encoded_pt: <MagicMock id='139934431536384'>, context_attention_mask_pt: <MagicMock id='139934431937872'>, context_text: str, target_or_reference_texts: ~typing.List[str], input_encoded_pt: <MagicMock id='139934431947472'>, input_attention_mask_pt: <MagicMock id='139934446522224'>, action_history: ~typing.List[str], meta_info: ~typing.Dict[str, ~typing.Any]) None

rl4lms.envs.text_generation.policy module

rl4lms.envs.text_generation.post_processors module

rl4lms.envs.text_generation.post_processors.three_sentence_summary(text)[source]

Returns first three sentences from the generated text

rl4lms.envs.text_generation.preference_reward module

rl4lms.envs.text_generation.registry module

rl4lms.envs.text_generation.reward module

rl4lms.envs.text_generation.test_datapool module

class rl4lms.envs.text_generation.test_datapool.TestTextGenPool(samples: List[Sample])[source]

Bases: TextGenPool

classmethod prepare(split: str, prompt: str, n_samples=100)[source]

A factory method to instantiate data pool

rl4lms.envs.text_generation.test_metric module

rl4lms.envs.text_generation.test_reward module

rl4lms.envs.text_generation.training_utils module

rl4lms.envs.text_generation.utils_supervised module

rl4lms.envs.text_generation.warm_start module

Module contents