Source code for omnisafe.algorithms.on_policy.base.ppo

# Copyright 2022-2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of the PPO algorithm."""

from typing import Dict, Tuple

import torch

from omnisafe.algorithms import registry
from omnisafe.algorithms.on_policy.base.policy_gradient import PolicyGradient


[docs]@registry.register class PPO(PolicyGradient): """The Proximal Policy Optimization (PPO) algorithm. References: - Title: Proximal Policy Optimization Algorithms - Authors: John Schulman, Filip Wolski, Prafulla Dhariwal, Alec Radford, Oleg Klimov. - URL: `PPO <https://arxiv.org/abs/1707.06347>`_ """
[docs] def _loss_pi( self, obs: torch.Tensor, act: torch.Tensor, logp: torch.Tensor, adv: torch.Tensor ) -> Tuple[torch.Tensor, Dict[str, float]]: r"""Computing pi/actor loss. In Proximal Policy Optimization, the loss is defined as: .. math:: L^{CLIP} = \mathbb{E}_{s_t \sim \rho_{\theta}} \left[ \min(r_t A^{R}_{\pi_{\theta}}(s_t, a_t) , \text{clip}(r_t, 1-\epsilon, 1+\epsilon) A^{R}_{\pi_{\theta}}(s_t, a_t) \right] where :math:`r_t = \frac{\pi_\theta ^{'}(a_t|s_t)}{\pi_\theta(a_t|s_t)}`, :math:`\epsilon` is the clip parameter, :math:`A^{R}_{\pi_{\theta}}(s_t, a_t)` is the advantage. Args: obs (torch.Tensor): ``observation`` stored in buffer. act (torch.Tensor): ``action`` stored in buffer. log_p (torch.Tensor): ``log probability`` of action stored in buffer. adv (torch.Tensor): ``advantage`` stored in buffer. """ distribution = self._actor_critic.actor(obs) logp_ = self._actor_critic.actor.log_prob(act) std = self._actor_critic.actor.std ratio = torch.exp(logp_ - logp) ratio_cliped = torch.clamp( ratio, 1 - self._cfgs.algo_cfgs.clip, 1 + self._cfgs.algo_cfgs.clip ) loss = -torch.min(ratio * adv, ratio_cliped * adv).mean() loss -= self._cfgs.algo_cfgs.entropy_coef * distribution.entropy().mean() # useful extra info entropy = distribution.entropy().mean().item() info = {'entropy': entropy, 'ratio': ratio.mean().item(), 'std': std} return loss, info