Source code for omnisafe.common.buffer.vector_offpolicy_buffer

# Copyright 2022-2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of VectorOffPolicyBuffer."""

from typing import Dict

import torch
from gymnasium.spaces import Box

from omnisafe.common.buffer.offpolicy_buffer import OffPolicyBuffer
from omnisafe.typing import OmnisafeSpace


[docs]class VectorOffPolicyBuffer(OffPolicyBuffer): """A VectorReplayBuffer for OffPolicy Algorithms."""
[docs] def __init__( # pylint: disable=super-init-not-called,too-many-arguments self, obs_space: OmnisafeSpace, act_space: OmnisafeSpace, size: int, batch_size: int, num_envs: int, device: torch.device = torch.device('cpu'), ): """Initialize the off policy buffer. The vector-off-policy buffer is a vectorized version of the off-policy buffer. It stores the data in a single tensor, and the data of each environment is stored in a separate column. .. warning:: The buffer only supports Box spaces. Args: obs_space (OmnisafeSpace): The observation space. act_space (OmnisafeSpace): The action space. size (int): The size of the buffer. batch_size (int): The batch size of the buffer. num_envs (int): The number of environments. device (torch.device, optional): The device of the buffer. Defaults to torch.device('cpu'). """ self._num_envs = num_envs if isinstance(obs_space, Box): obs_buf = torch.zeros( (size, num_envs, *obs_space.shape), dtype=torch.float32, device=device ) next_obs_buf = torch.zeros( (size, num_envs, *obs_space.shape), dtype=torch.float32, device=device ) else: raise NotImplementedError if isinstance(act_space, Box): act_buf = torch.zeros( (size, num_envs, *act_space.shape), dtype=torch.float32, device=device ) else: raise NotImplementedError self.data = { 'obs': obs_buf, 'act': act_buf, 'reward': torch.zeros((size, num_envs), dtype=torch.float32, device=device), 'cost': torch.zeros((size, num_envs), dtype=torch.float32, device=device), 'done': torch.zeros((size, num_envs), dtype=torch.float32, device=device), 'next_obs': next_obs_buf, } self._ptr: int = 0 self._size: int = 0 self._max_size: int = size self._batch_size: int = batch_size self._device = device
@property def num_envs(self) -> int: """Return the number of environments.""" return self._num_envs
[docs] def add_field(self, name: str, shape: tuple, dtype: torch.dtype): """Add a field to the buffer. Example: >>> buffer = BaseBuffer(...) >>> buffer.add_field('new_field', (2, 3), torch.float32) >>> buffer.data['new_field'].shape >>> (buffer.size, 2, 3) Args: name (str): The name of the field. shape (tuple): The shape of the field. dtype (torch.dtype): The dtype of the field. """ self.data[name] = torch.zeros( (self._max_size, self._num_envs, *shape), dtype=dtype, device=self._device )
[docs] def sample_batch(self) -> Dict[str, torch.Tensor]: """Sample a batch from the buffer.""" idx = torch.randint( 0, self._size, (self._batch_size * self._num_envs,), device=self._device ) env_idx = torch.arange(self._num_envs, device=self._device).repeat(self._batch_size) batch = {key: value[idx, env_idx] for key, value in self.data.items()} return batch