Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix demo minigrid #118

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open

fix demo minigrid #118

wants to merge 1 commit into from

Conversation

conwayz
Copy link

@conwayz conwayz commented Oct 25, 2024

fix two issues trying to run the commands from dev branch

python demo.py --env minigrid --mode train 
python demo.py --env minigrid --mode train  --vec multiprocessing
  1. Discrete has shape () and np.prod(()) = 1.0 so need to cast to int when building network
  2. set num agents correctly in MultiProcessing

however it does look like dev branch is slower (35k vs 50k sps)
1.0
Screenshot 2024-10-25 at 4 02 51 PM
dev
Screenshot 2024-10-25 at 4 04 05 PM

For the vector.py change my understanding is

  • MultiProcessing spawn num_workers workers 'sees' a portion of the buffer
  • The outer buffer has dimension (num_workers, agents_per_worker, *obs_shape) concretely let's say this is (6, 8, 160)
  • So each worker should get a 'slice' of shape of shape (agents_per_worker, *obs_shape) = (8, 160). Before this change the are only getting a slice of (1, 160) because driver_env.num_agents is 1.
  • From my understanding, driver_env is like the first env; in the worker process we use Serial so it's like one of the envs in Serial.
  • If the slice of observations is shape (1, 160) instead of (8, 160), then in Serial._assign_buffers when we assign parts of the buffer to the worker we will end up assigning empty slices and will see errors downstream

@leanke
Copy link
Contributor

leanke commented Oct 26, 2024

@conwayz I can test this in pokemon_red, but my concern is how this effects the envpool.

some of the environments we run use the envpool to batch out data from x amount of agents (agents_per_worker) at a time as the finish an episode. so essentially the batch dim you would see in the policy is equal the the number of agents in a single worker.

to clarify using your above example "(num_workers, agents_per_worker, *obs_shape) concretely let's say this is (6, 8, 160)" you would expect to see (8,160) if you were to check the shape of the observation from within the policy.

@conwayz
Copy link
Author

conwayz commented Oct 26, 2024

right i think this is the case in minigrid no? we pass in something like buffer[worker_idx] so like you mentioned the shape of the observation would be (8, 160) which is indexed by the number of agents per worker (in thise case 8)

@conwayz
Copy link
Author

conwayz commented Oct 26, 2024

ill try to test a bunch of other envs

@conwayz
Copy link
Author

conwayz commented Oct 30, 2024

@jsuarez5341 @leanke are these changes needed to run demo on dev or was my setup somehow busted?

@leanke
Copy link
Contributor

leanke commented Oct 30, 2024

@conwayz let me test minigrid here in a bit but last i tested it had worked

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants