# Copyright 2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of Config."""
from __future__ import annotations
import json
import os
from typing import Any
from omnisafe.typing import Activation, ActorType, AdvatageEstimator, InitFunction
from omnisafe.utils.tools import load_yaml
[docs]class Config(dict):
"""Config class for storing hyperparameters.
OmniSafe uses a Config class to store all hyperparameters. OmniSafe store hyperparameters in a
yaml file and load them into a Config object. Then the Config class will check the
hyperparameters are valid, then pass them to the algorithm class.
Attributes:
seed (int): Random seed.
device (str): Device to use for training.
device_id (int): Device id to use for training.
wrapper_type (str): Wrapper type.
epochs (int): Number of epochs.
steps_per_epoch (int): Number of steps per epoch.
actor_iters (int): Number of actor iterations.
critic_iters (int): Number of critic iterations.
check_freq (int): Frequency of checking.
save_freq (int): Frequency of saving.
entropy_coef (float): Entropy coefficient.
max_ep_len (int): Maximum episode length.
num_mini_batches (int): Number of mini batches.
actor_lr (float): Actor learning rate.
critic_lr (float): Critic learning rate.
log_dir (str): Log directory.
target_kl (float): Target KL divergence.
batch_size (int): Batch size.
use_cost (bool): Whether to use cost.
cost_gamma (float): Cost gamma.
linear_lr_decay (bool): Whether to use linear learning rate decay.
exploration_noise_anneal (bool): Whether to use exploration noise anneal.
penalty_param (float): Penalty parameter.
kl_early_stop (bool): Whether to use KL early stop.
use_max_grad_norm (bool): Whether to use max gradient norm.
max_grad_norm (float): Max gradient norm.
use_critic_norm (bool): Whether to use critic norm.
critic_norm_coeff (bool): Critic norm coefficient.
model_cfgs (ModelConfig): Model config.
buffer_cfgs (Config): Buffer config.
gamma (float): Discount factor.
lam (float): Lambda.
lam_c (float): Lambda for cost.
adv_eastimator (AdvatageEstimator): Advantage estimator.
standardized_rew_adv (bool): Whether to use standardized reward advantage.
standardized_cost_adv (bool): Whether to use standardized cost advantage.
env_cfgs (Config): Environment config.
num_envs (int): Number of environments.
async_env (bool): Whether to use asynchronous environments.
env_name (str): Environment name.
env_kwargs (dict): Environment keyword arguments.
normalize_obs (bool): Whether to normalize observation.
normalize_rew (bool): Whether to normalize reward.
normalize_cost (bool): Whether to normalize cost.
max_len (int): Maximum length.
num_threads (int): Number of threads.
Keyword Args:
kwargs (Any): keyword arguments to set the attributes.
"""
seed: int
device: str
device_id: int
wrapper_type: str
epochs: int
steps_per_epoch: int
actor_iters: int
critic_iters: int
check_freq: int
save_freq: int
entropy_coef: float
max_ep_len: int
num_mini_batches: int
actor_lr: float
critic_lr: float
log_dir: str
target_kl: float
batch_size: int
use_cost: bool
cost_gamma: float
linear_lr_decay: bool
exploration_noise_anneal: bool
penalty_param: float
kl_early_stop: bool
use_max_grad_norm: bool
max_grad_norm: float
use_critic_norm: bool
critic_norm_coeff: bool
model_cfgs: ModelConfig
buffer_cfgs: Config
gamma: float
lam: float
lam_c: float
adv_eastimator: AdvatageEstimator
standardized_rew_adv: bool
standardized_cost_adv: bool
env_cfgs: Config
num_envs: int
async_env: bool
normalized_rew: bool
normalized_cost: bool
normalized_obs: bool
max_len: int
num_threads: int
def __init__(self, **kwargs: Any) -> None:
"""Initialize an instance of :class:`Config`."""
for key, value in kwargs.items():
if isinstance(value, dict):
self[key] = Config.dict2config(value)
else:
self[key] = value
def __getattr__(self, name: str) -> Any:
"""Get attribute."""
try:
return self[name]
except KeyError:
return super().__getattribute__(name)
def __setattr__(self, name: str, value: Any) -> None:
"""Set attribute."""
self[name] = value
[docs] def get(self, name: str, default: Any = None) -> Any:
"""Get attribute."""
try:
return self[name]
except KeyError:
return default
[docs] def todict(self) -> dict[str, Any]:
"""Convert Config to dictionary.
Returns:
The dictionary of Config.
"""
config_dict: dict[str, Any] = {}
for key, value in self.items():
if isinstance(value, Config):
config_dict[key] = value.todict()
else:
config_dict[key] = value
return config_dict
[docs] def tojson(self) -> str:
"""Convert Config to json string.
Returns:
The json string of Config.
"""
return json.dumps(self.todict(), indent=4)
[docs] @staticmethod
def dict2config(config_dict: dict[str, Any]) -> Config:
"""Convert dictionary to Config.
Args:
config_dict (dict[str, Any]): The dictionary to be converted.
Returns:
The algorithm config.
"""
config = Config()
for key, value in config_dict.items():
if isinstance(value, dict):
config[key] = Config.dict2config(value)
else:
config[key] = value
return config
[docs] def recurisve_update(self, update_args: dict[str, Any]) -> None:
"""Recursively update args.
Args:
update_args (dict[str, Any]): Args to be updated.
"""
for key, value in self.items():
if key in update_args:
if isinstance(update_args[key], dict):
if isinstance(value, Config):
value.recurisve_update(update_args[key])
self[key] = value
else:
self[key] = Config.dict2config(update_args[key])
else:
self[key] = update_args[key]
for key, value in update_args.items():
if key not in self:
if isinstance(value, dict):
self[key] = Config.dict2config(value)
else:
self[key] = value
[docs]class ModelConfig(Config):
"""Model config."""
weight_initialization_mode: InitFunction
actor_type: ActorType
actor: ModelConfig
critic: ModelConfig
hidden_sizes: list[int]
activation: Activation
std: list[float]
use_obs_encoder: bool
lr: float | None
[docs]def get_default_kwargs_yaml(algo: str, env_id: str, algo_type: str) -> Config:
"""Get the default kwargs from ``yaml`` file.
.. note::
This function search the ``yaml`` file by the algorithm name and environment name. Make
sure your new implemented algorithm or environment has the same name as the yaml file.
Args:
algo (str): The algorithm name.
env_id (str): The environment name.
algo_type (str): The algorithm type.
Returns:
The default kwargs.
"""
path = os.path.dirname(os.path.abspath(__file__))
cfg_path = os.path.join(path, '..', 'configs', algo_type, f'{algo}.yaml')
print(f'Loading {algo}.yaml from {cfg_path}')
kwargs = load_yaml(cfg_path)
default_kwargs = kwargs['defaults']
env_spec_kwargs = kwargs.get(env_id)
default_kwargs = Config.dict2config(default_kwargs)
if env_spec_kwargs is not None:
default_kwargs.recurisve_update(env_spec_kwargs)
return default_kwargs
[docs]def check_all_configs(configs: Config, algo_type: str) -> None:
"""Check all configs.
This function is used to check the configs.
Args:
configs (Config): The configs to be checked.
algo_type (str): The algorithm type.
"""
__check_algo_configs(configs.algo_cfgs, algo_type)
__check_parallel_and_vectorized(configs, algo_type)
__check_logger_configs(configs.logger_cfgs)
def __check_parallel_and_vectorized(configs: Config, algo_type: str) -> None:
"""Check parallel and vectorized configs.
This function is used to check the parallel and vectorized configs.
Args:
configs (Config): The configs to be checked.
algo_type (str): The algorithm type.
"""
if algo_type in {'off-policy', 'model-based', 'offline'}:
assert (
configs.train_cfgs.parallel == 1
), 'off-policy, offline and model-based only support parallel==1!'
if configs.algo in ['PPOEarlyTerminated', 'TRPOEarlyTerminated']:
assert (
configs.train_cfgs.vector_env_nums == 1
), 'PPOEarlyTerminated or TRPOEarlyTerminated only support vector_env_nums == 1!'
[docs]def __check_algo_configs(configs: Config, algo_type: str) -> None:
"""Check algorithm configs.
This function is used to check the algorithm configs.
.. note::
- ``update_iters`` must be greater than 0 and must be int.
- ``steps_per_epoch`` must be greater than 0 and must be int.
- ``batch_size`` must be greater than 0 and must be int.
- ``target_kl`` must be greater than 0 and must be float.
- ``entropy_coeff`` must be in [0, 1] and must be float.
- ``gamma`` must be in [0, 1] and must be float.
- ``cost_gamma`` must be in [0, 1] and must be float.
- ``lam`` must be in [0, 1] and must be float.
- ``lam_c`` must be in [0, 1] and must be float.
- ``clip`` must be greater than 0 and must be float.
- ``penalty_coeff`` must be greater than 0 and must be float.
- ``reward_normalize`` must be bool.
- ``cost_normalize`` must be bool.
- ``obs_normalize`` must be bool.
- ``kl_early_stop`` must be bool.
- ``use_max_grad_norm`` must be bool.
- ``use_cost`` must be bool.
- ``max_grad_norm`` must be greater than 0 and must be float.
- ``adv_estimation_method`` must be in [``gae``, ``v-trace``, ``gae-rtg``, ``plain``].
- ``standardized_rew_adv`` must be bool.
- ``standardized_cost_adv`` must be bool.
Args:
configs (Config): The configs to be checked.
algo_type (str): The algorithm type.
"""
if algo_type == 'on-policy':
assert (
isinstance(configs.update_iters, int) and configs.update_iters > 0
), 'update_iters must be int and greater than 0'
assert (
isinstance(configs.steps_per_epoch, int) and configs.steps_per_epoch > 0
), 'steps_per_epoch must be int and greater than 0'
assert (
isinstance(configs.batch_size, int) and configs.batch_size > 0
), 'batch_size must be int and greater than 0'
assert (
isinstance(configs.target_kl, float) and configs.target_kl >= 0.0
), 'target_kl must be float and greater than 0.0'
assert (
isinstance(configs.entropy_coef, float)
and configs.entropy_coef >= 0.0
and configs.entropy_coef <= 1.0
), 'entropy_coef must be float, and it values must be [0.0, 1.0]'
assert isinstance(configs.reward_normalize, bool), 'reward_normalize must be bool'
assert isinstance(configs.cost_normalize, bool), 'cost_normalize must be bool'
assert isinstance(configs.obs_normalize, bool), 'obs_normalize must be bool'
assert isinstance(configs.kl_early_stop, bool), 'kl_early_stop must be bool'
assert isinstance(configs.use_max_grad_norm, bool), 'use_max_grad_norm must be bool'
assert isinstance(configs.use_critic_norm, bool), 'use_critic_norm must be bool'
assert isinstance(configs.max_grad_norm, float) and isinstance(
configs.critic_norm_coef,
float,
), 'norm must be float'
assert (
isinstance(configs.gamma, float) and configs.gamma >= 0.0 and configs.gamma <= 1.0
), 'gamma must be float, and it values must be [0.0, 1.0]'
assert (
isinstance(configs.cost_gamma, float)
and configs.cost_gamma >= 0.0
and configs.cost_gamma <= 1.0
), 'cost_gamma must be float, and it values must be [0.0, 1.0]'
assert (
isinstance(configs.lam, float) and configs.lam >= 0.0 and configs.lam <= 1.0
), 'lam must be float, and it values must be [0.0, 1.0]'
assert (
isinstance(configs.lam_c, float) and configs.lam_c >= 0.0 and configs.lam_c <= 1.0
), 'lam_c must be float, and it values must be [0.0, 1.0]'
if hasattr(configs, 'clip'):
assert (
isinstance(configs.clip, float) and configs.clip >= 0.0
), 'clip must be float, and it values must be [0.0, infty]'
assert isinstance(configs.adv_estimation_method, str) and configs.adv_estimation_method in [
'gae',
'gae-rtg',
'vtrace',
'plain',
], "adv_estimation_method must be string, and it values must be ['gae','gae-rtg','vtrace','plain']"
assert isinstance(configs.standardized_rew_adv, bool) and isinstance(
configs.standardized_cost_adv,
bool,
), 'standardized_<>_adv must be bool'
assert (
isinstance(configs.penalty_coef, float)
and configs.penalty_coef >= 0.0
and configs.penalty_coef <= 1.0
), 'penalty_coef must be float, and it values must be [0.0, 1.0]'
assert isinstance(configs.use_cost, bool), 'penalty_coef must be bool'
[docs]def __check_logger_configs(configs: Config) -> None:
"""Check logger configs.
Args:
configs (Config): The configs to be checked.
algo_type (str): The algorithm type.
"""
assert isinstance(configs.use_wandb, bool) and isinstance(
configs.wandb_project,
str,
), 'use_wandb and wandb_project must be bool and string'
assert isinstance(configs.use_tensorboard, bool), 'use_tensorboard must be bool'
assert isinstance(configs.save_model_freq, int), 'save_model_freq must be int'
if window_lens := configs.get('window_lens'):
assert isinstance(window_lens, int), 'window_lens must be int'
assert isinstance(configs.log_dir, str), 'log_dir must be string'