Skip to content

Commit

Permalink
Add EvalPipelineConfig, parse features from envs
Browse files Browse the repository at this point in the history
  • Loading branch information
aliberts committed Dec 27, 2024
1 parent ba31014 commit 87d92f9
Show file tree
Hide file tree
Showing 24 changed files with 506 additions and 460 deletions.
4 changes: 2 additions & 2 deletions lerobot/common/datasets/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
MultiLeRobotDataset,
)
from lerobot.common.datasets.transforms import ImageTransforms
from lerobot.configs.default import MainConfig
from lerobot.configs.policies import PretrainedConfig
from lerobot.configs.training import TrainPipelineConfig

IMAGENET_STATS = {
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
Expand Down Expand Up @@ -56,7 +56,7 @@ def resolve_delta_timestamps(
return delta_timestamps


def make_dataset(cfg: MainConfig, split: str = "train") -> LeRobotDataset | MultiLeRobotDataset:
def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDataset:
"""
Args:
cfg: A Hydra config as per the LeRobot config scheme.
Expand Down
13 changes: 13 additions & 0 deletions lerobot/common/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from torchvision import transforms

from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.configs.types import DictLike

DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk

Expand Down Expand Up @@ -98,6 +99,18 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict:
return outdict


def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any:
split_keys = flattened_key.split(sep)
getter = obj[split_keys[0]]
if len(split_keys) == 1:
return getter

for key in split_keys[1:]:
getter = getter[key]

return getter


def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
serialized_dict = {key: value.tolist() for key, value in flatten_dict(stats).items()}
return unflatten_dict(serialized_dict)
Expand Down
25 changes: 12 additions & 13 deletions lerobot/common/envs/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,39 +17,38 @@

import gymnasium as gym

from lerobot.configs.default import MainConfig
from lerobot.common.envs.configs import EnvConfig


def make_env(cfg: MainConfig) -> gym.vector.VectorEnv | None:
def make_env(
cfg: EnvConfig, n_envs: int | None = None, use_async_envs: bool = False
) -> gym.vector.VectorEnv | None:
"""Makes a gym vector environment according to the evaluation config.
n_envs can be used to override eval.batch_size in the configuration. Must be at least 1.
"""
n_envs = cfg.training.online.rollout_batch_size
if n_envs is not None and n_envs < 1:
raise ValueError("`n_envs must be at least 1")

if cfg.env.type == "real_world":
if cfg.type == "real_world":
return

package_name = f"gym_{cfg.env.type}"
package_name = f"gym_{cfg.type}"

try:
importlib.import_module(package_name)
except ModuleNotFoundError as e:
print(
f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.env.type}]'`"
)
print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`")
raise e

gym_handle = f"{package_name}/{cfg.env.task}"
gym_kwgs = getattr(cfg.env, "gym", {})
gym_handle = f"{package_name}/{cfg.task}"
gym_kwgs = getattr(cfg, "gym", {})

if getattr(cfg.env, "episode_length", None):
gym_kwgs["max_episode_steps"] = cfg.env.episode_length
if getattr(cfg, "episode_length", None):
gym_kwgs["max_episode_steps"] = cfg.episode_length

# batched version of the env that returns an observation of shape (b, c)
env_cls = gym.vector.AsyncVectorEnv if cfg.eval.use_async_envs else gym.vector.SyncVectorEnv
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
env = env_cls(
[
lambda: gym.make(gym_handle, disable_env_checker=True, **gym_kwgs)
Expand Down
1 change: 1 addition & 0 deletions lerobot/common/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@


def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
# TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding)
"""Convert environment observation to LeRobot format observation.
Args:
observation: Dictionary of observation batches from a Gym vector environment.
Expand Down
6 changes: 3 additions & 3 deletions lerobot/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@

from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
from lerobot.configs.default import MainConfig
from lerobot.configs.policies import FeatureType, NormalizationMode
from lerobot.configs.training import TrainPipelineConfig
from lerobot.configs.types import FeatureType, NormalizationMode


def log_output_dir(out_dir):
Expand Down Expand Up @@ -86,7 +86,7 @@ class Logger:
pretrained_model_dir_name = "pretrained_model"
training_state_file_name = "training_state.pth"

def __init__(self, cfg: MainConfig):
def __init__(self, cfg: TrainPipelineConfig):
self._cfg = cfg
self.log_dir = cfg.dir
self.log_dir.mkdir(parents=True, exist_ok=True)
Expand Down
8 changes: 5 additions & 3 deletions lerobot/common/optim/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
from torch.optim.lr_scheduler import LRScheduler

from lerobot.common.policies import Policy
from lerobot.configs.default import MainConfig
from lerobot.configs.training import TrainPipelineConfig


def make_optimizer_and_scheduler(cfg: MainConfig, policy: Policy) -> tuple[Optimizer, LRScheduler | None]:
def make_optimizer_and_scheduler(
cfg: TrainPipelineConfig, policy: Policy
) -> tuple[Optimizer, LRScheduler | None]:
params = policy.get_optim_params() if cfg.use_policy_training_preset else policy.parameters()
optimizer = cfg.optimizer.build(params)
lr_scheduler = cfg.scheduler.build(optimizer, cfg.training.offline.steps)
lr_scheduler = cfg.scheduler.build(optimizer, cfg.offline.steps)
return optimizer, lr_scheduler
5 changes: 4 additions & 1 deletion lerobot/common/policies/act/configuration_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

from lerobot.common.optim.optimizers import AdamWConfig
from lerobot.common.optim.schedulers import NoneSchedulerConfig
from lerobot.configs.policies import NormalizationMode, PretrainedConfig
from lerobot.configs.policies import PretrainedConfig
from lerobot.configs.types import NormalizationMode


@PretrainedConfig.register_subclass("act")
Expand Down Expand Up @@ -138,6 +139,8 @@ class ACTConfig(PretrainedConfig):
optimizer_lr_backbone: float = 1e-5

def __post_init__(self):
super().__post_init__()

"""Input validation (not exhaustive)."""
if not self.vision_backbone.startswith("resnet"):
raise ValueError(
Expand Down
5 changes: 4 additions & 1 deletion lerobot/common/policies/diffusion/configuration_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

from lerobot.common.optim.optimizers import AdamConfig
from lerobot.common.optim.schedulers import DiffuserSchedulerConfig
from lerobot.configs.policies import NormalizationMode, PretrainedConfig
from lerobot.configs.policies import PretrainedConfig
from lerobot.configs.types import NormalizationMode


@PretrainedConfig.register_subclass("diffusion")
Expand Down Expand Up @@ -159,6 +160,8 @@ class DiffusionConfig(PretrainedConfig):
scheduler_warmup_steps: int = 500

def __post_init__(self):
super().__post_init__()

"""Input validation (not exhaustive)."""
if not self.vision_backbone.startswith("resnet"):
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion lerobot/common/policies/diffusion/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def conditional_sample(

# Sample prior.
sample = torch.randn(
size=(batch_size, self.config.horizon, self.action_feature.shape[0]),
size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]),
dtype=dtype,
device=device,
generator=generator,
Expand Down
38 changes: 29 additions & 9 deletions lerobot/common/policies/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import gymnasium as gym
from torch import nn

from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.common.envs.configs import EnvConfig
from lerobot.common.policies.policy_protocol import Policy
from lerobot.configs.default import MainConfig
from lerobot.configs.policies import PretrainedConfig


def get_policy_class(name: str) -> Policy:
Expand All @@ -44,7 +46,12 @@ def get_policy_class(name: str) -> Policy:


def make_policy(
cfg: MainConfig, ds_meta: LeRobotDatasetMetadata, pretrained_policy_name_or_path: str | None = None
cfg: PretrainedConfig,
device: str,
ds_meta: LeRobotDatasetMetadata | None = None,
env: gym.Env | None = None,
env_cfg: EnvConfig | None = None,
pretrained_policy_name_or_path: str | None = None,
) -> Policy:
"""Make an instance of a policy class.
Expand All @@ -57,35 +64,48 @@ def make_policy(
directory containing weights saved using `Policy.save_pretrained`. Note that providing this
argument overrides everything in `hydra_cfg.policy` apart from `hydra_cfg.policy.type`.
"""
if not (ds_meta is None) ^ (env is None and env_cfg is None):
raise ValueError("Either one of a dataset metadata or a sim env must be provided.")

# Note: Currently, if you try to run vqbet with mps backend, you'll get this error.
# NotImplementedError: The operator 'aten::unique_dim' is not currently implemented for the MPS device. If
# you want this op to be added in priority during the prototype phase of this feature, please comment on
# https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment
# variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be
# slower than running natively on MPS.
if cfg.policy.type == "vqbet" and cfg.device == "mps":
if cfg.type == "vqbet" and device == "mps":
raise NotImplementedError(
"Current implementation of VQBeT does not support `mps` backend. "
"Please use `cpu` or `cuda` backend."
)

policy_cls = get_policy_class(cfg.policy.type)
cfg.policy.parse_features_from_dataset(ds_meta)
policy_cls = get_policy_class(cfg.type)

kwargs = {}
if ds_meta is not None:
cfg.parse_features_from_dataset(ds_meta)
kwargs["dataset_stats"] = ds_meta.stats
else:
cfg.parse_features_from_env(env, env_cfg)

kwargs["config"] = cfg

if pretrained_policy_name_or_path is None:
# Make a fresh policy.
policy = policy_cls(cfg.policy, ds_meta.stats)
policy = policy_cls(**kwargs)
else:
kwargs["pretrained_model_name_or_path"] = pretrained_policy_name_or_path
policy = policy_cls.from_pretrained(**kwargs)
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
# hyperparameters that we want to vary).
# TODO(alexander-soare): This hack makes use of huggingface_hub's tooling to load the policy with,
# pretrained weights which are then loaded into a fresh policy with the desired config. This PR in
# huggingface_hub should make it possible to avoid the hack:
# https://github.com/huggingface/huggingface_hub/pull/2274.
policy = policy_cls(cfg.policy)
policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict())
# policy = policy_cls(cfg)
# policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict())

policy.to(cfg.device)
policy.to(device)
assert isinstance(policy, nn.Module)

return policy
3 changes: 2 additions & 1 deletion lerobot/common/policies/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
import torch
from torch import Tensor, nn

from lerobot.configs.policies import FeatureType, NormalizationMode, PolicyFeature
from lerobot.configs.policies import PolicyFeature
from lerobot.configs.types import FeatureType, NormalizationMode


def create_stats_buffers(
Expand Down
5 changes: 4 additions & 1 deletion lerobot/common/policies/tdmpc/configuration_tdmpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

from lerobot.common.optim.optimizers import AdamConfig
from lerobot.common.optim.schedulers import NoneSchedulerConfig
from lerobot.configs.policies import NormalizationMode, PretrainedConfig
from lerobot.configs.policies import PretrainedConfig
from lerobot.configs.types import NormalizationMode


@PretrainedConfig.register_subclass("tdmpc")
Expand Down Expand Up @@ -161,6 +162,8 @@ class TDMPCConfig(PretrainedConfig):
optimizer_lr: float = 3e-4

def __post_init__(self):
super().__post_init__()

"""Input validation (not exhaustive)."""
if self.n_gaussian_samples <= 0:
raise ValueError(
Expand Down
31 changes: 31 additions & 0 deletions lerobot/common/policies/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path

import torch
from huggingface_hub import snapshot_download
from huggingface_hub.errors import RepositoryNotFoundError
from huggingface_hub.utils._validators import HFValidationError
from torch import nn


Expand Down Expand Up @@ -64,3 +70,28 @@ def get_output_shape(module: nn.Module, input_shape: tuple) -> tuple:
with torch.inference_mode():
output = module(dummy_input)
return tuple(output.shape)


def get_pretrained_policy_path(pretrained_policy_name_or_path, revision=None):
try:
pretrained_policy_path = Path(
snapshot_download(str(pretrained_policy_name_or_path), revision=revision)
)
except (HFValidationError, RepositoryNotFoundError) as e:
if isinstance(e, HFValidationError):
error_message = (
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
)
else:
error_message = (
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
)

logging.warning(f"{error_message} Treating it as a local directory.")
pretrained_policy_path = Path(pretrained_policy_name_or_path)
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
raise ValueError(
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "
"repo ID, nor is it an existing local directory."
)
return pretrained_policy_path
6 changes: 5 additions & 1 deletion lerobot/common/policies/vqbet/configuration_vqbet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@

from lerobot.common.optim.optimizers import AdamConfig
from lerobot.common.optim.schedulers import VQBeTSchedulerConfig
from lerobot.configs.policies import NormalizationMode, PretrainedConfig
from lerobot.configs.policies import PretrainedConfig
from lerobot.configs.types import NormalizationMode


@PretrainedConfig.register_subclass("vqbet")
Expand Down Expand Up @@ -137,9 +138,12 @@ class VQBeTConfig(PretrainedConfig):
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-6
optimizer_vqvae_lr: float = 1e-3
optimizer_vqvae_weight_decay: float = 1e-4
scheduler_warmup_steps: int = 500

def __post_init__(self):
super().__post_init__()

"""Input validation (not exhaustive)."""
if not self.vision_backbone.startswith("resnet"):
raise ValueError(
Expand Down
Loading

0 comments on commit 87d92f9

Please sign in to comment.