OmniSafe Math#

Tensor Operations#

Documentation

omnisafe.utils.math.get_transpose(tensor)[source]#

Transpose the last two dimensions of a tensor.

Example

>>> tensor = torch.rand(2, 3, 4)
>>> get_transpose(tensor).shape
torch.Size([2, 4, 3])
Parameters:

tensor (Tensor) – torch.Tensor

Return type:

Tensor

omnisafe.utils.math.get_diagonal(tensor)[source]#

Get the diagonal of the last two dimensions of a tensor.

Example

>>> tensor = torch.rand(2, 3, 4)
>>> get_diagonal(tensor).shape
torch.Size([2, 3])
Parameters:

tensor (Tensor) – torch.Tensor

Return type:

Tensor

omnisafe.utils.math.safe_inverse(var_q, det)[source]#

Inverse of a matrix with a safe guard for singular matrix.

Example

>>> var_q = torch.rand(3, 3)
>>> var_q
tensor([[1.00, 0.00, 0.00],
        [0.00, 2.00, 0.00],
        [0.00, 0.00, 3.00]])
>>> det = torch.det(var_q)
>>> det
tensor(6.00)
>>> safe_inverse(var_q, det)
tensor([[1.00, 0.00, 0.00],
        [0.00, 0.50, 0.00],
        [0.00, 0.00, 0.33]])
Parameters:
  • var_q (Tensor) – torch.Tensor

  • det (Tensor) – torch.Tensor

Return type:

Tensor

omnisafe.utils.math.discount_cumsum(x_vector, discount)[source]#

Compute the discounted cumulative sum of vectors.

Example

>>> x_vector = torch.arange(1, 5)
>>> x_vector
tensor([1, 2, 3, 4])
>>> discount_cumsum(x_vector, 0.9)
tensor([4.00, 3.90, 3.00, 1.00])
Parameters:
  • x_vector (torch.Tensor) – shape (B, T).

  • discount (float) – discount factor.

Return type:

Tensor

omnisafe.utils.math.conjugate_gradients(Avp, b_vector, num_steps=10, residual_tol=1e-10, eps=1e-6)[source]#

Implementation of Conjugate gradient algorithm.

Conjugate gradient algorithm is used to solve the linear system of equations \(Ax = b\). The algorithm is described in detail in the paper Conjugate Gradient Method.

Note

Increasing num_steps will lead to a more accurate approximation to \(A^{-1} b\), and possibly slightly-improved performance, but at the cost of slowing things down. Also probably don’t play with this hyperparameter.

Parameters:
  • Avp (Callable[[torch.Tensor], torch.Tensor]) – Fisher information matrix vector product.

  • b_vector (torch.Tensor) – The vector \(b\) in the equation \(Ax = b\).

  • num_steps (int) – The number of steps to run the algorithm for.

  • residual_tol (float) – The tolerance for the residual.

  • eps (float) – A small number to avoid dividing by zero.

Distribution Operations#

Documentation

omnisafe.utils.math.gaussian_kl(mean_p, mean_q, var_p, var_q)[source]#

Decoupled KL between two gaussian distribution.

Note

Detailedly,

(2)#\[KL(q||p) = 0.5 * (tr(\Sigma_p^{-1} \Sigma_q) + (\mu_p - \mu_q)^T \Sigma_p^{-1} (\mu_p - \mu_q) - k + log(\frac{det(\Sigma_p)}{det(\Sigma_q)}))\]

where \(\mu_p\) and \(\mu_q\) are the mean of \(p\) and \(q\), respectively. \(\Sigma_p\) and \(\Sigma_q\) are the co-variance of \(p\) and \(q\), respectively. \(k\) is the dimension of the distribution.

For more details, please refer to the paper A General and Adaptive Robust Loss Function, and the notes here.

Parameters:
  • mean_p (torch.Tensor) – mean of the first distribution, shape (B, n)

  • mean_q (torch.Tensor) – mean of the second distribution, shape (B, n)

  • var_p (torch.Tensor) – co-variance of the first distribution, shape (B, n, n)

  • var_q (torch.Tensor) – co-variance of the second distribution, shape (B, n, n)

Return type:

Tuple[Tensor, Tensor, Tensor, Tensor]

class omnisafe.utils.math.SafeTanhTransformer(cache_size=0)[source]#

Safe Tanh Transformer.

This transformer is used to avoid the error caused by the input of tanh function being too large or too small.

_call(x)[source]#

Abstract method to compute forward transformation.

Return type:

Tensor

_inverse(y)[source]#

Abstract method to compute inverse transformation.

Return type:

Tensor

class omnisafe.utils.math.TanhNormal(loc, scale, validate_args=None)[source]#

Creates a tanh-normal distribution.

(3)#\[ \begin{align}\begin{aligned}X \sim Normal(loc, scale)\\Y = tanh(X) \sim TanhNormal(loc, scale)\end{aligned}\end{align} \]

Example:

>>> m = TanhNormal(torch.tensor([0.0]), torch.tensor([1.0]))
>>> m.sample()  # tanh-normal distributed with mean=0 and stddev=1
tensor([-0.7616])
Parameters:
  • loc (float or Tensor) – mean of the underlying normal distribution

  • scale (float or Tensor) – standard deviation of the underlying normal distribution

entropy()[source]#

The entropy of the tanh normal distribution.

expand(batch_shape, instance=None)[source]#

Expand the distribution.

property loc#

The loc of the tanh normal distribution.

property mean#

The mean of the tanh normal distribution.

property scale#

The scale of the tanh normal distribution.

property stddev#

The stddev of the tanh normal distribution.

property variance#

The variance of the tanh normal distribution.