Projection-Based Constrained Policy Optimization#

Quick Facts#

  1. PCPO is an on-policy algorithm.

  2. PCPO can be used for environments with both discrete and continuous action spaces.

  3. PCPO is an improvement work done on the basis of CPO .

  4. The OmniSafe implementation of PCPO support parallelization.

PCPO Theorem#

Background#

Projection-Based Constrained Policy Optimization (PCPO) is an iterative method for optimizing policy in a two-stage process: the first stage performs a local reward improvement update, while the second stage reconciles any constraint violation by projecting the policy back onto the constraint set.

PCPO is an improvement work done on the basis of CPO (Constrained Policy Optimization). It provides a lower bound on reward improvement, and an upper bound on constraint violation, for each policy update just like CPO does. PCPO further characterizes the convergence of PCPO based on two different metrics: \(L2\) norm and KL divergence.

In a word, PCPO is a CPO-based algorithm dedicated to solving problem of learning control policies that optimize a reward function, while satisfying constraints due to considerations of safety, fairness, or other costs.

Hint

If you have not previously learned the CPO type of algorithm, in order to facilitate your complete understanding of the PCPO algorithm ideas introduced in this section, we strongly recommend that you read this article after reading the CPO tutorial (Constrained Policy Optimization) we wrote.


Optimization Objective#

In the previous chapters, you learned that CPO solves the following optimization problems:

(1)#\[\begin{split}&\pi_{k+1} = \arg\max_{\pi \in \Pi_{\boldsymbol{\theta}}}J^R(\pi)\\ \text{s.t.}\quad&D(\pi,\pi_k)\le\delta\\ &J^{C_i}(\pi)\le d_i\quad i=1,...m\end{split}\]

where \(\Pi_{\theta}\subseteq\Pi\) denotes the set of parametrized policies with parameters \(\theta\), and \(D\) is some distance measure. In local policy search for CMDPs, we additionally require policy iterates to be feasible for the CMDP, so instead of optimizing over \(\Pi_{\theta}\), PCPO optimizes over \(\Pi_{\theta}\cap\Pi_{C}\). Next, we will introduce you to how PCPO solves the above optimization problems. In order for you to have a clearer understanding, we hope that you will read the next section with the following questions:

Questions

  • What is two-stage policy update and how?

  • What is performance bound for PCPO and how PCPO get it?

  • How PCPO practically solve the optimal problem?


Two-stage Policy Update#

PCPO performs policy update in two stages. The first stage is Reward Improvement Stage which maximizes reward using a trust region optimization method without constraints. This might result in a new intermediate policy that does not satisfy the constraints. The second stage named Projection Stage reconciles the constraint violation (if any) by projecting the policy back onto the constraint set, i.e., choosing the policy in the constraint set that is closest to the selected intermediate policy. Next, we will describe how PCPO completes the two-stage update.

Reward Improvement Stage

First, PCPO optimizes the reward function by maximizing the reward advantage function \(A_{\pi}(s,a)\) subject to KL-Divergence constraint. This constraints the intermediate policy \(\pi_{k+\frac12}\) to be within a \(\delta\)-neighborhood of \(\pi_{k}\):

(2)#\[\begin{split}&\pi_{k+\frac12}=\underset{\pi}{\arg\max}\underset{s\sim d^{\pi_k}, a\sim\pi}{\mathbb{E}}[A^R_{\pi_k}(s,a)]\\ \text{s.t.}\quad &\underset{s\sim d^{\pi_k}}{\mathbb{E}}[D_{KL}(\pi||\pi_k)[s]]\le\delta\nonumber\end{split}\]

This update rule with the trust region is called TRPO (sees in Trust Region Policy Optimization). It constraints the policy changes to a divergence neighborhood and guarantees reward improvement.

Projection Stage

Second, PCPO projects the intermediate policy \(\pi_{k+\frac12}\) onto the constraint set by minimizing a distance measure \(D\) between \(\pi_{k+\frac12}\) and \(\pi\):

(3)#\[\begin{split}&\pi_{k+1}=\underset{\pi}{\arg\min}\quad D(\pi,\pi_{k+\frac12})\\ \text{s.t.}\quad &J^C\left(\pi_k\right)+\underset{\substack{s \sim d^{\pi_k} , a \sim \pi}}{\mathbb{E}}\left[A^C_{\pi_k}(s, a)\right] \leq d\end{split}\]

The Projection Stage ensures that the constraint-satisfying policy \(\pi_{k+1}\) is close to \(\pi_{k+\frac{1}{2}}\). The Reward Improvement Stage ensures that the agent’s updates are in the direction of maximizing rewards, so as not to violate the step size of distance measure \(D\). Projection Stage causes the agent to update in the direction of satisfying the constraint while avoiding crossing \(D\) as much as possible.


Policy Performance Bounds#

In safety-critical applications, how worse the performance of a system evolves when applying a learning algorithm is an important issue. For the two cases where the agent satisfies the constraint and does not satisfy the constraint, PCPO provides worst-case performance bound respectively.

Worst-case Bound on Updating Constraint-satisfying Policies

Define \(\epsilon_{\pi_{k+1}}^{R}\doteq \max\limits_{s}\big|\mathbb{E}_{a\sim\pi_{k+1}}[A^{R}_{\pi_{k}}(s,a)]\big|\), and \(\epsilon_{\pi_{k+1}}^{C}\doteq \max\limits_{s}\big|\mathbb{E}_{a\sim\pi_{k+1}}[A^{C}_{\pi_{k}}(s,a)]\big|\). If the current policy \(\pi_k\) satisfies the constraint, then under KL divergence projection, the lower bound on reward improvement, and upper bound on constraint violation for each policy update are

(4)#\[\begin{split}J^{R}(\pi_{k+1})-J^{R}(\pi_{k})&\geq&-\frac{\sqrt{2\delta}\gamma\epsilon_{\pi_{k+1}}^{R}}{(1-\gamma)^{2}}\\ J^{C}(\pi_{k+1})&\leq& d+\frac{\sqrt{2\delta}\gamma\epsilon_{\pi_{k+1}}^{C}}{(1-\gamma)^{2}}\end{split}\]

where \(\delta\) is the step size in the reward improvement step.

Worst-case Bound on Updating Constraint-violating Policies

Define \(\epsilon_{\pi_{k+1}}^{R}\doteq \max\limits_{s}\big|\mathbb{E}_{a\sim\pi_{k+1}}[A^{R}_{\pi_{k}}(s,a)]\big|\), \(\epsilon_{\pi_{k+1}}^{C}\doteq \max\limits_{s}\big|\mathbb{E}_{a\sim\pi_{k+1}}[A^{C}_{\pi_{k}}(s,a)]\big|\), \(b^{+}\doteq \max(0,J^{C}(\pi_k)-d),\) and \(\alpha_{KL} \doteq \frac{1}{2a^T\boldsymbol{H}^{-1}a},\) where \(a\) is the gradient of the cost advantage function and \(\boldsymbol{H}\) is the Hessian of the KL divergence constraint. If the current policy \(\pi_k\) violates the constraint, then under KL divergence projection, the lower bound on reward improvement and the upper bound on constraint violation for each policy update are

(5)#\[\begin{split}J^{R}(\pi_{k+1})-J^{R}(\pi_{k})\geq&-\frac{\sqrt{2(\delta+{b^+}^{2}\alpha_\mathrm{KL})}\gamma\epsilon_{\pi_{k+1}}^{R}}{(1-\gamma)^{2}}\\ J^{C}(\pi_{k+1})\leq& ~d+\frac{\sqrt{2(\delta+{b^+}^{2}\alpha_\mathrm{KL})}\gamma\epsilon_{\pi_{k+1}}^{C}}{(1-\gamma)^{2}}\end{split}\]

where \(\delta\) is the step size in the reward improvement step.


Practical Implementation#

Implementation of Two-stage Update#

For a large neural network policy with hundreds of thousands of parameters, directly solving for the PCPO update in Eq.2 and Eq.3 is impractical due to the computational cost. PCPO proposes that with a small step size \(\delta\), the reward function and constraints and the KL divergence constraint in the reward improvement step can be approximated with a first order expansion, while the KL divergence measure in the projection step can also be approximated with a second order expansion.

Reward Improvement Stage

Define:

\(g\doteq\nabla_\theta\underset{\substack{s\sim d^{\pi_k}a\sim \pi}}{\mathbb{E}}[A_{\pi_k}^{R}(s,a)]\) is the gradient of the reward advantage function,

\(a\doteq\nabla_\theta\underset{\substack{s\sim d^{\pi_k}a\sim \pi}}{\mathbb{E}}[A_{\pi_k}^{C}(s,a)]\) is the gradient of the cost advantage function,

where \(\boldsymbol{H}_{i,j}\doteq \frac{\partial^2 \underset{s\sim d^{\pi_{k}}}{\mathbb{E}}\big[KL(\pi ||\pi_{k})[s]\big]}{\partial \theta_j\partial \theta_j}\) is the Hessian of the KL divergence constraint (\(\boldsymbol{H}\) is also called the Fisher information matrix. It is symmetric positive semi-definite), \(b\doteq J^{C}(\pi_k)-d\) is the constraint violation of the policy \(\pi_{k}\), and \(\theta\) is the parameter of the policy. PCPO linearize the objective function at \(\pi_k\) subject to second order approximation of the KL divergence constraint in order to obtain the following updates:

(6)#\[\begin{split}&\theta_{k+\frac{1}{2}} = \underset{\theta}{\arg\max}g^{T}(\theta-\theta_k) \\ \text{s.t.}\quad &\frac{1}{2}(\theta-\theta_{k})^{T}\boldsymbol{H}(\theta-\theta_k)\le \delta . \label{eq:update1}\end{split}\]

In fact, the above problem is essentially an optimization problem presented in TRPO, which can be completely solved using the method we introduced in the TRPO tutorial.

Projection Stage

PCPO provides a selection reference for distance measures: if the projection is defined in the parameter space, \(L2\) norm projection is selected, while if the projection is defined in the probability space, KL divergence projection is better. This can be approximated through the second order expansion. Again, PCPO linearize the cost constraint at \(\pi_{k}\). This gives the following update for the projection step:

(7)#\[\begin{split}&\theta_{k+1} =\underset{\theta}{\arg\min}\frac{1}{2}(\theta-{\theta}_{k+\frac{1}{2}})^{T}\boldsymbol{L}(\theta-{\theta}_{k+\frac{1}{2}})\\ \text{s.t.}\quad & a^{T}(\theta-\theta_{k})+b\leq 0\end{split}\]

where \(\boldsymbol{L}=\boldsymbol{I}\) for \(L2\) norm projection, and \(\boldsymbol{L}=\boldsymbol{H}\) for KL divergence projection.

PCPO solves Problem Eq.4 and Problem Eq.5 using convex programming, see detailed in Appendix.

For each policy update:

(8)#\[\theta_{k+1}=\theta_{k}+\sqrt{\frac{2\delta}{g^T\boldsymbol{H}^{-1}g}}\boldsymbol{H}^{-1}g -\max\left(0,\frac{\sqrt{\frac{2\delta}{g^T\boldsymbol{H}^{-1}g}}a^{T}\boldsymbol{H}^{-1}g+b}{a^T\boldsymbol{L}^{-1}a}\right)\boldsymbol{L}^{-1}a\]

Hint

\(\boldsymbol{H}\) is assumed invertible and PCPO requires to invert \(\boldsymbol{H}\), which is impractical for huge neural network policies. Hence it use the conjugate gradient method. (See appendix for a discussion of the trade-off between the approximation error, and computational efficiency of the conjugate gradient method.)

Question

Is using linear approximation to the constraint set enough to ensure constraint satisfaction since the real constraint set is maybe non-convex?

Question

Can PCPO solve the multi-constraint problem? And how PCPO actually do that?

Answer

In fact, if the step size \(\delta\) is small, then the linearization of the constraint set is accurate enough to locally approximate it.

Answer

By sequentially projecting onto each of the sets, the update in Eq.5 can be extended by using alternating projections.


Analysis#

The update rule in Eq.5 shows that the difference between PCPO with KL divergence and \(L2\) norm projections is the cost update direction, leading to a difference in reward improvement. These two projections converge to different stationary points with different convergence rates related to the smallest and largest singular values of the Fisher information matrix shown in Theorem 3. PCPO assumes that: PCPO minimizes the negative reward objective function \(f: \mathbb{R}^n \rightarrow \mathbb{R}\) . The function \(f\) is \(L\)-smooth and twice continuously differentiable over the closed and convex constraint set \(\mathcal{C}\).

Theorem 3

Let \(\eta\doteq \sqrt{\frac{2\delta}{g^{T}\boldsymbol{H}^{-1}g}}\) in Eq.5, where \(\delta\) is the step size for reward improvement, \(g\) is the gradient of \(f\), and \(\boldsymbol{H}\) is the Fisher information matrix. Let \(\sigma_\mathrm{max}(\boldsymbol{H})\) be the largest singular value of \(\boldsymbol{H}\), and \(a\) be the gradient of cost advantage function in Eq.5. Then PCPO with KL divergence projection converges to a stationary point either inside the constraint set or in the boundary of the constraint set. In the latter case, the Lagrangian constraint \(g=-\alpha a, \alpha\geq0\) holds. Moreover, at step \(k+1\) the objective value satisfies

(9)#\[f(\theta_{k+1})\leq f(\theta_{k})+||\theta_{k+1}-\theta_{k}||^2_{-\frac{1}{\eta}\boldsymbol{H}+\frac{L}{2}\boldsymbol{I}}.\]

PCPO with \(L2\) norm projection converges to a stationary point either inside the constraint set or in the boundary of the constraint set. In the latter case, the Lagrangian constraint \(\boldsymbol{H}^{-1}g=-\alpha a, \alpha\geq0\) holds. If \(\sigma_\mathrm{max}(\boldsymbol{H})\leq1,\) then a step \(k+1\) objective value satisfies.

(10)#\[f(\theta_{k+1})\leq f(\theta_{k})+(\frac{L}{2}-\frac{1}{\eta})||\theta_{k+1}-\theta_{k}||^2_2.\]

Theorem 3 shows that in the stationary point \(g\) is a line that points to the opposite direction of \(a\). Further, the improvement of the objective value is affected by the singular value of the Fisher information matrix. Specifically, the objective of KL divergence projection decreases when \(\frac{L\eta}{2}\boldsymbol{I}\prec\boldsymbol{H},\) implying that \(\sigma_\mathrm{min}(\boldsymbol{H})> \frac{L\eta}{2}\). And the objective of \(L2\) norm projection decreases when \(\eta<\frac{2}{L},\) implying that condition number of \(\boldsymbol{H}\) is upper bounded: \(\frac{\sigma_\mathrm{max}(\boldsymbol{H})}{\sigma_\mathrm{min}(\boldsymbol{H})}<\frac{2||g||^2_2}{L^2\delta}\). Observing the singular values of the Fisher information matrix allows us to adaptively choose the appropriate projection and hence achieve objective improvement. In the supplemental material, we further use an example to compare the optimization trajectories and stationary points of KL divergence and \(L2\) norm projections.


Code with OmniSafe#

Quick start#

Run PCPO in Omnisafe

Here are 3 ways to run CPO in OmniSafe:

  • Run Agent from preset yaml file

  • Run Agent from custom config dict

  • Run Agent from custom terminal config

1import omnisafe
2
3
4env_id = 'SafetyPointGoal1-v0'
5
6agent = omnisafe.Agent('PCPO', env_id)
7agent.learn()
 1import omnisafe
 2
 3
 4env_id = 'SafetyPointGoal1-v0'
 5custom_cfgs = {
 6    'train_cfgs': {
 7        'total_steps': 1024000,
 8        'vector_env_nums': 1,
 9        'parallel': 1,
10    },
11    'algo_cfgs': {
12        'update_cycle': 2048,
13        'update_iters': 1,
14    },
15    'logger_cfgs': {
16        'use_wandb': False,
17    },
18}
19
20agent = omnisafe.Agent('PCPO', env_id, custom_cfgs=custom_cfgs)
21agent.learn()

We use train_policy.py as the entrance file. You can train the agent with PCPO simply using train_policy.py, with arguments about PCPO and environments does the training. For example, to run PCPO in SafetyPointGoal1-v0 , with 4 cpu cores and seed 0, you can use the following command:

1cd examples
2python train_policy.py --algo PCPO --env-id SafetyPointGoal1-v0 --parallel 1 --total-steps 1024000 --device cpu --vector-env-nums 1 --torch-threads 1

Architecture of functions#

  • pcpo.learn()

    • env.roll_out()

    • pcpo.update()

      • pcpo.buf.get()

      • pcpo.update_policy_net()

        • Fvp()

        • conjugate_gradients()

        • search_step_size()

      • pcpo.update_cost_net()

      • pcpo.update_value_net()

  • pcpo.log()


Documentation of basic functions#


Documentation of new functions#

pcpo.update_policy_net()

Update the policy network, flowing the next steps:

  1. Get the policy reward performance gradient g (flat as vector)

1self.pi_optimizer.zero_grad()
2loss_pi, pi_info = self.compute_loss_pi(data=data)
3loss_pi.backward()
4g_flat = get_flat_gradients_from(self.ac.pi.net)
5g_flat *= -1
  1. Get the policy cost performance gradient b (flat as vector)

1self.pi_optimizer.zero_grad()
2loss_cost, _ = self.compute_loss_cost_performance(data=data)
3loss_cost.backward()
4b_flat = get_flat_gradients_from(self.ac.pi.net)
  1. Build the Hessian-vector product based on an approximation of the KL-divergence, using conjugate_gradients

1p = conjugate_gradients(self.Fvp, b_flat, self.cg_iters)
2q = xHx
3r = g_flat.dot(p)  # g^T H^{-1} b
4s = b_flat.dot(p)  # b^T H^{-1} b
  1. Determine step direction and apply SGD step after grads where set (By adjust_cpo_step_direction())

1final_step_dir, accept_step = self.adjust_cpo_step_direction(
2step_dir,
3g_flat,
4c=c,
5optim_case=2,
6p_dist=p_dist,
7data=data,
8total_steps=20,
9)
  1. Update actor network parameters

1new_theta = theta_old + final_step_dir
2set_param_values_to_model(self.ac.pi.net, new_theta)

pcpo.adjust_cpo_step_direction()

PCPO algorithm performs line-search to ensure constraint satisfaction for rewards and costs, flowing the next steps:

  1. Calculate the expected reward improvement.

1expected_rew_improve = g_flat.dot(step_dir)
  1. Performs line-search to find a step improve the surrogate while not violating trust region.

  • Search acceptance step ranging from 0 to total step

1for j in range(total_steps):
2new_theta = _theta_old + step_frac * step_dir
3set_param_values_to_model(self.ac.pi.net, new_theta)
4acceptance_step = j + 1
  • In each step of for loop, calculate the policy performance and KL divergence.

1with torch.no_grad():
2    loss_pi_rew, _ = self.compute_loss_pi(data=data)
3    loss_pi_cost, _ = self.compute_loss_cost_performance(data=data)
4    q_dist = self.ac.pi.dist(data['obs'])
5    torch_kl = torch.distributions.kl.kl_divergence(p_dist, q_dist).mean().item()
6loss_rew_improve = self.loss_pi_before - loss_pi_rew.item()
7cost_diff = loss_pi_cost.item() - self.loss_pi_cost_before
  • Step only if surrogate is improved and within the trust region.

 1if not torch.isfinite(loss_pi_rew) and not torch.isfinite(loss_pi_cost):
 2    self.logger.log('WARNING: loss_pi not finite')
 3elif loss_rew_improve < 0 if optim_case > 1 else False:
 4    self.logger.log('INFO: did not improve improve <0')
 5
 6elif cost_diff > max(-c, 0):
 7    self.logger.log(f'INFO: no improve {cost_diff} > {max(-c, 0)}')
 8elif torch_kl > self.target_kl * 1.5:
 9    self.logger.log(f'INFO: violated KL constraint {torch_kl} at step {j + 1}.')
10else:
11    self.logger.log(f'Accept step at i={j + 1}')
12    break
  1. Return appropriate step direction and acceptance step.


Parameters#

Specific Parameters

  • target_kl(float): Constraint for KL-distance to avoid too far gap

  • cg_damping(float): parameter plays a role in building Hessian-vector

  • cg_iters(int): Number of iterations of conjugate gradient to perform.

  • cost_limit(float): Constraint for agent to avoid too much cost

Basic parameters

  • algo (string): The name of algorithm corresponding to current class, it does not actually affect any things which happen in the following.

  • actor (string): The type of network in actor, discrete or continuous.

  • model_cfgs (dictionary) : Actor and critic’s net work configuration, it originates from algo.yaml file to describe hidden layers , activation function, shared_weights and weight_initialization_mode.

    • shared_weights (bool) : Use shared weights between actor and critic network or not.

    • weight_initialization_mode (string) : The type of weight initialization method.

      • pi (dictionary) : parameters for actor network pi

        • hidden_sizes:

          • 64

          • 64

        • activations: tanh

      • val (dictionary) parameters for critic network v

        • hidden_sizes:

          • 64

          • 64

          Hint

          Name

          Type

          Description

          v

          nn.Module

          Gives the current estimate of V for states in s.

          pi

          nn.Module

          Deterministically or continuously computes an action from the agent, conditioned on states in s.

      • activations: tanh

      • env_id (string): The name of environment we want to roll out.

      • seed (int): Define the seed of experiments.

      • parallel (int): Define the seed of experiments.

      • epochs (int): The number of epochs we want to roll out.

      • steps_per_epoch (int):The number of time steps per epoch.

      • pi_iters (int): The number of iteration when we update actor network per mini batch.

      • critic_iters (int): The number of iteration when we update critic network per mini batch.

Optional parameters

  • use_cost_critic (bool): Use cost value function or not.

  • linear_lr_decay (bool): Use linear learning rate decay or not.

  • exploration_noise_anneal (bool): Use exploration noise anneal or not.

  • reward_penalty (bool): Use cost to penalize reward or not.

  • kl_early_stopping (bool): Use KL early stopping or not.

  • max_grad_norm (float): Use maximum gradient normalization or not.

  • scale_rewards (bool): Use reward scaling or not.

Buffer parameters

Hint

Name

Description

Buffer

A buffer for storing trajectories experienced by an agent interacting with the environment, and using Generalized Advantage Estimation (GAE) for calculating the advantages of state-action pairs.

Warning

Buffer collects only raw data received from environment.

  • gamma (float): The gamma for GAE.

  • lam (float): The lambda for reward GAE.

  • adv_estimation_method (float):Roughly what KL divergence we think is appropriate between new and old policies after an update. This will get used for early stopping. (Usually small, 0.01 or 0.05.)

  • standardized_reward (int): Use standardized reward or not.

  • standardized_cost (bool): Use standardized cost or not.


References#

Appendix#

Click here to jump to PCPO Theorem Click here to jump to Code with OmniSafe

Proof of Theorem 2#

To prove the policy performance bound when the current policy is infeasible (constraint-violating), we first prove two lemma of the KL divergence between \(\pi_{k}\) and \(\pi_{k+1}\) for the KL divergence projection. We then prove the main theorem for the worst-case performance degradation.

Lemma 1

If the current policy \(\pi_{k}\) satisfies the constraint, the constraint set is closed and convex, the KL divergence constraint for the first step is \(\mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL}(\pi_{k+\frac{1}{2}} ||\pi_{k})[s]\big]\leq \delta\), where \(\delta\) is the step size in the reward improvement step, then under KL divergence projection, we have

(11)#\[\mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL}(\pi_{k+1} ||\pi_{k})[s]\big]\leq \delta.\]

Lemma 2

If the current policy \(\pi_{k}\) violates the constraint, the constraint set is closed and convex, the KL divergence constraint for the first step is \(\mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL}(\pi_{k+\frac{1}{2}} ||\pi_{k})[s]\big]\leq \delta\). where \(\delta\) is the step size in the reward improvement step, then under the KL divergence projection, we have

(12)#\[\mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL}(\pi_{k+1} ||\pi_{k})[s]\big]\leq \delta+{b^+}^2\alpha_\mathrm{KL},\]

where \(\alpha_\mathrm{KL} \doteq \frac{1}{2a^T\boldsymbol{H}^{-1}a}\), \(a\) is the gradient of the cost advantage function, \(\boldsymbol{H}\) is the Hessian of the KL divergence constraint, and \(b^+\doteq\max(0,J^{C}(\pi_k)-h)\).

Proof of Lemma 1

By the Bregman divergence projection inequality, \(\pi_{k}\) being in the constraint set, and \(\pi_{k+1}\) being the projection of the \(\pi_{k+\frac{1}{2}}\) onto the constraint set, we have

(13)#\[\begin{split}&\mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL}(\pi_{k} ||\pi_{k+\frac{1}{2}})[s]\big]\geq \mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL}(\pi_{k}||\pi_{k+1})[s]\big] \\ &+ \mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL}(\pi_{k+1} ||\pi_{k+\frac{1}{2}})[s]\big]\\ &\Rightarrow\delta\geq \mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL}(\pi_{k} ||\pi_{k+\frac{1}{2}})[s]\big]\geq \mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL}(\pi_{k}||\pi_{k+1})[s]\big].\end{split}\]

The derivation uses the fact that KL divergence is always greater than zero. We know that KL divergence is asymptotically symmetric when updating the policy within a local neighborhood. Thus, we have

(14)#\[\delta\geq \mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL}(\pi_{k+\frac{1}{2}} ||\pi_{k})[s]\big]\geq \mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL}(\pi_{k+1}||\pi_{k})[s]\big].\]

Proof of Lemma 2

We define the sub-level set of cost constraint function for the current infeasible policy \(\pi_k\):

(15)#\[\begin{split}L^{\pi_k}=\{\pi~|~J^{C}(\pi_{k})+ \mathbb{E}_{\substack{s\sim d^{\pi_{k}}\\ a\sim \pi}}[A_{\pi_k}^{C}(s,a)]\leq J^{C}(\pi_{k})\}.\end{split}\]

This implies that the current policy \(\pi_k\) lies in \(L^{\pi_k}\), and \(\pi_{k+\frac{1}{2}}\) is projected onto the constraint set: \(\{\pi~|~J^{C}(\pi_{k})+ \mathbb{E}_{\substack{s\sim d^{\pi_{k}}\\ a\sim \pi}}[A_{\pi_k}^{C}(s,a)]\leq h\}\). Next, we define the policy \(\pi_{k+1}^l\) as the projection of \(\pi_{k+\frac{1}{2}}\) onto \(L^{\pi_k}\).

For these three polices \(\pi_k, \pi_{k+1}\) and \(\pi_{k+1}^l\), with \(\varphi(x)\doteq\sum_i x_i\log x_i\), we have

(16)#\[ \begin{align}\begin{aligned}\begin{split}\delta &\geq \mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL}(\pi_{k+1}^l ||\pi_{k})[s]\big] \\&=\mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL}(\pi_{k+1} ||\pi_{k})[s]\big] -\mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL} (\pi_{k+1} ||\pi_{k+1}^l)[s]\big]\\ &+\mathbb{E}_{s\sim d^{\pi_{k}}}\big[(\nabla\varphi(\pi_k)-\nabla\varphi(\pi_{k+1}^{l}))^T(\pi_{k+1}-\pi_{k+1}^l)[s]\big] \nonumber \\\end{split}\\\begin{split}\Rightarrow \mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL} (\pi_{k+1} ||\pi_{k})[s]\big]&\leq \delta + \mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL} (\pi_{k+1} ||\pi_{k+1}^l)[s]\big]\\ &- \mathbb{E}_{s\sim d^{\pi_{k}}}\big[(\nabla\varphi(\pi_k)-\nabla\varphi(\pi_{k+1}^{l}))^T(\pi_{k+1}-\pi_{k+1}^l)[s]\big].\end{split}\end{aligned}\end{align} \]

The inequality \(\mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL} (\pi_{k+1}^l ||\pi_{k})[s]\big]\leq\delta\) comes from that \(\pi_{k}\) and \(\pi_{k+1}^l\) are in \(L^{\pi_k}\), and Lemma 1.

If the constraint violation of the current policy \(\pi_k\) is small, \(b^+\) is small, \(\mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL} (\pi_{k+1} ||\pi_{k+1}^l)[s]\big]\) can be approximated by the second order expansion. By the update rule in Eq.5, we have

(17)#\[\begin{split}\mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL}(\pi_{k+1} ||\pi_{k+1}^l)[s]\big] &\approx \frac{1}{2}(\theta_{k+1}-\theta_{k+1}^l)^{T}\boldsymbol{H}(\theta_{k+1}-\theta_{k+1}^l)\\ &=\frac{1}{2} \Big(\frac{b^+}{a^T\boldsymbol{H}^{-1}a}\boldsymbol{H}^{-1}a\Big)^T\boldsymbol{H}\Big(\frac{b^+}{a^T\boldsymbol{H}^{-1}a}\boldsymbol{H}^{-1}a\Big)\\ &=\frac{{b^+}^2}{2a^T\boldsymbol{H}^{-1}a}\\ &={b^+}^2\alpha_\mathrm{KL},\end{split}\]

where \(\alpha_\mathrm{KL} \doteq \frac{1}{2a^T\boldsymbol{H}^{-1}a}.\)

And since \(\delta\) is small, we have \(\nabla\varphi(\pi_k)-\nabla\varphi(\pi_{k+1}^{l})\approx \mathbf{0}\) given \(s\). Thus, the third term in Eq.8 can be eliminated.

Combining Eq.8 and Eq.13, we have \([ \mathbb{E}_{s\sim d^{\pi_{k}}}\big[\mathrm{KL}(\pi_{k+1}||\pi_{k})[s]\big]\leq \delta+{b^+}^2\alpha_\mathrm{KL}.]\)

Now we use Lemma 2 to prove the Theorem 2. Following the same proof in Theorem 1, we complete the proof.

Proof of Analytical Solution to PCPO#

Analytical Solution to PCPO

Consider the PCPO problem. In the first step, we optimize the reward:

(18)#\[\begin{split}\theta_{k+\frac{1}{2}} = &\underset{\theta}{\arg\,min}\quad g^{T}(\theta-\theta_{k}) \\ \text{s.t.}\quad&\frac{1}{2}(\theta-\theta_{k})^{T}\boldsymbol{H}(\theta-\theta_{k})\leq \delta,\end{split}\]

and in the second step, we project the policy onto the constraint set:

(19)#\[\begin{split}\theta_{k+1} = &\underset{\theta}{\arg\,min}\quad \frac{1}{2}(\theta-{\theta}_{k+\frac{1}{2}})^{T}\boldsymbol{L}(\theta-{\theta}_{k+\frac{1}{2}}) \\ \text{s.t.}\quad &a^{T}(\theta-\theta_{k})+b\leq 0,\end{split}\]

where \(g, a, \theta \in R^n, b, \delta\in R, \delta>0,\) and \(\boldsymbol{H},\boldsymbol{L}\in R^{n\times n}, \boldsymbol{L}=\boldsymbol{H}\), if using the KL divergence projection, and \(\boldsymbol{L}=\boldsymbol{I}\) if using the \(L2\) norm projection. When there is at least one strictly feasible point, the optimal solution satisfies

(20)#\[\begin{split}\theta_{k+1}&=\theta_{k}+\sqrt{\frac{2\delta}{g^T\boldsymbol{H}^{-1}g}}\boldsymbol{H}^{-1}g\nonumber\\ &-\max(0,\frac{\sqrt{\frac{2\delta}{g^T\boldsymbol{H}^{-1}g}}a^{T}\boldsymbol{H}^{-1}g+b}{a^T\boldsymbol{L}^{-1}a})\boldsymbol{L}^{-1}a\end{split}\]

assuming that \(\boldsymbol{H}\) is invertible to get a unique solution.

Proof of Analytical Solution to PCPO (Click here)

For the first problem, since \(\boldsymbol{H}\) is the Fisher Information matrix, which automatically guarantees it is positive semi-definite. Hence it is a convex program with quadratic inequality constraints. Hence if the primal problem has a feasible point, then Slater’s condition is satisfied and strong duality holds. Let \(\theta^{*}\) and \(\lambda^*\) denote the solutions to the primal and dual problems, respectively. In addition, the primal objective function is continuously differentiable. Hence the Karush-Kuhn-Tucker (KKT) conditions are necessary and sufficient for the optimality of \(\theta^{*}\) and \(\lambda^*.\) We now form the Lagrangian:

(21)#\[\mathcal{L}(\theta,\lambda)=-g^{T}(\theta-\theta_{k})+\lambda\Big(\frac{1}{2}(\theta-\theta_{k})^{T}\boldsymbol{H}(\theta-\theta_{k})- \delta\Big).\]

And we have the following KKT conditions:

(22)#\[\begin{split}-g + \lambda^*\boldsymbol{H}\theta^{*}-\lambda^*\boldsymbol{H}\theta_{k}=0~~~~&~~~\nabla_\theta\mathcal{L}(\theta^{*},\lambda^{*})=0 \\ \frac{1}{2}(\theta^{*}-\theta_{k})^{T}\boldsymbol{H}(\theta^{*}-\theta_{k})- \delta=0~~~~&~~~\nabla_\lambda\mathcal{L}(\theta^{*},\lambda^{*})=0 \\ \frac{1}{2}(\theta^{*}-\theta_{k})^{T}\boldsymbol{H}(\theta^{*}-\theta_{k})-\delta\leq0~~~~&~~~\text{primal constraints}\label{KKT_3}\\ \lambda^*\geq0~~~~&~~~\text{dual constraints}\\ \lambda^*\Big(\frac{1}{2}(\theta^{*}-\theta_{k})^{T}\boldsymbol{H}(\theta^{*}-\theta_{k})-\delta\Big)=0~~~~&~~~\text{complementary slackness}\end{split}\]

By Eq.22, we have \(\theta^{*}=\theta_{k}+\frac{1}{\lambda^*}\boldsymbol{H}^{-1}g\). And \(\lambda^*=\sqrt{\frac{g^T\boldsymbol{H}^{-1}g}{2\delta}}\) . Hence we have our optimal solution:

(23)#\[\theta_{k+\frac{1}{2}}=\theta^{*}=\theta_{k}+\sqrt{\frac{2\delta}{g^T\boldsymbol{H}^{-1}g}}\boldsymbol{H}^{-1}g\]

Following the same reasoning, we now form the Lagrangian of the second problem:

(24)#\[\mathcal{L}(\theta,\lambda)=\frac{1}{2}(\theta-{\theta}_{k+\frac{1}{2}})^{T}\boldsymbol{L}(\theta-{\theta}_{k+\frac{1}{2}})+\lambda(a^T(\theta-\theta_{k})+b)\]

And we have the following KKT conditions:

(25)#\[\begin{split}\boldsymbol{L}\theta^*-\boldsymbol{L}\theta_{k+\frac{1}{2}}+\lambda^*a=0~~~~&~~~\nabla_\theta\mathcal{L}(\theta^{*},\lambda^{*})=0 \\ a^T(\theta^*-\theta_{k})+b=0~~~~&~~~\nabla_\lambda\mathcal{L}(\theta^{*},\lambda^{*})=0 \\ a^T(\theta^*-\theta_{k})+b\leq0~~~~&~~~\text{primal constraints} \\ \lambda^*\geq0~~~~&~~~\text{dual constraints} \\ \lambda^*(a^T(\theta^*-\theta_{k})+b)=0~~~~&~~~\text{complementary slackness}\end{split}\]

By Eq.25, we have \(\theta^{*}=\theta_{k+1}+\lambda^*\boldsymbol{L}^{-1}a\). And by solving Eq.25, we have \(\lambda^*=\max(0,\\ \frac{a^T(\theta_{k+\frac{1}{2}}-\theta_{k})+b}{a\boldsymbol{L}^{-1}a})\). Hence we have our optimal solution:

(26)#\[\theta_{k+1}=\theta^{*}=\theta_{k+\frac{1}{2}}-\max(0,\frac{a^T(\theta_{k+\frac{1}{2}}-\theta_{k})+b}{a^T\boldsymbol{L}^{-1}a^T})\boldsymbol{L}^{-1}a\]

we have

(27)#\[\begin{split}\theta_{k+1}&=\theta_{k}+\sqrt{\frac{2\delta}{g^T\boldsymbol{H}^{-1}g}}\boldsymbol{H}^{-1}g\\ &-\max(0,\frac{\sqrt{\frac{2\delta}{g^T\boldsymbol{H}^{-1}g}}a^{T}\boldsymbol{H}^{-1}g+b}{a^T\boldsymbol{L}^{-1}a})\boldsymbol{L}^{-1}a\end{split}\]

Proof of Theorem 3#

For our analysis, we make the following assumptions: we minimize the negative reward objective function \(f: R^n \rightarrow R\) (We follow the convention of the literature that authors typically minimize the objective function). The function \(f\) is \(L\)-smooth and twice continuously differentiable over the closed and convex constraint set \(\mathcal{C}\). We have the following Lemma 3 to characterize the projection and for the proof of Theorem 3

Lemma 3

For any \(\theta\), \(\theta^{*}=\mathrm{Proj}^{\boldsymbol{L}}_{\mathcal{C}}(\theta)\) if and only if \((\theta-\theta^*)^T\boldsymbol{L}(\theta'-\theta^*)\leq0, \forall\theta'\in\mathcal{C}\), where \(\mathrm{Proj}^{\boldsymbol{L}}_{\mathcal{C}}(\theta)\doteq \underset{\theta' \in \mathrm{C}}{\arg\,min}||\theta-\theta'||^2_{\boldsymbol{L}}\) and \(\boldsymbol{L}=\boldsymbol{H}\) if using the KL divergence projection, and \(\boldsymbol{L}=\boldsymbol{I}\) if using the \(L2\) norm projection.

Based on Lemma 3 we have the proof of following Theorem 3.

Theorem 3 (Stationary Points of PCPO with the KL divergence and \(L2\) Norm Projections)

Let \(\eta\doteq \sqrt{\frac{2\delta}{g^{T}\boldsymbol{H}^{-1}g}}\) in Eq.5, where \(\delta\) is the step size for reward improvement, \(g\) is the gradient of \(f\), \(\boldsymbol{H}\) is the Fisher information matrix. Let \(\sigma_\mathrm{max}(\boldsymbol{H})\) be the largest singular value of \(\boldsymbol{H}\), and \(a\) be the gradient of cost advantage function in Eq.5. Then PCPO with the KL divergence projection converges to stationary points with \(g\in-a\) (i.e., the gradient of \(f\) belongs to the negative gradient of the cost advantage function). The objective value changes by

(31)#\[f(\theta_{k+1})\leq f(\theta_{k})+||\theta_{k+1}-\theta_{k}||^2_{-\frac{1}{\eta}\boldsymbol{H}+\frac{L}{2}\boldsymbol{I}}\]

PCPO with the \(L2\) norm projection converges to stationary points with \(\boldsymbol{H}^{-1}g\in-a\) (i.e., the product of the inverse of \(\boldsymbol{H}\) and gradient of \(f\) belongs to the negative gradient of the cost advantage function). If \(\sigma_\mathrm{max}(\boldsymbol{H})\leq1\), then the objective value changes by

(32)#\[f(\theta_{k+1})\leq f(\theta_{k})+(\frac{L}{2}-\frac{1}{\eta})||\theta_{k+1}-\theta_{k}||^2_2\]
Proof of Theorem 3 (Click here)

The proof of the theorem is based on working in a Hilbert space and the non-expansive property of the projection. We first prove stationary points for PCPO with the KL divergence and \(L2\) norm projections, and then prove the change of the objective value.

When in stationary points \(\theta^*\), we have

(33)#\[\begin{split}\theta^{*}&=\theta^{*}-\sqrt{\frac{2\delta}{g^T\boldsymbol{H}^{-1}g}}\boldsymbol{H}^{-1}g -\max\left(0,\frac{\sqrt{\frac{2\delta}{g^T\boldsymbol{H}^{-1}g}}a^{T}\boldsymbol{H}^{-1}g+b}{a^T\boldsymbol{L}^{-1}a}\right)\boldsymbol{L}^{-1}a\\ &\Leftrightarrow \sqrt{\frac{2\delta}{g^T\boldsymbol{H}^{-1}g}}\boldsymbol{H}^{-1}g = -\max(0,\frac{\sqrt{\frac{2\delta}{g^T\boldsymbol{H}^{-1}g}}a^{T}\boldsymbol{H}^{-1}g+b}{a^T\boldsymbol{L}^{-1}a})\boldsymbol{L}^{-1}a\\ &\Leftrightarrow \boldsymbol{H}^{-1}g \in -\boldsymbol{L}^{-1}a. \label{eq:appendixStationary}\end{split}\]

For the KL divergence projection (\(\boldsymbol{L}=\boldsymbol{H}\)), Eq.33 boils down to \(g\in-a\), and for the \(L2\) norm projection (\(\boldsymbol{L}=\boldsymbol{I}\)), Eq.33 is equivalent to \(\boldsymbol{H}^{-1}g\in-a\).

Now we prove the second part of the theorem. Based on Lemma 3, for the KL divergence projection, we have

(34)#\[\begin{split}\label{eq:appendix_converge_0} \left(\theta_k-\theta_{k+1}\right)^T \boldsymbol{H}\left(\theta_k-\eta \boldsymbol{H}^{-1} \boldsymbol{g}-\theta_{k+1}\right) \leq 0 \\ \Rightarrow \boldsymbol{g}^T\left(\theta_{k+1}-\theta_k\right) \leq-\frac{1}{\eta}\left\|\theta_{k+1}-\theta_k\right\|_{\boldsymbol{H}}^2\end{split}\]

By Eq.34, and \(L\)-smooth continuous function \(f,\) we have

(35)#\[\begin{split}f\left(\theta_{k+1}\right) & \leq f\left(\theta_k\right)+\boldsymbol{g}^T\left(\theta_{k+1}-\theta_k\right)+\frac{L}{2}\left\|\theta_{k+1}-\theta_k\right\|_2^2 \\ & \leq f\left(\theta_k\right)-\frac{1}{\eta}\left\|\theta_{k+1}-\theta_k\right\|_{\boldsymbol{H}}^2+\frac{L}{2}\left\|\theta_{k+1}-\theta_k\right\|_2^2 \\ &=f\left(\theta_k\right)+\left(\theta_{k+1}-\theta_k\right)^T\left(-\frac{1}{\eta} \boldsymbol{H}+\frac{L}{2} \boldsymbol{I}\right)\left(\theta_{k+1}-\theta_k\right) \\ &=f\left(\theta_k\right)+\left\|\theta_{k+1}-\theta_k\right\|_{-\frac{1}{\eta} \boldsymbol{H}+\frac{L}{2} \boldsymbol{I}}^2\end{split}\]

For the \(L2\) norm projection, we have

(36)#\[\begin{split}(\theta_{k}-\theta_{k+1})^T(\theta_{k}-\eta\boldsymbol{H}^{-1}g-\theta_{k+1})\leq0\\ \Rightarrow g^T\boldsymbol{H}^{-1}(\theta_{k+1}-\theta_{k})\leq -\frac{1}{\eta}||\theta_{k+1}-\theta_{k}||^2_2\end{split}\]

By Eq.36, \(L\)-smooth continuous function \(f\), and if \(\sigma_\mathrm{max}(\boldsymbol{H})\leq1\), we have

(37)#\[\begin{split}f(\theta_{k+1})&\leq f(\theta_{k})+g^T(\theta_{k+1}-\theta_{k})+\frac{L}{2}||\theta_{k+1}-\theta_{k}||^2_2 \nonumber\\ &\leq f(\theta_{k})+(\frac{L}{2}-\frac{1}{\eta})||\theta_{k+1}-\theta_{k}||^2_2.\nonumber\end{split}\]

To see why we need the assumption of \(\sigma_\mathrm{max}(\boldsymbol{H})\leq1\), we define \(\boldsymbol{H}=\boldsymbol{U}\boldsymbol{\Sigma}\boldsymbol{U}^T\) as the singular value decomposition of \(\boldsymbol{H}\) with \(u_i\) being the column vector of \(\boldsymbol{U}\). Then we have

(38)#\[\begin{split}g^T\boldsymbol{H}^{-1}(\theta_{k+1}-\theta_{k}) &=g^T\boldsymbol{U}\boldsymbol{\Sigma}^{-1}\boldsymbol{U}^T(\theta_{k+1}-\theta_{k}) \nonumber\\ &=g^T(\sum_{i}\frac{1}{\sigma_i(\boldsymbol{H})}u_iu_i^T)(\theta_{k+1}-\theta_{k})\nonumber\\ &=\sum_{i}\frac{1}{\sigma_i(\boldsymbol{H})}g^T(\theta_{k+1}-\theta_{k}).\nonumber\end{split}\]

If we want to have

(39)#\[g^T(\theta_{k+1}-\theta_{k})\leq g^T\boldsymbol{H}^{-1}(\theta_{k+1}-\theta_{k})\leq -\frac{1}{\eta}||\theta_{k+1}-\theta_{k}||^2_2,\]

then every singular value \(\sigma_i(\boldsymbol{H})\) of \(\boldsymbol{H}\) needs to be smaller than \(1\), and hence \(\sigma_\mathrm{max}(\boldsymbol{H})\leq1\), which justifies the assumption we use to prove the bound.

Hint

To make the objective value for PCPO with the KL divergence projection improves, the right hand side of Eq.26 needs to be negative. Hence we have \(\frac{L\eta}{2}\boldsymbol{I}\prec\boldsymbol{H}\), implying that \(\sigma_\mathrm{min}(\boldsymbol{H})>\frac{L\eta}{2}\). And to make the objective value for PCPO with the \(L2\) norm projection improves, the right hand side of Eq.28 needs to be negative. Hence we have \(\eta<\frac{2}{L}\), implying that

(40)#\[\begin{split}&\eta = \sqrt{\frac{2\delta}{g^T\boldsymbol{H}^{-1}g}}<\frac{2}{L}\nonumber\\ \Rightarrow& \frac{2\delta}{g^T\boldsymbol{H}^{-1}g} < \frac{4}{L^2} \nonumber\\ \Rightarrow& \frac{g^{T}\boldsymbol{H}^{-1}g}{2\delta}>\frac{L^2}{4}\nonumber\\ \Rightarrow& \frac{L^2\delta}{2}<g^T\boldsymbol{H}^{-1}g\nonumber\\ &\leq||g||_2||\boldsymbol{H}^{-1}g||_2\nonumber\\ &\leq||g||_2||\boldsymbol{H}^{-1}||_2||g||_2\nonumber\\ &=\sigma_\mathrm{max}(\boldsymbol{H}^{-1})||g||^2_2\nonumber\\ &=\sigma_\mathrm{min}(\boldsymbol{H})||g||^2_2\nonumber\\ \Rightarrow&\sigma_\mathrm{min}(\boldsymbol{H})>\frac{L^2\delta}{2||g||^2_2}. \label{eqnarray}\end{split}\]

By the definition of the condition number and Eq.33, we have