Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move ExceptionGroup to types; add asyncio.TaskGroup #1

Merged
merged 1 commit into from
Oct 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Lib/asyncio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .tasks import *
from .threads import *
from .transports import *
from .taskgroup import *

# Exposed for _asynciomodule.c to implement now deprecated
# Task.all_tasks() method. This function will be removed in 3.9.
Expand All @@ -37,7 +38,8 @@
subprocess.__all__ +
tasks.__all__ +
threads.__all__ +
transports.__all__)
transports.__all__ +
taskgroup.__all__)

if sys.platform == 'win32': # pragma: no cover
from .windows_events import *
Expand Down
282 changes: 282 additions & 0 deletions Lib/asyncio/taskgroup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2016-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


from __future__ import annotations

import asyncio
import functools
import itertools
import sys
import types

__all__ = ('TaskGroup',)


class TaskGroup:

def __init__(self, *, name=None):
if name is None:
self._name = f'tg-{_name_counter()}'
else:
self._name = str(name)

self._entered = False
self._exiting = False
self._aborting = False
self._loop = None
self._parent_task = None
self._parent_cancel_requested = False
self._tasks = set()
self._unfinished_tasks = 0
self._errors = []
self._base_error = None
self._on_completed_fut = None

def get_name(self):
return self._name

def __repr__(self):
msg = f'<TaskGroup {self._name!r}'
if self._tasks:
msg += f' tasks:{len(self._tasks)}'
if self._unfinished_tasks:
msg += f' unfinished:{self._unfinished_tasks}'
if self._errors:
msg += f' errors:{len(self._errors)}'
if self._aborting:
msg += ' cancelling'
elif self._entered:
msg += ' entered'
msg += '>'
return msg

async def __aenter__(self):
if self._entered:
raise RuntimeError(
f"TaskGroup {self!r} has been already entered")
self._entered = True

if self._loop is None:
self._loop = asyncio.get_running_loop()

self._parent_task = asyncio.current_task(self._loop)
if self._parent_task is None:
raise RuntimeError(
f'TaskGroup {self!r} cannot determine the parent task')
self._patch_task(self._parent_task)

return self

async def __aexit__(self, et, exc, tb):
self._exiting = True
propagate_cancelation = False

if (exc is not None and
self._is_base_error(exc) and
self._base_error is None):
self._base_error = exc

if et is asyncio.CancelledError:
if self._parent_cancel_requested:
# Only if we did request task to cancel ourselves
# we mark it as no longer cancelled.
self._parent_task.__cancel_requested__ = False
else:
propagate_cancelation = True

if et is not None and not self._aborting:
# Our parent task is being cancelled:
#
# async with TaskGroup() as g:
# g.create_task(...)
# await ... # <- CancelledError
#
if et is asyncio.CancelledError:
propagate_cancelation = True

# or there's an exception in "async with":
#
# async with TaskGroup() as g:
# g.create_task(...)
# 1 / 0
#
self._abort()

# We use while-loop here because "self._on_completed_fut"
# can be cancelled multiple times if our parent task
# is being cancelled repeatedly (or even once, when
# our own cancellation is already in progress)
while self._unfinished_tasks:
if self._on_completed_fut is None:
self._on_completed_fut = self._loop.create_future()

try:
await self._on_completed_fut
except asyncio.CancelledError:
if not self._aborting:
# Our parent task is being cancelled:
#
# async def wrapper():
# async with TaskGroup() as g:
# g.create_task(foo)
#
# "wrapper" is being cancelled while "foo" is
# still running.
propagate_cancelation = True
self._abort()

self._on_completed_fut = None

assert self._unfinished_tasks == 0
self._on_completed_fut = None # no longer needed

if self._base_error is not None:
raise self._base_error

if propagate_cancelation:
# The wrapping task was cancelled; since we're done with
# closing all child tasks, just propagate the cancellation
# request now.
raise asyncio.CancelledError()

if et is not None and et is not asyncio.CancelledError:
self._errors.append(exc)

if self._errors:
# Exceptions are heavy objects that can have object
# cycles (bad for GC); let's not keep a reference to
# a bunch of them.
errors = self._errors
self._errors = None

raise types.ExceptionGroup(errors)

def create_task(self, coro):
if not self._entered:
raise RuntimeError(f"TaskGroup {self!r} has not been entered")
if self._exiting:
raise RuntimeError(f"TaskGroup {self!r} is awaiting in exit")
task = self._loop.create_task(coro)
task.add_done_callback(self._on_task_done)
self._unfinished_tasks += 1
self._tasks.add(task)
return task

if sys.version_info >= (3, 8):

# In Python 3.8 Tasks propagate all exceptions correctly,
# except for KeyboardInterrupt and SystemExit which are
# still considered special.

def _is_base_error(self, exc: BaseException) -> bool:
assert isinstance(exc, BaseException)
return isinstance(exc, (SystemExit, KeyboardInterrupt))

else:

# In Python prior to 3.8 all BaseExceptions are special and
# are bypassing the proper propagation through async/await
# code, essentially aborting the execution.

def _is_base_error(self, exc: BaseException) -> bool:
assert isinstance(exc, BaseException)
return not isinstance(exc, Exception)

def _patch_task(self, task):
# In Python 3.8 we'll need proper API on asyncio.Task to
# make TaskGroups possible. We need to be able to access
# information about task cancellation, more specifically,
# we need a flag to say if a task was cancelled or not.
# We also need to be able to flip that flag.

def _task_cancel(task, orig_cancel):
task.__cancel_requested__ = True
return orig_cancel()

if hasattr(task, '__cancel_requested__'):
return

task.__cancel_requested__ = False
# confirm that we were successful at adding the new attribute:
assert not task.__cancel_requested__

orig_cancel = task.cancel
task.cancel = functools.partial(_task_cancel, task, orig_cancel)

def _abort(self):
self._aborting = True

for t in self._tasks:
if not t.done():
t.cancel()

def _on_task_done(self, task):
self._unfinished_tasks -= 1
assert self._unfinished_tasks >= 0

if self._exiting and not self._unfinished_tasks:
if not self._on_completed_fut.done():
self._on_completed_fut.set_result(True)

if task.cancelled():
return

exc = task.exception()
if exc is None:
return

self._errors.append(exc)
if self._is_base_error(exc) and self._base_error is None:
self._base_error = exc

if self._parent_task.done():
# Not sure if this case is possible, but we want to handle
# it anyways.
self._loop.call_exception_handler({
'message': f'Task {task!r} has errored out but its parent '
f'task {self._parent_task} is already completed',
'exception': exc,
'task': task,
})
return

self._abort()
if not self._parent_task.__cancel_requested__:
# If parent task *is not* being cancelled, it means that we want
# to manually cancel it to abort whatever is being run right now
# in the TaskGroup. But we want to mark parent task as
# "not cancelled" later in __aexit__. Example situation that
# we need to handle:
#
# async def foo():
# try:
# async with TaskGroup() as g:
# g.create_task(crash_soon())
# await something # <- this needs to be canceled
# # by the TaskGroup, e.g.
# # foo() needs to be cancelled
# except Exception:
# # Ignore any exceptions raised in the TaskGroup
# pass
# await something_else # this line has to be called
# # after TaskGroup is finished.
self._parent_cancel_requested = True
self._parent_task.cancel()

_name_counter = itertools.count(1).__next__
64 changes: 63 additions & 1 deletion Lib/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _m(self): pass
GetSetDescriptorType = type(FunctionType.__code__)
MemberDescriptorType = type(FunctionType.__globals__)

del sys, _f, _g, _C, _c, _ag # Not for export
del _f, _g, _C, _c, _ag # Not for export


# Provide a PEP 3115 compliant mechanism for class creation
Expand Down Expand Up @@ -300,4 +300,66 @@ def wrapped(*args, **kwargs):
NoneType = type(None)
NotImplementedType = type(NotImplemented)

class ExceptionGroup(BaseException):

def __init__(self, excs, tb=None):
self.excs = set(excs)
if tb:
self.__traceback__ = tb
else:
import types
self.__traceback__ = types.TracebackType(
None, sys._getframe(), 0, 0)
for e in excs:
self.add_exc(e)

def add_exc(self, e):
self.excs.add(e)
self.__traceback__.next_map_add(e, e.__traceback__)

def split(self, E):
''' remove the exceptions that match E
and return them in a new ExceptionGroup
'''
matches = []
for e in self.excs:
if isinstance(e, E):
matches.append(e)
[self.excs.remove(m) for m in matches]
gtb = self.__traceback__
while gtb.tb_next:
# there could be normal tbs is the ExceptionGroup propagated
gtb = gtb.tb_next
tb = gtb.group_split(matches)

return ExceptionGroup(matches, tb)

def push_frame(self, frame):
import types
self.__traceback__ = types.TracebackType(
self.__traceback__, frame, 0, 0)

@staticmethod
def render(exc, tb=None, indent=0):
print(exc)
tb = tb or exc.__traceback__
while tb:
print(' '*indent, tb.tb_frame)
if tb.tb_next: # single traceback
tb = tb.tb_next
elif tb.tb_next_map:
indent += 4
for e, t in tb.tb_next_map.items():
print('---------------------------------------')
ExceptionGroup.render(e, t, indent)
tb = None
else:
tb = None

def __str__(self):
return f"ExceptionGroup({self.excs})"

def __repr__(self):
return str(self)

__all__ = [n for n in globals() if n[:1] != '_']
Loading