diff --git a/mesa/datacollection.py b/mesa/datacollection.py index 309d56d8e27..ba1d3d50eee 100644 --- a/mesa/datacollection.py +++ b/mesa/datacollection.py @@ -134,13 +134,88 @@ def __init__( for name, columns in tables.items(): self._new_table(name, columns) + def _validate_model_reporter(self, name, reporter, model): + """Validate model reporter and issue warnings if necessary. + + Args: + name: Name of the reporter + reporter: Reporter definition + model: Model instance + """ + # Type 1: Lambda function + if isinstance(reporter, types.LambdaType): + try: + # Try to call the lambda with a model instance + reporter(model) + except Exception as e: + warnings.warn( + f"Warning: Lambda reporter '{name}' failed: {e!s}\n" + f"Example of valid lambda: lambda m: len(m.agents)", + UserWarning, + stacklevel=2, + ) + return + + # Type 2: Method of class/instance + if callable(reporter) and not isinstance(reporter, types.LambdaType): + try: + # Try to call the method + reporter(model) + except Exception as e: + warnings.warn( + f"Warning: Method reporter '{name}' failed: {e!s}\n" + f"Example of valid method: self.get_agent_count or Model.get_agent_count", + UserWarning, + stacklevel=2, + ) + return + + # Type 3: Class attributes (string) + if isinstance(reporter, str): + if not hasattr(model, reporter): + warnings.warn( + f"Warning: Model reporter '{name}' references attribute '{reporter}' " + f"which is not defined in the model.\n" + f"Example of valid attribute: 'model_attribute'", + UserWarning, + stacklevel=2, + ) + return + + # Type 4: Function with parameters in list + if isinstance(reporter, list): + if not reporter or not callable(reporter[0]): + warnings.warn( + f"Warning: Invalid function list format for reporter '{name}'.\n" + f"First element must be a callable function.\n" + f"Example: [function, [param1, param2]]", + UserWarning, + stacklevel=2, + ) + return + + # If none of the above types match + warnings.warn( + f"Warning: Model reporter '{name}' has invalid type: {type(reporter)}.\n" + f"Must be one of:\n" + f"1. Lambda function: lambda m: len(m.agents)\n" + f"2. Method: self.get_count or Model.get_count\n" + f"3. Attribute name (str): 'model_attribute'\n" + f"4. Function list: [function, [param1, param2]]", + UserWarning, + stacklevel=2, + ) + def _new_model_reporter(self, name, reporter): """Add a new model-level reporter to collect. Args: name: Name of the model-level variable to collect. - reporter: Attribute string, or function object that returns the - variable when given a model instance. + reporter: Can be one of four types: + 1. Attribute name (str): "attribute_name" + 2. Lambda function: lambda m: len(m.agents) + 3. Method: model.get_count or Model.get_count + 4. List of [function, [parameters]] """ self.model_reporters[name] = reporter self.model_vars[name] = [] @@ -263,6 +338,9 @@ def collect(self, model): """Collect all the data for the given model object.""" if self.model_reporters: for var, reporter in self.model_reporters.items(): + # Add validation + self._validate_model_reporter(var, reporter, model) + # Check if lambda or partial function if isinstance(reporter, types.LambdaType | partial): # Use deepcopy to store a copy of the data,