rl4lms.algorithms.common package
Subpackages
- rl4lms.algorithms.common.maskable package
- Submodules
- rl4lms.algorithms.common.maskable.buffers module
- rl4lms.algorithms.common.maskable.callbacks module
- rl4lms.algorithms.common.maskable.distributions module
- rl4lms.algorithms.common.maskable.evaluation module
- rl4lms.algorithms.common.maskable.logits_processor module
- rl4lms.algorithms.common.maskable.policies module
- rl4lms.algorithms.common.maskable.utils module
- Module contents
Submodules
rl4lms.algorithms.common.algo_utils module
- rl4lms.algorithms.common.algo_utils.quantile_huber_loss(current_quantiles: <MagicMock id='139934403124272'>, target_quantiles: <MagicMock id='139934404299984'>, cum_prob: <MagicMock id='139934404546608'> | None = None, sum_over_quantiles: bool = True) <MagicMock id='139934403523056'>[source]
The quantile-regression loss, as described in the QR-DQN and TQC papers. Partially taken from https://github.com/bayesgroup/tqc_pytorch.
- Parameters:
current_quantiles – current estimate of quantiles, must be either (batch_size, n_quantiles) or (batch_size, n_critics, n_quantiles)
target_quantiles – target of quantiles, must be either (batch_size, n_target_quantiles), (batch_size, 1, n_target_quantiles), or (batch_size, n_critics, n_target_quantiles)
cum_prob – cumulative probabilities to calculate quantiles (also called midpoints in QR-DQN paper), must be either (batch_size, n_quantiles), (batch_size, 1, n_quantiles), or (batch_size, n_critics, n_quantiles). (if None, calculating unit quantiles)
sum_over_quantiles – if summing over the quantile dimension or not
- Returns:
the loss
- rl4lms.algorithms.common.algo_utils.conjugate_gradient_solver(matrix_vector_dot_fn: ~typing.Callable[[<MagicMock id='139934401959232'>], <MagicMock id='139934401966960'>], b, max_iter=10, residual_tol=1e-10) <MagicMock id='139934405039664'>[source]
Finds an approximate solution to a set of linear equations Ax = b
- Sources:
- Reference:
- Parameters:
matrix_vector_dot_fn – a function that right multiplies a matrix A by a vector v
b – the right hand term in the set of linear equations Ax = b
max_iter – the maximum number of iterations (default is 10)
residual_tol – residual tolerance for early stopping of the solving (default is 1e-10)
- Return x:
the approximate solution to the system of equations defined by matrix_vector_dot_fn and b
- rl4lms.algorithms.common.algo_utils.flat_grad(output, parameters: ~typing.Sequence[<MagicMock name='mock.parameter.Parameter' id='139934401917520'>], create_graph: bool = False, retain_graph: bool = False, device: str = 'cuda:0') <MagicMock id='139934405621264'>[source]
Returns the gradients of the passed sequence of parameters into a flat gradient. Order of parameters is preserved.
- Parameters:
output – functional output to compute the gradient for
parameters – sequence of
Parameterretain_graph – – If
False, the graph used to compute the grad will be freed. Defaults to the value ofcreate_graph.create_graph – – If
True, graph of the derivative will be constructed, allowing to compute higher order derivative products. Default:False.
- Returns:
Tensor containing the flattened gradients