From 26b3b7ca3b691f46ac2c757f400225c62ab72061 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Fri, 10 Mar 2023 16:38:24 -0800 Subject: [PATCH] Properly universalize dicts Resolves #723. --- DOCS.md | 1 + __coconut__/__init__.pyi | 1 + coconut/compiler/compiler.py | 80 ++++++++++++------- coconut/compiler/grammar.py | 3 +- coconut/compiler/header.py | 2 +- coconut/compiler/matching.py | 4 +- coconut/constants.py | 1 + coconut/root.py | 28 +++++-- .../tests/src/cocotest/agnostic/primary.coco | 14 ++++ 9 files changed, 95 insertions(+), 39 deletions(-) diff --git a/DOCS.md b/DOCS.md index 61ce0fcca..06ddd8ed0 100644 --- a/DOCS.md +++ b/DOCS.md @@ -242,6 +242,7 @@ While Coconut syntax is based off of the latest Python 3, Coconut code compiled To make Coconut built-ins universal across Python versions, Coconut makes available on any Python version built-ins that only exist in later versions, including **automatically overwriting Python 2 built-ins with their Python 3 counterparts.** Additionally, Coconut also [overwrites some Python 3 built-ins for optimization and enhancement purposes](#enhanced-built-ins). If access to the original Python versions of any overwritten built-ins is desired, the old built-ins can be retrieved by prefixing them with `py_`. Specifically, the overwritten built-ins are: - `py_chr` +- `py_dict` - `py_hex` - `py_input` - `py_int` diff --git a/__coconut__/__init__.pyi b/__coconut__/__init__.pyi index 1bce36f3c..20d40407c 100644 --- a/__coconut__/__init__.pyi +++ b/__coconut__/__init__.pyi @@ -133,6 +133,7 @@ if sys.version_info < (3, 7): py_chr = chr +py_dict = dict py_hex = hex py_input = input py_int = int diff --git a/coconut/compiler/compiler.py b/coconut/compiler/compiler.py index a2bf98f7d..34e441ba4 100644 --- a/coconut/compiler/compiler.py +++ b/coconut/compiler/compiler.py @@ -356,6 +356,48 @@ def reconstitute_paramdef(pos_only_args, req_args, default_args, star_arg, kwd_o return ", ".join(args_list) +def split_star_expr_tokens(tokens, is_dict=False): + """Split testlist_star_expr or dict_literal tokens.""" + groups = [[]] + has_star = False + has_comma = False + for tok_grp in tokens: + if tok_grp == ",": + has_comma = True + elif len(tok_grp) == 1: + internal_assert(not is_dict, "found non-star non-pair item in dict literal", tok_grp) + groups[-1].append(tok_grp[0]) + elif len(tok_grp) == 2: + internal_assert(not tok_grp[0].lstrip("*"), "invalid star expr item signifier", tok_grp[0]) + has_star = True + groups.append(tok_grp[1]) + groups.append([]) + elif len(tok_grp) == 3: + internal_assert(is_dict, "found dict key-value pair in non-dict tokens", tok_grp) + k, c, v = tok_grp + internal_assert(c == ":", "invalid colon in dict literal item", c) + groups[-1].append((k, v)) + else: + raise CoconutInternalException("invalid testlist_star_expr tokens", tokens) + if not groups[-1]: + groups.pop() + return groups, has_star, has_comma + + +def join_dict_group(group, as_tuples=False): + """Join group from split_star_expr_tokens$(is_dict=True).""" + items = [] + for k, v in group: + if as_tuples: + items.append("(" + k + ", " + v + ")") + else: + items.append(k + ": " + v) + if as_tuples: + return tuple_str_of(items, add_parens=False) + else: + return ", ".join(items) + + # end: UTILITIES # ----------------------------------------------------------------------------------------------------------------------- # COMPILER: @@ -3011,7 +3053,7 @@ def dict_comp_handle(self, loc, tokens): if self.target.startswith("3"): return "{" + key + ": " + val + " " + comp + "}" else: - return "dict(((" + key + "), (" + val + ")) " + comp + ")" + return "_coconut.dict(((" + key + "), (" + val + ")) " + comp + ")" def pattern_error(self, original, loc, value_var, check_var, match_error_class='_coconut_MatchError'): """Construct a pattern-matching error message.""" @@ -3594,30 +3636,9 @@ def unsafe_typedef_or_expr_handle(self, tokens): else: return "_coconut.typing.Union[" + ", ".join(tokens) + "]" - def split_star_expr_tokens(self, tokens): - """Split testlist_star_expr or dict_literal tokens.""" - groups = [[]] - has_star = False - has_comma = False - for tok_grp in tokens: - if tok_grp == ",": - has_comma = True - elif len(tok_grp) == 1: - groups[-1].append(tok_grp[0]) - elif len(tok_grp) == 2: - internal_assert(not tok_grp[0].lstrip("*"), "invalid star expr item signifier", tok_grp[0]) - has_star = True - groups.append(tok_grp[1]) - groups.append([]) - else: - raise CoconutInternalException("invalid testlist_star_expr tokens", tokens) - if not groups[-1]: - groups.pop() - return groups, has_star, has_comma - def testlist_star_expr_handle(self, original, loc, tokens, is_list=False): """Handle naked a, *b.""" - groups, has_star, has_comma = self.split_star_expr_tokens(tokens) + groups, has_star, has_comma = split_star_expr_tokens(tokens) is_sequence = has_comma or is_list if not is_sequence and not has_star: @@ -3667,20 +3688,23 @@ def list_expr_handle(self, original, loc, tokens): def dict_literal_handle(self, tokens): """Handle {**d1, **d2}.""" if not tokens: - return "{}" + return "{}" if self.target.startswith("3") else "_coconut.dict()" - groups, has_star, _ = self.split_star_expr_tokens(tokens) + groups, has_star, _ = split_star_expr_tokens(tokens, is_dict=True) if not has_star: internal_assert(len(groups) == 1, "dict_literal group splitting failed on", tokens) - return "{" + ", ".join(groups[0]) + "}" + if self.target.startswith("3"): + return "{" + join_dict_group(groups[0]) + "}" + else: + return "_coconut.dict((" + join_dict_group(groups[0], as_tuples=True) + "))" # naturally supported on 3.5+ elif self.target_info >= (3, 5): to_literal = [] for g in groups: if isinstance(g, list): - to_literal.extend(g) + to_literal.append(join_dict_group(g)) else: to_literal.append("**" + g) return "{" + ", ".join(to_literal) + "}" @@ -3690,7 +3714,7 @@ def dict_literal_handle(self, tokens): to_merge = [] for g in groups: if isinstance(g, list): - to_merge.append("{" + ", ".join(g) + "}") + to_merge.append("{" + join_dict_group(g) + "}") else: to_merge.append(g) return "_coconut_dict_merge(" + ", ".join(to_merge) + ")" diff --git a/coconut/compiler/grammar.py b/coconut/compiler/grammar.py index 424e6b84b..de5c1de0d 100644 --- a/coconut/compiler/grammar.py +++ b/coconut/compiler/grammar.py @@ -957,7 +957,8 @@ class Grammar(object): lbrace.suppress() + Optional( tokenlist( - Group(addspace(condense(test + colon) + test)) | dubstar_expr, + Group(test + colon + test) + | dubstar_expr, comma, ), ) diff --git a/coconut/compiler/header.py b/coconut/compiler/header.py index 9fbca8463..820759d63 100644 --- a/coconut/compiler/header.py +++ b/coconut/compiler/header.py @@ -208,7 +208,7 @@ def process_header_args(which, use_hash, target, no_tco, strict, no_wrap): format_dict = dict( COMMENT=COMMENT, - empty_dict="{}", + empty_dict="{}" if target_startswith == "3" else "_coconut.dict()", lbrace="{", rbrace="}", is_data_var=is_data_var, diff --git a/coconut/compiler/matching.py b/coconut/compiler/matching.py index 175b037f8..0149479e7 100644 --- a/coconut/compiler/matching.py +++ b/coconut/compiler/matching.py @@ -504,8 +504,8 @@ def match_dict(self, tokens, item): if rest is not None and rest != wildcard: match_keys = [k for k, v in matches] rest_item = ( - "dict((k, v) for k, v in " - + item + ".items() if k not in set((" + "_coconut.dict((k, v) for k, v in " + + item + ".items() if k not in _coconut.set((" + ", ".join(match_keys) + ("," if len(match_keys) == 1 else "") + ")))" ) diff --git a/coconut/constants.py b/coconut/constants.py index a59cba07c..c7c29270f 100644 --- a/coconut/constants.py +++ b/coconut/constants.py @@ -662,6 +662,7 @@ def get_bool_env_var(env_var, default=False): "cycle", "windowsof", "py_chr", + "py_dict", "py_hex", "py_input", "py_int", diff --git a/coconut/root.py b/coconut/root.py index 6aa0f923d..b079e6c6e 100644 --- a/coconut/root.py +++ b/coconut/root.py @@ -26,7 +26,7 @@ VERSION = "3.0.0" VERSION_NAME = None # False for release, int >= 1 for develop -DEVELOP = 11 +DEVELOP = 12 ALPHA = True # for pre releases rather than post releases # ----------------------------------------------------------------------------------------------------------------------- @@ -80,8 +80,8 @@ def breakpoint(*args, **kwargs): ''' # if a new assignment is added below, a new builtins import should be added alongside it -_base_py3_header = r'''from builtins import chr, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, repr -py_chr, py_hex, py_input, py_int, py_map, py_object, py_oct, py_open, py_print, py_range, py_str, py_super, py_zip, py_filter, py_reversed, py_enumerate, py_repr = chr, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, repr +_base_py3_header = r'''from builtins import chr, dict, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, repr +py_chr, py_dict, py_hex, py_input, py_int, py_map, py_object, py_oct, py_open, py_print, py_range, py_str, py_super, py_zip, py_filter, py_reversed, py_enumerate, py_repr = chr, dict, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, repr _coconut_py_str, _coconut_py_super = str, super from functools import wraps as _coconut_wraps exec("_coconut_exec = exec") @@ -96,10 +96,11 @@ def breakpoint(*args, **kwargs): ''' # if a new assignment is added below, a new builtins import should be added alongside it -PY27_HEADER = r'''from __builtin__ import chr, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, raw_input, xrange, repr, long -py_chr, py_hex, py_input, py_int, py_map, py_object, py_oct, py_open, py_print, py_range, py_str, py_super, py_zip, py_filter, py_reversed, py_enumerate, py_raw_input, py_xrange, py_repr = chr, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, raw_input, xrange, repr -_coconut_py_raw_input, _coconut_py_xrange, _coconut_py_int, _coconut_py_long, _coconut_py_print, _coconut_py_str, _coconut_py_super, _coconut_py_unicode, _coconut_py_repr = raw_input, xrange, int, long, print, str, super, unicode, repr +PY27_HEADER = r'''from __builtin__ import chr, dict, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, raw_input, xrange, repr, long +py_chr, py_dict, py_hex, py_input, py_int, py_map, py_object, py_oct, py_open, py_print, py_range, py_str, py_super, py_zip, py_filter, py_reversed, py_enumerate, py_raw_input, py_xrange, py_repr = chr, dict, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, raw_input, xrange, repr +_coconut_py_raw_input, _coconut_py_xrange, _coconut_py_int, _coconut_py_long, _coconut_py_print, _coconut_py_str, _coconut_py_super, _coconut_py_unicode, _coconut_py_repr, _coconut_py_dict = raw_input, xrange, int, long, print, str, super, unicode, repr, dict from functools import wraps as _coconut_wraps +from collections import Sequence as _coconut_Sequence, OrderedDict as _coconut_OrderedDict from future_builtins import * chr, str = unichr, unicode from io import open @@ -123,6 +124,20 @@ def __instancecheck__(cls, inst): return _coconut.isinstance(inst, (_coconut_py_int, _coconut_py_long)) def __subclasscheck__(cls, subcls): return _coconut.issubclass(subcls, (_coconut_py_int, _coconut_py_long)) +class dict(_coconut_OrderedDict): + __slots__ = () + __doc__ = getattr(_coconut_OrderedDict, "__doc__", "") + class __metaclass__(type): + def __instancecheck__(cls, inst): + return _coconut.isinstance(inst, _coconut_py_dict) + def __subclasscheck__(cls, subcls): + return _coconut.issubclass(subcls, _coconut_py_dict) + __eq__ = _coconut_py_dict.__eq__ + __repr__ = _coconut_py_dict.__repr__ + __str__ = _coconut_py_dict.__str__ + keys = _coconut_OrderedDict.viewkeys + values = _coconut_OrderedDict.viewvalues + items = _coconut_OrderedDict.viewitems class range(object): __slots__ = ("_xrange",) __doc__ = getattr(_coconut_py_xrange, "__doc__", "") @@ -189,7 +204,6 @@ def __copy__(self): return self.__class__(*self._args) def __eq__(self, other): return self.__class__ is other.__class__ and self._args == other._args -from collections import Sequence as _coconut_Sequence _coconut_Sequence.register(range) @_coconut_wraps(_coconut_py_print) def print(*args, **kwargs): diff --git a/coconut/tests/src/cocotest/agnostic/primary.coco b/coconut/tests/src/cocotest/agnostic/primary.coco index 595a86cfc..d6dbb1640 100644 --- a/coconut/tests/src/cocotest/agnostic/primary.coco +++ b/coconut/tests/src/cocotest/agnostic/primary.coco @@ -1558,4 +1558,18 @@ def primary_test() -> bool: assert not (is not)(False, False) assert (True is not .)(1) assert (. is not True)(1) + a_dict = {} + a_dict[1] = 1 + a_dict[3] = 2 + a_dict[2] = 3 + assert a_dict.keys() |> tuple == (1, 3, 2) + assert not a_dict.keys() `isinstance` list + assert not a_dict.values() `isinstance` list + assert not a_dict.items() `isinstance` list + assert {1: 1, 3: 2, 2: 3}.keys() |> tuple == (1, 3, 2) + assert {**[(1, 1), (3, 2), (2, 3)]}.keys() |> tuple == (1, 3, 2) + assert a_dict == {1: 1, 2: 3, 3: 2} + assert {1: 1} |> str == "{1: 1}" == {1: 1} |> repr + assert py_dict `issubclass` dict + assert py_dict() `isinstance` dict return True