From 301eaf84976d25188b6948611fccc50927dee6c7 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Sat, 23 Mar 2024 00:22:06 -0700 Subject: [PATCH] Improve case def Refs #833. --- coconut/compiler/compiler.py | 192 ++++++++++++------ coconut/compiler/grammar.py | 57 +++--- coconut/compiler/util.py | 2 +- coconut/constants.py | 2 + coconut/root.py | 2 +- coconut/tests/main_test.py | 2 +- .../tests/src/cocotest/agnostic/suite.coco | 9 +- coconut/tests/src/cocotest/agnostic/util.coco | 44 +++- 8 files changed, 208 insertions(+), 102 deletions(-) diff --git a/coconut/compiler/compiler.py b/coconut/compiler/compiler.py index 26bf9af5..efe8f14d 100644 --- a/coconut/compiler/compiler.py +++ b/coconut/compiler/compiler.py @@ -93,6 +93,7 @@ import_existing, use_adaptive_any_of, reverse_any_of, + tempsep, ) from coconut.util import ( pickleable_obj, @@ -774,6 +775,7 @@ def bind(cls): cls.testlist_star_namedexpr <<= attach(cls.testlist_star_namedexpr_tokens, cls.method("testlist_star_expr_handle")) cls.ellipsis <<= attach(cls.ellipsis_tokens, cls.method("ellipsis_handle")) cls.f_string <<= attach(cls.f_string_tokens, cls.method("f_string_handle")) + cls.funcname_typeparams <<= attach(cls.funcname_typeparams_tokens, cls.method("funcname_typeparams_handle")) # standard handlers of the form name <<= attach(name_ref, method("name_handle")) cls.term <<= attach(cls.term_ref, cls.method("term_handle")) @@ -806,7 +808,6 @@ def bind(cls): cls.base_match_for_stmt <<= attach(cls.base_match_for_stmt_ref, cls.method("base_match_for_stmt_handle")) cls.async_with_for_stmt <<= attach(cls.async_with_for_stmt_ref, cls.method("async_with_for_stmt_handle")) cls.unsafe_typedef_tuple <<= attach(cls.unsafe_typedef_tuple_ref, cls.method("unsafe_typedef_tuple_handle")) - cls.funcname_typeparams <<= attach(cls.funcname_typeparams_ref, cls.method("funcname_typeparams_handle")) cls.impl_call <<= attach(cls.impl_call_ref, cls.method("impl_call_handle")) cls.protocol_intersect_expr <<= attach(cls.protocol_intersect_expr_ref, cls.method("protocol_intersect_expr_handle")) @@ -2297,9 +2298,10 @@ def proc_funcdef(self, original, loc, decorators, funcdef, is_async, in_method, def_stmt = raw_lines.pop(0) out = [] - # detect addpattern/copyclosure functions + # detect keyword functions addpattern = False copyclosure = False + typed_case_def = False done = False while not done: if def_stmt.startswith("addpattern "): @@ -2308,6 +2310,11 @@ def proc_funcdef(self, original, loc, decorators, funcdef, is_async, in_method, elif def_stmt.startswith("copyclosure "): def_stmt = assert_remove_prefix(def_stmt, "copyclosure ") copyclosure = True + elif def_stmt.startswith("case "): + def_stmt = assert_remove_prefix(def_stmt, "case ") + case_def_ref, def_stmt = def_stmt.split(unwrapper, 1) + type_param_code, all_type_defs = self.get_ref("case_def", case_def_ref) + typed_case_def = True elif def_stmt.startswith("def"): done = True else: @@ -2547,6 +2554,37 @@ def {mock_var}({mock_paramdef}): if is_match_func: decorators += "@_coconut_mark_as_match\n" # binds most tightly + # handle typed case def functions (must happen before decorators are cleared out) + type_code = None + if typed_case_def: + if undotted_name is not None: + all_type_defs = [ + "def " + def_name + assert_remove_prefix(type_def, "def " + func_name) + for type_def in all_type_defs + ] + type_code = ( + self.deferred_code_proc(type_param_code) + + "\n".join( + ("@_coconut.typing.overload\n" if len(all_type_defs) > 1 else "") + + decorators + + self.deferred_code_proc(type_def) + for type_def in all_type_defs + ) + ) + if len(all_type_defs) > 1: + type_code += "\n" + decorators + handle_indentation(""" +def {def_name}(*_coconut_args, **_coconut_kwargs): + return {any_type_ellipsis} + """).format( + def_name=def_name, + any_type_ellipsis=self.any_type_ellipsis(), + ) + if undotted_name is not None: + type_code += "\n{func_name} = {def_name}".format( + func_name=func_name, + def_name=def_name, + ) + # handle dotted function definition if undotted_name is not None: out.append( @@ -2578,7 +2616,7 @@ def {mock_var}({mock_paramdef}): out += [decorators, def_stmt, func_code] decorators = "" - # handle copyclosure functions + # handle copyclosure functions and type_code if copyclosure: vars_var = self.get_temp_var("func_vars", loc) func_from_vars = vars_var + '["' + def_name + '"]' @@ -2591,24 +2629,39 @@ def {mock_var}({mock_paramdef}): handle_indentation( ''' if _coconut.typing.TYPE_CHECKING: - {code} + {type_code} {vars_var} = {{"{def_name}": {def_name}}} else: {vars_var} = _coconut.globals().copy() {vars_var}.update(_coconut.locals()) _coconut_exec({code_str}, {vars_var}) {func_name} = {func_from_vars} - ''', + ''', add_newline=True, ).format( func_name=func_name, def_name=def_name, vars_var=vars_var, - code=code, + type_code=code if type_code is None else type_code, code_str=self.wrap_str_of(self.reformat_post_deferred_code_proc(code)), func_from_vars=func_from_vars, ), ] + elif type_code: + out = [ + handle_indentation( + ''' +if _coconut.typing.TYPE_CHECKING: + {type_code} +else: + {code} + ''', + add_newline=True, + ).format( + type_code=type_code, + code="".join(out), + ), + ] internal_assert(not decorators, "unhandled decorators", decorators) return "".join(out) @@ -2664,29 +2717,21 @@ def deferred_code_proc(self, inputstring, add_code_at_start=False, ignore_names= func_id = int(assert_remove_prefix(line, funcwrapper)) original, loc, decorators, funcdef, is_async, in_method, is_stmt_lambda = self.get_ref("func", func_id) - # process inner code + # process inner code (we use tempsep to tell what was newly added before the funcdef) decorators = self.deferred_code_proc(decorators, add_code_at_start=True, ignore_names=ignore_names, **kwargs) - funcdef = self.deferred_code_proc(funcdef, ignore_names=ignore_names, **kwargs) - - # handle any non-function code that was added before the funcdef - pre_def_lines = [] - post_def_lines = [] - funcdef_lines = list(literal_lines(funcdef, True)) - for i, line in enumerate(funcdef_lines): - line_indent, line_base = split_leading_indent(line) - if self.def_regex.match(line_base): - pre_def_lines = funcdef_lines[:i] - post_def_lines = funcdef_lines[i:] - break - internal_assert(post_def_lines, "no def statement found in funcdef", funcdef) - - out.append(bef_ind) - out += pre_def_lines - func_indent, func_code, func_dedent = split_leading_trailing_indent("".join(post_def_lines), symmetric=True) - out.append(func_indent) - out.append(self.proc_funcdef(original, loc, decorators, func_code, is_async, in_method, is_stmt_lambda)) - out.append(func_dedent) - out.append(aft_ind) + raw_funcdef = self.deferred_code_proc(tempsep + funcdef, ignore_names=ignore_names, **kwargs) + + pre_funcdef, post_funcdef = raw_funcdef.split(tempsep) + func_indent, func_code, func_dedent = split_leading_trailing_indent(post_funcdef, symmetric=True) + + out += [ + bef_ind, + pre_funcdef, + func_indent, + self.proc_funcdef(original, loc, decorators, func_code, is_async, in_method, is_stmt_lambda), + func_dedent, + aft_ind, + ] # look for add_code_before regexes else: @@ -3490,7 +3535,7 @@ def __rmul__(self, other): return _coconut.NotImplemented def __eq__(self, other): return self.__class__ is other.__class__ and _coconut.tuple.__eq__(self, other) def __hash__(self): - return _coconut.tuple.__hash__(self) ^ hash(self.__class__) + return _coconut.tuple.__hash__(self) ^ _coconut.hash(self.__class__) """, add_newline=True, ).format( @@ -3824,45 +3869,70 @@ def op_match_funcdef_handle(self, original, loc, tokens): def base_case_funcdef_handle(self, original, loc, tokens): """Process case def function definitions.""" - if len(tokens) == 3: - name, typedef_grp, cases = tokens + if len(tokens) == 2: + name_toks, cases = tokens docstring = None - elif len(tokens) == 4: - name, typedef_grp, docstring, cases = tokens + elif len(tokens) == 3: + name_toks, docstring, cases = tokens else: raise CoconutInternalException("invalid case function definition tokens", tokens) - if typedef_grp: - typedef, = typedef_grp + + type_param_code = "" + if len(name_toks) == 1: + name, = name_toks else: - typedef = None + name, paramdefs = name_toks + # paramdefs are type params on >= 3.12 and type var assignments on < 3.12 + if self.target_info >= (3, 12): + name += "[" + ", ".join(paramdefs) + "]" + else: + type_param_code = "".join(paramdefs) check_var = self.get_temp_var("match_check", loc) all_case_code = [] + all_type_defs = [] for case_toks in cases: - if len(case_toks) == 2: - matches, body = case_toks - cond = None - else: - matches, cond, body = case_toks - matcher = self.get_matcher(original, loc, check_var) - matcher.match_function_toks(matches, include_setup=False) - if cond is not None: - matcher.add_guard(cond) - all_case_code.append(handle_indentation(""" + if "match" in case_toks: + if len(case_toks) == 2: + matches, body = case_toks + cond = None + else: + matches, cond, body = case_toks + matcher = self.get_matcher(original, loc, check_var) + matcher.match_function_toks(matches, include_setup=False) + if cond is not None: + matcher.add_guard(cond) + all_case_code.append(handle_indentation(""" if not {check_var}: {match_to_kwargs_var} = {match_to_kwargs_var}_store.copy() {match_out} if {check_var}: {body} - """).format( - check_var=check_var, - match_to_kwargs_var=match_to_kwargs_var, - match_out=matcher.out(), - body=body, - )) + """).format( + check_var=check_var, + match_to_kwargs_var=match_to_kwargs_var, + match_out=matcher.out(), + body=body, + )) + elif "type" in case_toks: + typed_params, typed_ret = case_toks + all_type_defs.append(handle_indentation(""" +def {name}{typed_params}{typed_ret} + return {ellipsis} + """).format( + name=name, + typed_params=typed_params, + typed_ret=typed_ret, + ellipsis=self.any_type_ellipsis(), + )) + else: + raise CoconutInternalException("invalid case_funcdef case_toks", case_toks) + + if type_param_code and not all_type_defs: + raise CoconutDeferredSyntaxError("type parameters in case def but no type declaration cases", loc) - code = handle_indentation(""" + func_code = handle_indentation(""" def {name}({match_func_paramdef}): {docstring} {check_var} = False @@ -3880,17 +3950,11 @@ def {name}({match_func_paramdef}): all_case_code="\n".join(all_case_code), error=self.pattern_error(original, loc, match_to_args_var, check_var, function_match_error_var), ) - if typedef is None: - return code - else: - return handle_indentation(""" -{typedef_stmt} -if not _coconut.typing.TYPE_CHECKING: - {code} - """).format( - code=code, - typedef_stmt=self.typed_assign_stmt_handle([name, typedef, self.any_type_ellipsis()]), - ) + + if not (type_param_code or all_type_defs): + return func_code + + return "case " + self.add_ref("case_def", (type_param_code, all_type_defs)) + unwrapper + func_code def set_literal_handle(self, tokens): """Converts set literals to the right form for the target Python.""" diff --git a/coconut/compiler/grammar.py b/coconut/compiler/grammar.py index d27d58bc..e7234e89 100644 --- a/coconut/compiler/grammar.py +++ b/coconut/compiler/grammar.py @@ -2251,7 +2251,7 @@ class Grammar(object): with_stmt = Forward() funcname_typeparams = Forward() - funcname_typeparams_ref = dotted_setname + Optional(type_params) + funcname_typeparams_tokens = dotted_setname + Optional(type_params) name_funcdef = condense(funcname_typeparams + parameters) op_tfpdef = unsafe_typedef_default | condense(setname + Optional(default)) op_funcdef_arg = setname | condense(lparen.suppress() + op_tfpdef + rparen.suppress()) @@ -2359,39 +2359,48 @@ class Grammar(object): base_case_funcdef = Forward() base_case_funcdef_ref = ( keyword("def").suppress() - + funcname_typeparams + + Group(funcname_typeparams_tokens) + colon.suppress() - - Group(Optional(typedef_test)) - newline.suppress() - indent.suppress() - Optional(docstring) - - Group(OneOrMore(Group( - keyword("match").suppress() - + lparen.suppress() - + match_args_list - + rparen.suppress() - + match_guard - + ( - colon.suppress() - + ( - newline.suppress() - + indent.suppress() - + attach(condense(OneOrMore(stmt)), make_suite_handle) - + dedent.suppress() - | attach(simple_stmt, make_suite_handle) - ) - | equals.suppress() + - Group(OneOrMore( + labeled_group( + keyword("match").suppress() + + lparen.suppress() + + match_args_list + + match_guard + + rparen.suppress() + ( - ( + colon.suppress() + + ( newline.suppress() + indent.suppress() - + attach(math_funcdef_body, make_suite_handle) + + attach(condense(OneOrMore(stmt)), make_suite_handle) + dedent.suppress() + | attach(simple_stmt, make_suite_handle) ) - | attach(implicit_return_stmt, make_suite_handle) - ) + | equals.suppress() + + ( + ( + newline.suppress() + + indent.suppress() + + attach(math_funcdef_body, make_suite_handle) + + dedent.suppress() + ) + | attach(implicit_return_stmt, make_suite_handle) + ) + ), + "match", ) - ))) + | labeled_group( + keyword("type").suppress() + + parameters + + return_typedef + + newline.suppress(), + "type", + ) + )) - dedent.suppress() ) case_funcdef = keyword("case").suppress() + base_case_funcdef diff --git a/coconut/compiler/util.py b/coconut/compiler/util.py index add897e3..9cea72c6 100644 --- a/coconut/compiler/util.py +++ b/coconut/compiler/util.py @@ -1077,7 +1077,7 @@ def load_cache_for(inputstring, codepath): incremental_info=incremental_info, )) if incremental_enabled: - logger.warn("Populating initial parsing cache (compilation may take longer than usual)...") + logger.warn("Populating initial parsing cache (initial compilation may take a while; pass --no-cache to disable)...") else: cache_path = None logger.log("Declined to load cache for {filename!r} ({incremental_info}).".format( diff --git a/coconut/constants.py b/coconut/constants.py index 62e5f296..50db1383 100644 --- a/coconut/constants.py +++ b/coconut/constants.py @@ -291,6 +291,7 @@ def get_path_env_var(env_var, default): early_passthrough_wrapper = "\u2038" # caret lnwrapper = "\u2021" # double dagger unwrapper = "\u23f9" # stop square +tempsep = "\u22ee" # vertical ellipsis funcwrapper = "def:" # must be tuples for .startswith / .endswith purposes @@ -314,6 +315,7 @@ def get_path_env_var(env_var, default): ) + indchars + comment_chars reserved_compiler_symbols = delimiter_symbols + ( reserved_prefix, + tempsep, funcwrapper, ) diff --git a/coconut/root.py b/coconut/root.py index bcd32000..1e3c3a43 100644 --- a/coconut/root.py +++ b/coconut/root.py @@ -26,7 +26,7 @@ VERSION = "3.1.0" VERSION_NAME = None # False for release, int >= 1 for develop -DEVELOP = 6 +DEVELOP = 7 ALPHA = False # for pre releases rather than post releases assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1" diff --git a/coconut/tests/main_test.py b/coconut/tests/main_test.py index 19c9900d..22c1be69 100644 --- a/coconut/tests/main_test.py +++ b/coconut/tests/main_test.py @@ -339,7 +339,7 @@ def call( continue # combine mypy error lines - if any(infix in line for infix in mypy_err_infixes): + if any(infix in line for infix in mypy_err_infixes) and i < len(raw_lines) - 1: # always add the next line, since it might be a continuation of the error message line += "\n" + raw_lines[i + 1] i += 1 diff --git a/coconut/tests/src/cocotest/agnostic/suite.coco b/coconut/tests/src/cocotest/agnostic/suite.coco index 483d6471..d14b3ad2 100644 --- a/coconut/tests/src/cocotest/agnostic/suite.coco +++ b/coconut/tests/src/cocotest/agnostic/suite.coco @@ -437,6 +437,7 @@ def suite_test() -> bool: assert partition([1, 2, 3], 2) |> map$(tuple) |> list == [(1,), (3, 2)] == partition_([1, 2, 3], 2) |> map$(tuple) |> list assert myreduce((+), (1, 2, 3)) == 6 assert recurse_n_times(10000) + assert recurse_n_times_(10000) assert fake_recurse_n_times(10000) a = clsA() assert ((not)..a.true)() is False @@ -535,7 +536,7 @@ def suite_test() -> bool: assert False tv = typed_vector() assert repr(tv) == "typed_vector(x=0, y=0)" - for obj in (factorial, iadd, collatz, recurse_n_times): + for obj in (factorial, iadd, collatz, recurse_n_times, recurse_n_times_): assert obj.__doc__ == "this is a docstring", obj assert list_type((|1,2|)) == "at least 2" assert list_type((|1|)) == "at least 1" @@ -632,8 +633,8 @@ def suite_test() -> bool: assert dt.N()$[:2] |> list == [(dt, 0), (dt, 1)] == dt.N_()$[:2] |> list assert map(HasDefs().a_def, range(5)) |> list == range(1, 6) |> list assert HasDefs().a_def 1 == 2 - assert HasDefs().case_def 1 == 0 - assert HasDefs.__annotations__.keys() |> set == {"a_def", "case_def"}, HasDefs.__annotations__ + assert HasDefs().case_def 1 == 0 == HasDefs().case_def_ 1 + assert HasDefs.__annotations__.keys() |> set == {"a_def"}, HasDefs.__annotations__ assert store.plus1 store.one == store.two assert ret_locals()["my_loc"] == 1 assert ret_globals()["my_glob"] == 1 @@ -1085,6 +1086,8 @@ forward 2""") == 900 assert ret_args_kwargs ↤** dict(a=1) == ((), dict(a=1)) assert ret_args_kwargs ↤**? None is None assert [1, 2, 3] |> reduce_with_init$(+) == 6 == (1, 2, 3) |> iter |> reduce_with_init$((+), init=0) + assert min(1, 2) == 1 == my_min(1, 2) + assert min([1, 2]) == 1 == my_min([1, 2]) with process_map.multiple_sequential_calls(): # type: ignore assert process_map(tuple <.. (|>)$(to_sort), qsorts) |> list == [to_sort |> sorted |> tuple] * len(qsorts) diff --git a/coconut/tests/src/cocotest/agnostic/util.coco b/coconut/tests/src/cocotest/agnostic/util.coco index 848a7f7a..aa99e1b2 100644 --- a/coconut/tests/src/cocotest/agnostic/util.coco +++ b/coconut/tests/src/cocotest/agnostic/util.coco @@ -2,6 +2,7 @@ import sys import random import pickle +import typing import operator # NOQA from contextlib import contextmanager from functools import wraps @@ -10,6 +11,8 @@ from collections import defaultdict, deque __doc__ = "docstring" # Helpers: +___ = typing.cast(typing.Any, ...) + def rand_list(n): '''Generate a random list of length n.''' return [random.randrange(10) for x in range(0, n)] @@ -243,7 +246,6 @@ addpattern def x! if x = False # type: ignore addpattern def x! = True # type: ignore # Type aliases: -import typing if sys.version_info >= (3, 5) or TYPE_CHECKING: type list_or_tuple = list | tuple @@ -395,6 +397,11 @@ def recurse_n_times(n) = return True recurse_n_times(n-1) +case def recurse_n_times_: + """this is a docstring""" + match(0) = True + match(n) = recurse_n_times_(n-1) + def is_even(n) = if not n: return True @@ -632,18 +639,20 @@ def factorial5(value): else: return None raise TypeError() -case def factorial6[Num: (int, float)]: (Num, Num) -> Num +case def factorial6[Num: (int, float)]: """Factorial function""" + type(n: Num, acc: Num = ___) -> Num match (0, acc=1): return acc - match (int(n), acc=1) if n > 0: + match (int(n), acc=1 if n > 0): return factorial6(n - 1, acc * n) - match (int(n), acc=...) if n < 0: + match (int(n), acc=... if n < 0): return None -case def factorial7[Num <: int | float]: (Num, Num) -> Num +case def factorial7[Num <: int | float]: + type(n: Num, acc: Num = ___) -> Num match(0, acc=1) = acc - match(int(n), acc=1) if n > 0 = factorial7(n - 1, acc * n) - match(int(n), acc=...) if n < 0 = None + match(int(n), acc=1 if n > 0) = factorial7(n - 1, acc * n) + match(int(n), acc=... if n < 0) = None match def fact(n) = fact(n, 1) match addpattern def fact(0, acc) = acc # type: ignore @@ -1414,13 +1423,20 @@ class HasDefs: a_def: typing.Callable @staticmethod - case def case_def: int -> int + case def case_def: + type(_: int) -> int match(0) = 1 match(1) = 0 def HasDefs.a_def(self, 0) = 1 # type: ignore addpattern def HasDefs.a_def(self, x) = x + 1 # type: ignore +@staticmethod # type: ignore +case def HasDefs.case_def_: # type: ignore + type(_: int) -> int + match(0) = 1 + match(1) = 0 + # Storage class class store: @@ -2080,3 +2096,15 @@ def outer_func_6(): match() = x funcs.append(inner_func) return funcs + + +# case def + +case def my_min[T]: + type(xs: T[]) -> T + match([x]) = x + match([x] + xs) = my_min(x, my_min(xs)) + + type(x: T, y: T) -> T + match(x, y if x <= y) = x + match(x, y) = y