Skip to content

Commit

Permalink
fix(pu): fix cache update()
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Jan 20, 2025
1 parent 51d4f17 commit db958d7
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 22 deletions.
24 changes: 11 additions & 13 deletions lzero/mcts/tree_search/mcts_ctree_sampled.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,19 +137,17 @@ def search(
for ix, iy in zip(latent_state_index_in_search_path, latent_state_index_in_batch):
latent_states.append(latent_state_batch_in_search_path[ix][iy])


try:
latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device)
except Exception as e:
print("="*20)
print(e)
# print("latent_states raw:", latent_states)
print("roots:", roots, "latent_state_roots:", latent_state_roots)
print ("latent_state_roots.shape:", latent_state_roots.shape)
# if not all(isinstance(x, np.ndarray) and x.shape == latent_states[0].shape for x in latent_states):
# raise ValueError(f"Inconsistent latent_states shapes: {[x.shape if isinstance(x, np.ndarray) else type(x) for x in latent_states]}")
import ipdb; ipdb.set_trace()

# try:
latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device)
# except Exception as e:
# print("="*20)
# print(e)
# # print("latent_states raw:", latent_states)
# print("roots:", roots, "latent_state_roots:", latent_state_roots)
# print ("latent_state_roots.shape:", latent_state_roots.shape)
# # if not all(isinstance(x, np.ndarray) and x.shape == latent_states[0].shape for x in latent_states):
# # raise ValueError(f"Inconsistent latent_states shapes: {[x.shape if isinstance(x, np.ndarray) else type(x) for x in latent_states]}")
# import ipdb; ipdb.set_trace()

if self._cfg.model.continuous_action_space is True:
# continuous action
Expand Down
3 changes: 2 additions & 1 deletion lzero/model/unizero_world_models/kv_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def update(self, x: torch.Tensor, tokens: int) -> None:
# Shift existing cache data by removing the oldest entries
shift_amount = required_capacity - self._cache.shape[2]
# =======TODO: 应该去掉偶数个(z,a)以保证head输出pattern保持不变=======
shift_amount = shift_amount+1
if shift_amount%2 != 0:
shift_amount = shift_amount+1
if shift_amount >= self._size:
# If the shift amount exceeds or equals the current size, just reset the cache
print("Cache too small; resetting the entire cache")
Expand Down
16 changes: 10 additions & 6 deletions lzero/model/unizero_world_models/world_model_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,12 +748,12 @@ def _add_position_embeddings(self, embeddings, prev_steps, num_steps, kvcache_in
else:
valid_context_lengths = torch.tensor(self.keys_values_wm_size_list_current, device=self.device)

try:
position_embeddings = self.pos_emb(
valid_context_lengths + torch.arange(num_steps, device=self.device)).unsqueeze(1)
except Exception as e:
print(e)
import ipdb; ipdb.set_trace()
# try:
position_embeddings = self.pos_emb(
valid_context_lengths + torch.arange(num_steps, device=self.device)).unsqueeze(1)
# except Exception as e:
# print(e)
# import ipdb; ipdb.set_trace()

return embeddings + position_embeddings

Expand Down Expand Up @@ -1099,6 +1099,8 @@ def forward_recurrent_inference(self, state_action_history, simulation_index=0,
Returns:
- tuple: A tuple containing output sequence, updated latent state, reward, logits policy, and logits value.
"""
# import ipdb; ipdb.set_trace()

latest_state, action = state_action_history[-1]
ready_env_num = latest_state.shape[0]

Expand Down Expand Up @@ -1162,6 +1164,8 @@ def forward_recurrent_inference(self, state_action_history, simulation_index=0,
token = outputs_wm.logits_observations
if len(token.shape) != 3:
token = token.unsqueeze(1) # (8,1024) -> (8,1,1024)
# print(f'token.shape:{token.shape}')

latent_state_list.append(token)

del self.latent_state # Very important to minimize cuda memory usage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def create_config(env_id, observation_shape_list, action_space_size_list, collec
action_space_size_list=action_space_size_list,
from_pixels=False,
# ===== TODO: only for debug =====
# frame_skip=10, # 10
# frame_skip=100, # 10
frame_skip=2,
continuous=True, # Assuming all DMC tasks use continuous action spaces
collector_env_num=collector_env_num,
Expand Down Expand Up @@ -323,7 +323,7 @@ def create_env_manager():
num_segments = 2
n_episode = 2
evaluator_env_num = 2
num_simulations = 1
num_simulations = 50
batch_size = [2 for _ in range(len(env_id_list))]
# =======================================

Expand Down

0 comments on commit db958d7

Please sign in to comment.