rl4lms.algorithms.common.maskable package

Submodules

rl4lms.algorithms.common.maskable.buffers module

Code adapted from https://github.com/DLR-RM/stable-baselines3

class rl4lms.algorithms.common.maskable.buffers.MaskableRolloutBufferSamples(observations, actions, old_values, old_log_prob, advantages, returns, action_masks)[source]

Bases: NamedTuple

observations: <MagicMock id='139934450519440'>

Alias for field number 0

actions: <MagicMock id='139934446083648'>

Alias for field number 1

old_values: <MagicMock id='139934446091376'>

Alias for field number 2

old_log_prob: <MagicMock id='139934446115552'>

Alias for field number 3

advantages: <MagicMock id='139934446123280'>

Alias for field number 4

returns: <MagicMock id='139934446163840'>

Alias for field number 5

action_masks: <MagicMock id='139934446171568'>

Alias for field number 6

class rl4lms.algorithms.common.maskable.buffers.MaskableDictRolloutBufferSamples(observations: <MagicMock id='139934450519440'>, actions: <MagicMock id='139934446083648'>, old_values: <MagicMock id='139934446091376'>, old_log_prob: <MagicMock id='139934446115552'>, advantages: <MagicMock id='139934446123280'>, returns: <MagicMock id='139934446163840'>, action_masks: <MagicMock id='139934446171568'>)[source]

Bases: MaskableRolloutBufferSamples

class rl4lms.algorithms.common.maskable.buffers.MaskableRolloutBuffer(*args, **kwargs)[source]

Bases: RolloutBuffer

Rollout buffer that also stores the invalid action masks associated with each observation.

Parameters:
  • buffer_size – Max number of element in the buffer

  • observation_space – Observation space

  • action_space – Action space

  • device

  • gae_lambda – Factor for trade-off of bias vs variance for Generalized Advantage Estimator Equivalent to classic advantage when set to 1.

  • gamma – Discount factor

  • n_envs – Number of parallel environments

__init__(*args, **kwargs)[source]
reset() None[source]

Reset the buffer.

add(*args, action_masks: ndarray | None = None, **kwargs) None[source]
Parameters:

action_masks – Masks applied to constrain the choice of possible actions.

get(batch_size: int | None = None) Generator[MaskableRolloutBufferSamples, None, None][source]
class rl4lms.algorithms.common.maskable.buffers.MaskableDictRolloutBuffer(buffer_size: int, observation_space: ~gym.spaces.space.Space, action_space: ~gym.spaces.space.Space, device: <MagicMock id='139934446403744'> | str = 'cpu', gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1)[source]

Bases: DictRolloutBuffer

Dict Rollout buffer used in on-policy algorithms like A2C/PPO. Extends the RolloutBuffer to use dictionary observations

It corresponds to buffer_size transitions collected using the current policy. This experience will be discarded after the policy update. In order to use PPO objective, we also store the current value of each state and the log probability of each taken action.

The term rollout here refers to the model-free notion and should not be used with the concept of rollout used in model-based RL or planning. Hence, it is only involved in policy and value function training but not action selection.

Parameters:
  • buffer_size – Max number of element in the buffer

  • observation_space – Observation space

  • action_space – Action space

  • device

  • gae_lambda – Factor for trade-off of bias vs variance for Generalized Advantage Estimator Equivalent to classic advantage when set to 1.

  • gamma – Discount factor

  • n_envs – Number of parallel environments

__init__(buffer_size: int, observation_space: ~gym.spaces.space.Space, action_space: ~gym.spaces.space.Space, device: <MagicMock id='139934446403744'> | str = 'cpu', gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1)[source]
reset() None[source]

Reset the buffer.

add(*args, action_masks: <MagicMock id='139934446419440'> | None = None, **kwargs) None[source]
Parameters:

action_masks – Masks applied to constrain the choice of possible actions.

get(batch_size: int | None = None) Generator[MaskableDictRolloutBufferSamples, None, None][source]

rl4lms.algorithms.common.maskable.callbacks module

rl4lms.algorithms.common.maskable.distributions module

rl4lms.algorithms.common.maskable.evaluation module

rl4lms.algorithms.common.maskable.logits_processor module

rl4lms.algorithms.common.maskable.policies module

rl4lms.algorithms.common.maskable.utils module

rl4lms.algorithms.common.maskable.utils.get_action_masks(env: Env | VecEnv) ndarray[source]

Checks whether gym env exposes a method returning invalid action masks

Parameters:

env – the Gym environment to get masks from

Returns:

A numpy array of the masks

rl4lms.algorithms.common.maskable.utils.is_masking_supported(env: Env | VecEnv) bool[source]

Checks whether gym env exposes a method returning invalid action masks

Parameters:

env – the Gym environment to check

Returns:

True if the method is found, False otherwise

Module contents