Source code for omnisafe.adapter.onpolicy_adapter

# Copyright 2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""OnPolicy Adapter for OmniSafe."""

from __future__ import annotations

from typing import Any

import torch
from rich.progress import track

from omnisafe.adapter.online_adapter import OnlineAdapter
from omnisafe.common.buffer import VectorOnPolicyBuffer
from omnisafe.common.logger import Logger
from omnisafe.models.actor_critic.constraint_actor_critic import ConstraintActorCritic
from omnisafe.utils.config import Config


[docs]class OnPolicyAdapter(OnlineAdapter): """OnPolicy Adapter for OmniSafe. :class:`OnPolicyAdapter` is used to adapt the environment to the on-policy training. Args: env_id (str): The environment id. num_envs (int): The number of environments. seed (int): The random seed. cfgs (Config): The configuration. """ _ep_ret: torch.Tensor _ep_cost: torch.Tensor _ep_len: torch.Tensor def __init__( # pylint: disable=too-many-arguments self, env_id: str, num_envs: int, seed: int, cfgs: Config, ) -> None: """Initialize an instance of :class:`OnPolicyAdapter`.""" super().__init__(env_id, num_envs, seed, cfgs) self._reset_log()
[docs] def rollout( # pylint: disable=too-many-locals self, steps_per_epoch: int, agent: ConstraintActorCritic, buffer: VectorOnPolicyBuffer, logger: Logger, ) -> None: """Rollout the environment and store the data in the buffer. .. warning:: As OmniSafe uses :class:`AutoReset` wrapper, the environment will be reset automatically, so the final observation will be stored in ``info['final_observation']``. Args: steps_per_epoch (int): Number of steps per epoch. agent (ConstraintActorCritic): Constraint actor-critic, including actor , reward critic and cost critic. buffer (VectorOnPolicyBuffer): Vector on-policy buffer. logger (Logger): Logger, to log ``EpRet``, ``EpCost``, ``EpLen``. """ self._reset_log() obs, _ = self.reset() for step in track( range(steps_per_epoch), description=f'Processing rollout for epoch: {logger.current_epoch}...', ): act, value_r, value_c, logp = agent.step(obs) next_obs, reward, cost, terminated, truncated, info = self.step(act) self._log_value(reward=reward, cost=cost, info=info) if self._cfgs.algo_cfgs.use_cost: logger.store({'Value/cost': value_c}) logger.store({'Value/reward': value_r}) buffer.store( obs=obs, act=act, reward=reward, cost=cost, value_r=value_r, value_c=value_c, logp=logp, ) obs = next_obs epoch_end = step >= steps_per_epoch - 1 for idx, (done, time_out) in enumerate(zip(terminated, truncated)): if epoch_end or done or time_out: last_value_r = torch.zeros(1) last_value_c = torch.zeros(1) if not done: if epoch_end: logger.log( f'Warning: trajectory cut off when rollout by epoch at {self._ep_len[idx]} steps.', ) _, last_value_r, last_value_c, _ = agent.step(obs[idx]) if time_out: _, last_value_r, last_value_c, _ = agent.step( info['final_observation'][idx], ) last_value_r = last_value_r.unsqueeze(0) last_value_c = last_value_c.unsqueeze(0) if done or time_out: self._log_metrics(logger, idx) self._reset_log(idx) self._ep_ret[idx] = 0.0 self._ep_cost[idx] = 0.0 self._ep_len[idx] = 0.0 buffer.finish_path(last_value_r, last_value_c, idx)
[docs] def _log_value( self, reward: torch.Tensor, cost: torch.Tensor, info: dict[str, Any], ) -> None: """Log value. .. note:: OmniSafe uses :class:`RewardNormalizer` wrapper, so the original reward and cost will be stored in ``info['original_reward']`` and ``info['original_cost']``. Args: reward (torch.Tensor): The immediate step reward. cost (torch.Tensor): The immediate step cost. info (dict[str, Any]): Some information logged by the environment. """ self._ep_ret += info.get('original_reward', reward).cpu() self._ep_cost += info.get('original_cost', cost).cpu() self._ep_len += 1
[docs] def _log_metrics(self, logger: Logger, idx: int) -> None: """Log metrics, including ``EpRet``, ``EpCost``, ``EpLen``. Args: logger (Logger): Logger, to log ``EpRet``, ``EpCost``, ``EpLen``. idx (int): The index of the environment. """ logger.store( { 'Metrics/EpRet': self._ep_ret[idx], 'Metrics/EpCost': self._ep_cost[idx], 'Metrics/EpLen': self._ep_len[idx], }, )
[docs] def _reset_log(self, idx: int | None = None) -> None: """Reset the episode return, episode cost and episode length. Args: idx (int or None, optional): The index of the environment. Defaults to None (single environment). """ if idx is None: self._ep_ret = torch.zeros(self._env.num_envs) self._ep_cost = torch.zeros(self._env.num_envs) self._ep_len = torch.zeros(self._env.num_envs) else: self._ep_ret[idx] = 0.0 self._ep_cost[idx] = 0.0 self._ep_len[idx] = 0.0