rl4lms.algorithms.common.maskable package
Submodules
rl4lms.algorithms.common.maskable.buffers module
Code adapted from https://github.com/DLR-RM/stable-baselines3
- class rl4lms.algorithms.common.maskable.buffers.MaskableRolloutBufferSamples(observations, actions, old_values, old_log_prob, advantages, returns, action_masks)[source]
Bases:
NamedTuple- observations: <MagicMock id='139934450519440'>
Alias for field number 0
- actions: <MagicMock id='139934446083648'>
Alias for field number 1
- old_values: <MagicMock id='139934446091376'>
Alias for field number 2
- old_log_prob: <MagicMock id='139934446115552'>
Alias for field number 3
- advantages: <MagicMock id='139934446123280'>
Alias for field number 4
- returns: <MagicMock id='139934446163840'>
Alias for field number 5
- action_masks: <MagicMock id='139934446171568'>
Alias for field number 6
- class rl4lms.algorithms.common.maskable.buffers.MaskableDictRolloutBufferSamples(observations: <MagicMock id='139934450519440'>, actions: <MagicMock id='139934446083648'>, old_values: <MagicMock id='139934446091376'>, old_log_prob: <MagicMock id='139934446115552'>, advantages: <MagicMock id='139934446123280'>, returns: <MagicMock id='139934446163840'>, action_masks: <MagicMock id='139934446171568'>)[source]
Bases:
MaskableRolloutBufferSamples
- class rl4lms.algorithms.common.maskable.buffers.MaskableRolloutBuffer(*args, **kwargs)[source]
Bases:
RolloutBufferRollout buffer that also stores the invalid action masks associated with each observation.
- Parameters:
buffer_size – Max number of element in the buffer
observation_space – Observation space
action_space – Action space
device –
gae_lambda – Factor for trade-off of bias vs variance for Generalized Advantage Estimator Equivalent to classic advantage when set to 1.
gamma – Discount factor
n_envs – Number of parallel environments
- add(*args, action_masks: ndarray | None = None, **kwargs) None[source]
- Parameters:
action_masks – Masks applied to constrain the choice of possible actions.
- get(batch_size: int | None = None) Generator[MaskableRolloutBufferSamples, None, None][source]
- class rl4lms.algorithms.common.maskable.buffers.MaskableDictRolloutBuffer(buffer_size: int, observation_space: ~gym.spaces.space.Space, action_space: ~gym.spaces.space.Space, device: <MagicMock id='139934446403744'> | str = 'cpu', gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1)[source]
Bases:
DictRolloutBufferDict Rollout buffer used in on-policy algorithms like A2C/PPO. Extends the RolloutBuffer to use dictionary observations
It corresponds to
buffer_sizetransitions collected using the current policy. This experience will be discarded after the policy update. In order to use PPO objective, we also store the current value of each state and the log probability of each taken action.The term rollout here refers to the model-free notion and should not be used with the concept of rollout used in model-based RL or planning. Hence, it is only involved in policy and value function training but not action selection.
- Parameters:
buffer_size – Max number of element in the buffer
observation_space – Observation space
action_space – Action space
device –
gae_lambda – Factor for trade-off of bias vs variance for Generalized Advantage Estimator Equivalent to classic advantage when set to 1.
gamma – Discount factor
n_envs – Number of parallel environments
- __init__(buffer_size: int, observation_space: ~gym.spaces.space.Space, action_space: ~gym.spaces.space.Space, device: <MagicMock id='139934446403744'> | str = 'cpu', gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1)[source]
- add(*args, action_masks: <MagicMock id='139934446419440'> | None = None, **kwargs) None[source]
- Parameters:
action_masks – Masks applied to constrain the choice of possible actions.
- get(batch_size: int | None = None) Generator[MaskableDictRolloutBufferSamples, None, None][source]