Proximal Policy Optimization Algorithms#

Quick Facts#

  1. PPO is an on-policy algorithm.

  2. PPO can be used for environments with both discrete and continuous action spaces.

  3. PPO can be thought of as being a simple implementation of TRPO .

  4. The OmniSafe implementation of PPO support parallelization.

  5. An API Documentation is available for PPO.

PPO Theorem#

Background#

Proximal Policy Optimization(PPO) is an RL algorithm inheriting some of the benefits of trpo, However, it is much simpler to implement. PPO shares the same target as TRPO: How can we take as big as an improvement step on a policy update using the data we already have, without stepping so far that we accidentally cause performance collapse? Instead of solving this problem with a complex second-order method as TRPO do, PPO use a few other tricks to keep new policies close to old. There are two primary PPO variants PPO-Penalty and bdg-ref-info-line:PPO-Clip<PPO-Clip>.

Problems of TRPO

  • The calculation of KL divergence in TRPO is too complicated.

  • Only the raw data sampled by the Monte Carlo method is used.

  • Using second-order optimization methods.

Advantage of PPO

  • Using clip method to make the difference between the two strategies less significant.

  • Using the \(\text{GAE}\) method to process data.

  • Simple to implement.

  • Using first-order optimization methods.


Optimization Objective#

In the previous chapters, we introduced that TRPO solves the following optimization problems:

(1)#\[\begin{split}&\pi_{k+1}=\arg\max_{\pi \in \Pi_{\boldsymbol{\theta}}}J^R(\pi)\\ \text{s.t.}\quad&D(\pi,\pi_k)\le\delta\end{split}\]

where \(\Pi_{\boldsymbol{\theta}} \subseteq \Pi\) denotes the set of parameterized policies with parameters \(\boldsymbol{\theta}\), and \(D\) is some distance measure. The problem that TRPO needs to solve is how to find a suitable update direction and update step, so that updating the actor can improve the performance without being too different from the original actor. Finally, TRPO rewrites Problem Eq.1 as:

(2)#\[\begin{split}&\underset{\theta}{\max} L_{\theta_{old}}(\theta) \\ &\text{s.t. } \quad \bar{D}_{\mathrm{KL}}(\theta_{old}, \theta) \le \delta\end{split}\]

where \(L_{\theta_{old}}(\theta)= \frac{\pi_\theta(a \mid s)}{\pi_{\theta_{old}}(a \mid s)} \hat{A}_\pi(s, a)\), Moreover,:math:hat{A}_{pi}(s, a) is an estimator of the advantage function given \(s\) and \(a\).

You may still have a question: Why are we using \(\hat{A}\) instead of \(A\). This is a trick named generalized advantage estimator (\(\text{GAE}\)). Almost all advanced reinforcement learning algorithms use \(\text{GAE}\) technique to estimate more efficiently:math:A. \(\hat{A}\) is the \(\text{GAE}\) version of \(A\).


PPO-Penalty#

TRPO suggests using a penalty instead of a constraint to solve the unconstrained optimization problem:

(3)#\[\max _\theta \mathbb{E}[\frac{\pi_\theta(a \mid s)}{\pi_{\theta_{old}}(a \mid s)} \hat{A}_\pi(s, a)-\beta D_{K L}[\pi_{\theta_{old}}(* \mid s), \pi_\theta(* \mid s)]]\]

However, experiments show that it is not sufficient to simply choose a fixed penalty coefficient \(\beta\) and optimize the penalized objective Eq.3 with SGD(stochastic gradient descent), so finally TRPO abandoned this method.

PPO-Penalty uses an approach named Adaptive KL Penalty Coefficient to solve the above problem, thus making Eq.3 perform well in the experiment. In the simplest implementation of this algorithm, PPO-Penalty performs the following steps in each policy update:

Step I

Using several epochs of mini-batch SGD, optimize the KL-penalized objective shown as eq:ppo-eq-3,

(4)#\[\begin{split}L^{\mathrm{KLPEN}}(\theta)&=&\hat{\mathbb{E}}[\frac{\pi_\theta(a \mid s)}{\pi_{\theta_{old}}(a \mid s)} \hat{A}_\pi(s, a)\\ &-&\beta D_{K L}[\pi_{\theta_{old}}(* \mid s), \pi_\theta(* \mid s)]]\end{split}\]

Step II

Compute \(d=\hat{\mathbb{E}}[\mathrm{KL}[\pi_{\theta_{\text {old }}}(\cdot \mid s), \pi_\theta(\cdot \mid s)]]\)

If \(d<d_{\text {targ }} / 1.5, \beta \leftarrow \beta / 2\)

If \(d>d_{\text {targ }} \times 1.5, \beta \leftarrow \beta * 2\)

The updated \(\beta\) is used for the next policy update.


PPO-Clip#

Let \(r(\theta)\) denote the probability ratio \(r(\theta)=\frac{\pi_\theta(a \mid s)}{\pi \theta_{d d}(a \mid s)}\), PPO-Clip rewrite the surrogate objective as:

(5)#\[L^{\mathrm{CLIP}}(\pi)=\mathbb{E}[\text{min} (r(\theta) \hat{A}_{\pi}(s, a), \text{clip}(r(\theta), 1-\varepsilon, 1+\varepsilon) \hat{A}_{\pi}(s, a))]\]

in which \(\varepsilon\) is a (small) hyperparameter which roughly says how far away the new policy is allowed to go from the old. This is a very complex formula, and it’s difficult to tell at first glance what it’s doing, or how it helps keep the new policy close to the old policy. To help you better understand the above expression, let \(L(s, a, \theta)\) denote \(\max [r(\theta) \hat{A}_{\pi}(s, a), \text{clip}(r(\theta), 1-\varepsilon, 1+\varepsilon) \hat{A}_{\pi}(s, a)]\), we’ll simplify the formula in two cases:

PPO Clip

  1. When Advantage is positive, we can rewrite \(L(s, a, \theta)\) as:

    (6)#\[L(s, a, \theta)=\max (r(\theta),(1-\varepsilon)) \hat{A}_{\pi}(s, a)\]
  2. When Advantage is negative, we can rewrite \(L(s, a, \theta)\) as:

    (7)#\[L(s, a, \theta)=\max (r(\theta),(1+\varepsilon)) \hat{A}_{\pi}(s, a)\]

With the above clipped surrogate function and Eq.5, PPO-Clip can guarantee the new policy would not update so far away from the old. In the experiment, PPO-Clip performs better than PPO-Penalty.


Practical Implementation#

Generalized Advantage Estimation#

One style of policy gradient implementation, popularized in and well-suited for use with recurrent neural networks, runs the policy for \(T\) timesteps (where \(T\) is much less than the episode length), and uses the collected samples for an update. This style requires an advantage estimator that does not look beyond timestep \(T\). This section will be concerned with producing an accurate estimate \(\hat{A}_{\pi}(s,a)\).

Define \(\delta^V=r_t+\gamma V(s_{t+1})-V(s)\) as the TD residual of \(V\) with discount \(\gamma\). Next, let us consider taking the sum of \(k\) of these \(\delta\) terms, which we will denote by \(\hat{A}_{\pi}^{(k)}\).

(8)#\[\begin{split}\begin{array}{ll} \hat{A}_{\pi}^{(1)}:=\delta_t^V =-V(s_t)+r_t+\gamma V(s_{t+1}) \\ \hat{A}_{\pi}^{(2)}:=\delta_t^V+\gamma \delta_{t+1}^V =-V(s_t)+r_t+\gamma r_{t+1}+\gamma^2 V(s_{t+2}) \\ \hat{A}_{\pi}^{(3)}:=\delta_t^V+\gamma \delta_{t+1}^V+\gamma^2 \delta_{t+2}^V =-V(s_t)+r_t+\gamma r_{t+1}+\gamma^2 r_{t+2}+\gamma^3 V(s_{t+3}) \\ \hat{A}_{\pi}^{(k)}:=\sum_{l=0}^{k-1} \gamma^l \delta_{t+l}^V =-V(s_t)+r_t+\gamma r_{t+1}+\cdots+\gamma^{k-1} r_{t+k-1}+\gamma^k V(s_{t+k}) \end{array}\end{split}\]

We can consider \(\hat{A}_{\pi}^{(k)}\) to be an estimator of the advantage function.

Hint

The bias generally becomes smaller as \(k arrow +\infty\), since the term \(\gamma^k V(s_{t+k})\) becomes more heavily discounted. Taking \(k \rightarrow +\infty\), we get:

(9)#\[\hat{A}_{\pi}^{(\infty)}=\sum_{l=0}^{\infty} \gamma^l \delta_{t+l}^V=-V(s_t)+\sum_{l=0}^{\infty} \gamma^l r_{t+l}\]

which is simply the empirical returns minus the value function baseline.

The generalized advantage estimator \(\text{GAE}(\gamma,\lambda)\) is defined as the exponentially-weighted average of these \(k\)-step estimators:

(10)#\[\begin{split}\hat{A}_{\pi}:&= (1-\lambda)(\hat{A}_{\pi}^{(1)}+\lambda \hat{A}_{\pi}^{(2)}+\lambda^2 \hat{A}_{\pi}^{(3)}+\ldots) \\ &= (1-\lambda)(\delta_t^V+\lambda(\delta_t^V+\gamma \delta_{t+1}^V)+\lambda^2(\delta_t^V+\gamma \delta_{t+1}^V+\gamma^2 \delta_{t+2}^V)+\ldots) \\ &= (1-\lambda)(\delta_t^V(1+\lambda+\lambda^2+\ldots)+\gamma \delta_{t+1}^V(\lambda+\lambda^2+\lambda^3+\ldots) .+\gamma^2 \delta_{t+2}^V(\lambda^2+\lambda^3+\lambda^4+\ldots)+\ldots) \\ &= (1-\lambda)(\delta_t^V(\frac{1}{1-\lambda})+\gamma \delta_{t+1}^V(\frac{\lambda}{1-\lambda})+\gamma^2 \delta_{t+2}^V(\frac{\lambda^2}{1-\lambda})+\ldots) \\ &= \sum_{l=0}^{\infty}(\gamma \lambda)^l \delta_{t+l}^V\end{split}\]

There are two notable special cases of this formula, obtained by setting \(\lambda =0\) and \(\lambda =1\).

(11)#\[\begin{split}\text{GAE}(\gamma, 0):\quad & \hat{A}_{\pi}:=\delta_t =r_t+\gamma V(s_{t+1})-V(s_t) \\ \text{GAE}(\gamma, 1):\quad & \hat{A}_{\pi}:=\sum_{l=0}^{\infty} \gamma^l \delta_{t+l} =\sum_{l=0}^{\infty} \gamma^l r_{t+l}-V(s_t)\end{split}\]

Hint

\(\text{GAE}(\gamma,1)\) is the traditional MC-based method to estimate the advantage function, but it has a high variance due to the sum of terms. \(\text{GAE}(\gamma,0)\) is TD-based method with low variance, but it suffers from bias.

The generalized advantage estimator for \(0\le\lambda\le1\) makes a compromise between bias and variance, controlled by parameter \(\lambda\).

Code with OmniSafe#

Quick start#

Run PPO in OmniSafe

Here are 3 ways to run PPO in OmniSafe:

  • Run Agent from preset yaml file

  • Run Agent from custom config dict

  • Run Agent from custom terminal config

1import omnisafe
2
3
4env_id = 'SafetyPointGoal1-v0'
5
6agent = omnisafe.Agent('PPO', env_id)
7agent.learn()
 1import omnisafe
 2
 3
 4env_id = 'SafetyPointGoal1-v0'
 5custom_cfgs = {
 6    'train_cfgs': {
 7        'total_steps': 1024000,
 8        'vector_env_nums': 1,
 9        'parallel': 1,
10    },
11    'algo_cfgs': {
12        'steps_per_epoch': 2048,
13        'update_iters': 1,
14    },
15    'logger_cfgs': {
16        'use_wandb': False,
17    },
18}
19
20agent = omnisafe.Agent('PPO', env_id, custom_cfgs=custom_cfgs)
21agent.learn()

We use train_policy.py as the entrance file. You can train the agent with PPO simply using train_policy.py, with arguments about PPO and environments does the training. For example, to run PPO in SafetyPointGoal1-v0 , with 1 torch thread and seed 0, you can use the following command:

1cd examples
2python train_policy.py --algo PPO --env-id SafetyPointGoal1-v0 --parallel 1 --total-steps 1024000 --device cpu --vector-env-nums 1 --torch-threads 1

Here is the documentation of PPO in PyTorch version.

Architecture of functions#

  • PPO.learn()

    • PPO._env.rollout()

    • PPO._update()

      • PPO._buf.get()

      • PPO.update_lagrange_multiplier(ep_costs)

      • PPO._update_actor

      • PPO._update_reward_critic


Documentation of algorithm specific functions#

ppo._loss_pi()

Compute the loss of Actor actor, flowing the next steps:

  1. Get the policy importance sampling ratio.

1distribution = self._actor_critic.actor(obs)
2logp_ = self._actor_critic.actor.log_prob(act)
3std = self._actor_critic.actor.std
4ratio = torch.exp(logp_ - logp)
  1. Get the clipped surrogate function.

1ratio_cliped = torch.clamp(
2    ratio, 1 - self._cfgs.algo_cfgs.clip, 1 + self._cfgs.algo_cfgs.clip
3)
4loss = -torch.min(ratio * adv, ratio_cliped * adv).mean()
5loss -= self._cfgs.algo_cfgs.entropy_coef * distribution.entropy().mean()
  1. Log useful information.

1entropy = distribution.entropy().mean().item()
2info = {'entropy': entropy, 'ratio': ratio.mean().item(), 'std': std}
3return loss, info
  1. Return the loss of Actor actor and useful information.


Configs#

Train Configs

  • device (str): Device to use for training, options: cpu, cuda,``cuda:0``, etc.

  • torch_threads (int): Number of threads to use for PyTorch.

  • total_steps (int): Total number of steps to train the agent.

  • parallel (int): Number of parallel agents, similar to A3C.

  • vector_env_nums (int): Number of the vector environments.

Algorithms Configs

Note

The following configs are specific to PPO algorithm.

  • clip (float): Clipping parameter for PPO.

  • steps_per_epoch (int): Number of steps to update the policy network.

  • update_iters (int): Number of iterations to update the policy network.

  • batch_size (int): Batch size for each iteration.

  • target_kl (float): Target KL divergence.

  • entropy_coef (float): Coefficient of entropy.

  • reward_normalize (bool): Whether to normalize the reward.

  • cost_normalize (bool): Whether to normalize the cost.

  • obs_normalize (bool): Whether to normalize the observation.

  • kl_early_stop (bool): Whether to stop the training when KL divergence is too large.

  • max_grad_norm (float): Maximum gradient norm.

  • use_max_grad_norm (bool): Whether to use maximum gradient norm.

  • use_critic_norm (bool): Whether to use critic norm.

  • critic_norm_coef (float): Coefficient of critic norm.

  • gamma (float): Discount factor.

  • cost_gamma (float): Cost discount factor.

  • lam (float): Lambda for GAE-Lambda.

  • lam_c (float): Lambda for cost GAE-Lambda.

  • adv_estimation_method (str): The method to estimate the advantage.

  • standardized_rew_adv (bool): Whether to use standardized reward advantage.

  • standardized_cost_adv (bool): Whether to use standardized cost advantage.

  • penalty_coef (float): Penalty coefficient for cost.

  • use_cost (bool): Whether to use cost.

Model Configs

  • weight_initialization_mode (str): The type of weight initialization method.

  • actor_type (str): The type of actor, default to gaussian_learning.

  • linear_lr_decay (bool): Whether to use linear learning rate decay.

  • exploration_noise_anneal (bool): Whether to use exploration noise anneal.

  • std_range (list): The range of standard deviation.

Hint

actor (dictionary): parameters for actor network actor

  • activations: tanh

  • hidden_sizes:

  • 64

  • 64

Hint

critic (dictionary): parameters for critic network critic

  • activations: tanh

  • hidden_sizes:

  • 64

  • 64

Logger Configs

  • use_wandb (bool): Whether to use wandb to log the training process.

  • wandb_project (str): The name of wandb project.

  • use_tensorboard (bool): Whether to use tensorboard to log the training process.

  • log_dir (str): The directory to save the log files.

  • window_lens (int): The length of the window to calculate the average reward.

  • save_model_freq (int): The frequency to save the model.


References#