Skip to content

Commit

Permalink
Add feature types to envs
Browse files Browse the repository at this point in the history
  • Loading branch information
aliberts committed Dec 27, 2024
1 parent fa55e67 commit ba31014
Showing 1 changed file with 25 additions and 15 deletions.
40 changes: 25 additions & 15 deletions lerobot/common/envs/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,15 @@

import draccus


@dataclass
class GymConfig:
obs_type: str = "pixels_agent_pos"
render_mode: str = "rgb_array"
from lerobot.configs.types import FeatureType


@dataclass
class EnvConfig(draccus.ChoiceRegistry):
n_envs: int | None = None
task: str | None = None
state_dim: int = 18
action_dim: int = 18
fps: int = 30
feature_types: dict = field(default_factory=dict)

@property
def type(self) -> str:
Expand All @@ -32,10 +27,17 @@ class RealEnv(EnvConfig):
@dataclass
class AlohaEnv(EnvConfig):
task: str = "AlohaInsertion-v0"
state_dim: int = 14
action_dim: int = 14
fps: int = 50
episode_length: int = 400
feature_types: dict = field(
default_factory=lambda: {
"agent_pos": FeatureType.STATE,
"pixels": {
"top": FeatureType.VISUAL,
},
"action": FeatureType.ACTION,
}
)
gym: dict = field(
default_factory=lambda: {
"obs_type": "pixels_agent_pos",
Expand All @@ -48,11 +50,15 @@ class AlohaEnv(EnvConfig):
@dataclass
class PushtEnv(EnvConfig):
task: str = "PushT-v0"
state_dim: int = 2
action_dim: int = 2
image_size: int = 96
fps: int = 10
episode_length: int = 300
feature_types: dict = field(
default_factory=lambda: {
"agent_pos": FeatureType.STATE,
"pixels": FeatureType.VISUAL,
"action": FeatureType.ACTION,
}
)
gym: dict = field(
default_factory=lambda: {
"obs_type": "pixels_agent_pos",
Expand All @@ -67,11 +73,15 @@ class PushtEnv(EnvConfig):
@dataclass
class XarmEnv(EnvConfig):
task: str = "XarmLift-v0"
state_dim: int = 4
action_dim: int = 4
image_size: int = 84
fps: int = 15
episode_length: int = 200
feature_types: dict = field(
default_factory=lambda: {
"agent_pos": FeatureType.STATE,
"pixels": FeatureType.VISUAL,
"action": FeatureType.ACTION,
}
)
gym: dict = field(
default_factory=lambda: {
"obs_type": "pixels_agent_pos",
Expand Down

0 comments on commit ba31014

Please sign in to comment.