Skip to content

Commit

Permalink
[update] generate simulator callbacks from config file
Browse files Browse the repository at this point in the history
  • Loading branch information
phython96 committed Jan 7, 2025
1 parent 985ea1a commit 91c3869
Show file tree
Hide file tree
Showing 40 changed files with 127 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ defaults:
- base
- _self_
text: Dig three blocks down and fill one block.
reference_video: build_dig3fill1
1 change: 1 addition & 0 deletions minestudio/benchmark/task_configs/simple/build_gate.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ defaults:
- base
- _self_
text: Build a gate to secure your area in Minecraft.
reference_video: build_gate
1 change: 1 addition & 0 deletions minestudio/benchmark/task_configs/simple/build_golems.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ defaults:
- base
- _self_
text: Build golems to protect your area from hostile mobs.
reference_video: build_golems
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ defaults:
- base
- _self_
text: Create obsidian blocks by pouring water over lava source blocks.
reference_video: build_obsidian
1 change: 1 addition & 0 deletions minestudio/benchmark/task_configs/simple/build_pillar.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ defaults:
- base
- _self_
text: Build a pillar using blocks of your choice.
reference_video: build_pillar
1 change: 1 addition & 0 deletions minestudio/benchmark/task_configs/simple/collect_dirt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ defaults:
- base
- _self_
text: Collect dirt blocks from the ground.
reference_video: collect_dirt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ defaults:
- base
- _self_
text: Collect grass blocks using the appropriate tool.
reference_video: collect_grass
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ defaults:
- base
- _self_
text: Collect seagrass from underwater areas.
reference_video: collect_seagrass
1 change: 1 addition & 0 deletions minestudio/benchmark/task_configs/simple/collect_wood.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ defaults:
- base
- _self_
text: 'Collect wood from trees in your surroundings. '
reference_video: collect_wood
1 change: 1 addition & 0 deletions minestudio/benchmark/task_configs/simple/collect_wool.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ defaults:
- base
- _self_
text: Shear sheep to collect wool.
reference_video: collect_wool
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ defaults:
- base
- _self_
text: Craft an enchanting table to enchant your items.
reference_video: craft_enchantment
1 change: 1 addition & 0 deletions minestudio/benchmark/task_configs/simple/craft_ladder.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ defaults:
- base
- _self_
text: Open your inventory and craft a ladder using sticks.
reference_video: craft_ladder
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ defaults:
- base
- _self_
text: Craft a furnace for smelting items.
reference_video: craft_smelting
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ defaults:
- base
- _self_
text: Open your inventory and craft a stonecutter.
reference_video: craft_stonecut
1 change: 1 addition & 0 deletions minestudio/benchmark/task_configs/simple/craft_table.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ defaults:
- base
- _self_
text: 'Open your inventory and craft a crafting table using wooden planks. '
reference_video: craft_table
1 change: 1 addition & 0 deletions minestudio/benchmark/task_configs/simple/explore_boat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ defaults:
- base
- _self_
text: Build and use a boat to explore water bodies in Minecraft.
reference_video: explore_boat
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ defaults:
- base
- _self_
text: Explore and open a chest to discover its contents.
reference_video: explore_chest
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ defaults:
- base
- _self_
text: Explore and climb a mountain.
reference_video: explore_climb
1 change: 1 addition & 0 deletions minestudio/benchmark/task_configs/simple/explore_mine.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ defaults:
- base
- _self_
text: Explore a mine to gather resources and discover treasures.
reference_video: explore_mine
1 change: 1 addition & 0 deletions minestudio/benchmark/task_configs/simple/explore_run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ defaults:
- _self_
text: Explore the world of Minecraft by traveling across various biomes and finding
interesting locations.
reference_video: explore_run
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ defaults:
- base
- _self_
text: Survive combat against hostile mobs in Minecraft.
reference_video: survive_combat
1 change: 1 addition & 0 deletions minestudio/benchmark/task_configs/simple/survive_hunt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ defaults:
- base
- _self_
text: Survive a hunt by using tools, weapons, and strategic defenses.
reference_video: survive_hunt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ defaults:
- base
- _self_
text: Survive by planting crops and gathering resources.
reference_video: survive_plant
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ defaults:
- base
- _self_
text: Find a safe place and rest in a bed to skip the night.
reference_video: survive_sleep
1 change: 1 addition & 0 deletions minestudio/benchmark/task_configs/simple/tool_bow.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ defaults:
- base
- _self_
text: Craft a bow to use as your ranged weapon.
reference_video: tool_bow
1 change: 1 addition & 0 deletions minestudio/benchmark/task_configs/simple/tool_flint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ defaults:
- base
- _self_
text: 'Obtain flint by breaking gravel blocks. '
reference_video: tool_flint
1 change: 1 addition & 0 deletions minestudio/benchmark/task_configs/simple/tool_lead.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ defaults:
- base
- _self_
text: Craft a lead using slimeballs and string.
reference_video: tool_lead
1 change: 1 addition & 0 deletions minestudio/benchmark/task_configs/simple/tool_pumpkin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ defaults:
- base
- _self_
text: Create a carved pumpkin to use as a decorative item or for crafting.
reference_video: tool_pumpkin
1 change: 1 addition & 0 deletions minestudio/benchmark/task_configs/simple/tool_trident.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ defaults:
- base
- _self_
text: Craft a trident using the necessary materials.
reference_video: tool_trident
8 changes: 5 additions & 3 deletions minestudio/benchmark/task_configs/upload.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
'''
Date: 2025-01-06 20:08:00
LastEditors: caishaofei-mus1 [email protected]
LastEditTime: 2025-01-07 10:24:24
LastEditors: caishaofei [email protected]
LastEditTime: 2025-01-07 11:17:00
FilePath: /MineStudio/minestudio/benchmark/task_configs/upload.py
'''
from huggingface_hub import HfApi
Expand All @@ -10,5 +10,7 @@
local_dir = "./simple"
repo_id = "CraftJarvis/MineStudio_task_group.simple"
api = HfApi()
api.create_repo(repo_id=repo_id, repo_type='dataset')
# api.create_repo(repo_id=repo_id, repo_type='dataset')
# if not api.repo_exists(repo_id=repo_id, repo_type='dataset'):
# api.create_repo(repo_id=repo_id, repo_type='dataset')
api.upload_folder(folder_path=local_dir, repo_id=repo_id, repo_type="dataset")
11 changes: 7 additions & 4 deletions minestudio/benchmark/utility/read_conf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
'''
Date: 2024-12-06 16:42:49
LastEditors: caishaofei-mus1 [email protected]
LastEditTime: 2025-01-07 10:40:24
LastEditors: caishaofei [email protected]
LastEditTime: 2025-01-07 11:08:39
FilePath: /MineStudio/minestudio/benchmark/utility/read_conf.py
'''
import os
Expand All @@ -27,13 +27,16 @@ def convert_yaml_to_callbacks(yaml_file):

return commands, task_dict

def prepare_task_configs(group_name: str, path: Optional[str] = None):
def prepare_task_configs(group_name: str, path: Optional[str] = None, refresh: bool = False) -> Dict:
"""
group_name: str - used to specify the group name
path: str - can be a local directory or a huggingface repo_id
"""
root_dir = get_mine_studio_dir()
local_dir = os.path.join(root_dir, "task_configs", group_name)
if refresh and os.path.exists(local_dir):
print(f"Refreshing the cache: removing existing task configs from: {local_dir}")
shutil.rmtree(local_dir)
if not os.path.exists(local_dir):
if os.path.isdir(path):
shutil.copytree(path, local_dir)
Expand All @@ -48,5 +51,5 @@ def prepare_task_configs(group_name: str, path: Optional[str] = None):
local_dir=local_dir,
repo_type='dataset'
)
yaml_files = [ (file_path.stem, str(file_path)) for file_path in Path(local_dir).rglob("*.yaml") ]
yaml_files = { file_path.stem: str(file_path) for file_path in Path(local_dir).rglob("*.yaml") }
return yaml_files
16 changes: 11 additions & 5 deletions minestudio/inference/filter/info_base_filter.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,29 @@
'''
Date: 2024-11-25 12:39:01
LastEditors: caishaofei [email protected]
LastEditTime: 2024-11-25 13:19:22
LastEditTime: 2025-01-07 08:19:00
FilePath: /MineStudio/minestudio/inference/filter/info_base_filter.py
'''
import re
import pickle
from minestudio.inference.filter.base_filter import EpisodeFilter

class InfoBaseFilter(EpisodeFilter):

def __init__(self, key: str, val: str, num: int, label: str = "status"):
def __init__(self, key: str, regex: str, num: int, label: str = "status"):
self.key = key
self.val = val
self.regex = regex
self.num = num
self.label = label

def filter(self, episode_generator):
for episode in episode_generator:
info = pickle.loads(open(episode["info_path"], "rb").read())
if info[-1][self.key].get(self.val, 0) >= self.num:
total = 0
last_info = info[-1][self.key]
for event in last_info:
if re.match(self.regex, event):
total += last_info.get(event, 0)
if total >= self.num:
episode[self.label] = "yes"
yield episode
yield episode
18 changes: 9 additions & 9 deletions minestudio/models/groot_one/body.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
'''
Date: 2024-11-25 07:03:41
LastEditors: caishaofei [email protected]
LastEditTime: 2025-01-04 17:05:47
LastEditTime: 2025-01-07 10:25:24
FilePath: /MineStudio/minestudio/models/groot_one/body.py
'''
import torch
Expand All @@ -14,6 +14,8 @@
import av

import timm
from huggingface_hub import PyTorchModelHubMixin

from minestudio.models.base_policy import MinePolicy
from minestudio.utils.vpt_lib.util import FanInInitReLULayer, ResidualRecurrentBlocks
from minestudio.utils.register import Registers
Expand Down Expand Up @@ -163,7 +165,7 @@ def initial_state(self, batch_size: int = None) -> List[torch.Tensor]:
return [t.to(device) for t in self.recurrent.initial_state(batch_size)]

@Registers.model.register
class GrootPolicy(MinePolicy):
class GrootPolicy(MinePolicy, PyTorchModelHubMixin):

def __init__(
self,
Expand Down Expand Up @@ -251,7 +253,7 @@ def forward(self, input: Dict, memory: Optional[List[torch.Tensor]] = None) -> D
image = self.updim(image)
image = rearrange(image, '(b t) c h w -> b t c h w', b=b)

if 'ref_video_path' in input or self.condition is not None:
if 'ref_video_path' in input:
if self.condition is None:
self.encode_video(input['ref_video_path'])
condition = self.condition
Expand Down Expand Up @@ -288,11 +290,8 @@ def initial_state(self, *args, **kwargs) -> Any:
@Registers.model_loader.register
def load_groot_policy(ckpt_path: str = None):
if ckpt_path is None:
from minestudio.models.utils.download import download_model
local_dir = download_model("GROOT")
if local_dir is None:
assert False, "Please specify the ckpt_path or download the model first."
ckpt_path = os.path.join(local_dir, "groot.ckpt")
repo_id = "CraftJarvis/MineStudio_GROOT.18w_EMA"
return GrootPolicy.from_pretrained("CraftJarvis/MineStudio_GROOT.18w_EMA")

ckpt = torch.load(ckpt_path)
model = GrootPolicy(**ckpt['hyper_parameters']['model'])
Expand All @@ -301,8 +300,9 @@ def load_groot_policy(ckpt_path: str = None):
return model

if __name__ == '__main__':
model = load_groot_policy()
model = GrootPolicy(
backbone='timm/vit_base_patch16_224.dino',
backbone='timm/vit_base_patch32_clip_224.openai',
hiddim=1024,
freeze_backbone=False,
video_encoder_kwargs=dict(
Expand Down
14 changes: 6 additions & 8 deletions minestudio/simulator/callbacks/commands.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
'''
Date: 2024-11-11 19:31:53
LastEditors: caishaofei [email protected]
LastEditTime: 2025-01-07 03:19:46
LastEditTime: 2025-01-07 11:58:44
FilePath: /MineStudio/minestudio/simulator/callbacks/commands.py
'''
import os
Expand All @@ -28,13 +28,11 @@ def __init__(self, commands):

def after_reset(self, sim, obs, info):
for command in self.commands:
obs, reward, done, info = sim.env.execute_cmd(command)
_obs, reward, done, info = sim.env.execute_cmd(command)
obs.update(_obs)
info.update(info)
obs, info = sim._wrap_obs_info(obs, info)
return obs, info

if __name__ == '__main__':
yaml_file = '/home/caishaofei/tmpdir/MineStudio/task_configs/debug_task/build_gate.yaml'
commands_callback = CommandsCallback.create_from_conf(yaml_file)
print(commands_callback)


def __repr__(self):
return f"CommandsCallback(commands={self.commands})"
Loading

0 comments on commit 91c3869

Please sign in to comment.