Source code for omnisafe.common.lagrange

# Copyright 2022-2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of Lagrange."""

import torch


[docs]class Lagrange: r"""Base class for Lagrangian-base Algorithms. This class implements the Lagrange multiplier update and the Lagrange loss. .. note:: Any traditional policy gradient algorithm can be converted to a Lagrangian-based algorithm by inheriting from this class and implementing the :meth:`_loss_pi` method. Example: >>> from omnisafe.common.lagrange import Lagrange >>> def loss_pi(self, data): >>> # implement your own loss function here >>> return loss You can also inherit this class to implement your own Lagrangian-based algorithm, with any policy gradient method you like in ``omnisafe``. Example: >>> from omnisafe.common.lagrange import Lagrange >>> class CustomAlgo: >>> def __init(self) -> None: >>> # initialize your own algorithm here >>> super().__init__() >>> # initialize the Lagrange multiplier >>> self.lagrange = Lagrange(**self._cfgs.lagrange_cfgs) """ # pylint: disable-next=too-many-arguments
[docs] def __init__( self, cost_limit: float, lagrangian_multiplier_init: float, lambda_lr: float, lambda_optimizer: str, lagrangian_upper_bound=None, ) -> None: """Initialize Lagrange multiplier.""" self.cost_limit = cost_limit self.lambda_lr = lambda_lr self.lagrangian_upper_bound = lagrangian_upper_bound init_value = max(lagrangian_multiplier_init, 1e-5) self.lagrangian_multiplier = torch.nn.Parameter( torch.as_tensor(init_value), requires_grad=True ) self.lambda_range_projection = torch.nn.ReLU() # fetch optimizer from PyTorch optimizer package assert hasattr( torch.optim, lambda_optimizer ), f'Optimizer={lambda_optimizer} not found in torch.' torch_opt = getattr(torch.optim, lambda_optimizer) self.lambda_optimizer = torch_opt( [ self.lagrangian_multiplier, ], lr=lambda_lr, )
[docs] def compute_lambda_loss(self, mean_ep_cost: float) -> torch.Tensor: r"""Penalty loss for Lagrange multiplier. .. note:: ``mean_ep_cost`` obtained from: ``self.logger.get_stats('EpCosts')[0]``, which are already averaged across MPI processes. Args: mean_ep_cost (float): mean episode cost. """ return -self.lagrangian_multiplier * (mean_ep_cost - self.cost_limit)
[docs] def update_lagrange_multiplier(self, Jc: float) -> None: r"""Update Lagrange multiplier (lambda). Detailedly speaking, we update the Lagrange multiplier by minimizing the penalty loss, which is defined as: .. math:: \lambda ^{'} = \lambda + \eta * (J_c - J_c^*) where :math:`\lambda` is the Lagrange multiplier, :math:`\eta` is the learning rate, :math:`J_c` is the mean episode cost, and :math:`J_c^*` is the cost limit. Args: Jc (float): mean episode cost. """ self.lambda_optimizer.zero_grad() lambda_loss = self.compute_lambda_loss(Jc) lambda_loss.backward() self.lambda_optimizer.step() self.lagrangian_multiplier.data.clamp_( 0.0, self.lagrangian_upper_bound ) # enforce: lambda in [0, inf]