# Copyright 2022-2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of ActorCritic."""
from __future__ import annotations
import torch
from torch import nn, optim
from torch.optim.lr_scheduler import ConstantLR, LinearLR
from omnisafe.models.actor import GaussianLearningActor, GaussianSACActor, MLPActor
from omnisafe.models.actor.actor_builder import ActorBuilder
from omnisafe.models.base import Critic
from omnisafe.models.critic.critic_builder import CriticBuilder
from omnisafe.typing import OmnisafeSpace
from omnisafe.utils.config import ModelConfig
from omnisafe.utils.schedule import PiecewiseSchedule, Schedule
[docs]class ActorCritic(nn.Module):
"""Class for ActorCritic.
In OmniSafe, we combine the actor and critic into one this class.
+-----------------+-----------------------------------------------+
| Model | Description |
+=================+===============================================+
| Actor | Input is observation. Output is action. |
+-----------------+-----------------------------------------------+
| Reward V Critic | Input is observation. Output is reward value. |
+-----------------+-----------------------------------------------+
Args:
obs_space (OmnisafeSpace): The observation space.
act_space (OmnisafeSpace): The action space.
model_cfgs (ModelConfig): The model configurations.
epochs (int): The number of epochs.
Attributes:
actor (Actor): The actor network.
reward_critic (Critic): The critic network.
std_schedule (Schedule): The schedule for the standard deviation of the Gaussian distribution.
"""
std_schedule: Schedule
# pylint: disable-next=too-many-arguments
def __init__(
self,
obs_space: OmnisafeSpace,
act_space: OmnisafeSpace,
model_cfgs: ModelConfig,
epochs: int,
) -> None:
"""Initialize an instance of :class:`ActorCritic`."""
super().__init__()
self.actor: GaussianLearningActor | GaussianSACActor | MLPActor = ActorBuilder(
obs_space=obs_space,
act_space=act_space,
hidden_sizes=model_cfgs.actor.hidden_sizes,
activation=model_cfgs.actor.activation,
weight_initialization_mode=model_cfgs.weight_initialization_mode,
).build_actor(actor_type=model_cfgs.actor_type)
self.reward_critic: Critic = CriticBuilder(
obs_space=obs_space,
act_space=act_space,
hidden_sizes=model_cfgs.critic.hidden_sizes,
activation=model_cfgs.critic.activation,
weight_initialization_mode=model_cfgs.weight_initialization_mode,
num_critics=1,
use_obs_encoder=False,
).build_critic(critic_type='v')
self.add_module('actor', self.actor)
self.add_module('reward_critic', self.reward_critic)
if model_cfgs.actor.lr is not None:
self.actor_optimizer: optim.Optimizer
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=model_cfgs.actor.lr)
if model_cfgs.critic.lr is not None:
self.reward_critic_optimizer: optim.Optimizer = optim.Adam(
self.reward_critic.parameters(),
lr=model_cfgs.critic.lr,
)
if model_cfgs.actor.lr is not None:
self.actor_scheduler: LinearLR | ConstantLR
if model_cfgs.linear_lr_decay:
self.actor_scheduler = LinearLR(
self.actor_optimizer,
start_factor=1.0,
end_factor=0.0,
total_iters=epochs,
verbose=True,
)
else:
self.actor_scheduler = ConstantLR(
self.actor_optimizer,
factor=1.0,
total_iters=epochs,
verbose=True,
)
[docs] def step(
self,
obs: torch.Tensor,
deterministic: bool = False,
) -> tuple[torch.Tensor, ...]:
"""Choose the action based on the observation. used in rollout without gradient.
Args:
obs (torch.tensor): The observation.
deterministic (bool, optional): Whether to use deterministic action. Defaults to False.
Returns:
action: The deterministic action if ``deterministic`` is True, otherwise the action with
Gaussian noise.
value_r: The reward value of the observation.
log_prob: The log probability of the action.
"""
with torch.no_grad():
value_r = self.reward_critic(obs)
act = self.actor.predict(obs, deterministic=deterministic)
log_prob = self.actor.log_prob(act)
return act, value_r[0], log_prob
[docs] def forward(
self,
obs: torch.Tensor,
deterministic: bool = False,
) -> tuple[torch.Tensor, ...]:
"""Choose the action based on the observation. used in training with gradient.
Args:
obs (torch.tensor): The observation.
deterministic (bool, optional): Whether to use deterministic action. Defaults to False.
Returns:
action: The deterministic action if ``deterministic`` is True, otherwise the action with
Gaussian noise.
value_r: The reward value of the observation.
log_prob: The log probability of the action.
"""
return self.step(obs, deterministic=deterministic)
[docs] def set_annealing(self, epochs: list[int], std: list[float]) -> None:
"""Set the annealing mode for the actor.
Args:
epochs (list of int): The list of epochs.
std (list of float): The list of standard deviation.
"""
assert isinstance(
self.actor,
GaussianLearningActor,
), 'Only GaussianLearningActor support annealing.'
self.std_schedule = PiecewiseSchedule(
endpoints=list(zip(epochs, std)),
outside_value=std[-1],
)
[docs] def annealing(self, epoch: int) -> None:
"""Set the annealing mode for the actor.
Args:
epoch (int): The current epoch.
"""
assert isinstance(
self.actor,
GaussianLearningActor,
), 'Only GaussianLearningActor support annealing.'
self.actor.std = self.std_schedule.value(epoch)