Skip to content

Commit

Permalink
Improve case def
Browse files Browse the repository at this point in the history
Refs   #833.
  • Loading branch information
evhub committed Mar 23, 2024
1 parent c3925ef commit 301eaf8
Show file tree
Hide file tree
Showing 8 changed files with 208 additions and 102 deletions.
192 changes: 128 additions & 64 deletions coconut/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
import_existing,
use_adaptive_any_of,
reverse_any_of,
tempsep,
)
from coconut.util import (
pickleable_obj,
Expand Down Expand Up @@ -774,6 +775,7 @@ def bind(cls):
cls.testlist_star_namedexpr <<= attach(cls.testlist_star_namedexpr_tokens, cls.method("testlist_star_expr_handle"))
cls.ellipsis <<= attach(cls.ellipsis_tokens, cls.method("ellipsis_handle"))
cls.f_string <<= attach(cls.f_string_tokens, cls.method("f_string_handle"))
cls.funcname_typeparams <<= attach(cls.funcname_typeparams_tokens, cls.method("funcname_typeparams_handle"))

# standard handlers of the form name <<= attach(name_ref, method("name_handle"))
cls.term <<= attach(cls.term_ref, cls.method("term_handle"))
Expand Down Expand Up @@ -806,7 +808,6 @@ def bind(cls):
cls.base_match_for_stmt <<= attach(cls.base_match_for_stmt_ref, cls.method("base_match_for_stmt_handle"))
cls.async_with_for_stmt <<= attach(cls.async_with_for_stmt_ref, cls.method("async_with_for_stmt_handle"))
cls.unsafe_typedef_tuple <<= attach(cls.unsafe_typedef_tuple_ref, cls.method("unsafe_typedef_tuple_handle"))
cls.funcname_typeparams <<= attach(cls.funcname_typeparams_ref, cls.method("funcname_typeparams_handle"))
cls.impl_call <<= attach(cls.impl_call_ref, cls.method("impl_call_handle"))
cls.protocol_intersect_expr <<= attach(cls.protocol_intersect_expr_ref, cls.method("protocol_intersect_expr_handle"))

Expand Down Expand Up @@ -2297,9 +2298,10 @@ def proc_funcdef(self, original, loc, decorators, funcdef, is_async, in_method,
def_stmt = raw_lines.pop(0)
out = []

# detect addpattern/copyclosure functions
# detect keyword functions
addpattern = False
copyclosure = False
typed_case_def = False
done = False
while not done:
if def_stmt.startswith("addpattern "):
Expand All @@ -2308,6 +2310,11 @@ def proc_funcdef(self, original, loc, decorators, funcdef, is_async, in_method,
elif def_stmt.startswith("copyclosure "):
def_stmt = assert_remove_prefix(def_stmt, "copyclosure ")
copyclosure = True
elif def_stmt.startswith("case "):
def_stmt = assert_remove_prefix(def_stmt, "case ")
case_def_ref, def_stmt = def_stmt.split(unwrapper, 1)
type_param_code, all_type_defs = self.get_ref("case_def", case_def_ref)
typed_case_def = True
elif def_stmt.startswith("def"):
done = True
else:
Expand Down Expand Up @@ -2547,6 +2554,37 @@ def {mock_var}({mock_paramdef}):
if is_match_func:
decorators += "@_coconut_mark_as_match\n" # binds most tightly

# handle typed case def functions (must happen before decorators are cleared out)
type_code = None
if typed_case_def:
if undotted_name is not None:
all_type_defs = [
"def " + def_name + assert_remove_prefix(type_def, "def " + func_name)
for type_def in all_type_defs
]
type_code = (
self.deferred_code_proc(type_param_code)
+ "\n".join(
("@_coconut.typing.overload\n" if len(all_type_defs) > 1 else "")
+ decorators
+ self.deferred_code_proc(type_def)
for type_def in all_type_defs
)
)
if len(all_type_defs) > 1:
type_code += "\n" + decorators + handle_indentation("""
def {def_name}(*_coconut_args, **_coconut_kwargs):
return {any_type_ellipsis}
""").format(
def_name=def_name,
any_type_ellipsis=self.any_type_ellipsis(),
)
if undotted_name is not None:
type_code += "\n{func_name} = {def_name}".format(
func_name=func_name,
def_name=def_name,
)

# handle dotted function definition
if undotted_name is not None:
out.append(
Expand Down Expand Up @@ -2578,7 +2616,7 @@ def {mock_var}({mock_paramdef}):
out += [decorators, def_stmt, func_code]
decorators = ""

# handle copyclosure functions
# handle copyclosure functions and type_code
if copyclosure:
vars_var = self.get_temp_var("func_vars", loc)
func_from_vars = vars_var + '["' + def_name + '"]'
Expand All @@ -2591,24 +2629,39 @@ def {mock_var}({mock_paramdef}):
handle_indentation(
'''
if _coconut.typing.TYPE_CHECKING:
{code}
{type_code}
{vars_var} = {{"{def_name}": {def_name}}}
else:
{vars_var} = _coconut.globals().copy()
{vars_var}.update(_coconut.locals())
_coconut_exec({code_str}, {vars_var})
{func_name} = {func_from_vars}
''',
''',
add_newline=True,
).format(
func_name=func_name,
def_name=def_name,
vars_var=vars_var,
code=code,
type_code=code if type_code is None else type_code,
code_str=self.wrap_str_of(self.reformat_post_deferred_code_proc(code)),
func_from_vars=func_from_vars,
),
]
elif type_code:
out = [
handle_indentation(
'''
if _coconut.typing.TYPE_CHECKING:
{type_code}
else:
{code}
''',
add_newline=True,
).format(
type_code=type_code,
code="".join(out),
),
]

internal_assert(not decorators, "unhandled decorators", decorators)
return "".join(out)
Expand Down Expand Up @@ -2664,29 +2717,21 @@ def deferred_code_proc(self, inputstring, add_code_at_start=False, ignore_names=
func_id = int(assert_remove_prefix(line, funcwrapper))
original, loc, decorators, funcdef, is_async, in_method, is_stmt_lambda = self.get_ref("func", func_id)

# process inner code
# process inner code (we use tempsep to tell what was newly added before the funcdef)
decorators = self.deferred_code_proc(decorators, add_code_at_start=True, ignore_names=ignore_names, **kwargs)
funcdef = self.deferred_code_proc(funcdef, ignore_names=ignore_names, **kwargs)

# handle any non-function code that was added before the funcdef
pre_def_lines = []
post_def_lines = []
funcdef_lines = list(literal_lines(funcdef, True))
for i, line in enumerate(funcdef_lines):
line_indent, line_base = split_leading_indent(line)
if self.def_regex.match(line_base):
pre_def_lines = funcdef_lines[:i]
post_def_lines = funcdef_lines[i:]
break
internal_assert(post_def_lines, "no def statement found in funcdef", funcdef)

out.append(bef_ind)
out += pre_def_lines
func_indent, func_code, func_dedent = split_leading_trailing_indent("".join(post_def_lines), symmetric=True)
out.append(func_indent)
out.append(self.proc_funcdef(original, loc, decorators, func_code, is_async, in_method, is_stmt_lambda))
out.append(func_dedent)
out.append(aft_ind)
raw_funcdef = self.deferred_code_proc(tempsep + funcdef, ignore_names=ignore_names, **kwargs)

pre_funcdef, post_funcdef = raw_funcdef.split(tempsep)
func_indent, func_code, func_dedent = split_leading_trailing_indent(post_funcdef, symmetric=True)

out += [
bef_ind,
pre_funcdef,
func_indent,
self.proc_funcdef(original, loc, decorators, func_code, is_async, in_method, is_stmt_lambda),
func_dedent,
aft_ind,
]

# look for add_code_before regexes
else:
Expand Down Expand Up @@ -3490,7 +3535,7 @@ def __rmul__(self, other): return _coconut.NotImplemented
def __eq__(self, other):
return self.__class__ is other.__class__ and _coconut.tuple.__eq__(self, other)
def __hash__(self):
return _coconut.tuple.__hash__(self) ^ hash(self.__class__)
return _coconut.tuple.__hash__(self) ^ _coconut.hash(self.__class__)
""",
add_newline=True,
).format(
Expand Down Expand Up @@ -3824,45 +3869,70 @@ def op_match_funcdef_handle(self, original, loc, tokens):

def base_case_funcdef_handle(self, original, loc, tokens):
"""Process case def function definitions."""
if len(tokens) == 3:
name, typedef_grp, cases = tokens
if len(tokens) == 2:
name_toks, cases = tokens
docstring = None
elif len(tokens) == 4:
name, typedef_grp, docstring, cases = tokens
elif len(tokens) == 3:
name_toks, docstring, cases = tokens
else:
raise CoconutInternalException("invalid case function definition tokens", tokens)
if typedef_grp:
typedef, = typedef_grp

type_param_code = ""
if len(name_toks) == 1:
name, = name_toks
else:
typedef = None
name, paramdefs = name_toks
# paramdefs are type params on >= 3.12 and type var assignments on < 3.12
if self.target_info >= (3, 12):
name += "[" + ", ".join(paramdefs) + "]"
else:
type_param_code = "".join(paramdefs)

check_var = self.get_temp_var("match_check", loc)

all_case_code = []
all_type_defs = []
for case_toks in cases:
if len(case_toks) == 2:
matches, body = case_toks
cond = None
else:
matches, cond, body = case_toks
matcher = self.get_matcher(original, loc, check_var)
matcher.match_function_toks(matches, include_setup=False)
if cond is not None:
matcher.add_guard(cond)
all_case_code.append(handle_indentation("""
if "match" in case_toks:
if len(case_toks) == 2:
matches, body = case_toks
cond = None
else:
matches, cond, body = case_toks
matcher = self.get_matcher(original, loc, check_var)
matcher.match_function_toks(matches, include_setup=False)
if cond is not None:
matcher.add_guard(cond)
all_case_code.append(handle_indentation("""
if not {check_var}:
{match_to_kwargs_var} = {match_to_kwargs_var}_store.copy()
{match_out}
if {check_var}:
{body}
""").format(
check_var=check_var,
match_to_kwargs_var=match_to_kwargs_var,
match_out=matcher.out(),
body=body,
))
""").format(
check_var=check_var,
match_to_kwargs_var=match_to_kwargs_var,
match_out=matcher.out(),
body=body,
))
elif "type" in case_toks:
typed_params, typed_ret = case_toks
all_type_defs.append(handle_indentation("""
def {name}{typed_params}{typed_ret}
return {ellipsis}
""").format(
name=name,
typed_params=typed_params,
typed_ret=typed_ret,
ellipsis=self.any_type_ellipsis(),
))
else:
raise CoconutInternalException("invalid case_funcdef case_toks", case_toks)

if type_param_code and not all_type_defs:
raise CoconutDeferredSyntaxError("type parameters in case def but no type declaration cases", loc)

code = handle_indentation("""
func_code = handle_indentation("""
def {name}({match_func_paramdef}):
{docstring}
{check_var} = False
Expand All @@ -3880,17 +3950,11 @@ def {name}({match_func_paramdef}):
all_case_code="\n".join(all_case_code),
error=self.pattern_error(original, loc, match_to_args_var, check_var, function_match_error_var),
)
if typedef is None:
return code
else:
return handle_indentation("""
{typedef_stmt}
if not _coconut.typing.TYPE_CHECKING:
{code}
""").format(
code=code,
typedef_stmt=self.typed_assign_stmt_handle([name, typedef, self.any_type_ellipsis()]),
)

if not (type_param_code or all_type_defs):
return func_code

return "case " + self.add_ref("case_def", (type_param_code, all_type_defs)) + unwrapper + func_code

def set_literal_handle(self, tokens):
"""Converts set literals to the right form for the target Python."""
Expand Down
57 changes: 33 additions & 24 deletions coconut/compiler/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2251,7 +2251,7 @@ class Grammar(object):
with_stmt = Forward()

funcname_typeparams = Forward()
funcname_typeparams_ref = dotted_setname + Optional(type_params)
funcname_typeparams_tokens = dotted_setname + Optional(type_params)
name_funcdef = condense(funcname_typeparams + parameters)
op_tfpdef = unsafe_typedef_default | condense(setname + Optional(default))
op_funcdef_arg = setname | condense(lparen.suppress() + op_tfpdef + rparen.suppress())
Expand Down Expand Up @@ -2359,39 +2359,48 @@ class Grammar(object):
base_case_funcdef = Forward()
base_case_funcdef_ref = (
keyword("def").suppress()
+ funcname_typeparams
+ Group(funcname_typeparams_tokens)
+ colon.suppress()
- Group(Optional(typedef_test))
- newline.suppress()
- indent.suppress()
- Optional(docstring)
- Group(OneOrMore(Group(
keyword("match").suppress()
+ lparen.suppress()
+ match_args_list
+ rparen.suppress()
+ match_guard
+ (
colon.suppress()
+ (
newline.suppress()
+ indent.suppress()
+ attach(condense(OneOrMore(stmt)), make_suite_handle)
+ dedent.suppress()
| attach(simple_stmt, make_suite_handle)
)
| equals.suppress()
- Group(OneOrMore(
labeled_group(
keyword("match").suppress()
+ lparen.suppress()
+ match_args_list
+ match_guard
+ rparen.suppress()
+ (
(
colon.suppress()
+ (
newline.suppress()
+ indent.suppress()
+ attach(math_funcdef_body, make_suite_handle)
+ attach(condense(OneOrMore(stmt)), make_suite_handle)
+ dedent.suppress()
| attach(simple_stmt, make_suite_handle)
)
| attach(implicit_return_stmt, make_suite_handle)
)
| equals.suppress()
+ (
(
newline.suppress()
+ indent.suppress()
+ attach(math_funcdef_body, make_suite_handle)
+ dedent.suppress()
)
| attach(implicit_return_stmt, make_suite_handle)
)
),
"match",
)
)))
| labeled_group(
keyword("type").suppress()
+ parameters
+ return_typedef
+ newline.suppress(),
"type",
)
))
- dedent.suppress()
)
case_funcdef = keyword("case").suppress() + base_case_funcdef
Expand Down
2 changes: 1 addition & 1 deletion coconut/compiler/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,7 +1077,7 @@ def load_cache_for(inputstring, codepath):
incremental_info=incremental_info,
))
if incremental_enabled:
logger.warn("Populating initial parsing cache (compilation may take longer than usual)...")
logger.warn("Populating initial parsing cache (initial compilation may take a while; pass --no-cache to disable)...")
else:
cache_path = None
logger.log("Declined to load cache for {filename!r} ({incremental_info}).".format(
Expand Down
Loading

0 comments on commit 301eaf8

Please sign in to comment.