Skip to content

Commit

Permalink
Add pusht hack
Browse files Browse the repository at this point in the history
  • Loading branch information
aliberts committed Jan 28, 2025
1 parent 214083f commit aa65bb7
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions lerobot/common/datasets/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@
"std": [[[0.229]], [[0.224]], [[0.225]]], # (c,1,1)
}

# Stats used in the original Diffusion repo
PUSHT_STATS = {
"observation.state": {
"min": [13.456424, 32.938293],
"max": [496.14618, 510.9579],
},
"action": {
"min": [12.0, 25.0],
"max": [511.0, 511.0],
},
}


def resolve_delta_timestamps(
cfg: PreTrainedConfig, ds_meta: LeRobotDatasetMetadata
Expand Down Expand Up @@ -112,4 +124,11 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
for stats_type, stats in IMAGENET_STATS.items():
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)

# HACK for pusht
if cfg.dataset.repo_id.startswith("lerobot/pusht"):
for key in PUSHT_STATS:
if key in dataset.meta.features:
for stats_type, stats in PUSHT_STATS[key].items():
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)

return dataset

0 comments on commit aa65bb7

Please sign in to comment.