79 lines
2.7 KiB
Python
79 lines
2.7 KiB
Python
"""Utility functions for HA core."""
|
|
import logging
|
|
from collections.abc import Awaitable, Callable
|
|
|
|
from homeassistant.core import HomeAssistant, ServiceCall
|
|
from homeassistant.util.read_only_dict import ReadOnlyDict
|
|
|
|
from .adaptation_utils import ServiceData
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
def setup_service_call_interceptor(
|
|
hass: HomeAssistant,
|
|
domain: str,
|
|
service: str,
|
|
intercept_func: Callable[[ServiceCall, ServiceData], Awaitable[None] | None],
|
|
) -> Callable[[], None]:
|
|
"""Inject a function into a registered service call to preprocess service data.
|
|
|
|
The injected interceptor function receives the service call and a writeable data dictionary
|
|
(the data of the service call is read-only) before the service call is executed.
|
|
"""
|
|
try:
|
|
# HACK: Access protected attribute of HA service registry.
|
|
# This is necessary to replace a registered service handler with our
|
|
# proxy handler to intercept calls.
|
|
registered_services = (
|
|
hass.services._services # pylint: disable=protected-access
|
|
)
|
|
except AttributeError as error:
|
|
msg = (
|
|
"Intercept failed because registered services are no longer"
|
|
" accessible (internal API may have changed)"
|
|
)
|
|
raise RuntimeError(msg) from error
|
|
|
|
if domain not in registered_services or service not in registered_services[domain]:
|
|
msg = f"Intercept failed because service {domain}.{service} is not registered"
|
|
raise RuntimeError(msg)
|
|
|
|
existing_service = registered_services[domain][service]
|
|
|
|
async def service_func_proxy(call: ServiceCall) -> None:
|
|
try:
|
|
# Convert read-only data to writeable dictionary for modification by interceptor
|
|
data = dict(call.data)
|
|
|
|
# Call interceptor
|
|
await intercept_func(call, data)
|
|
|
|
# Convert data back to read-only
|
|
call.data = ReadOnlyDict(data)
|
|
except Exception:
|
|
# Blindly catch all exceptions to avoid breaking light.turn_on
|
|
_LOGGER.exception(
|
|
"Error for call '%s' in service_func_proxy",
|
|
call.data,
|
|
)
|
|
# Call original service handler with processed data
|
|
await existing_service.job.target(call)
|
|
|
|
hass.services.async_register(
|
|
domain,
|
|
service,
|
|
service_func_proxy,
|
|
existing_service.schema,
|
|
)
|
|
|
|
def remove():
|
|
# Remove the interceptor by reinstalling the original service handler
|
|
hass.services.async_register(
|
|
domain,
|
|
service,
|
|
existing_service.job.target,
|
|
existing_service.schema,
|
|
)
|
|
|
|
return remove
|