Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional fix while retraining policies #629

Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions lerobot/common/policies/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def create_stats_buffers(
stats_buffers = {}

for key, ft in features.items():
norm_mode = norm_map.get(ft.type, None)
if norm_mode is None:
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue

assert isinstance(norm_mode, NormalizationMode)
Expand Down Expand Up @@ -140,8 +140,8 @@ def __init__(
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, ft in self.features.items():
norm_mode = self.norm_map.get(ft.type, None)
if norm_mode is None:
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue

buffer = getattr(self, "buffer_" + key.replace(".", "_"))
Expand Down Expand Up @@ -210,7 +210,10 @@ def __init__(
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, ft in self.features.items():
norm_mode = self.norm_map.get(ft.type, None)
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue

buffer = getattr(self, "buffer_" + key.replace(".", "_"))

if norm_mode is NormalizationMode.MEAN_STD:
Expand Down
8 changes: 4 additions & 4 deletions lerobot/common/policies/tdmpc/configuration_tdmpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,11 @@ class TDMPCConfig(PreTrainedConfig):
horizon: int = 5
n_action_steps: int = 1

normalization_mapping: dict[str, NormalizationMode | None] = field(
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": None,
"STATE": None,
"ENV": None,
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
"ENV": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.MIN_MAX,
}
)
Expand Down
2 changes: 1 addition & 1 deletion lerobot/common/policies/vqbet/configuration_vqbet.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class VQBeTConfig(PreTrainedConfig):

normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
"VISUAL": NormalizationMode.IDENTITY,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: This is because although input_normalization_modes was mean_std in vqbet.yaml config, it was a hack for not normalizing images (with normalization values of 0.5)

"STATE": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
Expand Down
13 changes: 13 additions & 0 deletions lerobot/configs/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,19 @@ class OnlineConfig:
# + eval + environment rendering simultaneously.
do_rollout_async: bool = False

def __post_init__(self):
if self.steps == 0:
return

if self.steps_between_rollouts is None:
raise ValueError(
"'steps_between_rollouts' must be set to a positive integer, but it is currently None."
)
if self.env_seed is None:
raise ValueError("'env_seed' must be set to a positive integer, but it is currently None.")
if self.buffer_capacity is None:
raise ValueError("'buffer_capacity' must be set to a positive integer, but it is currently None.")


@dataclass
class TrainPipelineConfig(HubMixin):
Expand Down
1 change: 1 addition & 0 deletions lerobot/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class FeatureType(str, Enum):
class NormalizationMode(str, Enum):
MIN_MAX = "MIN_MAX"
MEAN_STD = "MEAN_STD"
IDENTITY = "IDENTITY"


class DictLike(Protocol):
Expand Down
51 changes: 34 additions & 17 deletions lerobot/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,9 @@ def evaluate_and_checkpoint_if_needed(step: int, is_online: bool):
},
"next.reward": {"shape": (), "dtype": np.dtype("float32")},
"next.done": {"shape": (), "dtype": np.dtype("?")},
"next.success": {"shape": (), "dtype": np.dtype("?")},
"task_index": {"shape": (), "dtype": np.dtype("int64")},
# Removed next.success, since it's not used anywhere for now and offline dataset doesnt have it
# "next.success": {"shape": (), "dtype": np.dtype("?")},
},
buffer_capacity=cfg.online.buffer_capacity,
fps=online_env.unwrapped.metadata["render_fps"],
Expand Down Expand Up @@ -400,12 +402,14 @@ def evaluate_and_checkpoint_if_needed(step: int, is_online: bool):
)
dl_iter = cycle(dataloader)

# Lock and thread pool executor for asynchronous online rollouts. When asynchronous mode is disabled,
# these are still used but effectively do nothing.
lock = Lock()
# Note: 1 worker because we only ever want to run one set of online rollouts at a time. Batch
# parallelization of rollouts is handled within the job.
executor = ThreadPoolExecutor(max_workers=1)
if cfg.online.do_rollout_async:
# Lock and thread pool executor for asynchronous online rollouts.
lock = Lock()
# Note: 1 worker because we only ever want to run one set of online rollouts at a time. Batch
# parallelization of rollouts is handled within the job.
executor = ThreadPoolExecutor(max_workers=1)
else:
lock = None

online_step = 0
online_rollout_s = 0 # time take to do online rollout
Expand All @@ -424,10 +428,13 @@ def evaluate_and_checkpoint_if_needed(step: int, is_online: bool):

def sample_trajectory_and_update_buffer():
nonlocal rollout_start_seed
with lock:

with lock if lock is not None else nullcontext():
online_rollout_policy.load_state_dict(policy.state_dict())

online_rollout_policy.eval()
start_rollout_time = time.perf_counter()

with torch.no_grad():
eval_info = eval_policy(
online_env,
Expand All @@ -440,7 +447,14 @@ def sample_trajectory_and_update_buffer():
)
online_rollout_s = time.perf_counter() - start_rollout_time

with lock:
if len(offline_dataset.meta.tasks) > 1:
raise NotImplementedError("Add support for multi task.")

# Hack to add a task to the online_dataset (0 is the first task of the offline_dataset)
total_num_frames = eval_info["episodes"]["index"].shape[0]
eval_info["episodes"]["task_index"] = torch.tensor([0] * total_num_frames, dtype=torch.int64)

with lock if lock is not None else nullcontext():
start_update_buffer_time = time.perf_counter()
online_dataset.add_data(eval_info["episodes"])

Expand All @@ -463,19 +477,22 @@ def sample_trajectory_and_update_buffer():

return online_rollout_s, update_online_buffer_s

future = executor.submit(sample_trajectory_and_update_buffer)
# If we aren't doing async rollouts, or if we haven't yet gotten enough examples in our buffer, wait
# here until the rollout and buffer update is done, before proceeding to the policy update steps.
if not cfg.online.do_rollout_async or len(online_dataset) <= cfg.online.buffer_seed_size:
online_rollout_s, update_online_buffer_s = future.result()
if lock is None:
online_rollout_s, update_online_buffer_s = sample_trajectory_and_update_buffer()
else:
future = executor.submit(sample_trajectory_and_update_buffer)
# If we aren't doing async rollouts, or if we haven't yet gotten enough examples in our buffer, wait
# here until the rollout and buffer update is done, before proceeding to the policy update steps.
if len(online_dataset) <= cfg.online.buffer_seed_size:
online_rollout_s, update_online_buffer_s = future.result()

if len(online_dataset) <= cfg.online.buffer_seed_size:
logging.info(f"Seeding online buffer: {len(online_dataset)}/{cfg.online.buffer_seed_size}")
continue

policy.train()
for _ in range(cfg.online.steps_between_rollouts):
with lock:
with lock if lock is not None else nullcontext():
start_time = time.perf_counter()
batch = next(dl_iter)
dataloading_s = time.perf_counter() - start_time
Expand All @@ -498,7 +515,7 @@ def sample_trajectory_and_update_buffer():
train_info["online_rollout_s"] = online_rollout_s
train_info["update_online_buffer_s"] = update_online_buffer_s
train_info["await_update_online_buffer_s"] = await_update_online_buffer_s
with lock:
with lock if lock is not None else nullcontext():
train_info["online_buffer_size"] = len(online_dataset)

if step % cfg.log_freq == 0:
Expand All @@ -513,7 +530,7 @@ def sample_trajectory_and_update_buffer():

# If we're doing async rollouts, we should now wait until we've completed them before proceeding
# to do the next batch of rollouts.
if future.running():
if cfg.online.do_rollout_async and future.running():
start = time.perf_counter()
online_rollout_s, update_online_buffer_s = future.result()
await_update_online_buffer_s = time.perf_counter() - start
Expand Down
Loading