# Copyright 2022-2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of the PPO algorithm."""
from typing import Dict, Tuple
import torch
from omnisafe.algorithms import registry
from omnisafe.algorithms.on_policy.base.policy_gradient import PolicyGradient
[docs]@registry.register
class PPO(PolicyGradient):
"""The Proximal Policy Optimization (PPO) algorithm.
References:
- Title: Proximal Policy Optimization Algorithms
- Authors: John Schulman, Filip Wolski, Prafulla Dhariwal, Alec Radford, Oleg Klimov.
- URL: `PPO <https://arxiv.org/abs/1707.06347>`_
"""
[docs] def _loss_pi(
self, obs: torch.Tensor, act: torch.Tensor, logp: torch.Tensor, adv: torch.Tensor
) -> Tuple[torch.Tensor, Dict[str, float]]:
r"""Computing pi/actor loss.
In Proximal Policy Optimization, the loss is defined as:
.. math::
L^{CLIP} = \mathbb{E}_{s_t \sim \rho_{\theta}}
\left[ \min(r_t A^{R}_{\pi_{\theta}}(s_t, a_t) ,
\text{clip}(r_t, 1-\epsilon, 1+\epsilon) A^{R}_{\pi_{\theta}}(s_t, a_t) \right]
where :math:`r_t = \frac{\pi_\theta ^{'}(a_t|s_t)}{\pi_\theta(a_t|s_t)}`,
:math:`\epsilon` is the clip parameter, :math:`A^{R}_{\pi_{\theta}}(s_t, a_t)` is the advantage.
Args:
obs (torch.Tensor): ``observation`` stored in buffer.
act (torch.Tensor): ``action`` stored in buffer.
log_p (torch.Tensor): ``log probability`` of action stored in buffer.
adv (torch.Tensor): ``advantage`` stored in buffer.
"""
distribution = self._actor_critic.actor(obs)
logp_ = self._actor_critic.actor.log_prob(act)
std = self._actor_critic.actor.std
ratio = torch.exp(logp_ - logp)
ratio_cliped = torch.clamp(
ratio, 1 - self._cfgs.algo_cfgs.clip, 1 + self._cfgs.algo_cfgs.clip
)
loss = -torch.min(ratio * adv, ratio_cliped * adv).mean()
loss -= self._cfgs.algo_cfgs.entropy_coef * distribution.entropy().mean()
# useful extra info
entropy = distribution.entropy().mean().item()
info = {'entropy': entropy, 'ratio': ratio.mean().item(), 'std': std}
return loss, info