Skip to content

Commit

Permalink
Add support for eager tasks (#111425)
Browse files Browse the repository at this point in the history
* Add support for eager tasks

python 3.12 supports eager tasks

reading:
https://docs.python.org/3/library/asyncio-task.html#eager-task-factory
python/cpython#97696

There are lots of places were we are unlikely to suspend, but we might
suspend so creating a task makes sense

* reduce

* revert entity

* revert

* coverage

* coverage

* coverage

* coverage

* fix test
  • Loading branch information
bdraco authored Feb 26, 2024
1 parent 93cc6e0 commit 67e3569
Show file tree
Hide file tree
Showing 8 changed files with 162 additions and 18 deletions.
1 change: 1 addition & 0 deletions homeassistant/components/websocket_api/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def schedule_handler(
hass.async_create_background_task(
_handle_async_response(func, hass, connection, msg),
task_name,
eager_start=True,
)

return schedule_handler
Expand Down
11 changes: 8 additions & 3 deletions homeassistant/config_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,7 @@ def async_create_task(
hass: HomeAssistant,
target: Coroutine[Any, Any, _R],
name: str | None = None,
eager_start: bool = False,
) -> asyncio.Task[_R]:
"""Create a task from within the event loop.
Expand All @@ -923,7 +924,7 @@ def async_create_task(
target: target to call.
"""
task = hass.async_create_task(
target, f"{name} {self.title} {self.domain} {self.entry_id}"
target, f"{name} {self.title} {self.domain} {self.entry_id}", eager_start
)
self._tasks.add(task)
task.add_done_callback(self._tasks.remove)
Expand All @@ -932,15 +933,19 @@ def async_create_task(

@callback
def async_create_background_task(
self, hass: HomeAssistant, target: Coroutine[Any, Any, _R], name: str
self,
hass: HomeAssistant,
target: Coroutine[Any, Any, _R],
name: str,
eager_start: bool = False,
) -> asyncio.Task[_R]:
"""Create a background task tied to the config entry lifecycle.
Background tasks are automatically canceled when config entry is unloaded.
target: target to call.
"""
task = hass.async_create_background_task(target, name)
task = hass.async_create_background_task(target, name, eager_start)
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.remove)
return task
Expand Down
20 changes: 14 additions & 6 deletions homeassistant/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
from .util import dt as dt_util, location
from .util.async_ import (
cancelling,
create_eager_task,
run_callback_threadsafe,
shutdown_run_callback_threadsafe,
)
Expand Down Expand Up @@ -622,7 +623,10 @@ def create_task(

@callback
def async_create_task(
self, target: Coroutine[Any, Any, _R], name: str | None = None
self,
target: Coroutine[Any, Any, _R],
name: str | None = None,
eager_start: bool = False,
) -> asyncio.Task[_R]:
"""Create a task from within the event loop.
Expand All @@ -631,16 +635,17 @@ def async_create_task(
target: target to call.
"""
task = self.loop.create_task(target, name=name)
if eager_start:
task = create_eager_task(target, name=name, loop=self.loop)
else:
task = self.loop.create_task(target, name=name)
self._tasks.add(task)
task.add_done_callback(self._tasks.remove)
return task

@callback
def async_create_background_task(
self,
target: Coroutine[Any, Any, _R],
name: str,
self, target: Coroutine[Any, Any, _R], name: str, eager_start: bool = False
) -> asyncio.Task[_R]:
"""Create a task from within the event loop.
Expand All @@ -650,7 +655,10 @@ def async_create_background_task(
This method must be run in the event loop.
"""
task = self.loop.create_task(target, name=name)
if eager_start:
task = create_eager_task(target, name=name, loop=self.loop)
else:
task = self.loop.create_task(target, name=name)
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.remove)
return task
Expand Down
34 changes: 32 additions & 2 deletions homeassistant/util/async_.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""Asyncio utilities."""
from __future__ import annotations

from asyncio import Future, Semaphore, gather, get_running_loop
from asyncio.events import AbstractEventLoop
from asyncio import AbstractEventLoop, Future, Semaphore, Task, gather, get_running_loop
from collections.abc import Awaitable, Callable
import concurrent.futures
from contextlib import suppress
import functools
import logging
import sys
import threading
from traceback import extract_stack
from typing import Any, ParamSpec, TypeVar, TypeVarTuple
Expand All @@ -23,6 +23,36 @@
_P = ParamSpec("_P")
_Ts = TypeVarTuple("_Ts")

if sys.version_info >= (3, 12, 0):

def create_eager_task(
coro: Awaitable[_T],
*,
name: str | None = None,
loop: AbstractEventLoop | None = None,
) -> Task[_T]:
"""Create a task from a coroutine and schedule it to run immediately."""
return Task(
coro,
loop=loop or get_running_loop(),
name=name,
eager_start=True, # type: ignore[call-arg]
)
else:

def create_eager_task(
coro: Awaitable[_T],
*,
name: str | None = None,
loop: AbstractEventLoop | None = None,
) -> Task[_T]:
"""Create a task from a coroutine and schedule it to run immediately."""
return Task(
coro,
loop=loop or get_running_loop(),
name=name,
)


def cancelling(task: Future[Any]) -> bool:
"""Return True if task is cancelling."""
Expand Down
4 changes: 2 additions & 2 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,14 +260,14 @@ def async_add_executor_job(target, *args):

return orig_async_add_executor_job(target, *args)

def async_create_task(coroutine, name=None):
def async_create_task(coroutine, name=None, eager_start=False):
"""Create task."""
if isinstance(coroutine, Mock) and not isinstance(coroutine, AsyncMock):
fut = asyncio.Future()
fut.set_result(None)
return fut

return orig_async_create_task(coroutine, name)
return orig_async_create_task(coroutine, name, eager_start)

hass.async_add_job = async_add_job
hass.async_add_executor_job = async_add_executor_job
Expand Down
9 changes: 7 additions & 2 deletions tests/test_config_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -4228,11 +4228,16 @@ async def test_unload() -> None:

entry.async_on_unload(test_unload)
entry.async_create_task(hass, test_task())
entry.async_create_background_task(hass, test_task(), "background-task-name")
entry.async_create_background_task(
hass, test_task(), "background-task-name", eager_start=True
)
entry.async_create_background_task(
hass, test_task(), "background-task-name", eager_start=False
)
await asyncio.sleep(0)
hass.loop.call_soon(event.set)
await entry._async_process_on_unload(hass)
assert results == ["on_unload", "background", "normal"]
assert results == ["on_unload", "background", "background", "normal"]


async def test_preview_supported(
Expand Down
50 changes: 47 additions & 3 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import gc
import logging
import os
import sys
from tempfile import TemporaryDirectory
import threading
import time
Expand Down Expand Up @@ -161,7 +162,9 @@ def job():
assert len(hass.loop.run_in_executor.mock_calls) == 2


def test_async_create_task_schedule_coroutine(event_loop) -> None:
def test_async_create_task_schedule_coroutine(
event_loop: asyncio.AbstractEventLoop,
) -> None:
"""Test that we schedule coroutines and add jobs to the job pool."""
hass = MagicMock(loop=MagicMock(wraps=event_loop))

Expand All @@ -174,6 +177,44 @@ async def job():
assert len(hass.add_job.mock_calls) == 0


@pytest.mark.skipif(
sys.version_info < (3, 12), reason="eager_start is only supported for Python 3.12"
)
def test_async_create_task_eager_start_schedule_coroutine(
event_loop: asyncio.AbstractEventLoop,
) -> None:
"""Test that we schedule coroutines and add jobs to the job pool."""
hass = MagicMock(loop=MagicMock(wraps=event_loop))

async def job():
pass

ha.HomeAssistant.async_create_task(hass, job(), eager_start=True)
# Should create the task directly since 3.12 supports eager_start
assert len(hass.loop.create_task.mock_calls) == 0
assert len(hass.add_job.mock_calls) == 0


@pytest.mark.skipif(
sys.version_info >= (3, 12), reason="eager_start is not supported on < 3.12"
)
def test_async_create_task_eager_start_fallback_schedule_coroutine(
event_loop: asyncio.AbstractEventLoop,
) -> None:
"""Test that we schedule coroutines and add jobs to the job pool."""
hass = MagicMock(loop=MagicMock(wraps=event_loop))

async def job():
pass

ha.HomeAssistant.async_create_task(hass, job(), eager_start=True)
assert len(hass.loop.call_soon.mock_calls) == 1
# Should fallback to loop.create_task since 3.11 does
# not support eager_start
assert len(hass.loop.create_task.mock_calls) == 0
assert len(hass.add_job.mock_calls) == 0


def test_async_create_task_schedule_coroutine_with_name(event_loop) -> None:
"""Test that we schedule coroutines and add jobs to the job pool with a name."""
hass = MagicMock(loop=MagicMock(wraps=event_loop))
Expand Down Expand Up @@ -2598,7 +2639,8 @@ async def test_state_changed_events_to_not_leak_contexts(hass: HomeAssistant) ->
assert len(_get_by_type("homeassistant.core.Context")) == init_count


async def test_background_task(hass: HomeAssistant) -> None:
@pytest.mark.parametrize("eager_start", (True, False))
async def test_background_task(hass: HomeAssistant, eager_start: bool) -> None:
"""Test background tasks being quit."""
result = asyncio.Future()

Expand All @@ -2609,7 +2651,9 @@ async def test_task():
result.set_result(hass.state)
raise

task = hass.async_create_background_task(test_task(), "happy task")
task = hass.async_create_background_task(
test_task(), "happy task", eager_start=eager_start
)
assert "happy task" in str(task)
await asyncio.sleep(0)
await hass.async_stop()
Expand Down
51 changes: 51 additions & 0 deletions tests/util/test_async.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests for async util methods from Python source."""
import asyncio
import sys
import time
from unittest.mock import MagicMock, Mock, patch

Expand Down Expand Up @@ -246,3 +247,53 @@ async def test_callback_is_always_scheduled(hass: HomeAssistant) -> None:
hasync.run_callback_threadsafe(hass.loop, callback)

mock_call_soon_threadsafe.assert_called_once()


@pytest.mark.skipif(sys.version_info < (3, 12), reason="Test requires Python 3.12+")
async def test_create_eager_task_312(hass: HomeAssistant) -> None:
"""Test create_eager_task schedules a task eagerly in the event loop.
For Python 3.12+, the task is scheduled eagerly in the event loop.
"""
events = []

async def _normal_task():
events.append("normal")

async def _eager_task():
events.append("eager")

task1 = hasync.create_eager_task(_eager_task())
task2 = asyncio.create_task(_normal_task())

assert events == ["eager"]

await asyncio.sleep(0)
assert events == ["eager", "normal"]
await task1
await task2


@pytest.mark.skipif(sys.version_info >= (3, 12), reason="Test requires < Python 3.12")
async def test_create_eager_task_pre_312(hass: HomeAssistant) -> None:
"""Test create_eager_task schedules a task in the event loop.
For older python versions, the task is scheduled normally.
"""
events = []

async def _normal_task():
events.append("normal")

async def _eager_task():
events.append("eager")

task1 = hasync.create_eager_task(_eager_task())
task2 = asyncio.create_task(_normal_task())

assert events == []

await asyncio.sleep(0)
assert events == ["eager", "normal"]
await task1
await task2

0 comments on commit 67e3569

Please sign in to comment.