from typing import Optional, Dict, Any
import torch
[docs]class KLController:
[docs] def __init__(self, kl_coeff: float, target_kl: Optional[float] = None) -> None:
self._kl_coeff = kl_coeff
self._target_kl = target_kl
[docs] def step(self, kl_div: torch.tensor):
"""
Adapts the KL coeff
"""
if self._target_kl is not None:
diff_to_target = (kl_div - self._target_kl) / self._target_kl
e_t = torch.clip(diff_to_target, -0.2, 0.2).item()
self._kl_coeff = self._kl_coeff * (1 + 0.1 * e_t)
@property
def kl_coeff(self):
return self._kl_coeff
[docs] def get_state_dict(self) -> Dict[str, Any]:
state = {
"target_kl": self._target_kl,
"current_kl_coeff": self._kl_coeff
}
return state
[docs] def load_from_state_dict(self, state_dict: Dict[str, Any]):
self._kl_coeff = state_dict["current_kl_coeff"]
self._target_kl = state_dict["target_kl"]
if __name__ == "__main__":
contr = KLController(kl_coeff=0.1, target_kl=0.1)
contr.step(torch.tensor(-0.2))
print(contr.kl_coeff)
contr.step(torch.tensor(0.3))
print(contr.kl_coeff)
contr.step(torch.tensor(0.4))
print(contr.kl_coeff)
state_dict = contr.get_state_dict()
print(state_dict)
contr._target_kl = None
contr._kl_coeff = None
contr.load_from_state_dict(state_dict)
assert contr._target_kl == state_dict["target_kl"]
assert contr._kl_coeff == state_dict["current_kl_coeff"]