from copy import deepcopy
from typing import Any, Dict, Optional, Tuple
import gymnasium as gym
from gymnasium.envs.registration import EnvSpec
[docs]
class TimeLimitWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""This wrapper will issue a `truncated` signal if a maximum number of timesteps is exceeded.
If a truncation is not defined inside the environment itself, this is the only place that the truncation signal is issued.
Critically, this is different from the `terminated` signal that originates from the underlying environment as part of the MDP.
"""
def __init__(self, env: gym.Env, max_episode_steps: int) -> None:
"""Initializes the :class:`TimeLimitWrapper` with an environment and the number of steps after which truncation will occur.
Args:
env: The environment to apply the wrapper
max_episode_steps: An optional max episode steps (if ``None``, ``env.spec.max_episode_steps`` is used)
"""
gym.utils.RecordConstructorArgs.__init__(self, max_episode_steps=max_episode_steps)
gym.Wrapper.__init__(self, env)
self._max_episode_steps = max_episode_steps
self._elapsed_steps = 0
[docs]
def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict[str, Any]]:
"""Steps through the environment and if the number of steps elapsed exceeds ``max_episode_steps`` then truncate.
Args:
action: The environment step action
Returns:
The environment step ``(observation, reward, terminated, truncated, info)`` with `truncated=True`
if the number of steps elapsed >= max episode steps
"""
observation, reward, terminated, truncated, info = self.env.step(action)
self._elapsed_steps += 1
if self._elapsed_steps >= self._max_episode_steps:
truncated = True
info['time_limit_reached'] = True
if 'is_success' not in info:
info['is_success'] = False
return observation, reward, terminated, truncated, info
[docs]
def reset(self, **kwargs: Any) -> Tuple[Any, Dict[str, Any]]:
"""Reset the environment and zero the elapsed-step counter.
Args:
**kwargs: forwarded to the wrapped environment's ``reset``.
Returns:
The reset environment's initial observation (and info, on
Gymnasium API envs).
"""
self._elapsed_steps = 0
return self.env.reset(**kwargs)
@property
def spec(self) -> Optional[EnvSpec]:
"""Modifies the environment spec to include the `max_episode_steps=self._max_episode_steps`."""
if self._cached_spec is not None:
return self._cached_spec
env_spec = self.env.spec
if env_spec is not None:
env_spec = deepcopy(env_spec)
env_spec.max_episode_steps = self._max_episode_steps
self._cached_spec = env_spec
return self._cached_spec
# Usage of the public API to get max_episode_steps
[docs]
def get_env_params(env: gym.Env) -> Dict[str, int]:
params: Dict[str, int] = {}
params['max_timesteps'] = env.spec.max_episode_steps if env.spec else env._max_episode_steps
return params