Source code for sb3_ros_support.core

#!/bin/python3
"""
Extending the SB3 models of the frobs_rl library.

Recreated to overcome the following errors:
    - Cannot use with multiple environments
    - Support for goal-conditioned environments
    - pyyaml for loading parameters
"""
import os
from datetime import datetime
from typing import Any, Dict, Optional, Tuple

from sb3_ros_support.utils.sb3_common import get_policy_kwargs, get_action_noise, test_env, TimeLimitCallback

# ROS packages required
import rospy

# SB3 Callbacks
from stable_baselines3.common.callbacks import CheckpointCallback

# Logger
from stable_baselines3.common.logger import configure


[docs] class BasicModel: """ Base class for all the algorithms of Stable Baselines3. """ def __init__(self, env: Any, save_model_path: str, log_path: str, parm_dict: Dict[str, Any], load_trained: bool = False, action_noise_type: str = "normal", action_noise: bool = True) -> None: """ Args: env (gym.Env): The environment to be used. save_model_path (str): The path to save the model. parm_dict (dict): The dictionary containing the parameters. log_path (str): The path to save the log. load_trained (bool): Whether to load a trained model or not. action_noise_type (str): The type of action noise to use. Can be "normal" or "ornstein". (Optional) action_noise (bool): Whether to use action noise or not. (Optional) """ self.env = env self.save_model_path = save_model_path self.log_path = log_path self.save_trained_model_path = None self.model = None self.parm_dict = parm_dict if load_trained is False: # --- Policy kwargs self.policy_kwargs = get_policy_kwargs(parm_dict) # --- Noise kwargs if action_noise: self.action_noise = get_action_noise(self.env.action_space.shape[-1], parm_dict, action_noise_type) else: self.action_noise = None # --- Callback save_freq = parm_dict["save_freq"] save_prefix = parm_dict["save_prefix"] self.checkpoint_callback = CheckpointCallback(save_freq=save_freq, save_path=save_model_path, name_prefix=save_prefix)
[docs] def train(self, action_cycle_time: Optional[float] = None) -> bool: """ Function to train the model the number of steps specified in the yaml config file. The function will automatically save the model after training. Args: action_cycle_time (float): The time to wait between actions. (Optional) Returns: bool: True if the model was trained, False otherwise. """ training_steps = self.parm_dict["training_steps"] learn_log_int = self.parm_dict["log_interval"] learn_reset_num_tm = self.parm_dict["reset_num_timesteps"] if learn_reset_num_tm is False: self.env = self.model.get_env() self.env.reset() # Create the list of callbacks callbacks = [self.checkpoint_callback] if action_cycle_time is not None: # Create the callback time_limit_callback = TimeLimitCallback(action_cycle_time=action_cycle_time) callbacks.append(time_limit_callback) self.model.learn(total_timesteps=int(training_steps), callback=callbacks, log_interval=learn_log_int, reset_num_timesteps=learn_reset_num_tm) self.save_model() return True
[docs] def save_model(self) -> bool: """ Function to save the model. Returns: bool: True if the model was saved, False otherwise. """ # --- Model name trained_model_name = self.parm_dict["trained_model_name"] # If file exists, name the new model with a suffix self.save_trained_model_path = self.save_model_path + trained_model_name if os.path.isfile(self.save_model_path + trained_model_name + ".zip"): now = datetime.now() dt_string = now.strftime("%d_%m_%Y_%H_%M_%S") self.save_trained_model_path = self.save_trained_model_path + "_" + dt_string rospy.logwarn("Trained model name already exists, saving as: " + trained_model_name + "_" + dt_string) self.model.save(self.save_trained_model_path) self.save_replay_buffer() return True
[docs] def save_replay_buffer(self) -> None: """ Function to save the replay buffer, to be used the training must be finished or an error will be raised. Returns: bool: True if the replay buffer was saved, False otherwise. """ if self.save_trained_model_path is None: raise ValueError("Model not trained yet, cannot save replay buffer") if self.parm_dict["save_replay_buffer"]: rospy.logwarn("Saving replay buffer") self.model.save_replay_buffer(self.save_trained_model_path + '_replay_buffer')
[docs] def set_model_logger(self) -> bool: """ Function to set a logger of the model. Returns: bool: True if the logger was set, False otherwise. """ log_folder = self.parm_dict["log_folder"] log_path = self.log_path + log_folder assert not os.path.exists(log_path), "Log folder already exists, to log into that folder first delete it." new_logger = configure(log_path + '/', ["stdout", "csv", "tensorboard"]) self.model.set_logger(new_logger) return True
[docs] def close_env(self) -> bool: """ Use the env close method to close the environment. Returns: bool: True if the environment was closed, False otherwise. """ self.env.close() return True
[docs] def check_env(self) -> bool: """ Use the stable-baselines check_env method to check the environment. Returns: bool: True if the environment was checked, False otherwise. """ test_env(self.env) return True
[docs] def predict(self, observation: Any, state: Optional[Any] = None, deterministic: bool = False) -> Tuple[Any, Any]: """ Get the current action based on the observation, state or mask Args: observation (ndarray): The environment observation. state (ndarray): The previous states of the environment, used in recurrent policies. (Optional) deterministic (bool): Whether to return deterministic actions or not. (Optional) Returns: ndarray: The action to be taken. """ return self.model.predict(observation, state=state, deterministic=deterministic)