Source code for omnisafe.adapter.offpolicy_adapter

# Copyright 2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""OffPolicy Adapter for OmniSafe."""

from __future__ import annotations

from typing import Any

import torch

from omnisafe.adapter.online_adapter import OnlineAdapter
from omnisafe.common.buffer import VectorOffPolicyBuffer
from omnisafe.common.logger import Logger
from omnisafe.models.actor_critic.constraint_actor_q_critic import ConstraintActorQCritic
from omnisafe.utils.config import Config


[docs]class OffPolicyAdapter(OnlineAdapter): """OffPolicy Adapter for OmniSafe. :class:`OffPolicyAdapter` is used to adapt the environment to the off-policy training. .. note:: Off-policy training need to update the policy before finish the episode, so the :class:`OffPolicyAdapter` will store the current observation in ``_current_obs``. After update the policy, the agent will *remember* the current observation and use it to interact with the environment. Args: env_id (str): The environment id. num_envs (int): The number of environments. seed (int): The random seed. cfgs (Config): The configuration. """ _current_obs: torch.Tensor _ep_ret: torch.Tensor _ep_cost: torch.Tensor _ep_len: torch.Tensor def __init__( # pylint: disable=too-many-arguments self, env_id: str, num_envs: int, seed: int, cfgs: Config, ) -> None: """Initialize a instance of :class:`OffPolicyAdapter`.""" super().__init__(env_id, num_envs, seed, cfgs) self._current_obs, _ = self.reset() self._max_ep_len: int = 1000 self._reset_log()
[docs] def eval_policy( # pylint: disable=too-many-locals self, episode: int, agent: ConstraintActorQCritic, logger: Logger, ) -> None: """Rollout the environment with deterministic agent action. Args: episode (int): Number of episodes. agent (ConstraintActorCritic): Agent. logger (Logger): Logger, to log ``EpRet``, ``EpCost``, ``EpLen``. """ for _ in range(episode): ep_ret, ep_cost, ep_len = 0.0, 0.0, 0 obs, _ = self._eval_env.reset() obs = obs.to(self._device) done = False while not done: act = agent.step(obs, deterministic=True) obs, reward, cost, terminated, truncated, info = self._eval_env.step(act) obs, reward, cost, terminated, truncated = ( torch.as_tensor(x, dtype=torch.float32, device=self._device) for x in (obs, reward, cost, terminated, truncated) ) ep_ret += info.get('original_reward', reward).cpu() ep_cost += info.get('original_cost', cost).cpu() ep_len += 1 done = bool(terminated[0].item()) or bool(truncated[0].item()) logger.store( { 'Metrics/TestEpRet': ep_ret, 'Metrics/TestEpCost': ep_cost, 'Metrics/TestEpLen': ep_len, }, )
[docs] def rollout( # pylint: disable=too-many-locals self, rollout_step: int, agent: ConstraintActorQCritic, buffer: VectorOffPolicyBuffer, logger: Logger, use_rand_action: bool, ) -> None: """Rollout the environment and store the data in the buffer. .. warning:: As OmniSafe uses :class:`AutoReset` wrapper, the environment will be reset automatically, so the final observation will be stored in ``info['final_observation']``. Args: rollout_step (int): Number of rollout steps. agent (ConstraintActorCritic): Constraint actor-critic, including actor, reward critic, and cost critic. buffer (VectorOnPolicyBuffer): Vector on-policy buffer. logger (Logger): Logger, to log ``EpRet``, ``EpCost``, ``EpLen``. use_rand_action (bool): Whether to use random action. """ for _ in range(rollout_step): if use_rand_action: act = torch.as_tensor(self._env.sample_action(), dtype=torch.float32).to( self._device, ) else: act = agent.step(self._current_obs, deterministic=False) next_obs, reward, cost, terminated, truncated, info = self.step(act) self._log_value(reward=reward, cost=cost, info=info) real_next_obs = next_obs.clone() for idx, done in enumerate(torch.logical_or(terminated, truncated)): if done: real_next_obs[idx] = info['final_observation'][idx] self._log_metrics(logger, idx) self._reset_log(idx) buffer.store( obs=self._current_obs, act=act, reward=reward, cost=cost, done=torch.logical_and(terminated, torch.logical_xor(terminated, truncated)), next_obs=real_next_obs, ) self._current_obs = next_obs
[docs] def _log_value( self, reward: torch.Tensor, cost: torch.Tensor, info: dict[str, Any], ) -> None: """Log value. .. note:: OmniSafe uses :class:`RewardNormalizer` wrapper, so the original reward and cost will be stored in ``info['original_reward']`` and ``info['original_cost']``. Args: reward (torch.Tensor): The immediate step reward. cost (torch.Tensor): The immediate step cost. info (dict[str, Any]): Some information logged by the environment. """ self._ep_ret += info.get('original_reward', reward).cpu() self._ep_cost += info.get('original_cost', cost).cpu() self._ep_len += 1
[docs] def _log_metrics(self, logger: Logger, idx: int) -> None: """Log metrics, including ``EpRet``, ``EpCost``, ``EpLen``. Args: logger (Logger): Logger, to log ``EpRet``, ``EpCost``, ``EpLen``. idx (int): The index of the environment. """ logger.store( { 'Metrics/EpRet': self._ep_ret[idx], 'Metrics/EpCost': self._ep_cost[idx], 'Metrics/EpLen': self._ep_len[idx], }, )
[docs] def _reset_log(self, idx: int | None = None) -> None: """Reset the episode return, episode cost and episode length. Args: idx (int or None, optional): The index of the environment. Defaults to None (single environment). """ if idx is None: self._ep_ret = torch.zeros(self._env.num_envs) self._ep_cost = torch.zeros(self._env.num_envs) self._ep_len = torch.zeros(self._env.num_envs) else: self._ep_ret[idx] = 0.0 self._ep_cost[idx] = 0.0 self._ep_len[idx] = 0.0