Source code for realros.wrappers.time_limit_wrapper

from copy import deepcopy
from typing import Any, Dict, Optional, Tuple

import gymnasium as gym

from gymnasium.envs.registration import EnvSpec

[docs] class TimeLimitWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs): """This wrapper will issue a `truncated` signal if a maximum number of timesteps is exceeded. If a truncation is not defined inside the environment itself, this is the only place that the truncation signal is issued. Critically, this is different from the `terminated` signal that originates from the underlying environment as part of the MDP. """ def __init__(self, env: gym.Env, max_episode_steps: int) -> None: """Initializes the :class:`TimeLimitWrapper` with an environment and the number of steps after which truncation will occur. Args: env: The environment to apply the wrapper max_episode_steps: An optional max episode steps (if ``None``, ``env.spec.max_episode_steps`` is used) """ gym.utils.RecordConstructorArgs.__init__(self, max_episode_steps=max_episode_steps) gym.Wrapper.__init__(self, env) self._max_episode_steps = max_episode_steps self._elapsed_steps = 0
[docs] def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict[str, Any]]: """Steps through the environment and if the number of steps elapsed exceeds ``max_episode_steps`` then truncate. Args: action: The environment step action Returns: The environment step ``(observation, reward, terminated, truncated, info)`` with `truncated=True` if the number of steps elapsed >= max episode steps """ observation, reward, terminated, truncated, info = self.env.step(action) self._elapsed_steps += 1 if self._elapsed_steps >= self._max_episode_steps: truncated = True info['time_limit_reached'] = True if 'is_success' not in info: info['is_success'] = False return observation, reward, terminated, truncated, info
[docs] def reset(self, **kwargs: Any) -> Tuple[Any, Dict[str, Any]]: """Reset the environment and zero the elapsed-step counter. Args: **kwargs: forwarded to the wrapped environment's ``reset``. Returns: The reset environment's initial observation (and info, on Gymnasium API envs). """ self._elapsed_steps = 0 return self.env.reset(**kwargs)
@property def spec(self) -> Optional[EnvSpec]: """Modifies the environment spec to include the `max_episode_steps=self._max_episode_steps`.""" if self._cached_spec is not None: return self._cached_spec env_spec = self.env.spec if env_spec is not None: env_spec = deepcopy(env_spec) env_spec.max_episode_steps = self._max_episode_steps self._cached_spec = env_spec return self._cached_spec
# Usage of the public API to get max_episode_steps
[docs] def get_env_params(env: gym.Env) -> Dict[str, int]: params: Dict[str, int] = {} params['max_timesteps'] = env.spec.max_episode_steps if env.spec else env._max_episode_steps return params