"""Helpers for components that manage entities.""" from __future__ import annotations import asyncio from collections.abc import Callable, Iterable from datetime import timedelta import logging from types import ModuleType from typing import Any from homeassistant import config as conf_util from homeassistant.config_entries import ConfigEntry from homeassistant.const import ( CONF_ENTITY_NAMESPACE, CONF_SCAN_INTERVAL, EVENT_HOMEASSISTANT_STOP, ) from homeassistant.core import ( Event, HassJob, HassJobType, HomeAssistant, ServiceCall, ServiceResponse, SupportsResponse, callback, ) from homeassistant.exceptions import HomeAssistantError from homeassistant.loader import async_get_integration, bind_hass from homeassistant.setup import async_prepare_setup_platform from . import config_validation as cv, discovery, entity, service from .entity_platform import EntityPlatform from .typing import ConfigType, DiscoveryInfoType, VolDictType, VolSchemaType DEFAULT_SCAN_INTERVAL = timedelta(seconds=15) DATA_INSTANCES = "entity_components" @bind_hass async def async_update_entity(hass: HomeAssistant, entity_id: str) -> None: """Trigger an update for an entity.""" domain = entity_id.partition(".")[0] entity_comp: EntityComponent[entity.Entity] | None entity_comp = hass.data.get(DATA_INSTANCES, {}).get(domain) if entity_comp is None: logging.getLogger(__name__).warning( "Forced update failed. Component for %s not loaded.", entity_id ) return if (entity_obj := entity_comp.get_entity(entity_id)) is None: logging.getLogger(__name__).warning( "Forced update failed. Entity %s not found.", entity_id ) return await entity_obj.async_update_ha_state(True) class EntityComponent[_EntityT: entity.Entity = entity.Entity]: """The EntityComponent manages platforms that manage entities. An example of an entity component is 'light', which manages platforms such as 'hue.light'. This class has the following responsibilities: - Process the configuration and set up a platform based component, for example light. - Manage the platforms and their entities. - Help extract the entities from a service call. - Listen for discovery events for platforms related to the domain. """ def __init__( self, logger: logging.Logger, domain: str, hass: HomeAssistant, scan_interval: timedelta = DEFAULT_SCAN_INTERVAL, ) -> None: """Initialize an entity component.""" self.logger = logger self.hass = hass self.domain = domain self.scan_interval = scan_interval self.config: ConfigType | None = None domain_platform = self._async_init_entity_platform(domain, None) self._platforms: dict[ str | tuple[str, timedelta | None, str | None], EntityPlatform ] = {domain: domain_platform} self.async_add_entities = domain_platform.async_add_entities self.add_entities = domain_platform.add_entities self._entities: dict[str, entity.Entity] = domain_platform.domain_entities hass.data.setdefault(DATA_INSTANCES, {})[domain] = self @property def entities(self) -> Iterable[_EntityT]: """Return an iterable that returns all entities. As the underlying dicts may change when async context is lost, callers that iterate over this asynchronously should make a copy using list() before iterating. """ return self._entities.values() # type: ignore[return-value] def get_entity(self, entity_id: str) -> _EntityT | None: """Get an entity.""" return self._entities.get(entity_id) # type: ignore[return-value] def register_shutdown(self) -> None: """Register shutdown on Home Assistant STOP event. Note: this is only required if the integration never calls `setup` or `async_setup`. """ self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, self._async_shutdown) def setup(self, config: ConfigType) -> None: """Set up a full entity component. This doesn't block the executor to protect from deadlocks. """ self.hass.create_task( self.async_setup(config), f"EntityComponent setup {self.domain}" ) async def async_setup(self, config: ConfigType) -> None: """Set up a full entity component. Loads the platforms from the config and will listen for supported discovered platforms. This method must be run in the event loop. """ self.register_shutdown() self.config = config # Look in config for Domain, Domain 2, Domain 3 etc and load them for p_type, p_config in conf_util.config_per_platform(config, self.domain): if p_type is not None: self.hass.async_create_task_internal( self.async_setup_platform(p_type, p_config), f"EntityComponent setup platform {p_type} {self.domain}", eager_start=True, ) # Generic discovery listener for loading platform dynamically # Refer to: homeassistant.helpers.discovery.async_load_platform() discovery.async_listen_platform( self.hass, self.domain, self._async_component_platform_discovered ) async def _async_component_platform_discovered( self, platform: str, info: dict[str, Any] | None ) -> None: """Handle the loading of a platform.""" await self.async_setup_platform(platform, {}, info) async def async_setup_entry(self, config_entry: ConfigEntry) -> bool: """Set up a config entry.""" platform_type = config_entry.domain platform = await async_prepare_setup_platform( self.hass, # In future PR we should make hass_config part of the constructor # params. self.config or {}, self.domain, platform_type, ) if platform is None: return False key = config_entry.entry_id if key in self._platforms: raise ValueError( f"Config entry {config_entry.title} ({key}) for " f"{platform_type}.{self.domain} has already been setup!" ) self._platforms[key] = self._async_init_entity_platform( platform_type, platform, scan_interval=getattr(platform, "SCAN_INTERVAL", None), ) return await self._platforms[key].async_setup_entry(config_entry) async def async_unload_entry(self, config_entry: ConfigEntry) -> bool: """Unload a config entry.""" key = config_entry.entry_id if (platform := self._platforms.pop(key, None)) is None: raise ValueError("Config entry was never loaded!") await platform.async_reset() return True async def async_extract_from_service( self, service_call: ServiceCall, expand_group: bool = True ) -> list[_EntityT]: """Extract all known and available entities from a service call. Will return an empty list if entities specified but unknown. This method must be run in the event loop. """ return await service.async_extract_entities( self.hass, self.entities, service_call, expand_group ) @callback def async_register_legacy_entity_service( self, name: str, schema: VolDictType | VolSchemaType, func: str | Callable[..., Any], required_features: list[int] | None = None, supports_response: SupportsResponse = SupportsResponse.NONE, ) -> None: """Register an entity service with a legacy response format.""" if isinstance(schema, dict): schema = cv.make_entity_service_schema(schema) service_func: str | HassJob[..., Any] service_func = func if isinstance(func, str) else HassJob(func) async def handle_service( call: ServiceCall, ) -> ServiceResponse: """Handle the service.""" result = await service.entity_service_call( self.hass, self._entities, service_func, call, required_features ) if result: if len(result) > 1: raise HomeAssistantError( "Deprecated service call matched more than one entity" ) return result.popitem()[1] return None self.hass.services.async_register( self.domain, name, handle_service, schema, supports_response ) @callback def async_register_entity_service( self, name: str, schema: VolDictType | VolSchemaType | None, func: str | Callable[..., Any], required_features: list[int] | None = None, supports_response: SupportsResponse = SupportsResponse.NONE, ) -> None: """Register an entity service.""" service.async_register_entity_service( self.hass, self.domain, name, entities=self._entities, func=func, job_type=HassJobType.Coroutinefunction, required_features=required_features, schema=schema, supports_response=supports_response, ) async def async_setup_platform( self, platform_type: str, platform_config: ConfigType, discovery_info: DiscoveryInfoType | None = None, ) -> None: """Set up a platform for this component.""" if self.config is None: raise RuntimeError("async_setup needs to be called first") platform = await async_prepare_setup_platform( self.hass, self.config, self.domain, platform_type ) if platform is None: return # Use config scan interval, fallback to platform if none set scan_interval = platform_config.get( CONF_SCAN_INTERVAL, getattr(platform, "SCAN_INTERVAL", None) ) entity_namespace = platform_config.get(CONF_ENTITY_NAMESPACE) key = (platform_type, scan_interval, entity_namespace) if key not in self._platforms: self._platforms[key] = self._async_init_entity_platform( platform_type, platform, scan_interval, entity_namespace ) await self._platforms[key].async_setup(platform_config, discovery_info) async def _async_reset(self) -> None: """Remove entities and reset the entity component to initial values. This method must be run in the event loop. """ tasks = [] for key, platform in self._platforms.items(): if key == self.domain: tasks.append(platform.async_reset()) else: tasks.append(platform.async_destroy()) if tasks: await asyncio.gather(*tasks) self._platforms = {self.domain: self._platforms[self.domain]} self.config = None async def async_remove_entity(self, entity_id: str) -> None: """Remove an entity managed by one of the platforms.""" found = None for platform in self._platforms.values(): if entity_id in platform.entities: found = platform break if found: await found.async_remove_entity(entity_id) async def async_prepare_reload( self, *, skip_reset: bool = False ) -> ConfigType | None: """Prepare reloading this entity component. This method must be run in the event loop. """ try: conf = await conf_util.async_hass_config_yaml(self.hass) except HomeAssistantError as err: self.logger.error(err) return None integration = await async_get_integration(self.hass, self.domain) processed_conf = await conf_util.async_process_component_and_handle_errors( self.hass, conf, integration ) if processed_conf is None: return None if not skip_reset: await self._async_reset() return processed_conf @callback def _async_init_entity_platform( self, platform_type: str, platform: ModuleType | None, scan_interval: timedelta | None = None, entity_namespace: str | None = None, ) -> EntityPlatform: """Initialize an entity platform.""" if scan_interval is None: scan_interval = self.scan_interval entity_platform = EntityPlatform( hass=self.hass, logger=self.logger, domain=self.domain, platform_name=platform_type, platform=platform, scan_interval=scan_interval, entity_namespace=entity_namespace, ) entity_platform.async_prepare() return entity_platform @callback def _async_shutdown(self, event: Event) -> None: """Call when Home Assistant is stopping.""" for platform in self._platforms.values(): platform.async_shutdown()