# Copyright 2022-2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of the PPO algorithm."""
from __future__ import annotations
import torch
from omnisafe.algorithms import registry
from omnisafe.algorithms.on_policy.base.policy_gradient import PolicyGradient
[docs]@registry.register
class PPO(PolicyGradient):
"""The Proximal Policy Optimization (PPO) algorithm.
References:
- Title: Proximal Policy Optimization Algorithms
- Authors: John Schulman, Filip Wolski, Prafulla Dhariwal, Alec Radford, Oleg Klimov.
- URL: `PPO <https://arxiv.org/abs/1707.06347>`_
"""
[docs] def _loss_pi(
self,
obs: torch.Tensor,
act: torch.Tensor,
logp: torch.Tensor,
adv: torch.Tensor,
) -> torch.Tensor:
r"""Computing pi/actor loss.
In Proximal Policy Optimization, the loss is defined as:
.. math::
L^{CLIP} = \mathbb{E}_{s_t \sim \rho_{\theta}} \left[
\min ( r_t A^{R}_{\pi_{\theta}} (s_t, a_t) , \text{clip} (r_t, 1 - \epsilon, 1 + \epsilon)
A^{R}_{\pi_{\theta}} (s_t, a_t)
\right]
where :math:`r_t = \frac{\pi_{\theta}^{'} (a_t|s_t)}{\pi_{\theta} (a_t|s_t)}`,
:math:`\epsilon` is the clip parameter, and :math:`A^{R}_{\pi_{\theta}} (s_t, a_t)` is the
advantage.
Args:
obs (torch.Tensor): The ``observation`` sampled from buffer.
act (torch.Tensor): The ``action`` sampled from buffer.
logp (torch.Tensor): The ``log probability`` of action sampled from buffer.
adv (torch.Tensor): The ``advantage`` sampled from buffer.
Returns:
The loss of pi/actor.
"""
distribution = self._actor_critic.actor(obs)
logp_ = self._actor_critic.actor.log_prob(act)
std = self._actor_critic.actor.std
ratio = torch.exp(logp_ - logp)
ratio_cliped = torch.clamp(
ratio,
1 - self._cfgs.algo_cfgs.clip,
1 + self._cfgs.algo_cfgs.clip,
)
loss = -torch.min(ratio * adv, ratio_cliped * adv).mean()
loss -= self._cfgs.algo_cfgs.entropy_coef * distribution.entropy().mean()
# useful extra info
entropy = distribution.entropy().mean().item()
self._logger.store(
{
'Train/Entropy': entropy,
'Train/PolicyRatio': ratio,
'Train/PolicyStd': std,
'Loss/Loss_pi': loss.mean().item(),
},
)
return loss