diff --git a/docs/source/contributing/tasks.md b/docs/source/contributing/tasks.md
index 4367d6e44..e5c08920b 100644
--- a/docs/source/contributing/tasks.md
+++ b/docs/source/contributing/tasks.md
@@ -68,7 +68,7 @@ class PushCube(BaseEnv):
@property
def _default_sim_config(self):
return SimConfig(
- gpu_memory_cfg=GPUMemoryConfig(
+ gpu_memory_config=GPUMemoryConfig(
found_lost_pairs_capacity=2**25, max_rigid_patch_count=2**18
)
)
@@ -83,7 +83,7 @@ class RotateSingleObjectInHand(BaseEnv):
@property
def _default_sim_config(self):
return SimConfig(
- gpu_memory_cfg=GPUMemoryConfig(
+ gpu_memory_config=GPUMemoryConfig(
max_rigid_contact_count=self.num_envs * max(1024, self.num_envs) * 8,
max_rigid_patch_count=self.num_envs * max(1024, self.num_envs) * 2,
found_lost_pairs_capacity=2**26,
@@ -91,7 +91,7 @@ class RotateSingleObjectInHand(BaseEnv):
)
```
-For GPU simulation tuning, there are generally two considerations, memory and speed. It is recommended to set `gpu_memory_cfg` in such a way so that no errors are outputted when simulating as many as `4096` parallel environments with state observations on a single GPU.
+For GPU simulation tuning, there are generally two considerations, memory and speed. It is recommended to set `gpu_memory_config` in such a way so that no errors are outputted when simulating as many as `4096` parallel environments with state observations on a single GPU.
A simple way to test is to run the GPU sim benchmarking script on your already registered environment and check if any errors are reported
@@ -126,5 +126,5 @@ Examples of task cards are found throughout the [task documentation](../tasks/in
When contributing the task, make sure you do the following:
- The task code itself should have a reasonable unique name and be placed in `mani_skill/envs/tasks`.
-- Added a demo video of the task being solved successfully (for each variation if there are several) to `figures/environment_demos`. The video should have ray-tracing on so it looks nicer! This can be done by replaying a trajectory with `shader_dir="rt"` passed into `gym.make` when making the environment.
+- Added a demo video of the task being solved successfully (for each variation if there are several) to `figures/environment_demos`. The video should have ray-tracing on so it looks nicer! This can be done by replaying a trajectory with `human_render_camera_configs=dict(shader_pack="rt")` passed into `gym.make` when making the environment.
- Added a task card to `docs/source/tasks/index.md`.
\ No newline at end of file
diff --git a/docs/source/user_guide/concepts/gpu_simulation.md b/docs/source/user_guide/concepts/gpu_simulation.md
index 40ae0c801..117aac083 100644
--- a/docs/source/user_guide/concepts/gpu_simulation.md
+++ b/docs/source/user_guide/concepts/gpu_simulation.md
@@ -6,7 +6,7 @@ ManiSkill leverages [PhysX](https://github.com/NVIDIA-Omniverse/PhysX) to perfor
With GPU parallelization, the concept is that one can simulate a task thousands of times at once per GPU. In ManiSkill/SAPIEN this is realized by effectively putting all actors and articulations **into the same physx scene** and give each task it's own small workspace in the physx scene known as a **sub-scene**.
-The idea of sub-scenes is that reading data of e.g. actor poses is automatically pre-processed to be relative to the center of the sub-scene and not the physx scene. The diagram below shows how 64 sub-scenes may be organized. Note that each sub-scene's distance to each other is defined by the simulation configuration `sim_cfg.spacing` value which can be set when building your own task.
+The idea of sub-scenes is that reading data of e.g. actor poses is automatically pre-processed to be relative to the center of the sub-scene and not the physx scene. The diagram below shows how 64 sub-scenes may be organized. Note that each sub-scene's distance to each other is defined by the simulation configuration `sim_config.spacing` value which can be set when building your own task.
:::{figure} images/physx_scene_subscene_relationship.png
:::
diff --git a/docs/source/user_guide/concepts/observation.md b/docs/source/user_guide/concepts/observation.md
index 0a7fa5e58..fd252ffee 100644
--- a/docs/source/user_guide/concepts/observation.md
+++ b/docs/source/user_guide/concepts/observation.md
@@ -1,14 +1,12 @@
# Observation
-
-
## Observation mode
**The observation mode defines the observation space.**
All ManiSkill tasks take the observation mode (`obs_mode`) as one of the input arguments of `__init__`.
In general, the observation is organized as a dictionary (with an observation space of `gym.spaces.Dict`).
-There are two raw observations modes: `state_dict` (privileged states) and `sensor_data` (raw sensor data like visual data without postprocessing). `state` is a flat version of `state_dict`. `rgbd` and `pointcloud` apply post-processing on `sensor_data` to give convenient representations of visual data.
+There are two raw observations modes: `state_dict` (privileged states) and `sensor_data` (raw sensor data like visual data without postprocessing). `state` is a flat version of `state_dict`. `rgb+depth`, `rgb+depth+segmentation` (or any combination of `rgb`, `depth`, `segmentation`), and `pointcloud` apply post-processing on `sensor_data` to give convenient representations of visual data.
The details here show the unbatched shapes. In general there is always a batch dimension unless you are using CPU simulation. Moreover, we annotate what dtype some values are, where some have both a torch and numpy dtype depending on whether you are using GPU or CPU simulation repspectively.
@@ -16,7 +14,7 @@ The details here show the unbatched shapes. In general there is always a batch d
The observation is a dictionary of states. It usually contains privileged information such as object poses. It is not supported for soft-body tasks.
-- `agent`: robot proprioception
+- `agent`: robot proprioception (return value of a task's `_get_obs_agent` function)
- `qpos`: [nq], current joint positions. *nq* is the degree of freedom.
- `qvel`: [nq], current joint velocities
@@ -29,7 +27,7 @@ It is a flat version of *state_dict*. The observation space is `gym.spaces.Box`.
### sensor_data
-In addition to `agent` and `extra`, `sensor_data` and `sensor_param` are introduced.
+In addition to `agent` and `extra`, `sensor_data` and `sensor_param` are introduced. At the moment there are only Camera type sensors. Cameras are special in that they can be run with different choices of shaders. The default shader is called `minimal` which is the fastest and most memory efficient option. The shader chosen determines what data is stored in this observation mode. We describe the raw data format for the `minimal` shader here. Detailed information on how sensors/cameras can be customized can be found in the [sensors](../tutorials/sensors/index.md) section.
- `sensor_data`: data captured by sensors configured in the environment
- `{sensor_uid}`:
@@ -46,7 +44,7 @@ In addition to `agent` and `extra`, `sensor_data` and `sensor_param` are introdu
- `extrinsic_cv`: [4, 4], camera extrinsic (OpenCV convention)
- `intrinsic_cv`: [3, 3], camera intrinsic (OpenCV convention)
-### rgbd
+### rgb+depth+segmentation
This observation mode has the same data format as the [sensor_data mode](#sensor_data), but all sensor data from cameras are replaced with the following structure
@@ -58,9 +56,10 @@ This observation mode has the same data format as the [sensor_data mode](#sensor
- `depth`: [H, W, 1], `torch.int16, np.uint16`. The unit is millimeters. 0 stands for an invalid pixel (beyond the camera far).
- `segmentation`: [H, W, 1], `torch.int16, np.uint16`. See the [Segmentation data section](#segmentation-data) for more details.
- Otherwise keep the same data without any additional processing as in the sensor_data mode
+Note that this data is not scaled/normalized to [0, 1] or [-1, 1] in order to conserve memory, so if you consider to train on RGB, depth, and/or segmentation data be sure to scale your data before training on it.
+
-Note that this data is not scaled/normalized to [0, 1] or [-1, 1] in order to conserve memory, so if you consider to train on RGBD data be sure to scale your data before training on it.
+ManiSkill by default flexibly supports different combinations of RGB, depth, and segmentation data, namely `rgb`, `depth`, `segmentation`, `rgb+depth`, `rgb+depth+segmentation`, `rgb+segmentation`, and`depth+segmentation`. (`rgbd` is a short hand for `rgb+depth`). Whichever image modality that is not chosen will not be included in the observation and conserves some memory and GPU bandwith.
The RGB and depth data visualized can look like below:
```{image} images/replica_cad_rgbd.png
@@ -69,6 +68,8 @@ alt: RGBD from two cameras of Fetch robot inside the ReplicaCAD dataset scene
---
```
+
+
### pointcloud
This observation mode has the same data format as the [sensor_data mode](#sensor_data), but all sensor data from cameras are removed and instead a new key is added called `pointcloud`.
diff --git a/docs/source/user_guide/getting_started/images/parallel_gui_render.png b/docs/source/user_guide/getting_started/images/parallel_gui_render.png
index abc95c12c..a1ada167e 100644
Binary files a/docs/source/user_guide/getting_started/images/parallel_gui_render.png and b/docs/source/user_guide/getting_started/images/parallel_gui_render.png differ
diff --git a/docs/source/user_guide/getting_started/quickstart.md b/docs/source/user_guide/getting_started/quickstart.md
index 8e172a23d..91479efa8 100644
--- a/docs/source/user_guide/getting_started/quickstart.md
+++ b/docs/source/user_guide/getting_started/quickstart.md
@@ -86,7 +86,7 @@ For the full documentation of options you can provide for gym.make see the [docs
## GPU Parallelized/Vectorized Tasks
-ManiSkill is powered by SAPIEN which supports GPU parallelized physics simulation and GPU parallelized rendering. This enables achieving 200,000+ state-based simulation FPS and 10,000+ FPS with rendering on a single 4090 GPU on a e.g. manipulation tasks. The FPS can be higher or lower depending on what is simulated. For full benchmarking results see [this page](../additional_resources/performance_benchmarking)
+ManiSkill is powered by SAPIEN which supports GPU parallelized physics simulation and GPU parallelized rendering. This enables achieving 200,000+ state-based simulation FPS and 30,000+ FPS with rendering on a single 4090 GPU on a e.g. manipulation tasks. The FPS can be higher or lower depending on what is simulated. For full benchmarking results see [this page](../additional_resources/performance_benchmarking)
In order to run massively parallelized tasks on a GPU, it is as simple as adding the `num_envs` argument to `gym.make` as so
@@ -137,7 +137,7 @@ which will look something like this
### Parallel Rendering in one Scene
-We further support via recording or GUI to view all parallel environments at once, and you can also turn on ray-tracing for more photo-realism. Note that this feature is not useful for any practical purposes (for e.g. machine learning) apart from generating cool demonstration videos and so it is not well optimized.
+We further support via recording or GUI to view all parallel environments at once, and you can also turn on ray-tracing for more photo-realism. Note that this feature is not useful for any practical purposes (for e.g. machine learning) apart from generating cool demonstration videos.
Turning the parallel GUI render on simply requires adding the argument `parallel_in_single_scene` to `gym.make` as so
@@ -151,7 +151,7 @@ env = gym.make(
control_mode="pd_joint_delta_pos",
num_envs=16,
parallel_in_single_scene=True,
- shader_dir="rt-fast" # optionally set this argument for more photo-realistic rendering
+ viewer_camera_configs=dict(shader_pack="rt-fast"),
)
```
@@ -170,7 +170,7 @@ We currently do not properly support exposing multiple visible CUDA devices to a
Each ManiSkill task supports different **observation modes** and **control modes**, which determine its **observation space** and **action space**. They can be specified by `gym.make(env_id, obs_mode=..., control_mode=...)`.
-The common observation modes are `state`, `rgbd`, `pointcloud`. We also support `state_dict` (states organized as a hierarchical dictionary) and `sensor_data` (raw visual observations without postprocessing). Please refer to [Observation](../concepts/observation.md) for more details.
+The common observation modes are `state`, `rgbd`, `pointcloud`. We also support `state_dict` (states organized as a hierarchical dictionary) and `sensor_data` (raw visual observations without postprocessing). Please refer to [Observation](../concepts/observation.md) for more details. Furthermore, visual data generated by the simulator can be modified in many ways via shaders. Please refer to [the sensors/cameras tutorial](../tutorials/sensors/index.md) for more details.
We support a wide range of controllers. Different controllers can have different effects on your algorithms. Thus, it is recommended to understand the action space you are going to use. Please refer to [Controllers](../concepts/controllers.md) for more details.
diff --git a/docs/source/user_guide/index.md b/docs/source/user_guide/index.md
index 102340371..cc58c2d40 100644
--- a/docs/source/user_guide/index.md
+++ b/docs/source/user_guide/index.md
@@ -43,6 +43,7 @@ datasets/index
data_collection/index
reinforcement_learning/index
learning_from_demos/index
+wrappers/index
```
```{toctree}
diff --git a/docs/source/user_guide/reinforcement_learning/setup.md b/docs/source/user_guide/reinforcement_learning/setup.md
index 273c904e3..c1af79040 100644
--- a/docs/source/user_guide/reinforcement_learning/setup.md
+++ b/docs/source/user_guide/reinforcement_learning/setup.md
@@ -1,5 +1,10 @@
# Setup
+This page documents key things to know when setting up ManiSkill environments for reinforcement learning, including:
+
+- How to convert ManiSkill environments to gymnasium API compatible environments, both [single](#gym-environment-api) and [vectorized](#gym-vectorized-environment-api) APIs.
+- [Useful Wrappers](#useful-wrappers)
+
ManiSkill environments are created by gymnasium's `make` function. The result is by default a "batched" environment where every input and output is batched. Note that this is not standard gymnasium API. If you want the standard gymnasium environemnt / vectorized environment API see the next sections.
```python
@@ -56,3 +61,14 @@ You may also notice that there are two additional options when creating a vector
Note that for efficiency, everything returned by the environment will be a batched torch tensor on the GPU and not a batched numpy array on the CPU. This the only difference you may need to account for between ManiSkill vectorized environments and gymnasium vectorized environments.
+## Useful Wrappers
+
+RL practitioners often use wrappers to modify and augment environments. These are documented in the [wrappers](../wrappers/index.md) section. Some commonly used ones include:
+- [RecordEpisode](../wrappers/record.md) for recording videos/trajectories of rollouts.
+- [FlattenRGBDObservations](../wrappers/flatten.md#flatten-rgbd-observations) for flattening the `obs_mode="rgbd"` or `obs_mode="rgb+depth"` observations into a simple dictionary with just a combined `rgbd` tensor and `state` tensor.
+
+## Common Mistakes / Gotchas
+
+In old environments/benchmarks, people often have used `env.render(mode="rgb_array")` or `env.render()` to get image inputs for RL agents. This is not correct because image observations are returned by `env.reset()` and `env.step()` directly and `env.render` is just for visualization/video recording only in ManiSkill.
+
+For robotics tasks observations often are composed of state information (like robot joint angles) and image observations (like camera images). All tasks in ManiSkill will specifically remove certain priviliged state information from the observations when the `obs_mode` is not `state` or `state_dict` like ground truth object poses. Moreover, the image observations returned by `env.reset()` and `env.step()` are usually from cameras that are positioned in specific locations to provide a good view of the task to make it solvable.
\ No newline at end of file
diff --git a/docs/source/user_guide/tutorials/custom_tasks/advanced.md b/docs/source/user_guide/tutorials/custom_tasks/advanced.md
index 524356d9a..c746c458b 100644
--- a/docs/source/user_guide/tutorials/custom_tasks/advanced.md
+++ b/docs/source/user_guide/tutorials/custom_tasks/advanced.md
@@ -168,7 +168,7 @@ In the drop down below is a copy of all the configurations possible
:::{dropdown} All sim configs
:icon: code
-```
+```python
@dataclass
class GPUMemoryConfig:
"""A gpu memory configuration dataclass that neatly holds all parameters that configure physx GPU memory for simulation"""
@@ -232,16 +232,16 @@ class DefaultMaterialsConfig:
@dataclass
class SimConfig:
- spacing: int = 5
+ spacing: float = 5
"""Controls the spacing between parallel environments when simulating on GPU in meters. Increase this value
if you expect objects in one parallel environment to impact objects within this spacing distance"""
sim_freq: int = 100
"""simulation frequency (Hz)"""
control_freq: int = 20
"""control frequency (Hz). Every control step (e.g. env.step) contains sim_freq / control_freq physx simulation steps"""
- gpu_memory_cfg: GPUMemoryConfig = field(default_factory=GPUMemoryConfig)
- scene_cfg: SceneConfig = field(default_factory=SceneConfig)
- default_materials_cfg: DefaultMaterialsConfig = field(
+ gpu_memory_config: GPUMemoryConfig = field(default_factory=GPUMemoryConfig)
+ scene_config: SceneConfig = field(default_factory=SceneConfig)
+ default_materials_config: DefaultMaterialsConfig = field(
default_factory=DefaultMaterialsConfig
)
@@ -259,7 +259,7 @@ class MyCustomTask(BaseEnv):
@property
def _default_sim_config(self):
return SimConfig(
- gpu_memory_cfg=GPUMemoryConfig(
+ gpu_memory_config=GPUMemoryConfig(
max_rigid_contact_count=self.num_envs * max(1024, self.num_envs) * 8,
max_rigid_patch_count=self.num_envs * max(1024, self.num_envs) * 2,
found_lost_pairs_capacity=2**26,
diff --git a/docs/source/user_guide/tutorials/index.md b/docs/source/user_guide/tutorials/index.md
index face965b2..e057e0e25 100644
--- a/docs/source/user_guide/tutorials/index.md
+++ b/docs/source/user_guide/tutorials/index.md
@@ -10,6 +10,7 @@ For those looking for a quickstart/tutorial on Google Colab, checkout the [quick
custom_tasks/index
custom_robots
+sensors/index
custom_reusable_scenes
domain_randomization
```
\ No newline at end of file
diff --git a/docs/source/user_guide/tutorials/sensors/index.md b/docs/source/user_guide/tutorials/sensors/index.md
new file mode 100644
index 000000000..e38dfa89f
--- /dev/null
+++ b/docs/source/user_guide/tutorials/sensors/index.md
@@ -0,0 +1,37 @@
+# Sensors / Cameras
+
+This page documents how to use / customize sensors and cameras in ManiSkill in depth at runtime and in task/environment definitions. In ManiSkill, sensors are "devices" that can capture some modality of data. At the moment there is only the Camera sensor type.
+
+## Cameras
+
+Cameras in ManiSkill can capture a ton of different modalities of data. By default ManiSkill limits those to just `rgb`, `depth`, `position` (which is used to derive depth), and `segmentation`. Internally ManiSkill uses [SAPIEN](https://sapien.ucsd.edu/) which has a highly optimized rendering system that leverages shaders to render different modalities of data.
+
+Each shader has a preset configuration that generates textures containing data in a image format, often in a somewhat difficult to use format due to heavy optimization. ManiSkill uses a shader configuration system in python that parses these different shaders into more user friendly formats (namely the well known `rgb`, `depth`, `position`, and `segmentation` type data). This shader config system resides in this file on [Github](https://github.com/haosulab/ManiSkill/blob/main/mani_skill/render/shaders.py) and defines a few friendly defaults for minimal/fast rendering and ray-tracing.
+
+
+Every ManiSkill environment will have 3 categories of cameras (although some categories can be empty): sensors for observations for agents/policies, human_render_cameras for (high-quality) video capture for humans, and a single viewing camera which is used by the GUI application to render the environment.
+
+
+At runtime when creating environments with `gym.make`, you can pass runtime overrides to any of these cameras as so. Below changes human render cameras to use the ray-tracing shader for photorealistic rendering, modifies sensor cameras to have width 320 and height 240, and changes the viewer camera to have a different field of view value.
+
+```python
+gym.make("PickCube-v1",
+ sensor_configs=dict(width=320, height=240),
+ human_render_camera_configs=dict(shader_pack="rt"),
+ viewer_camera_configs=dict(fov=1),
+)
+```
+
+These overrides will affect every camera in the environment in that group. So `sensor_configs=dict(width=320, height=240)` will change the width and height of every sensor camera in the environment, but will not affect the human render cameras or the viewer camera.
+
+To override specific cameras, you can do it by camera name. For example, if you want to override the sensor camera with name `camera_0` to have a different width and height, you can do it as so:
+
+```python
+gym.make("PickCube-v1",
+ sensor_configs=dict(camera_0=dict(width=320, height=240)),
+)
+```
+
+Now all other sensor cameras will have the default width and height, and `camera_0` will have the specified width and height.
+
+These specific customizations can be useful for those looking to customize how they render or generate policy observations to suit their needs.
\ No newline at end of file
diff --git a/docs/source/user_guide/wrappers/flatten.md b/docs/source/user_guide/wrappers/flatten.md
new file mode 100644
index 000000000..0adf99f15
--- /dev/null
+++ b/docs/source/user_guide/wrappers/flatten.md
@@ -0,0 +1,54 @@
+# Flattening Data
+
+A suite of common wrappers useful for flattening/transforming data like observations/actions into more useful formats for e.g. Reinforcement Learning or Imitation Learning.
+
+## Flatten Observations
+
+A simple wrapper to flatten a dictionary observation space into a flat array observation space.
+
+```python
+import mani_skill.envs
+from mani_skill.utils.wrappers import FlattenObservationWrapper
+import gymnasium as gym
+
+env = gym.make("PickCube-v1", obs_mode="state_dict")
+print(env.observation_space) # is a complex nested dictionary
+# Dict('agent': Dict('qpos': Box(-inf, inf, (1, 9), float32), 'qvel': Box(-inf, inf, (1, 9), float32)), 'extra': Dict('is_grasped': Box(False, True, (1,), bool), 'tcp_pose': Box(-inf, inf, (1, 7), float32), 'goal_pos': Box(-inf, inf, (1, 3), float32), 'obj_pose': Box(-inf, inf, (1, 7), float32), 'tcp_to_obj_pos': Box(-inf, inf, (1, 3), float32), 'obj_to_goal_pos': Box(-inf, inf, (1, 3), float32)))
+env = FlattenObservationWrapper(env)
+print(env.observation_space) # is a flat array now
+# Box(-inf, inf, (1, 42), float32)
+```
+
+## Flatten Actions
+
+A simple wrapper to flatten a dictionary action space into a flat array action space. Commonly used for multi-agent like environments when you want to control multiple agents/robots together with one action space.
+
+```python
+import mani_skill.envs
+from mani_skill.utils.wrappers import FlattenActionSpaceWrapper
+import gymnasium as gym
+
+env = gym.make("TwoRobotStackCube-v1")
+print(env.action_space) # is a dictionary
+# Dict('panda-0': Box(-1.0, 1.0, (8,), float32), 'panda-1': Box(-1.0, 1.0, (8,), float32))
+env = FlattenActionSpaceWrapper(env)
+print(env.action_space) # is a flat array now
+# Box(-1.0, 1.0, (16,), float32)
+```
+
+## Flatten RGBD Observations
+
+This wrapper concatenates all the RGB and Depth images into a single image with combined channels, and concatenates all state data into a single array so that the observation space becomes a simple dictionary composed of a `state` key and a `rgbd` key.
+
+```python
+import mani_skill.envs
+from mani_skill.utils.wrappers import FlattenRGBDObservationWrapper
+import gymnasium as gym
+
+env = gym.make("PickCube-v1", obs_mode="rgbd")
+print(env.observation_space) # is a complex dictionary
+# Dict('agent': Dict('qpos': Box(-inf, inf, (1, 9), float32), 'qvel': Box(-inf, inf, (1, 9), float32)), 'extra': Dict('is_grasped': Box(False, True, (1,), bool), 'tcp_pose': Box(-inf, inf, (1, 7), float32), 'goal_pos': Box(-inf, inf, (1, 3), float32)), 'sensor_param': Dict('base_camera': Dict('extrinsic_cv': Box(-inf, inf, (1, 3, 4), float32), 'cam2world_gl': Box(-inf, inf, (1, 4, 4), float32), 'intrinsic_cv': Box(-inf, inf, (1, 3, 3), float32))), 'sensor_data': Dict('base_camera': Dict('rgb': Box(0, 255, (1, 128, 128, 3), uint8), 'depth': Box(-32768, 32767, (1, 128, 128, 1), int16))))
+env = FlattenRGBDObservationWrapper(env)
+print(env.observation_space) # is a much simpler dictionary now
+# Dict('state': Box(-inf, inf, (1, 29), float32), 'rgbd': Box(-32768, 32767, (1, 128, 128, 4), int16))
+```
\ No newline at end of file
diff --git a/docs/source/user_guide/wrappers/index.md b/docs/source/user_guide/wrappers/index.md
new file mode 100644
index 000000000..5b6de84b4
--- /dev/null
+++ b/docs/source/user_guide/wrappers/index.md
@@ -0,0 +1,9 @@
+# Wrappers
+
+
+```{toctree}
+:titlesonly:
+
+record
+flatten
+```
\ No newline at end of file
diff --git a/docs/source/user_guide/wrappers/record.md b/docs/source/user_guide/wrappers/record.md
new file mode 100644
index 000000000..984973a06
--- /dev/null
+++ b/docs/source/user_guide/wrappers/record.md
@@ -0,0 +1,58 @@
+# Recording Episodes
+
+ManiSkill provides a few ways to record videos/trajectories of tasks on single and vectorized environments. The recommended way is via [a wrapper](#recordepisode-wrapper). The other way is to [call a function](#capture-individual-images) to generate the video frames yourself and compile them into a video yourself.
+
+## RecordEpisode Wrapper
+
+The recommended approach is to use our RecordEpisode wrapper, which supports both single and vectorized environments, and saves videos and/or trajectory data (in the [ManiSkill format](../datasets/demos.md)) to disk. It will save whatever render_mode is specified upon environment creation (can be "rgb_array", "sensors", or "all" which combines both).
+
+This wrapper by default saves videos on environment reset for single environments
+```python
+import mani_skill.envs
+import gymnasium as gym
+from mani_skill.utils.wrappers.record import RecordEpisode
+env = gym.make("PickCube-v1", num_envs=1, render_mode="rgb_array")
+env = RecordEpisode(env, output_dir="videos", save_trajectory=True, trajectory_name="trajectory", save_video=True, video_fps=30)
+env.reset()
+for _ in range(200):
+ obs, reward, terminated, truncated, info = env.step(env.action_space.sample())
+ if terminated or truncated:
+ env.reset()
+```
+
+For vectorized environments, the wrapper will save videos of length `max_steps_per_video` before flushing the video to disk and starting a new video. It does not save on reset as environments can have partial resets.
+
+```python
+import mani_skill.envs
+import gymnasium as gym
+from mani_skill.utils.wrappers.record import RecordEpisode
+from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
+N = 4
+env = gym.make("PickCube-v1", num_envs=N, render_mode="rgb_array")
+env = RecordEpisode(env, output_dir="videos", save_trajectory=True, trajectory_name="trajectory", max_steps_per_video=50, video_fps=30)
+env = ManiSkillVectorEnv(env, auto_reset=True) # adds auto reset
+env.reset()
+for _ in range(200):
+ obs, reward, terminated, truncated, info = env.step(env.action_space.sample())
+```
+
+## Capture Individual Images
+
+If you want to use your own custom video recording methods, then you can call the API directly to capture images of the environment. This works the same for both single and vectorized environments.
+
+```python
+import mani_skill.envs
+import gymnasium as gym
+N = 1
+env = gym.make("PickCube-v1", num_envs=N)
+images = []
+env.reset()
+images.append(env.render_rgb_array())
+for _ in range(200):
+ obs, reward, terminated, truncated, info = env.step(env.action_space.sample())
+ images.append(env.render_rgb_array())
+ # env.render_sensors() # render sensors mode
+ # env.render_all() # render all mode
+```
+
+Note that the return of `env.render_rgb_array(), env.render_sensors()` etc. are all batched torch tensors on the GPU. You will likely need to convert them to CPU numpy arrays to save them to disk.
\ No newline at end of file
diff --git a/mani_skill/agents/controllers/utils/kinematics.py b/mani_skill/agents/controllers/utils/kinematics.py
index 9efecbd78..4a3c52075 100644
--- a/mani_skill/agents/controllers/utils/kinematics.py
+++ b/mani_skill/agents/controllers/utils/kinematics.py
@@ -7,9 +7,9 @@
try:
import pytorch_kinematics as pk
-finally:
- print(
- "pytorch_kinematics not installed. Install with pip install pytorch_kinematics_ms"
+except ImportError:
+ raise ImportError(
+ "pytorch_kinematics_ms not installed. Install with pip install pytorch_kinematics_ms"
)
import torch
from sapien.wrapper.pinocchio_model import PinocchioModel
diff --git a/mani_skill/envs/sapien_env.py b/mani_skill/envs/sapien_env.py
index bd8068479..e8d2d8cc4 100644
--- a/mani_skill/envs/sapien_env.py
+++ b/mani_skill/envs/sapien_env.py
@@ -1,5 +1,6 @@
import copy
import gc
+import logging
import os
from functools import cached_property
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
@@ -13,7 +14,6 @@
import sapien.utils.viewer.control_window
import torch
from gymnasium.vector.utils import batch_space
-from sapien.utils import Viewer
from mani_skill import PACKAGE_ASSET_DIR, logger
from mani_skill.agents import REGISTERED_AGENTS
@@ -21,22 +21,22 @@
from mani_skill.agents.multi_agent import MultiAgent
from mani_skill.envs.scene import ManiSkillScene
from mani_skill.envs.utils.observations import (
+ parse_visual_obs_mode_to_struct,
sensor_data_to_pointcloud,
- sensor_data_to_rgbd,
)
-from mani_skill.render import SAPIEN_RENDER_SYSTEM
from mani_skill.sensors.base_sensor import BaseSensor, BaseSensorConfig
from mani_skill.sensors.camera import (
Camera,
CameraConfig,
- parse_camera_cfgs,
- update_camera_cfgs_from_dict,
+ parse_camera_configs,
+ update_camera_configs_from_dict,
)
from mani_skill.sensors.depth_camera import StereoDepthCamera, StereoDepthCameraConfig
from mani_skill.utils import common, gym_utils, sapien_utils
from mani_skill.utils.structs import Actor, Articulation
from mani_skill.utils.structs.types import Array, SimConfig
-from mani_skill.utils.visualization.misc import observations_to_images, tile_images
+from mani_skill.utils.visualization.misc import tile_images
+from mani_skill.viewer import create_viewer
class BaseEnv(gym.Env):
@@ -50,7 +50,9 @@ class BaseEnv(gym.Env):
gpu_sim_backend: The GPU simulation backend to use (only used if the given num_envs argument is > 1). This affects the type of tensor
returned by the environment for e.g. observations and rewards. Can be "torch" or "jax". Default is "torch"
- obs_mode: observation mode to be used. Must be one of ("state", "state_dict", "none", "sensor_data", "rgb", "rgbd", "pointcloud")
+ obs_mode: observation mode to be used. Must be one of ("state", "state_dict", "none", "sensor_data", "rgb", "depth", "segmentation", "rgbd", "rgb+depth", "rgb+depth+segmentation", "rgb+segmentation", "depth+segmentation", "pointcloud")
+ The obs_mode is mostly for convenience to automatically optimize/setup all sensors/cameras for the given observation mode to render the correct data and try to ignore unecesary rendering.
+ For the most advanced use cases (e.g. you have 1 RGB only camera and 1 depth only camera)
reward_mode: reward mode to use. Must be one of ("normalized_dense", "dense", "sparse", "none"). With "none" the reward returned is always 0
@@ -59,22 +61,28 @@ class BaseEnv(gym.Env):
render_mode: render mode registered in @SUPPORTED_RENDER_MODES.
- shader_dir (str): shader directory. Defaults to "default".
- "default", "rt", "rt-fast" are built-in options with SAPIEN. Other options are user-defined. "rt" means ray-tracing which results
+ shader_dir (Optional[str]): shader directory. Defaults to None.
+ Setting this will override the shader used for all cameras in the environment. This is legacy behavior kept for backwards compatibility.
+ The proper way to change the shaders used for cameras is to either change the environment code or pass in sensor_configs/human_render_camera_configs with the desired shaders.
+
+
+ Previously the options are "default", "rt", "rt-fast". "rt" means ray-tracing which results
in more photorealistic renders but is slow, "rt-fast" is a lower quality but faster version of "rt".
enable_shadow (bool): whether to enable shadow for lights. Defaults to False.
- sensor_cfgs (dict): configurations of sensors. See notes for more details.
+ sensor_configs (dict): configurations of sensors. See notes for more details.
- human_render_camera_cfgs (dict): configurations of human rendering cameras. Similar usage as @sensor_cfgs.
+ human_render_camera_configs (dict): configurations of human rendering cameras. Similar usage as @sensor_configs.
+
+ viewer_camera_configs (dict): configurations of the viewer camera in the GUI. Similar usage as @sensor_configs.
robot_uids (Union[str, BaseAgent, List[Union[str, BaseAgent]]]): List of robots to instantiate and control in the environment.
- sim_cfg (Union[SimConfig, dict]): Configurations for simulation if used that override the environment defaults. If given
- a dictionary, it can just override specific attributes e.g. `sim_cfg=dict(scene_cfg=dict(solver_iterations=25))`. If
+ sim_config (Union[SimConfig, dict]): Configurations for simulation if used that override the environment defaults. If given
+ a dictionary, it can just override specific attributes e.g. `sim_config=dict(scene_config=dict(solver_iterations=25))`. If
passing in a SimConfig object, while typed, will override every attribute including the task defaults. Some environments
- define their own recommended default sim configurations via the `self._default_sim_cfg` attribute that generally should not be
+ define their own recommended default sim configurations via the `self._default_sim_config` attribute that generally should not be
completely overriden. For a full detail/explanation of what is in the sim config see the type hints / go to the source
https://github.com/haosulab/ManiSkill/blob/main/mani_skill/utils/structs/types.py
@@ -97,7 +105,7 @@ class BaseEnv(gym.Env):
otherwise as it slows down simulation and rendering.
Note:
- `sensor_cfgs` is used to update environement-specific sensor configurations.
+ `sensor_configs` is used to update environement-specific sensor configurations.
If the key is one of sensor names (e.g. a camera), the value will be applied to the corresponding sensor.
Otherwise, the value will be applied to all sensors (but overridden by sensor-specific values).
"""
@@ -106,9 +114,9 @@ class BaseEnv(gym.Env):
SUPPORTED_ROBOTS: List[Union[str, Tuple[str]]] = None
"""Override this to enforce which robots or tuples of robots together are supported in the task. During env creation,
setting robot_uids auto loads all desired robots into the scene, but not all tasks are designed to support some robot setups"""
- SUPPORTED_OBS_MODES = ("state", "state_dict", "none", "sensor_data", "rgb", "rgbd", "pointcloud")
+ SUPPORTED_OBS_MODES = ("state", "state_dict", "none", "sensor_data", "rgb", "depth", "segmentation", "rgbd", "rgb+depth", "rgb+depth+segmentation", "rgb+segmentation", "depth+segmentation", "pointcloud")
SUPPORTED_REWARD_MODES = ("normalized_dense", "dense", "sparse", "none")
- SUPPORTED_RENDER_MODES = ("human", "rgb_array", "sensors")
+ SUPPORTED_RENDER_MODES = ("human", "rgb_array", "sensors", "all")
"""The supported render modes. Human opens up a GUI viewer. rgb_array returns an rgb array showing the current environment state.
sensors returns an rgb array but only showing all data collected by sensors as images put together"""
@@ -159,12 +167,13 @@ def __init__(
reward_mode: str = None,
control_mode: str = None,
render_mode: str = None,
- shader_dir: str = "default",
+ shader_dir: Optional[str] = None,
enable_shadow: bool = False,
- sensor_configs: dict = None,
- human_render_camera_configs: dict = None,
+ sensor_configs: Optional[dict] = dict(),
+ human_render_camera_configs: Optional[dict] = dict(),
+ viewer_camera_configs: Optional[dict] = dict(),
robot_uids: Union[str, BaseAgent, List[Union[str, BaseAgent]]] = None,
- sim_cfg: Union[SimConfig, dict] = dict(),
+ sim_config: Union[SimConfig, dict] = dict(),
reconfiguration_freq: int = None,
sim_backend: str = "auto",
render_backend: str = "gpu",
@@ -174,8 +183,14 @@ def __init__(
self.num_envs = num_envs
self.reconfiguration_freq = reconfiguration_freq if reconfiguration_freq is not None else 0
self._reconfig_counter = 0
+ if shader_dir is not None:
+ logging.warn("shader_dir argument will be deprecated after ManiSkill v3.0.0 official release. Please use sensor_configs/human_render_camera_configs to set shaders.")
+ sensor_configs |= dict(shader_pack=shader_dir)
+ human_render_camera_configs |= dict(shader_pack=shader_dir)
+ viewer_camera_configs |= dict(shader_pack=shader_dir)
self._custom_sensor_configs = sensor_configs
self._custom_human_render_camera_configs = human_render_camera_configs
+ self._custom_viewer_camera_configs = viewer_camera_configs
self._parallel_in_single_scene = parallel_in_single_scene
self.robot_uids = robot_uids
if self.SUPPORTED_ROBOTS is not None:
@@ -219,67 +234,36 @@ def __init__(
elif render_backend[:4] == "cuda":
self._render_device = sapien.Device(render_backend)
+
+
+
# raise a number of nicer errors
if sim_backend == "cpu" and num_envs > 1:
raise RuntimeError("""Cannot set the sim backend to 'cpu' and have multiple environments.
If you want to do CPU sim backends and have environment vectorization you must use multi-processing across CPUs.
This can be done via the gymnasium's AsyncVectorEnv API""")
- if "rt" == shader_dir[:2]:
- if obs_mode in ["sensor_data", "rgb", "rgbd", "pointcloud"]:
- raise RuntimeError("""Currently you cannot use ray-tracing while running simulation with visual observation modes. You may still use
- env.render_rgb_array() or the RecordEpisode wrapper to save videos of ray-traced results""")
- if num_envs > 1 and parallel_in_single_scene == False:
- raise RuntimeError("""Currently you cannot run ray-tracing on more than one environment in a single process""")
+
+ if shader_dir is not None:
+ if "rt" == shader_dir[:2]:
+ if num_envs > 1 and parallel_in_single_scene == False:
+ raise RuntimeError("""Currently you cannot run ray-tracing on more than one environment in a single process""")
assert not parallel_in_single_scene or (obs_mode not in ["sensor_data", "pointcloud", "rgb", "depth", "rgbd"]), \
"Parallel rendering from parallel cameras is only supported when the gui/viewer is not used. parallel_in_single_scene must be False if using parallel rendering. If True only state based observations are supported."
- if isinstance(sim_cfg, SimConfig):
- sim_cfg = sim_cfg.dict()
- merged_gpu_sim_cfg = self._default_sim_config.dict()
- common.dict_merge(merged_gpu_sim_cfg, sim_cfg)
- self.sim_cfg = dacite.from_dict(data_class=SimConfig, data=merged_gpu_sim_cfg, config=dacite.Config(strict=True))
+ if isinstance(sim_config, SimConfig):
+ sim_config = sim_config.dict()
+ merged_gpu_sim_config = self._default_sim_config.dict()
+ common.dict_merge(merged_gpu_sim_config, sim_config)
+ self.sim_config = dacite.from_dict(data_class=SimConfig, data=merged_gpu_sim_config, config=dacite.Config(strict=True))
"""the final sim config after merging user overrides with the environment default"""
- physx.set_gpu_memory_config(**self.sim_cfg.gpu_memory_cfg.dict())
-
- if SAPIEN_RENDER_SYSTEM == "3.0":
- self.shader_dir = shader_dir
- if self.shader_dir == "default":
- sapien.render.set_camera_shader_dir("minimal")
- sapien.render.set_picture_format("Color", "r8g8b8a8unorm")
- sapien.render.set_picture_format("ColorRaw", "r8g8b8a8unorm")
- sapien.render.set_picture_format("PositionSegmentation", "r16g16b16a16sint")
- elif self.shader_dir == "rt":
- sapien.render.set_camera_shader_dir("rt")
- sapien.render.set_viewer_shader_dir("rt")
- sapien.render.set_ray_tracing_samples_per_pixel(32)
- sapien.render.set_ray_tracing_path_depth(16)
- sapien.render.set_ray_tracing_denoiser(
- "optix"
- ) # TODO "optix or oidn?" previous value was just True
- elif self.shader_dir == "rt-fast":
- sapien.render.set_camera_shader_dir("rt")
- sapien.render.set_viewer_shader_dir("rt")
- sapien.render.set_ray_tracing_samples_per_pixel(2)
- sapien.render.set_ray_tracing_path_depth(1)
- sapien.render.set_ray_tracing_denoiser("optix")
- elif self.shader_dir == "rt-med":
- sapien.render.set_camera_shader_dir("rt")
- sapien.render.set_viewer_shader_dir("rt")
- sapien.render.set_ray_tracing_samples_per_pixel(4)
- sapien.render.set_ray_tracing_path_depth(3)
- sapien.render.set_ray_tracing_denoiser("optix")
- elif SAPIEN_RENDER_SYSTEM == "3.1":
- self.shader_dir = "None"
+ physx.set_gpu_memory_config(**self.sim_config.gpu_memory_config.dict())
sapien.render.set_log_level(os.getenv("MS_RENDERER_LOG_LEVEL", "warn"))
# Set simulation and control frequency
- self._sim_freq = self.sim_cfg.sim_freq
- self._control_freq = self.sim_cfg.control_freq
- if self._sim_freq % self._control_freq != 0:
- logger.warn(
- f"sim_freq({self._sim_freq}) is not divisible by control_freq({self._control_freq}).",
- )
+ self._sim_freq = self.sim_config.sim_freq
+ self._control_freq = self.sim_config.control_freq
+ assert self._sim_freq % self._control_freq == 0, f"sim_freq({self._sim_freq}) is not divisible by control_freq({self._control_freq})."
self._sim_steps_per_control = self._sim_freq // self._control_freq
# Observation mode
@@ -288,6 +272,7 @@ def __init__(
if obs_mode not in self.SUPPORTED_OBS_MODES:
raise NotImplementedError("Unsupported obs mode: {}".format(obs_mode))
self._obs_mode = obs_mode
+ self._visual_obs_mode_struct = parse_visual_obs_mode_to_struct(self._obs_mode)
# Reward mode
if reward_mode is None:
@@ -392,11 +377,18 @@ def _default_sensor_configs(
def _default_human_render_camera_configs(
self,
) -> Union[
- BaseSensorConfig, Sequence[BaseSensorConfig], Dict[str, BaseSensorConfig]
+ CameraConfig, Sequence[CameraConfig], Dict[str, CameraConfig]
]:
"""Add default cameras for rendering when using render_mode='rgb_array'. These can be overriden by the user at env creation time """
return []
+ @property
+ def _default_viewer_camera_configs(
+ self,
+ ) -> CameraConfig:
+ """Default configuration for the viewer camera, controlling shader, fov, etc. By default if there is a human render camera called "render_camera" then the viewer will use that camera's pose."""
+ return CameraConfig(uid="viewer", pose=sapien.Pose([0, 0, 1]), width=1920, height=1080, shader_pack="minimal")
+
@property
def sim_freq(self):
return self._sim_freq
@@ -452,15 +444,15 @@ def get_obs(self, info: Optional[Dict] = None):
obs = common.flatten_state_dict(state_dict, use_torch=True, device=self.device)
elif self._obs_mode == "state_dict":
obs = self._get_obs_state_dict(info)
- elif self._obs_mode in ["sensor_data", "rgbd", "rgb", "pointcloud"]:
+ elif self._obs_mode == "pointcloud":
+ # TODO support more flexible pcd obs mode with new render system
+ obs = self._get_obs_with_sensor_data(info)
+ obs = sensor_data_to_pointcloud(obs, self._sensors)
+ elif self._obs_mode == "sensor_data":
+ # return raw texture data dependent on choice of shader
+ obs = self._get_obs_with_sensor_data(info, apply_texture_transforms=False)
+ elif self._obs_mode in ["rgb", "depth", "segmentation", "rgbd", "rgb+depth", "rgb+depth+segmentation", "depth+segmentation", "rgb+segmentation"]:
obs = self._get_obs_with_sensor_data(info)
- if self._obs_mode == "rgbd":
- obs = sensor_data_to_rgbd(obs, self._sensors, rgb=True, depth=True, segmentation=True)
- elif self._obs_mode == "rgb":
- # NOTE (stao): this obs mode is merely a convenience, it does not make simulation run noticebally faster
- obs = sensor_data_to_rgbd(obs, self._sensors, rgb=True, depth=False, segmentation=True)
- elif self.obs_mode == "pointcloud":
- obs = sensor_data_to_pointcloud(obs, self._sensors)
else:
raise NotImplementedError(self._obs_mode)
return obs
@@ -473,11 +465,12 @@ def _get_obs_state_dict(self, info: Dict):
)
def _get_obs_agent(self):
- """Get observations from the agent's sensors, e.g., proprioceptive sensors."""
+ """Get observations about the agent's state. By default it is proprioceptive observations which include qpos and qvel.
+ Controller state is also included although most default controllers do not have any state."""
return self.agent.get_proprioception()
def _get_obs_extra(self, info: Dict):
- """Get task-relevant extra observations."""
+ """Get task-relevant extra observations. Usually defined on a task by task basis"""
return dict()
def capture_sensor_data(self):
@@ -485,13 +478,9 @@ def capture_sensor_data(self):
for sensor in self._sensors.values():
sensor.capture()
- def get_sensor_obs(self) -> Dict[str, Dict[str, torch.Tensor]]:
- """Get raw sensor data for use as observations."""
- return self.scene.get_sensor_obs()
-
def get_sensor_images(self) -> Dict[str, Dict[str, torch.Tensor]]:
- """Get raw sensor data as images for visualization purposes."""
- return self.scene.get_sensor_images()
+ """Get image (RGB) visualizations of what sensors currently sense"""
+ return self.scene.get_sensor_images(self._get_obs_sensor_data())
def get_sensor_params(self) -> Dict[str, Dict[str, torch.Tensor]]:
"""Get all sensor parameters."""
@@ -500,16 +489,31 @@ def get_sensor_params(self) -> Dict[str, Dict[str, torch.Tensor]]:
params[name] = sensor.get_params()
return params
- def _get_obs_with_sensor_data(self, info: Dict) -> dict:
+ def _get_obs_sensor_data(self, apply_texture_transforms: bool = True) -> dict:
+ """get only data from sensors. Auto hides any objects that are designated to be hidden"""
for obj in self._hidden_objects:
obj.hide_visual()
self.scene.update_render()
self.capture_sensor_data()
+ sensor_obs = dict()
+ for name, sensor in self.scene.sensors.items():
+ if isinstance(sensor, Camera):
+ if self.obs_mode in ["state", "state_dict"]:
+ # normally in non visual observation modes we do not render sensor observations. But some users may want to render sensor data for debugging or various algorithms
+ sensor_obs[name] = sensor.get_obs(position=False, segmentation=False, apply_texture_transforms=apply_texture_transforms)
+ else:
+ sensor_obs[name] = sensor.get_obs(rgb=self._visual_obs_mode_struct.rgb, depth=self._visual_obs_mode_struct.depth, position=self._visual_obs_mode_struct.position, segmentation=self._visual_obs_mode_struct.segmentation, apply_texture_transforms=apply_texture_transforms)
+ # explicitly synchronize and wait for cuda kernels to finish
+ # this prevents the GPU from making poor scheduling decisions when other physx code begins to run
+ torch.cuda.synchronize()
+ return sensor_obs
+ def _get_obs_with_sensor_data(self, info: Dict, apply_texture_transforms: bool = True) -> dict:
+ """Get the observation with sensor data"""
return dict(
agent=self._get_obs_agent(),
extra=self._get_obs_extra(info),
sensor_param=self.get_sensor_params(),
- sensor_data=self.get_sensor_obs(),
+ sensor_data=self._get_obs_sensor_data(apply_texture_transforms),
)
@property
@@ -619,52 +623,62 @@ def _setup_sensors(self, options: dict):
self._sensor_configs = dict()
# Add task/external sensors
- self._sensor_configs.update(parse_camera_cfgs(self._default_sensor_configs))
+ self._sensor_configs.update(parse_camera_configs(self._default_sensor_configs))
# Add agent sensors
self._agent_sensor_configs = dict()
- self._agent_sensor_configs = parse_camera_cfgs(self.agent._sensor_configs)
+ self._agent_sensor_configs = parse_camera_configs(self.agent._sensor_configs)
self._sensor_configs.update(self._agent_sensor_configs)
# Add human render camera configs
- self._human_render_camera_configs = parse_camera_cfgs(
+ self._human_render_camera_configs = parse_camera_configs(
self._default_human_render_camera_configs
)
+ self._viewer_camera_config = parse_camera_configs(
+ self._default_viewer_camera_configs
+ )
+
# Override camera configurations with user supplied configurations
if self._custom_sensor_configs is not None:
- update_camera_cfgs_from_dict(
+ update_camera_configs_from_dict(
self._sensor_configs, self._custom_sensor_configs
)
if self._custom_human_render_camera_configs is not None:
- update_camera_cfgs_from_dict(
+ update_camera_configs_from_dict(
self._human_render_camera_configs,
self._custom_human_render_camera_configs,
)
+ if self._custom_viewer_camera_configs is not None:
+ update_camera_configs_from_dict(
+ self._viewer_camera_config,
+ self._custom_viewer_camera_configs,
+ )
+ self._viewer_camera_config = self._viewer_camera_config["viewer"]
# Now we instantiate the actual sensor objects
self._sensors = dict()
- for uid, sensor_cfg in self._sensor_configs.items():
+ for uid, sensor_config in self._sensor_configs.items():
if uid in self._agent_sensor_configs:
articulation = self.agent.robot
else:
articulation = None
- if isinstance(sensor_cfg, StereoDepthCameraConfig):
+ if isinstance(sensor_config, StereoDepthCameraConfig):
sensor_cls = StereoDepthCamera
- elif isinstance(sensor_cfg, CameraConfig):
+ elif isinstance(sensor_config, CameraConfig):
sensor_cls = Camera
self._sensors[uid] = sensor_cls(
- sensor_cfg,
+ sensor_config,
self.scene,
articulation=articulation,
)
# Cameras for rendering only
self._human_render_cameras = dict()
- for uid, camera_cfg in self._human_render_camera_configs.items():
+ for uid, camera_config in self._human_render_camera_configs.items():
self._human_render_cameras[uid] = Camera(
- camera_cfg,
+ camera_config,
self.scene,
)
@@ -879,7 +893,7 @@ def _step_action(
if self.num_envs == 1 and action_is_unbatched:
action = common.batch(action)
self.agent.set_action(action)
- if physx.is_gpu_enabled():
+ if self._sim_device.is_cuda():
self.scene.px.gpu_apply_articulation_target_position()
self.scene.px.gpu_apply_articulation_target_velocity()
self._before_control_step()
@@ -933,11 +947,10 @@ def _after_simulation_step(self):
# Simulation and other gym interfaces
# -------------------------------------------------------------------------- #
def _set_scene_config(self):
- # **self.sim_cfg.scene_cfg.dict()
- physx.set_shape_config(contact_offset=self.sim_cfg.scene_cfg.contact_offset, rest_offset=self.sim_cfg.scene_cfg.rest_offset)
- physx.set_body_config(solver_position_iterations=self.sim_cfg.scene_cfg.solver_position_iterations, solver_velocity_iterations=self.sim_cfg.scene_cfg.solver_velocity_iterations, sleep_threshold=self.sim_cfg.scene_cfg.sleep_threshold)
- physx.set_scene_config(gravity=self.sim_cfg.scene_cfg.gravity, bounce_threshold=self.sim_cfg.scene_cfg.bounce_threshold, enable_pcm=self.sim_cfg.scene_cfg.enable_pcm, enable_tgs=self.sim_cfg.scene_cfg.enable_tgs, enable_ccd=self.sim_cfg.scene_cfg.enable_ccd, enable_enhanced_determinism=self.sim_cfg.scene_cfg.enable_enhanced_determinism, enable_friction_every_iteration=self.sim_cfg.scene_cfg.enable_friction_every_iteration, cpu_workers=self.sim_cfg.scene_cfg.cpu_workers )
- physx.set_default_material(**self.sim_cfg.default_materials_cfg.dict())
+ physx.set_shape_config(contact_offset=self.sim_config.scene_config.contact_offset, rest_offset=self.sim_config.scene_config.rest_offset)
+ physx.set_body_config(solver_position_iterations=self.sim_config.scene_config.solver_position_iterations, solver_velocity_iterations=self.sim_config.scene_config.solver_velocity_iterations, sleep_threshold=self.sim_config.scene_config.sleep_threshold)
+ physx.set_scene_config(gravity=self.sim_config.scene_config.gravity, bounce_threshold=self.sim_config.scene_config.bounce_threshold, enable_pcm=self.sim_config.scene_config.enable_pcm, enable_tgs=self.sim_config.scene_config.enable_tgs, enable_ccd=self.sim_config.scene_config.enable_ccd, enable_enhanced_determinism=self.sim_config.scene_config.enable_enhanced_determinism, enable_friction_every_iteration=self.sim_config.scene_config.enable_friction_every_iteration, cpu_workers=self.sim_config.scene_config.cpu_workers )
+ physx.set_default_material(**self.sim_config.default_materials_config.dict())
def _setup_scene(self):
"""Setup the simulation scene instance.
@@ -959,8 +972,8 @@ def _setup_scene(self):
self.physx_system.set_scene_offset(
scene,
[
- scene_x * self.sim_cfg.spacing,
- scene_y * self.sim_cfg.spacing,
+ scene_x * self.sim_config.spacing,
+ scene_y * self.sim_config.spacing,
0,
],
)
@@ -973,10 +986,9 @@ def _setup_scene(self):
# create a "global" scene object that users can work with that is linked with all other scenes created
self.scene = ManiSkillScene(
sub_scenes,
- sim_cfg=self.sim_cfg,
+ sim_config=self.sim_config,
device=self.device,
- parallel_in_single_scene=self._parallel_in_single_scene,
- shader_dir=self.shader_dir
+ parallel_in_single_scene=self._parallel_in_single_scene
)
self.physx_system.timestep = 1.0 / self._sim_freq
@@ -1078,21 +1090,6 @@ def _setup_viewer(self):
Called by `self._reconfigure`
"""
- # commented code below is for a different parallel render system in the GUI but it does not support ray tracing
- # instead to show parallel envs in the GUI they are all spawned into the same sub scene and offsets are auto
- # added / subtracted from object poses.
- # if self.num_envs > 1:
- # side = int(np.ceil(self.num_envs ** 0.5))
- # idx = np.arange(self.num_envs)
- # offsets = np.stack([idx // side, idx % side, np.zeros_like(idx)], axis=1) * self.sim_cfg.spacing
- # self.viewer.set_scenes(self.scene.sub_scenes, offsets=offsets)
- # vs = self.viewer.window._internal_scene # type: ignore
- # cubemap = self.scene.sub_scenes[0].render_system.get_cubemap()
- # if cubemap is not None: # type: ignore [sapien may return None]
- # vs.set_cubemap(cubemap._internal_cubemap)
- # else:
- # vs.set_ambient_light([0.5, 0.5, 0.5])
- # else:
self._viewer.set_scene(self.scene.sub_scenes[0])
control_window: sapien.utils.viewer.control_window.ControlWindow = (
sapien_utils.get_obj_by_type(
@@ -1110,7 +1107,7 @@ def render_human(self):
for obj in self._hidden_objects:
obj.show_visual()
if self._viewer is None:
- self._viewer = Viewer()
+ self._viewer = create_viewer(self._viewer_camera_config)
self._setup_viewer()
if physx.is_gpu_enabled() and self.scene._gpu_sim_initialized:
self.physx_system.sync_poses_gpu_to_cpu()
@@ -1142,14 +1139,11 @@ def render_sensors(self):
"""
Renders all sensors that the agent can use and see and displays them
"""
- for obj in self._hidden_objects:
- obj.hide_visual()
images = []
- self.scene.update_render()
- self.capture_sensor_data()
sensor_images = self.get_sensor_images()
for image in sensor_images.values():
- images.append(image)
+ for img in image.values():
+ images.append(img)
return tile_images(images)
def render_all(self):
@@ -1160,15 +1154,13 @@ def render_all(self):
self.scene.update_render()
render_images = self.scene.get_human_render_camera_images()
- for obj in self._hidden_objects:
- obj.hide_visual()
- self.scene.update_render()
- self.capture_sensor_data()
sensor_images = self.get_sensor_images()
for image in render_images.values():
- images.append(image)
+ for img in image.values():
+ images.append(img)
for image in sensor_images.values():
- images.append(image)
+ for img in image.values():
+ images.append(img)
return tile_images(images)
def render(self):
@@ -1229,8 +1221,8 @@ def print_sim_details(self):
sensor_settings_str = []
for uid, cam in self._sensors.items():
if isinstance(cam, Camera):
- cfg = cam.cfg
- sensor_settings_str.append(f"RGBD({cfg.width}x{cfg.height})")
+ config = cam.config
+ sensor_settings_str.append(f"RGBD({config.width}x{config.height})")
sensor_settings_str = ", ".join(sensor_settings_str)
sim_backend = "gpu" if physx.is_gpu_enabled() else "cpu"
print(
diff --git a/mani_skill/envs/scene.py b/mani_skill/envs/scene.py
index 24c71e8b8..f179ce456 100644
--- a/mani_skill/envs/scene.py
+++ b/mani_skill/envs/scene.py
@@ -1,5 +1,5 @@
from functools import cached_property
-from typing import Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import sapien
@@ -8,7 +8,7 @@
import torch
from sapien.render import RenderCameraComponent
-from mani_skill.render import SAPIEN_RENDER_SYSTEM, set_shader_pack
+from mani_skill.render import SAPIEN_RENDER_SYSTEM
from mani_skill.sensors.base_sensor import BaseSensor
from mani_skill.sensors.camera import Camera
from mani_skill.utils import common, sapien_utils
@@ -40,11 +40,10 @@ class ManiSkillScene:
def __init__(
self,
sub_scenes: List[sapien.Scene] = None,
- sim_cfg: SimConfig = SimConfig(),
+ sim_config: SimConfig = SimConfig(),
debug_mode: bool = True,
device: Device = None,
parallel_in_single_scene: bool = False,
- shader_dir: str = "default",
):
if sub_scenes is None:
sub_scenes = [sapien.Scene()]
@@ -52,11 +51,10 @@ def __init__(
self.px: Union[physx.PhysxCpuSystem, physx.PhysxGpuSystem] = self.sub_scenes[
0
].physx_system
- self.sim_cfg = sim_cfg
+ self.sim_config = sim_config
self._gpu_sim_initialized = False
self.debug_mode = debug_mode
self.device = device
- self.shader_dir = shader_dir
self.render_system_group: sapien.render.RenderSystemGroup = None
self.camera_groups: Dict[str, sapien.render.RenderCameraGroup] = dict()
@@ -369,6 +367,7 @@ def _sapien_31_update_render(self):
scene.update_render()
self._setup_gpu_rendering()
self._gpu_setup_sensors(self.sensors)
+ self._gpu_setup_sensors(self.human_render_cameras)
manager: sapien.render.GpuSyncManager = self.render_system_group
manager.sync()
@@ -980,7 +979,7 @@ def _sapien_gpu_setup_sensors(self, sensors: Dict[str, BaseSensor]):
try:
camera_group = self.render_system_group.create_camera_group(
sensor.camera._render_cameras,
- sensor.texture_names,
+ list(sensor.config.shader_config.texture_names.keys()),
)
except RuntimeError as e:
raise RuntimeError(
@@ -997,70 +996,60 @@ def _sapien_31_gpu_setup_sensors(self, sensors: dict[str, BaseSensor]):
for name, sensor in sensors.items():
if isinstance(sensor, Camera):
batch_renderer = sapien.render.RenderManager(
- sapien.render.get_shader_pack("default")
+ sapien.render.get_shader_pack(
+ sensor.config.shader_config.shader_pack
+ )
)
- batch_renderer.set_size(sensor.cfg.width, sensor.cfg.height)
-
+ batch_renderer.set_size(sensor.config.width, sensor.config.height)
batch_renderer.set_cameras(sensor.camera._render_cameras)
- batch_renderer.take_picture()
sensor.camera.camera_group = self.camera_groups[name] = batch_renderer
else:
raise NotImplementedError(
f"This sensor {sensor} of type {sensor.__class__} has not bget_picture_cuda implemented yet on the GPU"
)
- def get_sensor_obs(self) -> Dict[str, Dict[str, torch.Tensor]]:
- """Get raw sensor data for use as observations."""
- sensor_data = dict()
- for name, sensor in self.sensors.items():
- sensor_data[name] = sensor.get_obs()
- return sensor_data
-
- def get_sensor_images(self) -> Dict[str, Dict[str, torch.Tensor]]:
+ def get_sensor_images(
+ self, obs: Dict[str, Any]
+ ) -> Dict[str, Dict[str, torch.Tensor]]:
"""Get raw sensor data as images for visualization purposes."""
sensor_data = dict()
for name, sensor in self.sensors.items():
- sensor_data[name] = sensor.get_images()
+ sensor_data[name] = sensor.get_images(obs[name])
return sensor_data
def get_human_render_camera_images(
self, camera_name: str = None
- ) -> Dict[str, Dict[str, torch.Tensor]]:
+ ) -> Dict[str, torch.Tensor]:
image_data = dict()
if physx.is_gpu_enabled():
if self.parallel_in_single_scene:
for name, camera in self.human_render_cameras.items():
camera.camera._render_cameras[0].take_picture()
- # TODO (stao): in the future shaders will be handled more cleanly
- if self.shader_dir == "default":
- rgb = common.to_tensor(
- camera.camera._render_cameras[0].get_picture("Color")
- )[None, ...]
- rgb = (rgb[..., :3]).to(torch.uint8)
- else:
- rgb = common.to_tensor(
- camera.camera._render_cameras[0].get_picture("Color")
- )[None, ...]
- rgb = (rgb[..., :3] * 255).to(torch.uint8)
+ rgb = camera.get_obs(
+ rgb=True, depth=False, segmentation=False, position=False
+ )["rgb"]
image_data[name] = rgb
else:
- for name in self.human_render_cameras.keys():
- camera_group = self.camera_groups[name]
+ for name, camera in self.human_render_cameras.items():
if camera_name is not None and name != camera_name:
continue
- camera_group.take_picture()
- rgb = (
- camera_group.get_picture_cuda("Color").torch()[..., :3].clone()
- )
+ assert camera.config.shader_config.shader_pack not in [
+ "rt",
+ "rt-fast",
+ "rt-med",
+ ], "ray tracing shaders do not work with parallel rendering"
+ camera.capture()
+ rgb = camera.get_obs(
+ rgb=True, depth=False, segmentation=False, position=False
+ )["rgb"]
image_data[name] = rgb
else:
for name, camera in self.human_render_cameras.items():
if camera_name is not None and name != camera_name:
continue
camera.capture()
- if self.shader_dir == "default":
- rgb = (camera.get_picture("Color")[..., :3]).to(torch.uint8)
- else:
- rgb = (camera.get_picture("Color")[..., :3] * 255).to(torch.uint8)
+ rgb = camera.get_obs(
+ rgb=True, depth=False, segmentation=False, position=False
+ )["rgb"]
image_data[name] = rgb
return image_data
diff --git a/mani_skill/envs/scenes/base_env.py b/mani_skill/envs/scenes/base_env.py
index 82ceeecf3..eb7c33e7b 100644
--- a/mani_skill/envs/scenes/base_env.py
+++ b/mani_skill/envs/scenes/base_env.py
@@ -75,7 +75,7 @@ def __init__(
def _default_sim_config(self):
return SimConfig(
spacing=50,
- gpu_memory_cfg=GPUMemoryConfig(
+ gpu_memory_config=GPUMemoryConfig(
found_lost_pairs_capacity=2**25,
max_rigid_patch_count=2**21,
max_rigid_contact_count=2**23,
diff --git a/mani_skill/envs/tasks/control/cartpole.py b/mani_skill/envs/tasks/control/cartpole.py
index 71b0d1e27..c3f24bb02 100644
--- a/mani_skill/envs/tasks/control/cartpole.py
+++ b/mani_skill/envs/tasks/control/cartpole.py
@@ -91,7 +91,7 @@ def __init__(self, *args, robot_uids=CartPoleRobot, **kwargs):
def _default_sim_config(self):
return SimConfig(
spacing=20,
- scene_cfg=SceneConfig(
+ scene_config=SceneConfig(
solver_position_iterations=4, solver_velocity_iterations=0
),
)
diff --git a/mani_skill/envs/tasks/control/hopper.py b/mani_skill/envs/tasks/control/hopper.py
index 57cfcadd7..71d000324 100644
--- a/mani_skill/envs/tasks/control/hopper.py
+++ b/mani_skill/envs/tasks/control/hopper.py
@@ -111,7 +111,7 @@ def __init__(self, *args, robot_uids=HopperRobot, **kwargs):
@property
def _default_sim_config(self):
return SimConfig(
- scene_cfg=SceneConfig(
+ scene_config=SceneConfig(
solver_position_iterations=4, solver_velocity_iterations=1
),
sim_freq=100,
diff --git a/mani_skill/envs/tasks/dexterity/rotate_single_object_in_hand.py b/mani_skill/envs/tasks/dexterity/rotate_single_object_in_hand.py
index a37ccb17e..517edd238 100644
--- a/mani_skill/envs/tasks/dexterity/rotate_single_object_in_hand.py
+++ b/mani_skill/envs/tasks/dexterity/rotate_single_object_in_hand.py
@@ -69,7 +69,7 @@ def __init__(
@property
def _default_sim_config(self):
return SimConfig(
- gpu_memory_cfg=GPUMemoryConfig(
+ gpu_memory_config=GPUMemoryConfig(
max_rigid_contact_count=self.num_envs * max(1024, self.num_envs) * 8,
max_rigid_patch_count=self.num_envs * max(1024, self.num_envs) * 2,
found_lost_pairs_capacity=2**26,
diff --git a/mani_skill/envs/tasks/fmb/fmb.py b/mani_skill/envs/tasks/fmb/fmb.py
index 6a05ce304..7009d9c86 100644
--- a/mani_skill/envs/tasks/fmb/fmb.py
+++ b/mani_skill/envs/tasks/fmb/fmb.py
@@ -48,7 +48,7 @@ def __init__(self, *args, robot_uids="panda", robot_init_qpos_noise=0.02, **kwar
@property
def _default_sim_config(self):
return SimConfig(
- gpu_memory_cfg=GPUMemoryConfig(
+ gpu_memory_config=GPUMemoryConfig(
max_rigid_contact_count=2**21, max_rigid_patch_count=2**20
)
)
diff --git a/mani_skill/envs/tasks/humanoid/humanoid_pick_place.py b/mani_skill/envs/tasks/humanoid/humanoid_pick_place.py
index bbb35c476..efb3a2ca1 100644
--- a/mani_skill/envs/tasks/humanoid/humanoid_pick_place.py
+++ b/mani_skill/envs/tasks/humanoid/humanoid_pick_place.py
@@ -201,13 +201,13 @@ def __init__(self, *args, **kwargs):
@property
def _default_sim_config(self):
return SimConfig(
- gpu_memory_cfg=GPUMemoryConfig(
+ gpu_memory_config=GPUMemoryConfig(
max_rigid_contact_count=2**22, max_rigid_patch_count=2**21
),
# TODO (stao): G1 robot may need some custom collision disabling as the dextrous fingers may often be close to each other
# and slow down simulation. A temporary fix is to reduce contact_offset value down so that we don't check so many possible
# collisions
- scene_cfg=SceneConfig(contact_offset=0.01),
+ scene_config=SceneConfig(contact_offset=0.01),
)
def _initialize_episode(self, env_idx: torch.Tensor, options: Dict):
diff --git a/mani_skill/envs/tasks/humanoid/humanoid_stand.py b/mani_skill/envs/tasks/humanoid/humanoid_stand.py
index 72e989ab7..4d6717b1a 100644
--- a/mani_skill/envs/tasks/humanoid/humanoid_stand.py
+++ b/mani_skill/envs/tasks/humanoid/humanoid_stand.py
@@ -4,7 +4,7 @@
import sapien
import torch
-from mani_skill.agents.robots import UnitreeH1Simplified, UnitreeG1Simplified
+from mani_skill.agents.robots import UnitreeG1Simplified, UnitreeH1Simplified
from mani_skill.envs.sapien_env import BaseEnv
from mani_skill.sensors.camera import CameraConfig
from mani_skill.utils import common, sapien_utils
@@ -74,7 +74,7 @@ def __init__(self, *args, robot_uids="unitree_h1_simplified", **kwargs):
@property
def _default_sim_config(self):
return SimConfig(
- gpu_memory_cfg=GPUMemoryConfig(
+ gpu_memory_config=GPUMemoryConfig(
max_rigid_contact_count=2**22, max_rigid_patch_count=2**21
)
)
@@ -89,11 +89,9 @@ def _initialize_episode(self, env_idx: torch.Tensor, options: dict):
b = len(env_idx)
standing_keyframe = self.agent.keyframes["standing"]
random_qpos = (
- torch.randn(
- size=(b, self.agent.robot.dof[0]), dtype=torch.float) * 0.05
+ torch.randn(size=(b, self.agent.robot.dof[0]), dtype=torch.float) * 0.05
)
- random_qpos += common.to_tensor(standing_keyframe.qpos,
- device=self.device)
+ random_qpos += common.to_tensor(standing_keyframe.qpos, device=self.device)
self.agent.robot.set_qpos(random_qpos)
self.agent.robot.set_pose(sapien.Pose(p=[0, 0, 0.975]))
@@ -109,7 +107,7 @@ def __init__(self, *args, robot_uids="unitree_g1_simplified_legs", **kwargs):
@property
def _default_sim_config(self):
return SimConfig(
- gpu_memory_cfg=GPUMemoryConfig(
+ gpu_memory_config=GPUMemoryConfig(
max_rigid_contact_count=2**22, max_rigid_patch_count=2**21
)
)
diff --git a/mani_skill/envs/tasks/mobile_manipulation/open_cabinet_drawer.py b/mani_skill/envs/tasks/mobile_manipulation/open_cabinet_drawer.py
index e6cef00a7..c37cc0104 100644
--- a/mani_skill/envs/tasks/mobile_manipulation/open_cabinet_drawer.py
+++ b/mani_skill/envs/tasks/mobile_manipulation/open_cabinet_drawer.py
@@ -72,7 +72,7 @@ def __init__(
def _default_sim_config(self):
return SimConfig(
spacing=10,
- gpu_memory_cfg=GPUMemoryConfig(
+ gpu_memory_config=GPUMemoryConfig(
max_rigid_contact_count=2**21, max_rigid_patch_count=2**19
),
)
diff --git a/mani_skill/envs/tasks/quadruped/quadruped_reach.py b/mani_skill/envs/tasks/quadruped/quadruped_reach.py
index ffcf18249..ce4d27ed1 100644
--- a/mani_skill/envs/tasks/quadruped/quadruped_reach.py
+++ b/mani_skill/envs/tasks/quadruped/quadruped_reach.py
@@ -28,8 +28,8 @@ def __init__(self, *args, robot_uids="anymal-c", **kwargs):
@property
def _default_sim_config(self):
return SimConfig(
- gpu_memory_cfg=GPUMemoryConfig(max_rigid_contact_count=2**20),
- scene_cfg=SceneConfig(
+ gpu_memory_config=GPUMemoryConfig(max_rigid_contact_count=2**20),
+ scene_config=SceneConfig(
solver_position_iterations=4, solver_velocity_iterations=0
),
)
diff --git a/mani_skill/envs/tasks/quadruped/quadruped_spin.py b/mani_skill/envs/tasks/quadruped/quadruped_spin.py
index 13f81405c..c64387762 100644
--- a/mani_skill/envs/tasks/quadruped/quadruped_spin.py
+++ b/mani_skill/envs/tasks/quadruped/quadruped_spin.py
@@ -28,8 +28,8 @@ def __init__(self, *args, robot_uids="anymal-c", **kwargs):
@property
def _default_sim_config(self):
return SimConfig(
- gpu_memory_cfg=GPUMemoryConfig(max_rigid_contact_count=2**20),
- scene_cfg=SceneConfig(
+ gpu_memory_config=GPUMemoryConfig(max_rigid_contact_count=2**20),
+ scene_config=SceneConfig(
solver_position_iterations=4, solver_velocity_iterations=0
),
)
diff --git a/mani_skill/envs/tasks/rotate_cube.py b/mani_skill/envs/tasks/rotate_cube.py
index ef20b9419..6db451cc2 100644
--- a/mani_skill/envs/tasks/rotate_cube.py
+++ b/mani_skill/envs/tasks/rotate_cube.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Tuple, Union
+from typing import Any, Dict, Tuple
import numpy as np
import torch
@@ -26,11 +26,11 @@ class RotateCubeEnv(BaseEnv):
SUPPORTED_ROBOTS = ["trifingerpro"]
# Specify some supported robot types
- agent: Union[TriFingerPro]
+ agent: TriFingerPro
# Specify default simulation/gpu memory configurations.
- sim_cfg = SimConfig(
- gpu_memory_cfg=GPUMemoryConfig(
+ sim_config = SimConfig(
+ gpu_memory_config=GPUMemoryConfig(
found_lost_pairs_capacity=2**25, max_rigid_patch_count=2**18
)
)
diff --git a/mani_skill/envs/tasks/tabletop/assembling_kits.py b/mani_skill/envs/tasks/tabletop/assembling_kits.py
index 9187f63d3..f0719b44d 100644
--- a/mani_skill/envs/tasks/tabletop/assembling_kits.py
+++ b/mani_skill/envs/tasks/tabletop/assembling_kits.py
@@ -67,7 +67,7 @@ def __init__(
@property
def _default_sim_config(self):
return SimConfig(
- gpu_memory_cfg=GPUMemoryConfig(max_rigid_contact_count=2**20)
+ gpu_memory_config=GPUMemoryConfig(max_rigid_contact_count=2**20)
)
@property
diff --git a/mani_skill/envs/tasks/tabletop/pick_clutter_ycb.py b/mani_skill/envs/tasks/tabletop/pick_clutter_ycb.py
index 7537ad78f..89a42880d 100644
--- a/mani_skill/envs/tasks/tabletop/pick_clutter_ycb.py
+++ b/mani_skill/envs/tasks/tabletop/pick_clutter_ycb.py
@@ -68,7 +68,7 @@ def __init__(
@property
def _default_sim_config(self):
return SimConfig(
- gpu_memory_cfg=GPUMemoryConfig(
+ gpu_memory_config=GPUMemoryConfig(
max_rigid_contact_count=2**21, max_rigid_patch_count=2**19
)
)
@@ -119,14 +119,14 @@ def _load_scene(self, options: dict):
for i, eps_idx in enumerate(eps_idxs):
self.selectable_target_objects.append([])
episode = self._episodes[eps_idx]
- for actor_cfg in episode["actors"]:
- builder = self._load_model(actor_cfg["model_id"])
- init_pose = actor_cfg["pose"]
+ for actor_config in episode["actors"]:
+ builder = self._load_model(actor_config["model_id"])
+ init_pose = actor_config["pose"]
builder.initial_pose = sapien.Pose(p=init_pose[:3], q=init_pose[3:])
builder.set_scene_idxs([i])
- obj = builder.build(name=f"set_{i}_{actor_cfg['model_id']}")
+ obj = builder.build(name=f"set_{i}_{actor_config['model_id']}")
all_objects.append(obj)
- if actor_cfg["rep_pts"] is not None:
+ if actor_config["rep_pts"] is not None:
# rep_pts is representative points, representing visible points
# we only permit selecting target objects that are visible
self.selectable_target_objects[-1].append(obj)
diff --git a/mani_skill/envs/tasks/tabletop/place_sphere.py b/mani_skill/envs/tasks/tabletop/place_sphere.py
index 21997cc33..bd38f5b60 100644
--- a/mani_skill/envs/tasks/tabletop/place_sphere.py
+++ b/mani_skill/envs/tasks/tabletop/place_sphere.py
@@ -1,14 +1,16 @@
from typing import Any, Dict, Union
+import gymnasium as gym
+import matplotlib.pyplot as plt
import numpy as np
+import sapien
import torch
import torch.random
-import sapien
from transforms3d.euler import euler2quat
-from mani_skill.envs.utils import randomization
from mani_skill.agents.robots import Fetch, Panda, Xmate3Robotiq
from mani_skill.envs.sapien_env import BaseEnv
+from mani_skill.envs.utils import randomization
from mani_skill.sensors.camera import CameraConfig
from mani_skill.utils import common, sapien_utils
from mani_skill.utils.building import actors
@@ -17,8 +19,6 @@
from mani_skill.utils.structs import Pose
from mani_skill.utils.structs.types import Array, GPUMemoryConfig, SimConfig
-import matplotlib.pyplot as plt
-import gymnasium as gym
@register_env("PlaceSphere-v1", max_episode_steps=50)
class PlaceSphereEnv(BaseEnv):
@@ -30,7 +30,7 @@ class PlaceSphereEnv(BaseEnv):
Randomizations
--------------
The position of the bin and the sphere are randomized: The bin is inited in [0, 0.1]x[-0.1, 0.1], and the sphere is inited in [-0.1, -0.05]x[-0.1, 0.1]
-
+
Success Conditions
------------------
The sphere is place on the top of the bin. The robot remains static and the gripper is not closed at the end state
@@ -42,12 +42,20 @@ class PlaceSphereEnv(BaseEnv):
agent: Union[Panda, Xmate3Robotiq, Fetch]
# set some commonly used values
- radius = 0.02 # radius of the sphere
- inner_side_half_len = 0.02 # side length of the bin's inner square
- short_side_half_size = 0.0025 # length of the shortest edge of the block
- block_half_size = [short_side_half_size, 2*short_side_half_size+inner_side_half_len, 2*short_side_half_size+inner_side_half_len] # The bottom block of the bin, which is larger: The list represents the half length of the block along the [x, y, z] axis respectively.
- edge_block_half_size = [short_side_half_size, 2*short_side_half_size+inner_side_half_len, 2*short_side_half_size] # The edge block of the bin, which is smaller. The representations are similar to the above one
-
+ radius = 0.02 # radius of the sphere
+ inner_side_half_len = 0.02 # side length of the bin's inner square
+ short_side_half_size = 0.0025 # length of the shortest edge of the block
+ block_half_size = [
+ short_side_half_size,
+ 2 * short_side_half_size + inner_side_half_len,
+ 2 * short_side_half_size + inner_side_half_len,
+ ] # The bottom block of the bin, which is larger: The list represents the half length of the block along the [x, y, z] axis respectively.
+ edge_block_half_size = [
+ short_side_half_size,
+ 2 * short_side_half_size + inner_side_half_len,
+ 2 * short_side_half_size,
+ ] # The edge block of the bin, which is smaller. The representations are similar to the above one
+
def __init__(self, *args, robot_uids="panda", robot_init_qpos_noise=0.02, **kwargs):
self.robot_init_qpos_noise = robot_init_qpos_noise
super().__init__(*args, robot_uids=robot_uids, **kwargs)
@@ -55,7 +63,7 @@ def __init__(self, *args, robot_uids="panda", robot_init_qpos_noise=0.02, **kwar
@property
def _default_sim_config(self):
return SimConfig(
- gpu_memory_cfg=GPUMemoryConfig(
+ gpu_memory_config=GPUMemoryConfig(
found_lost_pairs_capacity=2**25, max_rigid_patch_count=2**18
)
)
@@ -81,15 +89,15 @@ def _default_human_render_camera_configs(self):
return CameraConfig(
"render_camera", pose=pose, width=512, height=512, fov=1, near=0.01, far=100
)
-
+
def _build_bin(self, radius):
builder = self.scene.create_actor_builder()
-
+
# init the locations of the basic blocks
- dx = self.block_half_size[1] - self.block_half_size[0]
- dy = self.block_half_size[1] - self.block_half_size[0]
+ dx = self.block_half_size[1] - self.block_half_size[0]
+ dy = self.block_half_size[1] - self.block_half_size[0]
dz = self.edge_block_half_size[2] + self.block_half_size[0]
-
+
# build the bin bottom and edge blocks
poses = [
sapien.Pose([0, 0, 0]),
@@ -102,14 +110,22 @@ def _build_bin(self, radius):
[self.block_half_size[1], self.block_half_size[2], self.block_half_size[0]],
self.edge_block_half_size,
self.edge_block_half_size,
- [self.edge_block_half_size[1], self.edge_block_half_size[0], self.edge_block_half_size[2]],
- [self.edge_block_half_size[1], self.edge_block_half_size[0], self.edge_block_half_size[2]],
+ [
+ self.edge_block_half_size[1],
+ self.edge_block_half_size[0],
+ self.edge_block_half_size[2],
+ ],
+ [
+ self.edge_block_half_size[1],
+ self.edge_block_half_size[0],
+ self.edge_block_half_size[2],
+ ],
]
for pose, half_size in zip(poses, half_sizes):
builder.add_box_collision(pose, half_size)
builder.add_box_visual(pose, half_size)
- # build the kinematic bin
+ # build the kinematic bin
return builder.build_kinematic(name="bin")
def _load_scene(self, options: dict):
@@ -118,7 +134,7 @@ def _load_scene(self, options: dict):
env=self, robot_init_qpos_noise=self.robot_init_qpos_noise
)
self.table_scene.build()
-
+
# load the sphere
self.obj = actors.build_sphere(
self.scene,
@@ -127,7 +143,7 @@ def _load_scene(self, options: dict):
name="sphere",
body_type="dynamic",
)
-
+
# load the bin
self.bin = self._build_bin(self.radius)
@@ -136,21 +152,29 @@ def _initialize_episode(self, env_idx: torch.Tensor, options: dict):
# init the table scene
b = len(env_idx)
self.table_scene.initialize(env_idx)
-
+
# init the sphere in the first 1/4 zone along the x-axis (so that it doesn't collide the bin)
xyz = torch.zeros((b, 3))
- xyz[..., 0] = (torch.rand((b, 1)) * 0.05 - 0.1)[..., 0] # first 1/4 zone of x ([-0.1, -0.05])
- xyz[..., 1] = (torch.rand((b, 1)) * 0.2 - 0.1)[..., 0] # spanning all possible ys
- xyz[..., 2] = self.radius # on the table
+ xyz[..., 0] = (torch.rand((b, 1)) * 0.05 - 0.1)[
+ ..., 0
+ ] # first 1/4 zone of x ([-0.1, -0.05])
+ xyz[..., 1] = (torch.rand((b, 1)) * 0.2 - 0.1)[
+ ..., 0
+ ] # spanning all possible ys
+ xyz[..., 2] = self.radius # on the table
q = [1, 0, 0, 0]
obj_pose = Pose.create_from_pq(p=xyz, q=q)
self.obj.set_pose(obj_pose)
# init the bin in the last 1/2 zone along the x-axis (so that it doesn't collide the sphere)
pos = torch.zeros((b, 3))
- pos[:, 0] = torch.rand((b, 1))[..., 0] * 0.1 # the last 1/2 zone of x ([0, 0.1])
- pos[:, 1] = torch.rand((b, 1))[..., 0] * 0.2 - 0.1 # spanning all possible ys
- pos[:, 2] = self.block_half_size[0] # on the table
+ pos[:, 0] = (
+ torch.rand((b, 1))[..., 0] * 0.1
+ ) # the last 1/2 zone of x ([0, 0.1])
+ pos[:, 1] = (
+ torch.rand((b, 1))[..., 0] * 0.2 - 0.1
+ ) # spanning all possible ys
+ pos[:, 2] = self.block_half_size[0] # on the table
q = [1, 0, 0, 0]
bin_pose = Pose.create_from_pq(p=pos, q=q)
self.bin.set_pose(bin_pose)
@@ -159,11 +183,10 @@ def evaluate(self):
pos_obj = self.obj.pose.p
pos_bin = self.bin.pose.p
offset = pos_obj - pos_bin
- xy_flag = (
- torch.linalg.norm(offset[..., :2], axis=1)
- <= 0.005
+ xy_flag = torch.linalg.norm(offset[..., :2], axis=1) <= 0.005
+ z_flag = (
+ torch.abs(offset[..., 2] - self.radius - self.block_half_size[0]) <= 0.005
)
- z_flag = torch.abs(offset[..., 2] - self.radius - self.block_half_size[0]) <= 0.005
is_obj_on_bin = torch.logical_and(xy_flag, z_flag)
is_obj_static = self.obj.is_static(lin_thresh=1e-2, ang_thresh=0.5)
is_obj_grasped = self.agent.is_grasping(self.obj)
@@ -179,7 +202,7 @@ def _get_obs_extra(self, info: Dict):
obs = dict(
is_grasped=info["is_obj_grasped"],
tcp_pose=self.agent.tcp.pose.raw_pose,
- bin_pos=self.bin.pose.p
+ bin_pos=self.bin.pose.p,
)
if "state" in self.obs_mode:
obs.update(
@@ -197,7 +220,7 @@ def compute_dense_reward(self, obs: Any, action: torch.Tensor, info: Dict):
# grasp and place reward
obj_pos = self.obj.pose.p
- bin_pos = self.bin.pose.p
+ self.bin.pose.p
bin_top_pos = self.bin.pose.p.clone()
bin_top_pos[:, 2] = bin_top_pos[:, 2] + self.block_half_size[0] + self.radius
obj_to_bin_top_dist = torch.linalg.norm(bin_top_pos - obj_pos, axis=1)
@@ -205,22 +228,24 @@ def compute_dense_reward(self, obs: Any, action: torch.Tensor, info: Dict):
reward[info["is_obj_grasped"]] = (4 + place_reward)[info["is_obj_grasped"]]
# ungrasp and static reward
- gripper_width = (self.agent.robot.get_qlimits()[0, -1, 1] * 2).to(
- self.device
- )
+ gripper_width = (self.agent.robot.get_qlimits()[0, -1, 1] * 2).to(self.device)
is_obj_grasped = info["is_obj_grasped"]
ungrasp_reward = (
torch.sum(self.agent.robot.get_qpos()[:, -2:], axis=1) / gripper_width
)
- ungrasp_reward[~is_obj_grasped] = 16.0 # give ungrasp a bigger reward, so that it exceeds the robot static reward and the gripper can close
+ ungrasp_reward[
+ ~is_obj_grasped
+ ] = 16.0 # give ungrasp a bigger reward, so that it exceeds the robot static reward and the gripper can close
v = torch.linalg.norm(self.obj.linear_velocity, axis=1)
av = torch.linalg.norm(self.obj.angular_velocity, axis=1)
static_reward = 1 - torch.tanh(v * 10 + av)
- robot_static_reward = self.agent.is_static(0.2) # keep the robot static at the end state, since the sphere may spin when being placed on top
+ robot_static_reward = self.agent.is_static(
+ 0.2
+ ) # keep the robot static at the end state, since the sphere may spin when being placed on top
reward[info["is_obj_on_bin"]] = (
6 + (ungrasp_reward + static_reward + robot_static_reward) / 3.0
)[info["is_obj_on_bin"]]
-
+
# success reward
reward[info["success"]] = 13
return reward
@@ -229,7 +254,3 @@ def compute_normalized_dense_reward(self, obs: Any, action: Array, info: Dict):
# this should be equal to compute_dense_reward / max possible reward
max_reward = 13.0
return self.compute_dense_reward(obs=obs, action=action, info=info) / max_reward
-
-
-
-
diff --git a/mani_skill/envs/tasks/tabletop/push_cube.py b/mani_skill/envs/tasks/tabletop/push_cube.py
index 9a97ae43c..ff5408ee0 100644
--- a/mani_skill/envs/tasks/tabletop/push_cube.py
+++ b/mani_skill/envs/tasks/tabletop/push_cube.py
@@ -70,7 +70,7 @@ def __init__(self, *args, robot_uids="panda", robot_init_qpos_noise=0.02, **kwar
@property
def _default_sim_config(self):
return SimConfig(
- gpu_memory_cfg=GPUMemoryConfig(
+ gpu_memory_config=GPUMemoryConfig(
found_lost_pairs_capacity=2**25, max_rigid_patch_count=2**18
)
)
diff --git a/mani_skill/envs/tasks/tabletop/push_t.py b/mani_skill/envs/tasks/tabletop/push_t.py
index 7516c8918..9b51a0ee9 100644
--- a/mani_skill/envs/tasks/tabletop/push_t.py
+++ b/mani_skill/envs/tasks/tabletop/push_t.py
@@ -1,6 +1,7 @@
from typing import Any, Dict, Union
import numpy as np
+import sapien
import torch
import torch.random
from transforms3d.euler import euler2quat
@@ -14,7 +15,7 @@
from mani_skill.utils.scene_builder.table import TableSceneBuilder
from mani_skill.utils.structs import Pose
from mani_skill.utils.structs.types import Array, GPUMemoryConfig, SimConfig
-import sapien
+
# extending TableSceneBuilder and only making 2 changes:
# 1.Making table smooth and white, 2. adding support for keyframes of new robots - panda stick
@@ -23,7 +24,17 @@ def initialize(self, env_idx: torch.Tensor):
super().initialize(env_idx)
b = len(env_idx)
if self.env.robot_uids == "panda_stick":
- qpos = np.array([0.662,0.212,0.086,-2.685,-.115,2.898,1.673,])
+ qpos = np.array(
+ [
+ 0.662,
+ 0.212,
+ 0.086,
+ -2.685,
+ -0.115,
+ 2.898,
+ 1.673,
+ ]
+ )
qpos = (
self.env._episode_rng.normal(
0, self.robot_init_qpos_noise, (b, len(qpos))
@@ -32,11 +43,16 @@ def initialize(self, env_idx: torch.Tensor):
)
self.env.agent.reset(qpos)
self.env.agent.robot.set_pose(sapien.Pose([-0.615, 0, 0]))
+
def build(self):
super().build()
- #cheap way to un-texture table
+ # cheap way to un-texture table
for part in self.table._objs:
- for triangle in part.find_component_by_type(sapien.render.RenderBodyComponent).render_shapes[0].parts:
+ for triangle in (
+ part.find_component_by_type(sapien.render.RenderBodyComponent)
+ .render_shapes[0]
+ .parts
+ ):
triangle.material.set_base_color(np.array([255, 255, 255, 255]) / 255)
triangle.material.set_base_color_texture(None)
triangle.material.set_normal_texture(None)
@@ -45,6 +61,7 @@ def build(self):
triangle.material.set_metallic_texture(None)
triangle.material.set_roughness_texture(None)
+
@register_env("PushT-v1", max_episode_steps=100)
class PushTEnv(BaseEnv):
"""
@@ -77,7 +94,7 @@ class PushTEnv(BaseEnv):
- End-effector initial position [-0.322, 0.284, 0.024]
- intersection % threshold for success 90%
- Table View Camera parameters sapien_utils.look_at(eye=[0.3, 0, 0.6], target=[-0.1, 0, 0.1])
-
+
TODO's (xhin):
--------------
- Add hand mounted camera for panda_stick robot, for visual rl
@@ -102,8 +119,8 @@ class PushTEnv(BaseEnv):
# Hand crafted params to match visual of real life setup
# T Goal initial position on table
- goal_offset = torch.tensor([-0.156,-0.1])
- goal_z_rot = (5/3)*np.pi
+ goal_offset = torch.tensor([-0.156, -0.1])
+ goal_z_rot = (5 / 3) * np.pi
# end effector goal - NOTE that chaning this will not change the actual
# ee starting position of the robot - need to change joint position resting
@@ -115,19 +132,21 @@ class PushTEnv(BaseEnv):
# intersection threshold for success in T position
intersection_thresh = 0.90
- #T block design choices
+ # T block design choices
T_mass = 0.8
T_dynamic_friction = 3
T_static_friction = 3
- def __init__(self, *args, robot_uids="panda_stick", robot_init_qpos_noise=0.02,**kwargs):
+ def __init__(
+ self, *args, robot_uids="panda_stick", robot_init_qpos_noise=0.02, **kwargs
+ ):
self.robot_init_qpos_noise = robot_init_qpos_noise
super().__init__(*args, robot_uids=robot_uids, **kwargs)
@property
def _default_sim_config(self):
return SimConfig(
- gpu_memory_cfg=GPUMemoryConfig(
+ gpu_memory_config=GPUMemoryConfig(
found_lost_pairs_capacity=2**25, max_rigid_patch_count=2**18
)
)
@@ -165,63 +184,89 @@ def _load_scene(self, options: dict):
env=self, robot_init_qpos_noise=self.robot_init_qpos_noise
)
self.table_scene.build()
-
+
# returns 3d cad of create_tee - center of mass at (0,0,0)
# cad Tee is upside down (both 3D tee and target)
- TARGET_RED = np.array([194, 19, 22, 255]) / 255 # same as mani_skill.utils.building.actors.common - goal target
+ TARGET_RED = (
+ np.array([194, 19, 22, 255]) / 255
+ ) # same as mani_skill.utils.building.actors.common - goal target
+
def create_tee(name="tee", target=False, base_color=TARGET_RED):
- # dimensions of boxes that make tee
+ # dimensions of boxes that make tee
# box2 is same as box1, except (3/4) the lenght, and rotated 90 degrees
# these dimensions are an exact replica of the 3D tee model given by diffusion policy: https://cad.onshape.com/documents/f1140134e38f6ed6902648d5/w/a78cf81827600e4ff4058d03/e/f35f57fb7589f72e05c76caf
- box1_half_w = 0.2/2
- box1_half_h = 0.05/2
- half_thickness = 0.04/2 if not target else 1e-4
+ box1_half_w = 0.2 / 2
+ box1_half_h = 0.05 / 2
+ half_thickness = 0.04 / 2 if not target else 1e-4
# we have to center tee at its com so rotations are applied to com
# vertical block is (3/4) size of horizontal block, so
# center of mass is (1*com_horiz + (3/4)*com_vert) / (1+(3/4))
# # center of mass is (1*(0,0)) + (3/4)*(0,(.025+.15)/2)) / (1+(3/4)) = (0,0.0375)
com_y = 0.0375
-
+
builder = self.scene.create_actor_builder()
- first_block_pose = sapien.Pose([0., 0.-com_y, 0.])
+ first_block_pose = sapien.Pose([0.0, 0.0 - com_y, 0.0])
first_block_size = [box1_half_w, box1_half_h, half_thickness]
if not target:
builder._mass = self.T_mass
tee_material = sapien.pysapien.physx.PhysxMaterial(
- static_friction=self.T_dynamic_friction,
- dynamic_friction=self.T_static_friction,
- restitution=0
+ static_friction=self.T_dynamic_friction,
+ dynamic_friction=self.T_static_friction,
+ restitution=0,
+ )
+ builder.add_box_collision(
+ pose=first_block_pose,
+ half_size=first_block_size,
+ material=tee_material,
)
- builder.add_box_collision(pose=first_block_pose, half_size=first_block_size, material=tee_material)
- #builder.add_box_collision(pose=first_block_pose, half_size=first_block_size)
- builder.add_box_visual(pose=first_block_pose, half_size=first_block_size, material=sapien.render.RenderMaterial(
- base_color=base_color,
- ),)
+ # builder.add_box_collision(pose=first_block_pose, half_size=first_block_size)
+ builder.add_box_visual(
+ pose=first_block_pose,
+ half_size=first_block_size,
+ material=sapien.render.RenderMaterial(
+ base_color=base_color,
+ ),
+ )
# for the second block (vertical part), we translate y by 4*(box1_half_h)-com_y to align flush with horizontal block
# note that the cad model tee made here is upside down
- second_block_pose = sapien.Pose([0., 4*(box1_half_h)-com_y, 0.])
- second_block_size = [box1_half_h, (3/4)*(box1_half_w), half_thickness]
+ second_block_pose = sapien.Pose([0.0, 4 * (box1_half_h) - com_y, 0.0])
+ second_block_size = [box1_half_h, (3 / 4) * (box1_half_w), half_thickness]
if not target:
- builder.add_box_collision(pose=second_block_pose, half_size=second_block_size,material=tee_material)
- #builder.add_box_collision(pose=second_block_pose, half_size=second_block_size)
- builder.add_box_visual(pose=second_block_pose, half_size=second_block_size, material=sapien.render.RenderMaterial(
- base_color=base_color,
- ),)
+ builder.add_box_collision(
+ pose=second_block_pose,
+ half_size=second_block_size,
+ material=tee_material,
+ )
+ # builder.add_box_collision(pose=second_block_pose, half_size=second_block_size)
+ builder.add_box_visual(
+ pose=second_block_pose,
+ half_size=second_block_size,
+ material=sapien.render.RenderMaterial(
+ base_color=base_color,
+ ),
+ )
if not target:
return builder.build(name=name)
- else: return builder.build_kinematic(name=name)
+ else:
+ return builder.build_kinematic(name=name)
self.tee = create_tee(name="Tee", target=False)
- self.goal_tee = create_tee(name="goal_Tee", target=True, base_color=np.array([128,128,128,255])/255)
+ self.goal_tee = create_tee(
+ name="goal_Tee",
+ target=True,
+ base_color=np.array([128, 128, 128, 255]) / 255,
+ )
# adding end-effector end-episode goal position
builder = self.scene.create_actor_builder()
builder.add_cylinder_visual(
radius=0.02,
half_length=1e-4,
- material=sapien.render.RenderMaterial(base_color=np.array([128, 128, 128, 255]) / 255),
+ material=sapien.render.RenderMaterial(
+ base_color=np.array([128, 128, 128, 255]) / 255
+ ),
)
self.ee_goal_pos = builder.build_kinematic(name="goal_ee")
@@ -230,52 +275,72 @@ def create_tee(name="tee", target=False, base_color=TARGET_RED):
uv_half_width = 0.15
self.uv_half_width = uv_half_width
self.res = res
- oned_grid = (torch.arange(res, dtype=torch.float32).view(1,res).repeat(res,1) - (res/2))
- self.uv_grid = (torch.cat([oned_grid.unsqueeze(0), (-1*oned_grid.T).unsqueeze(0)], dim=0) + 0.5) / ((res/2)/uv_half_width)
+ oned_grid = torch.arange(res, dtype=torch.float32).view(1, res).repeat(
+ res, 1
+ ) - (res / 2)
+ self.uv_grid = (
+ torch.cat([oned_grid.unsqueeze(0), (-1 * oned_grid.T).unsqueeze(0)], dim=0)
+ + 0.5
+ ) / ((res / 2) / uv_half_width)
self.uv_grid = self.uv_grid.to(self.device)
- self.homo_uv = torch.cat([self.uv_grid, torch.ones_like(self.uv_grid[0]).unsqueeze(0)], dim=0)
-
+ self.homo_uv = torch.cat(
+ [self.uv_grid, torch.ones_like(self.uv_grid[0]).unsqueeze(0)], dim=0
+ )
+
# tee render
# tee is made of two different boxes, and then translated by center of mass
- self.center_of_mass = (0,0.0375) #in frame of upside tee with center of horizontal box (add cetner of mass to get to real tee frame)
- box1 = torch.tensor([[-0.1, 0.025], [0.1, 0.025], [-0.1, -0.025], [0.1, -0.025]])
- box2 = torch.tensor([[-0.025, 0.175], [0.025, 0.175], [-0.025, 0.025], [0.025, 0.025]])
+ self.center_of_mass = (
+ 0,
+ 0.0375,
+ ) # in frame of upside tee with center of horizontal box (add cetner of mass to get to real tee frame)
+ box1 = torch.tensor(
+ [[-0.1, 0.025], [0.1, 0.025], [-0.1, -0.025], [0.1, -0.025]]
+ )
+ box2 = torch.tensor(
+ [[-0.025, 0.175], [0.025, 0.175], [-0.025, 0.025], [0.025, 0.025]]
+ )
box1[:, 1] -= self.center_of_mass[1]
box2[:, 1] -= self.center_of_mass[1]
- #convert tee boxes to indices
- box1 *= ((res/2)/uv_half_width)
- box1 += (res/2)
+ # convert tee boxes to indices
+ box1 *= (res / 2) / uv_half_width
+ box1 += res / 2
- box2 *= ((res/2)/uv_half_width)
- box2 += (res/2)
+ box2 *= (res / 2) / uv_half_width
+ box2 += res / 2
box1 = box1.long()
box2 = box2.long()
- self.tee_render = torch.zeros(res,res)
+ self.tee_render = torch.zeros(res, res)
# image map has flipped x and y, set values in transpose to undo
- self.tee_render.T[box1[0,0]:box1[1,0], box1[2,1]:box1[0,1]] = 1
- self.tee_render.T[box2[0,0]:box2[1,0], box2[2,1]:box2[0,1]] = 1
+ self.tee_render.T[box1[0, 0] : box1[1, 0], box1[2, 1] : box1[0, 1]] = 1
+ self.tee_render.T[box2[0, 0] : box2[1, 0], box2[2, 1] : box2[0, 1]] = 1
# image map y is flipped of xy plane, flip to unflip
self.tee_render = self.tee_render.flip(0).to(self.device)
-
- goal_fake_quat = torch.tensor([(torch.tensor([self.goal_z_rot])/2).cos(),0,0,0.0]).unsqueeze(0)
- zrot = self.quat_to_zrot(goal_fake_quat).squeeze(0) # 3x3 rot matrix for goal to world transform
+
+ goal_fake_quat = torch.tensor(
+ [(torch.tensor([self.goal_z_rot]) / 2).cos(), 0, 0, 0.0]
+ ).unsqueeze(0)
+ zrot = self.quat_to_zrot(goal_fake_quat).squeeze(
+ 0
+ ) # 3x3 rot matrix for goal to world transform
goal_trans = torch.eye(3)
- goal_trans[:2,:2] = zrot[:2,:2]
+ goal_trans[:2, :2] = zrot[:2, :2]
goal_trans[0:2, 2] = self.goal_offset
- self.world_to_goal_trans = torch.linalg.inv(goal_trans).to(self.device) # this is just a 3x3 matrix (2d homogenious transform)
+ self.world_to_goal_trans = torch.linalg.inv(goal_trans).to(
+ self.device
+ ) # this is just a 3x3 matrix (2d homogenious transform)
def quat_to_z_euler(self, quats):
assert len(quats.shape) == 2 and quats.shape[-1] == 4
# z rotation == can be defined by just qw = cos(alpha/2), so alpha = 2*cos^{-1}(qw)
# for fixing quaternion double covering
- #for some reason, torch.sign() had bugs???
- signs = torch.ones_like(quats[:,-1])
- signs[quats[:,-1] < 0] = -1.0
- qw = quats[:,0] * signs
- z_euler = 2*qw.acos()
+ # for some reason, torch.sign() had bugs???
+ signs = torch.ones_like(quats[:, -1])
+ signs[quats[:, -1] < 0] = -1.0
+ qw = quats[:, 0] * signs
+ z_euler = 2 * qw.acos()
return z_euler
def quat_to_zrot(self, quats):
@@ -284,58 +349,78 @@ def quat_to_zrot(self, quats):
# output is batch of rotation matrices (b,3,3)
alphas = self.quat_to_z_euler(quats)
# constructing rot matrix with rotation around z
- rot_mats = torch.zeros(quats.shape[0], 3,3).to(quats.device)
- rot_mats[:,2,2] = 1
- rot_mats[:,0,0] = alphas.cos()
- rot_mats[:,1,1] = alphas.cos()
- rot_mats[:,0,1] = -alphas.sin()
- rot_mats[:,1,0] = alphas.sin()
+ rot_mats = torch.zeros(quats.shape[0], 3, 3).to(quats.device)
+ rot_mats[:, 2, 2] = 1
+ rot_mats[:, 0, 0] = alphas.cos()
+ rot_mats[:, 1, 1] = alphas.cos()
+ rot_mats[:, 0, 1] = -alphas.sin()
+ rot_mats[:, 1, 0] = alphas.sin()
return rot_mats
-
+
def pseudo_render_intersection(self):
"""'pseudo render' algo for calculating the intersection
- made custom 'psuedo renderer' to compute intersection area
+ made custom 'psuedo renderer' to compute intersection area
all computation in parallel on cuda, zero explicit loops
views blocks in 2d in the goal tee frame to see overlap"""
# we are given T_{a->w} where a == actor frame and w == world frame
# we are given T_{g->w} where g == goal frame and w == world frame
# applying T_{a->w} and then T_{w->g}, we get the actor's orientation in the goal tee's frame
# T_{w->g} is T_{g->w}^{-1}, we already have the goal's orientation, and it doesn't change
- tee_to_world_trans = self.quat_to_zrot(self.tee.pose.q) # should be (b,3,3) rot matrices
- tee_to_world_trans[:,0:2,2] = self.tee.pose.p[:,:2] # should be (b,3,3) rigid trans matrices
-
+ tee_to_world_trans = self.quat_to_zrot(
+ self.tee.pose.q
+ ) # should be (b,3,3) rot matrices
+ tee_to_world_trans[:, 0:2, 2] = self.tee.pose.p[
+ :, :2
+ ] # should be (b,3,3) rigid trans matrices
+
# these matrices convert egocentric 3d tee to 2d goal tee frame
- tee_to_goal_trans = self.world_to_goal_trans @ tee_to_world_trans # should be (b,3,3) rigid trans matrices
+ tee_to_goal_trans = (
+ self.world_to_goal_trans @ tee_to_world_trans
+ ) # should be (b,3,3) rigid trans matrices
- # making homogenious coords of uv map to apply transformations to view tee in goal tee frame
+ # making homogenious coords of uv map to apply transformations to view tee in goal tee frame
b = tee_to_world_trans.shape[0]
res = self.uv_grid.shape[1]
homo_uv = self.homo_uv
- #finally, get uv coordinates of tee in goal tee frame
- tees_in_goal_frame = (tee_to_goal_trans @ homo_uv.view(3,-1)).view(b,3,res,res)
+ # finally, get uv coordinates of tee in goal tee frame
+ tees_in_goal_frame = (tee_to_goal_trans @ homo_uv.view(3, -1)).view(
+ b, 3, res, res
+ )
# convert from homogenious coords to normal coords
- tees_in_goal_frame = tees_in_goal_frame[:,0:2,:,:] / tees_in_goal_frame[:,-1,:,:].unsqueeze(1) # now (b,2,res,res)
+ tees_in_goal_frame = tees_in_goal_frame[:, 0:2, :, :] / tees_in_goal_frame[
+ :, -1, :, :
+ ].unsqueeze(
+ 1
+ ) # now (b,2,res,res)
- #we now have a collection of coordinates xy that are the coordinates of the tees in the goal frame
+ # we now have a collection of coordinates xy that are the coordinates of the tees in the goal frame
# we just extract the indices in the uv map where the egocentic T is, to get the transformed T coords
# this works because while we transformed the coordinates of the uv map -
# the indices where the egocentric T is is still the indices of the T in the uv map (indices of uv map never chnaged, just values)
- tee_coords = tees_in_goal_frame[:, :, self.tee_render==1].view(b,2,-1) # (b,2,num_points_in_tee)
-
- #convert tee_coords to indices - this is basically a batch of indices - same shape as tee_coords
+ tee_coords = tees_in_goal_frame[:, :, self.tee_render == 1].view(
+ b, 2, -1
+ ) # (b,2,num_points_in_tee)
+
+ # convert tee_coords to indices - this is basically a batch of indices - same shape as tee_coords
# this is the inverse function of creating the uv map from image indices used in load_scene
- tee_indices = (tee_coords * ((res/2)/self.uv_half_width) + (res/2)).long().view(b,2,-1) # (b,2,num_points_in_tee)
+ tee_indices = (
+ (tee_coords * ((res / 2) / self.uv_half_width) + (res / 2))
+ .long()
+ .view(b, 2, -1)
+ ) # (b,2,num_points_in_tee)
# setting all of our work in image format to compare with egocentric image of goal T
- final_renders = torch.zeros(b,res,res).to(self.device)
+ final_renders = torch.zeros(b, res, res).to(self.device)
# for batch indexing
num_tee_pixels = tee_indices.shape[-1]
- batch_indices = torch.arange(b).view(-1,1).repeat(1,num_tee_pixels).to(self.device)
+ batch_indices = (
+ torch.arange(b).view(-1, 1).repeat(1, num_tee_pixels).to(self.device)
+ )
# # ensure no out of bounds indexing - it's fine to not fully 'render' tee, just need to fully see goal tee which is insured
# # because we are in the goal tee frame, and 'cad' tee render setup of egocentric view includes full tee
- # # also, the reward isn't miou, it's intersection area / goal area - don't need union -> don't need full T 'render'
+ # # also, the reward isn't miou, it's intersection area / goal area - don't need union -> don't need full T 'render'
# #ugly solution for now to keep parallelism no loop - set out of bound image t indices to [0,0]
# # anywhere where x or y is out of bounds, make indices (0,0)
invalid_xs = (tee_indices[:, 0, :] < 0) | (tee_indices[:, 0, :] >= self.res)
@@ -345,15 +430,17 @@ def pseudo_render_intersection(self):
tee_indices[:, 0, :][invalid_ys] = 0
tee_indices[:, 1, :][invalid_ys] = 0
- final_renders[batch_indices, tee_indices[:,0,:], tee_indices[:,1,:]] = 1
+ final_renders[batch_indices, tee_indices[:, 0, :], tee_indices[:, 1, :]] = 1
# coord to image fix - need to transpose each image in the batch, then reverse y coords to correctly visualize
- final_renders = final_renders.permute(0,2,1).flip(1)
+ final_renders = final_renders.permute(0, 2, 1).flip(1)
# finally, we can calculate intersection/goal_area for reward
- intersection = (final_renders.bool() & self.tee_render.bool()).sum(dim=[-1,-2]).float()
+ intersection = (
+ (final_renders.bool() & self.tee_render.bool()).sum(dim=[-1, -2]).float()
+ )
goal_area = self.tee_render.bool().sum().float()
- reward = (intersection / goal_area)
+ reward = intersection / goal_area
# del tee_to_world_trans; del tee_to_goal_trans; del tees_in_goal_frame; del tee_coords; del tee_indices
# del final_renders; del invalid_xs; del invalid_ys; batch_indices; del intersection; del goal_area
@@ -378,18 +465,24 @@ def _initialize_episode(self, env_idx: torch.Tensor, options: dict):
)
)
- # randomization code that randomizes the x, y position of the tee we
+ # randomization code that randomizes the x, y position of the tee we
# goal tee is alredy at y = -0.1 relative to robot, so we allow the tee to be only -0.2 y relative to robot arm
- target_region_xyz[..., 0] += torch.rand(b) * (self.tee_spawnbox_xlength) + self.tee_spawnbox_xoffset
- target_region_xyz[..., 1] += torch.rand(b) * (self.tee_spawnbox_ylength) + self.tee_spawnbox_yoffset
+ target_region_xyz[..., 0] += (
+ torch.rand(b) * (self.tee_spawnbox_xlength) + self.tee_spawnbox_xoffset
+ )
+ target_region_xyz[..., 1] += (
+ torch.rand(b) * (self.tee_spawnbox_ylength) + self.tee_spawnbox_yoffset
+ )
- target_region_xyz[..., 2] = 0.04/2 + 1e-3 #this is the half thickness of the tee plus a little
+ target_region_xyz[..., 2] = (
+ 0.04 / 2 + 1e-3
+ ) # this is the half thickness of the tee plus a little
# rotation for pose is just random rotation around z axis
# z axis rotation euler to quaternion = [cos(theta/2),0,0,sin(theta/2)]
- q_euler_angle = torch.rand(b)*(2*torch.pi)
- q = torch.zeros((b,4))
- q[:,0] = (q_euler_angle/2).cos()
- q[:,-1] = (q_euler_angle/2).sin()
+ q_euler_angle = torch.rand(b) * (2 * torch.pi)
+ q = torch.zeros((b, 4))
+ q[:, 0] = (q_euler_angle / 2).cos()
+ q[:, -1] = (q_euler_angle / 2).sin()
obj_pose = Pose.create_from_pq(p=target_region_xyz, q=q)
self.tee.set_pose(obj_pose)
@@ -432,30 +525,30 @@ def compute_dense_reward(self, obs: Any, action: Array, info: Dict):
# reward for overlap of the tees
# legacy reward
- #reward = self.pseudo_render_reward()
+ # reward = self.pseudo_render_reward()
# Pose based reward below is preferred over legacy reward
# legacy reward gets stuck in local maxs of 50-75% intersection
# and then fails to promote large explorations to perfectly orient the T, for PPO algorithm
-
+
# new pose based reward: cos(z_rot_euler) + function of translation, between target and goal both in [0,1]
# z euler cosine similarity reward: -- quat_to_z_euler guarenteed to reutrn value from [0,2pi]
tee_z_eulers = self.quat_to_z_euler(self.tee.pose.q)
# subtract the goal z rotatation to get relative rotation
rot_rew = (tee_z_eulers - self.goal_z_rot).cos()
# cos output [-1,1], we want reward of 0.5
- #reward = (rot_rew+1)/4
- reward = (((rot_rew+1)/2)**2)/2
+ # reward = (rot_rew+1)/4
+ reward = (((rot_rew + 1) / 2) ** 2) / 2
# x and y distance as reward
- tee_to_goal_pose = self.tee.pose.p[:,0:2] - self.goal_tee.pose.p[:,0:2]
+ tee_to_goal_pose = self.tee.pose.p[:, 0:2] - self.goal_tee.pose.p[:, 0:2]
tee_to_goal_pose_dist = torch.linalg.norm(tee_to_goal_pose, axis=1)
- reward += ((1 - torch.tanh(5 * tee_to_goal_pose_dist))**2)/2
+ reward += ((1 - torch.tanh(5 * tee_to_goal_pose_dist)) ** 2) / 2
# giving the robot a little help by rewarding it for having its end-effector close to the tee center of mass
- #tcp_to_push_pose = self.tee.pose.p[:,0:2] - self.agent.tcp.pose.p[:,0:2]
+ # tcp_to_push_pose = self.tee.pose.p[:,0:2] - self.agent.tcp.pose.p[:,0:2]
tcp_to_push_pose = self.tee.pose.p - self.agent.tcp.pose.p
tcp_to_push_pose_dist = torch.linalg.norm(tcp_to_push_pose, axis=1)
- reward += ((1 - torch.tanh(5 * tcp_to_push_pose_dist)).sqrt())/20
+ reward += ((1 - torch.tanh(5 * tcp_to_push_pose_dist)).sqrt()) / 20
# assign rewards to parallel environments that achieved success to the maximum of 3.
reward[info["success"]] = 3
@@ -463,4 +556,4 @@ def compute_dense_reward(self, obs: Any, action: Array, info: Dict):
def compute_normalized_dense_reward(self, obs: Any, action: Array, info: Dict):
max_reward = 3.0
- return self.compute_dense_reward(obs=obs, action=action, info=info) / max_reward
\ No newline at end of file
+ return self.compute_dense_reward(obs=obs, action=action, info=info) / max_reward
diff --git a/mani_skill/envs/tasks/tabletop/roll_ball.py b/mani_skill/envs/tasks/tabletop/roll_ball.py
index 18405d09e..7a8ecbfcd 100644
--- a/mani_skill/envs/tasks/tabletop/roll_ball.py
+++ b/mani_skill/envs/tasks/tabletop/roll_ball.py
@@ -48,7 +48,7 @@ def __init__(self, *args, robot_uids="panda", robot_init_qpos_noise=0.02, **kwar
@property
def _default_sim_config(self):
return SimConfig(
- gpu_memory_cfg=GPUMemoryConfig(
+ gpu_memory_config=GPUMemoryConfig(
found_lost_pairs_capacity=2**25, max_rigid_patch_count=2**18
)
)
diff --git a/mani_skill/envs/tasks/tabletop/two_robot_pick_cube.py b/mani_skill/envs/tasks/tabletop/two_robot_pick_cube.py
index 8c09276a9..4376a4708 100644
--- a/mani_skill/envs/tasks/tabletop/two_robot_pick_cube.py
+++ b/mani_skill/envs/tasks/tabletop/two_robot_pick_cube.py
@@ -53,7 +53,7 @@ def __init__(
@property
def _default_sim_config(self):
return SimConfig(
- gpu_memory_cfg=GPUMemoryConfig(
+ gpu_memory_config=GPUMemoryConfig(
found_lost_pairs_capacity=2**25,
max_rigid_patch_count=2**19,
max_rigid_contact_count=2**21,
diff --git a/mani_skill/envs/tasks/tabletop/two_robot_stack_cube.py b/mani_skill/envs/tasks/tabletop/two_robot_stack_cube.py
index 3da7318c4..eaa8ad08b 100644
--- a/mani_skill/envs/tasks/tabletop/two_robot_stack_cube.py
+++ b/mani_skill/envs/tasks/tabletop/two_robot_stack_cube.py
@@ -56,7 +56,7 @@ def __init__(
@property
def _default_sim_config(self):
return SimConfig(
- gpu_memory_cfg=GPUMemoryConfig(
+ gpu_memory_config=GPUMemoryConfig(
found_lost_pairs_capacity=2**25,
max_rigid_patch_count=2**19,
max_rigid_contact_count=2**21,
diff --git a/mani_skill/envs/template.py b/mani_skill/envs/template.py
index e535d316d..962790b4b 100644
--- a/mani_skill/envs/template.py
+++ b/mani_skill/envs/template.py
@@ -82,7 +82,7 @@ def __init__(self, *args, robot_uids="panda", robot_init_qpos_noise=0.02, **kwar
@property
def _default_sim_config(self):
return SimConfig(
- gpu_memory_cfg=GPUMemoryConfig(
+ gpu_memory_config=GPUMemoryConfig(
found_lost_pairs_capacity=2**25, max_rigid_patch_count=2**18
)
)
diff --git a/mani_skill/envs/utils/observations/__init__.py b/mani_skill/envs/utils/observations/__init__.py
index df4447b22..86b7af6e0 100644
--- a/mani_skill/envs/utils/observations/__init__.py
+++ b/mani_skill/envs/utils/observations/__init__.py
@@ -1 +1,53 @@
+from dataclasses import dataclass
+
from .observations import *
+
+
+@dataclass
+class CameraObsTextures:
+ rgb: bool
+ depth: bool
+ segmentation: bool
+ position: bool
+
+
+def parse_visual_obs_mode_to_struct(obs_mode: str) -> CameraObsTextures:
+ """Given user supplied observation mode, return a struct with the relevant textures that are to be captured"""
+ if obs_mode == "rgb":
+ return CameraObsTextures(
+ rgb=True, depth=False, segmentation=False, position=False
+ )
+ elif obs_mode == "rgbd":
+ return CameraObsTextures(
+ rgb=True, depth=True, segmentation=False, position=False
+ )
+ elif obs_mode == "depth":
+ return CameraObsTextures(
+ rgb=False, depth=True, segmentation=False, position=False
+ )
+ elif obs_mode == "segmentation":
+ return CameraObsTextures(
+ rgb=False, depth=False, segmentation=True, position=False
+ )
+ elif obs_mode == "rgb+depth":
+ return CameraObsTextures(
+ rgb=True, depth=True, segmentation=False, position=False
+ )
+ elif obs_mode == "rgb+depth+segmentation":
+ return CameraObsTextures(
+ rgb=True, depth=True, segmentation=True, position=False
+ )
+ elif obs_mode == "rgb+segmentation":
+ return CameraObsTextures(
+ rgb=True, depth=False, segmentation=True, position=False
+ )
+ elif obs_mode == "depth+segmentation":
+ return CameraObsTextures(
+ rgb=False, depth=True, segmentation=True, position=False
+ )
+ elif obs_mode == "pointcloud":
+ return CameraObsTextures(
+ rgb=True, depth=False, segmentation=True, position=True
+ )
+ else:
+ return None
diff --git a/mani_skill/envs/utils/observations/observations.py b/mani_skill/envs/utils/observations/observations.py
index 51abbeb27..52bfaa09a 100644
--- a/mani_skill/envs/utils/observations/observations.py
+++ b/mani_skill/envs/utils/observations/observations.py
@@ -8,58 +8,12 @@
import sapien.physx as physx
import torch
+from mani_skill.render import SAPIEN_RENDER_SYSTEM
from mani_skill.sensors.base_sensor import BaseSensor, BaseSensorConfig
from mani_skill.sensors.camera import Camera
from mani_skill.utils import common
-def sensor_data_to_rgbd(
- observation: Dict,
- sensors: Dict[str, BaseSensor],
- rgb=True,
- depth=True,
- segmentation=True,
-):
- """
- Converts all camera data to a easily usable rgb+depth format
-
- Optionally can include segmentation
- """
- sensor_data = observation["sensor_data"]
- for (cam_uid, ori_images), (sensor_uid, sensor) in zip(
- sensor_data.items(), sensors.items()
- ):
- assert cam_uid == sensor_uid
- if isinstance(sensor, Camera):
- new_images = dict()
- ori_images: Dict[str, torch.Tensor]
- for key in ori_images:
- if key == "Color":
- if rgb:
- rgb_data = ori_images[key][..., :3].clone() # [H, W, 4]
- new_images["rgb"] = rgb_data # [H, W, 4]
- elif key == "PositionSegmentation":
- if depth:
- depth_data = -ori_images[key][..., [2]] # [H, W, 1]
- # NOTE (stao): This is a bit of a hack since normally we have generic to_numpy call to convert
- # internal torch tensors to numpy if we do not use GPU simulation
- # but torch does not have a uint16 type so we convert that here earlier
- # if not physx.is_gpu_enabled():
- # depth_data = depth_data.numpy().astype(np.uint16)
- new_images["depth"] = depth_data
- if segmentation:
- segmentation_data = ori_images[key][..., [3]]
- # if not physx.is_gpu_enabled():
- # segmentation_data = segmentation_data.numpy().astype(
- # np.uint16
- # )
- new_images["segmentation"] = segmentation_data # [H, W, 1]
- else:
- new_images[key] = ori_images[key]
- sensor_data[cam_uid] = new_images
- return observation
-
-
def sensor_data_to_pointcloud(observation: Dict, sensors: Dict[str, BaseSensor]):
"""convert all camera data in sensor to pointcloud data"""
sensor_data = observation["sensor_data"]
@@ -72,30 +26,29 @@ def sensor_data_to_pointcloud(observation: Dict, sensors: Dict[str, BaseSensor])
assert cam_uid == sensor_uid
if isinstance(sensor, Camera):
cam_pcd = {}
-
+ # TODO: double check if the .clone()s are necessary
# Each pixel is (x, y, z, actor_id) in OpenGL camera space
# actor_id = 0 for the background
images: Dict[str, torch.Tensor]
- position = images["PositionSegmentation"]
- segmentation = position[..., 3].clone()
+ position = images["position"].clone()
+ segmentation = images["segmentation"].clone()
position = position.float()
- position[..., 3] = position[..., 3] != 0
position[..., :3] = (
position[..., :3] / 1000.0
) # convert the raw depth from millimeters to meters
# Convert to world space
cam2world = camera_params[cam_uid]["cam2world_gl"]
- xyzw = position.reshape(position.shape[0], -1, 4) @ cam2world.transpose(
- 1, 2
- )
+ xyzw = torch.cat([position, segmentation != 0], dim=-1).reshape(
+ position.shape[0], -1, 4
+ ) @ cam2world.transpose(1, 2)
cam_pcd["xyzw"] = xyzw
# Extra keys
- if "Color" in images:
- rgb = images["Color"][..., :3].clone()
+ if "rgb" in images:
+ rgb = images["rgb"][..., :3].clone()
cam_pcd["rgb"] = rgb.reshape(rgb.shape[0], -1, 3)
- if "PositionSegmentation" in images:
+ if "segmentation" in images:
cam_pcd["segmentation"] = segmentation.reshape(
segmentation.shape[0], -1, 1
)
diff --git a/mani_skill/examples/benchmarking/envs/maniskill/cartpole.py b/mani_skill/examples/benchmarking/envs/maniskill/cartpole.py
index f321b044a..9716118f8 100644
--- a/mani_skill/examples/benchmarking/envs/maniskill/cartpole.py
+++ b/mani_skill/examples/benchmarking/envs/maniskill/cartpole.py
@@ -20,7 +20,7 @@ def _default_sim_config(self):
sim_freq=120,
spacing=20,
control_freq=60,
- scene_cfg=SceneConfig(
+ scene_config=SceneConfig(
bounce_threshold=0.5,
solver_position_iterations=4, solver_velocity_iterations=0
),
diff --git a/mani_skill/examples/benchmarking/envs/maniskill/pick_cube.py b/mani_skill/examples/benchmarking/envs/maniskill/pick_cube.py
index 7019195f8..7572eb698 100644
--- a/mani_skill/examples/benchmarking/envs/maniskill/pick_cube.py
+++ b/mani_skill/examples/benchmarking/envs/maniskill/pick_cube.py
@@ -11,7 +11,7 @@ def _default_sim_config(self):
return SimConfig(
sim_freq=100,
control_freq=50,
- scene_cfg=SceneConfig(
+ scene_config=SceneConfig(
bounce_threshold=0.01,
solver_position_iterations=8, solver_velocity_iterations=0
),
diff --git a/mani_skill/examples/benchmarking/gpu_sim.py b/mani_skill/examples/benchmarking/gpu_sim.py
index 7ef72e872..69ddc4263 100644
--- a/mani_skill/examples/benchmarking/gpu_sim.py
+++ b/mani_skill/examples/benchmarking/gpu_sim.py
@@ -23,11 +23,11 @@
def main(args):
profiler = Profiler(output_format="stdout")
num_envs = args.num_envs
- sim_cfg = dict()
+ sim_config = dict()
if args.control_freq:
- sim_cfg["control_freq"] = args.control_freq
+ sim_config["control_freq"] = args.control_freq
if args.sim_freq:
- sim_cfg["sim_freq"] = args.sim_freq
+ sim_config["sim_freq"] = args.sim_freq
if not args.cpu_sim:
kwargs = dict()
if args.env_id in BENCHMARK_ENVS:
@@ -44,7 +44,7 @@ def main(args):
# enable_shadow=True,
render_mode=args.render_mode,
control_mode=args.control_mode,
- sim_cfg=sim_cfg,
+ sim_config=sim_config,
**kwargs
)
if isinstance(env.action_space, gym.spaces.Dict):
diff --git a/mani_skill/examples/demo_random_action.py b/mani_skill/examples/demo_random_action.py
index 07b477714..e14463ad9 100644
--- a/mani_skill/examples/demo_random_action.py
+++ b/mani_skill/examples/demo_random_action.py
@@ -2,6 +2,7 @@
import gymnasium as gym
import numpy as np
+import sapien
from mani_skill.envs.sapien_env import BaseEnv
from mani_skill.utils.wrappers import RecordEpisode
@@ -16,7 +17,7 @@ def parse_args(args=None):
parser.add_argument("--num-envs", type=int, default=1, help="Number of environments to run.")
parser.add_argument("-c", "--control-mode", type=str)
parser.add_argument("--render-mode", type=str, default="rgb_array")
- parser.add_argument("--shader", default="default", type=str, help="Change shader used for rendering. Default is 'default' which is very fast. Can also be 'rt' for ray tracing and generating photo-realistic renders. Can also be 'rt-fast' for a faster but lower quality ray-traced renderer")
+ parser.add_argument("--shader", default="minimal", type=str, help="Change shader used for all cameras in the environment for rendering. Default is 'minimal' which is very fast. Can also be 'rt' for ray tracing and generating photo-realistic renders. Can also be 'rt-fast' for a faster but lower quality ray-traced renderer")
parser.add_argument("--record-dir", type=str)
parser.add_argument("-p", "--pause", action="store_true", help="If using human render mode, auto pauses the simulation upon loading")
parser.add_argument("--quiet", action="store_true", help="Disable verbose output.")
@@ -57,7 +58,9 @@ def main(args):
reward_mode=args.reward_mode,
control_mode=args.control_mode,
render_mode=args.render_mode,
- shader_dir=args.shader,
+ sensor_configs=dict(shader_pack=args.shader),
+ human_render_camera_configs=dict(shader_pack=args.shader),
+ viewer_camera_configs=dict(shader_pack=args.shader),
num_envs=args.num_envs,
sim_backend=args.sim_backend,
parallel_in_single_scene=parallel_in_single_scene,
@@ -78,7 +81,8 @@ def main(args):
env.action_space.seed(args.seed)
if args.render_mode is not None:
viewer = env.render()
- viewer.paused = args.pause
+ if isinstance(viewer, sapien.utils.Viewer):
+ viewer.paused = args.pause
env.render()
while True:
action = env.action_space.sample()
diff --git a/mani_skill/examples/demo_reset_distribution.py b/mani_skill/examples/demo_reset_distribution.py
index 39cac0678..4c9241ac8 100644
--- a/mani_skill/examples/demo_reset_distribution.py
+++ b/mani_skill/examples/demo_reset_distribution.py
@@ -10,7 +10,7 @@ def parse_args(args=None):
parser = argparse.ArgumentParser()
parser.add_argument("-e", "--env-id", type=str, default="PushCube-v1", help="The environment ID of the task you want to simulate")
parser.add_argument("-b", "--sim-backend", type=str, default="auto", help="Which simulation backend to use. Can be 'auto', 'cpu', 'gpu'")
- parser.add_argument("--shader", default="default", type=str, help="Change shader used for rendering. Default is 'default' which is very fast. Can also be 'rt' for ray tracing and generating photo-realistic renders. Can also be 'rt-fast' for a faster but lower quality ray-traced renderer")
+ parser.add_argument("--shader", default="minimal", type=str, help="Change shader used for rendering. Default is 'default' which is very fast. Can also be 'rt' for ray tracing and generating photo-realistic renders. Can also be 'rt-fast' for a faster but lower quality ray-traced renderer")
parser.add_argument("--render-mode", type=str, default="rgb_array", help="Can be 'human' to open a viewer, or rgb_array / sensors which change the cameras saved videos use")
parser.add_argument("--record-dir", type=str, default="videos/reset_distributions", help="Where to save recorded videos. If none, no videos are saved")
parser.add_argument("-n", "--num-resets", type=int, default=20, help="Number of times to reset the environment")
@@ -33,7 +33,9 @@ def main(args):
obs_mode="none",
reward_mode="none",
render_mode=args.render_mode,
- shader_dir=args.shader,
+ sensor_configs=dict(shader_pack=args.shader),
+ human_render_camera_configs=dict(shader_pack=args.shader),
+ viewer_camera_configs=dict(shader_pack=args.shader),
sim_backend=args.sim_backend,
)
if args.record_dir is not None and args.render_mode != "human":
diff --git a/mani_skill/examples/demo_robot.py b/mani_skill/examples/demo_robot.py
index dda0f53c7..fb8bfe7a0 100644
--- a/mani_skill/examples/demo_robot.py
+++ b/mani_skill/examples/demo_robot.py
@@ -35,7 +35,9 @@ def main():
enable_shadow=True,
control_mode=args.control_mode,
robot_uids=args.robot_uid,
- shader_dir=args.shader,
+ sensor_configs=dict(shader_pack=args.shader),
+ human_render_camera_configs=dict(shader_pack=args.shader),
+ viewer_camera_configs=dict(shader_pack=args.shader),
render_mode="human",
sim_backend=args.sim_backend,
)
diff --git a/mani_skill/examples/demo_vis_pcd.py b/mani_skill/examples/demo_vis_pcd.py
index dd15ab1a5..54404fe49 100644
--- a/mani_skill/examples/demo_vis_pcd.py
+++ b/mani_skill/examples/demo_vis_pcd.py
@@ -47,10 +47,10 @@ def main(args):
# view from first camera
- for uid, cfg in env.unwrapped._sensor_configs.items():
- if isinstance(cfg, CameraConfig):
+ for uid, config in env.unwrapped._sensor_configs.items():
+ if isinstance(config, CameraConfig):
cam2world = obs["sensor_param"][uid]["cam2world_gl"][0]
- camera = trimesh.scene.Camera(uid, (1024, 1024), fov=(np.rad2deg(cfg.fov), np.rad2deg(cfg.fov)))
+ camera = trimesh.scene.Camera(uid, (1024, 1024), fov=(np.rad2deg(config.fov), np.rad2deg(config.fov)))
break
trimesh.Scene([pcd], camera=camera, camera_transform=cam2world).show()
if terminated or truncated:
diff --git a/mani_skill/examples/demo_vis_rgbd.py b/mani_skill/examples/demo_vis_rgbd.py
index a91b34928..eb8c1aaad 100644
--- a/mani_skill/examples/demo_vis_rgbd.py
+++ b/mani_skill/examples/demo_vis_rgbd.py
@@ -17,6 +17,7 @@
def parse_args(args=None):
parser = argparse.ArgumentParser()
parser.add_argument("-e", "--env-id", type=str, default="PushCube-v1", help="The environment ID of the task you want to simulate")
+ parser.add_argument("--shader", default="minimal", type=str, help="Change shader used for all cameras in the environment for rendering. Default is 'minimal' which is very fast. Can also be 'rt' for ray tracing and generating photo-realistic renders. Can also be 'rt-fast' for a faster but lower quality ray-traced renderer")
parser.add_argument("--num-envs", type=int, default=1, help="Number of environments to run. Used for some basic testing and not visualized")
parser.add_argument("--cam-width", type=int, help="Override the width of every camera in the environment")
parser.add_argument("--cam-height", type=int, help="Override the height of every camera in the environment")
@@ -44,6 +45,7 @@ def main(args):
sensor_configs["width"] = args.cam_width
if args.cam_height:
sensor_configs["height"] = args.cam_height
+ sensor_configs["shader_pack"] = args.shader
env: BaseEnv = gym.make(
args.env_id,
obs_mode="rgbd",
@@ -67,7 +69,6 @@ def main(args):
imgs=[]
for cam in obs["sensor_data"].keys():
if "rgb" in obs["sensor_data"][cam]:
-
rgb = common.to_numpy(obs["sensor_data"][cam]["rgb"][0])
depth = common.to_numpy(obs["sensor_data"][cam]["depth"][0]).astype(np.float32)
depth = depth / (depth.max() - depth.min())
diff --git a/mani_skill/examples/motionplanning/panda/run.py b/mani_skill/examples/motionplanning/panda/run.py
index 948f23a8c..4d18374b3 100644
--- a/mani_skill/examples/motionplanning/panda/run.py
+++ b/mani_skill/examples/motionplanning/panda/run.py
@@ -36,7 +36,9 @@ def main(args):
control_mode="pd_joint_pos",
render_mode=args.render_mode,
reward_mode="dense" if args.reward_mode is None else args.reward_mode,
- shader_dir=args.shader,
+ sensor_configs=dict(shader_pack=args.shader),
+ human_render_camera_configs=dict(shader_pack=args.shader),
+ viewer_camera_configs=dict(shader_pack=args.shader),
sim_backend=args.sim_backend
)
if env_id not in MP_SOLUTIONS:
diff --git a/mani_skill/examples/motionplanning/panda/solutions/peg_insertion_side.py b/mani_skill/examples/motionplanning/panda/solutions/peg_insertion_side.py
index 6961dc4f4..232371157 100644
--- a/mani_skill/examples/motionplanning/panda/solutions/peg_insertion_side.py
+++ b/mani_skill/examples/motionplanning/panda/solutions/peg_insertion_side.py
@@ -16,7 +16,6 @@ def main():
control_mode="pd_joint_pos",
render_mode="rgb_array",
reward_mode="dense",
- # shader_dir="rt-fast",
)
for seed in range(100):
res = solve(env, seed=seed, debug=False, vis=True)
diff --git a/mani_skill/examples/motionplanning/panda/solutions/plug_charger.py b/mani_skill/examples/motionplanning/panda/solutions/plug_charger.py
index 9da9cc22d..2e1806ee7 100644
--- a/mani_skill/examples/motionplanning/panda/solutions/plug_charger.py
+++ b/mani_skill/examples/motionplanning/panda/solutions/plug_charger.py
@@ -19,7 +19,6 @@ def main():
control_mode="pd_joint_pos",
render_mode="rgb_array",
reward_mode="sparse",
- # shader_dir="rt-fast",
)
for seed in tqdm(range(100)):
res = solve(env, seed=seed, debug=False, vis=True)
diff --git a/mani_skill/examples/teleoperation/interactive_panda.py b/mani_skill/examples/teleoperation/interactive_panda.py
index 9da940b5b..3d7b45f3f 100644
--- a/mani_skill/examples/teleoperation/interactive_panda.py
+++ b/mani_skill/examples/teleoperation/interactive_panda.py
@@ -21,7 +21,6 @@ def main(args):
control_mode="pd_joint_pos",
render_mode="rgb_array",
reward_mode="sparse",
- # shader_dir="rt-fast",
)
env = RecordEpisode(
env,
@@ -63,7 +62,6 @@ def main(args):
control_mode="pd_joint_pos",
render_mode="rgb_array",
reward_mode="sparse",
- shader_dir="rt-med",
)
env = RecordEpisode(
env,
diff --git a/mani_skill/render/__init__.py b/mani_skill/render/__init__.py
index 62b1d5ff1..a15684429 100644
--- a/mani_skill/render/__init__.py
+++ b/mani_skill/render/__init__.py
@@ -1,23 +1,2 @@
-from contextlib import contextmanager
-
-import sapien
-
-SAPIEN_RENDER_SYSTEM = "3.0"
-try:
- # NOTE (stao): hacky way to determine which render system in sapien 3 is being used for testing purposes
- from sapien.wrapper.scene import get_camera_shader_pack
-
- SAPIEN_RENDER_SYSTEM = "3.1"
-except:
- pass
-
-GlobalShaderPack = None
-
-
-@contextmanager
-def set_shader_pack(shader_pack):
- global GlobalShaderPack
- old = GlobalShaderPack
- GlobalShaderPack = shader_pack
- yield
- GlobalShaderPack = old
+from .shaders import PREBUILT_SHADER_CONFIGS, ShaderConfig, set_shader_pack
+from .version import SAPIEN_RENDER_SYSTEM
diff --git a/mani_skill/render/shaders.py b/mani_skill/render/shaders.py
new file mode 100644
index 000000000..8f9958253
--- /dev/null
+++ b/mani_skill/render/shaders.py
@@ -0,0 +1,147 @@
+from dataclasses import dataclass, field
+from typing import Any, Callable, Dict, List
+
+import sapien
+import torch
+
+from mani_skill.render.version import SAPIEN_RENDER_SYSTEM
+
+
+@dataclass
+class ShaderConfig:
+ """simple shader config dataclass to determine which shader pack to use, textures to render, and any possible configurations for the shader pack. Can be used as part of the CameraConfig
+ to further customize the camera output.
+
+ A shader config must define which shader pack to use, and which textures to consider rendering. Additional shader pack configs can be passed which are specific to the shader config itself
+ and can modify shader settings.
+
+ Texture transforms must be defined and are used to process the texture data into more standard formats for use. Some textures might be combined textures (e.g. depth+segmentation together)
+ due to shader optimizations. texture transforms must then split these combined textures back into their component parts.
+
+ The standard image modalities and expected dtypes/shapes are:
+ - rgb (torch.uint8, shape: [H, W, 3])
+ - depth (torch.int16, shape: [H, W])
+ - segmentation (torch.int16, shape: [H, W])
+ - position (torch.float32, shape: [H, W, 3]) (infinite points have segmentation == 0)
+ """
+
+ shader_pack: str
+ texture_names: Dict[str, List[str]] = field(default_factory=dict)
+ """dictionary mapping shader texture names to the image modalities that are rendered. e.g. Color, Depth, Segmentation, etc."""
+ shader_pack_config: Dict[str, Any] = field(default_factory=dict)
+ """configs for the shader pack. for e.g. the ray tracing shader you can configure the denoiser, samples per pixel, etc."""
+
+ texture_transforms: Dict[
+ str, Callable[[torch.Tensor], Dict[str, torch.Tensor]]
+ ] = field(default_factory=dict)
+ """texture transform functions that map each texture name to a function that converts the texture data into one or more standard image modalities. The return type should be a
+ dictionary with keys equal to the names of standard image modalities and values equal to the transformed data"""
+
+
+PREBUILT_SHADER_CONFIGS = {
+ "minimal": ShaderConfig(
+ shader_pack="minimal",
+ texture_names={
+ "Color": ["rgb"],
+ "PositionSegmentation": ["position", "depth", "segmentation"],
+ },
+ texture_transforms={
+ "Color": lambda data: {"rgb": data[..., :3]},
+ "PositionSegmentation": lambda data: {
+ "position": data[..., :3],
+ "depth": -data[..., [2]],
+ "segmentation": data[..., [3]],
+ },
+ },
+ ),
+ "default": ShaderConfig(
+ shader_pack="default",
+ texture_names={
+ "Color": ["rgb"],
+ "Position": ["position", "depth"],
+ "Segmentation": ["segmentation"],
+ },
+ texture_transforms={
+ "Color": lambda data: {"rgb": (data[..., :3] * 255).to(torch.uint8)},
+ "Position": lambda data: {
+ "depth": (-data[..., [2]] * 1000).to(torch.int16),
+ "position": data[..., :3],
+ },
+ "Segmentation": lambda data: {"segmentation": data[..., 3][..., None]},
+ },
+ ),
+ "rt": ShaderConfig(
+ shader_pack="rt",
+ texture_names={
+ "Color": ["rgb"],
+ },
+ shader_pack_config={
+ "ray_tracing_samples_per_pixel": 32,
+ "ray_tracing_path_depth": 16,
+ "ray_tracing_denoiser": "optix",
+ },
+ texture_transforms={
+ "Color": lambda data: {"rgb": (data[..., :3] * 255).to(torch.uint8)},
+ },
+ ),
+ "rt-med": ShaderConfig(
+ shader_pack="rt",
+ texture_names={
+ "Color": ["rgb"],
+ },
+ shader_pack_config={
+ "ray_tracing_samples_per_pixel": 4,
+ "ray_tracing_path_depth": 3,
+ "ray_tracing_denoiser": "optix",
+ },
+ texture_transforms={
+ "Color": lambda data: {"rgb": (data[..., :3] * 255).to(torch.uint8)},
+ },
+ ),
+ "rt-fast": ShaderConfig(
+ shader_pack="rt",
+ texture_names={
+ "Color": ["rgb"],
+ },
+ shader_pack_config={
+ "ray_tracing_samples_per_pixel": 2,
+ "ray_tracing_path_depth": 1,
+ "ray_tracing_denoiser": "optix",
+ },
+ texture_transforms={
+ "Color": lambda data: {"rgb": (data[..., :3] * 255).to(torch.uint8)},
+ },
+ ),
+}
+"""pre-defined shader configs"""
+
+
+def set_shader_pack(shader_config: ShaderConfig):
+ """sets a global shader pack for cameras. Used only for the 3.0 SAPIEN rendering system"""
+ if SAPIEN_RENDER_SYSTEM == "3.0":
+ sapien.render.set_camera_shader_dir(shader_config.shader_pack)
+ if shader_config.shader_pack == "minimal":
+ sapien.render.set_camera_shader_dir("minimal")
+ sapien.render.set_picture_format("Color", "r8g8b8a8unorm")
+ sapien.render.set_picture_format("ColorRaw", "r8g8b8a8unorm")
+ sapien.render.set_picture_format("PositionSegmentation", "r16g16b16a16sint")
+ if shader_config.shader_pack == "default":
+ sapien.render.set_camera_shader_dir("default")
+ sapien.render.set_picture_format("Color", "r32g32b32a32sfloat")
+ sapien.render.set_picture_format("ColorRaw", "r32g32b32a32sfloat")
+ sapien.render.set_picture_format(
+ "PositionSegmentation", "r32g32b32a32sfloat"
+ )
+ if shader_config.shader_pack[:2] == "rt":
+ sapien.render.set_ray_tracing_samples_per_pixel(
+ shader_config.shader_pack_config["ray_tracing_samples_per_pixel"]
+ )
+ sapien.render.set_ray_tracing_path_depth(
+ shader_config.shader_pack_config["ray_tracing_path_depth"]
+ )
+ sapien.render.set_ray_tracing_denoiser(
+ shader_config.shader_pack_config["ray_tracing_denoiser"]
+ )
+ elif SAPIEN_RENDER_SYSTEM == "3.1":
+ # sapien.render.set_camera_shader_pack_name would set a global default
+ pass
diff --git a/mani_skill/render/version.py b/mani_skill/render/version.py
new file mode 100644
index 000000000..37b16733a
--- /dev/null
+++ b/mani_skill/render/version.py
@@ -0,0 +1,8 @@
+SAPIEN_RENDER_SYSTEM = "3.0"
+try:
+ # NOTE (stao): hacky way to determine which render system in sapien 3 is being used for testing purposes
+ from sapien.wrapper.scene import get_camera_shader_pack
+
+ SAPIEN_RENDER_SYSTEM = "3.1"
+except:
+ pass
diff --git a/mani_skill/sensors/base_sensor.py b/mani_skill/sensors/base_sensor.py
index bfb2a2406..76a1ee3b5 100644
--- a/mani_skill/sensors/base_sensor.py
+++ b/mani_skill/sensors/base_sensor.py
@@ -14,8 +14,8 @@ class BaseSensor:
Base class for all sensors
"""
- def __init__(self, cfg: BaseSensorConfig) -> None:
- self.cfg = cfg
+ def __init__(self, config: BaseSensorConfig) -> None:
+ self.config = config
def setup(self) -> None:
"""
@@ -30,7 +30,7 @@ def capture(self) -> None:
non-blocking function if possible.
"""
- def get_obs(self):
+ def get_obs(self, **kwargs):
"""
Retrieves captured sensor data as an observation for use by an agent.
"""
@@ -51,4 +51,4 @@ def get_images(self) -> torch.Tensor:
@property
def uid(self):
- return self.cfg.uid
+ return self.config.uid
diff --git a/mani_skill/sensors/camera.py b/mani_skill/sensors/camera.py
index f359ad8fc..daed95936 100644
--- a/mani_skill/sensors/camera.py
+++ b/mani_skill/sensors/camera.py
@@ -2,14 +2,20 @@
import copy
from dataclasses import dataclass
-from typing import TYPE_CHECKING, Dict, Sequence, Union
+from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union
import numpy as np
import sapien
import sapien.render
+import torch
from torch._tensor import Tensor
-from mani_skill.render import SAPIEN_RENDER_SYSTEM
+from mani_skill.render import (
+ PREBUILT_SHADER_CONFIGS,
+ SAPIEN_RENDER_SYSTEM,
+ ShaderConfig,
+ set_shader_pack,
+)
from mani_skill.utils.structs import Actor, Articulation, Link
from mani_skill.utils.structs.pose import Pose
from mani_skill.utils.structs.types import Array
@@ -21,10 +27,6 @@
from .base_sensor import BaseSensor, BaseSensorConfig
-DEFAULT_TEXTURE_NAMES = ("Color", "PositionSegmentation")
-if SAPIEN_RENDER_SYSTEM == "3.1":
- DEFAULT_TEXTURE_NAMES = ("Color", "Position", "Segmentation")
-
@dataclass
class CameraConfig(BaseSensorConfig):
@@ -49,84 +51,95 @@ class CameraConfig(BaseSensorConfig):
"""entity_uid (str, optional): unique id of the entity to mount the camera. Defaults to None."""
mount: Union[Actor, Link] = None
"""the Actor or Link to mount the camera on top of. This means the global pose of the mounted camera is now mount.pose * local_pose"""
- texture_names: Sequence[str] = DEFAULT_TEXTURE_NAMES
- """texture_names (Sequence[str], optional): texture names to render. Defaults to ("Color", "PositionSegmentation"). Note that the renderign speed will not really change if you remove PositionSegmentation"""
+ texture_names: Optional[Sequence[str]] = None
+ """texture_names (Sequence[str], optional): texture names to render."""
+ shader_pack: Optional[str] = "minimal"
+ """The shader to use for rendering. Defaults to "minimal" which is the fastest rendering system with minimal GPU memory usage. There is also `default` and `rt`."""
+ shader_config: Optional[ShaderConfig] = None
+ """The shader config to use for rendering. If None, the shader_pack will be used to search amongst prebuilt shader configs to create a ShaderConfig."""
def __post_init__(self):
self.pose = Pose.create(self.pose)
+ if self.shader_config is None:
+ self.shader_config = PREBUILT_SHADER_CONFIGS[self.shader_pack]
+ else:
+ self.shader_pack = self.shader_config.shader_pack
def __repr__(self) -> str:
return self.__class__.__name__ + "(" + str(self.__dict__) + ")"
-def update_camera_cfgs_from_dict(
- camera_cfgs: Dict[str, CameraConfig], cfg_dict: Dict[str, dict]
+def update_camera_configs_from_dict(
+ camera_configs: Dict[str, CameraConfig], config_dict: Dict[str, dict]
):
# Update CameraConfig to StereoDepthCameraConfig
- if cfg_dict.pop("use_stereo_depth", False):
+ if config_dict.pop("use_stereo_depth", False):
from .depth_camera import StereoDepthCameraConfig # fmt: skip
- for name, cfg in camera_cfgs.items():
- camera_cfgs[name] = StereoDepthCameraConfig.fromCameraConfig(cfg)
+ for name, config in camera_configs.items():
+ camera_configs[name] = StereoDepthCameraConfig.fromCameraConfig(config)
# First, apply global configuration
- for k, v in cfg_dict.items():
- if k in camera_cfgs:
+ for k, v in config_dict.items():
+ if k in camera_configs:
continue
- for cfg in camera_cfgs.values():
- if not hasattr(cfg, k):
+ for config in camera_configs.values():
+ if not hasattr(config, k):
raise AttributeError(f"{k} is not a valid attribute of CameraConfig")
else:
- setattr(cfg, k, v)
+ if k == "shader_pack":
+ config.shader_config = None
+ setattr(config, k, v)
# Then, apply camera-specific configuration
- for name, v in cfg_dict.items():
- if name not in camera_cfgs:
+ for name, v in config_dict.items():
+ if name not in camera_configs:
continue
# Update CameraConfig to StereoDepthCameraConfig
if v.pop("use_stereo_depth", False):
from .depth_camera import StereoDepthCameraConfig # fmt: skip
- cfg = camera_cfgs[name]
- camera_cfgs[name] = StereoDepthCameraConfig.fromCameraConfig(cfg)
+ config = camera_configs[name]
+ camera_configs[name] = StereoDepthCameraConfig.fromCameraConfig(config)
- cfg = camera_cfgs[name]
+ config = camera_configs[name]
for kk in v:
- assert hasattr(cfg, kk), f"{kk} is not a valid attribute of CameraConfig"
+ if kk == "shader_pack":
+ config.shader_config = None
+ assert hasattr(config, kk), f"{kk} is not a valid attribute of CameraConfig"
v = copy.deepcopy(v)
# for json serailizable gym.make args, user has to pass a list, not a Pose object.
if "pose" in v and isinstance(v["pose"], list):
v["pose"] = sapien.Pose(v["pose"][:3], v["pose"][3:])
- cfg.__dict__.update(v)
-
-
-def parse_camera_cfgs(camera_cfgs):
- if isinstance(camera_cfgs, (tuple, list)):
- return dict([(cfg.uid, cfg) for cfg in camera_cfgs])
- elif isinstance(camera_cfgs, dict):
- return dict(camera_cfgs)
- elif isinstance(camera_cfgs, CameraConfig):
- return dict([(camera_cfgs.uid, camera_cfgs)])
+ config.__dict__.update(v)
+ for config in camera_configs.values():
+ config.__post_init__()
+
+
+def parse_camera_configs(camera_configs):
+ if isinstance(camera_configs, (tuple, list)):
+ return dict([(config.uid, config) for config in camera_configs])
+ elif isinstance(camera_configs, dict):
+ return dict(camera_configs)
+ elif isinstance(camera_configs, CameraConfig):
+ return dict([(camera_configs.uid, camera_configs)])
else:
- raise TypeError(type(camera_cfgs))
+ raise TypeError(type(camera_configs))
class Camera(BaseSensor):
"""Implementation of the Camera sensor which uses the sapien Camera."""
- cfg: CameraConfig
+ config: CameraConfig
def __init__(
self,
- camera_cfg: CameraConfig,
+ camera_config: CameraConfig,
scene: ManiSkillScene,
articulation: Articulation = None,
):
- super().__init__(cfg=camera_cfg)
-
- self.camera_cfg = camera_cfg
-
- entity_uid = camera_cfg.entity_uid
- if camera_cfg.mount is not None:
- self.entity = camera_cfg.mount
+ super().__init__(config=camera_config)
+ entity_uid = camera_config.entity_uid
+ if camera_config.mount is not None:
+ self.entity = camera_config.mount
elif entity_uid is None:
self.entity = None
else:
@@ -141,55 +154,87 @@ def __init__(
if self.entity is None:
raise RuntimeError(f"Mount entity ({entity_uid}) is not found")
- intrinsic = camera_cfg.intrinsic
- assert (camera_cfg.fov is None and intrinsic is not None) or (
- camera_cfg.fov is not None and intrinsic is None
+ intrinsic = camera_config.intrinsic
+ assert (camera_config.fov is None and intrinsic is not None) or (
+ camera_config.fov is not None and intrinsic is None
)
# Add camera to scene. Add mounted one if a entity is given
+ set_shader_pack(self.config.shader_config)
if self.entity is None:
self.camera = scene.add_camera(
- name=camera_cfg.uid,
- pose=camera_cfg.pose,
- width=camera_cfg.width,
- height=camera_cfg.height,
- fovy=camera_cfg.fov,
+ name=camera_config.uid,
+ pose=camera_config.pose,
+ width=camera_config.width,
+ height=camera_config.height,
+ fovy=camera_config.fov,
intrinsic=intrinsic,
- near=camera_cfg.near,
- far=camera_cfg.far,
+ near=camera_config.near,
+ far=camera_config.far,
)
else:
self.camera = scene.add_camera(
- name=camera_cfg.uid,
+ name=camera_config.uid,
mount=self.entity,
- pose=camera_cfg.pose,
- width=camera_cfg.width,
- height=camera_cfg.height,
- fovy=camera_cfg.fov,
+ pose=camera_config.pose,
+ width=camera_config.width,
+ height=camera_config.height,
+ fovy=camera_config.fov,
intrinsic=intrinsic,
- near=camera_cfg.near,
- far=camera_cfg.far,
+ near=camera_config.near,
+ far=camera_config.far,
)
# Filter texture names according to renderer type if necessary (legacy for Kuafu)
- self.texture_names = camera_cfg.texture_names
def capture(self):
self.camera.take_picture()
- def get_obs(self):
- images = {}
- for name in self.texture_names:
- image = self.get_picture(name)
- images[name] = image
- return images
-
- def get_picture(self, name: str):
- return self.camera.get_picture(name)
-
- def get_images(self) -> Tensor:
- return visualization.tile_images(
- visualization.observations_to_images(self.get_obs())
- )
+ def get_obs(
+ self,
+ rgb: bool = True,
+ depth: bool = True,
+ position: bool = True,
+ segmentation: bool = True,
+ apply_texture_transforms: bool = True,
+ ):
+ images_dict = {}
+ # determine which textures are needed to get the desired modalities
+ required_texture_names = []
+ for (
+ texture_name,
+ output_modalities,
+ ) in self.config.shader_config.texture_names.items():
+ if rgb and "rgb" in output_modalities:
+ required_texture_names.append(texture_name)
+ if depth and "depth" in output_modalities:
+ required_texture_names.append(texture_name)
+ if position and "position" in output_modalities:
+ required_texture_names.append(texture_name)
+ if segmentation and "segmentation" in output_modalities:
+ required_texture_names.append(texture_name)
+ required_texture_names = list(set(required_texture_names))
+
+ # fetch the image data
+ output_textures = self.camera.get_picture(required_texture_names)
+ for texture_name, texture in zip(required_texture_names, output_textures):
+ if apply_texture_transforms:
+ images_dict |= self.config.shader_config.texture_transforms[
+ texture_name
+ ](texture)
+ else:
+ images_dict[texture_name] = texture
+ if not rgb and "rgb" in images_dict:
+ del images_dict["rgb"]
+ if not depth and "depth" in images_dict:
+ del images_dict["depth"]
+ if not position and "position" in images_dict:
+ del images_dict["position"]
+ if not segmentation and "segmentation" in images_dict:
+ del images_dict["segmentation"]
+ return images_dict
+
+ def get_images(self, obs) -> Tensor:
+ return camera_observations_to_images(obs)
# TODO (stao): Computing camera parameters on GPU sim is not that fast, especially with mounted cameras and for model_matrix computation.
def get_params(self):
@@ -198,3 +243,46 @@ def get_params(self):
cam2world_gl=self.camera.get_model_matrix(),
intrinsic_cv=self.camera.get_intrinsic_matrix(),
)
+
+
+def normalize_depth(depth, min_depth=0, max_depth=None):
+ if min_depth is None:
+ min_depth = depth.min()
+ if max_depth is None:
+ max_depth = depth.max()
+ depth = (depth - min_depth) / (max_depth - min_depth)
+ depth = depth.clip(0, 1)
+ return depth
+
+
+def camera_observations_to_images(
+ observations: Dict[str, torch.Tensor], max_depth=None
+) -> List[Array]:
+ """Parse images from camera observations."""
+ images = dict()
+ for key in observations:
+ if "rgb" in key or "Color" in key:
+ rgb = observations[key][..., :3]
+ if torch is not None and rgb.dtype == torch.float:
+ rgb = torch.clip(rgb * 255, 0, 255).to(torch.uint8)
+ images[key] = rgb
+ elif "depth" in key or "position" in key:
+ depth = observations[key]
+ if "position" in key: # [H, W, 4]
+ depth = -depth[..., 2:3]
+ # [H, W, 1]
+ depth = normalize_depth(depth, max_depth=max_depth)
+ depth = (depth * 255).clip(0, 255)
+
+ depth = depth.to(torch.uint8)
+ depth = torch.repeat_interleave(depth, 3, dim=-1)
+ images[key] = depth
+ elif "segmentation" in key:
+ seg = observations[key] # [H, W, 1]
+ assert seg.ndim == 4 and seg.shape[-1] == 1, seg.shape
+ # A heuristic way to colorize labels
+ if seg.dtype == torch.uint32:
+ seg = seg.to(torch.int32)
+ seg = (seg * torch.tensor([11, 61, 127], device=seg.device)).to(torch.uint8)
+ images[key] = seg
+ return images
diff --git a/mani_skill/utils/sapien_utils.py b/mani_skill/utils/sapien_utils.py
index 4e3dd9ae5..00ec7d55c 100644
--- a/mani_skill/utils/sapien_utils.py
+++ b/mani_skill/utils/sapien_utils.py
@@ -143,18 +143,18 @@ def parse_urdf_config(config_dict: dict) -> Dict:
def apply_urdf_config(loader: sapien.wrapper.urdf_loader.URDFLoader, urdf_config: dict):
if "link" in urdf_config:
- for name, link_cfg in urdf_config["link"].items():
- if "material" in link_cfg:
- mat: physx.PhysxMaterial = link_cfg["material"]
+ for name, link_config in urdf_config["link"].items():
+ if "material" in link_config:
+ mat: physx.PhysxMaterial = link_config["material"]
loader.set_link_material(
name, mat.static_friction, mat.dynamic_friction, mat.restitution
)
- if "patch_radius" in link_cfg:
- loader.set_link_patch_radius(name, link_cfg["patch_radius"])
- if "min_patch_radius" in link_cfg:
- loader.set_link_min_patch_radius(name, link_cfg["min_patch_radius"])
- if "density" in link_cfg:
- loader.set_link_density(name, link_cfg["density"])
+ if "patch_radius" in link_config:
+ loader.set_link_patch_radius(name, link_config["patch_radius"])
+ if "min_patch_radius" in link_config:
+ loader.set_link_min_patch_radius(name, link_config["min_patch_radius"])
+ if "density" in link_config:
+ loader.set_link_density(name, link_config["density"])
if "material" in urdf_config:
mat: physx.PhysxMaterial = urdf_config["material"]
loader.set_material(mat.static_friction, mat.dynamic_friction, mat.restitution)
diff --git a/mani_skill/utils/scene_builder/replicacad/scene_builder.py b/mani_skill/utils/scene_builder/replicacad/scene_builder.py
index 16ae24a8e..0a02db52f 100644
--- a/mani_skill/utils/scene_builder/replicacad/scene_builder.py
+++ b/mani_skill/utils/scene_builder/replicacad/scene_builder.py
@@ -6,8 +6,8 @@
import json
import os.path as osp
-from typing import Dict, List, Tuple, Union
from pathlib import Path
+from typing import Dict, List, Tuple, Union
import numpy as np
import sapien
@@ -16,9 +16,9 @@
from mani_skill import ASSET_DIR
from mani_skill.agents.robots.fetch import (
- Fetch,
- FETCH_WHEELS_COLLISION_BIT,
FETCH_BASE_COLLISION_BIT,
+ FETCH_WHEELS_COLLISION_BIT,
+ Fetch,
)
from mani_skill.utils.scene_builder import SceneBuilder
from mani_skill.utils.scene_builder.registration import register_scene_builder
@@ -82,14 +82,14 @@ def build(self, build_config_idxs: Union[int, List[int]]):
env_idx = [i for i, v in enumerate(build_config_idxs) if v == bci]
unique_id = "scs-" + str(env_idx).replace(" ", "")
- build_cfg_path = self.build_configs[bci]
+ build_config_path = self.build_configs[bci]
# We read the json config file describing the scene setup for the selected ReplicaCAD scene
with open(
osp.join(
ASSET_DIR,
"scene_datasets/replica_cad_dataset/configs/scenes",
- build_cfg_path,
+ build_config_path,
),
"rb",
) as f:
@@ -132,19 +132,19 @@ def build(self, build_config_idxs: Union[int, List[int]]):
# Again, for any dataset you will have to figure out how they reference object files
# Note that ASSET_DIR will always refer to the ~/.maniskill/data folder or whatever MS_ASSET_DIR is set to
- obj_cfg_path = osp.join(
+ obj_config_path = osp.join(
ASSET_DIR,
"scene_datasets/replica_cad_dataset/configs/objects",
f"{osp.basename(obj_meta['template_name'])}.object_config.json",
)
- with open(obj_cfg_path) as f:
- obj_cfg = json.load(f)
+ with open(obj_config_path) as f:
+ obj_config = json.load(f)
visual_file = osp.join(
- osp.dirname(obj_cfg_path), obj_cfg["render_asset"]
+ osp.dirname(obj_config_path), obj_config["render_asset"]
)
- if "collision_asset" in obj_cfg:
+ if "collision_asset" in obj_config:
collision_file = osp.join(
- osp.dirname(obj_cfg_path), obj_cfg["collision_asset"]
+ osp.dirname(obj_config_path), obj_config["collision_asset"]
)
builder = self.scene.create_actor_builder()
pos = obj_meta["translation"]
@@ -157,8 +157,8 @@ def build(self, build_config_idxs: Union[int, List[int]]):
if obj_meta["motion_type"] == "DYNAMIC":
builder.add_visual_from_file(visual_file)
if (
- "use_bounding_box_for_collision" in obj_cfg
- and obj_cfg["use_bounding_box_for_collision"]
+ "use_bounding_box_for_collision" in obj_config
+ and obj_config["use_bounding_box_for_collision"]
):
# some dynamic objects do not have decomposed convex meshes and instead should use a simple bounding box for collision detection
# in this case we use the add_convex_collision_from_file function of SAPIEN which just creates a convex collision based on the visual mesh
@@ -215,12 +215,12 @@ def build(self, build_config_idxs: Union[int, List[int]]):
# for now classify articulated objects as "movable" object
for env_num in env_idx:
- self.articulations[f"env-{env_num}_{articulation.name}"] = (
- articulation
- )
- self.scene_objects[f"env-{env_num}_{articulation.name}"] = (
- articulation
- )
+ self.articulations[
+ f"env-{env_num}_{articulation.name}"
+ ] = articulation
+ self.scene_objects[
+ f"env-{env_num}_{articulation.name}"
+ ] = articulation
# ReplicaCAD also specifies where to put lighting
with open(
@@ -230,19 +230,19 @@ def build(self, build_config_idxs: Union[int, List[int]]):
f"{osp.basename(build_config_json['default_lighting'])}.lighting_config.json",
)
) as f:
- lighting_cfg = json.load(f)
- for light_cfg in lighting_cfg["lights"].values():
+ lighting_config = json.load(f)
+ for light_config in lighting_config["lights"].values():
# It appears ReplicaCAD only specifies point light sources so we only use those here
- if light_cfg["type"] == "point":
+ if light_config["type"] == "point":
light_pos_fixed = (
- sapien.Pose(q=q) * sapien.Pose(p=light_cfg["position"])
+ sapien.Pose(q=q) * sapien.Pose(p=light_config["position"])
).p
# In SAPIEN, one can set color to unbounded values, higher just means more intense. ReplicaCAD provides color and intensity separately so
# we multiply it together here. We also take absolute value of intensity since some scene configs write negative intensities (which result in black holes)
self.scene.add_point_light(
light_pos_fixed,
- color=np.array(light_cfg["color"])
- * np.abs(light_cfg["intensity"]),
+ color=np.array(light_config["color"])
+ * np.abs(light_config["intensity"]),
scene_idxs=env_idx,
)
self.scene.set_ambient_light([0.3, 0.3, 0.3])
diff --git a/mani_skill/utils/structs/render_camera.py b/mani_skill/utils/structs/render_camera.py
index 3fdf18a13..9a256e8f5 100644
--- a/mani_skill/utils/structs/render_camera.py
+++ b/mani_skill/utils/structs/render_camera.py
@@ -154,22 +154,27 @@ def get_model_matrix(self):
def get_near(self) -> float:
return self._render_cameras[0].get_near()
- def get_picture(self, name: str):
- if physx.is_gpu_enabled():
+ def get_picture(self, names: Union[str, List[str]]) -> List[torch.Tensor]:
+ if isinstance(names, str):
+ names = [names]
+ if physx.is_gpu_enabled() and not self.scene.parallel_in_single_scene:
if SAPIEN_RENDER_SYSTEM == "3.0":
- return self.camera_group.get_picture_cuda(name).torch()
+ return [
+ self.camera_group.get_picture_cuda(name).torch() for name in names
+ ]
elif SAPIEN_RENDER_SYSTEM == "3.1":
- return self.camera_group.get_cuda_pictures([name])[0].torch()
+ return [x.torch() for x in self.camera_group.get_cuda_pictures(names)]
else:
- return common.to_tensor(self._render_cameras[0].get_picture(name))[
- None, ...
+ return [
+ common.to_tensor(self._render_cameras[0].get_picture(name))[None, ...]
+ for name in names
]
- def get_picture_cuda(self, name: str):
- return self._render_cameras[0].get_picture_cuda(name)
+ # def get_picture_cuda(self, name: str):
+ # return self._render_cameras[0].get_picture_cuda(name)
- def get_picture_names(self) -> list[str]:
- return self._render_cameras[0].get_picture_names()
+ # def get_picture_names(self) -> list[str]:
+ # return self._render_cameras[0].get_picture_names()
def get_projection_matrix(self):
return self._render_cameras[0].get_projection_matrix()
diff --git a/mani_skill/utils/structs/types.py b/mani_skill/utils/structs/types.py
index 662734cda..ab911f089 100644
--- a/mani_skill/utils/structs/types.py
+++ b/mani_skill/utils/structs/types.py
@@ -87,9 +87,9 @@ class SimConfig:
"""simulation frequency (Hz)"""
control_freq: int = 20
"""control frequency (Hz). Every control step (e.g. env.step) contains sim_freq / control_freq physx simulation steps"""
- gpu_memory_cfg: GPUMemoryConfig = field(default_factory=GPUMemoryConfig)
- scene_cfg: SceneConfig = field(default_factory=SceneConfig)
- default_materials_cfg: DefaultMaterialsConfig = field(
+ gpu_memory_config: GPUMemoryConfig = field(default_factory=GPUMemoryConfig)
+ scene_config: SceneConfig = field(default_factory=SceneConfig)
+ default_materials_config: DefaultMaterialsConfig = field(
default_factory=DefaultMaterialsConfig
)
diff --git a/mani_skill/utils/visualization/__init__.py b/mani_skill/utils/visualization/__init__.py
index 7b2d36824..48e913fd6 100644
--- a/mani_skill/utils/visualization/__init__.py
+++ b/mani_skill/utils/visualization/__init__.py
@@ -1,10 +1,3 @@
from .jupyter_utils import display_images
-from .misc import (
- images_to_video,
- normalize_depth,
- observations_to_images,
- put_info_on_image,
- put_text_on_image,
- tile_images,
-)
+from .misc import images_to_video, put_info_on_image, put_text_on_image, tile_images
from .renderer import ImageRenderer
diff --git a/mani_skill/utils/visualization/misc.py b/mani_skill/utils/visualization/misc.py
index 6539b6e22..4c316c3ad 100644
--- a/mani_skill/utils/visualization/misc.py
+++ b/mani_skill/utils/visualization/misc.py
@@ -51,61 +51,6 @@ def images_to_video(
writer.close()
-def normalize_depth(depth, min_depth=0, max_depth=None):
- if min_depth is None:
- min_depth = depth.min()
- if max_depth is None:
- max_depth = depth.max()
- depth = (depth - min_depth) / (max_depth - min_depth)
- depth = depth.clip(0, 1)
- return depth
-
-
-def observations_to_images(observations, max_depth=None) -> List[Array]:
- """Parse images from camera observations."""
- images = []
- # is_torch = False
- # if torch is not None:
- # is_torch = isinstance(images[0], torch.Tensor)
- for key in observations:
- if "rgb" in key or "Color" in key:
- rgb = observations[key][..., :3]
- if rgb.dtype == np.float32:
- rgb = np.clip(rgb * 255, 0, 255).astype(np.uint8)
- if torch is not None and rgb.dtype == torch.float:
- rgb = torch.clip(rgb * 255, 0, 255).to(torch.uint8)
- images.append(rgb)
- elif "depth" in key or "Position" in key:
- depth = observations[key]
- if "Position" in key: # [H, W, 4]
- depth = -depth[..., 2:3]
- # [H, W, 1]
- depth = normalize_depth(depth, max_depth=max_depth)
- depth = (depth * 255).clip(0, 255)
- if isinstance(depth, np.ndarray):
- depth = depth.astype(np.uint8)
- depth = np.repeat(depth, 3, axis=-1)
- else:
- depth = depth.to(torch.uint8)
- depth = torch.repeat_interleave(depth, 3, dim=-1)
- images.append(depth)
- elif "seg" in key:
- seg: Array = observations[key] # [H, W, 1]
- assert seg.ndim == 3 and seg.shape[-1] == 1, seg.shape
- # A heuristic way to colorize labels
- seg = np.uint8(seg * [11, 61, 127]) # [H, W, 3]
- images.append(seg)
- elif "Segmentation" in key:
- seg: Array = observations[key] # [H, W, 4]
- assert seg.ndim == 3 and seg.shape[-1] == 4, seg.shape
- # A heuristic way to colorize labels
- visual_seg = np.uint8(seg[..., 0:1] * [11, 61, 127]) # [H, W, 3]
- actor_seg = np.uint8(seg[..., 1:2] * [11, 61, 127]) # [H, W, 3]
- images.append(visual_seg)
- images.append(actor_seg)
- return images
-
-
def tile_images(images: List[Array], nrows=1) -> Array:
"""
Tile multiple images to a single image comprised of nrows and an appropriate number of columns to fit all the images.
diff --git a/mani_skill/utils/wrappers/visual_encoders.py b/mani_skill/utils/wrappers/visual_encoders.py
index 8a1c1ad3d..46c6b41c0 100644
--- a/mani_skill/utils/wrappers/visual_encoders.py
+++ b/mani_skill/utils/wrappers/visual_encoders.py
@@ -8,7 +8,7 @@
class VisualEncoderWrapper(gym.ObservationWrapper):
- def __init__(self, env, encoder: Literal["r3m"], encoder_cfg=dict()):
+ def __init__(self, env, encoder: Literal["r3m"], encoder_config=dict()):
self.base_env: BaseEnv = env.unwrapped
assert encoder == "r3m", "Only encoder='r3m' is supported at the moment"
if encoder == "r3m":
diff --git a/mani_skill/viewer/__init__.py b/mani_skill/viewer/__init__.py
new file mode 100644
index 000000000..44bf90cb7
--- /dev/null
+++ b/mani_skill/viewer/__init__.py
@@ -0,0 +1,42 @@
+import sapien
+from sapien.utils import Viewer
+
+from mani_skill.render import SAPIEN_RENDER_SYSTEM
+from mani_skill.sensors.camera import CameraConfig
+
+
+def create_viewer(viewer_camera_config: CameraConfig):
+ """Creates a viewer with the given camera config"""
+ if SAPIEN_RENDER_SYSTEM == "3.0":
+ sapien.render.set_viewer_shader_dir(
+ viewer_camera_config.shader_config.shader_pack
+ )
+ if viewer_camera_config.shader_config.shader_pack[:2] == "rt":
+ sapien.render.set_ray_tracing_denoiser(
+ viewer_camera_config.shader_config.shader_pack_config[
+ "ray_tracing_denoiser"
+ ]
+ )
+ sapien.render.set_ray_tracing_path_depth(
+ viewer_camera_config.shader_config.shader_pack_config[
+ "ray_tracing_path_depth"
+ ]
+ )
+ sapien.render.set_ray_tracing_samples_per_pixel(
+ viewer_camera_config.shader_config.shader_pack_config[
+ "ray_tracing_samples_per_pixel"
+ ]
+ )
+ viewer = Viewer(
+ resolutions=(viewer_camera_config.width, viewer_camera_config.height)
+ )
+ elif SAPIEN_RENDER_SYSTEM == "3.1":
+ # TODO (stao): figure out how shader pack configs can be set at run time
+ viewer = Viewer(
+ resolutions=(viewer_camera_config.width, viewer_camera_config.height),
+ shader_pack=sapien.render.get_shader_pack(
+ viewer_camera_config.shader_config.shader_pack
+ ),
+ )
+
+ return viewer
diff --git a/tests/test_envs.py b/tests/test_envs.py
index 62b76e3b6..ab95cc863 100644
--- a/tests/test_envs.py
+++ b/tests/test_envs.py
@@ -50,8 +50,6 @@ def test_envs_obs_modes(env_id, obs_mode):
assert obs["sensor_data"][cam]["rgb"].shape == (128, 128, 3)
assert obs["sensor_data"][cam]["depth"].shape == (128, 128, 1)
assert obs["sensor_data"][cam]["depth"].dtype == np.int16
- assert obs["sensor_data"][cam]["segmentation"].shape == (128, 128, 1)
- assert obs["sensor_data"][cam]["segmentation"].dtype == np.int16
assert obs["sensor_param"][cam]["extrinsic_cv"].shape == (3, 4)
assert obs["sensor_param"][cam]["intrinsic_cv"].shape == (3, 3)
assert obs["sensor_param"][cam]["cam2world_gl"].shape == (4, 4)
@@ -61,6 +59,14 @@ def test_envs_obs_modes(env_id, obs_mode):
assert obs["pointcloud"]["rgb"].shape == (num_pts, 3)
assert obs["pointcloud"]["segmentation"].shape == (num_pts, 1)
assert obs["pointcloud"]["segmentation"].dtype == np.int16
+ elif obs_mode == "rgb":
+ for cam in obs["sensor_data"].keys():
+ assert obs["sensor_data"][cam]["rgb"].shape == (128, 128, 3)
+ assert obs["sensor_param"][cam]["extrinsic_cv"].shape == (3, 4)
+ elif obs_mode == "depth+segmentation":
+ for cam in obs["sensor_data"].keys():
+ assert obs["sensor_data"][cam]["depth"].shape == (128, 128, 1)
+ assert obs["sensor_data"][cam]["segmentation"].shape == (128, 128, 1)
env.close()
del env
@@ -85,8 +91,6 @@ def test_envs_obs_modes_without_cpu_gym_wrapper(env_id, obs_mode):
assert obs["sensor_data"][cam]["rgb"].shape == (1, 128, 128, 3)
assert obs["sensor_data"][cam]["depth"].shape == (1, 128, 128, 1)
assert obs["sensor_data"][cam]["depth"].dtype == torch.int16
- assert obs["sensor_data"][cam]["segmentation"].shape == (1, 128, 128, 1)
- assert obs["sensor_data"][cam]["segmentation"].dtype == torch.int16
assert obs["sensor_param"][cam]["extrinsic_cv"].shape == (1, 3, 4)
assert obs["sensor_param"][cam]["intrinsic_cv"].shape == (1, 3, 3)
assert obs["sensor_param"][cam]["cam2world_gl"].shape == (1, 4, 4)
@@ -96,6 +100,14 @@ def test_envs_obs_modes_without_cpu_gym_wrapper(env_id, obs_mode):
assert obs["pointcloud"]["rgb"].shape == (1, num_pts, 3)
assert obs["pointcloud"]["segmentation"].shape == (1, num_pts, 1)
assert obs["pointcloud"]["segmentation"].dtype == torch.int16
+ elif obs_mode == "rgb":
+ for cam in obs["sensor_data"].keys():
+ assert obs["sensor_data"][cam]["rgb"].shape == (1, 128, 128, 3)
+ assert obs["sensor_param"][cam]["extrinsic_cv"].shape == (1, 3, 4)
+ elif obs_mode == "depth+segmentation":
+ for cam in obs["sensor_data"].keys():
+ assert obs["sensor_data"][cam]["depth"].shape == (1, 128, 128, 1)
+ assert obs["sensor_data"][cam]["segmentation"].shape == (1, 128, 128, 1)
env.close()
del env
diff --git a/tests/test_gpu_envs.py b/tests/test_gpu_envs.py
index dcfe5fba3..ed0f55d0a 100644
--- a/tests/test_gpu_envs.py
+++ b/tests/test_gpu_envs.py
@@ -10,7 +10,7 @@
from tests.utils import (
CONTROL_MODES_STATIONARY_SINGLE_ARM,
ENV_IDS,
- LOW_MEM_SIM_CFG,
+ LOW_MEM_SIM_CONFIG,
MULTI_AGENT_ENV_IDS,
OBS_MODES,
SINGLE_ARM_STATIONARY_ROBOTS,
@@ -24,10 +24,10 @@
@pytest.mark.gpu_sim
@pytest.mark.parametrize("env_id", ENV_IDS)
def test_all_envs(env_id):
- sim_cfg = dict()
+ sim_config = dict()
if "Scene" not in env_id:
- sim_cfg = LOW_MEM_SIM_CFG
- env = gym.make(env_id, num_envs=16, obs_mode="state", sim_cfg=sim_cfg)
+ sim_config = LOW_MEM_SIM_CONFIG
+ env = gym.make(env_id, num_envs=16, obs_mode="state", sim_config=sim_config)
obs, _ = env.reset()
action_space = env.action_space
for _ in range(5):
@@ -47,7 +47,7 @@ def assert_device(x):
env_id,
num_envs=16,
vectorization_mode="custom",
- vector_kwargs=dict(obs_mode=obs_mode, sim_cfg=LOW_MEM_SIM_CFG),
+ vector_kwargs=dict(obs_mode=obs_mode, sim_config=LOW_MEM_SIM_CONFIG),
)
obs, _ = env.reset()
assert_isinstance(obs, torch.Tensor)
@@ -71,8 +71,6 @@ def assert_device(x):
assert obs["sensor_data"][cam]["rgb"].shape == (16, 128, 128, 3)
assert obs["sensor_data"][cam]["depth"].shape == (16, 128, 128, 1)
assert obs["sensor_data"][cam]["depth"].dtype == torch.int16
- assert obs["sensor_data"][cam]["segmentation"].shape == (16, 128, 128, 1)
- assert obs["sensor_data"][cam]["segmentation"].dtype == torch.int16
assert obs["sensor_param"][cam]["extrinsic_cv"].shape == (16, 3, 4)
assert obs["sensor_param"][cam]["intrinsic_cv"].shape == (16, 3, 3)
assert obs["sensor_param"][cam]["cam2world_gl"].shape == (16, 4, 4)
@@ -82,6 +80,14 @@ def assert_device(x):
assert obs["pointcloud"]["rgb"].shape == (16, num_pts, 3)
assert obs["pointcloud"]["segmentation"].shape == (16, num_pts, 1)
assert obs["pointcloud"]["segmentation"].dtype == torch.int16
+ elif obs_mode == "rgb":
+ for cam in obs["sensor_data"].keys():
+ assert obs["sensor_data"][cam]["rgb"].shape == (16, 128, 128, 3)
+ assert obs["sensor_param"][cam]["extrinsic_cv"].shape == (16, 3, 4)
+ elif obs_mode == "depth+segmentation":
+ for cam in obs["sensor_data"].keys():
+ assert obs["sensor_data"][cam]["depth"].shape == (16, 128, 128, 1)
+ assert obs["sensor_data"][cam]["segmentation"].shape == (16, 128, 128, 1)
env.close()
del env
@@ -94,7 +100,7 @@ def assert_device(x):
# env_id,
# num_envs=16,
# vectorization_mode="custom",
-# vector_kwargs=dict(obs_mode=obs_mode, sim_cfg=LOW_MEM_SIM_CFG),
+# vector_kwargs=dict(obs_mode=obs_mode, sim_config=LOW_MEM_SIM_CONFIG),
# )
# obs, _ = env.reset()
# assert_isinstance(obs, torch.Tensor)
@@ -118,7 +124,7 @@ def test_env_control_modes(env_id, control_mode):
env_id,
num_envs=16,
vectorization_mode="custom",
- vector_kwargs=dict(control_mode=control_mode, sim_cfg=LOW_MEM_SIM_CFG),
+ vector_kwargs=dict(control_mode=control_mode, sim_config=LOW_MEM_SIM_CONFIG),
)
env.reset()
action_space = env.action_space
@@ -191,7 +197,7 @@ def test_env_reconfiguration(env_id):
def test_raw_sim_states():
# Test sim state get and set works for environment without overriden get_state_dict functions
env = gym.make(
- "PickCube-v1", num_envs=16, obs_mode="state_dict", sim_cfg=LOW_MEM_SIM_CFG
+ "PickCube-v1", num_envs=16, obs_mode="state_dict", sim_config=LOW_MEM_SIM_CONFIG
)
base_env: BaseEnv = env.unwrapped
obs1, _ = env.reset()
@@ -233,7 +239,7 @@ def test_robots(env_id, robot_uids):
env_id,
num_envs=16,
vectorization_mode="custom",
- vector_kwargs=dict(robot_uids=robot_uids, sim_cfg=LOW_MEM_SIM_CFG),
+ vector_kwargs=dict(robot_uids=robot_uids, sim_config=LOW_MEM_SIM_CONFIG),
)
env.reset()
action_space = env.action_space
@@ -250,7 +256,7 @@ def test_multi_agent(env_id):
env_id,
num_envs=16,
vectorization_mode="custom",
- vector_kwargs=dict(sim_cfg=LOW_MEM_SIM_CFG),
+ vector_kwargs=dict(sim_config=LOW_MEM_SIM_CONFIG),
)
env.reset()
action_space = env.action_space
@@ -270,7 +276,7 @@ def test_partial_resets(env_id):
env_id,
num_envs=16,
vectorization_mode="custom",
- vector_kwargs=dict(sim_cfg=LOW_MEM_SIM_CFG),
+ vector_kwargs=dict(sim_config=LOW_MEM_SIM_CONFIG),
)
obs, _ = env.reset()
action_space = env.action_space
@@ -298,7 +304,7 @@ def test_timelimits():
"PickCube-v1",
num_envs=16,
vectorization_mode="custom",
- vector_kwargs=dict(sim_cfg=LOW_MEM_SIM_CFG),
+ vector_kwargs=dict(sim_config=LOW_MEM_SIM_CONFIG),
)
obs, _ = env.reset()
for _ in range(50):
diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py
index ab902d003..f72699a3e 100644
--- a/tests/test_wrappers.py
+++ b/tests/test_wrappers.py
@@ -10,7 +10,7 @@
from mani_skill.utils.wrappers.visual_encoders import VisualEncoderWrapper
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
from tests.utils import (
- LOW_MEM_SIM_CFG,
+ LOW_MEM_SIM_CONFIG,
MULTI_AGENT_ENV_IDS,
OBS_MODES,
STATIONARY_ENV_IDS,
@@ -27,7 +27,7 @@ def test_recordepisode_wrapper_gpu(env_id, obs_mode):
render_mode="rgb_array",
max_episode_steps=10,
num_envs=16,
- sim_cfg=LOW_MEM_SIM_CFG,
+ sim_config=LOW_MEM_SIM_CONFIG,
)
env = RecordEpisode(
env,
@@ -81,11 +81,11 @@ def test_recordepisode_wrapper_gpu_render_sensor(env_id, obs_mode):
obs_mode=obs_mode,
render_mode="sensors",
num_envs=16,
- sim_cfg=LOW_MEM_SIM_CFG,
+ sim_config=LOW_MEM_SIM_CONFIG,
)
env = RecordEpisode(
env,
- output_dir=f"videos/pytest/{env_id}-gpu-render-sensor",
+ output_dir=f"videos/pytest/{env_id}-gpu-{obs_mode}-render-sensor",
trajectory_name=f"test_traj_{obs_mode}",
save_trajectory=True,
max_steps_per_video=50,
@@ -113,7 +113,7 @@ def test_recordepisode_wrapper_render_sensor(env_id, obs_mode):
)
env = RecordEpisode(
env,
- output_dir=f"videos/pytest/{env_id}-render-sensor",
+ output_dir=f"videos/pytest/{env_id}-{obs_mode}-render-sensor",
trajectory_name=f"test_traj_{obs_mode}",
info_on_video=True,
)
@@ -136,11 +136,11 @@ def test_recordepisode_wrapper_partial_reset_gpu(env_id, obs_mode):
obs_mode=obs_mode,
render_mode="rgb_array",
num_envs=16,
- sim_cfg=LOW_MEM_SIM_CFG,
+ sim_config=LOW_MEM_SIM_CONFIG,
)
env = RecordEpisode(
env,
- output_dir=f"videos/pytest/{env_id}-gpu-partial-resets",
+ output_dir=f"videos/pytest/{env_id}-gpu-{obs_mode}-partial-resets",
trajectory_name=f"test_traj_{obs_mode}",
save_trajectory=True,
max_steps_per_video=50,
@@ -169,11 +169,11 @@ def test_recordepisode_wrapper_partial_reset(env_id, obs_mode):
obs_mode=obs_mode,
num_envs=1,
render_mode="rgb_array",
- sim_cfg=LOW_MEM_SIM_CFG,
+ sim_config=LOW_MEM_SIM_CONFIG,
)
env = RecordEpisode(
env,
- output_dir=f"videos/pytest/{env_id}-partial-resets",
+ output_dir=f"videos/pytest/{env_id}-{obs_mode}-partial-resets",
trajectory_name=f"test_traj_{obs_mode}",
save_trajectory=True,
max_steps_per_video=50,
@@ -202,7 +202,7 @@ def test_visualencoders_gpu(env_id):
render_mode="rgb_array",
max_episode_steps=10,
num_envs=16,
- sim_cfg=LOW_MEM_SIM_CFG,
+ sim_config=LOW_MEM_SIM_CONFIG,
)
assert (
"embedding" not in env.observation_space.keys()
@@ -238,7 +238,7 @@ def test_visualencoder_flatten_gpu(env_id):
render_mode="rgb_array",
max_episode_steps=10,
num_envs=16,
- sim_cfg=LOW_MEM_SIM_CFG,
+ sim_config=LOW_MEM_SIM_CONFIG,
)
env = VisualEncoderWrapper(env, encoder="r3m")
env = FlattenObservationWrapper(env)
@@ -264,7 +264,7 @@ def test_visualencoder_flatten_gpu(env_id):
@pytest.mark.gpu_sim
@pytest.mark.parametrize("env_id", MULTI_AGENT_ENV_IDS[:1])
def test_multi_agent_flatten_action_space_gpu(env_id):
- env = gym.make(env_id, num_envs=16, sim_cfg=LOW_MEM_SIM_CFG)
+ env = gym.make(env_id, num_envs=16, sim_config=LOW_MEM_SIM_CONFIG)
env = FlattenActionSpaceWrapper(env)
env.reset()
action_space = env.action_space
diff --git a/tests/utils.py b/tests/utils.py
index 1d9a0bcac..20678da99 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -28,7 +28,10 @@
OBS_MODES = [
"state_dict",
"state",
+ "rgb",
"rgbd",
+ "rgb+depth+segmentation",
+ "depth+segmentation",
"pointcloud",
# "rgbd_robot_seg",
# "pointcloud_robot_seg",
@@ -42,8 +45,10 @@
]
SINGLE_ARM_STATIONARY_ROBOTS = ["panda", "xmate3_robotiq"]
-LOW_MEM_SIM_CFG = dict(
- gpu_memory_cfg=dict(max_rigid_patch_count=81920, found_lost_pairs_capacity=262144)
+LOW_MEM_SIM_CONFIG = dict(
+ gpu_memory_config=dict(
+ max_rigid_patch_count=81920, found_lost_pairs_capacity=262144
+ )
)