Source code for realros.wrappers.normalize_obs_wrapper

from typing import Any, Dict, Union

import gymnasium as gym
import numpy as np


[docs] class NormalizeObservationWrapper(gym.ObservationWrapper): """ A wrapper for normalizing the observation space of an environment. This wrapper normalizes the observations to be between -1 and 1. It can handle environments whose observation space is either a Box or a dictionary with keys for 'observation', 'achieved_goal', and 'desired_goal'. Args: env (gym.Env): The environment to wrap. normalize_goal_spaces (bool): Whether to normalize the achieved_goal and desired_goal spaces. Raises: ValueError: If the observation space of the environment is not supported. """ def __init__(self, env: gym.Env, normalize_goal_spaces: bool = False) -> None: # init the ObservationWrapper super().__init__(env) self.normalize_goal_spaces = normalize_goal_spaces # check if it is gymnasium.Env based if isinstance(env.observation_space, gym.spaces.Box): self.observation_space = gym.spaces.Box(low=-1.0, high=1.0, shape=env.observation_space.shape, dtype=np.float32) self.normalize_observation = self._normalize_box_observation # check if it is gymnasium_robotics.GoalEnv based elif isinstance(env.observation_space, gym.spaces.Dict): self.observation_space = gym.spaces.Dict({ 'observation': gym.spaces.Box(low=-1.0, high=1.0, shape=env.observation_space['observation'].shape, dtype=np.float32), 'achieved_goal': env.observation_space['achieved_goal'], 'desired_goal': env.observation_space['desired_goal'] }) self.normalize_observation = self._normalize_dict_observation else: raise ValueError(f"Unsupported observation space: {type(env.observation_space)}") def _normalize_box_observation(self, observation: np.ndarray) -> np.ndarray: # Normalize a Box observation to be between -1 and 1 if isinstance(self.env.observation_space, gym.spaces.Box): low = self.env.observation_space.low high = self.env.observation_space.high elif isinstance(self.env.observation_space, gym.spaces.Dict): low = self.env.observation_space['observation'].low high = self.env.observation_space['observation'].high else: raise ValueError(f"Unsupported observation space: {type(self.env.observation_space)}") observation = 2 * (observation - low) / (high - low) - 1.0 return observation def _normalize_achieved_goal(self, achieved_goal: np.ndarray) -> np.ndarray: # Check that the achieved_goal_space is a Box space if not isinstance(self.env.observation_space['achieved_goal'], gym.spaces.Box): raise ValueError(f"Unsupported achieved_goal space: {type(self.env.observation_space['achieved_goal'])}") # Normalize an achieved_goal observation to be between -1 and 1 low = self.env.observation_space['achieved_goal'].low high = self.env.observation_space['achieved_goal'].high achieved_goal = 2 * (achieved_goal - low) / (high - low) - 1.0 return achieved_goal def _normalize_desired_goal(self, desired_goal: np.ndarray) -> np.ndarray: # Check that the desired_goal_space is a Box space if not isinstance(self.env.observation_space['desired_goal'], gym.spaces.Box): raise ValueError(f"Unsupported desired_goal space: {type(self.env.observation_space['desired_goal'])}") # Normalize a desired_goal observation to be between -1 and 1 low = self.env.observation_space['desired_goal'].low high = self.env.observation_space['desired_goal'].high desired_goal = 2 * (desired_goal - low) / (high - low) - 1.0 return desired_goal def _normalize_dict_observation(self, observation: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: # Normalize a dictionary observation with keys for 'observation', 'achieved_goal', and 'desired_goal' observation['observation'] = self._normalize_box_observation(observation['observation']) if self.normalize_goal_spaces: observation['achieved_goal'] = self._normalize_achieved_goal(observation['achieved_goal']) observation['desired_goal'] = self._normalize_desired_goal(observation['desired_goal']) return observation
[docs] def observation(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]: # Normalize the observation using the appropriate method return self.normalize_observation(observation)