diff --git a/pufferlib/models.py b/pufferlib/models.py index 5e946f0e..93b1bcf8 100644 --- a/pufferlib/models.py +++ b/pufferlib/models.py @@ -35,7 +35,7 @@ def __init__(self, env, hidden_size=128): if self.is_dict_obs: self.dtype = pufferlib.pytorch.nativize_dtype(env.emulated) - input_size = sum(np.prod(v.shape) for v in env.env.observation_space.values()) + input_size = sum(int(np.prod(v.shape)) for v in env.env.observation_space.values()) self.encoder = nn.Linear(input_size, self.hidden_size) else: self.encoder = nn.Linear(np.prod(env.single_observation_space.shape), hidden_size) diff --git a/pufferlib/vector.py b/pufferlib/vector.py index 62c0f736..73fe3670 100644 --- a/pufferlib/vector.py +++ b/pufferlib/vector.py @@ -372,7 +372,7 @@ def __init__(self, env_creators, env_args, env_kwargs, target=_worker_process, args=(env_creators[start:end], env_args[start:end], env_kwargs[start:end], obs_shape, obs_dtype, - atn_shape, atn_dtype, envs_per_worker, driver_env.num_agents, + atn_shape, atn_dtype, envs_per_worker, agents_per_worker, num_workers, i, w_send_pipes[i], w_recv_pipes[i], self.shm, is_native) )