Source code for rl4lms.envs.common.observation

from dataclasses import dataclass
from abc import ABC, abstractmethod
import torch


[docs]@dataclass class BaseObservation: """ Placeholder for observation data class """ pass
[docs]class BaseObservationFeaturizer(ABC):
[docs] @abstractmethod def featurize(self, observation: BaseObservation) -> torch.Tensor: raise NotImplementedError
[docs] def get_observation_dim(self) -> int: """ Returns the observation dim """ return self.get_input_dim() + self.get_context_dim()