Source code for omnisafe.models.actor_critic.actor_critic

# Copyright 2022-2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of ActorCritic."""

from __future__ import annotations

import torch
from torch import nn, optim
from torch.optim.lr_scheduler import ConstantLR, LinearLR

from omnisafe.models.actor import GaussianLearningActor, GaussianSACActor, MLPActor
from omnisafe.models.actor.actor_builder import ActorBuilder
from omnisafe.models.base import Critic
from omnisafe.models.critic.critic_builder import CriticBuilder
from omnisafe.typing import OmnisafeSpace
from omnisafe.utils.config import ModelConfig
from omnisafe.utils.schedule import PiecewiseSchedule, Schedule


[docs]class ActorCritic(nn.Module): """Class for ActorCritic. In OmniSafe, we combine the actor and critic into one this class. +-----------------+-----------------------------------------------+ | Model | Description | +=================+===============================================+ | Actor | Input is observation. Output is action. | +-----------------+-----------------------------------------------+ | Reward V Critic | Input is observation. Output is reward value. | +-----------------+-----------------------------------------------+ Args: obs_space (OmnisafeSpace): The observation space. act_space (OmnisafeSpace): The action space. model_cfgs (ModelConfig): The model configurations. epochs (int): The number of epochs. Attributes: actor (Actor): The actor network. reward_critic (Critic): The critic network. std_schedule (Schedule): The schedule for the standard deviation of the Gaussian distribution. """ std_schedule: Schedule # pylint: disable-next=too-many-arguments def __init__( self, obs_space: OmnisafeSpace, act_space: OmnisafeSpace, model_cfgs: ModelConfig, epochs: int, ) -> None: """Initialize an instance of :class:`ActorCritic`.""" super().__init__() self.actor: GaussianLearningActor | GaussianSACActor | MLPActor = ActorBuilder( obs_space=obs_space, act_space=act_space, hidden_sizes=model_cfgs.actor.hidden_sizes, activation=model_cfgs.actor.activation, weight_initialization_mode=model_cfgs.weight_initialization_mode, ).build_actor(actor_type=model_cfgs.actor_type) self.reward_critic: Critic = CriticBuilder( obs_space=obs_space, act_space=act_space, hidden_sizes=model_cfgs.critic.hidden_sizes, activation=model_cfgs.critic.activation, weight_initialization_mode=model_cfgs.weight_initialization_mode, num_critics=1, use_obs_encoder=False, ).build_critic(critic_type='v') self.add_module('actor', self.actor) self.add_module('reward_critic', self.reward_critic) if model_cfgs.actor.lr is not None: self.actor_optimizer: optim.Optimizer self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=model_cfgs.actor.lr) if model_cfgs.critic.lr is not None: self.reward_critic_optimizer: optim.Optimizer = optim.Adam( self.reward_critic.parameters(), lr=model_cfgs.critic.lr, ) if model_cfgs.actor.lr is not None: self.actor_scheduler: LinearLR | ConstantLR if model_cfgs.linear_lr_decay: self.actor_scheduler = LinearLR( self.actor_optimizer, start_factor=1.0, end_factor=0.0, total_iters=epochs, verbose=True, ) else: self.actor_scheduler = ConstantLR( self.actor_optimizer, factor=1.0, total_iters=epochs, verbose=True, )
[docs] def step( self, obs: torch.Tensor, deterministic: bool = False, ) -> tuple[torch.Tensor, ...]: """Choose the action based on the observation. used in rollout without gradient. Args: obs (torch.tensor): The observation. deterministic (bool, optional): Whether to use deterministic action. Defaults to False. Returns: action: The deterministic action if ``deterministic`` is True, otherwise the action with Gaussian noise. value_r: The reward value of the observation. log_prob: The log probability of the action. """ with torch.no_grad(): value_r = self.reward_critic(obs) act = self.actor.predict(obs, deterministic=deterministic) log_prob = self.actor.log_prob(act) return act, value_r[0], log_prob
[docs] def forward( self, obs: torch.Tensor, deterministic: bool = False, ) -> tuple[torch.Tensor, ...]: """Choose the action based on the observation. used in training with gradient. Args: obs (torch.tensor): The observation. deterministic (bool, optional): Whether to use deterministic action. Defaults to False. Returns: action: The deterministic action if ``deterministic`` is True, otherwise the action with Gaussian noise. value_r: The reward value of the observation. log_prob: The log probability of the action. """ return self.step(obs, deterministic=deterministic)
[docs] def set_annealing(self, epochs: list[int], std: list[float]) -> None: """Set the annealing mode for the actor. Args: epochs (list of int): The list of epochs. std (list of float): The list of standard deviation. """ assert isinstance( self.actor, GaussianLearningActor, ), 'Only GaussianLearningActor support annealing.' self.std_schedule = PiecewiseSchedule( endpoints=list(zip(epochs, std)), outside_value=std[-1], )
[docs] def annealing(self, epoch: int) -> None: """Set the annealing mode for the actor. Args: epoch (int): The current epoch. """ assert isinstance( self.actor, GaussianLearningActor, ), 'Only GaussianLearningActor support annealing.' self.actor.std = self.std_schedule.value(epoch)