Skip to content

Commit

Permalink
fix: update datacollection validation and add test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
peter-kinger committed Jan 27, 2025
1 parent 426c62a commit b5c9918
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 51 deletions.
103 changes: 52 additions & 51 deletions mesa/datacollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ def __init__(
self._agenttype_records = {}
self.tables = {}

# add the signal of the validation of model reporter
self._validated = False

if model_reporters is not None:
for name, reporter in model_reporters.items():
self._new_model_reporter(name, reporter)
Expand All @@ -135,76 +138,73 @@ def __init__(
self._new_table(name, columns)

def _validate_model_reporter(self, name, reporter, model):
"""Validate model reporter and issue warnings if necessary.
"""Validate model reporter and handle validation results appropriately.
Args:
name: Name of the reporter
reporter: Reporter definition
reporter: Reporter definition (lambda/method/attribute/function list)
model: Model instance
Raises:
ValueError: If reporter is None or has invalid format
AttributeError: If model attribute doesn't exist
TypeError: If reporter type is not supported
RuntimeError: If reporter execution fails
"""
self._validated = True # put the change of signal firstly avoid losing efficacy

# Type 1: Lambda function
if isinstance(reporter, types.LambdaType):
try:
# Try to call the lambda with a model instance
reporter(model)
except Exception as e:
warnings.warn(
f"Warning: Lambda reporter '{name}' failed: {e!s}\n"
f"Example of valid lambda: lambda m: len(m.agents)",
UserWarning,
stacklevel=2,
)
return
raise RuntimeError(
f"Lambda reporter '{name}' failed validation: {e!s}\n"
f"Example: lambda m: len(m.agents)"
) from e

# Type 2: Method of class/instance
if callable(reporter) and not isinstance(reporter, types.LambdaType):
try:
# Try to call the method
reporter(model)
callable(reporter)
except Exception as e:
warnings.warn(
f"Warning: Method reporter '{name}' failed: {e!s}\n"
f"Example of valid method: self.get_agent_count or Model.get_agent_count",
UserWarning,
stacklevel=2,
)
return
raise RuntimeError(

Check warning on line 171 in mesa/datacollection.py

View check run for this annotation

Codecov / codecov/patch

mesa/datacollection.py#L170-L171

Added lines #L170 - L171 were not covered by tests
f"Method reporter '{name}' failed validation: {e!s}"
) from e

# Type 3: Class attributes (string)
# Type 3: Model attribute (string)
if isinstance(reporter, str):
if not hasattr(model, reporter):
warnings.warn(
f"Warning: Model reporter '{name}' references attribute '{reporter}' "
f"which is not defined in the model.\n"
f"Example of valid attribute: 'model_attribute'",
UserWarning,
stacklevel=2,
)
return
try:
if not hasattr(model, reporter):
raise AttributeError(
f"Model reporter '{name}' references non-existent attribute '{reporter}'\n"
f"Available attributes: {', '.join(dir(model))}"
)
getattr(model, reporter) # 验证属性是否可访问
except AttributeError as e:
raise AttributeError(
f"Model reporter '{name}' attribute validation failed: {e!s}\n"
f"Available attributes: {', '.join(dir(model))}"
) from e
except Exception as e:
raise RuntimeError(

Check warning on line 190 in mesa/datacollection.py

View check run for this annotation

Codecov / codecov/patch

mesa/datacollection.py#L189-L190

Added lines #L189 - L190 were not covered by tests
f"Model reporter '{name}' attribute validation failed: {e!s}"
) from e

# Type 4: Function with parameters in list
if isinstance(reporter, list):
if not reporter or not callable(reporter[0]):
warnings.warn(
f"Warning: Invalid function list format for reporter '{name}'.\n"
f"First element must be a callable function.\n"
f"Example: [function, [param1, param2]]",
UserWarning,
stacklevel=2,
raise ValueError(
f"Invalid function list format for reporter '{name}'\n"
f"Expected: [function, [param1, param2]], got: {reporter}"
)
return

# If none of the above types match
warnings.warn(
f"Warning: Model reporter '{name}' has invalid type: {type(reporter)}.\n"
f"Must be one of:\n"
f"1. Lambda function: lambda m: len(m.agents)\n"
f"2. Method: self.get_count or Model.get_count\n"
f"3. Attribute name (str): 'model_attribute'\n"
f"4. Function list: [function, [param1, param2]]",
UserWarning,
stacklevel=2,
)
try:
reporter[0](*reporter[1])
except Exception as e:
raise RuntimeError(

Check warning on line 204 in mesa/datacollection.py

View check run for this annotation

Codecov / codecov/patch

mesa/datacollection.py#L203-L204

Added lines #L203 - L204 were not covered by tests
f"Function list reporter '{name}' failed validation: {e!s}\n"
f"Example: [function, [param1, param2]]"
) from e

def _new_model_reporter(self, name, reporter):
"""Add a new model-level reporter to collect.
Expand Down Expand Up @@ -337,10 +337,11 @@ def get_reports(agent):
def collect(self, model):
"""Collect all the data for the given model object."""
if self.model_reporters:
for var, reporter in self.model_reporters.items():
# Add validation
self._validate_model_reporter(var, reporter, model)
if not self._validated:
for name, reporter in self.model_reporters.items():
self._validate_model_reporter(name, reporter, model)

for var, reporter in self.model_reporters.items():
# Check if lambda or partial function
if isinstance(reporter, types.LambdaType | partial):
# Use deepcopy to store a copy of the data,
Expand Down
61 changes: 61 additions & 0 deletions tests/test_datacollector.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,5 +352,66 @@ def test_agenttype_superclass_reporter(self):
self.assertTrue(super_data.equals(agent_data))


class MockModelForErrors(Model):
"""Test model for error handling."""

def __init__(self):
"""Initialize the test model for error handling."""
super().__init__()
self.num_agents = 10
self.valid_attribute = "test"

def valid_method(self):
"""Valid method for testing."""
return self.num_agents


def helper_function(model, param1):
"""Test function with parameters."""
return model.num_agents * param1


class TestDataCollectorErrorHandling(unittest.TestCase):
"""Test error handling in DataCollector."""

def setUp(self):
"""Set up test cases."""
self.model = MockModelForErrors()

def test_lambda_error(self):
"""Test error when lambda tries to access non-existent attribute."""
dc_lambda = DataCollector(
model_reporters={"bad_lambda": lambda m: m.nonexistent_attr}
)
with self.assertRaises(RuntimeError):
dc_lambda.collect(self.model)

def test_method_error(self):
"""Test error when accessing non-existent method."""

def bad_method(model):
raise Exception("Test error")

dc_method = DataCollector(model_reporters={"test": bad_method})
with self.assertRaises(RuntimeError):
dc_method.collect(self.model)

def test_attribute_error(self):
"""Test error when accessing non-existent attribute."""
dc_attribute = DataCollector(
model_reporters={"bad_attribute": "nonexistent_attribute"}
)
with self.assertRaises(Exception):
dc_attribute.collect(self.model)

def test_function_error(self):
"""Test error when function list is not callable."""
dc_function = DataCollector(
model_reporters={"bad_function": ["not_callable", [1, 2]]}
)
with self.assertRaises(ValueError):
dc_function.collect(self.model)


if __name__ == "__main__":
unittest.main()

0 comments on commit b5c9918

Please sign in to comment.