Projection-Based Constrained Policy Optimization#

Quick Facts#

  1. PCPO is an on-policy algorithm.

  2. PCPO can be used for environments with both discrete and continuous action spaces.

  3. PCPO is an improvement work done based on:bdg-info-line:CPO .

  4. The OmniSafe implementation of PCPO support parallelization.

  5. An API Documentation is available for PCPO.

PCPO Theorem#

Background#

Projection-Based Constrained Policy Optimization (PCPO) is an iterative method for optimizing policy in a two-stage process: the first stage performs a local reward improvement update, while the second stage reconciles any constraint violation by projecting the policy back onto the constraint set.

PCPO is an improvement work done based on CPO (Constrained Policy Optimization). It provides a lower bound on reward improvement, and an upper bound on constraint violation, for each policy update just like CPO does. PCPO further characterizes the convergence of PCPO based on two different metrics: \(L2\) norm and KL divergence.

In a word, PCPO is a CPO-based algorithm dedicated to solving the problem of learning control policies that optimize a reward function, while satisfying constraints due to considerations of safety, fairness, or other costs.

Hint

If you have not previously learned the CPO type of algorithm, to facilitate your complete understanding of the PCPO algorithm ideas introduced in this section, we strongly recommend that you read this article after reading the CPO tutorial (Constrained Policy Optimization) we wrote.


Optimization Objective#

In the previous chapters, you learned that CPO solves the following optimization problems:

(1)#\[\begin{split}&\pi_{k+1} = \arg\max_{\pi \in \Pi_{\boldsymbol{\theta}}}J^R(\pi)\\ \text{s.t.}\quad&D(\pi,\pi_k)\le\delta\\ &J^{C_i}(\pi)\le d_i\quad i=1,...m\end{split}\]

where \(\Pi_{\theta}\subseteq\Pi\) denotes the set of parametrized policies with parameters \(\theta\), and \(D\) is some distance measure. In local policy search for CMDPs, we additionally require policy iterates to be feasible for the CMDP, so instead of optimizing over \(\Pi_{\theta}\), PCPO optimizes over \(\Pi_{\theta}\cap\Pi_{C}\). Next, we will introduce you to how PCPO solves the above optimization problems. For you to have a clearer understanding, we hope that you will read the next section with the following questions:

Questions

  • What is a two-stage policy update and how?

  • What is performance bound for PCPO and how PCPO get it?

  • How PCPO practically solve the optimal problem?


Two-stage Policy Update#

PCPO performs policy update in two stages. The first stage is Reward Improvement Stage which maximizes reward using a trust region optimization method without constraints. This might result in a new intermediate policy that does not satisfy the constraints. The second stage named Projection Stage reconciles the constraint violation (if any) by projecting the policy back onto the constraint set, i.e., choosing the policy in the constraint set that is closest to the selected intermediate policy. Next, we will describe how PCPO completes the two-stage update.

Reward Improvement Stage

First, PCPO optimizes the reward function by maximizing the reward advantage function \(A_{\pi}(s,a)\) subject to KL-Divergence constraint. This constraints the intermediate policy \(\pi_{k+\frac12}\) to be within a \(\delta\)-neighborhood of \(\pi_{k}\):

(2)#\[\begin{split}&\pi_{k+\frac12}=\underset{\pi}{\arg\max}\underset{s\sim d^{\pi_k}, a\sim\pi}{\mathbb{E}}[A^R_{\pi_k}(s,a)]\\ \text{s.t.}\quad &\underset{s\sim d^{\pi_k}}{\mathbb{E}}[D_{KL}(\pi||\pi_k)[s]]\le\delta\nonumber\end{split}\]

This update rule with the trust region is called TRPO (sees in Trust Region Policy Optimization). It constraints the policy changes to a divergence neighborhood and guarantees reward improvement.

Projection Stage

Second, PCPO projects the intermediate policy \(\pi_{k+\frac12}\) onto the constraint set by minimizing a distance measure \(D\) between \(\pi_{k+\frac12}\) and \(\pi\):

(3)#\[\begin{split}&\pi_{k+1}=\underset{\pi}{\arg\min}\quad D(\pi,\pi_{k+\frac12})\\ \text{s.t.}\quad &J^C\left(\pi_k\right)+\underset{\substack{s \sim d^{\pi_k} , a \sim \pi}}{\mathbb{E}}\left[A^C_{\pi_k}(s, a)\right] \leq d\end{split}\]

The Projection Stage ensures that the constraint-satisfying policy \(\pi_{k+1}\) is close to \(\pi_{k+\frac{1}{2}}\). The Reward Improvement Stage ensures that the agent’s updates are in the direction of maximizing rewards, so as not to violate the step size of distance measure \(D\). Projection Stage causes the agent to update in the direction of satisfying the constraint while avoiding crossing \(D\) as much as possible.


Policy Performance Bounds#

In safety-critical applications, how worse the performance of a system evolves when applying a learning algorithm is an important issue. For the two cases where the agent satisfies the constraint and does not satisfy the constraint, PCPO provides worst-case performance bound respectively.

Worst-case Bound on Updating Constraint-satisfying Policies

Define \(\epsilon_{\pi_{k+1}}^{R}\doteq \max\limits_{s}\big|\mathbb{E}_{a\sim\pi_{k+1}}[A^{R}_{\pi_{k}}(s,a)]\big|\), and \(\epsilon_{\pi_{k+1}}^{C}\doteq \max\limits_{s}\big|\mathbb{E}_{a\sim\pi_{k+1}}[A^{C}_{\pi_{k}}(s,a)]\big|\). If the current policy \(\pi_k\) satisfies the constraint, then under KL divergence projection, the lower bound on reward improvement, and upper bound on constraint violation for each policy update are

(4)#\[\begin{split}J^{R}(\pi_{k+1})-J^{R}(\pi_{k})&\geq&-\frac{\sqrt{2\delta}\gamma\epsilon_{\pi_{k+1}}^{R}}{(1-\gamma)^{2}}\\ J^{C}(\pi_{k+1})&\leq& d+\frac{\sqrt{2\delta}\gamma\epsilon_{\pi_{k+1}}^{C}}{(1-\gamma)^{2}}\end{split}\]

where \(\delta\) is the step size in the reward improvement step.

Worst-case Bound on Updating Constraint-violating Policies

Define \(\epsilon_{\pi_{k+1}}^{R}\doteq \max\limits_{s}\big|\mathbb{E}_{a\sim\pi_{k+1}}[A^{R}_{\pi_{k}}(s,a)]\big|\), \(\epsilon_{\pi_{k+1}}^{C}\doteq \max\limits_{s}\big|\mathbb{E}_{a\sim\pi_{k+1}}[A^{C}_{\pi_{k}}(s,a)]\big|\), \(b^{+}\doteq \max(0,J^{C}(\pi_k)-d),\) and \(\alpha_{KL} \doteq \frac{1}{2a^T\boldsymbol{H}^{-1}a},\) where \(a\) is the gradient of the cost advantage function and \(\boldsymbol{H}\) is the Hessian of the KL divergence constraint. If the current policy \(\pi_k\) violates the constraint, then under KL divergence projection, the lower bound on reward improvement and the upper bound on constraint violation for each policy update are

(5)#\[\begin{split}J^{R}(\pi_{k+1})-J^{R}(\pi_{k})\geq&-\frac{\sqrt{2(\delta+{b^+}^{2}\alpha_\mathrm{KL})}\gamma\epsilon_{\pi_{k+1}}^{R}}{(1-\gamma)^{2}}\\ J^{C}(\pi_{k+1})\leq& ~d+\frac{\sqrt{2(\delta+{b^+}^{2}\alpha_\mathrm{KL})}\gamma\epsilon_{\pi_{k+1}}^{C}}{(1-\gamma)^{2}}\end{split}\]

where \(\delta\) is the step size in the reward improvement step.


Practical Implementation#

Implementation of a Two-stage Update#

For a large neural network policy with hundreds of thousands of parameters, directly solving for the PCPO update in Eq.2 and Eq.3 is impractical due to the computational cost. PCPO proposes that with a small step size \(\delta\), the reward function and constraints and the KL divergence constraint in the reward improvement step can be approximated with a first-order expansion, while the KL divergence measure in the projection step can also be approximated with a second order expansion.

Reward Improvement Stage

Define:

\(g\doteq\nabla_\theta\underset{\substack{s\sim d^{\pi_k}a\sim \pi}}{\mathbb{E}}[A_{\pi_k}^{R}(s,a)]\) is the gradient of the reward advantage function,

\(a\doteq\nabla_\theta\underset{\substack{s\sim d^{\pi_k}a\sim \pi}}{\mathbb{E}}[A_{\pi_k}^{C}(s,a)]\) is the gradient of the cost advantage function,

where \(\boldsymbol{H}_{i,j}\doteq \frac{\partial^2 \underset{s\sim d^{\pi_{k}}}{\mathbb{E}}\big[KL(\pi ||\pi_{k})[s]\big]}{\partial \theta_j\partial \theta_j}\) is the Hessian of the KL divergence constraint (\(\boldsymbol{H}\) is also called the Fisher information matrix. It is symmetric positive semi-definite), \(b\doteq J^{C}(\pi_k)-d\) is the constraint violation of the policy \(\pi_{k}\), and \(\theta\) is the parameter of the policy. PCPO linearize the objective function at \(\pi_k\) subject to second order approximation of the KL divergence constraint to obtain the following updates:

(6)#\[\begin{split}&\theta_{k+\frac{1}{2}} = \underset{\theta}{\arg\max}g^{T}(\theta-\theta_k) \\ \text{s.t.}\quad &\frac{1}{2}(\theta-\theta_{k})^{T}\boldsymbol{H}(\theta-\theta_k)\le \delta . \label{eq:update1}\end{split}\]

The above problem is essentially an optimization problem presented in TRPO, which can be completely solved using the method we introduced in the TRPO tutorial.

Projection Stage

PCPO provides a selection reference for distance measures: if the projection is defined in the parameter space, \(L2\) norm projection is selected, while if the projection is defined in the probability space, KL divergence projection is better. This can be approximated through the second-order expansion. Again, PCPO linearize the cost constraint at \(\pi_{k}\). This gives the following update for the projection step:

(7)#\[\begin{split}&\theta_{k+1} =\underset{\theta}{\arg\min}\frac{1}{2}(\theta-{\theta}_{k+\frac{1}{2}})^{T}\boldsymbol{L}(\theta-{\theta}_{k+\frac{1}{2}})\\ \text{s.t.}\quad & a^{T}(\theta-\theta_{k})+b\leq 0\end{split}\]

where \(\boldsymbol{L}=\boldsymbol{I}\) for \(L2\) norm projection, and \(\boldsymbol{L}=\boldsymbol{H}\) for KL divergence projection.

PCPO solves Eq.4 and Eq.5 using convex programming, see detailed in Appendix.

For each policy update:

(8)#\[\theta_{k+1}=\theta_{k}+\sqrt{\frac{2\delta}{g^T\boldsymbol{H}^{-1}g}}\boldsymbol{H}^{-1}g -\max\left(0,\frac{\sqrt{\frac{2\delta}{g^T\boldsymbol{H}^{-1}g}}a^{T}\boldsymbol{H}^{-1}g+b}{a^T\boldsymbol{L}^{-1}a}\right)\boldsymbol{L}^{-1}a\]

Hint

\(\boldsymbol{H}\) is assumed invertible and PCPO requires to invert \(\boldsymbol{H}\), which is impractical for huge neural network policies. Hence it uses the conjugate gradient method. (See appendix for a discussion of the trade-off between the approximation error, and computational efficiency of the conjugate gradient method.)

Question

Is using a linear approximation to the constraint set enough to ensure constraint satisfaction since the real constraint set is maybe non-convex?

Question

Can PCPO solve the multi-constraint problem? And how PCPO do that?

Answer

If the step size \(\delta\) is small, then the linearization of the constraint set is accurate enough to locally approximate it.

Answer

By sequentially projecting onto each of the sets, the update in Eq.5 can be extended by using alternating projections.


Analysis#

The update rule in Eq.5 shows that the difference between PCPO with KL divergence and \(L2\) norm projections are the cost update direction, leading to a difference in reward improvement. These two projections converge to different stationary points with different convergence rates related to the smallest and largest singular values of the Fisher information matrix shown in Theorem 3. PCPO assumes that: PCPO minimizes the negative reward objective function \(f: \mathbb{R}^n \rightarrow \mathbb{R}\) . The function \(f\) is \(L\)-smooth and twice continuously differentiable over the closed and convex constraint set \(\mathcal{C}\).

Theorem 3

Let \(\eta\doteq \sqrt{\frac{2\delta}{g^{T}\boldsymbol{H}^{-1}g}}\) in Eq.5, where \(\delta\) is the step size for reward improvement, \(g\) is the gradient of \(f\), and \(\boldsymbol{H}\) is the Fisher information matrix. Let \(\sigma_\mathrm{max}(\boldsymbol{H})\) be the largest singular value of \(\boldsymbol{H}\), and \(a\) be the gradient of cost advantage function in Eq.5. Then PCPO with KL divergence projection converges to a stationary point either inside the constraint set or in the boundary of the constraint set. In the latter case, the Lagrangian constraint \(g=-\alpha a, \alpha\geq0\) holds. Moreover, at step \(k+1\) the objective value satisfies

(9)#\[f(\theta_{k+1})\leq f(\theta_{k})+||\theta_{k+1}-\theta_{k}||^2_{-\frac{1}{\eta}\boldsymbol{H}+\frac{L}{2}\boldsymbol{I}}.\]

PCPO with \(L2\) norm projection converges to a stationary point either inside the constraint set or in the boundary of the constraint set. In the latter case, the Lagrangian constraint \(\boldsymbol{H}^{-1}g=-\alpha a, \alpha\geq0\) holds. If \(\sigma_\mathrm{max}(\boldsymbol{H})\leq1,\) then a step \(k+1\) objective value satisfies.

(10)#\[f(\theta_{k+1})\leq f(\theta_{k})+(\frac{L}{2}-\frac{1}{\eta})||\theta_{k+1}-\theta_{k}||^2_2.\]

Theorem 3 shows that in the stationary point \(g\) is a line that points to the opposite direction of \(a\). Further, the improvement of the objective value is affected by the singular value of the Fisher information matrix. Specifically, the objective of KL divergence projection decreases when \(\frac{L\eta}{2}\boldsymbol{I}\prec\boldsymbol{H},\) implying that \(\sigma_\mathrm{min}(\boldsymbol{H})> \frac{L\eta}{2}\). And the objective of \(L2\) norm projection decreases when \(\eta<\frac{2}{L},\) implying that condition number of \(\boldsymbol{H}\) is upper bounded: \(\frac{\sigma_\mathrm{max}(\boldsymbol{H})}{\sigma_\mathrm{min}(\boldsymbol{H})}<\frac{2||g||^2_2}{L^2\delta}\). Observing the singular values of the Fisher information matrix allows us to adaptively choose the appropriate projection and hence achieve objective improvement. In the supplemental material, we further use an example to compare the optimization trajectories and stationary points of KL divergence and \(L2\) norm projections.


Code with OmniSafe#

Quick start#

Run PCPO in OmniSafe

Here are 3 ways to run CPO in OmniSafe:

  • Run Agent from preset yaml file

  • Run Agent from custom config dict

  • Run Agent from custom terminal config

1import omnisafe
2
3
4env_id = 'SafetyPointGoal1-v0'
5
6agent = omnisafe.Agent('PCPO', env_id)
7agent.learn()
 1import omnisafe
 2
 3
 4env_id = 'SafetyPointGoal1-v0'
 5custom_cfgs = {
 6    'train_cfgs': {
 7        'total_steps': 1024000,
 8        'vector_env_nums': 1,
 9        'parallel': 1,
10    },
11    'algo_cfgs': {
12        'steps_per_epoch': 2048,
13        'update_iters': 1,
14    },
15    'logger_cfgs': {
16        'use_wandb': False,
17    },
18}
19
20agent = omnisafe.Agent('PCPO', env_id, custom_cfgs=custom_cfgs)
21agent.learn()

We use train_policy.py as the entrance file. You can train the agent with PCPO simply using train_policy.py, with arguments about PCPO and environments does the training. For example, to run PCPO in SafetyPointGoal1-v0 , with 1 torch thread and seed 0, you can use the following command:

1cd examples
2python train_policy.py --algo PCPO --env-id SafetyPointGoal1-v0 --parallel 1 --total-steps 1024000 --device cpu --vector-env-nums 1 --torch-threads 1

Architecture of functions#

  • PCPO.learn()

    • PCPO._env.rollout()

    • PCPO._update()

      • PCPO._buf.get()

      • PCPO._update_actor()

        • PCPO._fvp()

        • conjugate_gradients()

        • PCPO._cpo_search_step()

      • PCPO._update_cost_critic()

      • PCPO._update_reward_critic()


Documentation of basic functions#


Documentation of algorithm specific functions#

pcpo._update_actor()

Update the policy network, flowing the next steps:

  1. Get the policy reward performance gradient g (flat as vector)

1theta_old = get_flat_params_from(self._actor_critic.actor)
2self._actor_critic.actor.zero_grad()
3loss_reward, info = self._loss_pi(obs, act, logp, adv_r)
4loss_reward_before = distributed.dist_avg(loss_reward).item()
5p_dist = self._actor_critic.actor(obs)
  1. Get the policy cost performance gradient b (flat as vector)

1self._actor_critic.zero_grad()
2loss_cost = self._loss_pi_cost(obs, act, logp, adv_c)
3loss_cost_before = distributed.dist_avg(loss_cost).item()
4
5loss_cost.backward()
6distributed.avg_grads(self._actor_critic.actor)
7
8b_grads = get_flat_gradients_from(self._actor_critic.actor)
  1. Build the Hessian-vector product based on an approximation of the KL-divergence, using conjugate_gradients

1p = conjugate_gradients(self._fvp, b_grads, self._cfgs.algo_cfgs.cg_iters)
2q = xHx
3r = grads.dot(p)
4s = b_grads.dot(p)
  1. Determine step direction and apply SGD step after grads where set (By adjust_cpo_step_direction())

 1step_direction, accept_step = self._cpo_search_step(
 2    step_direction=step_direction,
 3    grads=grads,
 4    p_dist=p_dist,
 5    obs=obs,
 6    act=act,
 7    logp=logp,
 8    adv_r=adv_r,
 9    adv_c=adv_c,
10    loss_reward_before=loss_reward_before,
11    loss_cost_before=loss_cost_before,
12    total_steps=200,
13    violation_c=ep_costs,
14)
  1. Update actor network parameters

1theta_new = theta_old + step_direction
2set_param_values_to_model(self._actor_critic.actor, theta_new)

Configs#

Train Configs

  • device (str): Device to use for training, options: cpu, cuda,``cuda:0``, etc.

  • torch_threads (int): Number of threads to use for PyTorch.

  • total_steps (int): Total number of steps to train the agent.

  • parallel (int): Number of parallel agents, similar to A3C.

  • vector_env_nums (int): Number of the vector environments.

Algorithms Configs

Note

The following configs are specific to PCPO algorithm.

  • cg_damping (float): Damping coefficient for conjugate gradient.

  • cg_iters (int): Number of iterations for conjugate gradient.

  • fvp_sample_freq (int): Frequency of sampling for Fisher vector product.

  • steps_per_epoch (int): Number of steps to update the policy network.

  • update_iters (int): Number of iterations to update the policy network.

  • batch_size (int): Batch size for each iteration.

  • target_kl (float): Target KL divergence.

  • entropy_coef (float): Coefficient of entropy.

  • reward_normalize (bool): Whether to normalize the reward.

  • cost_normalize (bool): Whether to normalize the cost.

  • obs_normalize (bool): Whether to normalize the observation.

  • kl_early_stop (bool): Whether to stop the training when KL divergence is too large.

  • max_grad_norm (float): Maximum gradient norm.

  • use_max_grad_norm (bool): Whether to use maximum gradient norm.

  • use_critic_norm (bool): Whether to use critic norm.

  • critic_norm_coef (float): Coefficient of critic norm.

  • gamma (float): Discount factor.

  • cost_gamma (float): Cost discount factor.

  • lam (float): Lambda for GAE-Lambda.

  • lam_c (float): Lambda for cost GAE-Lambda.

  • adv_estimation_method (str): The method to estimate the advantage.

  • standardized_rew_adv (bool): Whether to use standardized reward advantage.

  • standardized_cost_adv (bool): Whether to use standardized cost advantage.

  • penalty_coef (float): Penalty coefficient for cost.

  • use_cost (bool): Whether to use cost.

Model Configs

  • weight_initialization_mode (str): The type of weight initialization method.

  • actor_type (str): The type of actor, default to gaussian_learning.

  • linear_lr_decay (bool): Whether to use linear learning rate decay.

  • exploration_noise_anneal (bool): Whether to use exploration noise anneal.

  • std_range (list): The range of standard deviation.

Hint

actor (dictionary): parameters for actor network actor

  • activations: tanh

  • hidden_sizes:

  • 64

  • 64

Hint

critic (dictionary): parameters for critic network critic

  • activations: tanh

  • hidden_sizes:

  • 64

  • 64

Logger Configs

  • use_wandb (bool): Whether to use wandb to log the training process.

  • wandb_project (str): The name of wandb project.

  • use_tensorboard (bool): Whether to use tensorboard to log the training process.

  • log_dir (str): The directory to save the log files.

  • window_lens (int): The length of the window to calculate the average reward.

  • save_model_freq (int): The frequency to save the model.


References#

Appendix#

Click here to jump to PCPO Theorem Click here to jump to Code with OmniSafe

Proof of Theorem 2#

To prove the policy performance bound when the current policy is infeasible (constraint-violating), we first prove two lemmas of the KL divergence between \(\pi_{k}\) and \(\pi_{k+1}\) for the KL divergence projection. We then prove the main theorem for the worst-case performance degradation.

Lemma 1

If the current policy \(\pi_{k}\) satisfies the constraint, the constraint set is closed and convex, and the KL divergence constraint for the first step is \(\mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL}(\pi_{k+\frac{1}{2}} ||\pi_{k})[s]\big]\leq \delta\), where \(\delta\) is the step size in the reward improvement step, then under KL divergence projection, we have

(11)#\[\mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL}(\pi_{k+1} ||\pi_{k})[s]\big]\leq \delta.\]

Lemma 2

If the current policy \(\pi_{k}\) violates the constraint, the constraint set is closed and convex, the KL divergence constraint for the first step is \(\mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL}(\pi_{k+\frac{1}{2}} ||\pi_{k})[s]\big]\leq \delta\). where \(\delta\) is the step size in the reward improvement step, then under the KL divergence projection, we have

(12)#\[\mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL}(\pi_{k+1} ||\pi_{k})[s]\big]\leq \delta+{b^+}^2\alpha_\mathrm{KL},\]

where \(\alpha_\mathrm{KL} \doteq \frac{1}{2a^T\boldsymbol{H}^{-1}a}\), \(a\) is the gradient of the cost advantage function, \(\boldsymbol{H}\) is the Hessian of the KL divergence constraint, and \(b^+\doteq\max(0,J^{C}(\pi_k)-h)\).

Proof of Lemma 1

By the Bregman divergence projection inequality, \(\pi_{k}\) being in the constraint set, and \(\pi_{k+1}\) being the projection of the \(\pi_{k+\frac{1}{2}}\) onto the constraint set, we have

(13)#\[\begin{split}&\mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL}(\pi_{k} ||\pi_{k+\frac{1}{2}})[s]\big]\geq \mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL}(\pi_{k}||\pi_{k+1})[s]\big] \\ &+ \mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL}(\pi_{k+1} ||\pi_{k+\frac{1}{2}})[s]\big]\\ &\Rightarrow\delta\geq \mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL}(\pi_{k} ||\pi_{k+\frac{1}{2}})[s]\big]\geq \mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL}(\pi_{k}||\pi_{k+1})[s]\big].\end{split}\]

The derivation uses the fact that KL divergence is always greater than zero. We know that KL divergence is asymptotically symmetric when updating the policy within a local neighborhood. Thus, we have

(14)#\[\delta\geq \mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL}(\pi_{k+\frac{1}{2}} ||\pi_{k})[s]\big]\geq \mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL}(\pi_{k+1}||\pi_{k})[s]\big].\]

Proof of Lemma 2

We define the sub-level set of cost constraint functions for the current infeasible policy \(\pi_k\):

(15)#\[\begin{split}L^{\pi_k}=\{\pi~|~J^{C}(\pi_{k})+ \mathbb{E}_{\substack{s\sim d^{\pi_{k}}\\ a\sim \pi}}[A_{\pi_k}^{C}(s,a)]\leq J^{C}(\pi_{k})\}.\end{split}\]

This implies that the current policy \(\pi_k\) lies in \(L^{\pi_k}\), and \(\pi_{k+\frac{1}{2}}\) is projected onto the constraint set: \(\{\pi~|~J^{C}(\pi_{k})+ \mathbb{E}_{\substack{s\sim d^{\pi_{k}}\\ a\sim \pi}}[A_{\pi_k}^{C}(s,a)]\leq h\}\). Next, we define the policy \(\pi_{k+1}^l\) as the projection of \(\pi_{k+\frac{1}{2}}\) onto \(L^{\pi_k}\).

For these three polices \(\pi_k, \pi_{k+1}\) and \(\pi_{k+1}^l\), with \(\varphi(x)\doteq\sum_i x_i\log x_i\), we have

(16)#\[ \begin{align}\begin{aligned}\begin{split}\delta &\geq \mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL}(\pi_{k+1}^l ||\pi_{k})[s]\big] \\&=\mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL}(\pi_{k+1} ||\pi_{k})[s]\big] -\mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL} (\pi_{k+1} ||\pi_{k+1}^l)[s]\big]\\ &+\mathbb{E}_{s\sim d^{\pi_{k}}}\big[(\nabla\varphi(\pi_k)-\nabla\varphi(\pi_{k+1}^{l}))^T(\pi_{k+1}-\pi_{k+1}^l)[s]\big] \nonumber \\\end{split}\\\begin{split}\Rightarrow \mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL} (\pi_{k+1} ||\pi_{k})[s]\big]&\leq \delta + \mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL} (\pi_{k+1} ||\pi_{k+1}^l)[s]\big]\\ &- \mathbb{E}_{s\sim d^{\pi_{k}}}\big[(\nabla\varphi(\pi_k)-\nabla\varphi(\pi_{k+1}^{l}))^T(\pi_{k+1}-\pi_{k+1}^l)[s]\big].\end{split}\end{aligned}\end{align} \]

The inequality \(\mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL} (\pi_{k+1}^l ||\pi_{k})[s]\big]\leq\delta\) comes from that \(\pi_{k}\) and \(\pi_{k+1}^l\) are in \(L^{\pi_k}\), and Lemma 1.

If the constraint violation of the current policy \(\pi_k\) is small, \(b^+\) is small, \(\mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL} (\pi_{k+1} ||\pi_{k+1}^l)[s]\big]\) can be approximated by the second order expansion. By the update rule in Eq.5, we have

(17)#\[\begin{split}\mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL}(\pi_{k+1} ||\pi_{k+1}^l)[s]\big] &\approx \frac{1}{2}(\theta_{k+1}-\theta_{k+1}^l)^{T}\boldsymbol{H}(\theta_{k+1}-\theta_{k+1}^l)\\ &=\frac{1}{2} \Big(\frac{b^+}{a^T\boldsymbol{H}^{-1}a}\boldsymbol{H}^{-1}a\Big)^T\boldsymbol{H}\Big(\frac{b^+}{a^T\boldsymbol{H}^{-1}a}\boldsymbol{H}^{-1}a\Big)\\ &=\frac{{b^+}^2}{2a^T\boldsymbol{H}^{-1}a}\\ &={b^+}^2\alpha_\mathrm{KL},\end{split}\]

where \(\alpha_\mathrm{KL} \doteq \frac{1}{2a^T\boldsymbol{H}^{-1}a}.\)

And since \(\delta\) is small, we have \(\nabla\varphi(\pi_k)-\nabla\varphi(\pi_{k+1}^{l})\approx \mathbf{0}\) given \(s\). Thus, the third term in Eq.8 can be eliminated.

Combining Eq.8 and Eq.13, we have \([ \mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL}(\pi_{k+1}||\pi_{k})[s]\big]\leq \delta+{b^+}^2\alpha_\mathrm{KL}.]\)

Now we use Lemma 2 to prove the Theorem 2. Following the same proof in Theorem 1, we complete the proof.

Proof of Analytical Solution to PCPO#

Analytical Solution to PCPO

Consider the PCPO problem. In the first step, we optimize the reward:

(18)#\[\begin{split}\theta_{k+\frac{1}{2}} = &\underset{\theta}{\arg\,min}\quad g^{T}(\theta-\theta_{k}) \\ \text{s.t.}\quad&\frac{1}{2}(\theta-\theta_{k})^{T}\boldsymbol{H}(\theta-\theta_{k})\leq \delta,\end{split}\]

and in the second step, we project the policy onto the constraint set:

(19)#\[\begin{split}\theta_{k+1} = &\underset{\theta}{\arg\,min}\quad \frac{1}{2}(\theta-{\theta}_{k+\frac{1}{2}})^{T}\boldsymbol{L}(\theta-{\theta}_{k+\frac{1}{2}}) \\ \text{s.t.}\quad &a^{T}(\theta-\theta_{k})+b\leq 0,\end{split}\]

where \(g, a, \theta \in R^n, b, \delta\in R, \delta>0,\) and \(\boldsymbol{H},\boldsymbol{L}\in R^{n\times n}, \boldsymbol{L}=\boldsymbol{H}\), if using the KL divergence projection, and \(\boldsymbol{L}=\boldsymbol{I}\) if using the \(L2\) norm projection. When there is at least one strictly feasible point, the optimal solution satisfies

(20)#\[\begin{split}\theta_{k+1}&=\theta_{k}+\sqrt{\frac{2\delta}{g^T\boldsymbol{H}^{-1}g}}\boldsymbol{H}^{-1}g\nonumber\\ &-\max(0,\frac{\sqrt{\frac{2\delta}{g^T\boldsymbol{H}^{-1}g}}a^{T}\boldsymbol{H}^{-1}g+b}{a^T\boldsymbol{L}^{-1}a})\boldsymbol{L}^{-1}a\end{split}\]

assuming that \(\boldsymbol{H}\) is invertible to get a unique solution.

Proof of Analytical Solution to PCPO (Click here)

For the first problem, since \(\boldsymbol{H}\) is the Fisher Information matrix, which automatically guarantees it is positive semi-definite. Hence it is a convex program with quadratic inequality constraints. Hence if the primal problem has a feasible point, then Slater’s condition is satisfied and strong duality holds. Let \(\theta^{*}\) and \(\lambda^*\) denote the solutions to the primal and dual problems, respectively. In addition, the primal objective function is continuously differentiable. Hence the Karush-Kuhn-Tucker (KKT) conditions are necessary and sufficient for the optimality of \(\theta^{*}\) and \(\lambda^*.\) We now form the Lagrangian:

(21)#\[\mathcal{L}(\theta,\lambda)=-g^{T}(\theta-\theta_{k})+\lambda\Big(\frac{1}{2}(\theta-\theta_{k})^{T}\boldsymbol{H}(\theta-\theta_{k})- \delta\Big).\]

And we have the following KKT conditions:

(22)#\[\begin{split}-g + \lambda^*\boldsymbol{H}\theta^{*}-\lambda^*\boldsymbol{H}\theta_{k}=0~~~~&~~~\nabla_\theta\mathcal{L}(\theta^{*},\lambda^{*})=0 \\ \frac{1}{2}(\theta^{*}-\theta_{k})^{T}\boldsymbol{H}(\theta^{*}-\theta_{k})- \delta=0~~~~&~~~\nabla_\lambda\mathcal{L}(\theta^{*},\lambda^{*})=0 \\ \frac{1}{2}(\theta^{*}-\theta_{k})^{T}\boldsymbol{H}(\theta^{*}-\theta_{k})-\delta\leq0~~~~&~~~\text{primal constraints}\label{KKT_3}\\ \lambda^*\geq0~~~~&~~~\text{dual constraints}\\ \lambda^*\Big(\frac{1}{2}(\theta^{*}-\theta_{k})^{T}\boldsymbol{H}(\theta^{*}-\theta_{k})-\delta\Big)=0~~~~&~~~\text{complementary slackness}\end{split}\]

By Eq.22, we have \(\theta^{*}=\theta_{k}+\frac{1}{\lambda^*}\boldsymbol{H}^{-1}g\). And \(\lambda^*=\sqrt{\frac{g^T\boldsymbol{H}^{-1}g}{2\delta}}\) . Hence we have our optimal solution:

(23)#\[\theta_{k+\frac{1}{2}}=\theta^{*}=\theta_{k}+\sqrt{\frac{2\delta}{g^T\boldsymbol{H}^{-1}g}}\boldsymbol{H}^{-1}g\]

Following the same reasoning, we now form the Lagrangian of the second problem:

(24)#\[\mathcal{L}(\theta,\lambda)=\frac{1}{2}(\theta-{\theta}_{k+\frac{1}{2}})^{T}\boldsymbol{L}(\theta-{\theta}_{k+\frac{1}{2}})+\lambda(a^T(\theta-\theta_{k})+b)\]

And we have the following KKT conditions:

(25)#\[\begin{split}\boldsymbol{L}\theta^*-\boldsymbol{L}\theta_{k+\frac{1}{2}}+\lambda^*a=0~~~~&~~~\nabla_\theta\mathcal{L}(\theta^{*},\lambda^{*})=0 \\ a^T(\theta^*-\theta_{k})+b=0~~~~&~~~\nabla_\lambda\mathcal{L}(\theta^{*},\lambda^{*})=0 \\ a^T(\theta^*-\theta_{k})+b\leq0~~~~&~~~\text{primal constraints} \\ \lambda^*\geq0~~~~&~~~\text{dual constraints} \\ \lambda^*(a^T(\theta^*-\theta_{k})+b)=0~~~~&~~~\text{complementary slackness}\end{split}\]

By Eq.25, we have \(\theta^{*}=\theta_{k+1}+\lambda^*\boldsymbol{L}^{-1}a\). And by solving Eq.25, we have \(\lambda^*=\max(0,\\ \frac{a^T(\theta_{k+\frac{1}{2}}-\theta_{k})+b}{a\boldsymbol{L}^{-1}a})\). Hence we have our optimal solution:

(26)#\[\theta_{k+1}=\theta^{*}=\theta_{k+\frac{1}{2}}-\max(0,\frac{a^T(\theta_{k+\frac{1}{2}}-\theta_{k})+b}{a^T\boldsymbol{L}^{-1}a^T})\boldsymbol{L}^{-1}a\]

we have

(27)#\[\begin{split}\theta_{k+1}&=\theta_{k}+\sqrt{\frac{2\delta}{g^T\boldsymbol{H}^{-1}g}}\boldsymbol{H}^{-1}g\\ &-\max(0,\frac{\sqrt{\frac{2\delta}{g^T\boldsymbol{H}^{-1}g}}a^{T}\boldsymbol{H}^{-1}g+b}{a^T\boldsymbol{L}^{-1}a})\boldsymbol{L}^{-1}a\end{split}\]

Proof of Theorem 3#

For our analysis, we make the following assumptions: we minimize the negative reward objective function \(f: R^n \rightarrow R\) (We follow the convention of the literature that authors typically minimize the objective function). The function \(f\) is \(L\)-smooth and twice continuously differentiable over the closed and convex constraint set \(\mathcal{C}\). We have the following Lemma 3 to characterize the projection and for the proof of Theorem 3

Lemma 3

For any \(\theta\), \(\theta^{*}=\mathrm{Proj}^{\boldsymbol{L}}_{\mathcal{C}}(\theta)\) if and only if \((\theta-\theta^*)^T\boldsymbol{L}(\theta'-\theta^*)\leq0, \forall\theta'\in\mathcal{C}\), where \(\mathrm{Proj}^{\boldsymbol{L}}_{\mathcal{C}}(\theta)\doteq \underset{\theta' \in \mathrm{C}}{\arg\,min}||\theta-\theta'||^2_{\boldsymbol{L}}\) and \(\boldsymbol{L}=\boldsymbol{H}\) if using the KL divergence projection, and \(\boldsymbol{L}=\boldsymbol{I}\) if using the \(L2\) norm projection.

Based on Lemma 3 we have the proof of following Theorem 3.

Theorem 3 (Stationary Points of PCPO with the KL divergence and \(L2\) Norm Projections)

Let \(\eta\doteq \sqrt{\frac{2\delta}{g^{T}\boldsymbol{H}^{-1}g}}\) in Eq.5, where \(\delta\) is the step size for reward improvement, \(g\) is the gradient of \(f\), \(\boldsymbol{H}\) is the Fisher information matrix. Let \(\sigma_\mathrm{max}(\boldsymbol{H})\) be the largest singular value of \(\boldsymbol{H}\), and \(a\) be the gradient of cost advantage function in Eq.5. Then PCPO with the KL divergence projection converges to stationary points with \(g\in-a\) (i.e., the gradient of \(f\) belongs to the negative gradient of the cost advantage function). The objective value changes by

(31)#\[f(\theta_{k+1})\leq f(\theta_{k})+||\theta_{k+1}-\theta_{k}||^2_{-\frac{1}{\eta}\boldsymbol{H}+\frac{L}{2}\boldsymbol{I}}\]

PCPO with the \(L2\) norm projection converges to stationary points with \(\boldsymbol{H}^{-1}g\in-a\) (i.e., the product of the inverse of \(\boldsymbol{H}\) and gradient of \(f\) belongs to the negative gradient of the cost advantage function). If \(\sigma_\mathrm{max}(\boldsymbol{H})\leq1\), then the objective value changes by

(32)#\[f(\theta_{k+1})\leq f(\theta_{k})+(\frac{L}{2}-\frac{1}{\eta})||\theta_{k+1}-\theta_{k}||^2_2\]
Proof of Theorem 3 (Click here)

The proof of the theorem is based on working in a Hilbert space and the non-expansive property of the projection. We first prove stationary points for PCPO with the KL divergence and \(L2\) norm projections and then prove the change of the objective value.

When in stationary points \(\theta^*\), we have

(33)#\[\begin{split}\theta^{*}&=\theta^{*}-\sqrt{\frac{2\delta}{g^T\boldsymbol{H}^{-1}g}}\boldsymbol{H}^{-1}g -\max\left(0,\frac{\sqrt{\frac{2\delta}{g^T\boldsymbol{H}^{-1}g}}a^{T}\boldsymbol{H}^{-1}g+b}{a^T\boldsymbol{L}^{-1}a}\right)\boldsymbol{L}^{-1}a\\ &\Leftrightarrow \sqrt{\frac{2\delta}{g^T\boldsymbol{H}^{-1}g}}\boldsymbol{H}^{-1}g = -\max(0,\frac{\sqrt{\frac{2\delta}{g^T\boldsymbol{H}^{-1}g}}a^{T}\boldsymbol{H}^{-1}g+b}{a^T\boldsymbol{L}^{-1}a})\boldsymbol{L}^{-1}a\\ &\Leftrightarrow \boldsymbol{H}^{-1}g \in -\boldsymbol{L}^{-1}a. \label{eq:appendixStationary}\end{split}\]

For the KL divergence projection (\(\boldsymbol{L}=\boldsymbol{H}\)), Eq.33 boils down to \(g\in-a\), and for the \(L2\) norm projection (\(\boldsymbol{L}=\boldsymbol{I}\)), Eq.33 is equivalent to \(\boldsymbol{H}^{-1}g\in-a\).

Now we prove the second part of the theorem. Based on Lemma 3, for the KL divergence projection, we have

(34)#\[\begin{split}\label{eq:appendix_converge_0} \left(\theta_k-\theta_{k+1}\right)^T \boldsymbol{H}\left(\theta_k-\eta \boldsymbol{H}^{-1} \boldsymbol{g}-\theta_{k+1}\right) \leq 0 \\ \Rightarrow \boldsymbol{g}^T\left(\theta_{k+1}-\theta_k\right) \leq-\frac{1}{\eta}\left\|\theta_{k+1}-\theta_k\right\|_{\boldsymbol{H}}^2\end{split}\]

By Eq.34, and \(L\)-smooth continuous function \(f,\) we have

(35)#\[\begin{split}f\left(\theta_{k+1}\right) & \leq f\left(\theta_k\right)+\boldsymbol{g}^T\left(\theta_{k+1}-\theta_k\right)+\frac{L}{2}\left\|\theta_{k+1}-\theta_k\right\|_2^2 \\ & \leq f\left(\theta_k\right)-\frac{1}{\eta}\left\|\theta_{k+1}-\theta_k\right\|_{\boldsymbol{H}}^2+\frac{L}{2}\left\|\theta_{k+1}-\theta_k\right\|_2^2 \\ &=f\left(\theta_k\right)+\left(\theta_{k+1}-\theta_k\right)^T\left(-\frac{1}{\eta} \boldsymbol{H}+\frac{L}{2} \boldsymbol{I}\right)\left(\theta_{k+1}-\theta_k\right) \\ &=f\left(\theta_k\right)+\left\|\theta_{k+1}-\theta_k\right\|_{-\frac{1}{\eta} \boldsymbol{H}+\frac{L}{2} \boldsymbol{I}}^2\end{split}\]

For the \(L2\) norm projection, we have

(36)#\[\begin{split}(\theta_{k}-\theta_{k+1})^T(\theta_{k}-\eta\boldsymbol{H}^{-1}g-\theta_{k+1})\leq0\\ \Rightarrow g^T\boldsymbol{H}^{-1}(\theta_{k+1}-\theta_{k})\leq -\frac{1}{\eta}||\theta_{k+1}-\theta_{k}||^2_2\end{split}\]

By Eq.36, \(L\)-smooth continuous function \(f\), and if \(\sigma_\mathrm{max}(\boldsymbol{H})\leq1\), we have

(37)#\[\begin{split}f(\theta_{k+1})&\leq f(\theta_{k})+g^T(\theta_{k+1}-\theta_{k})+\frac{L}{2}||\theta_{k+1}-\theta_{k}||^2_2 \nonumber\\ &\leq f(\theta_{k})+(\frac{L}{2}-\frac{1}{\eta})||\theta_{k+1}-\theta_{k}||^2_2.\nonumber\end{split}\]

To see why we need the assumption of \(\sigma_\mathrm{max}(\boldsymbol{H})\leq1\), we define \(\boldsymbol{H}=\boldsymbol{U}\boldsymbol{\Sigma}\boldsymbol{U}^T\) as the singular value decomposition of \(\boldsymbol{H}\) with \(u_i\) being the column vector of \(\boldsymbol{U}\). Then we have

(38)#\[\begin{split}g^T\boldsymbol{H}^{-1}(\theta_{k+1}-\theta_{k}) &=g^T\boldsymbol{U}\boldsymbol{\Sigma}^{-1}\boldsymbol{U}^T(\theta_{k+1}-\theta_{k}) \nonumber\\ &=g^T(\sum_{i}\frac{1}{\sigma_i(\boldsymbol{H})}u_iu_i^T)(\theta_{k+1}-\theta_{k})\nonumber\\ &=\sum_{i}\frac{1}{\sigma_i(\boldsymbol{H})}g^T(\theta_{k+1}-\theta_{k}).\nonumber\end{split}\]

If we want to have

(39)#\[g^T(\theta_{k+1}-\theta_{k})\leq g^T\boldsymbol{H}^{-1}(\theta_{k+1}-\theta_{k})\leq -\frac{1}{\eta}||\theta_{k+1}-\theta_{k}||^2_2,\]

then every singular value \(\sigma_i(\boldsymbol{H})\) of \(\boldsymbol{H}\) needs to be smaller than \(1\), and hence \(\sigma_\mathrm{max}(\boldsymbol{H})\leq1\), which justifies the assumption we use to prove the bound.

Hint

To make the objective value for PCPO with the KL divergence projection improves, the right-hand side of Eq.26 needs to be negative. Hence we have \(\frac{L\eta}{2}\boldsymbol{I}\prec\boldsymbol{H}\), implying that \(\sigma_\mathrm{min}(\boldsymbol{H})>\frac{L\eta}{2}\). And to make the objective value for PCPO with the \(L2\) norm projection improves, the right-hand side of Eq.28 needs to be negative. Hence we have \(\eta<\frac{2}{L}\), implying that

(40)#\[\begin{split}&\eta = \sqrt{\frac{2\delta}{g^T\boldsymbol{H}^{-1}g}}<\frac{2}{L}\nonumber\\ \Rightarrow& \frac{2\delta}{g^T\boldsymbol{H}^{-1}g} < \frac{4}{L^2} \nonumber\\ \Rightarrow& \frac{g^{T}\boldsymbol{H}^{-1}g}{2\delta}>\frac{L^2}{4}\nonumber\\ \Rightarrow& \frac{L^2\delta}{2}<g^T\boldsymbol{H}^{-1}g\nonumber\\ &\leq||g||_2||\boldsymbol{H}^{-1}g||_2\nonumber\\ &\leq||g||_2||\boldsymbol{H}^{-1}||_2||g||_2\nonumber\\ &=\sigma_\mathrm{max}(\boldsymbol{H}^{-1})||g||^2_2\nonumber\\ &=\sigma_\mathrm{min}(\boldsymbol{H})||g||^2_2\nonumber\\ \Rightarrow&\sigma_\mathrm{min}(\boldsymbol{H})>\frac{L^2\delta}{2||g||^2_2}. \label{eqnarray}\end{split}\]

By the definition of the condition number and Eq.33, we have