diff --git a/mlflow/models/display_utils.py b/mlflow/models/display_utils.py index 37726fecf22fb..264be17195707 100644 --- a/mlflow/models/display_utils.py +++ b/mlflow/models/display_utils.py @@ -6,7 +6,17 @@ from mlflow.utils import databricks_utils +def _is_input_string(inputs: schema.Schema) -> bool: + return ( + not inputs.has_input_names() + and len(inputs.input_types()) == 1 + and inputs.input_types()[0] == schema.DataType.string + ) + + def _is_input_agent_compatible(inputs: schema.Schema) -> bool: + if _is_input_string(inputs): + return True if not inputs.has_input_names(): return False messages = inputs.input_dict().get("messages") diff --git a/tests/models/test_display_utils.py b/tests/models/test_display_utils.py index dbea5b49913c9..c77ff80c03f93 100644 --- a/tests/models/test_display_utils.py +++ b/tests/models/test_display_utils.py @@ -62,6 +62,11 @@ def test_should_render_eval_template_with_vanilla_string(enable_databricks_env): assert should_render_agent_eval_template(signature) +def test_should_render_eval_template_with_string_input(enable_databricks_env): + signature = infer_signature("A vanilla string input", _STRING_RESPONSE) + assert should_render_agent_eval_template(signature) + + def test_should_not_render_eval_template_generic_signature(enable_databricks_env): signature = infer_signature({"input": "string"}, {"output": "string"}) assert not should_render_agent_eval_template(signature)