Source code for omnisafe.common.buffer.base

# Copyright 2022-2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Abstract base class for buffer."""

from abc import ABC, abstractmethod
from typing import Dict

import torch
from gymnasium.spaces import Box

from omnisafe.typing import OmnisafeSpace


[docs]class BaseBuffer(ABC): """Abstract base class for buffer."""
[docs] def __init__( self, obs_space: OmnisafeSpace, act_space: OmnisafeSpace, size: int, device: torch.device = torch.device('cpu'), ): """Initialize the buffer. .. warning:: The buffer only supports Box spaces. In base buffer, we store the following data: .. list-table:: * - Name - Shape - Dtype - Description * - obs - (size, obs_space.shape) - torch.float32 - The observation. * - act - (size, act_space.shape) - torch.float32 - The action. * - reward - (size, ) - torch.float32 - Single step reward. * - cost - (size, ) - torch.float32 - Single step cost. * - done - (size, ) - torch.float32 - Whether the episode is done. Args: obs_space (OmnisafeSpace): The observation space. act_space (OmnisafeSpace): The action space. size (int): The size of the buffer. device (torch.device): The device of the buffer. """ if isinstance(obs_space, Box): obs_buf = torch.zeros((size, *obs_space.shape), dtype=torch.float32, device=device) else: raise NotImplementedError if isinstance(act_space, Box): act_buf = torch.zeros((size, *act_space.shape), dtype=torch.float32, device=device) else: raise NotImplementedError self.data: Dict[str, torch.Tensor] = { 'obs': obs_buf, 'act': act_buf, 'reward': torch.zeros(size, dtype=torch.float32, device=device), 'cost': torch.zeros(size, dtype=torch.float32, device=device), 'done': torch.zeros(size, dtype=torch.float32, device=device), } self._size = size self._device = device
@property def device(self) -> torch.device: """Return the device of the buffer.""" return self._device @property def size(self) -> int: """Return the size of the buffer.""" return self._size def __len__(self): """Return the length of the buffer.""" return self._size
[docs] def add_field(self, name: str, shape: tuple, dtype: torch.dtype): """Add a field to the buffer. Example: >>> buffer = BaseBuffer(...) >>> buffer.add_field('new_field', (2, 3), torch.float32) >>> buffer.data['new_field'].shape >>> (buffer.size, 2, 3) Args: name (str): The name of the field. shape (tuple): The shape of the field. dtype (torch.dtype): The dtype of the field. """ self.data[name] = torch.zeros((self._size, *shape), dtype=dtype, device=self._device)
[docs] @abstractmethod def store(self, **data: torch.Tensor): """Store a transition in the buffer. .. warning:: This is an abstract method. Example: >>> buffer = BaseBuffer(...) >>> buffer.store(obs=obs, act=act, reward=reward, cost=cost, done=done) Args: data (torch.Tensor): The data to store. """