Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bpo-42345: Fix three issues with typing.Literal parameters #23294

Merged
merged 8 commits into from
Nov 17, 2020
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions Lib/test/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,7 @@ def test_repr(self):
self.assertEqual(repr(Literal[int]), "typing.Literal[int]")
self.assertEqual(repr(Literal), "typing.Literal")
self.assertEqual(repr(Literal[None]), "typing.Literal[None]")
self.assertEqual(repr(Literal[1, 2, 3, 3]), "typing.Literal[1, 2, 3]")
uriyyo marked this conversation as resolved.
Show resolved Hide resolved

def test_cannot_init(self):
with self.assertRaises(TypeError):
Expand Down Expand Up @@ -559,6 +560,20 @@ def test_no_multiple_subscripts(self):
with self.assertRaises(TypeError):
Literal[1][1]

def test_equal(self):
self.assertNotEqual(Literal[0], Literal[False])
self.assertNotEqual(Literal[True], Literal[1])
self.assertNotEqual(Literal[1], Literal[2])
self.assertNotEqual(Literal[1, True], Literal[1])
self.assertEqual(Literal[1], Literal[1])
self.assertEqual(Literal[1, 2], Literal[2, 1])
self.assertEqual(Literal[1, 2, 3], Literal[1, 2, 3, 3])

def test_flatten(self):
self.assertEqual(Literal[Literal[1], Literal[2], Literal[3]], Literal[1, 2, 3])
self.assertEqual(Literal[Literal[1, 2], 3], Literal[1, 2, 3])
self.assertEqual(Literal[Literal[1, 2, 3]], Literal[1, 2, 3])
uriyyo marked this conversation as resolved.
Show resolved Hide resolved


XK = TypeVar('XK', str, bytes)
XV = TypeVar('XV')
Expand Down
99 changes: 76 additions & 23 deletions Lib/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,20 @@ def _check_generic(cls, parameters, elen):
f" actual {alen}, expected {elen}")


def _deduplicate(params):
# Weed out strict duplicates, preserving the first of each occurrence.
all_params = set(params)
uriyyo marked this conversation as resolved.
Show resolved Hide resolved
if len(all_params) < len(params):
new_params = []
for t in params:
if t in all_params:
new_params.append(t)
all_params.remove(t)
params = new_params
assert not all_params, all_params
return params


def _remove_dups_flatten(parameters):
"""An internal helper for Union creation and substitution: flatten Unions
among parameters, then remove duplicates.
Expand All @@ -215,38 +229,45 @@ def _remove_dups_flatten(parameters):
params.extend(p[1:])
else:
params.append(p)
# Weed out strict duplicates, preserving the first of each occurrence.
all_params = set(params)
if len(all_params) < len(params):
new_params = []
for t in params:
if t in all_params:
new_params.append(t)
all_params.remove(t)
params = new_params
assert not all_params, all_params

return tuple(_deduplicate(params))


def _flatten_literal_params(parameters):
"""An internal helper for Literal creation: flatten Literals among parameters"""
params = []
for p in parameters:
if isinstance(p, _LiteralGenericAlias):
params.extend(p.__args__)
else:
uriyyo marked this conversation as resolved.
Show resolved Hide resolved
params.append(p)
return tuple(params)


_cleanups = []


def _tp_cache(func):
def _tp_cache(func=None, /, *, typed=False):
"""Internal wrapper caching __getitem__ of generic types with a fallback to
original function for non-hashable arguments.
uriyyo marked this conversation as resolved.
Show resolved Hide resolved
"""
cached = functools.lru_cache()(func)
_cleanups.append(cached.cache_clear)
if func is not None:
return _tp_cache(typed=typed)(func)
uriyyo marked this conversation as resolved.
Show resolved Hide resolved

@functools.wraps(func)
def inner(*args, **kwds):
try:
return cached(*args, **kwds)
except TypeError:
pass # All real errors (not unhashable args) are raised below.
return func(*args, **kwds)
return inner
def decorator(func):
cached = functools.lru_cache(typed=typed)(func)
_cleanups.append(cached.cache_clear)

@functools.wraps(func)
def inner(*args, **kwds):
try:
return cached(*args, **kwds)
except TypeError:
pass # All real errors (not unhashable args) are raised below.
return func(*args, **kwds)
return inner

return decorator

def _eval_type(t, globalns, localns, recursive_guard=frozenset()):
"""Evaluate all forward references in the given type t.
Expand Down Expand Up @@ -319,6 +340,13 @@ def __subclasscheck__(self, cls):
def __getitem__(self, parameters):
return self._getitem(self, parameters)


class _LiteralSpecialForm(_SpecialForm, _root=True):
@_tp_cache(typed=True)
def __getitem__(self, parameters):
return self._getitem(self, parameters)


@_SpecialForm
def Any(self, parameters):
"""Special type indicating an unconstrained type.
Expand Down Expand Up @@ -436,7 +464,7 @@ def Optional(self, parameters):
arg = _type_check(parameters, f"{self} requires a single type.")
return Union[arg, type(None)]

@_SpecialForm
@_LiteralSpecialForm
def Literal(self, parameters):
"""Special typing form to define literal types (a.k.a. value types).

Expand All @@ -460,7 +488,17 @@ def open_helper(file: str, mode: MODE) -> str:
"""
# There is no '_type_check' call because arguments to Literal[...] are
# values, not types.
return _GenericAlias(self, parameters)
if not isinstance(parameters, tuple):
parameters = (parameters,)

parameters = _flatten_literal_params(parameters)

try:
parameters = tuple(p for p, _ in _deduplicate(list(_value_and_type_iter(parameters))))
except TypeError: # unhashable parameters
pass
uriyyo marked this conversation as resolved.
Show resolved Hide resolved

return _LiteralGenericAlias(self, parameters)


@_SpecialForm
Expand Down Expand Up @@ -930,6 +968,21 @@ def __subclasscheck__(self, cls):
return True


def _value_and_type_iter(parameters):
return ((p, type(p)) for p in parameters)


class _LiteralGenericAlias(_GenericAlias, _root=True):

def __eq__(self, other):
if not isinstance(other, _LiteralGenericAlias):
return NotImplemented

return set(_value_and_type_iter(self.__args__)) == set(_value_and_type_iter(other.__args__))
uriyyo marked this conversation as resolved.
Show resolved Hide resolved

def __hash__(self):
return hash(tuple(_value_and_type_iter(self.__args__)))


class Generic:
"""Abstract base class for generic types.
Expand Down
1 change: 1 addition & 0 deletions Misc/ACKS
Original file line number Diff line number Diff line change
Expand Up @@ -861,6 +861,7 @@ Jan Kanis
Rafe Kaplan
Jacob Kaplan-Moss
Allison Kaptur
Yurii Karabas
Janne Karila
Per Øyvind Karlsen
Anton Kasyanov
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Fix ``typing.Literal`` equals method to ignore the order of arguments.
Fix issue related to ``typing.Literal`` caching by adding ``typed``
parameter to ``typing._tp_cache`` function. Add deduplication of
``typing.Literal`` arguments. Patch provided by Yurii Karabas.
uriyyo marked this conversation as resolved.
Show resolved Hide resolved