diff --git a/mesa/agent.py b/mesa/agent.py index 109a9dcdba8..8688dc702a4 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -8,6 +8,7 @@ from __future__ import annotations import contextlib +import copy import operator import warnings import weakref @@ -124,9 +125,7 @@ def __init__(self, agents: Iterable[Agent], model: Model): stacklevel=2, ) - self._agents = weakref.WeakKeyDictionary() - for agent in agents: - self._agents[agent] = None + self._agents = weakref.WeakKeyDictionary({agent: None for agent in agents}) def __len__(self) -> int: """Return the number of agents in the AgentSet.""" @@ -161,20 +160,21 @@ def select( AgentSet: A new AgentSet containing the selected agents, unless inplace is True, in which case the current AgentSet is updated. """ - def agent_generator(): + if filter_func is None and agent_type is None and n == 0: + return self if inplace else copy.copy(self) + + def agent_generator(filter_func=None, agent_type=None, n=0): count = 0 for agent in self: - if filter_func and not filter_func(agent): - continue - if agent_type and not isinstance(agent, agent_type): - continue - yield agent - count += 1 - # default of n is zero, zo evaluates to False - if n and count >= n: - break - - agents = agent_generator() + if (not filter_func or filter_func(agent)) and ( + not agent_type or isinstance(agent, agent_type) + ): + yield agent + count += 1 + if 0 < n <= count: + break + + agents = agent_generator(filter_func, agent_type, n) return AgentSet(agents, self.model) if not inplace else self._update(agents) @@ -229,11 +229,8 @@ def _update(self, agents: Iterable[Agent]): """Update the AgentSet with a new set of agents. This is a private method primarily used internally by other methods like select, shuffle, and sort. """ - _agents = weakref.WeakKeyDictionary() - for agent in agents: - _agents[agent] = None - self._agents = _agents + self._agents = weakref.WeakKeyDictionary({agent: None for agent in agents}) return self def do( diff --git a/mesa/space.py b/mesa/space.py index 969f18380cf..b730c2da796 100644 --- a/mesa/space.py +++ b/mesa/space.py @@ -326,8 +326,10 @@ def iter_neighbors( at most 9 if Moore, 5 if Von-Neumann (8 and 4 if not including the center). """ - neighborhood = self.get_neighborhood(pos, moore, include_center, radius) - return self.iter_cell_list_contents(neighborhood) + default_val = self.default_val() + for x, y in self.get_neighborhood(pos, moore, include_center, radius): + if (cell := self._grid[x][y]) != default_val: + yield cell def get_neighbors( self, @@ -385,11 +387,10 @@ def iter_cell_list_contents( An iterator of the agents contained in the cells identified in `cell_list`. """ # iter_cell_list_contents returns only non-empty contents. - return ( - cell - for x, y in cell_list - if (cell := self._grid[x][y]) != self.default_val() - ) + default_val = self.default_val() + for x, y in cell_list: + if (cell := self._grid[x][y]) != default_val: + yield cell @accept_tuple_argument def get_cell_list_contents(self, cell_list: Iterable[Coordinate]) -> list[Agent]: @@ -1045,6 +1046,17 @@ def remove_agent(self, agent: Agent) -> None: self._empty_mask[agent.pos] = False agent.pos = None + def iter_neighbors( + self, + pos: Coordinate, + moore: bool, + include_center: bool = False, + radius: int = 1, + ) -> Iterator[Agent]: + return itertools.chain.from_iterable( + super().iter_neighbors(pos, moore, include_center, radius) + ) + @accept_tuple_argument def iter_cell_list_contents( self, cell_list: Iterable[Coordinate] @@ -1058,10 +1070,9 @@ def iter_cell_list_contents( Returns: An iterator of the agents contained in the cells identified in `cell_list`. """ + default_val = self.default_val() return itertools.chain.from_iterable( - cell - for x, y in cell_list - if (cell := self._grid[x][y]) != self.default_val() + cell for x, y in cell_list if (cell := self._grid[x][y]) != default_val )