diff --git a/Makefile b/Makefile index 073b4183a..02ac76020 100644 --- a/Makefile +++ b/Makefile @@ -131,13 +131,12 @@ test-tdmpc-ete-eval: --eval.batch_size=1 \ --device=$(DEVICE) -# FIXME: currently broken -# test-tdmpc-ete-train-with-online: -# python lerobot/scripts/train.py \ +test-tdmpc-ete-train-with-online: + python lerobot/scripts/train.py \ --policy.type=tdmpc \ --env.type=pusht \ --env.obs_type=environment_state_agent_pos \ - --env.episode_length=10 \ + --env.episode_length=5 \ --dataset.repo_id=lerobot/pusht_keypoints \ --dataset.image_transforms.enable=true \ --dataset.episodes='[0]' \ @@ -147,7 +146,7 @@ test-tdmpc-ete-eval: --online.rollout_n_episodes=2 \ --online.rollout_batch_size=2 \ --online.steps_between_rollouts=10 \ - --online.buffer_capacity=15 \ + --online.buffer_capacity=1000 \ --online.env_seed=10000 \ --save_checkpoint=false \ --save_freq=10 \ diff --git a/lerobot/common/envs/configs.py b/lerobot/common/envs/configs.py index ac96f5c59..6259ca94a 100644 --- a/lerobot/common/envs/configs.py +++ b/lerobot/common/envs/configs.py @@ -99,6 +99,7 @@ def gym_kwargs(self) -> dict: "render_mode": self.render_mode, "visualization_width": self.visualization_width, "visualization_height": self.visualization_height, + "max_episode_steps": self.episode_length, } @@ -137,4 +138,5 @@ def gym_kwargs(self) -> dict: "render_mode": self.render_mode, "visualization_width": self.visualization_width, "visualization_height": self.visualization_height, + "max_episode_steps": self.episode_length, } diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py index d8f021d93..0d188d815 100644 --- a/lerobot/common/policies/normalize.py +++ b/lerobot/common/policies/normalize.py @@ -37,8 +37,8 @@ def create_stats_buffers( stats_buffers = {} for key, ft in features.items(): - norm_mode = norm_map.get(ft.type, None) - if norm_mode is None: + norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY) + if norm_mode is NormalizationMode.IDENTITY: continue assert isinstance(norm_mode, NormalizationMode) @@ -140,8 +140,8 @@ def __init__( def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: batch = dict(batch) # shallow copy avoids mutating the input batch for key, ft in self.features.items(): - norm_mode = self.norm_map.get(ft.type, None) - if norm_mode is None: + norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY) + if norm_mode is NormalizationMode.IDENTITY: continue buffer = getattr(self, "buffer_" + key.replace(".", "_")) @@ -210,7 +210,10 @@ def __init__( def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: batch = dict(batch) # shallow copy avoids mutating the input batch for key, ft in self.features.items(): - norm_mode = self.norm_map.get(ft.type, None) + norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY) + if norm_mode is NormalizationMode.IDENTITY: + continue + buffer = getattr(self, "buffer_" + key.replace(".", "_")) if norm_mode is NormalizationMode.MEAN_STD: diff --git a/lerobot/common/policies/tdmpc/configuration_tdmpc.py b/lerobot/common/policies/tdmpc/configuration_tdmpc.py index 5b128ec55..c3e8aee68 100644 --- a/lerobot/common/policies/tdmpc/configuration_tdmpc.py +++ b/lerobot/common/policies/tdmpc/configuration_tdmpc.py @@ -112,11 +112,11 @@ class TDMPCConfig(PreTrainedConfig): horizon: int = 5 n_action_steps: int = 1 - normalization_mapping: dict[str, NormalizationMode | None] = field( + normalization_mapping: dict[str, NormalizationMode] = field( default_factory=lambda: { - "VISUAL": None, - "STATE": None, - "ENV": None, + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.IDENTITY, + "ENV": NormalizationMode.IDENTITY, "ACTION": NormalizationMode.MIN_MAX, } ) diff --git a/lerobot/common/policies/vqbet/configuration_vqbet.py b/lerobot/common/policies/vqbet/configuration_vqbet.py index 2f3617cc6..47007e823 100644 --- a/lerobot/common/policies/vqbet/configuration_vqbet.py +++ b/lerobot/common/policies/vqbet/configuration_vqbet.py @@ -98,7 +98,7 @@ class VQBeTConfig(PreTrainedConfig): normalization_mapping: dict[str, NormalizationMode] = field( default_factory=lambda: { - "VISUAL": NormalizationMode.MEAN_STD, + "VISUAL": NormalizationMode.IDENTITY, "STATE": NormalizationMode.MIN_MAX, "ACTION": NormalizationMode.MIN_MAX, } diff --git a/lerobot/configs/train.py b/lerobot/configs/train.py index e567f5b9a..b540f2f0a 100644 --- a/lerobot/configs/train.py +++ b/lerobot/configs/train.py @@ -68,6 +68,19 @@ class OnlineConfig: # + eval + environment rendering simultaneously. do_rollout_async: bool = False + def __post_init__(self): + if self.steps == 0: + return + + if self.steps_between_rollouts is None: + raise ValueError( + "'steps_between_rollouts' must be set to a positive integer, but it is currently None." + ) + if self.env_seed is None: + raise ValueError("'env_seed' must be set to a positive integer, but it is currently None.") + if self.buffer_capacity is None: + raise ValueError("'buffer_capacity' must be set to a positive integer, but it is currently None.") + @dataclass class TrainPipelineConfig(HubMixin): diff --git a/lerobot/configs/types.py b/lerobot/configs/types.py index f31f437b1..0ca45a197 100644 --- a/lerobot/configs/types.py +++ b/lerobot/configs/types.py @@ -15,6 +15,7 @@ class FeatureType(str, Enum): class NormalizationMode(str, Enum): MIN_MAX = "MIN_MAX" MEAN_STD = "MEAN_STD" + IDENTITY = "IDENTITY" class DictLike(Protocol): diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index a7d3fa543..78facdeea 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -354,15 +354,17 @@ def evaluate_and_checkpoint_if_needed(step: int, is_online: bool): online_buffer_path, data_spec={ **{ - ft.key: {"shape": ft.shape, "dtype": np.dtype("float32")} - for ft in policy.config.input_features + key: {"shape": ft.shape, "dtype": np.dtype("float32")} + for key, ft in policy.config.input_features.items() }, **{ - ft.key: {"shape": ft.shape, "dtype": np.dtype("float32")} - for ft in policy.config.output_features + key: {"shape": ft.shape, "dtype": np.dtype("float32")} + for key, ft in policy.config.output_features.items() }, "next.reward": {"shape": (), "dtype": np.dtype("float32")}, "next.done": {"shape": (), "dtype": np.dtype("?")}, + "task_index": {"shape": (), "dtype": np.dtype("int64")}, + # FIXME: 'next.success' is expected by pusht env but not xarm "next.success": {"shape": (), "dtype": np.dtype("?")}, }, buffer_capacity=cfg.online.buffer_capacity, @@ -400,12 +402,14 @@ def evaluate_and_checkpoint_if_needed(step: int, is_online: bool): ) dl_iter = cycle(dataloader) - # Lock and thread pool executor for asynchronous online rollouts. When asynchronous mode is disabled, - # these are still used but effectively do nothing. - lock = Lock() - # Note: 1 worker because we only ever want to run one set of online rollouts at a time. Batch - # parallelization of rollouts is handled within the job. - executor = ThreadPoolExecutor(max_workers=1) + if cfg.online.do_rollout_async: + # Lock and thread pool executor for asynchronous online rollouts. + lock = Lock() + # Note: 1 worker because we only ever want to run one set of online rollouts at a time. Batch + # parallelization of rollouts is handled within the job. + executor = ThreadPoolExecutor(max_workers=1) + else: + lock = None online_step = 0 online_rollout_s = 0 # time take to do online rollout @@ -424,10 +428,13 @@ def evaluate_and_checkpoint_if_needed(step: int, is_online: bool): def sample_trajectory_and_update_buffer(): nonlocal rollout_start_seed - with lock: + + with lock if lock is not None else nullcontext(): online_rollout_policy.load_state_dict(policy.state_dict()) + online_rollout_policy.eval() start_rollout_time = time.perf_counter() + with torch.no_grad(): eval_info = eval_policy( online_env, @@ -440,7 +447,14 @@ def sample_trajectory_and_update_buffer(): ) online_rollout_s = time.perf_counter() - start_rollout_time - with lock: + if len(offline_dataset.meta.tasks) > 1: + raise NotImplementedError("Add support for multi task.") + + # Hack to add a task to the online_dataset (0 is the first task of the offline_dataset) + total_num_frames = eval_info["episodes"]["index"].shape[0] + eval_info["episodes"]["task_index"] = torch.tensor([0] * total_num_frames, dtype=torch.int64) + + with lock if lock is not None else nullcontext(): start_update_buffer_time = time.perf_counter() online_dataset.add_data(eval_info["episodes"]) @@ -463,11 +477,14 @@ def sample_trajectory_and_update_buffer(): return online_rollout_s, update_online_buffer_s - future = executor.submit(sample_trajectory_and_update_buffer) - # If we aren't doing async rollouts, or if we haven't yet gotten enough examples in our buffer, wait - # here until the rollout and buffer update is done, before proceeding to the policy update steps. - if not cfg.online.do_rollout_async or len(online_dataset) <= cfg.online.buffer_seed_size: - online_rollout_s, update_online_buffer_s = future.result() + if lock is None: + online_rollout_s, update_online_buffer_s = sample_trajectory_and_update_buffer() + else: + future = executor.submit(sample_trajectory_and_update_buffer) + # If we aren't doing async rollouts, or if we haven't yet gotten enough examples in our buffer, wait + # here until the rollout and buffer update is done, before proceeding to the policy update steps. + if len(online_dataset) <= cfg.online.buffer_seed_size: + online_rollout_s, update_online_buffer_s = future.result() if len(online_dataset) <= cfg.online.buffer_seed_size: logging.info(f"Seeding online buffer: {len(online_dataset)}/{cfg.online.buffer_seed_size}") @@ -475,7 +492,7 @@ def sample_trajectory_and_update_buffer(): policy.train() for _ in range(cfg.online.steps_between_rollouts): - with lock: + with lock if lock is not None else nullcontext(): start_time = time.perf_counter() batch = next(dl_iter) dataloading_s = time.perf_counter() - start_time @@ -498,7 +515,7 @@ def sample_trajectory_and_update_buffer(): train_info["online_rollout_s"] = online_rollout_s train_info["update_online_buffer_s"] = update_online_buffer_s train_info["await_update_online_buffer_s"] = await_update_online_buffer_s - with lock: + with lock if lock is not None else nullcontext(): train_info["online_buffer_size"] = len(online_dataset) if step % cfg.log_freq == 0: @@ -513,7 +530,7 @@ def sample_trajectory_and_update_buffer(): # If we're doing async rollouts, we should now wait until we've completed them before proceeding # to do the next batch of rollouts. - if future.running(): + if cfg.online.do_rollout_async and future.running(): start = time.perf_counter() online_rollout_s, update_online_buffer_s = future.result() await_update_online_buffer_s = time.perf_counter() - start