RL4LMs
latest
Getting Started
Installation
Quick Start - Train PPO/NLPO using pre-defined YAML configs
Custom Building Blocks
Module Guide
rl4lms.algorithms package
rl4lms.envs package
rl4lms.core_components package
rl4lms.data_pools package
RL4LMs
Index
Edit on GitHub
Index
_
|
A
|
B
|
C
|
D
|
E
|
F
|
G
|
I
|
K
|
L
|
M
|
N
|
O
|
P
|
Q
|
R
|
S
|
T
|
U
|
W
|
X
|
Z
_
__call__() (rl4lms.envs.common.reward.RewardFunction class method)
__init__() (rl4lms.algorithms.common.maskable.buffers.MaskableDictRolloutBuffer method)
(rl4lms.algorithms.common.maskable.buffers.MaskableRolloutBuffer method)
(rl4lms.core_components.sampler.PrioritySampler method)
(rl4lms.data_pools.text_generation_pool.Sample method)
(rl4lms.data_pools.text_generation_pool.TextGenPool method)
(rl4lms.envs.common.action_space.ActionSpace method)
(rl4lms.envs.common.observation.BaseObservation method)
(rl4lms.envs.text_generation.caption_metrics.cider.Cider method)
(rl4lms.envs.text_generation.caption_metrics.cider.CiderScorer method)
(rl4lms.envs.text_generation.caption_metrics.spice.spice.Spice method)
(rl4lms.envs.text_generation.kl_controllers.KLController method)
(rl4lms.envs.text_generation.observation.Observation method)
(rl4lms.envs.text_generation.summ_metrics.summa_c.SummaCImager method)
(rl4lms.envs.text_generation.summ_metrics.summa_c.SummaCZS method)
A
action_history (rl4lms.envs.text_generation.observation.Observation attribute)
action_masks (rl4lms.algorithms.common.maskable.buffers.MaskableRolloutBufferSamples attribute)
action_to_ix() (rl4lms.envs.common.action_space.ActionSpace method)
actions (rl4lms.algorithms.common.maskable.buffers.MaskableRolloutBufferSamples attribute)
ActionSpace (class in rl4lms.envs.common.action_space)
add() (rl4lms.algorithms.common.maskable.buffers.MaskableDictRolloutBuffer method)
(rl4lms.algorithms.common.maskable.buffers.MaskableRolloutBuffer method)
(rl4lms.core_components.sampler.PrioritySampler method)
advantages (rl4lms.algorithms.common.maskable.buffers.MaskableRolloutBufferSamples attribute)
B
BaseObservation (class in rl4lms.envs.common.observation)
BaseObservationFeaturizer (class in rl4lms.envs.common.observation)
batcher() (in module rl4lms.envs.text_generation.summ_metrics.summa_c)
build_image() (rl4lms.envs.text_generation.summ_metrics.summa_c.SummaCImager method)
C
card_to_name() (in module rl4lms.envs.text_generation.summ_metrics.summa_c)
Cider (class in rl4lms.envs.text_generation.caption_metrics.cider)
CiderScorer (class in rl4lms.envs.text_generation.caption_metrics.cider)
CNNDailyMail (class in rl4lms.data_pools.custom_text_generation_pools)
CommonGen (class in rl4lms.data_pools.custom_text_generation_pools)
compute_cider() (rl4lms.envs.text_generation.caption_metrics.cider.CiderScorer method)
compute_doc_freq() (rl4lms.envs.text_generation.caption_metrics.cider.CiderScorer method)
compute_score() (rl4lms.envs.text_generation.caption_metrics.cider.Cider method)
(rl4lms.envs.text_generation.caption_metrics.cider.CiderScorer method)
(rl4lms.envs.text_generation.caption_metrics.spice.spice.Spice method)
conjugate_gradient_solver() (in module rl4lms.algorithms.common.algo_utils)
context_attention_mask_pt (rl4lms.envs.text_generation.observation.Observation attribute)
context_encoded_pt (rl4lms.envs.text_generation.observation.Observation attribute)
context_text (rl4lms.envs.text_generation.observation.Observation attribute)
cook_append() (rl4lms.envs.text_generation.caption_metrics.cider.CiderScorer method)
cook_refs() (in module rl4lms.envs.text_generation.caption_metrics.cider)
cook_test() (in module rl4lms.envs.text_generation.caption_metrics.cider)
copy() (rl4lms.envs.text_generation.caption_metrics.cider.CiderScorer method)
CRD3DialogueGeneration (class in rl4lms.data_pools.custom_text_generation_pools)
D
DailyDialog (class in rl4lms.data_pools.custom_text_generation_pools)
DEST_BASE_FOLDER (rl4lms.data_pools.custom_text_generation_pools.CRD3DialogueGeneration attribute)
DEST_EXTRACTED_FOLDER (rl4lms.data_pools.custom_text_generation_pools.CRD3DialogueGeneration attribute)
dict_hash() (in module rl4lms.core_components.sweep)
download_file_using_url() (in module rl4lms.data_pools.custom_text_generation_pools)
E
EOU_TOKEN (rl4lms.data_pools.custom_text_generation_pools.DailyDialog attribute)
F
featurize() (rl4lms.envs.common.observation.BaseObservationFeaturizer method)
find_products() (in module rl4lms.core_components.sweep)
flat_grad() (in module rl4lms.algorithms.common.algo_utils)
float_convert() (rl4lms.envs.text_generation.caption_metrics.spice.spice.Spice method)
G
gen_split_name() (rl4lms.data_pools.custom_text_generation_pools.CommonGen static method)
(rl4lms.data_pools.custom_text_generation_pools.ToTTo static method)
get() (rl4lms.algorithms.common.maskable.buffers.MaskableDictRolloutBuffer method)
(rl4lms.algorithms.common.maskable.buffers.MaskableRolloutBuffer method)
get_action_masks() (in module rl4lms.algorithms.common.maskable.utils)
get_all_samples() (rl4lms.core_components.sampler.PrioritySampler method)
get_cache_file() (rl4lms.envs.text_generation.summ_metrics.summa_c.SummaCImager method)
get_dataset() (rl4lms.data_pools.custom_text_generation_pools.IWSLT2017EnDe class method)
(rl4lms.data_pools.custom_text_generation_pools.WMT class method)
(rl4lms.data_pools.custom_text_generation_pools.WMT14PreprocessedEnDe class method)
(rl4lms.data_pools.custom_text_generation_pools.WMT16NewsOnlyDatasetEnDe class method)
get_dict_obj() (in module rl4lms.core_components.sweep)
get_neutral_idx() (in module rl4lms.envs.text_generation.summ_metrics.summa_c)
get_observation_dim() (rl4lms.envs.common.observation.BaseObservationFeaturizer method)
get_state_dict() (rl4lms.envs.text_generation.kl_controllers.KLController method)
I
id (rl4lms.data_pools.text_generation_pool.Sample attribute)
IMDB (class in rl4lms.data_pools.custom_text_generation_pools)
IMDBForSeq2Seq (class in rl4lms.data_pools.custom_text_generation_pools)
init_from_sample() (rl4lms.envs.text_generation.observation.Observation class method)
input_attention_mask_pt (rl4lms.envs.text_generation.observation.Observation attribute)
input_encoded_pt (rl4lms.envs.text_generation.observation.Observation attribute)
is_masking_supported() (in module rl4lms.algorithms.common.maskable.utils)
IWSLT2017EnDe (class in rl4lms.data_pools.custom_text_generation_pools)
ix_to_action() (rl4lms.envs.common.action_space.ActionSpace method)
K
kl_coeff (rl4lms.envs.text_generation.kl_controllers.KLController property)
KLController (class in rl4lms.envs.text_generation.kl_controllers)
L
load_cache() (rl4lms.envs.text_generation.summ_metrics.summa_c.SummaCImager method)
load_from_state_dict() (rl4lms.envs.text_generation.kl_controllers.KLController method)
load_nli() (rl4lms.envs.text_generation.summ_metrics.summa_c.SummaCImager method)
M
MaskableDictRolloutBuffer (class in rl4lms.algorithms.common.maskable.buffers)
MaskableDictRolloutBufferSamples (class in rl4lms.algorithms.common.maskable.buffers)
MaskableRolloutBuffer (class in rl4lms.algorithms.common.maskable.buffers)
MaskableRolloutBufferSamples (class in rl4lms.algorithms.common.maskable.buffers)
meta_data (rl4lms.data_pools.text_generation_pool.Sample attribute)
meta_info (rl4lms.envs.text_generation.observation.Observation attribute)
method() (rl4lms.envs.text_generation.caption_metrics.cider.Cider method)
(rl4lms.envs.text_generation.caption_metrics.spice.spice.Spice method)
module
rl4lms.algorithms
rl4lms.algorithms.a2c
rl4lms.algorithms.common
rl4lms.algorithms.common.algo_utils
rl4lms.algorithms.common.maskable
rl4lms.algorithms.common.maskable.buffers
rl4lms.algorithms.common.maskable.utils
rl4lms.algorithms.ppo
rl4lms.core_components
rl4lms.core_components.sampler
rl4lms.core_components.sweep
rl4lms.data_pools
rl4lms.data_pools.custom_text_generation_pools
rl4lms.data_pools.text_generation_pool
rl4lms.envs
rl4lms.envs.common
rl4lms.envs.common.action_space
rl4lms.envs.common.observation
rl4lms.envs.common.reward
rl4lms.envs.text_generation
rl4lms.envs.text_generation.caption_metrics
rl4lms.envs.text_generation.caption_metrics.cider
rl4lms.envs.text_generation.caption_metrics.spice
rl4lms.envs.text_generation.caption_metrics.spice.spice
rl4lms.envs.text_generation.kl_controllers
rl4lms.envs.text_generation.observation
rl4lms.envs.text_generation.policy
,
[1]
rl4lms.envs.text_generation.post_processors
rl4lms.envs.text_generation.summ_metrics
rl4lms.envs.text_generation.summ_metrics.summa_c
rl4lms.envs.text_generation.test_datapool
N
name_to_card() (in module rl4lms.envs.text_generation.summ_metrics.summa_c)
NarrativeQA (class in rl4lms.data_pools.custom_text_generation_pools)
normalize_text() (rl4lms.data_pools.custom_text_generation_pools.NarrativeQA class method)
O
Observation (class in rl4lms.envs.text_generation.observation)
observations (rl4lms.algorithms.common.maskable.buffers.MaskableRolloutBufferSamples attribute)
old_log_prob (rl4lms.algorithms.common.maskable.buffers.MaskableRolloutBufferSamples attribute)
old_values (rl4lms.algorithms.common.maskable.buffers.MaskableRolloutBufferSamples attribute)
P
PATH_TO_ALIGNED_DATA (rl4lms.data_pools.custom_text_generation_pools.CRD3DialogueGeneration attribute)
PATH_TO_CLEANED_DATA (rl4lms.data_pools.custom_text_generation_pools.CRD3DialogueGeneration attribute)
precook() (in module rl4lms.envs.text_generation.caption_metrics.cider)
prepare() (rl4lms.data_pools.custom_text_generation_pools.CNNDailyMail class method)
(rl4lms.data_pools.custom_text_generation_pools.CommonGen class method)
(rl4lms.data_pools.custom_text_generation_pools.CRD3DialogueGeneration class method)
(rl4lms.data_pools.custom_text_generation_pools.DailyDialog class method)
(rl4lms.data_pools.custom_text_generation_pools.IMDB class method)
(rl4lms.data_pools.custom_text_generation_pools.IMDBForSeq2Seq class method)
(rl4lms.data_pools.custom_text_generation_pools.IWSLT2017EnDe class method)
(rl4lms.data_pools.custom_text_generation_pools.NarrativeQA class method)
(rl4lms.data_pools.custom_text_generation_pools.ToTTo class method)
(rl4lms.data_pools.custom_text_generation_pools.WMT class method)
(rl4lms.data_pools.custom_text_generation_pools.WMT14PreprocessedEnDe class method)
(rl4lms.data_pools.custom_text_generation_pools.WMT16NewsOnlyDatasetEnDe class method)
(rl4lms.data_pools.custom_text_generation_pools.Xsum class method)
(rl4lms.data_pools.text_generation_pool.TextGenPool class method)
(rl4lms.envs.text_generation.test_datapool.TestTextGenPool class method)
PrioritySampler (class in rl4lms.core_components.sampler)
prompt_or_input_attention_mask_pt (rl4lms.envs.text_generation.observation.Observation attribute)
prompt_or_input_encoded_pt (rl4lms.envs.text_generation.observation.Observation attribute)
prompt_or_input_text (rl4lms.data_pools.text_generation_pool.Sample attribute)
(rl4lms.envs.text_generation.observation.Observation attribute)
Q
quantile_huber_loss() (in module rl4lms.algorithms.common.algo_utils)
R
references (rl4lms.data_pools.text_generation_pool.Sample attribute)
reset() (rl4lms.algorithms.common.maskable.buffers.MaskableDictRolloutBuffer method)
(rl4lms.algorithms.common.maskable.buffers.MaskableRolloutBuffer method)
returns (rl4lms.algorithms.common.maskable.buffers.MaskableRolloutBufferSamples attribute)
RewardFunction (class in rl4lms.envs.common.reward)
rl4lms.algorithms
module
rl4lms.algorithms.a2c
module
rl4lms.algorithms.common
module
rl4lms.algorithms.common.algo_utils
module
rl4lms.algorithms.common.maskable
module
rl4lms.algorithms.common.maskable.buffers
module
rl4lms.algorithms.common.maskable.utils
module
rl4lms.algorithms.ppo
module
rl4lms.core_components
module
rl4lms.core_components.sampler
module
rl4lms.core_components.sweep
module
rl4lms.data_pools
module
rl4lms.data_pools.custom_text_generation_pools
module
rl4lms.data_pools.text_generation_pool
module
rl4lms.envs
module
rl4lms.envs.common
module
rl4lms.envs.common.action_space
module
rl4lms.envs.common.observation
module
rl4lms.envs.common.reward
module
rl4lms.envs.text_generation
module
rl4lms.envs.text_generation.caption_metrics
module
rl4lms.envs.text_generation.caption_metrics.cider
module
rl4lms.envs.text_generation.caption_metrics.spice
module
rl4lms.envs.text_generation.caption_metrics.spice.spice
module
rl4lms.envs.text_generation.kl_controllers
module
rl4lms.envs.text_generation.observation
module
rl4lms.envs.text_generation.policy
module
,
[1]
rl4lms.envs.text_generation.post_processors
module
rl4lms.envs.text_generation.summ_metrics
module
rl4lms.envs.text_generation.summ_metrics.summa_c
module
rl4lms.envs.text_generation.test_datapool
module
S
Sample (class in rl4lms.data_pools.text_generation_pool)
sample() (rl4lms.core_components.sampler.PrioritySampler method)
(rl4lms.data_pools.text_generation_pool.TextGenPool method)
save_cache() (rl4lms.envs.text_generation.summ_metrics.summa_c.SummaCImager method)
save_imager_cache() (rl4lms.envs.text_generation.summ_metrics.summa_c.SummaCZS method)
score() (rl4lms.envs.text_generation.summ_metrics.summa_c.SummaCZS method)
score_one() (rl4lms.envs.text_generation.summ_metrics.summa_c.SummaCZS method)
size() (rl4lms.envs.common.action_space.ActionSpace method)
(rl4lms.envs.text_generation.caption_metrics.cider.CiderScorer method)
SOURCE_URL (rl4lms.data_pools.custom_text_generation_pools.CRD3DialogueGeneration attribute)
Spice (class in rl4lms.envs.text_generation.caption_metrics.spice.spice)
split() (rl4lms.data_pools.text_generation_pool.TextGenPool method)
split_2sents() (rl4lms.envs.text_generation.summ_metrics.summa_c.SummaCImager method)
split_config() (in module rl4lms.core_components.sweep)
split_paragraphs() (rl4lms.envs.text_generation.summ_metrics.summa_c.SummaCImager method)
split_sentences() (rl4lms.envs.text_generation.summ_metrics.summa_c.SummaCImager method)
split_text() (rl4lms.envs.text_generation.summ_metrics.summa_c.SummaCImager method)
step() (rl4lms.envs.text_generation.kl_controllers.KLController method)
SummaCImager (class in rl4lms.envs.text_generation.summ_metrics.summa_c)
SummaCZS (class in rl4lms.envs.text_generation.summ_metrics.summa_c)
T
target_or_reference_texts (rl4lms.envs.text_generation.observation.Observation attribute)
TestTextGenPool (class in rl4lms.envs.text_generation.test_datapool)
TextGenPool (class in rl4lms.data_pools.text_generation_pool)
three_sentence_summary() (in module rl4lms.envs.text_generation.post_processors)
to_dict() (rl4lms.envs.text_generation.observation.Observation method)
to_expand() (in module rl4lms.core_components.sweep)
tokenize() (rl4lms.envs.text_generation.caption_metrics.cider.Cider method)
(rl4lms.envs.text_generation.caption_metrics.spice.spice.Spice method)
tokenize_rewards() (in module rl4lms.algorithms.common.algo_utils)
ToTTo (class in rl4lms.data_pools.custom_text_generation_pools)
U
update() (rl4lms.core_components.sampler.PrioritySampler method)
(rl4lms.envs.text_generation.observation.Observation method)
W
WMT (class in rl4lms.data_pools.custom_text_generation_pools)
WMT14PreprocessedEnDe (class in rl4lms.data_pools.custom_text_generation_pools)
WMT16NewsOnlyDatasetEnDe (class in rl4lms.data_pools.custom_text_generation_pools)
X
Xsum (class in rl4lms.data_pools.custom_text_generation_pools)
Z
ZIP_FILE_NAME (rl4lms.data_pools.custom_text_generation_pools.CRD3DialogueGeneration attribute)