rl4lms.envs.text_generation package
Subpackages
- rl4lms.envs.text_generation.caption_metrics package
- rl4lms.envs.text_generation.policy package
- rl4lms.envs.text_generation.summ_metrics package
Submodules
rl4lms.envs.text_generation.alg_wrappers module
rl4lms.envs.text_generation.env module
rl4lms.envs.text_generation.evaluation_utils module
rl4lms.envs.text_generation.hf_generation_utils module
rl4lms.envs.text_generation.kl_controllers module
rl4lms.envs.text_generation.logging_utils module
rl4lms.envs.text_generation.metric module
rl4lms.envs.text_generation.observation module
- class rl4lms.envs.text_generation.observation.Observation(prompt_or_input_encoded_pt: <MagicMock id='139934431535568'>, prompt_or_input_attention_mask_pt: <MagicMock id='139934431340160'>, prompt_or_input_text: str, context_encoded_pt: <MagicMock id='139934431536384'>, context_attention_mask_pt: <MagicMock id='139934431937872'>, context_text: str, target_or_reference_texts: List[str], input_encoded_pt: <MagicMock id='139934431947472'>, input_attention_mask_pt: <MagicMock id='139934446522224'>, action_history: List[str], meta_info: Dict[str, Any])[source]
Bases:
object- prompt_or_input_encoded_pt: <MagicMock id='139934431535568'>
- prompt_or_input_attention_mask_pt: <MagicMock id='139934431340160'>
- prompt_or_input_text: str
- context_encoded_pt: <MagicMock id='139934431536384'>
- context_attention_mask_pt: <MagicMock id='139934431937872'>
- context_text: str
- target_or_reference_texts: List[str]
- input_encoded_pt: <MagicMock id='139934431947472'>
- input_attention_mask_pt: <MagicMock id='139934446522224'>
- action_history: List[str]
- meta_info: Dict[str, Any]
- to_dict() Dict[str, <MagicMock id='139934446555424'>][source]
For stable baselines (only return tensor items)
- update(action: int, tokenizer: AutoTokenizer) Observation[source]
Updates the observation using the given action
- classmethod init_from_sample(sample: Sample, tokenizer: AutoTokenizer, max_input_length: int, max_context_length: int, prompt_truncation_side: str, context_start_token: int | None = None, meta_info: Dict[str, Any] | None = None)[source]
- __init__(prompt_or_input_encoded_pt: <MagicMock id='139934431535568'>, prompt_or_input_attention_mask_pt: <MagicMock id='139934431340160'>, prompt_or_input_text: str, context_encoded_pt: <MagicMock id='139934431536384'>, context_attention_mask_pt: <MagicMock id='139934431937872'>, context_text: str, target_or_reference_texts: ~typing.List[str], input_encoded_pt: <MagicMock id='139934431947472'>, input_attention_mask_pt: <MagicMock id='139934446522224'>, action_history: ~typing.List[str], meta_info: ~typing.Dict[str, ~typing.Any]) None