# Copyright 2022-2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of Normalizer."""
from typing import Any, Mapping, Tuple
import torch
import torch.nn as nn
[docs]class Normalizer(nn.Module):
"""Calculate normalized raw_data from running mean and std
References:
- Title: Updating Formulae and a Pairwise Algorithm for Computing Sample Variances
- Author: Tony F. Chan, Gene H. Golub, Randall J. LeVeque
- URL: http://i.stanford.edu/pub/cstr/reports/cs/tr/79/773/CS-TR-79-773.pdf
"""
[docs] def __init__(self, shape: Tuple[int, ...], clip: float = 1e6) -> None:
"""Initialize the normalize."""
super().__init__()
if shape == ():
self.register_buffer('_mean', torch.tensor(0.0))
self.register_buffer('_sumsq', torch.tensor(0.0))
self.register_buffer('_var', torch.tensor(0.0))
self.register_buffer('_std', torch.tensor(0.0))
self.register_buffer('_count', torch.tensor(0))
self.register_buffer('_clip', clip * torch.tensor(1.0))
else:
self.register_buffer('_mean', torch.zeros(*shape))
self.register_buffer('_sumsq', torch.zeros(*shape))
self.register_buffer('_var', torch.zeros(*shape))
self.register_buffer('_std', torch.zeros(*shape))
self.register_buffer('_count', torch.tensor(0))
self.register_buffer('_clip', clip * torch.ones(*shape))
self._mean: torch.Tensor # running mean
self._sumsq: torch.Tensor # running sum of squares
self._var: torch.Tensor # running variance
self._std: torch.Tensor # running standard deviation
self._count: torch.Tensor # number of samples
self._clip: torch.Tensor # clip value
self._shape = shape
self._first = True
@property
def shape(self) -> Tuple[int, ...]:
"""Return the shape of the normalize."""
return self._shape
@property
def mean(self) -> torch.Tensor:
"""Return the mean of the normalize."""
return self._mean
@property
def std(self) -> torch.Tensor:
"""Return the std of the normalize."""
return self._std
[docs] def forward(self, data: torch.Tensor) -> torch.Tensor:
"""Normalize the data."""
return self.normalize(data)
[docs] def normalize(self, data: torch.Tensor) -> torch.Tensor:
"""Normalize the data.
.. hint::
- If the data is the first data, the data will be used to initialize the mean and std.
- If the data is not the first data, the data will be normalized by the mean and std.
- Update the mean and std by the data.
Args:
data: raw data to be normalized.
"""
data = data.to(self._mean.device)
self._push(data)
if self._count <= 1:
return data
output = (data - self._mean) / self._std
return torch.clamp(output, -self._clip, self._clip)
[docs] def _push(self, raw_data: torch.Tensor) -> None:
"""Update the mean and std by the raw_data.
Args:
raw_data: raw data to be normalized.
"""
if raw_data.shape == self._shape:
raw_data = raw_data.unsqueeze(0)
assert raw_data.shape[1:] == self._shape, 'data shape must be equal to (batch_size, *shape)'
if self._first:
self._mean = torch.mean(raw_data, dim=0)
self._sumsq = torch.sum((raw_data - self._mean) ** 2, dim=0)
self._count = torch.tensor(
raw_data.shape[0], dtype=self._count.dtype, device=self._count.device
)
self._first = False
else:
count_raw = raw_data.shape[0]
count = self._count + count_raw
mean_raw = torch.mean(raw_data, dim=0)
delta = mean_raw - self._mean
self._mean += delta * count_raw / count
sumq_raw = torch.sum((raw_data - mean_raw) ** 2, dim=0)
self._sumsq += sumq_raw + delta**2 * self._count * count_raw / count
self._count = count
self._var = self._sumsq / (self._count - 1)
self._std = torch.sqrt(self._var)
self._std = torch.max(self._std, 1e-2 * torch.ones_like(self._std))
[docs] def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
self._first = False
return super().load_state_dict(state_dict, strict)