Source code for omnisafe.utils.tools

# Copyright 2022-2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""tool_function_packages"""

import hashlib
import json
import os
import random
import sys

import numpy as np
import torch
import yaml
from rich.console import Console


[docs]def get_flat_params_from(model: torch.nn.Module) -> torch.Tensor: """This function is used to get the flattened parameters from the model. .. note:: Some algorithms need to get the flattened parameters from the model, such as the :class:`TRPO` and :class:`CPO` algorithm. In these algorithms, the parameters are flattened and then used to calculate the loss. Example: >>> model = torch.nn.Linear(2, 2) >>> model.weight.data = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) >>> get_flat_params_from(model) tensor([1., 2., 3., 4.]) Args: model (torch.nn.Module): model to be flattened. """ flat_params = [] for _, param in model.named_parameters(): if param.requires_grad: data = param.data data = data.view(-1) # flatten tensor flat_params.append(data) assert flat_params, 'No gradients were found in model parameters.' return torch.cat(flat_params)
[docs]def get_flat_gradients_from(model: torch.nn.Module) -> torch.Tensor: """This function is used to get the flattened gradients from the model. .. note:: Some algorithms need to get the flattened gradients from the model, such as the :class:`TRPO` and :class:`CPO` algorithm. In these algorithms, the gradients are flattened and then used to calculate the loss. Args: model (torch.nn.Module): model to be flattened. """ grads = [] for _, param in model.named_parameters(): if param.requires_grad and param.grad is not None: grad = param.grad grads.append(grad.view(-1)) # flatten tensor and append assert grads, 'No gradients were found in model parameters.' return torch.cat(grads)
[docs]def set_param_values_to_model(model: torch.nn.Module, vals: torch.Tensor) -> None: """This function is used to set the parameters to the model. .. note:: Some algorithms (e.g. TRPO, CPO, etc.) need to set the parameters to the model, instead of using the ``optimizer.step()``. Example: >>> model = torch.nn.Linear(2, 2) >>> model.weight.data = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) >>> vals = torch.tensor([1.0, 2.0, 3.0, 4.0]) >>> set_param_values_to_model(model, vals) >>> model.weight.data tensor([[1., 2.], [3., 4.]]) Args: model (torch.nn.Module): model to be set. vals (torch.Tensor): parameters to be set. """ assert isinstance(vals, torch.Tensor) i: int = 0 for _, param in model.named_parameters(): if param.requires_grad: # param has grad and, hence, must be set orig_size = param.size() size = np.prod(list(param.size())) new_values = vals[i : int(i + size)] # set new param values new_values = new_values.view(orig_size) param.data = new_values i += int(size) # increment array position assert i == len(vals), f'Lengths do not match: {i} vs. {len(vals)}'
[docs]def seed_all(seed: int): """This function is used to set the random seed for all the packages. .. hint:: To reproduce the results, you need to set the random seed for all the packages. Including ``numpy``, ``random``, ``torch``, ``torch.cuda``, ``torch.backends.cudnn``. .. warning:: If you want to use the ``torch.backends.cudnn.benchmark`` or ``torch.backends.cudnn. deterministic`` and your ``cuda`` version is over 10.2, you need to set the ``CUBLAS_WORKSPACE_CONFIG`` and ``PYTHONHASHSEED`` environment variables. """ os.environ['PYTHONHASHSEED'] = str(seed) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) try: torch.use_deterministic_algorithms(True) torch.backends.cudnn.enabled = False torch.backends.cudnn.benchmark = False if float(torch.version.cuda) >= 10.2: os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' os.environ['PYTHONHASHSEED'] = str(seed) except AttributeError: pass
[docs]def custom_cfgs_to_dict(key_list, value): """This function is used to convert the custom configurations to dict. .. note:: This function is used to convert the custom configurations to dict. For example, if the custom configurations are ``train_cfgs:use_wandb`` and ``True``, then the output dict will be ``{'train_cfgs': {'use_wandb': True}}``. Args: key_list (list): list of keys. value: value. """ if value == 'True': value = True elif value == 'False': value = False elif '.' in value: value = float(value) elif value.isdigit(): value = int(value) elif value.startswith('[') and value.endswith(']'): value = value[1:-1] value = value.split(',') else: value = str(value) keys_split = key_list.replace('-', '_').split(':') return_dict = {keys_split[-1]: value} for key in reversed(keys_split[:-1]): return_dict = {key.replace('-', '_'): return_dict} return return_dict
[docs]def update_dic(total_dic, item_dic): '''Updater of multi-level dictionary.''' for idd in item_dic.keys(): total_value = total_dic.get(idd) item_value = item_dic.get(idd) if total_value is None: total_dic.update({idd: item_value}) elif isinstance(item_value, dict): update_dic(total_value, item_value) total_dic.update({idd: total_value}) else: total_value = item_value total_dic.update({idd: total_value})
[docs]def load_yaml(path) -> dict: """Get the default kwargs from ``yaml`` file. .. note:: This function search the ``yaml`` file by the algorithm name and environment name. Make sure your new implemented algorithm or environment has the same name as the yaml file. Args: path (str): path of the ``yaml`` file. """ with open(path, encoding='utf-8') as file: try: kwargs = yaml.load(file, Loader=yaml.FullLoader) except yaml.YAMLError as exc: assert False, f'{path} error: {exc}' return kwargs
[docs]def recursive_check_config(config, default_config, exclude_keys=()): """Check whether config is valid in default_config. Args: config (dict): config to be checked. default_config (dict): default config. """ for key in config.keys(): if key not in default_config.keys() and key not in exclude_keys: raise KeyError(f'Invalid key: {key}') if isinstance(config[key], dict): recursive_check_config(config[key], default_config[key])
def assert_with_exit(condition, msg) -> None: """Assert with message. Args: condition (bool): condition to be checked. msg (str): message to be printed. """ try: assert condition except AssertionError: console = Console() console.print('ERROR: ' + msg, style='bold red') sys.exit(1) def recursive_dict2json(dict_obj) -> str: """This function is used to recursively convert the dict to json. Args: dict_obj (dict): dict to be converted. """ assert isinstance(dict_obj, dict), 'Input must be a dict.' flat_dict = {} def _flatten_dict(dict_obj, path=''): if isinstance(dict_obj, dict): for key, value in dict_obj.items(): _flatten_dict(value, path + key + ':') else: flat_dict[path[:-1]] = dict_obj _flatten_dict(dict_obj) flat_dict_str = json.dumps(flat_dict, sort_keys=True).replace('"', "'") return flat_dict_str def hash_string(string) -> str: """This function is used to generate the folder name. Args: string (str): string to be hashed. """ salt = b'\xf8\x99/\xe4\xe6J\xd8d\x1a\x9b\x8b\x98\xa2\x1d\xff3*^\\\xb1\xc1:e\x11M=PW\x03\xa5\\h' # convert string to bytes and add salt salted_string = salt + string.encode('utf-8') # use sha256 to hash hash_object = hashlib.sha256(salted_string) # get the hex digest folder_name = hash_object.hexdigest() return folder_name