Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance of AgentSet and iter_cell_list_contents #1964

Merged
merged 2 commits into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 16 additions & 19 deletions mesa/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from __future__ import annotations

import contextlib
import copy
import operator
import warnings
import weakref
Expand Down Expand Up @@ -124,9 +125,7 @@ def __init__(self, agents: Iterable[Agent], model: Model):
stacklevel=2,
)

self._agents = weakref.WeakKeyDictionary()
for agent in agents:
self._agents[agent] = None
self._agents = weakref.WeakKeyDictionary({agent: None for agent in agents})

def __len__(self) -> int:
"""Return the number of agents in the AgentSet."""
Expand Down Expand Up @@ -161,20 +160,21 @@ def select(
AgentSet: A new AgentSet containing the selected agents, unless inplace is True, in which case the current AgentSet is updated.
"""

def agent_generator():
if filter_func is None and agent_type is None and n == 0:
return self if inplace else copy.copy(self)

def agent_generator(filter_func=None, agent_type=None, n=0):
count = 0
for agent in self:
if filter_func and not filter_func(agent):
continue
if agent_type and not isinstance(agent, agent_type):
continue
yield agent
count += 1
# default of n is zero, zo evaluates to False
if n and count >= n:
break

agents = agent_generator()
if (not filter_func or filter_func(agent)) and (
not agent_type or isinstance(agent, agent_type)
):
yield agent
count += 1
if 0 < n <= count:
break

agents = agent_generator(filter_func, agent_type, n)

return AgentSet(agents, self.model) if not inplace else self._update(agents)

Expand Down Expand Up @@ -229,11 +229,8 @@ def _update(self, agents: Iterable[Agent]):
"""Update the AgentSet with a new set of agents.
This is a private method primarily used internally by other methods like select, shuffle, and sort.
"""
_agents = weakref.WeakKeyDictionary()
for agent in agents:
_agents[agent] = None

self._agents = _agents
self._agents = weakref.WeakKeyDictionary({agent: None for agent in agents})
return self

def do(
Expand Down
31 changes: 21 additions & 10 deletions mesa/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,10 @@ def iter_neighbors(
at most 9 if Moore, 5 if Von-Neumann
(8 and 4 if not including the center).
"""
neighborhood = self.get_neighborhood(pos, moore, include_center, radius)
return self.iter_cell_list_contents(neighborhood)
default_val = self.default_val()
for x, y in self.get_neighborhood(pos, moore, include_center, radius):
if (cell := self._grid[x][y]) != default_val:
yield cell

def get_neighbors(
self,
Expand Down Expand Up @@ -385,11 +387,10 @@ def iter_cell_list_contents(
An iterator of the agents contained in the cells identified in `cell_list`.
"""
# iter_cell_list_contents returns only non-empty contents.
return (
cell
for x, y in cell_list
if (cell := self._grid[x][y]) != self.default_val()
)
default_val = self.default_val()
for x, y in cell_list:
if (cell := self._grid[x][y]) != default_val:
yield cell

@accept_tuple_argument
def get_cell_list_contents(self, cell_list: Iterable[Coordinate]) -> list[Agent]:
Expand Down Expand Up @@ -1045,6 +1046,17 @@ def remove_agent(self, agent: Agent) -> None:
self._empty_mask[agent.pos] = False
agent.pos = None

def iter_neighbors(
self,
pos: Coordinate,
moore: bool,
include_center: bool = False,
radius: int = 1,
) -> Iterator[Agent]:
return itertools.chain.from_iterable(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I can see caching default_val() as a speedup, how does this line speed up the code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This just chains the iterable from the parent function - which is already cache-optimized. So now we don't have to call iter_cell_list_contents, which is wrapped in an annotation

super().iter_neighbors(pos, moore, include_center, radius)
)

@accept_tuple_argument
def iter_cell_list_contents(
self, cell_list: Iterable[Coordinate]
Expand All @@ -1058,10 +1070,9 @@ def iter_cell_list_contents(
Returns:
An iterator of the agents contained in the cells identified in `cell_list`.
"""
default_val = self.default_val()
return itertools.chain.from_iterable(
cell
for x, y in cell_list
if (cell := self._grid[x][y]) != self.default_val()
cell for x, y in cell_list if (cell := self._grid[x][y]) != default_val
)


Expand Down