sb3_ros_support API reference

Stable Baselines 3 algorithm wrappers configured for ROS-based training scripts. Each algorithm subclasses sb3_ros_support.core.BasicModel and exposes a uniform train / validate / save / load surface, so swapping PPO for SAC for TD3 is a configuration edit, not a code rewrite.

Base model

Extending the SB3 models of the frobs_rl library.

Recreated to overcome the following errors:
  • Cannot use with multiple environments

  • Support for goal-conditioned environments

  • pyyaml for loading parameters

class sb3_ros_support.core.BasicModel(env, save_model_path, log_path, parm_dict, load_trained=False, action_noise_type='normal', action_noise=True, seed=None)[source]

Bases: object

Base class for all the algorithms of Stable Baselines3.

Parameters:
  • env (gym.Env) – The environment to be used.

  • save_model_path (str) – The path to save the model.

  • parm_dict (dict) – The dictionary containing the parameters.

  • log_path (str) – The path to save the log.

  • load_trained (bool) – Whether to load a trained model or not.

  • action_noise_type (str) – The type of action noise to use. Can be “normal” or “ornstein”. (Optional)

  • action_noise (bool) – Whether to use action noise or not. (Optional)

  • seed (int) – If provided, appended to save_prefix / trained_model_name / log_folder as _s<seed>_<YYYYmmdd_HHMMSS> so every run lands in its own checkpoint + log directory (no clobber across seeds or repeat runs).

train(action_cycle_time=None)[source]

Function to train the model the number of steps specified in the yaml config file. The function will automatically save the model after training.

Parameters:

action_cycle_time (float) – The time to wait between actions. (Optional)

Returns:

True if the model was trained, False otherwise.

Return type:

bool

save_model()[source]

Function to save the model.

Returns:

True if the model was saved, False otherwise.

Return type:

bool

save_replay_buffer()[source]

Function to save the replay buffer, to be used the training must be finished or an error will be raised.

Returns:

True if the replay buffer was saved, False otherwise.

Return type:

bool

set_model_logger()[source]

Function to set a logger of the model.

The log directory is composed as <log_path>/<log_folder><run_tag> where run_tag is the per-run _s<seed>_<timestamp> suffix (empty when no seed was supplied). If the resulting directory already exists (e.g. two runs collide within a one-second timestamp window), an extra timestamp segment is appended so the new logger never writes over an existing run’s TensorBoard data.

Returns:

True if the logger was set, False otherwise.

Return type:

bool

close_env()[source]

Use the env close method to close the environment.

Returns:

True if the environment was closed, False otherwise.

Return type:

bool

check_env()[source]

Use the stable-baselines check_env method to check the environment.

Returns:

True if the environment was checked, False otherwise.

Return type:

bool

predict(observation, state=None, deterministic=False)[source]

Get the current action based on the observation, state or mask

Parameters:
  • observation (ndarray) – The environment observation.

  • state (ndarray) – The previous states of the environment, used in recurrent policies. (Optional)

  • deterministic (bool) – Whether to return deterministic actions or not. (Optional)

Returns:

The action to be taken.

Return type:

ndarray

Algorithms

PPO

class sb3_ros_support.ppo.PPO(env, save_model_path, log_path, model_pkg_path=None, load_trained=False, load_model_path=None, config_file_pkg=None, config_filename=None, abs_config_path=None, seed=None)[source]

Bases: BasicModel

Proximal Policy Optimization (PPO) algorithm.

Paper: https://arxiv.org/abs/1707.06347

Parameters:
  • env (gym.Env) – The environment to be used.

  • save_model_path (str) – The path to save the model. Can be absolute or relative.

  • log_path (str) – The abs path to save the log. Can be absolute or relative.

  • model_pkg_path (str) – The package name to save or load the model.

  • load_trained (bool) – Whether to load a trained model or not.

  • load_model_path (str) – The path to load the model. Should include the model name. Can be absolute or relative.

  • config_file_pkg (str) – The package name of the config file. Required if abs_config_path is not provided.

  • config_filename (str) – The name of the config file. Required if abs_config_path is not provided.

  • abs_config_path (str) – The absolute path to the config file. Required if config_file_pkg and config_filename are not provided.

  • seed (int) – If provided, overrides the YAML ppo_params.seed for the SB3 learner (PyTorch / rollout RNG).

static load_trained_model(model_path, model_pkg=None, env=None, config_file_pkg=None, config_filename=None, abs_config_path=None)[source]

Load a trained model. Use only with predict function, as the logs will not be saved.

Parameters:
  • model_path (str) – The path to the trained model. Can be absolute or relative.

  • model_pkg (str) – The package name to load the model. Required if abs_model_path is relative.

  • env (gym.Env) – The environment to be used.

  • config_file_pkg (str) – The package name of the config file. Use the same package as model_pkg if not provided.

  • config_filename (str) – The name of the config file.

  • abs_config_path (str) – The absolute path to the config file.

Returns:

The loaded model.

Return type:

model

A2C

class sb3_ros_support.a2c.A2C(env, save_model_path, log_path, model_pkg_path=None, load_trained=False, load_model_path=None, config_file_pkg=None, config_filename=None, abs_config_path=None, seed=None)[source]

Bases: BasicModel

Advantage Actor-Critic (A2C) algorithm.

Paper: https://arxiv.org/abs/1602.01783

Parameters:
  • env (gym.Env) – The environment to be used.

  • save_model_path (str) – The path to save the model. Can be absolute or relative.

  • log_path (str) – The abs path to save the log. Can be absolute or relative.

  • model_pkg_path (str) – The package name to save or load the model.

  • load_trained (bool) – Whether to load a trained model or not.

  • load_model_path (str) – The path to load the model. Should include the model name. Can be absolute or relative.

  • config_file_pkg (str) – The package name of the config file. Required if abs_config_path is not provided.

  • config_filename (str) – The name of the config file. Required if abs_config_path is not provided.

  • abs_config_path (str) – The absolute path to the config file. Required if config_file_pkg and config_filename are not provided.

  • seed (int) – If provided, overrides the YAML a2c_params.seed for the SB3 learner (PyTorch / rollout RNG).

static load_trained_model(model_path, model_pkg=None, env=None, config_file_pkg=None, config_filename=None, abs_config_path=None)[source]

Load a trained model. Use only with predict function, as the logs will not be saved.

Parameters:
  • model_path (str) – The path to the trained model. Can be absolute or relative.

  • model_pkg (str) – The package name to load the model. Required if abs_model_path is relative.

  • env (gym.Env) – The environment to be used.

  • config_file_pkg (str) – The package name of the config file. Use the same package as model_pkg if not provided.

  • config_filename (str) – The name of the config file.

  • abs_config_path (str) – The absolute path to the config file.

Returns:

The loaded model.

Return type:

model

DDPG

class sb3_ros_support.ddpg.DDPG(env, save_model_path, log_path, model_pkg_path=None, load_trained=False, load_model_path=None, config_file_pkg=None, config_filename=None, abs_config_path=None, use_her=False, seed=None)[source]

Bases: BasicModel

Deep Deterministic Policy Gradient (DDPG) with optional HER for goal-conditioned envs.

Paper: https://arxiv.org/abs/1509.02971

Policy selection is automatic:
  • gymnasium.spaces.Dict observation space → "MultiInputPolicy"

  • anything else → "MlpPolicy"

Hindsight Experience Replay (HER) is enabled when use_her=True or when the YAML config has use_HER: true. HER is only valid for goal-conditioned envs.

Parameters:
  • env (gym.Env) – The environment to be used.

  • save_model_path (str) – The path to save the model. Can be absolute or relative.

  • log_path (str) – The abs path to save the log. Can be absolute or relative.

  • model_pkg_path (str) – The package name to save or load the model.

  • load_trained (bool) – Whether to load a trained model or not.

  • load_model_path (str) – The path to load the model. Should include the model name. Can be absolute or relative.

  • config_file_pkg (str) – The package name of the config file. Required if abs_config_path is not provided.

  • config_filename (str) – The name of the config file. Required if abs_config_path is not provided.

  • abs_config_path (str) – The absolute path to the config file. Required if config_file_pkg and config_filename are not provided.

  • use_her (bool) – Whether to use Hindsight Experience Replay. Only valid for goal-conditioned envs (Dict obs space).

  • seed (int) – If provided, overrides the YAML ddpg_params.seed for the SB3 learner (PyTorch / replay-buffer RNG).

static load_trained_model(model_path, model_pkg=None, env=None, config_file_pkg=None, config_filename=None, abs_config_path=None, use_her=False)[source]

Load a trained model. Use only with predict function, as the logs will not be saved.

Parameters:
  • model_path (str) – The path to the trained model. Can be absolute or relative.

  • model_pkg (str) – The package name to load the model. Required if abs_model_path is relative.

  • env (gym.Env) – The environment to be used.

  • config_file_pkg (str) – The package name of the config file. Use the same package as model_pkg if not provided.

  • config_filename (str) – The name of the config file.

  • abs_config_path (str) – The absolute path to the config file.

  • use_her (bool) – Whether to use Hindsight Experience Replay. Only valid for goal-conditioned envs.

Returns:

The loaded model.

Return type:

model

TD3

class sb3_ros_support.td3.TD3(env, save_model_path, log_path, model_pkg_path=None, load_trained=False, load_model_path=None, config_file_pkg=None, config_filename=None, abs_config_path=None, use_her=False, seed=None)[source]

Bases: BasicModel

Twin Delayed DDPG (TD3) with optional HER for goal-conditioned envs.

Paper: https://arxiv.org/abs/1802.09477

Policy selection is automatic:
  • gymnasium.spaces.Dict observation space → "MultiInputPolicy"

  • anything else → "MlpPolicy"

Hindsight Experience Replay (HER) is enabled when use_her=True or when the YAML config has use_HER: true. HER is only valid for goal-conditioned envs.

Parameters:
  • env (gym.Env) – The environment to be used.

  • save_model_path (str) – The path to save the model. Can be absolute or relative.

  • log_path (str) – The abs path to save the log. Can be absolute or relative.

  • model_pkg_path (str) – The package name to save or load the model.

  • load_trained (bool) – Whether to load a trained model or not.

  • load_model_path (str) – The path to load the model. Should include the model name. Can be absolute or relative.

  • config_file_pkg (str) – The package name of the config file. Required if abs_config_path is not provided.

  • config_filename (str) – The name of the config file. Required if abs_config_path is not provided.

  • abs_config_path (str) – The absolute path to the config file. Required if config_file_pkg and config_filename are not provided.

  • use_her (bool) – Whether to use Hindsight Experience Replay. Only valid for goal-conditioned envs (Dict obs space).

  • seed (int) – If provided, overrides the YAML td3_params.seed for the SB3 learner (PyTorch / replay-buffer RNG).

static load_trained_model(model_path, model_pkg=None, env=None, config_file_pkg=None, config_filename=None, abs_config_path=None, use_her=False)[source]

Load a trained model. Use only with predict function, as the logs will not be saved.

Parameters:
  • model_path (str) – The path to the trained model. Can be absolute or relative.

  • model_pkg (str) – The package name to load the model. Required if abs_model_path is relative.

  • env (gym.Env) – The environment to be used.

  • config_file_pkg (str) – The package name of the config file. Use the same package as model_pkg if not provided.

  • config_filename (str) – The name of the config file.

  • abs_config_path (str) – The absolute path to the config file.

  • use_her (bool) – Whether to use Hindsight Experience Replay. Only valid for goal-conditioned envs.

Returns:

The loaded model.

Return type:

model

SAC

class sb3_ros_support.sac.SAC(env, save_model_path, log_path, model_pkg_path=None, load_trained=False, load_model_path=None, config_file_pkg=None, config_filename=None, abs_config_path=None, use_her=False, seed=None)[source]

Bases: BasicModel

Soft Actor-Critic (SAC) with optional HER for goal-conditioned envs.

Paper: https://arxiv.org/abs/1801.01290

Policy selection is automatic:
  • gymnasium.spaces.Dict observation space → "MultiInputPolicy"

  • anything else → "MlpPolicy"

Hindsight Experience Replay (HER) is enabled when use_her=True or when the YAML config has use_HER: true. HER is only valid for goal-conditioned envs.

Parameters:
  • env (gym.Env) – The environment to be used.

  • save_model_path (str) – The path to save the model. Can be absolute or relative.

  • log_path (str) – The abs path to save the log. Can be absolute or relative.

  • model_pkg_path (str) – The package name to save or load the model.

  • load_trained (bool) – Whether to load a trained model or not.

  • load_model_path (str) – The path to load the model. Should include the model name. Can be absolute or relative.

  • config_file_pkg (str) – The package name of the config file. Required if abs_config_path is not provided.

  • config_filename (str) – The name of the config file. Required if abs_config_path is not provided.

  • abs_config_path (str) – The absolute path to the config file. Required if config_file_pkg and config_filename are not provided.

  • use_her (bool) – Whether to use Hindsight Experience Replay. Only valid for goal-conditioned envs (Dict obs space).

  • seed (int) – If provided, overrides the YAML sac_params.seed for the SB3 learner (PyTorch / replay-buffer RNG).

static load_trained_model(model_path, model_pkg=None, env=None, config_file_pkg=None, config_filename=None, abs_config_path=None, use_her=False)[source]

Load a trained model. Use only with predict function, as the logs will not be saved.

Parameters:
  • model_path (str) – The path to the trained model. Can be absolute or relative.

  • model_pkg (str) – The package name to load the model. Required if abs_model_path is relative.

  • env (gym.Env) – The environment to be used.

  • config_file_pkg (str) – The package name of the config file. Use the same package as model_pkg if not provided.

  • config_filename (str) – The name of the config file.

  • abs_config_path (str) – The absolute path to the config file.

  • use_her (bool) – Whether to use Hindsight Experience Replay. Only valid for goal-conditioned envs.

Returns:

The loaded model.

Return type:

model

DQN

class sb3_ros_support.dqn.DQN(env, save_model_path, log_path, model_pkg_path=None, load_trained=False, load_model_path=None, config_file_pkg=None, config_filename=None, abs_config_path=None, use_her=False, seed=None)[source]

Bases: BasicModel

Deep Q Network (DQN) with optional HER for goal-conditioned envs.

Paper: https://arxiv.org/abs/1312.5602

Policy selection is automatic:
  • gymnasium.spaces.Dict observation space → "MultiInputPolicy"

  • anything else → "MlpPolicy"

Hindsight Experience Replay (HER) is enabled when use_her=True or when the YAML config has use_HER: true. HER is only valid for goal-conditioned envs.

Parameters:
  • env (gym.Env) – The environment to be used.

  • save_model_path (str) – The path to save the model. Can be absolute or relative.

  • log_path (str) – The abs path to save the log. Can be absolute or relative.

  • model_pkg_path (str) – The package name to save or load the model.

  • load_trained (bool) – Whether to load a trained model or not.

  • load_model_path (str) – The path to load the model. Should include the model name. Can be absolute or relative.

  • config_file_pkg (str) – The package name of the config file. Required if abs_config_path is not provided.

  • config_filename (str) – The name of the config file. Required if abs_config_path is not provided.

  • abs_config_path (str) – The absolute path to the config file. Required if config_file_pkg and config_filename are not provided.

  • use_her (bool) – Whether to use Hindsight Experience Replay. Only valid for goal-conditioned envs (Dict obs space).

  • seed (int) – If provided, overrides the YAML dqn_params.seed for the SB3 learner (PyTorch / replay-buffer RNG).

static load_trained_model(model_path, model_pkg=None, env=None, config_file_pkg=None, config_filename=None, abs_config_path=None, use_her=False)[source]

Load a trained model. Use only with predict function, as the logs will not be saved.

Parameters:
  • model_path (str) – The path to the trained model. Can be absolute or relative.

  • model_pkg (str) – The package name to load the model. Required if abs_model_path is relative.

  • env (gym.Env) – The environment to be used.

  • config_file_pkg (str) – The package name of the config file. Use the same package as model_pkg if not provided.

  • config_filename (str) – The name of the config file.

  • abs_config_path (str) – The absolute path to the config file.

  • use_her (bool) – Whether to use Hindsight Experience Replay. Only valid for goal-conditioned envs.

Returns:

The loaded model.

Return type:

model

Goal-conditioned variants (HER)

DDPG_GOAL

Backwards-compat shim for sb3_ros_support.ddpg_goal.DDPG_GOAL.

Goal-conditioned DDPG support was consolidated into the single sb3_ros_support.ddpg.DDPG class. DDPG now auto-detects the policy type from the env’s observation space (gymnasium.spaces.Dict"MultiInputPolicy", otherwise "MlpPolicy") and enables Hindsight Experience Replay (HER) via the use_her constructor flag or the YAML config’s use_HER key.

DDPG_GOAL remains a working alias for backwards compatibility but emits a DeprecationWarning on instantiation. New code should import sb3_ros_support.ddpg.DDPG directly.

class sb3_ros_support.ddpg_goal.DDPG_GOAL(*args, **kwargs)[source]

Bases: DDPG

Deprecated. Use sb3_ros_support.ddpg.DDPG with use_her=True.

Parameters:
  • env (gym.Env) – The environment to be used.

  • save_model_path (str) – The path to save the model. Can be absolute or relative.

  • log_path (str) – The abs path to save the log. Can be absolute or relative.

  • model_pkg_path (str) – The package name to save or load the model.

  • load_trained (bool) – Whether to load a trained model or not.

  • load_model_path (str) – The path to load the model. Should include the model name. Can be absolute or relative.

  • config_file_pkg (str) – The package name of the config file. Required if abs_config_path is not provided.

  • config_filename (str) – The name of the config file. Required if abs_config_path is not provided.

  • abs_config_path (str) – The absolute path to the config file. Required if config_file_pkg and config_filename are not provided.

  • use_her (bool) – Whether to use Hindsight Experience Replay. Only valid for goal-conditioned envs (Dict obs space).

  • seed (int) – If provided, overrides the YAML ddpg_params.seed for the SB3 learner (PyTorch / replay-buffer RNG).

static load_trained_model(model_path, model_pkg=None, env=None, config_file_pkg=None, config_filename=None, abs_config_path=None, use_her=True)[source]

Backwards-compat loader preserving the pre-cleanup defaults.

Defaults config_filename to ddpg_goal.yaml and use_her to True so legacy callers using DDPG_GOAL.load_trained_model(path, env=goal_env) keep the goal-conditioned behaviour instead of silently picking up ddpg.yaml with HER disabled.

Parameters:
  • model_path (str)

  • model_pkg (str | None)

  • env (Any | None)

  • config_file_pkg (str | None)

  • config_filename (str | None)

  • abs_config_path (str | None)

  • use_her (bool)

Return type:

DDPG_GOAL

TD3_GOAL

Backwards-compat shim for sb3_ros_support.td3_goal.TD3_GOAL.

Goal-conditioned TD3 support was consolidated into the single sb3_ros_support.td3.TD3 class. TD3 now auto-detects the policy type from the env’s observation space (gymnasium.spaces.Dict"MultiInputPolicy", otherwise "MlpPolicy") and enables Hindsight Experience Replay (HER) via the use_her constructor flag or the YAML config’s use_HER key.

TD3_GOAL remains a working alias for backwards compatibility but emits a DeprecationWarning on instantiation. New code should import sb3_ros_support.td3.TD3 directly.

class sb3_ros_support.td3_goal.TD3_GOAL(*args, **kwargs)[source]

Bases: TD3

Deprecated. Use sb3_ros_support.td3.TD3 with use_her=True.

Parameters:
  • env (gym.Env) – The environment to be used.

  • save_model_path (str) – The path to save the model. Can be absolute or relative.

  • log_path (str) – The abs path to save the log. Can be absolute or relative.

  • model_pkg_path (str) – The package name to save or load the model.

  • load_trained (bool) – Whether to load a trained model or not.

  • load_model_path (str) – The path to load the model. Should include the model name. Can be absolute or relative.

  • config_file_pkg (str) – The package name of the config file. Required if abs_config_path is not provided.

  • config_filename (str) – The name of the config file. Required if abs_config_path is not provided.

  • abs_config_path (str) – The absolute path to the config file. Required if config_file_pkg and config_filename are not provided.

  • use_her (bool) – Whether to use Hindsight Experience Replay. Only valid for goal-conditioned envs (Dict obs space).

  • seed (int) – If provided, overrides the YAML td3_params.seed for the SB3 learner (PyTorch / replay-buffer RNG).

static load_trained_model(model_path, model_pkg=None, env=None, config_file_pkg=None, config_filename=None, abs_config_path=None, use_her=True)[source]

Backwards-compat loader preserving the pre-cleanup defaults.

Defaults config_filename to td3_goal.yaml and use_her to True so legacy callers using TD3_GOAL.load_trained_model(path, env=goal_env) keep the goal-conditioned behaviour instead of silently picking up td3.yaml with HER disabled.

Parameters:
  • model_path (str)

  • model_pkg (str | None)

  • env (Any | None)

  • config_file_pkg (str | None)

  • config_filename (str | None)

  • abs_config_path (str | None)

  • use_her (bool)

Return type:

TD3_GOAL

SAC_GOAL

Backwards-compat shim for sb3_ros_support.sac_goal.SAC_GOAL.

Goal-conditioned SAC support was consolidated into the single sb3_ros_support.sac.SAC class. SAC now auto-detects the policy type from the env’s observation space (gymnasium.spaces.Dict"MultiInputPolicy", otherwise "MlpPolicy") and enables Hindsight Experience Replay (HER) via the use_her constructor flag or the YAML config’s use_HER key.

SAC_GOAL remains a working alias for backwards compatibility but emits a DeprecationWarning on instantiation. New code should import sb3_ros_support.sac.SAC directly.

class sb3_ros_support.sac_goal.SAC_GOAL(*args, **kwargs)[source]

Bases: SAC

Deprecated. Use sb3_ros_support.sac.SAC with use_her=True.

Parameters:
  • env (gym.Env) – The environment to be used.

  • save_model_path (str) – The path to save the model. Can be absolute or relative.

  • log_path (str) – The abs path to save the log. Can be absolute or relative.

  • model_pkg_path (str) – The package name to save or load the model.

  • load_trained (bool) – Whether to load a trained model or not.

  • load_model_path (str) – The path to load the model. Should include the model name. Can be absolute or relative.

  • config_file_pkg (str) – The package name of the config file. Required if abs_config_path is not provided.

  • config_filename (str) – The name of the config file. Required if abs_config_path is not provided.

  • abs_config_path (str) – The absolute path to the config file. Required if config_file_pkg and config_filename are not provided.

  • use_her (bool) – Whether to use Hindsight Experience Replay. Only valid for goal-conditioned envs (Dict obs space).

  • seed (int) – If provided, overrides the YAML sac_params.seed for the SB3 learner (PyTorch / replay-buffer RNG).

static load_trained_model(model_path, model_pkg=None, env=None, config_file_pkg=None, config_filename=None, abs_config_path=None, use_her=True)[source]

Backwards-compat loader preserving the pre-cleanup defaults.

Defaults config_filename to sac_goal.yaml and use_her to True so legacy callers using SAC_GOAL.load_trained_model(path, env=goal_env) keep the goal-conditioned behaviour instead of silently picking up sac.yaml with HER disabled.

Parameters:
  • model_path (str)

  • model_pkg (str | None)

  • env (Any | None)

  • config_file_pkg (str | None)

  • config_filename (str | None)

  • abs_config_path (str | None)

  • use_her (bool)

Return type:

SAC_GOAL

DQN_GOAL

Backwards-compat shim for sb3_ros_support.dqn_goal.DQN_GOAL.

Goal-conditioned DQN support was consolidated into the single sb3_ros_support.dqn.DQN class. DQN now auto-detects the policy type from the env’s observation space (gymnasium.spaces.Dict"MultiInputPolicy", otherwise "MlpPolicy") and enables Hindsight Experience Replay (HER) via the use_her constructor flag or the YAML config’s use_HER key.

DQN_GOAL remains a working alias for backwards compatibility but emits a DeprecationWarning on instantiation. New code should import sb3_ros_support.dqn.DQN directly.

class sb3_ros_support.dqn_goal.DQN_GOAL(*args, **kwargs)[source]

Bases: DQN

Deprecated. Use sb3_ros_support.dqn.DQN with use_her=True.

Parameters:
  • env (gym.Env) – The environment to be used.

  • save_model_path (str) – The path to save the model. Can be absolute or relative.

  • log_path (str) – The abs path to save the log. Can be absolute or relative.

  • model_pkg_path (str) – The package name to save or load the model.

  • load_trained (bool) – Whether to load a trained model or not.

  • load_model_path (str) – The path to load the model. Should include the model name. Can be absolute or relative.

  • config_file_pkg (str) – The package name of the config file. Required if abs_config_path is not provided.

  • config_filename (str) – The name of the config file. Required if abs_config_path is not provided.

  • abs_config_path (str) – The absolute path to the config file. Required if config_file_pkg and config_filename are not provided.

  • use_her (bool) – Whether to use Hindsight Experience Replay. Only valid for goal-conditioned envs (Dict obs space).

  • seed (int) – If provided, overrides the YAML dqn_params.seed for the SB3 learner (PyTorch / replay-buffer RNG).

static load_trained_model(model_path, model_pkg=None, env=None, config_file_pkg=None, config_filename=None, abs_config_path=None, use_her=True)[source]

Backwards-compat loader preserving the pre-cleanup defaults.

Defaults config_filename to dqn_goal.yaml and use_her to True so legacy callers using DQN_GOAL.load_trained_model(path, env=goal_env) keep the goal-conditioned behaviour instead of silently picking up dqn.yaml with HER disabled.

Parameters:
  • model_path (str)

  • model_pkg (str | None)

  • env (Any | None)

  • config_file_pkg (str | None)

  • config_filename (str | None)

  • abs_config_path (str | None)

  • use_her (bool)

Return type:

DQN_GOAL

Utilities

sb3_ros_support.utils.sb3_common.get_policy_kwargs(parm_dict)[source]

Function to get the policy kwargs from the parm_dict.

Parameters:

parm_dict (dict) – The dictionary containing the parameters.

Returns:

Dictionary containing the policy kwargs.

Return type:

dict

sb3_ros_support.utils.sb3_common.get_action_noise(action_space_shape, parm_dict, action_noise_type='normal')[source]

Function to get the action noise from the parm_dict.

Parameters:
  • action_space_shape (int) – The shape of the action space.

  • parm_dict (dict) – The dictionary containing the parameters.

  • action_noise_type (str) – The type of action noise to use. Can be “normal” or “ornstein”.

Returns:

The action noise.

Return type:

action_noise

sb3_ros_support.utils.sb3_common.test_env(env)[source]

Use SB3 env checker.

Return type:

bool

sb3_ros_support.utils.sb3_common.is_dict_obs_space(env)[source]

Return True if env.observation_space is a Dict (goal-conditioned).

The algorithm classes use this to auto-select between "MlpPolicy" (Box observation) and "MultiInputPolicy" (Dict observation) so callers don’t have to specify the policy by hand or pick the right algorithm class for their env.

Return type:

bool

sb3_ros_support.utils.sb3_common.her_replay_buffer_kwargs(parm_dict)[source]

Build HER replay-buffer kwargs from a config’s her_params block.

Defaults match the previous algorithm-specific code paths: 4 sampled goals per real one, "future" goal-selection strategy.

Parameters:

parm_dict (dict)

Return type:

dict

class sb3_ros_support.utils.sb3_common.TimeLimitCallback(*args, **kwargs)[source]

Bases: BaseCallback

Callback for setting an action cycle for training.

Parameters:
  • action_cycle_time (float) – The time in seconds for the action cycle.

  • verbose (int) – The verbosity level: 0 none, 1 training information, 2 debug.

sb3_ros_support.utils.yaml_utils.load_yaml(pkg_name=None, file_name=None, file_abs_path=None)[source]

Fetch a YAML file from a package or an abs path, parse and converts to a Python dictionary.

Parameters:
  • pkg_name (str) – name of package. Required if file_abs_path is None.

  • file_name (str) – name of file. Required if file_abs_path is None.

  • file_abs_path (str) – Absolute path of the YAML file. Required if pkg_name and file_name are None.

Returns:

Dictionary containing the YAML file.

Return type:

dict