diff --git a/mesa/agent.py b/mesa/agent.py index 243a187337b..57edf1a032a 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -264,17 +264,29 @@ def do( return res if return_results else self - def get(self, attr_name: str) -> list[Any]: + def get(self, attr_names: str | list[str]) -> list[Any]: """ - Retrieve a specified attribute from each agent in the AgentSet. + Retrieve the specified attribute(s) from each agent in the AgentSet. Args: - attr_name (str): The name of the attribute to retrieve from each agent. + attr_names (str | list[str]): The name(s) of the attribute(s) to retrieve from each agent. Returns: - list[Any]: A list of attribute values from each agent in the set. + list[Any]: A list with the attribute value for each agent in the set if attr_names is a str + list[list[Any]]: A list with a list of attribute values for each agent in the set if attr_names is a list of str + + Raises: + AttributeError if an agent does not have the specified attribute(s) + """ - return [getattr(agent, attr_name) for agent in self._agents] + + if isinstance(attr_names, str): + return [getattr(agent, attr_names) for agent in self._agents] + else: + return [ + [getattr(agent, attr_name) for attr_name in attr_names] + for agent in self._agents + ] def __getitem__(self, item: int | slice) -> Agent: """ diff --git a/tests/test_agent.py b/tests/test_agent.py index 7ad538eba27..0cd211123e1 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -221,6 +221,24 @@ def test_agentset_get_attribute(): with pytest.raises(AttributeError): agentset.get("non_existing_attribute") + model = Model() + agents = [] + for i in range(10): + agent = TestAgent(model.next_id(), model) + agent.i = i**2 + agents.append(agent) + agentset = AgentSet(agents, model) + + values = agentset.get(["unique_id", "i"]) + + for value, agent in zip(values, agents): + ( + unique_id, + i, + ) = value + assert agent.unique_id == unique_id + assert agent.i == i + class OtherAgentType(Agent): def get_unique_identifier(self):