Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose scripts with no fields as entities #123061

Merged
merged 1 commit into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 95 additions & 80 deletions homeassistant/helpers/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,9 @@
):
continue

tools.append(ScriptTool(self.hass, state.entity_id))
script_tool = ScriptTool(self.hass, state.entity_id)
if script_tool.parameters.schema:
tools.append(script_tool)

return tools

Expand Down Expand Up @@ -446,12 +448,17 @@
entities = {}

for state in hass.states.async_all():
if state.domain == SCRIPT_DOMAIN:
continue

if not async_should_expose(hass, assistant, state.entity_id):
continue

description: str | None = None
if state.domain == SCRIPT_DOMAIN:
description, parameters = _get_cached_script_parameters(
hass, state.entity_id
)
if parameters.schema: # Only list scripts without input fields here
continue

entity_entry = entity_registry.async_get(state.entity_id)
names = [state.name]
area_names = []
Expand Down Expand Up @@ -480,6 +487,9 @@
"state": state.state,
}

if description:
info["description"] = description

if area_names:
info["areas"] = ", ".join(area_names)

Expand Down Expand Up @@ -605,6 +615,83 @@
return {"type": "string"}


def _get_cached_script_parameters(
hass: HomeAssistant, entity_id: str
) -> tuple[str | None, vol.Schema]:
"""Get script description and schema."""
entity_registry = er.async_get(hass)

description = None
parameters = vol.Schema({})
entity_entry = entity_registry.async_get(entity_id)
if entity_entry and entity_entry.unique_id:
parameters_cache = hass.data.get(SCRIPT_PARAMETERS_CACHE)

if parameters_cache is None:
parameters_cache = hass.data[SCRIPT_PARAMETERS_CACHE] = {}

@callback
def clear_cache(event: Event) -> None:
"""Clear script parameter cache on script reload or delete."""
if (
event.data[ATTR_DOMAIN] == SCRIPT_DOMAIN
and event.data[ATTR_SERVICE] in parameters_cache
):
parameters_cache.pop(event.data[ATTR_SERVICE])

cancel = hass.bus.async_listen(EVENT_SERVICE_REMOVED, clear_cache)

@callback
def on_homeassistant_close(event: Event) -> None:
"""Cleanup."""
cancel()

hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_CLOSE, on_homeassistant_close
)

if entity_entry.unique_id in parameters_cache:
return parameters_cache[entity_entry.unique_id]

if service_desc := service.async_get_cached_service_description(
hass, SCRIPT_DOMAIN, entity_entry.unique_id
):
description = service_desc.get("description")
schema: dict[vol.Marker, Any] = {}
fields = service_desc.get("fields", {})

for field, config in fields.items():
field_description = config.get("description")
if not field_description:
field_description = config.get("name")
key: vol.Marker
if config.get("required"):
key = vol.Required(field, description=field_description)
else:
key = vol.Optional(field, description=field_description)
if "selector" in config:
schema[key] = selector.selector(config["selector"])
else:
schema[key] = cv.string

parameters = vol.Schema(schema)

aliases: list[str] = []
if entity_entry.name:
aliases.append(entity_entry.name)
if entity_entry.aliases:
aliases.extend(entity_entry.aliases)
if aliases:
if description:
description = description + ". Aliases: " + str(list(aliases))
else:
description = "Aliases: " + str(list(aliases))

Check warning on line 688 in homeassistant/helpers/llm.py

View check run for this annotation

Codecov / codecov/patch

homeassistant/helpers/llm.py#L688

Added line #L688 was not covered by tests

parameters_cache[entity_entry.unique_id] = (description, parameters)

return description, parameters


class ScriptTool(Tool):
"""LLM Tool representing a Script."""

Expand All @@ -614,86 +701,14 @@
script_entity_id: str,
) -> None:
"""Init the class."""
entity_registry = er.async_get(hass)

self.name = split_entity_id(script_entity_id)[1]
if self.name[0].isdigit():
self.name = "_" + self.name
self._entity_id = script_entity_id
self.parameters = vol.Schema({})
entity_entry = entity_registry.async_get(script_entity_id)
if entity_entry and entity_entry.unique_id:
parameters_cache = hass.data.get(SCRIPT_PARAMETERS_CACHE)

if parameters_cache is None:
parameters_cache = hass.data[SCRIPT_PARAMETERS_CACHE] = {}

@callback
def clear_cache(event: Event) -> None:
"""Clear script parameter cache on script reload or delete."""
if (
event.data[ATTR_DOMAIN] == SCRIPT_DOMAIN
and event.data[ATTR_SERVICE] in parameters_cache
):
parameters_cache.pop(event.data[ATTR_SERVICE])

cancel = hass.bus.async_listen(EVENT_SERVICE_REMOVED, clear_cache)

@callback
def on_homeassistant_close(event: Event) -> None:
"""Cleanup."""
cancel()

hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_CLOSE, on_homeassistant_close
)

if entity_entry.unique_id in parameters_cache:
self.description, self.parameters = parameters_cache[
entity_entry.unique_id
]
return

if service_desc := service.async_get_cached_service_description(
hass, SCRIPT_DOMAIN, entity_entry.unique_id
):
self.description = service_desc.get("description")
schema: dict[vol.Marker, Any] = {}
fields = service_desc.get("fields", {})

for field, config in fields.items():
description = config.get("description")
if not description:
description = config.get("name")
key: vol.Marker
if config.get("required"):
key = vol.Required(field, description=description)
else:
key = vol.Optional(field, description=description)
if "selector" in config:
schema[key] = selector.selector(config["selector"])
else:
schema[key] = cv.string

self.parameters = vol.Schema(schema)

aliases: list[str] = []
if entity_entry.name:
aliases.append(entity_entry.name)
if entity_entry.aliases:
aliases.extend(entity_entry.aliases)
if aliases:
if self.description:
self.description = (
self.description + ". Aliases: " + str(list(aliases))
)
else:
self.description = "Aliases: " + str(list(aliases))

parameters_cache[entity_entry.unique_id] = (
self.description,
self.parameters,
)

self.description, self.parameters = _get_cached_script_parameters(
hass, script_entity_id
)

async def async_call(
self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
Expand Down
22 changes: 19 additions & 3 deletions tests/helpers/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,11 +374,16 @@ async def test_assist_api_prompt(
"beer": {"description": "Number of beers"},
"wine": {},
},
}
},
"script_with_no_fields": {
"description": "This is another test script",
"sequence": [],
},
}
},
)
async_expose_entity(hass, "conversation", "script.test_script", True)
async_expose_entity(hass, "conversation", "script.script_with_no_fields", True)

entry = MockConfigEntry(title=None)
entry.add_to_hass(hass)
Expand Down Expand Up @@ -511,6 +516,10 @@ def create_entity(
)
)
exposed_entities_prompt = """An overview of the areas and the devices in this smart home:
- names: script_with_no_fields
domain: script
state: 'off'
description: This is another test script
- names: Kitchen
domain: light
state: 'on'
Expand Down Expand Up @@ -657,13 +666,18 @@ async def test_script_tool(
"extra_field": {"selector": {"area": {}}},
},
},
"script_with_no_fields": {
"description": "This is another test script",
"sequence": [],
},
"unexposed_script": {
"sequence": [],
},
}
},
)
async_expose_entity(hass, "conversation", "script.test_script", True)
async_expose_entity(hass, "conversation", "script.script_with_no_fields", True)

entity_registry.async_update_entity(
"script.test_script", name="script name", aliases={"script alias"}
Expand Down Expand Up @@ -700,7 +714,8 @@ async def test_script_tool(
"test_script": (
"This is a test script. Aliases: ['script name', 'script alias']",
vol.Schema(schema),
)
),
"script_with_no_fields": ("This is another test script", vol.Schema({})),
}

tool_input = llm.ToolInput(
Expand Down Expand Up @@ -781,7 +796,8 @@ async def test_script_tool(
"test_script": (
"This is a new test script. Aliases: ['script name', 'script alias']",
vol.Schema(schema),
)
),
"script_with_no_fields": ("This is another test script", vol.Schema({})),
}


Expand Down