from dataclasses import dataclass
from abc import ABC, abstractmethod
import torch
[docs]@dataclass
class BaseObservation:
"""
Placeholder for observation data class
"""
pass
[docs]class BaseObservationFeaturizer(ABC):
[docs] @abstractmethod
def featurize(self, observation: BaseObservation) -> torch.Tensor:
raise NotImplementedError
[docs] def get_observation_dim(self) -> int:
"""
Returns the observation dim
"""
return self.get_input_dim() + self.get_context_dim()