#!/bin/python3
import os
from typing import Any, Optional
import stable_baselines3
from sb3_ros_support import core
from sb3_ros_support.utils import yaml_utils
# ROS packages required
import rospy
import rospkg
[docs]
class A2C(core.BasicModel):
"""
Advantage Actor-Critic (A2C) algorithm.
Paper: https://arxiv.org/abs/1602.01783
"""
def __init__(self, env: Any, save_model_path: str, log_path: str,
model_pkg_path: Optional[str] = None, load_trained: bool = False,
load_model_path: Optional[str] = None,
config_file_pkg: Optional[str] = None,
config_filename: Optional[str] = None,
abs_config_path: Optional[str] = None,
seed: Optional[int] = None) -> None:
"""
Args:
env (gym.Env): The environment to be used.
save_model_path (str): The path to save the model. Can be absolute or relative.
log_path (str): The abs path to save the log. Can be absolute or relative.
model_pkg_path (str): The package name to save or load the model.
load_trained (bool): Whether to load a trained model or not.
load_model_path (str): The path to load the model. Should include the model name. Can be absolute or relative.
config_file_pkg (str): The package name of the config file. Required if abs_config_path is not provided.
config_filename (str): The name of the config file. Required if abs_config_path is not provided.
abs_config_path (str): The absolute path to the config file. Required if config_file_pkg and config_filename are not provided.
seed (int): If provided, overrides the YAML ``a2c_params.seed``
for the SB3 learner (PyTorch / rollout RNG).
"""
rospy.loginfo("Init A2C Policy")
# --- Set the environment
self.env = env
# --- Set the save and log path
if model_pkg_path is not None:
rospack = rospkg.RosPack()
pkg_path = rospack.get_path(model_pkg_path)
# check if the path starts with "/"
if save_model_path[0] != "/":
save_model_path = "/" + save_model_path
if log_path[0] != "/":
log_path = "/" + log_path
# check if the path ends with "/"
if save_model_path[-1] != "/":
save_model_path = save_model_path + "/"
if log_path[-1] != "/":
log_path = log_path + "/"
save_model_path = pkg_path + save_model_path
log_path = pkg_path + log_path
if load_trained:
# check if the path starts with "/"
if load_model_path[0] != "/":
load_model_path = "/" + load_model_path
load_model_path = pkg_path + load_model_path
# Load YAML Config File
parm_dict = yaml_utils.load_yaml(pkg_name=config_file_pkg, file_name=config_filename,
file_abs_path=abs_config_path)
# --- Init superclass
super().__init__(env, save_model_path, log_path, parm_dict, load_trained=load_trained, action_noise=False, seed=seed)
if load_trained:
rospy.logwarn("Loading trained model")
self.model = stable_baselines3.A2C.load(load_model_path, env=env)
else:
# --- SDE for A2C
if parm_dict["use_sde"]:
model_sde = True
model_sde_sample_freq = parm_dict["sde_params"]["sde_sample_freq"]
self.action_noise = None
else:
model_sde = False
model_sde_sample_freq = -1
# --- A2C model parameters
model_learning_rate = parm_dict["a2c_params"]["learning_rate"]
model_n_steps = parm_dict["a2c_params"]["n_steps"]
model_gamma = parm_dict["a2c_params"]["gamma"]
model_gae_lambda = parm_dict["a2c_params"]["gae_lambda"]
model_ent_coef = parm_dict["a2c_params"]["ent_coef"]
model_vf_coef = parm_dict["a2c_params"]["vf_coef"]
model_max_grad_norm = parm_dict["a2c_params"]["max_grad_norm"]
model_use_rms_prop = parm_dict["a2c_params"]["use_rms_prop"]
model_rms_prop_eps = parm_dict["a2c_params"]["rms_prop_eps"]
model_norm_advant = parm_dict["a2c_params"]["normalize_advantage"]
model_seed = seed if seed is not None else parm_dict["a2c_params"]["seed"]
# --- Create or load model
if parm_dict["load_model"]: # Load model
model_name = parm_dict["model_name"]
assert os.path.exists(save_model_path + model_name + ".zip"), "Model {} doesn't exist".format(
model_name)
rospy.logwarn("Loading model: " + model_name)
self.model = stable_baselines3.A2C.load(save_model_path + model_name, env=env, verbose=1,
learning_rate=model_learning_rate,
n_steps=model_n_steps, gamma=model_gamma,
gae_lambda=model_gae_lambda, ent_coef=model_ent_coef,
vf_coef=model_vf_coef, max_grad_norm=model_max_grad_norm,
use_sde=model_sde, sde_sample_freq=model_sde_sample_freq,
use_rms_prop=model_use_rms_prop,
rms_prop_eps=model_rms_prop_eps,
normalize_advantage=model_norm_advant,
seed=model_seed)
if os.path.exists(save_model_path + model_name + "_replay_buffer.pkl"):
rospy.logwarn("Loading replay buffer")
self.model.load_replay_buffer(save_model_path + model_name + "_replay_buffer")
else:
rospy.logwarn("No replay buffer found")
else: # Create new model
rospy.logwarn("Creating new model")
self.model = stable_baselines3.A2C("MlpPolicy", env, verbose=1, policy_kwargs=self.policy_kwargs,
learning_rate=model_learning_rate, n_steps=model_n_steps,
gamma=model_gamma,
gae_lambda=model_gae_lambda, ent_coef=model_ent_coef,
vf_coef=model_vf_coef, max_grad_norm=model_max_grad_norm,
use_sde=model_sde, sde_sample_freq=model_sde_sample_freq,
use_rms_prop=model_use_rms_prop, rms_prop_eps=model_rms_prop_eps,
normalize_advantage=model_norm_advant,
seed=model_seed)
# --- Logger
self.set_model_logger()
[docs]
@staticmethod
def load_trained_model(model_path: str, model_pkg: Optional[str] = None,
env: Optional[Any] = None,
config_file_pkg: Optional[str] = None,
config_filename: Optional[str] = None,
abs_config_path: Optional[str] = None) -> "A2C":
"""
Load a trained model. Use only with predict function, as the logs will not be saved.
Args:
model_path (str): The path to the trained model. Can be absolute or relative.
model_pkg (str): The package name to load the model. Required if abs_model_path is relative.
env (gym.Env): The environment to be used.
config_file_pkg (str): The package name of the config file. Use the same package as model_pkg if not provided.
config_filename (str): The name of the config file.
abs_config_path (str): The absolute path to the config file.
Returns:
model: The loaded model.
"""
if config_file_pkg is None and config_filename is None and abs_config_path is None:
config_file_pkg = "sb3_ros_support"
config_filename = "a2c.yaml"
rospy.logwarn("Using default config file: " + config_filename + " from package: " + config_file_pkg)
elif model_pkg is not None and config_filename is not None and config_file_pkg is None:
config_file_pkg = model_pkg
model = A2C(env=env, save_model_path=model_path, log_path=model_path, model_pkg_path=model_pkg,
load_trained=True, load_model_path=model_path, config_file_pkg=config_file_pkg,
config_filename=config_filename, abs_config_path=abs_config_path)
return model