Source code for omnisafe.algorithms.on_policy.base.policy_gradient

# Copyright 2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of the Policy Gradient algorithm."""

import time
from typing import Any, Dict, Tuple, Union

import torch
import torch.nn as nn
from rich.progress import track
from torch.utils.data import DataLoader, TensorDataset

from omnisafe.adapter import OnPolicyAdapter
from omnisafe.algorithms import registry
from omnisafe.algorithms.base_algo import BaseAlgo
from omnisafe.common.buffer import VectorOnPolicyBuffer
from omnisafe.common.logger import Logger
from omnisafe.models.actor_critic.constraint_actor_critic import ConstraintActorCritic
from omnisafe.utils import distributed


[docs]@registry.register # pylint: disable-next=too-many-instance-attributes, too-few-public-methods, line-too-long class PolicyGradient(BaseAlgo): """The Policy Gradient algorithm. References: - Title: Policy Gradient Methods for Reinforcement Learning with Function Approximation - Authors: Richard S. Sutton, David McAllester, Satinder Singh, Yishay Mansour. - URL: `PG <https://proceedings.neurips.cc/paper/1999/file64d828b85b0bed98e80ade0a5c43b0f-Paper.pdf>`_ """
[docs] def _init_env(self) -> None: """Initialize the environment. Omnisafe use :class:`omnisafe.adapter.OnPolicyAdapter` to adapt the environment to the algorithm. User can customize the environment by inheriting this function. Example: >>> def _init_env(self) -> None: >>> self._env = CustomAdapter() """ self._env = OnPolicyAdapter( self._env_id, self._cfgs.train_cfgs.vector_env_nums, self._seed, self._cfgs ) assert (self._cfgs.algo_cfgs.update_cycle) % ( distributed.world_size() * self._cfgs.train_cfgs.vector_env_nums ) == 0, ('The number of steps per epoch is not divisible by the number of ' 'environments.') self._steps_per_epoch = ( self._cfgs.algo_cfgs.update_cycle // distributed.world_size() // self._cfgs.train_cfgs.vector_env_nums )
[docs] def _init_model(self) -> None: """Initialize the model. Omnisafe use :class:`omnisafe.models.actor_critic.constraint_actor_critic. ConstraintActorCritic` as the default model. User can customize the model by inheriting this function. Example: >>> def _init_model(self) -> None: >>> self._actor_critic = CustomActorCritic() """ self._actor_critic = ConstraintActorCritic( obs_space=self._env.observation_space, act_space=self._env.action_space, model_cfgs=self._cfgs.model_cfgs, epochs=self._cfgs.train_cfgs.epochs, ).to(self._device) if distributed.world_size() > 1: distributed.sync_params(self._actor_critic) if self._cfgs.model_cfgs.exploration_noise_anneal: self._actor_critic.set_annealing( epochs=[0, self._cfgs.train_cfgs.epochs], std=self._cfgs.model_cfgs.std_range, )
[docs] def _init(self) -> None: """The initialization of the algorithm. User can define the initialization of the algorithm by inheriting this function. Example: >>> def _init(self) -> None: >>> super()._init() >>> self._buffer = CustomBuffer() >>> self._model = CustomModel() """ self._buf = VectorOnPolicyBuffer( obs_space=self._env.observation_space, act_space=self._env.action_space, size=self._steps_per_epoch, gamma=self._cfgs.algo_cfgs.gamma, lam=self._cfgs.algo_cfgs.lam, lam_c=self._cfgs.algo_cfgs.lam_c, advantage_estimator=self._cfgs.algo_cfgs.adv_estimation_method, standardized_adv_r=self._cfgs.algo_cfgs.standardized_rew_adv, standardized_adv_c=self._cfgs.algo_cfgs.standardized_cost_adv, penalty_coefficient=self._cfgs.algo_cfgs.penalty_coef, num_envs=self._cfgs.train_cfgs.vector_env_nums, device=self._device, )
[docs] def _init_log(self) -> None: """Log info about epoch. .. list-table:: * - Things to log - Description * - Train/Epoch - Current epoch. * - Metrics/EpCost - Average cost of the epoch. * - Metrics/EpCost - Average cost of the epoch. * - Metrics/EpRet - Average return of the epoch. * - Metrics/EpLen - Average length of the epoch. * - Values/reward - Average value in :meth:`roll_out()` (from critic network) of the epoch. * - Values/cost - Average cost in :meth:`roll_out()` (from critic network) of the epoch. * - Values/Adv - Average advantage in :meth:`roll_out()` of the epoch. * - Loss/Loss_pi - Loss of the policy network. * - Loss/Delta_loss_pi - Delta loss of the policy network. * - Loss/Loss_reward_critic - Loss of the value network. * - Loss/Delta_loss_reward_critic - Delta loss of the value network. * - Loss/Loss_cost_critic - Loss of the cost network. * - Loss/Delta_loss_cost_critic - Delta loss of the cost network. * - Train/Entropy - Entropy of the policy network. * - Train/KL - KL divergence of the policy network. * - Train/StopIters - Number of iterations of the policy network. * - Train/PolicyRatio - Ratio of the policy network. * - Train/LR - Learning rate of the policy network. * - Misc/Seed - Seed of the experiment. * - Misc/TotalEnvSteps - Total steps of the experiment. * - Time - Total time. * - FPS - Frames per second of the epoch. Args: epoch (int): current epoch. """ self._logger = Logger( output_dir=self._cfgs.logger_cfgs.log_dir, exp_name=self._cfgs.exp_name, seed=self._cfgs.seed, use_tensorboard=self._cfgs.logger_cfgs.use_tensorboard, use_wandb=self._cfgs.logger_cfgs.use_wandb, config=self._cfgs, ) what_to_save: Dict[str, Any] = {} what_to_save['pi'] = self._actor_critic.actor if self._cfgs.algo_cfgs.obs_normalize: obs_normalizer = self._env.save()['obs_normalizer'] what_to_save['obs_normalizer'] = obs_normalizer self._logger.setup_torch_saver(what_to_save) self._logger.torch_save() self._logger.register_key('Metrics/EpRet', window_length=50) self._logger.register_key('Metrics/EpCost', window_length=50) self._logger.register_key('Metrics/EpLen', window_length=50) self._logger.register_key('Train/Epoch') self._logger.register_key('Train/Entropy') self._logger.register_key('Train/KL') self._logger.register_key('Train/StopIter') self._logger.register_key('Train/PolicyRatio') self._logger.register_key('Train/LR') if self._cfgs.model_cfgs.actor_type == 'gaussian_learning': self._logger.register_key('Train/PolicyStd') self._logger.register_key('TotalEnvSteps') # log information about actor self._logger.register_key('Loss/Loss_pi', delta=True) self._logger.register_key('Value/Adv') # log information about critic self._logger.register_key('Loss/Loss_reward_critic', delta=True) self._logger.register_key('Value/reward') if self._cfgs.algo_cfgs.use_cost: # log information about cost critic self._logger.register_key('Loss/Loss_cost_critic', delta=True) self._logger.register_key('Value/cost') self._logger.register_key('Time/Total') self._logger.register_key('Time/Rollout') self._logger.register_key('Time/Update') self._logger.register_key('Time/Epoch') self._logger.register_key('Time/FPS')
[docs] def learn(self) -> Tuple[Union[int, float], ...]: r"""This is main function for algorithm update, divided into the following steps, - :meth:`rollout`: collect interactive data from environment. - :meth:`update`: perform actor/critic updates. - :meth:`log`: epoch/update information for visualization and terminal log print. Args: self (object): object of the class. """ start_time = time.time() self._logger.log('INFO: Start training') for epoch in range(self._cfgs.train_cfgs.epochs): epoch_time = time.time() roll_out_time = time.time() self._env.roll_out( steps_per_epoch=self._steps_per_epoch, agent=self._actor_critic, buffer=self._buf, logger=self._logger, ) self._logger.store(**{'Time/Rollout': time.time() - roll_out_time}) update_time = time.time() self._update() self._logger.store(**{'Time/Update': time.time() - update_time}) if self._cfgs.model_cfgs.exploration_noise_anneal: self._actor_critic.annealing(epoch) if self._cfgs.model_cfgs.actor.lr != 'None': self._actor_critic.actor_scheduler.step() self._logger.store( **{ 'TotalEnvSteps': (epoch + 1) * self._cfgs.algo_cfgs.update_cycle, 'Time/FPS': self._cfgs.algo_cfgs.update_cycle / (time.time() - epoch_time), 'Time/Total': (time.time() - start_time), 'Time/Epoch': (time.time() - epoch_time), 'Train/Epoch': epoch, 'Train/LR': 0.0 if self._cfgs.model_cfgs.actor.lr == 'None' else self._actor_critic.actor_scheduler.get_last_lr()[0], } ) self._logger.dump_tabular() # save model to disk if (epoch + 1) % self._cfgs.logger_cfgs.save_model_freq == 0: self._logger.torch_save() ep_ret = self._logger.get_stats('Metrics/EpRet')[0] ep_cost = self._logger.get_stats('Metrics/EpCost')[0] ep_len = self._logger.get_stats('Metrics/EpLen')[0] self._logger.close() return ep_ret, ep_cost, ep_len
[docs] def _update(self) -> None: r"""Update actor, critic, following next steps: - Get the ``data`` from buffer .. hint:: .. list-table:: * - obs - ``observaion`` stored in buffer. * - act - ``action`` stored in buffer. * - target_value_r - ``target value`` stored in buffer. * - target_value_c - ``target cost`` stored in buffer. * - logp - ``log probability`` stored in buffer. * - adv - ``estimated advantage`` (e.g. **GAE**) stored in buffer. * - cost_adv - ``estimated cost advantage`` (e.g. **GAE**) stored in buffer. - Update value net by :meth:`_update_reward_critic()`. - Update cost net by :meth:`_update_cost_critic()`. - Update policy net by :meth:`_update_actor()`. The basic process of each update is as follows: #. Get the data from buffer. #. Shuffle the data and split it into mini-batch data. #. Get the loss of network. #. Update the network by loss. #. Repeat steps 2, 3 until the number of mini-batch data is used up. #. Repeat steps 2, 3, 4 until the KL divergence violates the limit. Args: self (object): object of the class. """ data = self._buf.get() obs, act, logp, target_value_r, target_value_c, adv_r, adv_c = ( data['obs'], data['act'], data['logp'], data['target_value_r'], data['target_value_c'], data['adv_r'], data['adv_c'], ) original_obs = obs old_distribution = self._actor_critic.actor(obs) dataloader = DataLoader( dataset=TensorDataset(obs, act, logp, target_value_r, target_value_c, adv_r, adv_c), batch_size=self._cfgs.algo_cfgs.batch_size, shuffle=True, ) for i in track(range(self._cfgs.algo_cfgs.update_iters), description='Updating...'): for ( obs, act, logp, target_value_r, target_value_c, adv_r, adv_c, ) in dataloader: self._update_reward_critic(obs, target_value_r) if self._cfgs.algo_cfgs.use_cost: self._update_cost_critic(obs, target_value_c) self._update_actor(obs, act, logp, adv_r, adv_c) new_distribution = self._actor_critic.actor(original_obs) kl = ( torch.distributions.kl.kl_divergence(old_distribution, new_distribution) .sum(-1, keepdim=True) .mean() .item() ) kl = distributed.dist_avg(kl) if self._cfgs.algo_cfgs.kl_early_stop and kl > self._cfgs.algo_cfgs.target_kl: self._logger.log(f'Early stopping at iter {i + 1} due to reaching max kl') break self._logger.store( **{ 'Train/StopIter': i + 1, # pylint: disable=undefined-loop-variable 'Value/Adv': adv_r.mean().item(), 'Train/KL': kl, } )
[docs] def _update_reward_critic(self, obs: torch.Tensor, target_value_r: torch.Tensor) -> None: r"""Update value network under a double for loop. The loss function is ``MSE loss``, which is defined in ``torch.nn.MSELoss``. Specifically, the loss function is defined as: .. math:: L = \frac{1}{N} \sum_{i=1}^N (\hat{V} - V)^2 where :math:`\hat{V}` is the predicted cost and :math:`V` is the target cost. #. Compute the loss function. #. Add the ``critic norm`` to the loss function if ``use_critic_norm`` is ``True``. #. Clip the gradient if ``use_max_grad_norm`` is ``True``. #. Update the network by loss function. Args: obs (torch.Tensor): ``observation`` stored in buffer. target_value_r (torch.Tensor): ``target_value_r`` stored in buffer. """ self._actor_critic.reward_critic_optimizer.zero_grad() loss = nn.functional.mse_loss(self._actor_critic.reward_critic(obs)[0], target_value_r) if self._cfgs.algo_cfgs.use_critic_norm: for param in self._actor_critic.reward_critic.parameters(): loss += param.pow(2).sum() * self._cfgs.algo_cfgs.critic_norm_coef loss.backward() if self._cfgs.algo_cfgs.use_max_grad_norm: torch.nn.utils.clip_grad_norm_( self._actor_critic.reward_critic.parameters(), self._cfgs.algo_cfgs.max_grad_norm ) distributed.avg_grads(self._actor_critic.reward_critic) self._actor_critic.reward_critic_optimizer.step() self._logger.store(**{'Loss/Loss_reward_critic': loss.mean().item()})
[docs] def _update_cost_critic(self, obs: torch.Tensor, target_value_c: torch.Tensor) -> None: r"""Update value network under a double for loop. The loss function is ``MSE loss``, which is defined in ``torch.nn.MSELoss``. Specifically, the loss function is defined as: .. math:: L = \frac{1}{N} \sum_{i=1}^N (\hat{V} - V)^2 where :math:`\hat{V}` is the predicted cost and :math:`V` is the target cost. #. Compute the loss function. #. Add the ``critic norm`` to the loss function if ``use_critic_norm`` is ``True``. #. Clip the gradient if ``use_max_grad_norm`` is ``True``. #. Update the network by loss function. Args: obs (torch.Tensor): ``observation`` stored in buffer. target_value_c (torch.Tensor): ``target_value_c`` stored in buffer. """ self._actor_critic.cost_critic_optimizer.zero_grad() loss = nn.functional.mse_loss(self._actor_critic.cost_critic(obs)[0], target_value_c) if self._cfgs.algo_cfgs.use_critic_norm: for param in self._actor_critic.cost_critic.parameters(): loss += param.pow(2).sum() * self._cfgs.algo_cfgs.critic_norm_coef loss.backward() if self._cfgs.algo_cfgs.use_max_grad_norm: torch.nn.utils.clip_grad_norm_( self._actor_critic.cost_critic.parameters(), self._cfgs.algo_cfgs.max_grad_norm ) distributed.avg_grads(self._actor_critic.cost_critic) self._actor_critic.cost_critic_optimizer.step() self._logger.store(**{'Loss/Loss_cost_critic': loss.mean().item()})
[docs] def _update_actor( # pylint: disable=too-many-arguments self, obs: torch.Tensor, act: torch.Tensor, logp: torch.Tensor, adv_r: torch.Tensor, adv_c: torch.Tensor, ) -> None: r"""Update policy network under a double for loop. #. Compute the loss function. #. Clip the gradient if ``use_max_grad_norm`` is ``True``. #. Update the network by loss function. .. warning:: For some ``KL divergence`` based algorithms (e.g. TRPO, CPO, etc.), the ``KL divergence`` between the old policy and the new policy is calculated. And the ``KL divergence`` is used to determine whether the update is successful. If the ``KL divergence`` is too large, the update will be terminated. Args: obs (torch.Tensor): ``observation`` stored in buffer. act (torch.Tensor): ``action`` stored in buffer. log_p (torch.Tensor): ``log_p`` stored in buffer. adv_r (torch.Tensor): ``advantage`` stored in buffer. adv_c (torch.Tensor): ``cost_advantage`` stored in buffer. """ adv = self._compute_adv_surrogate(adv_r, adv_c) loss, info = self._loss_pi(obs, act, logp, adv) self._actor_critic.actor_optimizer.zero_grad() loss.backward() if self._cfgs.algo_cfgs.use_max_grad_norm: torch.nn.utils.clip_grad_norm_( self._actor_critic.actor.parameters(), self._cfgs.algo_cfgs.max_grad_norm ) distributed.avg_grads(self._actor_critic.actor) self._actor_critic.actor_optimizer.step() self._logger.store( **{ 'Train/Entropy': info['entropy'], 'Train/PolicyRatio': info['ratio'], 'Train/PolicyStd': info['std'], 'Loss/Loss_pi': loss.mean().item(), } )
[docs] def _compute_adv_surrogate( # pylint: disable=unused-argument self, adv_r: torch.Tensor, adv_c: torch.Tensor ) -> torch.Tensor: """Compute surrogate loss. Policy Gradient only use reward advantage. Args: adv_r (torch.Tensor): reward advantage adv_c (torch.Tensor): cost advantage """ return adv_r
[docs] def _loss_pi( self, obs: torch.Tensor, act: torch.Tensor, logp: torch.Tensor, adv: torch.Tensor, ) -> Tuple[torch.Tensor, Dict[str, float]]: r"""Computing pi/actor loss. In Policy Gradient, the loss is defined as: .. math:: L = -\mathbb{E}_{s_t \sim \rho_\theta} [ \sum_{t=0}^T ( \frac{\pi^{'}_\theta(a_t|s_t)}{\pi_\theta(a_t|s_t)} ) A^{R}_{\pi_{\theta}}(s_t, a_t) ] where :math:`\pi_\theta` is the policy network, :math:`\pi^{'}_\theta` is the new policy network, :math:`A^{R}_{\pi_{\theta}}(s_t, a_t)` is the advantage. Args: obs (torch.Tensor): ``observation`` stored in buffer. act (torch.Tensor): ``action`` stored in buffer. logp (torch.Tensor): ``log probability`` of action stored in buffer. adv (torch.Tensor): ``advantage`` stored in buffer. """ distribution = self._actor_critic.actor(obs) logp_ = self._actor_critic.actor.log_prob(act) std = self._actor_critic.actor.std ratio = torch.exp(logp_ - logp) loss = -(ratio * adv).mean() entropy = distribution.entropy().mean().item() info = {'entropy': entropy, 'ratio': ratio.mean().item(), 'std': std} return loss, info