# Copyright 2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of ConstraintActorQCritic."""
from copy import deepcopy
from torch import optim
from omnisafe.models.actor_critic.actor_q_critic import ActorQCritic
from omnisafe.models.base import Critic
from omnisafe.models.critic.critic_builder import CriticBuilder
from omnisafe.typing import OmnisafeSpace
from omnisafe.utils.config import ModelConfig
[docs]class ConstraintActorQCritic(ActorQCritic):
"""ConstraintActorQCritic is a wrapper around ActorCritic that adds a cost critic to the model.
In OmniSafe, we combine the actor and critic into one this class.
+-----------------+---------------------------------------------------+
| Model | Description |
+=================+===================================================+
| Actor | Input is observation. Output is action. |
+-----------------+---------------------------------------------------+
| Reward Q Critic | Input is obs-action pair, Output is reward value. |
+-----------------+---------------------------------------------------+
| Cost Q Critic | Input is obs-action pair. Output is cost value. |
+-----------------+---------------------------------------------------+
Args:
obs_space (OmnisafeSpace): The observation space.
act_space (OmnisafeSpace): The action space.
model_cfgs (ModelConfig): The model configurations.
epochs (int): The number of epochs.
Attributes:
actor (Actor): The actor network.
target_actor (Actor): The target actor network.
reward_critic (Critic): The critic network.
target_reward_critic (Critic): The target critic network.
cost_critic (Critic): The critic network.
target_cost_critic (Critic): The target critic network.
actor_optimizer (Optimizer): The optimizer for the actor network.
reward_critic_optimizer (Optimizer): The optimizer for the critic network.
std_schedule (Schedule): The schedule for the standard deviation of the Gaussian distribution.
"""
def __init__(
self,
obs_space: OmnisafeSpace,
act_space: OmnisafeSpace,
model_cfgs: ModelConfig,
epochs: int,
) -> None:
"""Initialize an instance of :class:`ConstraintActorQCritic`."""
super().__init__(obs_space, act_space, model_cfgs, epochs)
self.cost_critic: Critic = CriticBuilder(
obs_space=obs_space,
act_space=act_space,
hidden_sizes=model_cfgs.critic.hidden_sizes,
activation=model_cfgs.critic.activation,
weight_initialization_mode=model_cfgs.weight_initialization_mode,
num_critics=1,
use_obs_encoder=False,
).build_critic('q')
self.target_cost_critic: Critic = deepcopy(self.cost_critic)
for param in self.target_cost_critic.parameters():
param.requires_grad = False
self.add_module('cost_critic', self.cost_critic)
if model_cfgs.critic.lr is not None:
self.cost_critic_optimizer: optim.Optimizer
self.cost_critic_optimizer = optim.Adam(
self.cost_critic.parameters(),
lr=model_cfgs.critic.lr,
)
[docs] def polyak_update(self, tau: float) -> None:
"""Update the target network with polyak averaging.
Args:
tau (float): The polyak averaging factor.
"""
super().polyak_update(tau)
for target_param, param in zip(
self.target_cost_critic.parameters(),
self.cost_critic.parameters(),
):
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)