# Copyright 2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of the Policy Gradient algorithm."""
from __future__ import annotations
import time
from typing import Any
import torch
import torch.nn as nn
from rich.progress import track
from torch.nn.utils.clip_grad import clip_grad_norm_
from torch.utils.data import DataLoader, TensorDataset
from omnisafe.adapter import OnPolicyAdapter
from omnisafe.algorithms import registry
from omnisafe.algorithms.base_algo import BaseAlgo
from omnisafe.common.buffer import VectorOnPolicyBuffer
from omnisafe.common.logger import Logger
from omnisafe.models.actor_critic.constraint_actor_critic import ConstraintActorCritic
from omnisafe.utils import distributed
[docs]@registry.register
# pylint: disable-next=too-many-instance-attributes,too-few-public-methods,line-too-long
class PolicyGradient(BaseAlgo):
"""The Policy Gradient algorithm.
References:
- Title: Policy Gradient Methods for Reinforcement Learning with Function Approximation
- Authors: Richard S. Sutton, David McAllester, Satinder Singh, Yishay Mansour.
- URL: `PG <https://proceedings.neurips.cc/paper/1999/file64d828b85b0bed98e80ade0a5c43b0f-Paper.pdf>`_
"""
[docs] def _init_env(self) -> None:
"""Initialize the environment.
OmniSafe uses :class:`omnisafe.adapter.OnPolicyAdapter` to adapt the environment to the
algorithm.
User can customize the environment by inheriting this method.
Examples:
>>> def _init_env(self) -> None:
... self._env = CustomAdapter()
Raises:
AssertionError: If the number of steps per epoch is not divisible by the number of
environments.
"""
self._env: OnPolicyAdapter = OnPolicyAdapter(
self._env_id,
self._cfgs.train_cfgs.vector_env_nums,
self._seed,
self._cfgs,
)
assert (self._cfgs.algo_cfgs.steps_per_epoch) % (
distributed.world_size() * self._cfgs.train_cfgs.vector_env_nums
) == 0, 'The number of steps per epoch is not divisible by the number of environments.'
self._steps_per_epoch: int = (
self._cfgs.algo_cfgs.steps_per_epoch
// distributed.world_size()
// self._cfgs.train_cfgs.vector_env_nums
)
[docs] def _init_model(self) -> None:
"""Initialize the model.
OmniSafe uses :class:`omnisafe.models.actor_critic.constraint_actor_critic.ConstraintActorCritic`
as the default model.
User can customize the model by inheriting this method.
Examples:
>>> def _init_model(self) -> None:
... self._actor_critic = CustomActorCritic()
"""
self._actor_critic: ConstraintActorCritic = ConstraintActorCritic(
obs_space=self._env.observation_space,
act_space=self._env.action_space,
model_cfgs=self._cfgs.model_cfgs,
epochs=self._cfgs.train_cfgs.epochs,
).to(self._device)
if distributed.world_size() > 1:
distributed.sync_params(self._actor_critic)
if self._cfgs.model_cfgs.exploration_noise_anneal:
self._actor_critic.set_annealing(
epochs=[0, self._cfgs.train_cfgs.epochs],
std=self._cfgs.model_cfgs.std_range,
)
[docs] def _init(self) -> None:
"""The initialization of the algorithm.
User can define the initialization of the algorithm by inheriting this method.
Examples:
>>> def _init(self) -> None:
... super()._init()
... self._buffer = CustomBuffer()
... self._model = CustomModel()
"""
self._buf: VectorOnPolicyBuffer = VectorOnPolicyBuffer(
obs_space=self._env.observation_space,
act_space=self._env.action_space,
size=self._steps_per_epoch,
gamma=self._cfgs.algo_cfgs.gamma,
lam=self._cfgs.algo_cfgs.lam,
lam_c=self._cfgs.algo_cfgs.lam_c,
advantage_estimator=self._cfgs.algo_cfgs.adv_estimation_method,
standardized_adv_r=self._cfgs.algo_cfgs.standardized_rew_adv,
standardized_adv_c=self._cfgs.algo_cfgs.standardized_cost_adv,
penalty_coefficient=self._cfgs.algo_cfgs.penalty_coef,
num_envs=self._cfgs.train_cfgs.vector_env_nums,
device=self._device,
)
[docs] def _init_log(self) -> None:
"""Log info about epoch.
+-----------------------+----------------------------------------------------------------------+
| Things to log | Description |
+=======================+======================================================================+
| Train/Epoch | Current epoch. |
+-----------------------+----------------------------------------------------------------------+
| Metrics/EpCost | Average cost of the epoch. |
+-----------------------+----------------------------------------------------------------------+
| Metrics/EpRet | Average return of the epoch. |
+-----------------------+----------------------------------------------------------------------+
| Metrics/EpLen | Average length of the epoch. |
+-----------------------+----------------------------------------------------------------------+
| Values/reward | Average value in :meth:`rollout` (from critic network) of the epoch. |
+-----------------------+----------------------------------------------------------------------+
| Values/cost | Average cost in :meth:`rollout` (from critic network) of the epoch. |
+-----------------------+----------------------------------------------------------------------+
| Values/Adv | Average reward advantage of the epoch. |
+-----------------------+----------------------------------------------------------------------+
| Loss/Loss_pi | Loss of the policy network. |
+-----------------------+----------------------------------------------------------------------+
| Loss/Loss_cost_critic | Loss of the cost critic network. |
+-----------------------+----------------------------------------------------------------------+
| Train/Entropy | Entropy of the policy network. |
+-----------------------+----------------------------------------------------------------------+
| Train/StopIters | Number of iterations of the policy network. |
+-----------------------+----------------------------------------------------------------------+
| Train/PolicyRatio | Ratio of the policy network. |
+-----------------------+----------------------------------------------------------------------+
| Train/LR | Learning rate of the policy network. |
+-----------------------+----------------------------------------------------------------------+
| Misc/Seed | Seed of the experiment. |
+-----------------------+----------------------------------------------------------------------+
| Misc/TotalEnvSteps | Total steps of the experiment. |
+-----------------------+----------------------------------------------------------------------+
| Time | Total time. |
+-----------------------+----------------------------------------------------------------------+
| FPS | Frames per second of the epoch. |
+-----------------------+----------------------------------------------------------------------+
"""
self._logger = Logger(
output_dir=self._cfgs.logger_cfgs.log_dir,
exp_name=self._cfgs.exp_name,
seed=self._cfgs.seed,
use_tensorboard=self._cfgs.logger_cfgs.use_tensorboard,
use_wandb=self._cfgs.logger_cfgs.use_wandb,
config=self._cfgs,
)
what_to_save: dict[str, Any] = {}
what_to_save['pi'] = self._actor_critic.actor
if self._cfgs.algo_cfgs.obs_normalize:
obs_normalizer = self._env.save()['obs_normalizer']
what_to_save['obs_normalizer'] = obs_normalizer
self._logger.setup_torch_saver(what_to_save)
self._logger.torch_save()
self._logger.register_key('Metrics/EpRet', window_length=50)
self._logger.register_key('Metrics/EpCost', window_length=50)
self._logger.register_key('Metrics/EpLen', window_length=50)
self._logger.register_key('Train/Epoch')
self._logger.register_key('Train/Entropy')
self._logger.register_key('Train/KL')
self._logger.register_key('Train/StopIter')
self._logger.register_key('Train/PolicyRatio', min_and_max=True)
self._logger.register_key('Train/LR')
if self._cfgs.model_cfgs.actor_type == 'gaussian_learning':
self._logger.register_key('Train/PolicyStd')
self._logger.register_key('TotalEnvSteps')
# log information about actor
self._logger.register_key('Loss/Loss_pi', delta=True)
self._logger.register_key('Value/Adv')
# log information about critic
self._logger.register_key('Loss/Loss_reward_critic', delta=True)
self._logger.register_key('Value/reward')
if self._cfgs.algo_cfgs.use_cost:
# log information about cost critic
self._logger.register_key('Loss/Loss_cost_critic', delta=True)
self._logger.register_key('Value/cost')
self._logger.register_key('Time/Total')
self._logger.register_key('Time/Rollout')
self._logger.register_key('Time/Update')
self._logger.register_key('Time/Epoch')
self._logger.register_key('Time/FPS')
[docs] def learn(self) -> tuple[float, float, int]:
"""This is main function for algorithm update.
It is divided into the following steps:
- :meth:`rollout`: collect interactive data from environment.
- :meth:`update`: perform actor/critic updates.
- :meth:`log`: epoch/update information for visualization and terminal log print.
Returns:
ep_ret: average episode return in final epoch.
ep_cost: average episode cost in final epoch.
ep_len: average episode length in final epoch.
"""
start_time = time.time()
self._logger.log('INFO: Start training')
for epoch in range(self._cfgs.train_cfgs.epochs):
epoch_time = time.time()
rollout_time = time.time()
self._env.rollout(
steps_per_epoch=self._steps_per_epoch,
agent=self._actor_critic,
buffer=self._buf,
logger=self._logger,
)
self._logger.store({'Time/Rollout': time.time() - rollout_time})
update_time = time.time()
self._update()
self._logger.store({'Time/Update': time.time() - update_time})
if self._cfgs.model_cfgs.exploration_noise_anneal:
self._actor_critic.annealing(epoch)
if self._cfgs.model_cfgs.actor.lr is not None:
self._actor_critic.actor_scheduler.step()
self._logger.store(
{
'TotalEnvSteps': (epoch + 1) * self._cfgs.algo_cfgs.steps_per_epoch,
'Time/FPS': self._cfgs.algo_cfgs.steps_per_epoch / (time.time() - epoch_time),
'Time/Total': (time.time() - start_time),
'Time/Epoch': (time.time() - epoch_time),
'Train/Epoch': epoch,
'Train/LR': 0.0
if self._cfgs.model_cfgs.actor.lr is None
else self._actor_critic.actor_scheduler.get_last_lr()[0],
},
)
self._logger.dump_tabular()
# save model to disk
if (epoch + 1) % self._cfgs.logger_cfgs.save_model_freq == 0:
self._logger.torch_save()
ep_ret = self._logger.get_stats('Metrics/EpRet')[0]
ep_cost = self._logger.get_stats('Metrics/EpCost')[0]
ep_len = int(self._logger.get_stats('Metrics/EpLen')[0])
self._logger.close()
return ep_ret, ep_cost, ep_len
[docs] def _update(self) -> None:
"""Update actor, critic.
- Get the ``data`` from buffer
.. hint::
+----------------+------------------------------------------------------------------+
| obs | ``observation`` sampled from buffer. |
+================+==================================================================+
| act | ``action`` sampled from buffer. |
+----------------+------------------------------------------------------------------+
| target_value_r | ``target reward value`` sampled from buffer. |
+----------------+------------------------------------------------------------------+
| target_value_c | ``target cost value`` sampled from buffer. |
+----------------+------------------------------------------------------------------+
| logp | ``log probability`` sampled from buffer. |
+----------------+------------------------------------------------------------------+
| adv_r | ``estimated advantage`` (e.g. **GAE**) sampled from buffer. |
+----------------+------------------------------------------------------------------+
| adv_c | ``estimated cost advantage`` (e.g. **GAE**) sampled from buffer. |
+----------------+------------------------------------------------------------------+
- Update value net by :meth:`_update_reward_critic`.
- Update cost net by :meth:`_update_cost_critic`.
- Update policy net by :meth:`_update_actor`.
The basic process of each update is as follows:
#. Get the data from buffer.
#. Shuffle the data and split it into mini-batch data.
#. Get the loss of network.
#. Update the network by loss.
#. Repeat steps 2, 3 until the number of mini-batch data is used up.
#. Repeat steps 2, 3, 4 until the KL divergence violates the limit.
"""
data = self._buf.get()
obs, act, logp, target_value_r, target_value_c, adv_r, adv_c = (
data['obs'],
data['act'],
data['logp'],
data['target_value_r'],
data['target_value_c'],
data['adv_r'],
data['adv_c'],
)
original_obs = obs
old_distribution = self._actor_critic.actor(obs)
dataloader = DataLoader(
dataset=TensorDataset(obs, act, logp, target_value_r, target_value_c, adv_r, adv_c),
batch_size=self._cfgs.algo_cfgs.batch_size,
shuffle=True,
)
update_counts = 0
final_kl = torch.ones_like(old_distribution.loc)
for i in track(range(self._cfgs.algo_cfgs.update_iters), description='Updating...'):
for (
obs,
act,
logp,
target_value_r,
target_value_c,
adv_r,
adv_c,
) in dataloader:
self._update_reward_critic(obs, target_value_r)
if self._cfgs.algo_cfgs.use_cost:
self._update_cost_critic(obs, target_value_c)
self._update_actor(obs, act, logp, adv_r, adv_c)
new_distribution = self._actor_critic.actor(original_obs)
kl = (
torch.distributions.kl.kl_divergence(old_distribution, new_distribution)
.sum(-1, keepdim=True)
.mean()
.item()
)
kl = distributed.dist_avg(kl)
final_kl = kl
update_counts += 1
if self._cfgs.algo_cfgs.kl_early_stop and kl > self._cfgs.algo_cfgs.target_kl:
self._logger.log(f'Early stopping at iter {i + 1} due to reaching max kl')
break
self._logger.store(
{
'Train/StopIter': update_counts, # pylint: disable=undefined-loop-variable
'Value/Adv': adv_r.mean().item(),
'Train/KL': final_kl,
},
)
[docs] def _update_reward_critic(self, obs: torch.Tensor, target_value_r: torch.Tensor) -> None:
r"""Update value network under a double for loop.
The loss function is ``MSE loss``, which is defined in ``torch.nn.MSELoss``.
Specifically, the loss function is defined as:
.. math::
L = \frac{1}{N} \sum_{i=1}^N (\hat{V} - V)^2
where :math:`\hat{V}` is the predicted cost and :math:`V` is the target cost.
#. Compute the loss function.
#. Add the ``critic norm`` to the loss function if ``use_critic_norm`` is ``True``.
#. Clip the gradient if ``use_max_grad_norm`` is ``True``.
#. Update the network by loss function.
Args:
obs (torch.Tensor): The ``observation`` sampled from buffer.
target_value_r (torch.Tensor): The ``target_value_r`` sampled from buffer.
"""
self._actor_critic.reward_critic_optimizer.zero_grad()
loss = nn.functional.mse_loss(self._actor_critic.reward_critic(obs)[0], target_value_r)
if self._cfgs.algo_cfgs.use_critic_norm:
for param in self._actor_critic.reward_critic.parameters():
loss += param.pow(2).sum() * self._cfgs.algo_cfgs.critic_norm_coef
loss.backward()
if self._cfgs.algo_cfgs.use_max_grad_norm:
clip_grad_norm_(
self._actor_critic.reward_critic.parameters(),
self._cfgs.algo_cfgs.max_grad_norm,
)
distributed.avg_grads(self._actor_critic.reward_critic)
self._actor_critic.reward_critic_optimizer.step()
self._logger.store({'Loss/Loss_reward_critic': loss.mean().item()})
[docs] def _update_cost_critic(self, obs: torch.Tensor, target_value_c: torch.Tensor) -> None:
r"""Update value network under a double for loop.
The loss function is ``MSE loss``, which is defined in ``torch.nn.MSELoss``.
Specifically, the loss function is defined as:
.. math::
L = \frac{1}{N} \sum_{i=1}^N (\hat{V} - V)^2
where :math:`\hat{V}` is the predicted cost and :math:`V` is the target cost.
#. Compute the loss function.
#. Add the ``critic norm`` to the loss function if ``use_critic_norm`` is ``True``.
#. Clip the gradient if ``use_max_grad_norm`` is ``True``.
#. Update the network by loss function.
Args:
obs (torch.Tensor): The ``observation`` sampled from buffer.
target_value_c (torch.Tensor): The ``target_value_c`` sampled from buffer.
"""
self._actor_critic.cost_critic_optimizer.zero_grad()
loss = nn.functional.mse_loss(self._actor_critic.cost_critic(obs)[0], target_value_c)
if self._cfgs.algo_cfgs.use_critic_norm:
for param in self._actor_critic.cost_critic.parameters():
loss += param.pow(2).sum() * self._cfgs.algo_cfgs.critic_norm_coef
loss.backward()
if self._cfgs.algo_cfgs.use_max_grad_norm:
clip_grad_norm_(
self._actor_critic.cost_critic.parameters(),
self._cfgs.algo_cfgs.max_grad_norm,
)
distributed.avg_grads(self._actor_critic.cost_critic)
self._actor_critic.cost_critic_optimizer.step()
self._logger.store({'Loss/Loss_cost_critic': loss.mean().item()})
[docs] def _update_actor( # pylint: disable=too-many-arguments
self,
obs: torch.Tensor,
act: torch.Tensor,
logp: torch.Tensor,
adv_r: torch.Tensor,
adv_c: torch.Tensor,
) -> None:
"""Update policy network under a double for loop.
#. Compute the loss function.
#. Clip the gradient if ``use_max_grad_norm`` is ``True``.
#. Update the network by loss function.
.. warning::
For some ``KL divergence`` based algorithms (e.g. TRPO, CPO, etc.),
the ``KL divergence`` between the old policy and the new policy is calculated.
And the ``KL divergence`` is used to determine whether the update is successful.
If the ``KL divergence`` is too large, the update will be terminated.
Args:
obs (torch.Tensor): The ``observation`` sampled from buffer.
act (torch.Tensor): The ``action`` sampled from buffer.
logp (torch.Tensor): The ``log_p`` sampled from buffer.
adv_r (torch.Tensor): The ``reward_advantage`` sampled from buffer.
adv_c (torch.Tensor): The ``cost_advantage`` sampled from buffer.
"""
adv = self._compute_adv_surrogate(adv_r, adv_c)
loss = self._loss_pi(obs, act, logp, adv)
self._actor_critic.actor_optimizer.zero_grad()
loss.backward()
if self._cfgs.algo_cfgs.use_max_grad_norm:
clip_grad_norm_(
self._actor_critic.actor.parameters(),
self._cfgs.algo_cfgs.max_grad_norm,
)
distributed.avg_grads(self._actor_critic.actor)
self._actor_critic.actor_optimizer.step()
[docs] def _compute_adv_surrogate( # pylint: disable=unused-argument
self,
adv_r: torch.Tensor,
adv_c: torch.Tensor,
) -> torch.Tensor:
"""Compute surrogate loss.
Policy Gradient only use reward advantage.
Args:
adv_r (torch.Tensor): The ``reward_advantage`` sampled from buffer.
adv_c (torch.Tensor): The ``cost_advantage`` sampled from buffer.
Returns:
The ``reward_advantage`` used to update policy network.
"""
return adv_r
[docs] def _loss_pi(
self,
obs: torch.Tensor,
act: torch.Tensor,
logp: torch.Tensor,
adv: torch.Tensor,
) -> torch.Tensor:
r"""Computing pi/actor loss.
In Policy Gradient, the loss is defined as:
.. math::
L = -\mathbb{E}_{s_t \sim \rho_{\theta}} [
\sum_{t=0}^T ( \frac{\pi^{'}_{\theta}(a_t|s_t)}{\pi_{\theta}(a_t|s_t)} )
A^{R}_{\pi_{\theta}}(s_t, a_t)
]
where :math:`\pi_{\theta}` is the policy network, :math:`\pi^{'}_{\theta}`
is the new policy network, :math:`A^{R}_{\pi_{\theta}}(s_t, a_t)` is the advantage.
Args:
obs (torch.Tensor): The ``observation`` sampled from buffer.
act (torch.Tensor): The ``action`` sampled from buffer.
logp (torch.Tensor): The ``log probability`` of action sampled from buffer.
adv (torch.Tensor): The ``advantage`` sampled from buffer.
Returns:
The loss of pi/actor.
"""
distribution = self._actor_critic.actor(obs)
logp_ = self._actor_critic.actor.log_prob(act)
std = self._actor_critic.actor.std
ratio = torch.exp(logp_ - logp)
loss = -(ratio * adv).mean()
entropy = distribution.entropy().mean().item()
self._logger.store(
{
'Train/Entropy': entropy,
'Train/PolicyRatio': ratio,
'Train/PolicyStd': std,
'Loss/Loss_pi': loss.mean().item(),
},
)
return loss