# Copyright 2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of the TRPO algorithm."""
from __future__ import annotations
import torch
from torch.distributions import Distribution
from omnisafe.algorithms import registry
from omnisafe.algorithms.on_policy.base.natural_pg import NaturalPG
from omnisafe.utils import distributed
from omnisafe.utils.math import conjugate_gradients
from omnisafe.utils.tools import (
get_flat_gradients_from,
get_flat_params_from,
set_param_values_to_model,
)
[docs]@registry.register
class TRPO(NaturalPG):
"""The Trust Region Policy Optimization (TRPO) algorithm.
References:
- Title: Trust Region Policy Optimization
- Authors: John Schulman, Sergey Levine, Philipp Moritz, Michael I. Jordan, Pieter Abbeel.
- URL: `TRPO <https://arxiv.org/abs/1502.05477>`_
"""
[docs] def _init_log(self) -> None:
super()._init_log()
self._logger.register_key('Misc/AcceptanceStep')
# pylint: disable-next=too-many-arguments,too-many-locals,arguments-differ
[docs] def _search_step_size(
self,
step_direction: torch.Tensor,
grads: torch.Tensor,
p_dist: Distribution,
obs: torch.Tensor,
act: torch.Tensor,
logp: torch.Tensor,
adv: torch.Tensor,
loss_before: float,
total_steps: int = 15,
decay: float = 0.8,
) -> tuple[torch.Tensor, int]:
"""TRPO performs `line-search <https://en.wikipedia.org/wiki/Line_search>`_ until constraint satisfaction.
.. hint::
TRPO search around for a satisfied step of policy update to improve loss and reward performance. The search
is done by line-search, which is a way to find a step size that satisfies the constraint. The constraint is
the KL-divergence between the old policy and the new policy.
Args:
step_dir (torch.Tensor): The step direction.
g_flat (torch.Tensor): The gradient of the policy.
p_dist (torch.distributions.Distribution): The old policy distribution.
obs (torch.Tensor): The observation.
act (torch.Tensor): The action.
logp (torch.Tensor): The log probability of the action.
adv (torch.Tensor): The advantage.
adv_c (torch.Tensor): The cost advantage.
loss_pi_before (float): The loss of the policy before the update.
total_steps (int, optional): The total steps to search. Defaults to 15.
decay (float, optional): The decay rate of the step size. Defaults to 0.8.
Returns:
The tuple of final update direction and acceptance step size.
"""
# How far to go in a single update
step_frac = 1.0
# Get old parameterized policy expression
theta_old = get_flat_params_from(self._actor_critic.actor)
# Change expected objective function gradient = expected_imrpove best this moment
expected_improve = grads.dot(step_direction)
final_kl = 0.0
# While not within_trust_region and not out of total_steps:
for step in range(total_steps):
# update theta params
new_theta = theta_old + step_frac * step_direction
# set new params as params of net
set_param_values_to_model(self._actor_critic.actor, new_theta)
with torch.no_grad():
loss = self._loss_pi(obs, act, logp, adv)
# compute KL distance between new and old policy
q_dist = self._actor_critic.actor(obs)
# KL-distance of old p-dist and new q-dist, applied in KLEarlyStopping
kl = torch.distributions.kl.kl_divergence(p_dist, q_dist).mean().item()
kl = distributed.dist_avg(kl).mean().item()
# real loss improve: old policy loss - new policy loss
loss_improve = loss_before - loss.item()
# average processes.... multi-processing style like: mpi_tools.mpi_avg(xxx)
loss_improve = distributed.dist_avg(loss_improve)
self._logger.log(f'Expected Improvement: {expected_improve} Actual: {loss_improve}')
if not torch.isfinite(loss):
self._logger.log('WARNING: loss_pi not finite')
elif loss_improve < 0:
self._logger.log('INFO: did not improve improve <0')
elif kl > self._cfgs.algo_cfgs.target_kl:
self._logger.log('INFO: violated KL constraint.')
else:
# step only if surrogate is improved and when within trust reg.
acceptance_step = step + 1
self._logger.log(f'Accept step at i={acceptance_step}')
final_kl = kl
break
step_frac *= decay
else:
self._logger.log('INFO: no suitable step found...')
step_direction = torch.zeros_like(step_direction)
acceptance_step = 0
set_param_values_to_model(self._actor_critic.actor, theta_old)
self._logger.store(
{
'Train/KL': final_kl,
},
)
return step_frac * step_direction, acceptance_step
[docs] def _update_actor( # pylint: disable=too-many-arguments,too-many-locals
self,
obs: torch.Tensor,
act: torch.Tensor,
logp: torch.Tensor,
adv_r: torch.Tensor,
adv_c: torch.Tensor,
) -> None:
"""Update policy network.
Trust Policy Region Optimization updates policy network using the
`conjugate gradient <https://en.wikipedia.org/wiki/Conjugate_gradient_method>`_ algorithm,
following the steps:
- Compute the gradient of the policy.
- Compute the step direction.
- Search for a step size that satisfies the constraint.
- Update the policy network.
Args:
obs (torch.Tensor): The observation tensor.
act (torch.Tensor): The action tensor.
logp (torch.Tensor): The log probability of the action.
adv_r (torch.Tensor): The reward advantage tensor.
adv_c (torch.Tensor): The cost advantage tensor.
"""
self._fvp_obs = obs[:: self._cfgs.algo_cfgs.fvp_sample_freq]
theta_old = get_flat_params_from(self._actor_critic.actor)
self._actor_critic.actor.zero_grad()
adv = self._compute_adv_surrogate(adv_r, adv_c)
loss = self._loss_pi(obs, act, logp, adv)
loss_before = distributed.dist_avg(loss).item()
p_dist = self._actor_critic.actor(obs)
loss.backward()
distributed.avg_grads(self._actor_critic.actor)
grads = -get_flat_gradients_from(self._actor_critic.actor)
x = conjugate_gradients(self._fvp, grads, self._cfgs.algo_cfgs.cg_iters)
assert torch.isfinite(x).all(), 'x is not finite'
xHx = torch.dot(x, self._fvp(x))
assert xHx.item() >= 0, 'xHx is negative'
alpha = torch.sqrt(2 * self._cfgs.algo_cfgs.target_kl / (xHx + 1e-8))
step_direction = x * alpha
assert torch.isfinite(step_direction).all(), 'step_direction is not finite'
step_direction, accept_step = self._search_step_size(
step_direction=step_direction,
grads=grads,
p_dist=p_dist,
obs=obs,
act=act,
logp=logp,
adv=adv,
loss_before=loss_before,
)
theta_new = theta_old + step_direction
set_param_values_to_model(self._actor_critic.actor, theta_new)
with torch.no_grad():
loss = self._loss_pi(obs, act, logp, adv)
self._logger.store(
{
'Misc/Alpha': alpha.item(),
'Misc/FinalStepNorm': torch.norm(step_direction).mean().item(),
'Misc/xHx': xHx.item(),
'Misc/gradient_norm': torch.norm(grads).mean().item(),
'Misc/H_inv_g': x.norm().item(),
'Misc/AcceptanceStep': accept_step,
},
)