Source code for omnisafe.envs.safety_gymnasium_env

# Copyright 2022-2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Environments in the Safety-Gymnasium."""

from __future__ import annotations

from typing import Any

import numpy as np
import safety_gymnasium
import torch

from omnisafe.envs.core import CMDP, env_register
from omnisafe.typing import DEVICE_CPU, Box


[docs]@env_register class SafetyGymnasiumEnv(CMDP): """Safety Gymnasium Environment. Args: env_id (str): Environment id. num_envs (int, optional): Number of environments. Defaults to 1. device (torch.device, optional): Device to store the data. Defaults to ``torch.device('cpu')``. **kwargs (Any): Other arguments. Attributes: need_auto_reset_wrapper (bool): Whether to use auto reset wrapper. need_time_limit_wrapper (bool): Whether to use time limit wrapper. """ need_auto_reset_wrapper: bool = False need_time_limit_wrapper: bool = False _support_envs: list[str] = [ 'SafetyPointGoal0-v0', 'SafetyPointGoal1-v0', 'SafetyPointGoal2-v0', 'SafetyPointButton0-v0', 'SafetyPointButton1-v0', 'SafetyPointButton2-v0', 'SafetyPointPush0-v0', 'SafetyPointPush1-v0', 'SafetyPointPush2-v0', 'SafetyPointCircle0-v0', 'SafetyPointCircle1-v0', 'SafetyPointCircle2-v0', 'SafetyCarGoal0-v0', 'SafetyCarGoal1-v0', 'SafetyCarGoal2-v0', 'SafetyCarButton0-v0', 'SafetyCarButton1-v0', 'SafetyCarButton2-v0', 'SafetyCarPush0-v0', 'SafetyCarPush1-v0', 'SafetyCarPush2-v0', 'SafetyCarCircle0-v0', 'SafetyCarCircle1-v0', 'SafetyCarCircle2-v0', 'SafetyAntGoal0-v0', 'SafetyAntGoal1-v0', 'SafetyAntGoal2-v0', 'SafetyAntButton0-v0', 'SafetyAntButton1-v0', 'SafetyAntButton2-v0', 'SafetyAntPush0-v0', 'SafetyAntPush1-v0', 'SafetyAntPush2-v0', 'SafetyAntCircle0-v0', 'SafetyAntCircle1-v0', 'SafetyAntCircle2-v0', 'SafetyHalfCheetahVelocity-v1', 'SafetyHopperVelocity-v1', 'SafetySwimmerVelocity-v1', 'SafetyWalker2dVelocity-v1', 'SafetyAntVelocity-v1', 'SafetyHumanoidVelocity-v1', ] def __init__( self, env_id: str, num_envs: int = 1, device: torch.device = DEVICE_CPU, **kwargs: Any, ) -> None: """Initialize an instance of :class:`SafetyGymnasiumEnv`.""" super().__init__(env_id) self._num_envs = num_envs self._device = torch.device(device) if num_envs > 1: self._env = safety_gymnasium.vector.make(env_id=env_id, num_envs=num_envs, **kwargs) assert isinstance(self._env.single_action_space, Box), 'Only support Box action space.' assert isinstance( self._env.single_observation_space, Box, ), 'Only support Box observation space.' self._action_space = self._env.single_action_space self._observation_space = self._env.single_observation_space else: self.need_time_limit_wrapper = True self.need_auto_reset_wrapper = True self._env = safety_gymnasium.make(id=env_id, autoreset=True, **kwargs) assert isinstance(self._env.action_space, Box), 'Only support Box action space.' assert isinstance( self._env.observation_space, Box, ), 'Only support Box observation space.' self._action_space = self._env.action_space self._observation_space = self._env.observation_space self._metadata = self._env.metadata
[docs] def step( self, action: torch.Tensor, ) -> tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any], ]: """Step the environment. .. note:: OmniSafe uses auto reset wrapper to reset the environment when the episode is terminated. So the ``obs`` will be the first observation of the next episode. And the true ``final_observation`` in ``info`` will be stored in the ``final_observation`` key of ``info``. Args: action (torch.Tensor): Action to take. Returns: observation: The agent's observation of the current environment. reward: The amount of reward returned after previous action. cost: The amount of cost returned after previous action. terminated: Whether the episode has ended. truncated: Whether the episode has been truncated due to a time limit. info: Some information logged by the environment. """ obs, reward, cost, terminated, truncated, info = self._env.step( action.detach().cpu().numpy(), ) obs, reward, cost, terminated, truncated = ( torch.as_tensor(x, dtype=torch.float32, device=self._device) for x in (obs, reward, cost, terminated, truncated) ) if 'final_observation' in info: info['final_observation'] = np.array( [ array if array is not None else np.zeros(obs.shape[-1]) for array in info['final_observation'] ], ) info['final_observation'] = torch.as_tensor( info['final_observation'], dtype=torch.float32, device=self._device, ) return obs, reward, cost, terminated, truncated, info
[docs] def reset(self, seed: int | None = None) -> tuple[torch.Tensor, dict[str, Any]]: """Reset the environment. Args: seed (int or None, optional): Seed to reset the environment. Defaults to None. Returns: observation: Agent's observation of the current environment. info: Some information logged by the environment. """ obs, info = self._env.reset(seed=seed) return torch.as_tensor(obs, dtype=torch.float32, device=self._device), info
[docs] def set_seed(self, seed: int) -> None: """Set the seed for the environment. Args: seed (int): Seed to set. """ self.reset(seed=seed)
[docs] def sample_action(self) -> torch.Tensor: """Sample a random action. Returns: A random action. """ return torch.as_tensor( self._env.action_space.sample(), dtype=torch.float32, device=self._device, )
[docs] def render(self) -> Any: """Render the environment. Returns: Rendered image. """ return self._env.render()
[docs] def close(self) -> None: """Close the environment.""" self._env.close()