Source code for sb3_ros_support.utils.sb3_common

#!/bin/python3

import numpy as np
import torch as th
import stable_baselines3

# ROS packages required
import rospy

# Noise
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
from stable_baselines3.common.env_checker import check_env

# Callbacks
from stable_baselines3.common.callbacks import BaseCallback
import time


[docs] def get_policy_kwargs(parm_dict: dict) -> dict: """ Function to get the policy kwargs from the parm_dict. Args: parm_dict (dict): The dictionary containing the parameters. Returns: dict: Dictionary containing the policy kwargs. """ if parm_dict["use_custom_policy"]: # Activation function for the policy activation_function = parm_dict["policy_params"]["activation_fn"].lower() if activation_function == "relu": activation_fn = th.nn.ReLU elif activation_function == "tanh": activation_fn = th.nn.Tanh elif activation_function == "elu": activation_fn = th.nn.ELU elif activation_function == "selu": activation_fn = th.nn.SELU else: rospy.logwarn("Activation function not found, using ReLU") activation_fn = th.nn.ReLU # Feature extractor for the policy feature_extractor = parm_dict["policy_params"]["features_extractor_class"] if feature_extractor == "FlattenExtractor": features_extractor_class = stable_baselines3.common.torch_layers.FlattenExtractor elif feature_extractor == "BaseFeaturesExtractor": features_extractor_class = stable_baselines3.common.torch_layers.BaseFeaturesExtractor elif feature_extractor == "CombinedExtractor": features_extractor_class = stable_baselines3.common.torch_layers.CombinedExtractor else: rospy.logwarn("Feature extractor not found, using FlattenExtractor") features_extractor_class = stable_baselines3.common.torch_layers.FlattenExtractor # Optimizer for the policy optimizer_class = parm_dict["policy_params"]["optimizer_class"] if optimizer_class == "Adam": optimizer_class = th.optim.Adam elif optimizer_class == "SGD": optimizer_class = th.optim.SGD elif optimizer_class == "RMSprop": optimizer_class = th.optim.RMSprop elif optimizer_class == "Adagrad": optimizer_class = th.optim.Adagrad elif optimizer_class == "Adadelta": optimizer_class = th.optim.Adadelta else: rospy.logwarn("Optimizer not found, using Adam") optimizer_class = th.optim.Adam # Net Archiecture for the policy net_arch = parm_dict["policy_params"]["net_arch"] policy_kwargs = dict(activation_fn=activation_fn, features_extractor_class=features_extractor_class, optimizer_class=optimizer_class, net_arch=net_arch) # log rospy.logwarn(policy_kwargs) else: policy_kwargs = None return policy_kwargs
[docs] def get_action_noise(action_space_shape: int, parm_dict: dict, action_noise_type: str = "normal"): """ Function to get the action noise from the parm_dict. Args: action_space_shape (int): The shape of the action space. parm_dict (dict): The dictionary containing the parameters. action_noise_type (str): The type of action noise to use. Can be "normal" or "ornstein". Returns: action_noise: The action noise. """ action_noise = None if parm_dict["use_action_noise"] is None: rospy.loginfo("Parameter use_action_noise was not found") return action_noise if parm_dict["use_action_noise"] is True: action_mean = parm_dict["action_noise"]["mean"] action_sigma = parm_dict["action_noise"]["sigma"] # normal if action_noise_type == "normal": # create noise action_noise = NormalActionNoise(mean=action_mean * np.ones(action_space_shape), sigma=action_sigma * np.ones(action_space_shape)) # ornstein elif action_noise_type == "ornstein": action_theta = parm_dict["action_noise"]["theta"] action_dt = parm_dict["action_noise"]["dt"] # initial noise if parm_dict["action_noise"]["initial_noise"] is not None: action_initial_noise = parm_dict["action_noise"]["initial_noise"] else: action_initial_noise = None # create noise action_noise = OrnsteinUhlenbeckActionNoise(mean=action_mean * np.ones(action_space_shape), sigma=action_sigma * np.ones(action_space_shape), theta=action_theta, dt=action_dt, initial_noise=action_initial_noise) return action_noise
[docs] def test_env(env) -> bool: """ Use SB3 env checker. """ check_env(env) return True
[docs] def is_dict_obs_space(env) -> bool: """Return True if ``env.observation_space`` is a Dict (goal-conditioned). The algorithm classes use this to auto-select between ``"MlpPolicy"`` (Box observation) and ``"MultiInputPolicy"`` (Dict observation) so callers don't have to specify the policy by hand or pick the right algorithm class for their env. """ try: import gymnasium if isinstance(env.observation_space, gymnasium.spaces.Dict): return True except ImportError: pass try: import gym if isinstance(env.observation_space, gym.spaces.Dict): return True except ImportError: pass return False
[docs] def her_replay_buffer_kwargs(parm_dict: dict) -> dict: """Build HER replay-buffer kwargs from a config's ``her_params`` block. Defaults match the previous algorithm-specific code paths: 4 sampled goals per real one, ``"future"`` goal-selection strategy. """ her = parm_dict.get("her_params", {}) return dict( n_sampled_goal=her.get("n_sampled_goal", 4), goal_selection_strategy=her.get("goal_selection_strategy", "future"), )
[docs] class TimeLimitCallback(BaseCallback): """ Callback for setting an action cycle for training. """ def __init__(self, action_cycle_time: float, verbose: int = 0) -> None: """ Args: action_cycle_time (float): The time in seconds for the action cycle. verbose (int): The verbosity level: 0 none, 1 training information, 2 debug. """ super(TimeLimitCallback, self).__init__(verbose) self.action_cycle_time = action_cycle_time self.next_action_cycle = time.time() + action_cycle_time def _on_step(self) -> bool: # Wait until it's time for the next action cycle wait_time = max(0, self.next_action_cycle - time.time()) time.sleep(wait_time) self.next_action_cycle = time.time() + self.action_cycle_time return True