Source code for omnisafe.models.critic.v_critic

# Copyright 2022-2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of VCritic."""

from __future__ import annotations

import torch
import torch.nn as nn

from omnisafe.models.base import Critic
from omnisafe.typing import Activation, InitFunction, OmnisafeSpace
from omnisafe.utils.model import build_mlp_network


[docs]class VCritic(Critic): """Implementation of VCritic. A V-function approximator that uses a multi-layer perceptron (MLP) to map observations to V-values. This class is an inherit class of :class:`Critic`. You can design your own V-function approximator by inheriting this class or :class:`Critic`. Args: obs_dim (int): Observation dimension. act_dim (int): Action dimension. hidden_sizes (list of int): List of hidden layer sizes. activation (Activation, optional): Activation function. Defaults to ``'relu'``. weight_initialization_mode (InitFunction, optional): Weight initialization mode. Defaults to ``'kaiming_uniform'``. num_critics (int, optional): Number of critics. Defaults to 1. """ def __init__( self, obs_space: OmnisafeSpace, act_space: OmnisafeSpace, hidden_sizes: list[int], activation: Activation = 'relu', weight_initialization_mode: InitFunction = 'kaiming_uniform', num_critics: int = 1, ) -> None: """Initialize an instance of :class:`VCritic`.""" super().__init__( obs_space, act_space, hidden_sizes, activation, weight_initialization_mode, num_critics, use_obs_encoder=False, ) self.net_lst: list[nn.Module] self.net_lst = [] for idx in range(self._num_critics): net = build_mlp_network( sizes=[self._obs_dim, *self._hidden_sizes, 1], activation=self._activation, weight_initialization_mode=self._weight_initialization_mode, ) self.net_lst.append(net) self.add_module(f'critic_{idx}', net)
[docs] def forward( self, obs: torch.Tensor, ) -> list[torch.Tensor]: """Forward function. Specifically, V function approximator maps observations to V-values. Args: obs (torch.Tensor): Observations. Returns: The V critic value of observation. """ res = [] for critic in self.net_lst: res.append(torch.squeeze(critic(obs), -1)) return res