OmniSafe Math#
Tensor Operations#
Documentation
- omnisafe.utils.math.get_transpose(tensor)[source]#
Transpose the last two dimensions of a tensor.
Example
>>> tensor = torch.rand(2, 3, 4) >>> get_transpose(tensor).shape torch.Size([2, 4, 3])
- Parameters:
tensor (
Tensor
) – torch.Tensor- Return type:
Tensor
- omnisafe.utils.math.get_diagonal(tensor)[source]#
Get the diagonal of the last two dimensions of a tensor.
Example
>>> tensor = torch.rand(2, 3, 4) >>> get_diagonal(tensor).shape torch.Size([2, 3])
- Parameters:
tensor (
Tensor
) – torch.Tensor- Return type:
Tensor
- omnisafe.utils.math.safe_inverse(var_q, det)[source]#
Inverse of a matrix with a safe guard for singular matrix.
Example
>>> var_q = torch.rand(3, 3) >>> var_q tensor([[1.00, 0.00, 0.00], [0.00, 2.00, 0.00], [0.00, 0.00, 3.00]]) >>> det = torch.det(var_q) >>> det tensor(6.00) >>> safe_inverse(var_q, det) tensor([[1.00, 0.00, 0.00], [0.00, 0.50, 0.00], [0.00, 0.00, 0.33]])
- Parameters:
var_q (
Tensor
) – torch.Tensordet (
Tensor
) – torch.Tensor
- Return type:
Tensor
- omnisafe.utils.math.discount_cumsum(x_vector, discount)[source]#
Compute the discounted cumulative sum of vectors.
Example
>>> x_vector = torch.arange(1, 5) >>> x_vector tensor([1, 2, 3, 4]) >>> discount_cumsum(x_vector, 0.9) tensor([4.00, 3.90, 3.00, 1.00])
- Parameters:
x_vector (torch.Tensor) – shape (B, T).
discount (float) – discount factor.
- Return type:
Tensor
- omnisafe.utils.math.conjugate_gradients(Avp, b_vector, num_steps=10, residual_tol=1e-10, eps=1e-6)[source]#
Implementation of Conjugate gradient algorithm.
Conjugate gradient algorithm is used to solve the linear system of equations \(Ax = b\). The algorithm is described in detail in the paper Conjugate Gradient Method.
Note
Increasing
num_steps
will lead to a more accurate approximation to \(A^{-1} b\), and possibly slightly-improved performance, but at the cost of slowing things down. Also probably don’t play with this hyperparameter.- Parameters:
Avp (Callable[[torch.Tensor], torch.Tensor]) – Fisher information matrix vector product.
b_vector (torch.Tensor) – The vector \(b\) in the equation \(Ax = b\).
num_steps (int) – The number of steps to run the algorithm for.
residual_tol (float) – The tolerance for the residual.
eps (float) – A small number to avoid dividing by zero.
Distribution Operations#
Documentation
- omnisafe.utils.math.gaussian_kl(mean_p, mean_q, var_p, var_q)[source]#
Decoupled KL between two gaussian distribution.
Note
Detailedly,
(2)#\[KL(q||p) = 0.5 * (tr(\Sigma_p^{-1} \Sigma_q) + (\mu_p - \mu_q)^T \Sigma_p^{-1} (\mu_p - \mu_q) - k + log(\frac{det(\Sigma_p)}{det(\Sigma_q)}))\]where \(\mu_p\) and \(\mu_q\) are the mean of \(p\) and \(q\), respectively. \(\Sigma_p\) and \(\Sigma_q\) are the co-variance of \(p\) and \(q\), respectively. \(k\) is the dimension of the distribution.
For more details, please refer to the paper A General and Adaptive Robust Loss Function, and the notes here.
- Parameters:
mean_p (torch.Tensor) – mean of the first distribution, shape (B, n)
mean_q (torch.Tensor) – mean of the second distribution, shape (B, n)
var_p (torch.Tensor) – co-variance of the first distribution, shape (B, n, n)
var_q (torch.Tensor) – co-variance of the second distribution, shape (B, n, n)
- Return type:
Tuple
[Tensor
,Tensor
,Tensor
,Tensor
]
- class omnisafe.utils.math.SafeTanhTransformer(cache_size=0)[source]#
Safe Tanh Transformer.
This transformer is used to avoid the error caused by the input of tanh function being too large or too small.
- class omnisafe.utils.math.TanhNormal(loc, scale, validate_args=None)[source]#
Creates a tanh-normal distribution.
(3)#\[ \begin{align}\begin{aligned}X \sim Normal(loc, scale)\\Y = tanh(X) \sim TanhNormal(loc, scale)\end{aligned}\end{align} \]Example:
>>> m = TanhNormal(torch.tensor([0.0]), torch.tensor([1.0])) >>> m.sample() # tanh-normal distributed with mean=0 and stddev=1 tensor([-0.7616])
- Parameters:
loc (float or Tensor) – mean of the underlying normal distribution
scale (float or Tensor) – standard deviation of the underlying normal distribution
- property loc#
The loc of the tanh normal distribution.
- property mean#
The mean of the tanh normal distribution.
- property scale#
The scale of the tanh normal distribution.
- property stddev#
The stddev of the tanh normal distribution.
- property variance#
The variance of the tanh normal distribution.