Skip to content

Commit

Permalink
Properly universalize dicts
Browse files Browse the repository at this point in the history
Resolves   #723.
  • Loading branch information
evhub committed Mar 11, 2023
1 parent f94bf8f commit 26b3b7c
Show file tree
Hide file tree
Showing 9 changed files with 95 additions and 39 deletions.
1 change: 1 addition & 0 deletions DOCS.md
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ While Coconut syntax is based off of the latest Python 3, Coconut code compiled
To make Coconut built-ins universal across Python versions, Coconut makes available on any Python version built-ins that only exist in later versions, including **automatically overwriting Python 2 built-ins with their Python 3 counterparts.** Additionally, Coconut also [overwrites some Python 3 built-ins for optimization and enhancement purposes](#enhanced-built-ins). If access to the original Python versions of any overwritten built-ins is desired, the old built-ins can be retrieved by prefixing them with `py_`. Specifically, the overwritten built-ins are:

- `py_chr`
- `py_dict`
- `py_hex`
- `py_input`
- `py_int`
Expand Down
1 change: 1 addition & 0 deletions __coconut__/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ if sys.version_info < (3, 7):


py_chr = chr
py_dict = dict
py_hex = hex
py_input = input
py_int = int
Expand Down
80 changes: 52 additions & 28 deletions coconut/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,48 @@ def reconstitute_paramdef(pos_only_args, req_args, default_args, star_arg, kwd_o
return ", ".join(args_list)


def split_star_expr_tokens(tokens, is_dict=False):
"""Split testlist_star_expr or dict_literal tokens."""
groups = [[]]
has_star = False
has_comma = False
for tok_grp in tokens:
if tok_grp == ",":
has_comma = True
elif len(tok_grp) == 1:
internal_assert(not is_dict, "found non-star non-pair item in dict literal", tok_grp)
groups[-1].append(tok_grp[0])
elif len(tok_grp) == 2:
internal_assert(not tok_grp[0].lstrip("*"), "invalid star expr item signifier", tok_grp[0])
has_star = True
groups.append(tok_grp[1])
groups.append([])
elif len(tok_grp) == 3:
internal_assert(is_dict, "found dict key-value pair in non-dict tokens", tok_grp)
k, c, v = tok_grp
internal_assert(c == ":", "invalid colon in dict literal item", c)
groups[-1].append((k, v))
else:
raise CoconutInternalException("invalid testlist_star_expr tokens", tokens)
if not groups[-1]:
groups.pop()
return groups, has_star, has_comma


def join_dict_group(group, as_tuples=False):
"""Join group from split_star_expr_tokens$(is_dict=True)."""
items = []
for k, v in group:
if as_tuples:
items.append("(" + k + ", " + v + ")")
else:
items.append(k + ": " + v)
if as_tuples:
return tuple_str_of(items, add_parens=False)
else:
return ", ".join(items)


# end: UTILITIES
# -----------------------------------------------------------------------------------------------------------------------
# COMPILER:
Expand Down Expand Up @@ -3011,7 +3053,7 @@ def dict_comp_handle(self, loc, tokens):
if self.target.startswith("3"):
return "{" + key + ": " + val + " " + comp + "}"
else:
return "dict(((" + key + "), (" + val + ")) " + comp + ")"
return "_coconut.dict(((" + key + "), (" + val + ")) " + comp + ")"

def pattern_error(self, original, loc, value_var, check_var, match_error_class='_coconut_MatchError'):
"""Construct a pattern-matching error message."""
Expand Down Expand Up @@ -3594,30 +3636,9 @@ def unsafe_typedef_or_expr_handle(self, tokens):
else:
return "_coconut.typing.Union[" + ", ".join(tokens) + "]"

def split_star_expr_tokens(self, tokens):
"""Split testlist_star_expr or dict_literal tokens."""
groups = [[]]
has_star = False
has_comma = False
for tok_grp in tokens:
if tok_grp == ",":
has_comma = True
elif len(tok_grp) == 1:
groups[-1].append(tok_grp[0])
elif len(tok_grp) == 2:
internal_assert(not tok_grp[0].lstrip("*"), "invalid star expr item signifier", tok_grp[0])
has_star = True
groups.append(tok_grp[1])
groups.append([])
else:
raise CoconutInternalException("invalid testlist_star_expr tokens", tokens)
if not groups[-1]:
groups.pop()
return groups, has_star, has_comma

def testlist_star_expr_handle(self, original, loc, tokens, is_list=False):
"""Handle naked a, *b."""
groups, has_star, has_comma = self.split_star_expr_tokens(tokens)
groups, has_star, has_comma = split_star_expr_tokens(tokens)
is_sequence = has_comma or is_list

if not is_sequence and not has_star:
Expand Down Expand Up @@ -3667,20 +3688,23 @@ def list_expr_handle(self, original, loc, tokens):
def dict_literal_handle(self, tokens):
"""Handle {**d1, **d2}."""
if not tokens:
return "{}"
return "{}" if self.target.startswith("3") else "_coconut.dict()"

groups, has_star, _ = self.split_star_expr_tokens(tokens)
groups, has_star, _ = split_star_expr_tokens(tokens, is_dict=True)

if not has_star:
internal_assert(len(groups) == 1, "dict_literal group splitting failed on", tokens)
return "{" + ", ".join(groups[0]) + "}"
if self.target.startswith("3"):
return "{" + join_dict_group(groups[0]) + "}"
else:
return "_coconut.dict((" + join_dict_group(groups[0], as_tuples=True) + "))"

# naturally supported on 3.5+
elif self.target_info >= (3, 5):
to_literal = []
for g in groups:
if isinstance(g, list):
to_literal.extend(g)
to_literal.append(join_dict_group(g))
else:
to_literal.append("**" + g)
return "{" + ", ".join(to_literal) + "}"
Expand All @@ -3690,7 +3714,7 @@ def dict_literal_handle(self, tokens):
to_merge = []
for g in groups:
if isinstance(g, list):
to_merge.append("{" + ", ".join(g) + "}")
to_merge.append("{" + join_dict_group(g) + "}")
else:
to_merge.append(g)
return "_coconut_dict_merge(" + ", ".join(to_merge) + ")"
Expand Down
3 changes: 2 additions & 1 deletion coconut/compiler/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,8 @@ class Grammar(object):
lbrace.suppress()
+ Optional(
tokenlist(
Group(addspace(condense(test + colon) + test)) | dubstar_expr,
Group(test + colon + test)
| dubstar_expr,
comma,
),
)
Expand Down
2 changes: 1 addition & 1 deletion coconut/compiler/header.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def process_header_args(which, use_hash, target, no_tco, strict, no_wrap):

format_dict = dict(
COMMENT=COMMENT,
empty_dict="{}",
empty_dict="{}" if target_startswith == "3" else "_coconut.dict()",
lbrace="{",
rbrace="}",
is_data_var=is_data_var,
Expand Down
4 changes: 2 additions & 2 deletions coconut/compiler/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,8 +504,8 @@ def match_dict(self, tokens, item):
if rest is not None and rest != wildcard:
match_keys = [k for k, v in matches]
rest_item = (
"dict((k, v) for k, v in "
+ item + ".items() if k not in set(("
"_coconut.dict((k, v) for k, v in "
+ item + ".items() if k not in _coconut.set(("
+ ", ".join(match_keys) + ("," if len(match_keys) == 1 else "")
+ ")))"
)
Expand Down
1 change: 1 addition & 0 deletions coconut/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,7 @@ def get_bool_env_var(env_var, default=False):
"cycle",
"windowsof",
"py_chr",
"py_dict",
"py_hex",
"py_input",
"py_int",
Expand Down
28 changes: 21 additions & 7 deletions coconut/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
VERSION = "3.0.0"
VERSION_NAME = None
# False for release, int >= 1 for develop
DEVELOP = 11
DEVELOP = 12
ALPHA = True # for pre releases rather than post releases

# -----------------------------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -80,8 +80,8 @@ def breakpoint(*args, **kwargs):
'''

# if a new assignment is added below, a new builtins import should be added alongside it
_base_py3_header = r'''from builtins import chr, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, repr
py_chr, py_hex, py_input, py_int, py_map, py_object, py_oct, py_open, py_print, py_range, py_str, py_super, py_zip, py_filter, py_reversed, py_enumerate, py_repr = chr, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, repr
_base_py3_header = r'''from builtins import chr, dict, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, repr
py_chr, py_dict, py_hex, py_input, py_int, py_map, py_object, py_oct, py_open, py_print, py_range, py_str, py_super, py_zip, py_filter, py_reversed, py_enumerate, py_repr = chr, dict, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, repr
_coconut_py_str, _coconut_py_super = str, super
from functools import wraps as _coconut_wraps
exec("_coconut_exec = exec")
Expand All @@ -96,10 +96,11 @@ def breakpoint(*args, **kwargs):
'''

# if a new assignment is added below, a new builtins import should be added alongside it
PY27_HEADER = r'''from __builtin__ import chr, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, raw_input, xrange, repr, long
py_chr, py_hex, py_input, py_int, py_map, py_object, py_oct, py_open, py_print, py_range, py_str, py_super, py_zip, py_filter, py_reversed, py_enumerate, py_raw_input, py_xrange, py_repr = chr, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, raw_input, xrange, repr
_coconut_py_raw_input, _coconut_py_xrange, _coconut_py_int, _coconut_py_long, _coconut_py_print, _coconut_py_str, _coconut_py_super, _coconut_py_unicode, _coconut_py_repr = raw_input, xrange, int, long, print, str, super, unicode, repr
PY27_HEADER = r'''from __builtin__ import chr, dict, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, raw_input, xrange, repr, long
py_chr, py_dict, py_hex, py_input, py_int, py_map, py_object, py_oct, py_open, py_print, py_range, py_str, py_super, py_zip, py_filter, py_reversed, py_enumerate, py_raw_input, py_xrange, py_repr = chr, dict, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, raw_input, xrange, repr
_coconut_py_raw_input, _coconut_py_xrange, _coconut_py_int, _coconut_py_long, _coconut_py_print, _coconut_py_str, _coconut_py_super, _coconut_py_unicode, _coconut_py_repr, _coconut_py_dict = raw_input, xrange, int, long, print, str, super, unicode, repr, dict
from functools import wraps as _coconut_wraps
from collections import Sequence as _coconut_Sequence, OrderedDict as _coconut_OrderedDict
from future_builtins import *
chr, str = unichr, unicode
from io import open
Expand All @@ -123,6 +124,20 @@ def __instancecheck__(cls, inst):
return _coconut.isinstance(inst, (_coconut_py_int, _coconut_py_long))
def __subclasscheck__(cls, subcls):
return _coconut.issubclass(subcls, (_coconut_py_int, _coconut_py_long))
class dict(_coconut_OrderedDict):
__slots__ = ()
__doc__ = getattr(_coconut_OrderedDict, "__doc__", "<see help(py_dict)>")
class __metaclass__(type):
def __instancecheck__(cls, inst):
return _coconut.isinstance(inst, _coconut_py_dict)
def __subclasscheck__(cls, subcls):
return _coconut.issubclass(subcls, _coconut_py_dict)
__eq__ = _coconut_py_dict.__eq__
__repr__ = _coconut_py_dict.__repr__
__str__ = _coconut_py_dict.__str__
keys = _coconut_OrderedDict.viewkeys
values = _coconut_OrderedDict.viewvalues
items = _coconut_OrderedDict.viewitems
class range(object):
__slots__ = ("_xrange",)
__doc__ = getattr(_coconut_py_xrange, "__doc__", "<see help(py_xrange)>")
Expand Down Expand Up @@ -189,7 +204,6 @@ def __copy__(self):
return self.__class__(*self._args)
def __eq__(self, other):
return self.__class__ is other.__class__ and self._args == other._args
from collections import Sequence as _coconut_Sequence
_coconut_Sequence.register(range)
@_coconut_wraps(_coconut_py_print)
def print(*args, **kwargs):
Expand Down
14 changes: 14 additions & 0 deletions coconut/tests/src/cocotest/agnostic/primary.coco
Original file line number Diff line number Diff line change
Expand Up @@ -1558,4 +1558,18 @@ def primary_test() -> bool:
assert not (is not)(False, False)
assert (True is not .)(1)
assert (. is not True)(1)
a_dict = {}
a_dict[1] = 1
a_dict[3] = 2
a_dict[2] = 3
assert a_dict.keys() |> tuple == (1, 3, 2)
assert not a_dict.keys() `isinstance` list
assert not a_dict.values() `isinstance` list
assert not a_dict.items() `isinstance` list
assert {1: 1, 3: 2, 2: 3}.keys() |> tuple == (1, 3, 2)
assert {**[(1, 1), (3, 2), (2, 3)]}.keys() |> tuple == (1, 3, 2)
assert a_dict == {1: 1, 2: 3, 3: 2}
assert {1: 1} |> str == "{1: 1}" == {1: 1} |> repr
assert py_dict `issubclass` dict
assert py_dict() `isinstance` dict
return True

0 comments on commit 26b3b7c

Please sign in to comment.