From fd327e4465b178924c387a5dccc6828b4aa76d74 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Mon, 27 Nov 2023 21:02:17 -0800 Subject: [PATCH 01/54] Reenable develop --- coconut/root.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coconut/root.py b/coconut/root.py index 2d622b4d8..e671b7e19 100644 --- a/coconut/root.py +++ b/coconut/root.py @@ -26,7 +26,7 @@ VERSION = "3.0.4" VERSION_NAME = None # False for release, int >= 1 for develop -DEVELOP = False +DEVELOP = 1 ALPHA = False # for pre releases rather than post releases assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1" From ceeecfc0bfa213aa4002528b29be3380d244593a Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Tue, 28 Nov 2023 00:01:22 -0800 Subject: [PATCH 02/54] Make sure we always install the kernel --- coconut/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coconut/util.py b/coconut/util.py index b0e04be68..b8b2be601 100644 --- a/coconut/util.py +++ b/coconut/util.py @@ -374,7 +374,7 @@ def get_kernel_data_files(argv): elif any(arg.startswith("install") for arg in argv): executable = sys.executable else: - return [] + executable = "python" install_custom_kernel(executable) return [ ( From 88f571c0a1e8f2e92b0461dca2f40ebbc94978fd Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Tue, 28 Nov 2023 00:08:44 -0800 Subject: [PATCH 03/54] Further improve kernel install --- coconut/util.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/coconut/util.py b/coconut/util.py index b8b2be601..962c278d4 100644 --- a/coconut/util.py +++ b/coconut/util.py @@ -369,12 +369,10 @@ def get_displayable_target(target): def get_kernel_data_files(argv): """Given sys.argv, write the custom kernel file and return data_files.""" - if any(arg.startswith("bdist") for arg in argv): + if any(arg.startswith("bdist") or arg.startswith("sdist") for arg in argv): executable = "python" - elif any(arg.startswith("install") for arg in argv): - executable = sys.executable else: - executable = "python" + executable = sys.executable install_custom_kernel(executable) return [ ( From a920e7fd464897c5167e0464c43bd2061650e005 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Tue, 28 Nov 2023 02:12:11 -0800 Subject: [PATCH 04/54] Fix kernel parsing --- coconut/compiler/compiler.py | 3 +- coconut/compiler/util.py | 2 +- coconut/constants.py | 3 ++ coconut/icoconut/root.py | 62 +++++++++++++++++++++++++++++++++-- coconut/root.py | 2 +- coconut/tests/src/extras.coco | 12 ++++++- 6 files changed, 77 insertions(+), 7 deletions(-) diff --git a/coconut/compiler/compiler.py b/coconut/compiler/compiler.py index 24307a965..8d2f38570 100644 --- a/coconut/compiler/compiler.py +++ b/coconut/compiler/compiler.py @@ -485,6 +485,7 @@ class Compiler(Grammar, pickleable_obj): def __init__(self, *args, **kwargs): """Creates a new compiler with the given parsing parameters.""" self.setup(*args, **kwargs) + self.reset() # changes here should be reflected in __reduce__, get_cli_args, and in the stub for coconut.api.setup def setup(self, target=None, strict=False, minify=False, line_numbers=True, keep_lines=False, no_tco=False, no_wrap=False): @@ -998,7 +999,7 @@ def remove_strs(self, inputstring, inner_environment=True, **kwargs): try: with (self.inner_environment() if inner_environment else noop_ctx()): return self.str_proc(inputstring, **kwargs) - except Exception: + except CoconutSyntaxError: logger.log_exc() return None diff --git a/coconut/compiler/util.py b/coconut/compiler/util.py index 64a0ff84f..760cf6bd1 100644 --- a/coconut/compiler/util.py +++ b/coconut/compiler/util.py @@ -1806,7 +1806,7 @@ def collapse_indents(indentation): def is_blank(line): """Determine whether a line is blank.""" line, _ = rem_and_count_indents(rem_comment(line)) - return line.strip() == "" + return not line or line.isspace() def final_indentation_level(code): diff --git a/coconut/constants.py b/coconut/constants.py index a6c276a8e..a5bd61d6d 100644 --- a/coconut/constants.py +++ b/coconut/constants.py @@ -1276,6 +1276,9 @@ def get_path_env_var(env_var, default): enabled_xonsh_modes = ("single",) +# 1 is safe, 2 seems to work okay, and 3 breaks stuff like '"""\n(\n)\n"""' +num_assemble_logical_lines_tries = 1 + # ----------------------------------------------------------------------------------------------------------------------- # DOCUMENTATION CONSTANTS: # ----------------------------------------------------------------------------------------------------------------------- diff --git a/coconut/icoconut/root.py b/coconut/icoconut/root.py index 0b0cb77f9..56e2ba161 100644 --- a/coconut/icoconut/root.py +++ b/coconut/icoconut/root.py @@ -43,15 +43,17 @@ conda_build_env_var, coconut_kernel_kwargs, default_whitespace_chars, + num_assemble_logical_lines_tries, ) from coconut.terminal import logger from coconut.util import override, memoize_with_exceptions, replace_all from coconut.compiler import Compiler -from coconut.compiler.util import should_indent +from coconut.compiler.util import should_indent, paren_change from coconut.command.util import Runner try: from IPython.core.inputsplitter import IPythonInputSplitter + from IPython.core.inputtransformer import CoroutineInputTransformer from IPython.core.interactiveshell import InteractiveShellABC from IPython.core.compilerop import CachingCompiler from IPython.terminal.embed import InteractiveShellEmbed @@ -154,8 +156,8 @@ class CoconutSplitter(IPythonInputSplitter, object): def __init__(self, *args, **kwargs): """Version of __init__ that sets up Coconut code compilation.""" super(CoconutSplitter, self).__init__(*args, **kwargs) - self._original_compile = self._compile - self._compile = self._coconut_compile + self._original_compile, self._compile = self._compile, self._coconut_compile + self.assemble_logical_lines = self._coconut_assemble_logical_lines() def _coconut_compile(self, source, *args, **kwargs): """Version of _compile that checks Coconut code. @@ -170,6 +172,60 @@ def _coconut_compile(self, source, *args, **kwargs): else: return True + @staticmethod + @CoroutineInputTransformer.wrap + def _coconut_assemble_logical_lines(): + """Version of assemble_logical_lines() that respects strings/parentheses/brackets/braces.""" + line = "" + while True: + line = (yield line) + if not line or line.isspace(): + continue + + parts = [] + level = 0 + while line is not None: + + # get no_strs_line + no_strs_line = None + while no_strs_line is None: + no_strs_line = line.strip() + if no_strs_line: + no_strs_line = COMPILER.remove_strs(no_strs_line) + if no_strs_line is None: + # if we're in the middle of a string, fetch a new line + for _ in range(num_assemble_logical_lines_tries): + new_line = (yield None) + if new_line is not None: + break + if new_line is None: + # if we're not able to build a no_strs_line, we should stop doing line joining + level = 0 + no_strs_line = "" + break + else: + line += new_line + + # update paren level + level += paren_change(no_strs_line) + + # put line in parts and break if done + if level < 0: + parts.append(line) + elif no_strs_line.endswith("\\"): + parts.append(line[:-1]) + else: + parts.append(line) + break + + # if we're not done, fetch a new line + for _ in range(num_assemble_logical_lines_tries): + line = (yield None) + if line is not None: + break + + line = ''.join(parts) + INTERACTIVE_SHELL_CODE = ''' input_splitter = CoconutSplitter(line_input_checker=True) input_transformer_manager = CoconutSplitter(line_input_checker=False) diff --git a/coconut/root.py b/coconut/root.py index e671b7e19..389b16740 100644 --- a/coconut/root.py +++ b/coconut/root.py @@ -26,7 +26,7 @@ VERSION = "3.0.4" VERSION_NAME = None # False for release, int >= 1 for develop -DEVELOP = 1 +DEVELOP = 2 ALPHA = False # for pre releases rather than post releases assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1" diff --git a/coconut/tests/src/extras.coco b/coconut/tests/src/extras.coco index 0d13f39d3..ed97ddd48 100644 --- a/coconut/tests/src/extras.coco +++ b/coconut/tests/src/extras.coco @@ -462,7 +462,7 @@ def test_kernel() -> bool: captured_messages: list[tuple] = [] else: captured_messages: list = [] - def send(self, stream, msg_or_type, content, *args, **kwargs): + def send(self, stream, msg_or_type, content=None, *args, **kwargs): self.captured_messages.append((msg_or_type, content)) if PY35: @@ -515,6 +515,16 @@ def test_kernel() -> bool: assert keyword_complete_result["cursor_start"] == 0 assert keyword_complete_result["cursor_end"] == 1 + assert k.do_execute("ident$(\n?,\n)(99)", False, True, {}, True) |> unwrap_future$(loop) |> .["status"] == "ok" + captured_msg_type, captured_msg_content = fake_session.captured_messages[-1] + assert captured_msg_content is None + assert captured_msg_type["content"]["data"]["text/plain"] == "99" + + assert k.do_execute('"""\n(\n)\n"""', False, True, {}, True) |> unwrap_future$(loop) |> .["status"] == "ok" + captured_msg_type, captured_msg_content = fake_session.captured_messages[-1] + assert captured_msg_content is None + assert captured_msg_type["content"]["data"]["text/plain"] == "'()'" + return True From 5ab65910e28be7344c2c20e51c61501a465413bc Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Tue, 28 Nov 2023 15:25:31 -0800 Subject: [PATCH 05/54] Improve kernel --- DOCS.md | 2 +- coconut/constants.py | 2 ++ coconut/icoconut/root.py | 7 ++++--- coconut/tests/constants_test.py | 2 ++ coconut/tests/main_test.py | 17 ++++++++++++----- coconut/tests/src/extras.coco | 4 ++-- coconut/util.py | 3 ++- 7 files changed, 25 insertions(+), 12 deletions(-) diff --git a/DOCS.md b/DOCS.md index 1355ca8fb..b7a9fb561 100644 --- a/DOCS.md +++ b/DOCS.md @@ -4716,7 +4716,7 @@ Switches the [`breakpoint` built-in](https://www.python.org/dev/peps/pep-0553/) Both functions behave identically to [`setuptools.find_packages`](https://setuptools.pypa.io/en/latest/userguide/quickstart.html#package-discovery), except that they find Coconut packages rather than Python packages. `find_and_compile_packages` additionally compiles any Coconut packages that it finds in-place. -Note that if you want to use either of these functions in your `setup.py`, you'll need to include `coconut` as a [build-time dependency in your `pyproject.toml`](https://pip.pypa.io/en/stable/reference/build-system/pyproject-toml/#build-time-dependencies). If you want `setuptools` to package your Coconut files, you'll also need to add `global-include *.coco` to your [`MANIFEST.in`](https://setuptools.pypa.io/en/latest/userguide/miscellaneous.html). +Note that if you want to use either of these functions in your `setup.py`, you'll need to include `coconut` as a [build-time dependency in your `pyproject.toml`](https://pip.pypa.io/en/stable/reference/build-system/pyproject-toml/#build-time-dependencies). If you want `setuptools` to package your Coconut files, you'll also need to add `global-include *.coco` to your [`MANIFEST.in`](https://setuptools.pypa.io/en/latest/userguide/miscellaneous.html) and [pass `include_package_data=True` to `setuptools.setup`](https://setuptools.pypa.io/en/latest/userguide/datafiles.html). ##### Example diff --git a/coconut/constants.py b/coconut/constants.py index a5bd61d6d..7cde91999 100644 --- a/coconut/constants.py +++ b/coconut/constants.py @@ -1235,6 +1235,8 @@ def get_path_env_var(env_var, default): "coconut_pycon = coconut.highlighter:CoconutPythonConsoleLexer", ) +setuptools_distribution_names = ("bdist", "sdist") + requests_sleep_times = (0, 0.1, 0.2, 0.3, 0.4, 1) # ----------------------------------------------------------------------------------------------------------------------- diff --git a/coconut/icoconut/root.py b/coconut/icoconut/root.py index 56e2ba161..a5ba61726 100644 --- a/coconut/icoconut/root.py +++ b/coconut/icoconut/root.py @@ -110,6 +110,10 @@ def syntaxerr_memoized_parse_block(code): # KERNEL: # ----------------------------------------------------------------------------------------------------------------------- +if papermill_translators is not None: + papermill_translators.register("coconut", PythonTranslator) + + if LOAD_MODULE: COMPILER.warm_up(enable_incremental_mode=True) @@ -349,6 +353,3 @@ class CoconutKernelApp(IPKernelApp, object): classes = IPKernelApp.classes + [CoconutKernel, CoconutShell] kernel_class = CoconutKernel subcommands = {} - - if papermill_translators is not None: - papermill_translators.register("coconut", PythonTranslator) diff --git a/coconut/tests/constants_test.py b/coconut/tests/constants_test.py index eb3250b29..d60976c19 100644 --- a/coconut/tests/constants_test.py +++ b/coconut/tests/constants_test.py @@ -81,6 +81,7 @@ class TestConstants(unittest.TestCase): def test_defaults(self): assert constants.use_fast_pyparsing_reprs assert not constants.embed_on_internal_exc + assert constants.num_assemble_logical_lines_tries >= 1 def test_fixpath(self): assert os.path.basename(fixpath("CamelCase.py")) == "CamelCase.py" @@ -133,6 +134,7 @@ def test_targets(self): def test_tuples(self): assert isinstance(constants.indchars, tuple) assert isinstance(constants.comment_chars, tuple) + assert isinstance(constants.setuptools_distribution_names, tuple) # ----------------------------------------------------------------------------------------------------------------------- diff --git a/coconut/tests/main_test.py b/coconut/tests/main_test.py index 2d7bf296e..1bc893103 100644 --- a/coconut/tests/main_test.py +++ b/coconut/tests/main_test.py @@ -580,6 +580,15 @@ def using_env_vars(env_vars): os.environ.update(old_env) +def list_kernel_names(): + """Get a list of installed jupyter kernels.""" + stdout, stderr, retcode = call_output(["jupyter", "kernelspec", "list"]) + if not stdout: + stdout, stderr = stderr, "" + assert not retcode and not stderr, stderr + return stdout + + # ----------------------------------------------------------------------------------------------------------------------- # RUNNERS: # ----------------------------------------------------------------------------------------------------------------------- @@ -933,13 +942,11 @@ def test_ipython_extension(self): ) def test_kernel_installation(self): + assert icoconut_custom_kernel_name in list_kernel_names() call(["coconut", "--jupyter"], assert_output=kernel_installation_msg) - stdout, stderr, retcode = call_output(["jupyter", "kernelspec", "list"]) - if not stdout: - stdout, stderr = stderr, "" - assert not retcode and not stderr, stderr + kernels = list_kernel_names() for kernel in (icoconut_custom_kernel_name,) + icoconut_default_kernel_names: - assert kernel in stdout + assert kernel in kernels if not WINDOWS and not PYPY: def test_jupyter_console(self): diff --git a/coconut/tests/src/extras.coco b/coconut/tests/src/extras.coco index ed97ddd48..f97c94d3a 100644 --- a/coconut/tests/src/extras.coco +++ b/coconut/tests/src/extras.coco @@ -235,7 +235,7 @@ def f() = assert_raises(-> parse("return = 1"), CoconutParseError, err_has='invalid use of the keyword "return"') assert_raises(-> parse("if a = b: pass"), CoconutParseError, err_has="misplaced assignment") assert_raises(-> parse("while a == b"), CoconutParseError, err_has="misplaced newline") - assert_raises(-> parse("0xfgf"), CoconutParseError, err_has=" \~~^") + assert_raises(-> parse("0xfgf"), CoconutParseError, err_has=r" \~~^") try: parse(""" @@ -257,7 +257,7 @@ def gam_eps_rate(bitarr) = ( err_str = str(err) assert "misplaced '?'" in err_str if not PYPY: - assert """ + assert r""" |> map$(int(?, 2)) \~~~~^""" in err_str or """ |> map$(int(?, 2)) diff --git a/coconut/util.py b/coconut/util.py index 962c278d4..1e9b91bb1 100644 --- a/coconut/util.py +++ b/coconut/util.py @@ -49,6 +49,7 @@ icoconut_custom_kernel_file_loc, WINDOWS, non_syntactic_newline, + setuptools_distribution_names, ) @@ -369,7 +370,7 @@ def get_displayable_target(target): def get_kernel_data_files(argv): """Given sys.argv, write the custom kernel file and return data_files.""" - if any(arg.startswith("bdist") or arg.startswith("sdist") for arg in argv): + if any(arg.startswith(setuptools_distribution_names) for arg in argv): executable = "python" else: executable = sys.executable From e2befe661eafd9c66afad84e0737058175b74e5c Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Tue, 28 Nov 2023 19:00:30 -0800 Subject: [PATCH 06/54] Add multidim arr concat op funcs Refs #809. --- DOCS.md | 8 ++++- __coconut__/__init__.pyi | 31 +++++++++---------- coconut/__coconut__.pyi | 2 +- coconut/compiler/grammar.py | 18 +++++++---- coconut/compiler/header.py | 2 +- coconut/compiler/templates/header.py_template | 3 +- coconut/constants.py | 5 ++- coconut/icoconut/root.py | 7 ++--- coconut/root.py | 2 +- .../src/cocotest/agnostic/primary_1.coco | 6 ++-- .../src/cocotest/agnostic/primary_2.coco | 5 +-- coconut/tests/src/extras.coco | 5 ++- 12 files changed, 55 insertions(+), 39 deletions(-) diff --git a/DOCS.md b/DOCS.md index b7a9fb561..ee1b8615e 100644 --- a/DOCS.md +++ b/DOCS.md @@ -480,7 +480,7 @@ To allow for better use of [`numpy`](https://numpy.org/) objects in Coconut, all - Coconut's [multidimensional array literal and array concatenation syntax](#multidimensional-array-literalconcatenation-syntax) supports `numpy` objects, including using fast `numpy` concatenation methods if given `numpy` arrays rather than Coconut's default much slower implementation built for Python lists of lists. - Many of Coconut's built-ins include special `numpy` support, specifically: * [`fmap`](#fmap) will use [`numpy.vectorize`](https://numpy.org/doc/stable/reference/generated/numpy.vectorize.html) to map over `numpy` arrays. - * [`multi_enumerate`](#multi_enumerate) allows for easily looping over all the multi-dimensional indices in a `numpy` array. + * [`multi_enumerate`](#multi_enumerate) allows for easily looping over all the multidimensional indices in a `numpy` array. * [`cartesian_product`](#cartesian_product) can compute the Cartesian product of given `numpy` arrays as a `numpy` array. * [`all_equal`](#all_equal) allows for easily checking if all the elements in a `numpy` array are the same. - [`numpy.ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html) is registered as a [`collections.abc.Sequence`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence), enabling it to be used in [sequence patterns](#semantics-specification). @@ -1822,6 +1822,10 @@ A very common thing to do in functional programming is to make use of function v (not in) => # negative containment (assert) => def (cond, msg=None) => assert cond, msg # (but a better msg if msg is None) (raise) => def (exc=None, from_exc=None) => raise exc from from_exc # or just raise if exc is None +# operator functions for multidimensional array concatenation use brackets: +[;] => def (x, y) => [x; y] +[;;] => def (x, y) => [x;; y] +... # and so on for any number of semicolons # there are two operator functions that don't require parentheses: .[] => (operator.getitem) .$[] => # iterator slicing operator @@ -2067,6 +2071,8 @@ If multiple different concatenation operators are used, the operators with the l [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] ``` +_Note: the [operator functions](#operator-functions) for multidimensional array concatenation are spelled `[;]`, `[;;]`, etc. (for any number of parentheses)._ + ##### Comparison to Julia Coconut's multidimensional array syntax is based on that of [Julia](https://docs.julialang.org/en/v1/manual/arrays/#man-array-literals). The primary difference between Coconut's syntax and Julia's syntax is that multidimensional arrays are row-first in Coconut (following `numpy`), but column-first in Julia. Thus, `;` is vertical concatenation in Julia but **horizontal concatenation** in Coconut and `;;` is horizontal concatenation in Julia but **vertical concatenation** in Coconut. diff --git a/__coconut__/__init__.pyi b/__coconut__/__init__.pyi index 70a0646f5..42af159ec 100644 --- a/__coconut__/__init__.pyi +++ b/__coconut__/__init__.pyi @@ -1828,45 +1828,44 @@ def _coconut_mk_anon_namedtuple( # @_t.overload -# def _coconut_multi_dim_arr( -# arrs: _t.Tuple[_coconut.npt.NDArray[_DType], ...], +# def _coconut_arr_concat_op( # dim: int, +# *arrs: _coconut.npt.NDArray[_DType], # ) -> _coconut.npt.NDArray[_DType]: ... # @_t.overload -# def _coconut_multi_dim_arr( -# arrs: _t.Tuple[_DType, ...], +# def _coconut_arr_concat_op( # dim: int, +# *arrs: _DType, # ) -> _coconut.npt.NDArray[_DType]: ... - @_t.overload -def _coconut_multi_dim_arr( - arrs: _t.Tuple[_t.Sequence[_T], ...], +def _coconut_arr_concat_op( dim: _t.Literal[1], + *arrs: _t.Sequence[_T], ) -> _t.Sequence[_T]: ... @_t.overload -def _coconut_multi_dim_arr( - arrs: _t.Tuple[_T, ...], +def _coconut_arr_concat_op( dim: _t.Literal[1], + *arrs: _T, ) -> _t.Sequence[_T]: ... @_t.overload -def _coconut_multi_dim_arr( - arrs: _t.Tuple[_t.Sequence[_t.Sequence[_T]], ...], +def _coconut_arr_concat_op( dim: _t.Literal[2], + *arrs: _t.Sequence[_t.Sequence[_T]], ) -> _t.Sequence[_t.Sequence[_T]]: ... @_t.overload -def _coconut_multi_dim_arr( - arrs: _t.Tuple[_t.Sequence[_T], ...], +def _coconut_arr_concat_op( dim: _t.Literal[2], + *arrs: _t.Sequence[_T], ) -> _t.Sequence[_t.Sequence[_T]]: ... @_t.overload -def _coconut_multi_dim_arr( - arrs: _t.Tuple[_T, ...], +def _coconut_arr_concat_op( dim: _t.Literal[2], + *arrs: _T, ) -> _t.Sequence[_t.Sequence[_T]]: ... @_t.overload -def _coconut_multi_dim_arr(arrs: _Tuple, dim: int) -> _Sequence: ... +def _coconut_arr_concat_op(dim: int, *arrs: _t.Any) -> _Sequence: ... class _coconut_SupportsAdd(_t.Protocol, _t.Generic[_Tco, _Ucontra, _Vco]): diff --git a/coconut/__coconut__.pyi b/coconut/__coconut__.pyi index e56d0e55e..cca933f3f 100644 --- a/coconut/__coconut__.pyi +++ b/coconut/__coconut__.pyi @@ -1,2 +1,2 @@ from __coconut__ import * -from __coconut__ import _coconut_tail_call, _coconut_tco, _coconut_call_set_names, _coconut_handle_cls_kwargs, _coconut_handle_cls_stargs, _namedtuple_of, _coconut, _coconut_Expected, _coconut_MatchError, _coconut_SupportsAdd, _coconut_SupportsMinus, _coconut_SupportsMul, _coconut_SupportsPow, _coconut_SupportsTruediv, _coconut_SupportsFloordiv, _coconut_SupportsMod, _coconut_SupportsAnd, _coconut_SupportsXor, _coconut_SupportsOr, _coconut_SupportsLshift, _coconut_SupportsRshift, _coconut_SupportsMatmul, _coconut_SupportsInv, _coconut_Expected, _coconut_MatchError, _coconut_iter_getitem, _coconut_base_compose, _coconut_forward_compose, _coconut_back_compose, _coconut_forward_star_compose, _coconut_back_star_compose, _coconut_forward_dubstar_compose, _coconut_back_dubstar_compose, _coconut_pipe, _coconut_star_pipe, _coconut_dubstar_pipe, _coconut_back_pipe, _coconut_back_star_pipe, _coconut_back_dubstar_pipe, _coconut_none_pipe, _coconut_none_star_pipe, _coconut_none_dubstar_pipe, _coconut_bool_and, _coconut_bool_or, _coconut_none_coalesce, _coconut_minus, _coconut_map, _coconut_partial, _coconut_complex_partial, _coconut_get_function_match_error, _coconut_base_pattern_func, _coconut_addpattern, _coconut_sentinel, _coconut_assert, _coconut_raise, _coconut_mark_as_match, _coconut_reiterable, _coconut_self_match_types, _coconut_dict_merge, _coconut_exec, _coconut_comma_op, _coconut_multi_dim_arr, _coconut_mk_anon_namedtuple, _coconut_matmul, _coconut_py_str, _coconut_flatten, _coconut_multiset, _coconut_back_none_pipe, _coconut_back_none_star_pipe, _coconut_back_none_dubstar_pipe, _coconut_forward_none_compose, _coconut_back_none_compose, _coconut_forward_none_star_compose, _coconut_back_none_star_compose, _coconut_forward_none_dubstar_compose, _coconut_back_none_dubstar_compose, _coconut_call_or_coefficient, _coconut_in, _coconut_not_in, _coconut_attritemgetter +from __coconut__ import _coconut_tail_call, _coconut_tco, _coconut_call_set_names, _coconut_handle_cls_kwargs, _coconut_handle_cls_stargs, _namedtuple_of, _coconut, _coconut_Expected, _coconut_MatchError, _coconut_SupportsAdd, _coconut_SupportsMinus, _coconut_SupportsMul, _coconut_SupportsPow, _coconut_SupportsTruediv, _coconut_SupportsFloordiv, _coconut_SupportsMod, _coconut_SupportsAnd, _coconut_SupportsXor, _coconut_SupportsOr, _coconut_SupportsLshift, _coconut_SupportsRshift, _coconut_SupportsMatmul, _coconut_SupportsInv, _coconut_Expected, _coconut_MatchError, _coconut_iter_getitem, _coconut_base_compose, _coconut_forward_compose, _coconut_back_compose, _coconut_forward_star_compose, _coconut_back_star_compose, _coconut_forward_dubstar_compose, _coconut_back_dubstar_compose, _coconut_pipe, _coconut_star_pipe, _coconut_dubstar_pipe, _coconut_back_pipe, _coconut_back_star_pipe, _coconut_back_dubstar_pipe, _coconut_none_pipe, _coconut_none_star_pipe, _coconut_none_dubstar_pipe, _coconut_bool_and, _coconut_bool_or, _coconut_none_coalesce, _coconut_minus, _coconut_map, _coconut_partial, _coconut_complex_partial, _coconut_get_function_match_error, _coconut_base_pattern_func, _coconut_addpattern, _coconut_sentinel, _coconut_assert, _coconut_raise, _coconut_mark_as_match, _coconut_reiterable, _coconut_self_match_types, _coconut_dict_merge, _coconut_exec, _coconut_comma_op, _coconut_arr_concat_op, _coconut_mk_anon_namedtuple, _coconut_matmul, _coconut_py_str, _coconut_flatten, _coconut_multiset, _coconut_back_none_pipe, _coconut_back_none_star_pipe, _coconut_back_none_dubstar_pipe, _coconut_forward_none_compose, _coconut_back_none_compose, _coconut_forward_none_star_compose, _coconut_back_none_star_compose, _coconut_forward_none_dubstar_compose, _coconut_back_none_dubstar_compose, _coconut_call_or_coefficient, _coconut_in, _coconut_not_in, _coconut_attritemgetter diff --git a/coconut/compiler/grammar.py b/coconut/compiler/grammar.py index 7eaed6226..e8d9a2e43 100644 --- a/coconut/compiler/grammar.py +++ b/coconut/compiler/grammar.py @@ -109,7 +109,6 @@ labeled_group, any_keyword_in, any_char, - tuple_str_of, any_len_perm, any_len_perm_at_least_one, boundary, @@ -576,20 +575,27 @@ def array_literal_handle(loc, tokens): array_elems = [] for p in pieces: if p: - if len(p) > 1: + if p[0].lstrip(";") == "": + raise CoconutDeferredSyntaxError("invalid initial multidimensional array separator or broken-up multidimensional array concatenation operator function", loc) + elif len(p) > 1: internal_assert(sep_level > 1, "failed to handle array literal tokens", tokens) subarr_item = array_literal_handle(loc, p) - elif p[0].lstrip(";") == "": - raise CoconutDeferredSyntaxError("naked multidimensional array separators are not allowed", loc) else: subarr_item = p[0] array_elems.append(subarr_item) + # if multidimensional array literal is only separators, compile to implicit partial if not array_elems: - raise CoconutDeferredSyntaxError("multidimensional array literal cannot be only separators", loc) + if len(pieces) > 2: + raise CoconutDeferredSyntaxError("invalid empty multidimensional array literal or broken-up multidimensional array concatenation operator function", loc) + return "_coconut_partial(_coconut_arr_concat_op, " + str(sep_level) + ")" + + # check for initial top-level separators + if not pieces[0]: + raise CoconutDeferredSyntaxError("invalid initial multidimensional array separator", loc) # build multidimensional array - return "_coconut_multi_dim_arr(" + tuple_str_of(array_elems) + ", " + str(sep_level) + ")" + return "_coconut_arr_concat_op(" + str(sep_level) + ", " + ", ".join(array_elems) + ")" def typedef_op_item_handle(loc, tokens): diff --git a/coconut/compiler/header.py b/coconut/compiler/header.py index 8a60ff8cc..17e1b1041 100644 --- a/coconut/compiler/header.py +++ b/coconut/compiler/header.py @@ -638,7 +638,7 @@ def __anext__(self): # (extra_format_dict is to keep indentation levels matching) extra_format_dict = dict( # when anything is added to this list it must also be added to *both* __coconut__ stub files - underscore_imports="{tco_comma}{call_set_names_comma}{handle_cls_args_comma}_namedtuple_of, _coconut, _coconut_Expected, _coconut_MatchError, _coconut_SupportsAdd, _coconut_SupportsMinus, _coconut_SupportsMul, _coconut_SupportsPow, _coconut_SupportsTruediv, _coconut_SupportsFloordiv, _coconut_SupportsMod, _coconut_SupportsAnd, _coconut_SupportsXor, _coconut_SupportsOr, _coconut_SupportsLshift, _coconut_SupportsRshift, _coconut_SupportsMatmul, _coconut_SupportsInv, _coconut_iter_getitem, _coconut_base_compose, _coconut_forward_compose, _coconut_back_compose, _coconut_forward_star_compose, _coconut_back_star_compose, _coconut_forward_dubstar_compose, _coconut_back_dubstar_compose, _coconut_pipe, _coconut_star_pipe, _coconut_dubstar_pipe, _coconut_back_pipe, _coconut_back_star_pipe, _coconut_back_dubstar_pipe, _coconut_none_pipe, _coconut_none_star_pipe, _coconut_none_dubstar_pipe, _coconut_bool_and, _coconut_bool_or, _coconut_none_coalesce, _coconut_minus, _coconut_map, _coconut_partial, _coconut_complex_partial, _coconut_get_function_match_error, _coconut_base_pattern_func, _coconut_addpattern, _coconut_sentinel, _coconut_assert, _coconut_raise, _coconut_mark_as_match, _coconut_reiterable, _coconut_self_match_types, _coconut_dict_merge, _coconut_exec, _coconut_comma_op, _coconut_multi_dim_arr, _coconut_mk_anon_namedtuple, _coconut_matmul, _coconut_py_str, _coconut_flatten, _coconut_multiset, _coconut_back_none_pipe, _coconut_back_none_star_pipe, _coconut_back_none_dubstar_pipe, _coconut_forward_none_compose, _coconut_back_none_compose, _coconut_forward_none_star_compose, _coconut_back_none_star_compose, _coconut_forward_none_dubstar_compose, _coconut_back_none_dubstar_compose, _coconut_call_or_coefficient, _coconut_in, _coconut_not_in, _coconut_attritemgetter".format(**format_dict), + underscore_imports="{tco_comma}{call_set_names_comma}{handle_cls_args_comma}_namedtuple_of, _coconut, _coconut_Expected, _coconut_MatchError, _coconut_SupportsAdd, _coconut_SupportsMinus, _coconut_SupportsMul, _coconut_SupportsPow, _coconut_SupportsTruediv, _coconut_SupportsFloordiv, _coconut_SupportsMod, _coconut_SupportsAnd, _coconut_SupportsXor, _coconut_SupportsOr, _coconut_SupportsLshift, _coconut_SupportsRshift, _coconut_SupportsMatmul, _coconut_SupportsInv, _coconut_iter_getitem, _coconut_base_compose, _coconut_forward_compose, _coconut_back_compose, _coconut_forward_star_compose, _coconut_back_star_compose, _coconut_forward_dubstar_compose, _coconut_back_dubstar_compose, _coconut_pipe, _coconut_star_pipe, _coconut_dubstar_pipe, _coconut_back_pipe, _coconut_back_star_pipe, _coconut_back_dubstar_pipe, _coconut_none_pipe, _coconut_none_star_pipe, _coconut_none_dubstar_pipe, _coconut_bool_and, _coconut_bool_or, _coconut_none_coalesce, _coconut_minus, _coconut_map, _coconut_partial, _coconut_complex_partial, _coconut_get_function_match_error, _coconut_base_pattern_func, _coconut_addpattern, _coconut_sentinel, _coconut_assert, _coconut_raise, _coconut_mark_as_match, _coconut_reiterable, _coconut_self_match_types, _coconut_dict_merge, _coconut_exec, _coconut_comma_op, _coconut_arr_concat_op, _coconut_mk_anon_namedtuple, _coconut_matmul, _coconut_py_str, _coconut_flatten, _coconut_multiset, _coconut_back_none_pipe, _coconut_back_none_star_pipe, _coconut_back_none_dubstar_pipe, _coconut_forward_none_compose, _coconut_back_none_compose, _coconut_forward_none_star_compose, _coconut_back_none_star_compose, _coconut_forward_none_dubstar_compose, _coconut_back_none_dubstar_compose, _coconut_call_or_coefficient, _coconut_in, _coconut_not_in, _coconut_attritemgetter".format(**format_dict), import_typing=pycondition( (3, 5), if_ge=''' diff --git a/coconut/compiler/templates/header.py_template b/coconut/compiler/templates/header.py_template index e7ec5f6f1..2f401ad5a 100644 --- a/coconut/compiler/templates/header.py_template +++ b/coconut/compiler/templates/header.py_template @@ -2036,7 +2036,8 @@ def _coconut_concatenate(arrs, axis): if not axis: return _coconut.list(_coconut.itertools.chain.from_iterable(arrs)) return [_coconut_concatenate(rows, axis - 1) for rows in _coconut.zip(*arrs)] -def _coconut_multi_dim_arr(arrs, dim): +def _coconut_arr_concat_op(dim, *arrs): + """Coconut multi-dimensional array concatenation operator.""" arr_dims = [_coconut_ndim(a) for a in arrs] arrs = [_coconut_expand_arr(a, dim - d) if d < dim else a for a, d in _coconut.zip(arrs, arr_dims)] arr_dims.append(dim) diff --git a/coconut/constants.py b/coconut/constants.py index 7cde91999..0c648eb51 100644 --- a/coconut/constants.py +++ b/coconut/constants.py @@ -1269,8 +1269,11 @@ def get_path_env_var(env_var, default): "coconut3", ) -py_syntax_version = 3 mimetype = "text/x-python3" +codemirror_mode = { + "name": "ipython", + "version": 3, +} all_keywords = keyword_vars + const_vars + reserved_vars diff --git a/coconut/icoconut/root.py b/coconut/icoconut/root.py index a5ba61726..7fd1d968c 100644 --- a/coconut/icoconut/root.py +++ b/coconut/icoconut/root.py @@ -34,7 +34,7 @@ ) from coconut.constants import ( PY311, - py_syntax_version, + codemirror_mode, mimetype, version_banner, tutorial_url, @@ -307,10 +307,7 @@ class CoconutKernel(IPythonKernel, object): "name": "coconut", "version": VERSION, "mimetype": mimetype, - "codemirror_mode": { - "name": "ipython", - "version": py_syntax_version, - }, + "codemirror_mode": codemirror_mode, "pygments_lexer": "coconut", "file_extension": code_exts[0], } diff --git a/coconut/root.py b/coconut/root.py index 389b16740..a89c05c85 100644 --- a/coconut/root.py +++ b/coconut/root.py @@ -26,7 +26,7 @@ VERSION = "3.0.4" VERSION_NAME = None # False for release, int >= 1 for develop -DEVELOP = 2 +DEVELOP = 3 ALPHA = False # for pre releases rather than post releases assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1" diff --git a/coconut/tests/src/cocotest/agnostic/primary_1.coco b/coconut/tests/src/cocotest/agnostic/primary_1.coco index bc85179c7..b8e9a44d5 100644 --- a/coconut/tests/src/cocotest/agnostic/primary_1.coco +++ b/coconut/tests/src/cocotest/agnostic/primary_1.coco @@ -973,8 +973,8 @@ def primary_test_1() -> bool: def ret1() = 1 assert ret1() == 1 assert (.,2)(1) == (1, 2) == (1,.)(2) - assert [[];] == [] - assert [[];;] == [[]] + assert [[];] == [] == [;]([]) + assert [[];;] == [[]] == [;;]([]) assert [1;] == [1] == [[1];] assert [1;;] == [[1]] == [[1];;] assert [[[1]];;] == [[1]] == [[1;];;] @@ -1009,7 +1009,7 @@ def primary_test_1() -> bool: 5, 6 ;; 7, 8] == [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] a = [1,2 ;; 3,4] - assert [a; a] == [[1,2,1,2], [3,4,3,4]] + assert [a; a] == [[1,2,1,2], [3,4,3,4]] == [;](a, a) assert [a;; a] == [[1,2],[3,4],[1,2],[3,4]] == [*a, *a] assert [a ;;; a] == [[[1,2],[3,4]], [[1,2],[3,4]]] == [a, a] assert [a ;;;; a] == [[a], [a]] diff --git a/coconut/tests/src/cocotest/agnostic/primary_2.coco b/coconut/tests/src/cocotest/agnostic/primary_2.coco index b4e55fb2e..a7f7560df 100644 --- a/coconut/tests/src/cocotest/agnostic/primary_2.coco +++ b/coconut/tests/src/cocotest/agnostic/primary_2.coco @@ -303,8 +303,8 @@ def primary_test_2() -> bool: a_dict = {"a": 1, "b": 2} a_dict |= {"a": 10, "c": 20} assert a_dict == {"a": 10, "b": 2, "c": 20} == {"a": 1, "b": 2} | {"a": 10, "c": 20} - assert ["abc" ; "def"] == ['abc', 'def'] - assert ["abc" ;; "def"] == [['abc'], ['def']] + assert ["abc" ; "def"] == ['abc', 'def'] == [;] <*| ("abc", "def") + assert ["abc" ;; "def"] == [['abc'], ['def']] == [;;] <*| ("abc", "def") assert {"a":0, "b":1}$[0] == "a" assert (|0, NotImplemented, 2|)$[1] is NotImplemented assert m{1, 1, 2} |> fmap$(.+1) == m{2, 2, 3} @@ -410,6 +410,7 @@ def primary_test_2() -> bool: assert 0x == 0 == 0 x assert 0xff == 255 == 0x100-1 assert 11259375 == 0xabcdef + assert [[] ;; [] ;;;] == [[[], []]] with process_map.multiple_sequential_calls(): # type: ignore assert map((+), range(3), range(4)$[:-1], strict=True) |> list == [0, 2, 4] == process_map((+), range(3), range(4)$[:-1], strict=True) |> list # type: ignore diff --git a/coconut/tests/src/extras.coco b/coconut/tests/src/extras.coco index f97c94d3a..407094fe9 100644 --- a/coconut/tests/src/extras.coco +++ b/coconut/tests/src/extras.coco @@ -127,8 +127,11 @@ def test_setup_none() -> bool: assert_raises(-> parse("\\("), CoconutSyntaxError) assert_raises(-> parse("if a:\n b\n c"), CoconutSyntaxError) assert_raises(-> parse("_coconut"), CoconutSyntaxError) - assert_raises(-> parse("[;]"), CoconutSyntaxError) + assert_raises(-> parse("[; ;]"), CoconutSyntaxError) assert_raises(-> parse("[; ;; ;]"), CoconutSyntaxError) + assert_raises(-> parse("[; ; ;;]"), CoconutSyntaxError) + assert_raises(-> parse("[[] ;;; ;; [] ;]"), CoconutSyntaxError) + assert_raises(-> parse("[; []]"), CoconutSyntaxError) assert_raises(-> parse("f$()"), CoconutSyntaxError) assert_raises(-> parse("f(**x, y)"), CoconutSyntaxError) assert_raises(-> parse("def f(x) = return x"), CoconutSyntaxError) From 315b2313356321d5704a4784f42e44809f690963 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Tue, 28 Nov 2023 20:21:20 -0800 Subject: [PATCH 07/54] Add multidim arr concat impl partials Resolves #809. --- DOCS.md | 33 ++++++++++++------- coconut/compiler/compiler.py | 22 +++++++++++++ coconut/compiler/grammar.py | 29 ++++++++++++++-- coconut/root.py | 2 +- .../src/cocotest/agnostic/primary_2.coco | 9 +++++ 5 files changed, 80 insertions(+), 15 deletions(-) diff --git a/DOCS.md b/DOCS.md index ee1b8615e..b4a52b807 100644 --- a/DOCS.md +++ b/DOCS.md @@ -1853,25 +1853,34 @@ print(list(map(operator.add, range(0, 5), range(5, 10)))) Coconut supports a number of different syntactical aliases for common partial application use cases. These are: ```coconut -.attr => operator.attrgetter("attr") -.method(args) => operator.methodcaller("method", args) -func$ => ($)$(func) -seq[] => operator.getitem$(seq) -iter$[] => # the equivalent of seq[] for iterators -.[a:b:c] => operator.itemgetter(slice(a, b, c)) -.$[a:b:c] => # the equivalent of .[a:b:c] for iterators -``` +# attribute access and method calling +.attr1.attr2 => operator.attrgetter("attr1.attr2") +.method(args) => operator.methodcaller("method", args) +.attr.method(args) => .attr ..> .method(args) + +# indexing +.[a:b:c] => operator.itemgetter(slice(a, b, c)) +.[x][y] => .[x] ..> .[y] +.method[x] => .method ..> .[x] +seq[] => operator.getitem$(seq) -Additionally, `.attr.method(args)`, `.[x][y]`, `.$[x]$[y]`, and `.method[x]` are also supported. +# iterator indexing +.$[a:b:c] => # the equivalent of .[a:b:c] for iterators +.$[x]$[y] => .$[x] ..> .$[y] +iter$[] => # the equivalent of seq[] for iterators + +# currying +func$ => ($)$(func) +``` In addition, for every Coconut [operator function](#operator-functions), Coconut supports syntax for implicitly partially applying that operator function as ``` (. ) ( .) ``` -where `` is the operator function and `` is any expression. Note that, as with operator functions themselves, the parentheses are necessary for this type of implicit partial application. +where `` is the operator function and `` is any expression. Note that, as with operator functions themselves, the parentheses are necessary for this type of implicit partial application. This syntax is slightly different for multidimensional array concatenation operator functions, which use brackets instead of parentheses. -Additionally, Coconut also supports implicit operator function partials for arbitrary functions as +Furthermore, Coconut also supports implicit operator function partials for arbitrary functions as ``` (. `` ) ( `` .) @@ -2071,7 +2080,7 @@ If multiple different concatenation operators are used, the operators with the l [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] ``` -_Note: the [operator functions](#operator-functions) for multidimensional array concatenation are spelled `[;]`, `[;;]`, etc. (for any number of parentheses)._ +_Note: the [operator functions](#operator-functions) for multidimensional array concatenation are spelled `[;]`, `[;;]`, etc. (with any number of parentheses). The [implicit partials](#implicit-partial-application) are similarly spelled `[. ; x]`, `[x ; .]`, etc._ ##### Comparison to Julia diff --git a/coconut/compiler/compiler.py b/coconut/compiler/compiler.py index 8d2f38570..3bf9fbf3a 100644 --- a/coconut/compiler/compiler.py +++ b/coconut/compiler/compiler.py @@ -130,6 +130,7 @@ attrgetter_atom_handle, itemgetter_handle, partial_op_item_handle, + partial_arr_concat_handle, ) from coconut.compiler.util import ( ExceptionNode, @@ -2763,6 +2764,7 @@ def pipe_item_split(self, tokens, loc): - (name, args) for attr/method - (attr, [(op, args)]) for itemgetter - (op, arg) for right op partial + - (op, arg) for right arr concat partial """ # list implies artificial tokens, which must be expr if isinstance(tokens, list) or "expr" in tokens: @@ -2792,6 +2794,18 @@ def pipe_item_split(self, tokens, loc): return "right op partial", (op, arg) else: raise CoconutInternalException("invalid op partial tokens in pipe_item", inner_toks) + elif "arr concat partial" in tokens: + inner_toks, = tokens + if "left arr concat partial" in inner_toks: + arg, op = inner_toks + internal_assert(op.lstrip(";") == "", "invalid arr concat op", op) + return "partial", ("_coconut_arr_concat_op", str(len(op)) + ", " + arg, "") + elif "right arr concat partial" in inner_toks: + op, arg = inner_toks + internal_assert(op.lstrip(";") == "", "invalid arr concat op", op) + return "right arr concat partial", (op, arg) + else: + raise CoconutInternalException("invalid arr concat partial tokens in pipe_item", inner_toks) elif "await" in tokens: internal_assert(len(tokens) == 1 and tokens[0] == "await", "invalid await pipe item tokens", tokens) return "await", [] @@ -2821,6 +2835,8 @@ def pipe_handle(self, original, loc, tokens, **kwargs): return itemgetter_handle(item) elif name == "right op partial": return partial_op_item_handle(item) + elif name == "right arr concat partial": + return partial_arr_concat_handle(item) elif name == "await": raise CoconutDeferredSyntaxError("await in pipe must have something piped into it", loc) else: @@ -2889,6 +2905,12 @@ def pipe_handle(self, original, loc, tokens, **kwargs): raise CoconutDeferredSyntaxError("cannot star pipe into operator partial", loc) op, arg = split_item return "({op})({x}, {arg})".format(op=op, x=subexpr, arg=arg) + elif name == "right arr concat partial": + if stars: + raise CoconutDeferredSyntaxError("cannot star pipe into array concatenation operator partial", loc) + op, arg = split_item + internal_assert(op.lstrip(";") == "", "invalid arr concat op", op) + return "_coconut_arr_concat_op({dim}, {x}, {arg})".format(dim=len(op), x=subexpr, arg=arg) elif name == "await": internal_assert(not split_item, "invalid split await pipe item tokens", split_item) if stars: diff --git a/coconut/compiler/grammar.py b/coconut/compiler/grammar.py index e8d9a2e43..b17e856a9 100644 --- a/coconut/compiler/grammar.py +++ b/coconut/compiler/grammar.py @@ -550,6 +550,21 @@ def partial_op_item_handle(tokens): raise CoconutInternalException("invalid operator function implicit partial token group", tok_grp) +def partial_arr_concat_handle(tokens): + """Handle array concatenation operator function implicit partials.""" + tok_grp, = tokens + if "left arr concat partial" in tok_grp: + arg, op = tok_grp + internal_assert(op.lstrip(";") == "", "invalid arr concat op", op) + return "_coconut_partial(_coconut_arr_concat_op, " + str(len(op)) + ", " + arg + ")" + elif "right arr concat partial" in tok_grp: + op, arg = tok_grp + internal_assert(op.lstrip(";") == "", "invalid arr concat op", op) + return "_coconut_complex_partial(_coconut_arr_concat_op, {{0: {dim}, 2: {arg}}}, 3, ())".format(dim=len(op), arg=arg) + else: + raise CoconutInternalException("invalid array concatenation operator function implicit partial token group", tok_grp) + + def array_literal_handle(loc, tokens): """Handle multidimensional array literals.""" internal_assert(len(tokens) >= 1, "invalid array literal tokens", tokens) @@ -1071,7 +1086,7 @@ class Grammar(object): ) partial_op_item = attach(partial_op_item_tokens, partial_op_item_handle) op_item = ( - # partial_op_item must come first, then typedef_op_item must come after base_op_item + # must stay in exactly this order partial_op_item | typedef_op_item | base_op_item @@ -1079,6 +1094,12 @@ class Grammar(object): partial_op_atom_tokens = lparen.suppress() + partial_op_item_tokens + rparen.suppress() + partial_arr_concat_tokens = lbrack.suppress() + ( + labeled_group(dot.suppress() + multisemicolon + test_no_infix + rbrack.suppress(), "right arr concat partial") + | labeled_group(test_no_infix + multisemicolon + dot.suppress() + rbrack.suppress(), "left arr concat partial") + ) + partial_arr_concat = attach(partial_arr_concat_tokens, partial_arr_concat_handle) + # we include (var)arg_comma to ensure the pattern matches the whole arg arg_comma = comma | fixto(FollowedBy(rparen), "") setarg_comma = arg_comma | fixto(FollowedBy(colon), "") @@ -1234,7 +1255,8 @@ class Grammar(object): list_item = ( lbrack.suppress() + list_expr + rbrack.suppress() | condense(lbrack + Optional(comprehension_expr) + rbrack) - # array_literal must come last + # partial_arr_concat and array_literal must come last + | partial_arr_concat | array_literal ) @@ -1544,6 +1566,7 @@ class Grammar(object): | labeled_group(itemgetter_atom_tokens, "itemgetter") + pipe_op | labeled_group(attrgetter_atom_tokens, "attrgetter") + pipe_op | labeled_group(partial_op_atom_tokens, "op partial") + pipe_op + | labeled_group(partial_arr_concat_tokens, "arr concat partial") + pipe_op # expr must come at end | labeled_group(comp_pipe_expr, "expr") + pipe_op ) @@ -1554,6 +1577,7 @@ class Grammar(object): | labeled_group(itemgetter_atom_tokens, "itemgetter") + end_simple_stmt_item | labeled_group(attrgetter_atom_tokens, "attrgetter") + end_simple_stmt_item | labeled_group(partial_op_atom_tokens, "op partial") + end_simple_stmt_item + | labeled_group(partial_arr_concat_tokens, "arr concat partial") + end_simple_stmt_item ) last_pipe_item = Group( lambdef("expr") @@ -1564,6 +1588,7 @@ class Grammar(object): attrgetter_atom_tokens("attrgetter"), partial_atom_tokens("partial"), partial_op_atom_tokens("op partial"), + partial_arr_concat_tokens("arr concat partial"), comp_pipe_expr("expr"), ) ) diff --git a/coconut/root.py b/coconut/root.py index a89c05c85..379efe1f8 100644 --- a/coconut/root.py +++ b/coconut/root.py @@ -26,7 +26,7 @@ VERSION = "3.0.4" VERSION_NAME = None # False for release, int >= 1 for develop -DEVELOP = 3 +DEVELOP = 4 ALPHA = False # for pre releases rather than post releases assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1" diff --git a/coconut/tests/src/cocotest/agnostic/primary_2.coco b/coconut/tests/src/cocotest/agnostic/primary_2.coco index a7f7560df..b8b3afdd6 100644 --- a/coconut/tests/src/cocotest/agnostic/primary_2.coco +++ b/coconut/tests/src/cocotest/agnostic/primary_2.coco @@ -411,6 +411,15 @@ def primary_test_2() -> bool: assert 0xff == 255 == 0x100-1 assert 11259375 == 0xabcdef assert [[] ;; [] ;;;] == [[[], []]] + assert ( + 1 + |> [. ; 2] + |> [[3; 4] ;; .] + ) == [3; 4;; 1; 2] == [[3; 4] ;; .]([. ; 2](1)) + arr: Any = 1 + arr |>= [. ; 2] + arr |>= [[3; 4] ;; .] + assert arr == [3; 4;; 1; 2] == [[3; 4] ;; .] |> call$(?, [. ; 2] |> call$(?, 1)) with process_map.multiple_sequential_calls(): # type: ignore assert map((+), range(3), range(4)$[:-1], strict=True) |> list == [0, 2, 4] == process_map((+), range(3), range(4)$[:-1], strict=True) |> list # type: ignore From 9df70ae19406fc91b4596d358a556d688e011269 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Tue, 28 Nov 2023 23:43:34 -0800 Subject: [PATCH 08/54] Enable --run for dirs Resolves #799. --- DOCS.md | 10 +- coconut/command/command.py | 241 ++++++++++-------- coconut/root.py | 2 +- coconut/tests/main_test.py | 30 ++- .../tests/src/cocotest/agnostic/__main__.coco | 20 ++ coconut/tests/src/runner.coco | 13 +- 6 files changed, 195 insertions(+), 121 deletions(-) create mode 100644 coconut/tests/src/cocotest/agnostic/__main__.coco diff --git a/DOCS.md b/DOCS.md index b4a52b807..a5f21dd70 100644 --- a/DOCS.md +++ b/DOCS.md @@ -225,22 +225,22 @@ as an alias for ``` coconut --quiet --target sys --keep-lines --run --argv ``` -which will quietly compile and run ``, passing any additional arguments to the script, mimicking how the `python` command works. +which will quietly compile and run ``, passing any additional arguments to the script, mimicking how the `python` command works. To instead pass additional compilation arguments to Coconut itself (e.g. `--no-tco`), put them before the `` file. + +`coconut-run` can be used to compile and run directories rather than files, again mimicking how the `python` command works. Specifically, Coconut will compile the directory and then run the `__main__.coco` in that directory, which must exist. `coconut-run` can be used in a Unix shebang line to create a Coconut script by adding the following line to the start of your script: ```bash #!/usr/bin/env coconut-run ``` -To pass additional compilation arguments to `coconut-run` (e.g. `--no-tco`), put them before the `` file. - `coconut-run` will always enable [automatic compilation](#automatic-compilation), such that Coconut source files can be directly imported from any Coconut files run via `coconut-run`. Additionally, compilation parameters (e.g. `--no-tco`) used in `coconut-run` will be passed along and used for any auto compilation. On Python 3.4+, `coconut-run` will use a `__coconut_cache__` directory to cache the compiled Python. Note that `__coconut_cache__` will always be removed from `__file__`. #### Naming Source Files -Coconut source files should, so the compiler can recognize them, use the extension `.coco` (preferred), `.coc`, or `.coconut`. +Coconut source files should, so the compiler can recognize them, use the extension `.coco`. When Coconut compiles a `.coco` file, it will compile to another file with the same name, except with `.py` instead of `.coco`, which will hold the compiled code. @@ -248,7 +248,7 @@ If an extension other than `.py` is desired for the compiled files, then that ex #### Compilation Modes -Files compiled by the `coconut` command-line utility will vary based on compilation parameters. If an entire directory of files is compiled (which the compiler will search recursively for any folders containing `.coco`, `.coc`, or `.coconut` files), a `__coconut__.py` file will be created to house necessary functions (package mode), whereas if only a single file is compiled, that information will be stored within a header inside the file (standalone mode). Standalone mode is better for single files because it gets rid of the overhead involved in importing `__coconut__.py`, but package mode is better for large packages because it gets rid of the need to run the same Coconut header code again in every file, since it can just be imported from `__coconut__.py`. +Files compiled by the `coconut` command-line utility will vary based on compilation parameters. If an entire directory of files is compiled (which the compiler will search recursively for any folders containing `.coco` files), a `__coconut__.py` file will be created to house necessary functions (package mode), whereas if only a single file is compiled, that information will be stored within a header inside the file (standalone mode). Standalone mode is better for single files because it gets rid of the overhead involved in importing `__coconut__.py`, but package mode is better for large packages because it gets rid of the need to run the same Coconut header code again in every file, since it can just be imported from `__coconut__.py`. By default, if the `source` argument to the command-line utility is a file, it will perform standalone compilation on it, whereas if it is a directory, it will recursively search for all `.coco` files and perform package compilation on them. Thus, in most cases, the mode chosen by Coconut automatically will be the right one. But if it is very important that no additional files like `__coconut__.py` be created, for example, then the command-line utility can also be forced to use a specific mode with the `--package` (`-p`) and `--standalone` (`-a`) flags. diff --git a/coconut/command/command.py b/coconut/command/command.py index 95e21d0da..c842dcd75 100644 --- a/coconut/command/command.py +++ b/coconut/command/command.py @@ -253,15 +253,27 @@ def execute_args(self, args, interact=True, original_args=None): logger.log("Directly passed args:", original_args) logger.log("Parsed args:", args) - # validate general command args + # validate args and show warnings if args.stack_size and args.stack_size % 4 != 0: logger.warn("--stack-size should generally be a multiple of 4, not {stack_size} (to support 4 KB pages)".format(stack_size=args.stack_size)) if args.mypy is not None and args.no_line_numbers: logger.warn("using --mypy running with --no-line-numbers is not recommended; mypy error messages won't include Coconut line numbers") + if args.interact and args.run: + logger.warn("extraneous --run argument passed; --interact implies --run") + if args.package and self.mypy: + logger.warn("extraneous --package argument passed; --mypy implies --package") + + # validate args and raise errors if args.line_numbers and args.no_line_numbers: raise CoconutException("cannot compile with both --line-numbers and --no-line-numbers") if args.site_install and args.site_uninstall: raise CoconutException("cannot --site-install and --site-uninstall simultaneously") + if args.standalone and args.package: + raise CoconutException("cannot compile as both --package and --standalone") + if args.standalone and self.mypy: + raise CoconutException("cannot compile as both --package (implied by --mypy) and --standalone") + if args.no_write and self.mypy: + raise CoconutException("cannot compile with --no-write when using --mypy") for and_args in getattr(args, "and") or []: if len(and_args) > 2: raise CoconutException( @@ -271,6 +283,9 @@ def execute_args(self, args, interact=True, original_args=None): ), ) + # modify args + args.run = args.run or args.interact + # process general command args self.set_jobs(args.jobs, args.profile) if args.recursion_limit is not None: @@ -338,44 +353,45 @@ def execute_args(self, args, interact=True, original_args=None): # do compilation, keeping track of compiled filepaths filepaths = [] if args.source is not None: - # warnings if source is given - if args.interact and args.run: - logger.warn("extraneous --run argument passed; --interact implies --run") - if args.package and self.mypy: - logger.warn("extraneous --package argument passed; --mypy implies --package") - - # errors if source is given - if args.standalone and args.package: - raise CoconutException("cannot compile as both --package and --standalone") - if args.standalone and self.mypy: - raise CoconutException("cannot compile as both --package (implied by --mypy) and --standalone") - if args.no_write and self.mypy: - raise CoconutException("cannot compile with --no-write when using --mypy") - # process all source, dest pairs - src_dest_package_triples = [] + all_compile_path_kwargs = [] + extra_compile_path_kwargs = [] for and_args in [(args.source, args.dest)] + (getattr(args, "and") or []): if len(and_args) == 1: src, = and_args dest = None else: src, dest = and_args - src_dest_package_triples.append(self.process_source_dest(src, dest, args)) + all_new_main_kwargs, all_new_extra_kwargs = self.process_source_dest(src, dest, args) + all_compile_path_kwargs += all_new_main_kwargs + extra_compile_path_kwargs += all_new_extra_kwargs # disable jobs if we know we're only compiling one file - if len(src_dest_package_triples) <= 1 and not any(os.path.isdir(source) for source, dest, package in src_dest_package_triples): + if len(all_compile_path_kwargs) <= 1 and not any(os.path.isdir(kwargs["source"]) for kwargs in all_compile_path_kwargs): self.disable_jobs() - # do compilation - with self.running_jobs(exit_on_error=not ( + # do main compilation + exit_on_error = extra_compile_path_kwargs or not ( args.watch or args.profile - )): - for source, dest, package in src_dest_package_triples: - filepaths += self.compile_path(source, dest, package, run=args.run or args.interact, force=args.force) + ) + with self.running_jobs(exit_on_error=exit_on_error): + for kwargs in all_compile_path_kwargs: + filepaths += self.compile_path(**kwargs) + + # run mypy on compiled files self.run_mypy(filepaths) + # do extra compilation if there is any + if extra_compile_path_kwargs: + with self.running_jobs(exit_on_error=exit_on_error): + for kwargs in extra_compile_path_kwargs: + extra_filepaths = self.compile_path(**kwargs) + internal_assert(lambda: set(extra_filepaths) <= set(filepaths), "new file paths from extra compilation", (extra_filepaths, filepaths)) + # validate args if no source is given + elif getattr(args, "and"): + raise CoconutException("--and should only be used for extra source/dest pairs, not the first source/dest pair") elif ( args.run or args.no_write @@ -386,8 +402,6 @@ def execute_args(self, args, interact=True, original_args=None): or args.jobs ): raise CoconutException("a source file/folder must be specified when options that depend on the source are enabled") - elif getattr(args, "and"): - raise CoconutException("--and should only be used for extra source/dest pairs, not the first source/dest pair") # handle extra cli tasks if args.code is not None: @@ -417,8 +431,8 @@ def execute_args(self, args, interact=True, original_args=None): ): self.start_prompt() if args.watch: - # src_dest_package_triples is always available here - self.watch(src_dest_package_triples, args.run, args.force) + # all_compile_path_kwargs is always available here + self.watch(all_compile_path_kwargs) if args.profile: print_profiling_results() @@ -426,16 +440,11 @@ def execute_args(self, args, interact=True, original_args=None): return filepaths def process_source_dest(self, source, dest, args): - """Determine the correct source, dest, package mode to use for the given source, dest, and args.""" + """Get all the compile_path kwargs to use for the given source, dest, and args.""" # determine source processed_source = fixpath(source) # validate args - if (args.run or args.interact) and os.path.isdir(processed_source): - if args.run: - raise CoconutException("source path %r must point to file not directory when --run is enabled" % (source,)) - if args.interact: - raise CoconutException("source path %r must point to file not directory when --run (implied by --interact) is enabled" % (source,)) if args.watch and os.path.isfile(processed_source): raise CoconutException("source path %r must point to directory not file when --watch is enabled" % (source,)) @@ -464,67 +473,51 @@ def process_source_dest(self, source, dest, args): else: raise CoconutException("could not find source path", source) - return processed_source, processed_dest, package - - def register_exit_code(self, code=1, errmsg=None, err=None): - """Update the exit code and errmsg.""" - if err is not None: - internal_assert(errmsg is None, "register_exit_code accepts only one of errmsg or err") - if logger.verbose: - errmsg = format_error(err) - else: - errmsg = err.__class__.__name__ - if errmsg is not None: - if self.errmsg is None: - self.errmsg = errmsg - elif errmsg not in self.errmsg: - if logger.verbose: - self.errmsg += "\nAnd error: " + errmsg - else: - self.errmsg += "; " + errmsg - if code is not None: - self.exit_code = code or self.exit_code - - @contextmanager - def handling_exceptions(self, exit_on_error=None, on_keyboard_interrupt=None): - """Perform proper exception handling.""" - if exit_on_error is None: - exit_on_error = self.fail_fast - try: - if self.using_jobs: - with handling_broken_process_pool(): - yield - else: - yield - except SystemExit as err: - self.register_exit_code(err.code) - # make sure we don't catch GeneratorExit below - except GeneratorExit: - raise - except BaseException as err: - if isinstance(err, CoconutException): - logger.print_exc() - elif isinstance(err, KeyboardInterrupt): - if on_keyboard_interrupt is not None: - on_keyboard_interrupt() - else: - logger.print_exc() - logger.printerr(report_this_text) - self.register_exit_code(err=err) - if exit_on_error: - self.exit_on_error() + # handle running directories + run = args.run + extra_compilation_tasks = [] + if run and os.path.isdir(processed_source): + main_source = os.path.join(processed_source, "__main__" + code_exts[0]) + if not os.path.isfile(main_source): + raise CoconutException("source directory {source} must contain a __main__{ext} when --run{implied} is enabled".format( + source=source, + ext=code_exts[0], + implied=" (implied by --interact)" if args.interact else "", + )) + # first compile the directory without --run + run = False + # then compile just __main__ with --run + extra_compilation_tasks.append(dict( + source=main_source, + dest=processed_dest, + package=package, + run=True, + force=args.force, + )) + + # compile_path kwargs + main_compilation_tasks = [ + dict( + source=processed_source, + dest=processed_dest, + package=package, + run=run, + force=args.force, + ), + ] + return main_compilation_tasks, extra_compilation_tasks - def compile_path(self, path, write=True, package=True, handling_exceptions_kwargs={}, **kwargs): + def compile_path(self, source, dest=True, package=True, handling_exceptions_kwargs={}, **kwargs): """Compile a path and return paths to compiled files.""" - if not isinstance(write, bool): - write = fixpath(write) - if os.path.isfile(path): - destpath = self.compile_file(path, write, package, **kwargs) + if not isinstance(dest, bool): + dest = fixpath(dest) + if os.path.isfile(source): + destpath = self.compile_file(source, dest, package, **kwargs) return [destpath] if destpath is not None else [] - elif os.path.isdir(path): - return self.compile_folder(path, write, package, handling_exceptions_kwargs=handling_exceptions_kwargs, **kwargs) + elif os.path.isdir(source): + return self.compile_folder(source, dest, package, handling_exceptions_kwargs=handling_exceptions_kwargs, **kwargs) else: - raise CoconutException("could not find source path", path) + raise CoconutException("could not find source path", source) def compile_folder(self, directory, write=True, package=True, handling_exceptions_kwargs={}, **kwargs): """Compile a directory and return paths to compiled files.""" @@ -693,6 +686,54 @@ def callback_wrapper(completed_future): callback(result) future.add_done_callback(callback_wrapper) + def register_exit_code(self, code=1, errmsg=None, err=None): + """Update the exit code and errmsg.""" + if err is not None: + internal_assert(errmsg is None, "register_exit_code accepts only one of errmsg or err") + if logger.verbose: + errmsg = format_error(err) + else: + errmsg = err.__class__.__name__ + if errmsg is not None: + if self.errmsg is None: + self.errmsg = errmsg + elif errmsg not in self.errmsg: + if logger.verbose: + self.errmsg += "\nAnd error: " + errmsg + else: + self.errmsg += "; " + errmsg + if code is not None: + self.exit_code = code or self.exit_code + + @contextmanager + def handling_exceptions(self, exit_on_error=None, on_keyboard_interrupt=None): + """Perform proper exception handling.""" + if exit_on_error is None: + exit_on_error = self.fail_fast + try: + if self.using_jobs: + with handling_broken_process_pool(): + yield + else: + yield + except SystemExit as err: + self.register_exit_code(err.code) + # make sure we don't catch GeneratorExit below + except GeneratorExit: + raise + except BaseException as err: + if isinstance(err, CoconutException): + logger.print_exc() + elif isinstance(err, KeyboardInterrupt): + if on_keyboard_interrupt is not None: + on_keyboard_interrupt() + else: + logger.print_exc() + logger.printerr(report_this_text) + self.register_exit_code(err=err) + if exit_on_error: + self.exit_on_error() + def set_jobs(self, jobs, profile=False): """Set --jobs.""" if jobs in (None, "sys"): @@ -1085,21 +1126,23 @@ def start_jupyter(self, args): if run_args is not None: self.register_exit_code(run_cmd(run_args, raise_errs=False), errmsg="Jupyter error") - def watch(self, src_dest_package_triples, run=False, force=False): + def watch(self, all_compile_path_kwargs): """Watch a source and recompile on change.""" from coconut.command.watch import Observer, RecompilationWatcher - for src, _, _ in src_dest_package_triples: + for kwargs in all_compile_path_kwargs: logger.show() - logger.show_tabulated("Watching", showpath(src), "(press Ctrl-C to end)...") + logger.show_tabulated("Watching", showpath(kwargs["source"]), "(press Ctrl-C to end)...") interrupted = [False] # in list to allow modification def interrupt(): interrupted[0] = True - def recompile(path, src, dest, package): + def recompile(path, **kwargs): path = fixpath(path) + src = kwargs.pop("source") + dest = kwargs.pop("dest") if os.path.isfile(path) and os.path.splitext(path)[1] in code_exts: with self.handling_exceptions(on_keyboard_interrupt=interrupt): if dest is True or dest is None: @@ -1111,19 +1154,17 @@ def recompile(path, src, dest, package): filepaths = self.compile_path( path, writedir, - package, - run=run, - force=force, show_unchanged=False, handling_exceptions_kwargs=dict(on_keyboard_interrupt=interrupt), + **kwargs, ) self.run_mypy(filepaths) observer = Observer() watchers = [] - for src, dest, package in src_dest_package_triples: - watcher = RecompilationWatcher(recompile, src, dest, package) - observer.schedule(watcher, src, recursive=True) + for kwargs in all_compile_path_kwargs: + watcher = RecompilationWatcher(recompile, **kwargs) + observer.schedule(watcher, kwargs["source"], recursive=True) watchers.append(watcher) with self.running_jobs(): diff --git a/coconut/root.py b/coconut/root.py index 379efe1f8..1d473a589 100644 --- a/coconut/root.py +++ b/coconut/root.py @@ -26,7 +26,7 @@ VERSION = "3.0.4" VERSION_NAME = None # False for release, int >= 1 for develop -DEVELOP = 4 +DEVELOP = 5 ALPHA = False # for pre releases rather than post releases assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1" diff --git a/coconut/tests/main_test.py b/coconut/tests/main_test.py index 1bc893103..418a5621d 100644 --- a/coconut/tests/main_test.py +++ b/coconut/tests/main_test.py @@ -666,8 +666,19 @@ def run_extras(**kwargs): call_python([os.path.join(dest, "extras.py")], assert_output=True, check_errors=False, stderr_first=True, **kwargs) -def run(args=[], agnostic_target=None, use_run_arg=False, convert_to_import=False, always_sys=False, manage_cache=True, **kwargs): +def run( + args=[], + agnostic_target=None, + use_run_arg=False, + run_directory=False, + convert_to_import=False, + always_sys=False, + manage_cache=True, + **kwargs # no comma for compat +): """Compiles and runs tests.""" + assert use_run_arg + run_directory < 2 + if agnostic_target is None: agnostic_args = args else: @@ -692,12 +703,22 @@ def run(args=[], agnostic_target=None, use_run_arg=False, convert_to_import=Fals if sys.version_info >= (3, 11): comp_311(args, **spec_kwargs) - comp_agnostic(agnostic_args, **kwargs) + if not run_directory: + comp_agnostic(agnostic_args, **kwargs) comp_sys(args, **kwargs) # do non-strict at the end so we get the non-strict header comp_non_strict(args, **kwargs) - if use_run_arg: + if run_directory: + _kwargs = kwargs.copy() + _kwargs["assert_output"] = True + _kwargs["stderr_first"] = True + comp_agnostic( + # remove --strict so that we run with the non-strict header + ["--run"] + [arg for arg in agnostic_args if arg != "--strict"], + **_kwargs + ) + elif use_run_arg: _kwargs = kwargs.copy() _kwargs["assert_output"] = True comp_runner(["--run"] + agnostic_args, **_kwargs) @@ -1028,6 +1049,9 @@ def test_and(self): def test_run_arg(self): run(use_run_arg=True) + def test_run_dir(self): + run(run_directory=True) + if not PYPY and not PY26: def test_jobs_zero(self): run(["--jobs", "0"]) diff --git a/coconut/tests/src/cocotest/agnostic/__main__.coco b/coconut/tests/src/cocotest/agnostic/__main__.coco new file mode 100644 index 000000000..4df76fafc --- /dev/null +++ b/coconut/tests/src/cocotest/agnostic/__main__.coco @@ -0,0 +1,20 @@ +import sys +import os.path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import cocotest +from cocotest.main import run_main + + +def main() -> bool: + print(".", end="", flush=True) # . + assert cocotest.__doc__ + assert run_main( + outer_MatchError=MatchError, + test_easter_eggs="--test-easter-eggs" in sys.argv, + ) is True + return True + + +if __name__ == "__main__": + assert main() is True diff --git a/coconut/tests/src/runner.coco b/coconut/tests/src/runner.coco index 3265cf493..62a090d92 100644 --- a/coconut/tests/src/runner.coco +++ b/coconut/tests/src/runner.coco @@ -5,18 +5,7 @@ sys.path.append(os.path.dirname(os.path.abspath(__file__))) import pytest pytest.register_assert_rewrite(py_str("cocotest")) -import cocotest -from cocotest.main import run_main - - -def main() -> bool: - print(".", end="", flush=True) # . - assert cocotest.__doc__ - assert run_main( - outer_MatchError=MatchError, - test_easter_eggs="--test-easter-eggs" in sys.argv, - ) is True - return True +from cocotest.__main__ import main if __name__ == "__main__": From 3c32f230a0d3f07ba69cd7bd7340161b09bccc8f Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Wed, 29 Nov 2023 13:53:54 -0800 Subject: [PATCH 09/54] Fix package test --- coconut/tests/src/cocotest/agnostic/main.coco | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coconut/tests/src/cocotest/agnostic/main.coco b/coconut/tests/src/cocotest/agnostic/main.coco index 97c9d3df7..56bfad400 100644 --- a/coconut/tests/src/cocotest/agnostic/main.coco +++ b/coconut/tests/src/cocotest/agnostic/main.coco @@ -90,7 +90,7 @@ def run_main(outer_MatchError, test_easter_eggs=False) -> bool: if using_tco: assert hasattr(tco_func, "_coconut_tco_func") assert tco_test() is True - if outer_MatchError.__module__ != "__main__": + if not outer_MatchError.__module__.endswith("__main__"): assert package_test(outer_MatchError) is True print_dot() # ....... From a360430b934fc9ae1a5a6a892dffc4c17d032657 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Wed, 29 Nov 2023 15:28:27 -0800 Subject: [PATCH 10/54] Fix py2 syntax --- coconut/command/command.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coconut/command/command.py b/coconut/command/command.py index c842dcd75..c49151572 100644 --- a/coconut/command/command.py +++ b/coconut/command/command.py @@ -1156,7 +1156,7 @@ def recompile(path, **kwargs): writedir, show_unchanged=False, handling_exceptions_kwargs=dict(on_keyboard_interrupt=interrupt), - **kwargs, + **kwargs # no comma for py2 ) self.run_mypy(filepaths) From cc967cd36ab8a2e738ef996212211b083f27b206 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Wed, 29 Nov 2023 21:35:29 -0800 Subject: [PATCH 11/54] Disable py312 tests --- coconut/tests/main_test.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/coconut/tests/main_test.py b/coconut/tests/main_test.py index 418a5621d..f7b5a3e26 100644 --- a/coconut/tests/main_test.py +++ b/coconut/tests/main_test.py @@ -89,8 +89,10 @@ os.environ["PYDEVD_DISABLE_FILE_VALIDATION"] = "1" -# run fewer tests on Windows so appveyor doesn't time out -TEST_ALL = get_bool_env_var("COCONUT_TEST_ALL", not WINDOWS) +TEST_ALL = get_bool_env_var("COCONUT_TEST_ALL", ( + # run fewer tests on Windows so appveyor doesn't time out + not WINDOWS +)) # ----------------------------------------------------------------------------------------------------------------------- @@ -1015,18 +1017,20 @@ def test_always_sys(self): def test_target(self): run(agnostic_target=(2 if PY2 else 3)) - def test_standalone(self): - run(["--standalone"]) + def test_no_tco(self): + run(["--no-tco"]) def test_package(self): run(["--package"]) - def test_no_tco(self): - run(["--no-tco"]) + # TODO: re-allow these once we figure out what's causing the strange unreproducible errors with them on py3.12 + if not PY312: + def test_standalone(self): + run(["--standalone"]) - if PY35: - def test_no_wrap(self): - run(["--no-wrap"]) + if PY35: + def test_no_wrap(self): + run(["--no-wrap"]) if TEST_ALL: if CPYTHON: From 7428bde476acfe937d10354f3c495fcc1feb96e7 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Wed, 29 Nov 2023 22:19:35 -0800 Subject: [PATCH 12/54] Fix import --- coconut/tests/main_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/coconut/tests/main_test.py b/coconut/tests/main_test.py index f7b5a3e26..4e51c793e 100644 --- a/coconut/tests/main_test.py +++ b/coconut/tests/main_test.py @@ -57,6 +57,7 @@ PY38, PY39, PY310, + PY312, CPYTHON, adaptive_any_of_env_var, reverse_any_of_env_var, From f2eea1d9cdae04d4d02307e9401faf400182cd49 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Thu, 30 Nov 2023 01:27:33 -0800 Subject: [PATCH 13/54] Improve error formatting Refs #812. --- coconut/compiler/compiler.py | 3 ++- coconut/compiler/grammar.py | 2 +- coconut/constants.py | 7 ++++--- coconut/exceptions.py | 22 +++++++++++++++++++--- coconut/highlighter.py | 16 ++++++++++++++++ coconut/root.py | 2 +- coconut/terminal.py | 13 ++++++++----- coconut/tests/src/extras.coco | 16 ++++++++++++++++ coconut/util.py | 13 +++++++++++++ 9 files changed, 80 insertions(+), 14 deletions(-) diff --git a/coconut/compiler/compiler.py b/coconut/compiler/compiler.py index 3bf9fbf3a..e9eee011c 100644 --- a/coconut/compiler/compiler.py +++ b/coconut/compiler/compiler.py @@ -49,6 +49,7 @@ ) from coconut.constants import ( + PY35, specific_targets, targets, pseudo_targets, @@ -1359,7 +1360,7 @@ def parse( internal_assert(pre_procd is not None, "invalid deferred syntax error in pre-processing", err) raise self.make_syntax_err(err, pre_procd, after_parsing=parsed is not None) # RuntimeError, not RecursionError, for Python < 3.5 - except RuntimeError as err: + except (RecursionError if PY35 else RuntimeError) as err: raise CoconutException( str(err), extra="try again with --recursion-limit greater than the current " + str(sys.getrecursionlimit()) + " (you may also need to increase --stack-size)", diff --git a/coconut/compiler/grammar.py b/coconut/compiler/grammar.py index b17e856a9..0c62467e6 100644 --- a/coconut/compiler/grammar.py +++ b/coconut/compiler/grammar.py @@ -2192,7 +2192,7 @@ class Grammar(object): where_stmt_ref = where_item + where_suite implicit_return = ( - invalid_syntax(return_stmt, "expected expression but got return statement") + invalid_syntax(return_stmt, "assignment function expected expression as last statement but got return instead") | attach(new_testlist_star_expr, implicit_return_handle) ) implicit_return_where = Forward() diff --git a/coconut/constants.py b/coconut/constants.py index 0c648eb51..4677df802 100644 --- a/coconut/constants.py +++ b/coconut/constants.py @@ -315,12 +315,11 @@ def get_path_env_var(env_var, default): ) tabideal = 4 # spaces to indent code for displaying - -taberrfmt = 2 # spaces to indent exceptions - justify_len = 79 # ideal line length +taberrfmt = 2 # spaces to indent exceptions min_squiggles_in_err_msg = 1 +max_err_msg_lines = 10 # for pattern-matching default_matcher_style = "python warn" @@ -645,6 +644,8 @@ def get_path_env_var(env_var, default): log_color_code = "93" default_style = "default" +fake_styles = ("none", "list") + prompt_histfile = get_path_env_var( "COCONUT_HISTORY_FILE", os.path.join(coconut_home, ".coconut_history"), diff --git a/coconut/exceptions.py b/coconut/exceptions.py index 341ef3831..fde3962fc 100644 --- a/coconut/exceptions.py +++ b/coconut/exceptions.py @@ -30,6 +30,7 @@ taberrfmt, report_this_text, min_squiggles_in_err_msg, + max_err_msg_lines, ) from coconut.util import ( pickleable_obj, @@ -38,6 +39,7 @@ clean, get_displayable_target, normalize_newlines, + highlight, ) # ----------------------------------------------------------------------------------------------------------------------- @@ -153,8 +155,10 @@ def message(self, message, source, point, ln, extra=None, endpoint=None, filenam point_ind = clip(point_ind, 0, len(part)) endpoint_ind = clip(endpoint_ind, point_ind, len(part)) - message += "\n" + " " * taberrfmt + part + # add code to message + message += "\n" + " " * taberrfmt + highlight(part) + # add squiggles to message if point_ind > 0 or endpoint_ind > 0: err_len = endpoint_ind - point_ind message += "\n" + " " * (taberrfmt + point_ind) @@ -182,14 +186,26 @@ def message(self, message, source, point, ln, extra=None, endpoint=None, filenam max_line_len = max(len(line) for line in lines) + # add top squiggles message += "\n" + " " * (taberrfmt + point_ind) if point_ind >= len(lines[0]): message += "|" else: message += "/" + "~" * (len(lines[0]) - point_ind - 1) message += "~" * (max_line_len - len(lines[0])) + "\n" - for line in lines: - message += "\n" + " " * taberrfmt + line + + # add code + if len(lines) > max_err_msg_lines: + for i in range(max_err_msg_lines // 2): + message += "\n" + " " * taberrfmt + highlight(lines[i]) + message += "\n" + " " * (taberrfmt // 2) + "..." + for i in range(len(lines) - max_err_msg_lines // 2, len(lines)): + message += "\n" + " " * taberrfmt + highlight(lines[i]) + else: + for line in lines: + message += "\n" + " " * taberrfmt + highlight(line) + + # add bottom squiggles message += ( "\n\n" + " " * taberrfmt + "~" * endpoint_ind + ("^" if self.point_to_endpoint else "/" if 0 < endpoint_ind < len(lines[-1]) else "|") diff --git a/coconut/highlighter.py b/coconut/highlighter.py index a12686a06..f7c010b31 100644 --- a/coconut/highlighter.py +++ b/coconut/highlighter.py @@ -19,10 +19,12 @@ from coconut.root import * # NOQA +from pygments import highlight from pygments.lexers import Python3Lexer, PythonConsoleLexer from pygments.token import Text, Operator, Keyword, Name, Number from pygments.lexer import words, bygroups from pygments.util import shebang_matches +from pygments.formatters import Terminal256Formatter from coconut.constants import ( highlight_builtins, @@ -36,6 +38,9 @@ template_ext, coconut_exceptions, main_prompt, + style_env_var, + default_style, + fake_styles, ) # ----------------------------------------------------------------------------------------------------------------------- @@ -113,3 +118,14 @@ def __init__(self, stripnl=False, stripall=False, ensurenl=True, tabsize=tabidea def analyse_text(text): return shebang_matches(text, shebang_regex) + + +def highlight_coconut_for_terminal(code): + """Highlight Coconut code for the terminal.""" + style = os.getenv(style_env_var, default_style) + if style not in fake_styles: + try: + return highlight(code, CoconutLexer(), Terminal256Formatter(style=style)) + except Exception: + pass + return code diff --git a/coconut/root.py b/coconut/root.py index 1d473a589..0a9d623c1 100644 --- a/coconut/root.py +++ b/coconut/root.py @@ -26,7 +26,7 @@ VERSION = "3.0.4" VERSION_NAME = None # False for release, int >= 1 for develop -DEVELOP = 5 +DEVELOP = 6 ALPHA = False # for pre releases rather than post releases assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1" diff --git a/coconut/terminal.py b/coconut/terminal.py index 11fb41cf7..8a19ca95f 100644 --- a/coconut/terminal.py +++ b/coconut/terminal.py @@ -207,8 +207,13 @@ def __init__(self, other=None): self.patch_logging() @classmethod - def enable_colors(cls): + def enable_colors(cls, file=None): """Attempt to enable CLI colors.""" + if ( + use_color is False + or use_color is None and file is not None and not isatty(file) + ): + return False if not cls.colors_enabled: # necessary to resolve https://bugs.python.org/issue40134 try: @@ -216,6 +221,7 @@ def enable_colors(cls): except BaseException: logger.log_exc() cls.colors_enabled = True + return True def copy_from(self, other): """Copy other onto self.""" @@ -265,11 +271,8 @@ def display( else: raise CoconutInternalException("invalid logging level", level) - if use_color is False or (use_color is None and not isatty(file)): - color = None - if color: - self.enable_colors() + color = self.enable_colors(file) and color raw_message = " ".join(str(msg) for msg in messages) # if there's nothing to display but there is a sig, display the sig diff --git a/coconut/tests/src/extras.coco b/coconut/tests/src/extras.coco index 407094fe9..7be756796 100644 --- a/coconut/tests/src/extras.coco +++ b/coconut/tests/src/extras.coco @@ -365,6 +365,22 @@ import abc except CoconutStyleError as err: assert str(err) == """found unused import 'abc' (add '# NOQA' to suppress) (remove --strict to downgrade to a warning) (line 1) import abc""" + assert_raises(-> parse("""class A(object): + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11 + 12 + 13 + 14 + 15"""), CoconutStyleError, err_has="\n ...\n") setup(line_numbers=False, strict=True, target="sys") assert_raises(-> parse("await f x"), CoconutParseError, err_has='invalid use of the keyword "await"') diff --git a/coconut/util.py b/coconut/util.py index 1e9b91bb1..38eda7b76 100644 --- a/coconut/util.py +++ b/coconut/util.py @@ -325,6 +325,19 @@ def replace_all(inputstr, all_to_replace, replace_to): return inputstr +def highlight(code): + """Attempt to highlight Coconut code for the terminal.""" + from coconut.terminal import logger # hide to remove circular deps + if logger.enable_colors(sys.stdout) and logger.enable_colors(sys.stderr): + try: + from coconut.highlighter import highlight_coconut_for_terminal + except ImportError: + pass + else: + return highlight_coconut_for_terminal(code) + return code + + # ----------------------------------------------------------------------------------------------------------------------- # VERSIONING: # ----------------------------------------------------------------------------------------------------------------------- From 1ed8227dc560d327dab0c2ca90fc58a28afdbd66 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Thu, 30 Nov 2023 15:21:45 -0800 Subject: [PATCH 14/54] Fix exception highlighting --- coconut/exceptions.py | 66 ++++++++++++++++++++++-------------------- coconut/highlighter.py | 3 +- coconut/root.py | 2 +- coconut/util.py | 2 +- 4 files changed, 38 insertions(+), 35 deletions(-) diff --git a/coconut/exceptions.py b/coconut/exceptions.py index fde3962fc..c431a70e7 100644 --- a/coconut/exceptions.py +++ b/coconut/exceptions.py @@ -104,19 +104,18 @@ def kwargs(self): def message(self, message, source, point, ln, extra=None, endpoint=None, filename=None): """Creates a SyntaxError-like message.""" - if message is None: - message = "parsing failed" + message_parts = ["parsing failed" if message is None else message] if extra is not None: - message += " (" + str(extra) + ")" + message_parts += [" (", str(extra), ")"] if ln is not None: - message += " (line " + str(ln) + message_parts += [" (line ", str(ln)] if filename is not None: - message += " in " + repr(filename) - message += ")" + message_parts += [" in ", repr(filename)] + message_parts += [")"] if source: if point is None: for line in source.splitlines(): - message += "\n" + " " * taberrfmt + clean(line) + message_parts += ["\n", " " * taberrfmt, clean(line)] else: source = normalize_newlines(source) point = clip(point, 0, len(source)) @@ -155,25 +154,25 @@ def message(self, message, source, point, ln, extra=None, endpoint=None, filenam point_ind = clip(point_ind, 0, len(part)) endpoint_ind = clip(endpoint_ind, point_ind, len(part)) - # add code to message - message += "\n" + " " * taberrfmt + highlight(part) + # add code to message, highlighting part only at end so as not to change len(part) + message_parts += ["\n", " " * taberrfmt, highlight(part)] # add squiggles to message if point_ind > 0 or endpoint_ind > 0: err_len = endpoint_ind - point_ind - message += "\n" + " " * (taberrfmt + point_ind) + message_parts += ["\n", " " * (taberrfmt + point_ind)] if err_len <= min_squiggles_in_err_msg: if not self.point_to_endpoint: - message += "^" - message += "~" * err_len # err_len ~'s when there's only an extra char in one spot + message_parts += ["^"] + message_parts += ["~" * err_len] # err_len ~'s when there's only an extra char in one spot if self.point_to_endpoint: - message += "^" + message_parts += ["^"] else: - message += ( - ("^" if not self.point_to_endpoint else "\\") - + "~" * (err_len - 1) # err_len-1 ~'s when there's an extra char at the start and end - + ("^" if self.point_to_endpoint else "/" if endpoint_ind < len(part) else "|") - ) + message_parts += [ + ("^" if not self.point_to_endpoint else "\\"), + "~" * (err_len - 1), # err_len-1 ~'s when there's an extra char at the start and end + ("^" if self.point_to_endpoint else "/" if endpoint_ind < len(part) else "|"), + ] # multi-line error message else: @@ -187,31 +186,34 @@ def message(self, message, source, point, ln, extra=None, endpoint=None, filenam max_line_len = max(len(line) for line in lines) # add top squiggles - message += "\n" + " " * (taberrfmt + point_ind) + message_parts += ["\n", " " * (taberrfmt + point_ind)] if point_ind >= len(lines[0]): - message += "|" + message_parts += ["|"] else: - message += "/" + "~" * (len(lines[0]) - point_ind - 1) - message += "~" * (max_line_len - len(lines[0])) + "\n" + message_parts += ["/", "~" * (len(lines[0]) - point_ind - 1)] + message_parts += ["~" * (max_line_len - len(lines[0])), "\n"] - # add code + # add code, highlighting all of it together + code_parts = [] if len(lines) > max_err_msg_lines: for i in range(max_err_msg_lines // 2): - message += "\n" + " " * taberrfmt + highlight(lines[i]) - message += "\n" + " " * (taberrfmt // 2) + "..." + code_parts += ["\n", " " * taberrfmt, lines[i]] + code_parts += ["\n", " " * (taberrfmt // 2), "..."] for i in range(len(lines) - max_err_msg_lines // 2, len(lines)): - message += "\n" + " " * taberrfmt + highlight(lines[i]) + code_parts += ["\n", " " * taberrfmt, lines[i]] else: for line in lines: - message += "\n" + " " * taberrfmt + highlight(line) + code_parts += ["\n", " " * taberrfmt, line] + message_parts += highlight("".join(code_parts)) # add bottom squiggles - message += ( - "\n\n" + " " * taberrfmt + "~" * endpoint_ind - + ("^" if self.point_to_endpoint else "/" if 0 < endpoint_ind < len(lines[-1]) else "|") - ) + message_parts += [ + "\n\n", + " " * taberrfmt + "~" * endpoint_ind, + ("^" if self.point_to_endpoint else "/" if 0 < endpoint_ind < len(lines[-1]) else "|"), + ] - return message + return "".join(message_parts) def syntax_err(self): """Creates a SyntaxError.""" diff --git a/coconut/highlighter.py b/coconut/highlighter.py index f7c010b31..cb6ce0e53 100644 --- a/coconut/highlighter.py +++ b/coconut/highlighter.py @@ -42,6 +42,7 @@ default_style, fake_styles, ) +from coconut.terminal import logger # ----------------------------------------------------------------------------------------------------------------------- # LEXERS: @@ -127,5 +128,5 @@ def highlight_coconut_for_terminal(code): try: return highlight(code, CoconutLexer(), Terminal256Formatter(style=style)) except Exception: - pass + logger.log_exc() return code diff --git a/coconut/root.py b/coconut/root.py index 0a9d623c1..d6f1aa528 100644 --- a/coconut/root.py +++ b/coconut/root.py @@ -26,7 +26,7 @@ VERSION = "3.0.4" VERSION_NAME = None # False for release, int >= 1 for develop -DEVELOP = 6 +DEVELOP = 7 ALPHA = False # for pre releases rather than post releases assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1" diff --git a/coconut/util.py b/coconut/util.py index 38eda7b76..1a07d0c3b 100644 --- a/coconut/util.py +++ b/coconut/util.py @@ -332,7 +332,7 @@ def highlight(code): try: from coconut.highlighter import highlight_coconut_for_terminal except ImportError: - pass + logger.log_exc() else: return highlight_coconut_for_terminal(code) return code From 5a8ae486d39cc2f91d3a9102f184c3edf8195df4 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Thu, 30 Nov 2023 15:48:52 -0800 Subject: [PATCH 15/54] Improve command loading --- coconut/command/command.py | 1 - 1 file changed, 1 deletion(-) diff --git a/coconut/command/command.py b/coconut/command/command.py index c49151572..fc6fe2d3e 100644 --- a/coconut/command/command.py +++ b/coconut/command/command.py @@ -246,7 +246,6 @@ def execute_args(self, args, interact=True, original_args=None): unset_fast_pyparsing_reprs() if args.profile: start_profiling() - logger.enable_colors() logger.log(cli_version) if original_args is not None: From d76196d2fd0eebb8be0dfffbda5767b432a20fb8 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Sat, 2 Dec 2023 17:26:01 -0800 Subject: [PATCH 16/54] Add lift_apart Resolves #807. --- DOCS.md | 59 +++++++++- __coconut__/__init__.pyi | 28 ++++- coconut/compiler/grammar.py | 10 +- coconut/compiler/header.py | 9 +- coconut/compiler/templates/header.py_template | 111 +++++++++++------- coconut/constants.py | 3 +- coconut/root.py | 2 +- coconut/terminal.py | 4 +- .../tests/src/cocotest/agnostic/suite.coco | 4 +- coconut/tests/src/cocotest/agnostic/util.coco | 18 +++ coconut/tests/src/extras.coco | 3 + 11 files changed, 188 insertions(+), 63 deletions(-) diff --git a/DOCS.md b/DOCS.md index a5f21dd70..3d565cc5f 100644 --- a/DOCS.md +++ b/DOCS.md @@ -3505,11 +3505,11 @@ def flip(f, nargs=None) = ) ``` -#### `lift` +#### `lift` and `lift_apart` -**lift**(_func_) +##### **lift**(_func_) -**lift**(_func_, *_func\_args_, **_func\_kwargs_) +##### **lift**(_func_, *_func\_args_, **_func\_kwargs_) Coconut's `lift` built-in is a higher-order function that takes in a function and “lifts” it up so that all of its arguments are functions. @@ -3533,7 +3533,33 @@ def lift(f) = ( `lift` also supports a shortcut form such that `lift(f, *func_args, **func_kwargs)` is equivalent to `lift(f)(*func_args, **func_kwargs)`. -##### Example +##### **lift\_apart**(_func_) + +##### **lift\_apart**(_func_, *_func\_args_, **_func\_kwargs_) + +Coconut's `lift_apart` built-in is very similar to `lift`, except instead of duplicating the final arguments to each function, it separates them out. + +For a binary function `f(x, y)` and two unary functions `g(z)` and `h(z)`, `lift_apart` works as +```coconut +lift_apart(f)(g, h)(z, w) == f(g(z), h(w)) +``` +such that in this case `lift_apart` implements the `D2` combinator. + +In the general case, `lift_apart` is equivalent to a pickleable version of +```coconut +def lift_apart(f) = ( + (*func_args, **func_kwargs) => + (*args, **kwargs) => + f( + *(f(x) for f, x in zip(func_args, args, strict=True)), + **{k: func_kwargs[k](kwargs[k]) for k in func_kwargs.keys() | kwargs.keys()}, + ) +) +``` + +`lift_apart` supports the same shortcut form as `lift`. + +##### Examples **Coconut:** ```coconut @@ -3552,8 +3578,33 @@ def plus_and_times(x, y): return x + y, x * y ``` +**Coconut:** +```coconut +first_false_and_last_true = ( + lift(,)(ident, reversed) + ..*> lift_apart(,)(dropwhile$(bool), dropwhile$(not)) + ..*> lift_apart(,)(.$[0], .$[0]) +) +``` + +**Python:** +```coconut_python +from itertools import dropwhile + +def first_false_and_last_true(xs): + rev_xs = reversed(xs) + return ( + next(dropwhile(bool, xs)), + next(dropwhile(lambda x: not x, rev_xs)), + ) +``` + #### `and_then` and `and_then_await` +**and\_then**(_first\_async\_func_, _second\_func_) + +**and\_then\_await**(_first\_async\_func_, _second\_async\_func_) + Coconut provides the `and_then` and `and_then_await` built-ins for composing `async` functions. Specifically: * To forwards compose an async function `async_f` with a normal function `g` (such that `g` is called on the result of `await`ing `async_f`), write ``async_f `and_then` g``. * To forwards compose an async function `async_f` with another async function `async_g` (such that `async_g` is called on the result of `await`ing `async_f`, and then `async_g` is itself awaited), write ``async_f `and_then_await` async_g``. diff --git a/__coconut__/__init__.pyi b/__coconut__/__init__.pyi index 42af159ec..c68c7b69c 100644 --- a/__coconut__/__init__.pyi +++ b/__coconut__/__init__.pyi @@ -1636,20 +1636,42 @@ def lift(func: _t.Callable[[_T, _U], _W]) -> _coconut_lifted_2[_T, _U, _W]: ... def lift(func: _t.Callable[[_T, _U, _V], _W]) -> _coconut_lifted_3[_T, _U, _V, _W]: ... @_t.overload def lift(func: _t.Callable[..., _W]) -> _t.Callable[..., _t.Callable[..., _W]]: - """Lift a function up so that all of its arguments are functions. + """Lift a function up so that all of its arguments are functions that all take the same arguments. For a binary function f(x, y) and two unary functions g(z) and h(z), lift works as the S' combinator: lift(f)(g, h)(z) == f(g(z), h(z)) In general, lift is equivalent to: - def lift(f) = ((*func_args, **func_kwargs) -> (*args, **kwargs) -> - f(*(g(*args, **kwargs) for g in func_args), **{lbrace}k: h(*args, **kwargs) for k, h in func_kwargs.items(){rbrace})) + def lift(f) = ((*func_args, **func_kwargs) => (*args, **kwargs) => ( + f(*(g(*args, **kwargs) for g in func_args), **{k: h(*args, **kwargs) for k, h in func_kwargs.items()})) + ) lift also supports a shortcut form such that lift(f, *func_args, **func_kwargs) is equivalent to lift(f)(*func_args, **func_kwargs). """ ... _coconut_lift = lift +@_t.overload +def lift_apart(func: _t.Callable[[_T], _W]) -> _t.Callable[[_t.Callable[[_U], _T]], _t.Callable[[_U], _W]]: ... +@_t.overload +def lift_apart(func: _t.Callable[[_T, _X], _W]) -> _t.Callable[[_t.Callable[[_U], _T], _t.Callable[[_Y], _X]], _t.Callable[[_U, _Y], _W]]: ... +@_t.overload +def lift_apart(func: _t.Callable[..., _W]) -> _t.Callable[..., _t.Callable[..., _W]]: + """Lift a function up so that all of its arguments are functions that each take separate arguments. + + For a binary function f(x, y) and two unary functions g(z) and h(z), lift_apart works as the D2 combinator: + lift_apart(f)(g, h)(z, w) == f(g(z), h(w)) + + In general, lift_apart is equivalent to: + def lift_apart(func) = (*func_args, **func_kwargs) => (*args, **kwargs) => func( + *map(call, func_args, args, strict=True), + **{k: func_kwargs[k](kwargs[k]) for k in func_kwargs.keys() | kwargs.keys()}, + ) + + lift_apart also supports a shortcut form such that lift_apart(f, *func_args, **func_kwargs) is equivalent to lift_apart(f)(*func_args, **func_kwargs). + """ + ... + def all_equal(iterable: _Iterable) -> bool: """For a given iterable, check whether all elements in that iterable are equal to each other. diff --git a/coconut/compiler/grammar.py b/coconut/compiler/grammar.py index 0c62467e6..be4d19268 100644 --- a/coconut/compiler/grammar.py +++ b/coconut/compiler/grammar.py @@ -1381,9 +1381,13 @@ class Grammar(object): simple_assign = Forward() simple_assign_ref = maybeparens( lparen, - (setname | passthrough_atom) - + ZeroOrMore(ZeroOrMore(complex_trailer) + OneOrMore(simple_trailer)), - rparen, + ( + # refname if there's a trailer, setname if not + (refname | passthrough_atom) + OneOrMore(ZeroOrMore(complex_trailer) + OneOrMore(simple_trailer)) + | setname + | passthrough_atom + ), + rparen ) simple_assignlist = maybeparens(lparen, itemlist(simple_assign, comma, suppress_trailing=False), rparen) diff --git a/coconut/compiler/header.py b/coconut/compiler/header.py index 17e1b1041..37e813c8b 100644 --- a/coconut/compiler/header.py +++ b/coconut/compiler/header.py @@ -290,11 +290,15 @@ def process_header_args(which, use_hash, target, no_tco, strict, no_wrap): report_this_text=report_this_text, from_None=" from None" if target.startswith("3") else "", process_="process_" if target_info >= (3, 13) else "", - numpy_modules=tuple_str_of(numpy_modules, add_quotes=True), pandas_numpy_modules=tuple_str_of(pandas_numpy_modules, add_quotes=True), jax_numpy_modules=tuple_str_of(jax_numpy_modules, add_quotes=True), self_match_types=tuple_str_of(self_match_types), + comma_bytearray=", bytearray" if not target.startswith("3") else "", + lstatic="staticmethod(" if not target.startswith("3") else "", + rstatic=")" if not target.startswith("3") else "", + all_keys="self.func_kwargs.keys() | kwargs.keys()" if target_info >= (3,) else "_coconut.set(self.func_kwargs.keys()) | _coconut.set(kwargs.keys())", + set_super=( # we have to use _coconut_super even on the universal target, since once we set __class__ it becomes a local variable "super = py_super" if target.startswith("3") else "super = _coconut_super" @@ -335,9 +339,6 @@ def process_header_args(which, use_hash, target, no_tco, strict, no_wrap): else "zip_longest = itertools.izip_longest", indent=1, ), - comma_bytearray=", bytearray" if not target.startswith("3") else "", - lstatic="staticmethod(" if not target.startswith("3") else "", - rstatic=")" if not target.startswith("3") else "", zip_iter=prepare( r''' for items in _coconut.iter(_coconut.zip(*self.iters, strict=self.strict) if _coconut_sys.version_info >= (3, 10) else _coconut.zip_longest(*self.iters, fillvalue=_coconut_sentinel) if self.strict else _coconut.zip(*self.iters)): diff --git a/coconut/compiler/templates/header.py_template b/coconut/compiler/templates/header.py_template index 2f401ad5a..30e75868a 100644 --- a/coconut/compiler/templates/header.py_template +++ b/coconut/compiler/templates/header.py_template @@ -432,7 +432,7 @@ def and_then(first_async_func, second_func): first_async_func: async (**T) -> U, second_func: U -> V, ) -> async (**T) -> V = - async def (*args, **kwargs) -> ( + async def (*args, **kwargs) => ( first_async_func(*args, **kwargs) |> await |> second_func @@ -447,7 +447,7 @@ def and_then_await(first_async_func, second_async_func): first_async_func: async (**T) -> U, second_async_func: async U -> V, ) -> async (**T) -> V = - async def (*args, **kwargs) -> ( + async def (*args, **kwargs) => ( first_async_func(*args, **kwargs) |> await |> second_async_func @@ -458,98 +458,98 @@ def and_then_await(first_async_func, second_async_func): def _coconut_forward_compose(func, *funcs): """Forward composition operator (..>). - (..>)(f, g) is effectively equivalent to (*args, **kwargs) -> g(f(*args, **kwargs)).""" + (..>)(f, g) is effectively equivalent to (*args, **kwargs) => g(f(*args, **kwargs)).""" return _coconut_base_compose(func, *((f, 0, False) for f in funcs)) def _coconut_back_compose(*funcs): """Backward composition operator (<..). - (<..)(f, g) is effectively equivalent to (*args, **kwargs) -> f(g(*args, **kwargs)).""" + (<..)(f, g) is effectively equivalent to (*args, **kwargs) => f(g(*args, **kwargs)).""" return _coconut_forward_compose(*_coconut.reversed(funcs)) def _coconut_forward_none_compose(func, *funcs): """Forward none-aware composition operator (..?>). - (..?>)(f, g) is effectively equivalent to (*args, **kwargs) -> g?(f(*args, **kwargs)).""" + (..?>)(f, g) is effectively equivalent to (*args, **kwargs) => g?(f(*args, **kwargs)).""" return _coconut_base_compose(func, *((f, 0, True) for f in funcs)) def _coconut_back_none_compose(*funcs): """Backward none-aware composition operator (<..?). - (<..?)(f, g) is effectively equivalent to (*args, **kwargs) -> f?(g(*args, **kwargs)).""" + (<..?)(f, g) is effectively equivalent to (*args, **kwargs) => f?(g(*args, **kwargs)).""" return _coconut_forward_none_compose(*_coconut.reversed(funcs)) def _coconut_forward_star_compose(func, *funcs): """Forward star composition operator (..*>). - (..*>)(f, g) is effectively equivalent to (*args, **kwargs) -> g(*f(*args, **kwargs)).""" + (..*>)(f, g) is effectively equivalent to (*args, **kwargs) => g(*f(*args, **kwargs)).""" return _coconut_base_compose(func, *((f, 1, False) for f in funcs)) def _coconut_back_star_compose(*funcs): """Backward star composition operator (<*..). - (<*..)(f, g) is effectively equivalent to (*args, **kwargs) -> f(*g(*args, **kwargs)).""" + (<*..)(f, g) is effectively equivalent to (*args, **kwargs) => f(*g(*args, **kwargs)).""" return _coconut_forward_star_compose(*_coconut.reversed(funcs)) def _coconut_forward_none_star_compose(func, *funcs): """Forward none-aware star composition operator (..?*>). - (..?*>)(f, g) is effectively equivalent to (*args, **kwargs) -> g?(*f(*args, **kwargs)).""" + (..?*>)(f, g) is effectively equivalent to (*args, **kwargs) => g?(*f(*args, **kwargs)).""" return _coconut_base_compose(func, *((f, 1, True) for f in funcs)) def _coconut_back_none_star_compose(*funcs): """Backward none-aware star composition operator (<*?..). - (<*?..)(f, g) is effectively equivalent to (*args, **kwargs) -> f?(*g(*args, **kwargs)).""" + (<*?..)(f, g) is effectively equivalent to (*args, **kwargs) => f?(*g(*args, **kwargs)).""" return _coconut_forward_none_star_compose(*_coconut.reversed(funcs)) def _coconut_forward_dubstar_compose(func, *funcs): """Forward double star composition operator (..**>). - (..**>)(f, g) is effectively equivalent to (*args, **kwargs) -> g(**f(*args, **kwargs)).""" + (..**>)(f, g) is effectively equivalent to (*args, **kwargs) => g(**f(*args, **kwargs)).""" return _coconut_base_compose(func, *((f, 2, False) for f in funcs)) def _coconut_back_dubstar_compose(*funcs): """Backward double star composition operator (<**..). - (<**..)(f, g) is effectively equivalent to (*args, **kwargs) -> f(**g(*args, **kwargs)).""" + (<**..)(f, g) is effectively equivalent to (*args, **kwargs) => f(**g(*args, **kwargs)).""" return _coconut_forward_dubstar_compose(*_coconut.reversed(funcs)) def _coconut_forward_none_dubstar_compose(func, *funcs): """Forward none-aware double star composition operator (..?**>). - (..?**>)(f, g) is effectively equivalent to (*args, **kwargs) -> g?(**f(*args, **kwargs)).""" + (..?**>)(f, g) is effectively equivalent to (*args, **kwargs) => g?(**f(*args, **kwargs)).""" return _coconut_base_compose(func, *((f, 2, True) for f in funcs)) def _coconut_back_none_dubstar_compose(*funcs): """Backward none-aware double star composition operator (<**?..). - (<**?..)(f, g) is effectively equivalent to (*args, **kwargs) -> f?(**g(*args, **kwargs)).""" + (<**?..)(f, g) is effectively equivalent to (*args, **kwargs) => f?(**g(*args, **kwargs)).""" return _coconut_forward_none_dubstar_compose(*_coconut.reversed(funcs)) def _coconut_pipe(x, f): - """Pipe operator (|>). Equivalent to (x, f) -> f(x).""" + """Pipe operator (|>). Equivalent to (x, f) => f(x).""" return f(x) def _coconut_star_pipe(xs, f): - """Star pipe operator (*|>). Equivalent to (xs, f) -> f(*xs).""" + """Star pipe operator (*|>). Equivalent to (xs, f) => f(*xs).""" return f(*xs) def _coconut_dubstar_pipe(kws, f): - """Double star pipe operator (**|>). Equivalent to (kws, f) -> f(**kws).""" + """Double star pipe operator (**|>). Equivalent to (kws, f) => f(**kws).""" return f(**kws) def _coconut_back_pipe(f, x): - """Backward pipe operator (<|). Equivalent to (f, x) -> f(x).""" + """Backward pipe operator (<|). Equivalent to (f, x) => f(x).""" return f(x) def _coconut_back_star_pipe(f, xs): - """Backward star pipe operator (<*|). Equivalent to (f, xs) -> f(*xs).""" + """Backward star pipe operator (<*|). Equivalent to (f, xs) => f(*xs).""" return f(*xs) def _coconut_back_dubstar_pipe(f, kws): - """Backward double star pipe operator (<**|). Equivalent to (f, kws) -> f(**kws).""" + """Backward double star pipe operator (<**|). Equivalent to (f, kws) => f(**kws).""" return f(**kws) def _coconut_none_pipe(x, f): - """Nullable pipe operator (|?>). Equivalent to (x, f) -> f(x) if x is not None else None.""" + """Nullable pipe operator (|?>). Equivalent to (x, f) => f(x) if x is not None else None.""" return None if x is None else f(x) def _coconut_none_star_pipe(xs, f): - """Nullable star pipe operator (|?*>). Equivalent to (xs, f) -> f(*xs) if xs is not None else None.""" + """Nullable star pipe operator (|?*>). Equivalent to (xs, f) => f(*xs) if xs is not None else None.""" return None if xs is None else f(*xs) def _coconut_none_dubstar_pipe(kws, f): - """Nullable double star pipe operator (|?**>). Equivalent to (kws, f) -> f(**kws) if kws is not None else None.""" + """Nullable double star pipe operator (|?**>). Equivalent to (kws, f) => f(**kws) if kws is not None else None.""" return None if kws is None else f(**kws) def _coconut_back_none_pipe(f, x): - """Nullable backward pipe operator ( f(x) if x is not None else None.""" + """Nullable backward pipe operator ( f(x) if x is not None else None.""" return None if x is None else f(x) def _coconut_back_none_star_pipe(f, xs): - """Nullable backward star pipe operator (<*?|). Equivalent to (f, xs) -> f(*xs) if xs is not None else None.""" + """Nullable backward star pipe operator (<*?|). Equivalent to (f, xs) => f(*xs) if xs is not None else None.""" return None if xs is None else f(*xs) def _coconut_back_none_dubstar_pipe(f, kws): - """Nullable backward double star pipe operator (<**?|). Equivalent to (kws, f) -> f(**kws) if kws is not None else None.""" + """Nullable backward double star pipe operator (<**?|). Equivalent to (kws, f) => f(**kws) if kws is not None else None.""" return None if kws is None else f(**kws) def _coconut_assert(cond, msg=None): """Assert operator (assert). Asserts condition with optional message.""" @@ -563,27 +563,27 @@ def _coconut_raise(exc=None, from_exc=None): exc.__cause__ = from_exc raise exc def _coconut_bool_and(a, b): - """Boolean and operator (and). Equivalent to (a, b) -> a and b.""" + """Boolean and operator (and). Equivalent to (a, b) => a and b.""" return a and b def _coconut_bool_or(a, b): - """Boolean or operator (or). Equivalent to (a, b) -> a or b.""" + """Boolean or operator (or). Equivalent to (a, b) => a or b.""" return a or b def _coconut_in(a, b): - """Containment operator (in). Equivalent to (a, b) -> a in b.""" + """Containment operator (in). Equivalent to (a, b) => a in b.""" return a in b def _coconut_not_in(a, b): - """Negative containment operator (not in). Equivalent to (a, b) -> a not in b.""" + """Negative containment operator (not in). Equivalent to (a, b) => a not in b.""" return a not in b def _coconut_none_coalesce(a, b): - """None coalescing operator (??). Equivalent to (a, b) -> a if a is not None else b.""" + """None coalescing operator (??). Equivalent to (a, b) => a if a is not None else b.""" return b if a is None else a def _coconut_minus(a, b=_coconut_sentinel): - """Minus operator (-). Effectively equivalent to (a, b=None) -> a - b if b is not None else -a.""" + """Minus operator (-). Effectively equivalent to (a, b=None) => a - b if b is not None else -a.""" if b is _coconut_sentinel: return -a return a - b def _coconut_comma_op(*args): - """Comma operator (,). Equivalent to (*args) -> args.""" + """Comma operator (,). Equivalent to (*args) => args.""" return args {def_coconut_matmul} class scan(_coconut_has_iter): @@ -1678,7 +1678,7 @@ def _coconut_dict_merge(*dicts, **kwargs): prevlen = _coconut.len(newdict) return newdict def ident(x, **kwargs): - """The identity function. Generally equivalent to x -> x. Useful in point-free programming. + """The identity function. Generally equivalent to x => x. Useful in point-free programming. Accepts one keyword-only argument, side_effect, which specifies a function to call on the argument before it is returned.""" side_effect = kwargs.pop("side_effect", None) if kwargs: @@ -1874,30 +1874,36 @@ class const(_coconut_base_callable): def __repr__(self): return "const(%s)" % (_coconut.repr(self.value),) class _coconut_lifted(_coconut_base_callable): - __slots__ = ("func", "func_args", "func_kwargs") - def __init__(self, _coconut_func, *func_args, **func_kwargs): - self.func = _coconut_func + __slots__ = ("apart", "func", "func_args", "func_kwargs") + def __init__(self, apart, func, func_args, func_kwargs): + self.apart = apart + self.func = func self.func_args = func_args self.func_kwargs = func_kwargs def __reduce__(self): - return (self.__class__, (self.func,) + self.func_args, {lbrace}"func_kwargs": self.func_kwargs{rbrace}) + return (self.__class__, (self.apart, self.func, self.func_args, self.func_kwargs)) def __call__(self, *args, **kwargs): - return self.func(*(g(*args, **kwargs) for g in self.func_args), **_coconut_py_dict((k, h(*args, **kwargs)) for k, h in self.func_kwargs.items())) + if self.apart: + return self.func(*(f(x) for f, x in {_coconut_}zip(self.func_args, args, strict=True)), **_coconut_py_dict((k, self.func_kwargs[k](kwargs[k])) for k in {all_keys})) + else: + return self.func(*(g(*args, **kwargs) for g in self.func_args), **_coconut_py_dict((k, h(*args, **kwargs)) for k, h in self.func_kwargs.items())) def __repr__(self): - return "lift(%r)(%s%s)" % (self.func, ", ".join(_coconut.repr(g) for g in self.func_args), ", ".join(k + "=" + _coconut.repr(h) for k, h in self.func_kwargs.items())) + return "lift%s(%r)(%s%s)" % (self.func, ("_apart" if self.apart else ""), ", ".join(_coconut.repr(g) for g in self.func_args), ", ".join(k + "=" + _coconut.repr(h) for k, h in self.func_kwargs.items())) class lift(_coconut_base_callable): - """Lift a function up so that all of its arguments are functions. + """Lift a function up so that all of its arguments are functions that all take the same arguments. For a binary function f(x, y) and two unary functions g(z) and h(z), lift works as the S' combinator: lift(f)(g, h)(z) == f(g(z), h(z)) In general, lift is equivalent to: - def lift(f) = ((*func_args, **func_kwargs) -> (*args, **kwargs) -> + def lift(f) = ((*func_args, **func_kwargs) => (*args, **kwargs) => ( f(*(g(*args, **kwargs) for g in func_args), **{lbrace}k: h(*args, **kwargs) for k, h in func_kwargs.items(){rbrace})) + ) lift also supports a shortcut form such that lift(f, *func_args, **func_kwargs) is equivalent to lift(f)(*func_args, **func_kwargs). """ __slots__ = ("func",) + _apart = False def __new__(cls, func, *func_args, **func_kwargs): self = _coconut.super({_coconut_}lift, cls).__new__(cls) self.func = func @@ -1907,9 +1913,24 @@ class lift(_coconut_base_callable): def __reduce__(self): return (self.__class__, (self.func,)) def __repr__(self): - return "lift(%r)" % (self.func,) + return "lift%s(%r)" % (("_apart" if self._apart else ""), self.func) def __call__(self, *func_args, **func_kwargs): - return _coconut_lifted(self.func, *func_args, **func_kwargs) + return _coconut_lifted(self._apart, self.func, func_args, func_kwargs) +class lift_apart(lift): + """Lift a function up so that all of its arguments are functions that each take separate arguments. + + For a binary function f(x, y) and two unary functions g(z) and h(z), lift_apart works as the D2 combinator: + lift_apart(f)(g, h)(z, w) == f(g(z), h(w)) + + In general, lift_apart is equivalent to: + def lift_apart(func) = (*func_args, **func_kwargs) => (*args, **kwargs) => func( + *(f(x) for f, x in zip(func_args, args, strict=True)), + **{lbrace}k: func_kwargs[k](kwargs[k]) for k in func_kwargs.keys() | kwargs.keys(){rbrace}, + ) + + lift_apart also supports a shortcut form such that lift_apart(f, *func_args, **func_kwargs) is equivalent to lift_apart(f)(*func_args, **func_kwargs). + """ + _apart = True def all_equal(iterable): """For a given iterable, check whether all elements in that iterable are equal to each other. @@ -1974,7 +1995,7 @@ def collectby(key_func, iterable, value_func=None, **kwargs): If map_using is passed, calculate key_func and value_func by mapping them over the iterable using map_using as map. Useful with process_map/thread_map. """ - return {_coconut_}mapreduce(_coconut_lifted(_coconut_comma_op, key_func, {_coconut_}ident if value_func is None else value_func), iterable, **kwargs) + return {_coconut_}mapreduce(_coconut_lifted(False, _coconut_comma_op, (key_func, {_coconut_}ident if value_func is None else value_func), {empty_dict}), iterable, **kwargs) collectby.using_processes = _coconut_partial(_coconut_parallel_mapreduce, collectby, process_map) collectby.using_threads = _coconut_partial(_coconut_parallel_mapreduce, collectby, thread_map) def _namedtuple_of(**kwargs): diff --git a/coconut/constants.py b/coconut/constants.py index 4677df802..133e8dda5 100644 --- a/coconut/constants.py +++ b/coconut/constants.py @@ -639,7 +639,7 @@ def get_path_env_var(env_var, default): coconut_home = get_path_env_var(home_env_var, "~") -use_color = get_bool_env_var("COCONUT_USE_COLOR", None) +use_color_env_var = "COCONUT_USE_COLOR" error_color_code = "31" log_color_code = "93" @@ -794,6 +794,7 @@ def get_path_env_var(env_var, default): "flip", "const", "lift", + "lift_apart", "all_equal", "collectby", "mapreduce", diff --git a/coconut/root.py b/coconut/root.py index d6f1aa528..e38f17581 100644 --- a/coconut/root.py +++ b/coconut/root.py @@ -26,7 +26,7 @@ VERSION = "3.0.4" VERSION_NAME = None # False for release, int >= 1 for develop -DEVELOP = 7 +DEVELOP = 8 ALPHA = False # for pre releases rather than post releases assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1" diff --git a/coconut/terminal.py b/coconut/terminal.py index 8a19ca95f..7247c7641 100644 --- a/coconut/terminal.py +++ b/coconut/terminal.py @@ -47,7 +47,8 @@ taberrfmt, use_packrat_parser, embed_on_internal_exc, - use_color, + use_color_env_var, + get_bool_env_var, error_color_code, log_color_code, ansii_escape, @@ -209,6 +210,7 @@ def __init__(self, other=None): @classmethod def enable_colors(cls, file=None): """Attempt to enable CLI colors.""" + use_color = get_bool_env_var(use_color_env_var) if ( use_color is False or use_color is None and file is not None and not isatty(file) diff --git a/coconut/tests/src/cocotest/agnostic/suite.coco b/coconut/tests/src/cocotest/agnostic/suite.coco index 813fe05b0..1b5309bf1 100644 --- a/coconut/tests/src/cocotest/agnostic/suite.coco +++ b/coconut/tests/src/cocotest/agnostic/suite.coco @@ -1029,7 +1029,7 @@ forward 2""") == 900 assert Phi((,), (.+1), (.-1)) <| 5 == (6, 4) assert Psi((,), (.+1), 3) <| 4 == (4, 5) assert D1((,), 0, 1, (.+1)) <| 1 == (0, 1, 2) - assert D2((+), (.*2), 3, (.+1)) <| 4 == 11 + assert D2((+), (.*2), 3, (.+1)) <| 4 == 11 == D2_((+), (.*2), (.+1))(3, 4) assert E((+), 10, (*), 2) <| 3 == 16 assert Phi1((,), (+), (*), 2) <| 3 == (5, 6) assert BE((,), (+), 10, 2, (*), 2) <| 3 == (12, 6) @@ -1075,6 +1075,8 @@ forward 2""") == 900 assert pickle_round_trip(.loc[0]) <| (loc=[10]) == 10 assert pickle_round_trip(.method(0)) <| (method=const 10) == 10 assert pickle_round_trip(.method(x=10)) <| (method=x -> x) == 10 + assert sq_and_t2p1(10) == (100, 21) + assert first_false_and_last_true([3, 2, 1, 0, "11", "1", ""]) == (0, "1") with process_map.multiple_sequential_calls(): # type: ignore assert process_map(tuple <.. (|>)$(to_sort), qsorts) |> list == [to_sort |> sorted |> tuple] * len(qsorts) diff --git a/coconut/tests/src/cocotest/agnostic/util.coco b/coconut/tests/src/cocotest/agnostic/util.coco index c06598be7..0feebd3a1 100644 --- a/coconut/tests/src/cocotest/agnostic/util.coco +++ b/coconut/tests/src/cocotest/agnostic/util.coco @@ -1544,6 +1544,24 @@ def BE(f, g, x, y, h, z) = lift(f)(const(g x y), h$(z)) def on(b, u) = (,) ..> map$(u) ..*> b +def D2_(f, g, h) = lift_apart(f)(g, h) + + +# branching +branch = lift(,) +branched = lift_apart(,) + +sq_and_t2p1 = ( + branch(ident, (.*2)) + ..*> branched((.**2), (.+1)) # type: ignore +) + +first_false_and_last_true = ( + lift(,)(ident, reversed) + ..*> lift_apart(,)(dropwhile$(bool), dropwhile$(not)) # type: ignore + ..*> lift_apart(,)(.$[0], .$[0]) # type: ignore +) + # maximum difference def maxdiff1(ns) = ( diff --git a/coconut/tests/src/extras.coco b/coconut/tests/src/extras.coco index 7be756796..58bdb6557 100644 --- a/coconut/tests/src/extras.coco +++ b/coconut/tests/src/extras.coco @@ -1,5 +1,8 @@ +import os from collections.abc import Sequence +os.environ["COCONUT_USE_COLOR"] = "False" + from coconut.__coconut__ import consume as coc_consume from coconut.constants import ( IPY, From 58ed6dbc1132a652646a8478b88864c8bd16a3bd Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Wed, 6 Dec 2023 19:42:00 -0800 Subject: [PATCH 17/54] Add (if) op Resolves #813. --- __coconut__/__init__.pyi | 5 +++++ coconut/__coconut__.pyi | 2 +- coconut/compiler/grammar.py | 1 + coconut/compiler/header.py | 2 +- coconut/compiler/templates/header.py_template | 3 +++ coconut/root.py | 2 +- coconut/tests/src/cocotest/agnostic/primary_2.coco | 1 + 7 files changed, 13 insertions(+), 3 deletions(-) diff --git a/__coconut__/__init__.pyi b/__coconut__/__init__.pyi index c68c7b69c..2cba5f7c7 100644 --- a/__coconut__/__init__.pyi +++ b/__coconut__/__init__.pyi @@ -1216,6 +1216,11 @@ def _coconut_comma_op(*args: _t.Any) -> _Tuple: ... +def _coconut_if_op(cond: _t.Any, if_true: _T, if_false: _U) -> _t.Union[_T, _U]: + """If operator (if). Equivalent to (cond, if_true, if_false) => if_true if cond else if_false.""" + ... + + if sys.version_info < (3, 5): @_t.overload def _coconut_matmul(a: _T, b: _T) -> _T: ... diff --git a/coconut/__coconut__.pyi b/coconut/__coconut__.pyi index cca933f3f..520b56973 100644 --- a/coconut/__coconut__.pyi +++ b/coconut/__coconut__.pyi @@ -1,2 +1,2 @@ from __coconut__ import * -from __coconut__ import _coconut_tail_call, _coconut_tco, _coconut_call_set_names, _coconut_handle_cls_kwargs, _coconut_handle_cls_stargs, _namedtuple_of, _coconut, _coconut_Expected, _coconut_MatchError, _coconut_SupportsAdd, _coconut_SupportsMinus, _coconut_SupportsMul, _coconut_SupportsPow, _coconut_SupportsTruediv, _coconut_SupportsFloordiv, _coconut_SupportsMod, _coconut_SupportsAnd, _coconut_SupportsXor, _coconut_SupportsOr, _coconut_SupportsLshift, _coconut_SupportsRshift, _coconut_SupportsMatmul, _coconut_SupportsInv, _coconut_Expected, _coconut_MatchError, _coconut_iter_getitem, _coconut_base_compose, _coconut_forward_compose, _coconut_back_compose, _coconut_forward_star_compose, _coconut_back_star_compose, _coconut_forward_dubstar_compose, _coconut_back_dubstar_compose, _coconut_pipe, _coconut_star_pipe, _coconut_dubstar_pipe, _coconut_back_pipe, _coconut_back_star_pipe, _coconut_back_dubstar_pipe, _coconut_none_pipe, _coconut_none_star_pipe, _coconut_none_dubstar_pipe, _coconut_bool_and, _coconut_bool_or, _coconut_none_coalesce, _coconut_minus, _coconut_map, _coconut_partial, _coconut_complex_partial, _coconut_get_function_match_error, _coconut_base_pattern_func, _coconut_addpattern, _coconut_sentinel, _coconut_assert, _coconut_raise, _coconut_mark_as_match, _coconut_reiterable, _coconut_self_match_types, _coconut_dict_merge, _coconut_exec, _coconut_comma_op, _coconut_arr_concat_op, _coconut_mk_anon_namedtuple, _coconut_matmul, _coconut_py_str, _coconut_flatten, _coconut_multiset, _coconut_back_none_pipe, _coconut_back_none_star_pipe, _coconut_back_none_dubstar_pipe, _coconut_forward_none_compose, _coconut_back_none_compose, _coconut_forward_none_star_compose, _coconut_back_none_star_compose, _coconut_forward_none_dubstar_compose, _coconut_back_none_dubstar_compose, _coconut_call_or_coefficient, _coconut_in, _coconut_not_in, _coconut_attritemgetter +from __coconut__ import _coconut_tail_call, _coconut_tco, _coconut_call_set_names, _coconut_handle_cls_kwargs, _coconut_handle_cls_stargs, _namedtuple_of, _coconut, _coconut_Expected, _coconut_MatchError, _coconut_SupportsAdd, _coconut_SupportsMinus, _coconut_SupportsMul, _coconut_SupportsPow, _coconut_SupportsTruediv, _coconut_SupportsFloordiv, _coconut_SupportsMod, _coconut_SupportsAnd, _coconut_SupportsXor, _coconut_SupportsOr, _coconut_SupportsLshift, _coconut_SupportsRshift, _coconut_SupportsMatmul, _coconut_SupportsInv, _coconut_Expected, _coconut_MatchError, _coconut_iter_getitem, _coconut_base_compose, _coconut_forward_compose, _coconut_back_compose, _coconut_forward_star_compose, _coconut_back_star_compose, _coconut_forward_dubstar_compose, _coconut_back_dubstar_compose, _coconut_pipe, _coconut_star_pipe, _coconut_dubstar_pipe, _coconut_back_pipe, _coconut_back_star_pipe, _coconut_back_dubstar_pipe, _coconut_none_pipe, _coconut_none_star_pipe, _coconut_none_dubstar_pipe, _coconut_bool_and, _coconut_bool_or, _coconut_none_coalesce, _coconut_minus, _coconut_map, _coconut_partial, _coconut_complex_partial, _coconut_get_function_match_error, _coconut_base_pattern_func, _coconut_addpattern, _coconut_sentinel, _coconut_assert, _coconut_raise, _coconut_mark_as_match, _coconut_reiterable, _coconut_self_match_types, _coconut_dict_merge, _coconut_exec, _coconut_comma_op, _coconut_arr_concat_op, _coconut_mk_anon_namedtuple, _coconut_matmul, _coconut_py_str, _coconut_flatten, _coconut_multiset, _coconut_back_none_pipe, _coconut_back_none_star_pipe, _coconut_back_none_dubstar_pipe, _coconut_forward_none_compose, _coconut_back_none_compose, _coconut_forward_none_star_compose, _coconut_back_none_star_compose, _coconut_forward_none_dubstar_compose, _coconut_back_none_dubstar_compose, _coconut_call_or_coefficient, _coconut_in, _coconut_not_in, _coconut_attritemgetter, _coconut_if_op diff --git a/coconut/compiler/grammar.py b/coconut/compiler/grammar.py index be4d19268..e4e24f46b 100644 --- a/coconut/compiler/grammar.py +++ b/coconut/compiler/grammar.py @@ -1067,6 +1067,7 @@ class Grammar(object): | fixto(dollar, "_coconut_partial") | fixto(keyword("assert"), "_coconut_assert") | fixto(keyword("raise"), "_coconut_raise") + | fixto(keyword("if"), "_coconut_if_op") | fixto(keyword("is") + keyword("not"), "_coconut.operator.is_not") | fixto(keyword("not") + keyword("in"), "_coconut_not_in") diff --git a/coconut/compiler/header.py b/coconut/compiler/header.py index 37e813c8b..1306fb2a2 100644 --- a/coconut/compiler/header.py +++ b/coconut/compiler/header.py @@ -639,7 +639,7 @@ def __anext__(self): # (extra_format_dict is to keep indentation levels matching) extra_format_dict = dict( # when anything is added to this list it must also be added to *both* __coconut__ stub files - underscore_imports="{tco_comma}{call_set_names_comma}{handle_cls_args_comma}_namedtuple_of, _coconut, _coconut_Expected, _coconut_MatchError, _coconut_SupportsAdd, _coconut_SupportsMinus, _coconut_SupportsMul, _coconut_SupportsPow, _coconut_SupportsTruediv, _coconut_SupportsFloordiv, _coconut_SupportsMod, _coconut_SupportsAnd, _coconut_SupportsXor, _coconut_SupportsOr, _coconut_SupportsLshift, _coconut_SupportsRshift, _coconut_SupportsMatmul, _coconut_SupportsInv, _coconut_iter_getitem, _coconut_base_compose, _coconut_forward_compose, _coconut_back_compose, _coconut_forward_star_compose, _coconut_back_star_compose, _coconut_forward_dubstar_compose, _coconut_back_dubstar_compose, _coconut_pipe, _coconut_star_pipe, _coconut_dubstar_pipe, _coconut_back_pipe, _coconut_back_star_pipe, _coconut_back_dubstar_pipe, _coconut_none_pipe, _coconut_none_star_pipe, _coconut_none_dubstar_pipe, _coconut_bool_and, _coconut_bool_or, _coconut_none_coalesce, _coconut_minus, _coconut_map, _coconut_partial, _coconut_complex_partial, _coconut_get_function_match_error, _coconut_base_pattern_func, _coconut_addpattern, _coconut_sentinel, _coconut_assert, _coconut_raise, _coconut_mark_as_match, _coconut_reiterable, _coconut_self_match_types, _coconut_dict_merge, _coconut_exec, _coconut_comma_op, _coconut_arr_concat_op, _coconut_mk_anon_namedtuple, _coconut_matmul, _coconut_py_str, _coconut_flatten, _coconut_multiset, _coconut_back_none_pipe, _coconut_back_none_star_pipe, _coconut_back_none_dubstar_pipe, _coconut_forward_none_compose, _coconut_back_none_compose, _coconut_forward_none_star_compose, _coconut_back_none_star_compose, _coconut_forward_none_dubstar_compose, _coconut_back_none_dubstar_compose, _coconut_call_or_coefficient, _coconut_in, _coconut_not_in, _coconut_attritemgetter".format(**format_dict), + underscore_imports="{tco_comma}{call_set_names_comma}{handle_cls_args_comma}_namedtuple_of, _coconut, _coconut_Expected, _coconut_MatchError, _coconut_SupportsAdd, _coconut_SupportsMinus, _coconut_SupportsMul, _coconut_SupportsPow, _coconut_SupportsTruediv, _coconut_SupportsFloordiv, _coconut_SupportsMod, _coconut_SupportsAnd, _coconut_SupportsXor, _coconut_SupportsOr, _coconut_SupportsLshift, _coconut_SupportsRshift, _coconut_SupportsMatmul, _coconut_SupportsInv, _coconut_iter_getitem, _coconut_base_compose, _coconut_forward_compose, _coconut_back_compose, _coconut_forward_star_compose, _coconut_back_star_compose, _coconut_forward_dubstar_compose, _coconut_back_dubstar_compose, _coconut_pipe, _coconut_star_pipe, _coconut_dubstar_pipe, _coconut_back_pipe, _coconut_back_star_pipe, _coconut_back_dubstar_pipe, _coconut_none_pipe, _coconut_none_star_pipe, _coconut_none_dubstar_pipe, _coconut_bool_and, _coconut_bool_or, _coconut_none_coalesce, _coconut_minus, _coconut_map, _coconut_partial, _coconut_complex_partial, _coconut_get_function_match_error, _coconut_base_pattern_func, _coconut_addpattern, _coconut_sentinel, _coconut_assert, _coconut_raise, _coconut_mark_as_match, _coconut_reiterable, _coconut_self_match_types, _coconut_dict_merge, _coconut_exec, _coconut_comma_op, _coconut_arr_concat_op, _coconut_mk_anon_namedtuple, _coconut_matmul, _coconut_py_str, _coconut_flatten, _coconut_multiset, _coconut_back_none_pipe, _coconut_back_none_star_pipe, _coconut_back_none_dubstar_pipe, _coconut_forward_none_compose, _coconut_back_none_compose, _coconut_forward_none_star_compose, _coconut_back_none_star_compose, _coconut_forward_none_dubstar_compose, _coconut_back_none_dubstar_compose, _coconut_call_or_coefficient, _coconut_in, _coconut_not_in, _coconut_attritemgetter, _coconut_if_op".format(**format_dict), import_typing=pycondition( (3, 5), if_ge=''' diff --git a/coconut/compiler/templates/header.py_template b/coconut/compiler/templates/header.py_template index 30e75868a..1e61620a6 100644 --- a/coconut/compiler/templates/header.py_template +++ b/coconut/compiler/templates/header.py_template @@ -585,6 +585,9 @@ def _coconut_minus(a, b=_coconut_sentinel): def _coconut_comma_op(*args): """Comma operator (,). Equivalent to (*args) => args.""" return args +def _coconut_if_op(cond, if_true, if_false): + """If operator (if). Equivalent to (cond, if_true, if_false) => if_true if cond else if_false.""" + return if_true if cond else if_false {def_coconut_matmul} class scan(_coconut_has_iter): """Reduce func over iterable, yielding intermediate results, diff --git a/coconut/root.py b/coconut/root.py index e38f17581..0e1a95ebf 100644 --- a/coconut/root.py +++ b/coconut/root.py @@ -26,7 +26,7 @@ VERSION = "3.0.4" VERSION_NAME = None # False for release, int >= 1 for develop -DEVELOP = 8 +DEVELOP = 9 ALPHA = False # for pre releases rather than post releases assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1" diff --git a/coconut/tests/src/cocotest/agnostic/primary_2.coco b/coconut/tests/src/cocotest/agnostic/primary_2.coco index b8b3afdd6..b562ee789 100644 --- a/coconut/tests/src/cocotest/agnostic/primary_2.coco +++ b/coconut/tests/src/cocotest/agnostic/primary_2.coco @@ -420,6 +420,7 @@ def primary_test_2() -> bool: arr |>= [. ; 2] arr |>= [[3; 4] ;; .] assert arr == [3; 4;; 1; 2] == [[3; 4] ;; .] |> call$(?, [. ; 2] |> call$(?, 1)) + assert (if)(10, 20, 30) == 20 == (if)(0, 10, 20) with process_map.multiple_sequential_calls(): # type: ignore assert map((+), range(3), range(4)$[:-1], strict=True) |> list == [0, 2, 4] == process_map((+), range(3), range(4)$[:-1], strict=True) |> list # type: ignore From 589e0d56bdb39cf183d38cc6eec981e6f023bce6 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Thu, 7 Dec 2023 00:41:29 -0800 Subject: [PATCH 18/54] Disallow partial (if) --- coconut/compiler/grammar.py | 2 +- coconut/tests/src/extras.coco | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/coconut/compiler/grammar.py b/coconut/compiler/grammar.py index e4e24f46b..0b7497481 100644 --- a/coconut/compiler/grammar.py +++ b/coconut/compiler/grammar.py @@ -1080,7 +1080,7 @@ class Grammar(object): | fixto(keyword("is"), "_coconut.operator.is_") | fixto(keyword("in"), "_coconut_in") ) - partialable_op = base_op_item | infix_op + partialable_op = ~keyword("if") + (base_op_item | infix_op) partial_op_item_tokens = ( labeled_group(dot.suppress() + partialable_op + test_no_infix, "right partial") | labeled_group(test_no_infix + partialable_op + dot.suppress(), "left partial") diff --git a/coconut/tests/src/extras.coco b/coconut/tests/src/extras.coco index 58bdb6557..6ade7f0c8 100644 --- a/coconut/tests/src/extras.coco +++ b/coconut/tests/src/extras.coco @@ -242,6 +242,7 @@ def f() = assert_raises(-> parse("if a = b: pass"), CoconutParseError, err_has="misplaced assignment") assert_raises(-> parse("while a == b"), CoconutParseError, err_has="misplaced newline") assert_raises(-> parse("0xfgf"), CoconutParseError, err_has=r" \~~^") + assert_raises(-> parse("(. if 1)"), CoconutParseError, err_has=r" \~~^") try: parse(""" From 9d168199d3afb8e4c7c2102295e816cff324e49a Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Fri, 8 Dec 2023 00:22:58 -0800 Subject: [PATCH 19/54] Clarify docs --- DOCS.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/DOCS.md b/DOCS.md index 3d565cc5f..5b5150cfe 100644 --- a/DOCS.md +++ b/DOCS.md @@ -1726,6 +1726,8 @@ If the last `statement` (not followed by a semicolon) in a statement lambda is a Statement lambdas also support implicit lambda syntax such that `def => _` is equivalent to `def (_=None) => _` as well as explicitly marking them as pattern-matching such that `match def (x) => x` will be a pattern-matching function. +Importantly, statement lambdas do not capture variables introduced only in the surrounding expression, e.g. inside of a list comprehension or normal lambda. To avoid such situations, only nest statement lambdas inside other statement lambdas, and explicitly partially apply a statement lambda to pass in a value from a list comprehension. + Note that statement lambdas have a lower precedence than normal lambdas and thus capture things like trailing commas. To avoid confusion, statement lambdas should always be wrapped in their own set of parentheses. _Deprecated: Statement lambdas also support `->` instead of `=>`. Note that when using `->`, any lambdas in the body of the statement lambda must also use `->` rather than `=>`._ From 2abb2762875d0ec484b59d41129c553cb23b4655 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Thu, 21 Dec 2023 20:01:10 -0800 Subject: [PATCH 20/54] Add xarray support Resolves #816. --- DOCS.md | 15 +++-- __coconut__/__init__.pyi | 2 + _coconut/__init__.pyi | 4 +- coconut/compiler/header.py | 6 +- coconut/compiler/templates/header.py_template | 65 ++++++++++++------- coconut/constants.py | 10 ++- coconut/root.py | 2 +- coconut/tests/src/extras.coco | 33 ++++++++-- 8 files changed, 98 insertions(+), 39 deletions(-) diff --git a/DOCS.md b/DOCS.md index 5b5150cfe..9c1671d6c 100644 --- a/DOCS.md +++ b/DOCS.md @@ -487,7 +487,7 @@ To allow for better use of [`numpy`](https://numpy.org/) objects in Coconut, all - `numpy` objects are allowed seamlessly in Coconut's [implicit coefficient syntax](#implicit-function-application-and-coefficients), allowing the use of e.g. `A B**2` shorthand for `A * B**2` when `A` and `B` are `numpy` arrays (note: **not** `A @ B**2`). - Coconut supports `@` for matrix multiplication of `numpy` arrays on all Python versions, as well as supplying the `(@)` [operator function](#operator-functions). -Additionally, Coconut provides the exact same support for [`pandas`](https://pandas.pydata.org/), [`pytorch`](https://pytorch.org/), and [`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html), including using `pandas`/`jax`-specific methods over `numpy` methods when given `pandas`/`jax` objects. +Additionally, Coconut provides the exact same support for [`pandas`](https://pandas.pydata.org/), [`xarray`](https://docs.xarray.dev/en/stable/), [`pytorch`](https://pytorch.org/), and [`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html) objects. #### `xonsh` Support @@ -3383,14 +3383,8 @@ In Haskell, `fmap(func, obj)` takes a data type `obj` and returns a new data typ `fmap` can also be used on the built-in objects `str`, `dict`, `list`, `tuple`, `set`, `frozenset`, and `dict` as a variant of `map` that returns back an object of the same type. -The behavior of `fmap` for a given object can be overridden by defining an `__fmap__(self, func)` magic method that will be called whenever `fmap` is invoked on that object. Note that `__fmap__` implementations should always satisfy the [Functor Laws](https://wiki.haskell.org/Functor). - For `dict`, or any other `collections.abc.Mapping`, `fmap` will map over the mapping's `.items()` instead of the default iteration through its `.keys()`, with the new mapping reconstructed from the mapped over items. _Deprecated: `fmap$(starmap_over_mappings=True)` will `starmap` over the `.items()` instead of `map` over them._ -For [`numpy`](#numpy-integration) objects, `fmap` will use [`np.vectorize`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.vectorize.html) to produce the result. - -For [`pandas`](https://pandas.pydata.org/) objects, `fmap` will use [`.apply`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.apply.html) along the last axis (so row-wise for `DataFrame`'s, element-wise for `Series`'s). - For asynchronous iterables, `fmap` will map asynchronously, making `fmap` equivalent in that case to ```coconut_python async def fmap_over_async_iters(func, async_iter): @@ -3399,6 +3393,13 @@ async def fmap_over_async_iters(func, async_iter): ``` such that `fmap` can effectively be used as an async map. +Some objects from external libraries are also given special support: +* For [`numpy`](#numpy-integration) objects, `fmap` will use [`np.vectorize`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.vectorize.html) to produce the result. +* For [`pandas`](https://pandas.pydata.org/) objects, `fmap` will use [`.apply`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.apply.html) along the last axis (so row-wise for `DataFrame`'s, element-wise for `Series`'s). +* For [`xarray`](https://docs.xarray.dev/en/stable/) objects, `fmap` will first convert them into `pandas` objects, apply `fmap`, then convert them back. + +The behavior of `fmap` for a given object can be overridden by defining an `__fmap__(self, func)` magic method that will be called whenever `fmap` is invoked on that object. Note that `__fmap__` implementations should always satisfy the [Functor Laws](https://wiki.haskell.org/Functor). + _Deprecated: `fmap(func, obj, fallback_to_init=True)` will fall back to `obj.__class__(map(func, obj))` if no `fmap` implementation is available rather than raise `TypeError`._ ##### Example diff --git a/__coconut__/__init__.pyi b/__coconut__/__init__.pyi index 2cba5f7c7..007cdcfab 100644 --- a/__coconut__/__init__.pyi +++ b/__coconut__/__init__.pyi @@ -1466,6 +1466,8 @@ def fmap(func: _t.Callable[[_T, _U], _t.Tuple[_V, _W]], obj: _t.Mapping[_T, _U], """ ... +_coconut_fmap = fmap + def _coconut_handle_cls_kwargs(**kwargs: _t.Dict[_t.Text, _t.Any]) -> _t.Callable[[_T], _T]: ... diff --git a/_coconut/__init__.pyi b/_coconut/__init__.pyi index 31d9fd411..82d320478 100644 --- a/_coconut/__init__.pyi +++ b/_coconut/__init__.pyi @@ -109,8 +109,10 @@ npt = _npt # Fake, like typing zip_longest = _zip_longest numpy_modules: _t.Any = ... -pandas_numpy_modules: _t.Any = ... +xarray_modules: _t.Any = ... +pandas_modules: _t.Any = ... jax_numpy_modules: _t.Any = ... + tee_type: _t.Any = ... reiterables: _t.Any = ... fmappables: _t.Any = ... diff --git a/coconut/compiler/header.py b/coconut/compiler/header.py index 1306fb2a2..2d14cbc88 100644 --- a/coconut/compiler/header.py +++ b/coconut/compiler/header.py @@ -33,8 +33,9 @@ justify_len, report_this_text, numpy_modules, - pandas_numpy_modules, + pandas_modules, jax_numpy_modules, + xarray_modules, self_match_types, is_data_var, data_defaults_var, @@ -291,7 +292,8 @@ def process_header_args(which, use_hash, target, no_tco, strict, no_wrap): from_None=" from None" if target.startswith("3") else "", process_="process_" if target_info >= (3, 13) else "", numpy_modules=tuple_str_of(numpy_modules, add_quotes=True), - pandas_numpy_modules=tuple_str_of(pandas_numpy_modules, add_quotes=True), + xarray_modules=tuple_str_of(xarray_modules, add_quotes=True), + pandas_modules=tuple_str_of(pandas_modules, add_quotes=True), jax_numpy_modules=tuple_str_of(jax_numpy_modules, add_quotes=True), self_match_types=tuple_str_of(self_match_types), comma_bytearray=", bytearray" if not target.startswith("3") else "", diff --git a/coconut/compiler/templates/header.py_template b/coconut/compiler/templates/header.py_template index 1e61620a6..f67cd339d 100644 --- a/coconut/compiler/templates/header.py_template +++ b/coconut/compiler/templates/header.py_template @@ -54,7 +54,8 @@ class _coconut{object}:{COMMENT.EVERYTHING_HERE_MUST_BE_COPIED_TO_STUB_FILE} else: abc.Sequence.register(numpy.ndarray) numpy_modules = {numpy_modules} - pandas_numpy_modules = {pandas_numpy_modules} + xarray_modules = {xarray_modules} + pandas_modules = {pandas_modules} jax_numpy_modules = {jax_numpy_modules} tee_type = type(itertools.tee((), 1)[0]) reiterables = abc.Sequence, abc.Mapping, abc.Set @@ -121,6 +122,20 @@ class _coconut_Sentinel(_coconut_baseclass): _coconut_sentinel = _coconut_Sentinel() def _coconut_get_base_module(obj): return obj.__class__.__module__.split(".", 1)[0] +def _coconut_xarray_to_pandas(obj): + import xarray + if isinstance(obj, xarray.Dataset): + return obj.to_dataframe() + elif isinstance(obj, xarray.DataArray): + return obj.to_series() + else: + return obj.to_pandas() +def _coconut_xarray_to_numpy(obj): + import xarray + if isinstance(obj, xarray.Dataset): + return obj.to_dataframe().to_numpy() + else: + return obj.to_numpy() class MatchError(_coconut_baseclass, Exception): """Pattern-matching error. Has attributes .pattern, .value, and .message."""{COMMENT.no_slots_to_allow_setattr_below} max_val_repr_len = 500 @@ -752,8 +767,10 @@ Additionally supports Cartesian products of numpy arrays.""" if iterables: it_modules = [_coconut_get_base_module(it) for it in iterables] if _coconut.all(mod in _coconut.numpy_modules for mod in it_modules): - if _coconut.any(mod in _coconut.pandas_numpy_modules for mod in it_modules): - iterables = tuple((it.to_numpy() if _coconut_get_base_module(it) in _coconut.pandas_numpy_modules else it) for it in iterables) + if _coconut.any(mod in _coconut.xarray_modules for mod in it_modules): + iterables = tuple((_coconut_xarray_to_numpy(it) if mod in _coconut.xarray_modules else it) for it, mod in _coconut.zip(iterables, it_modules)) + if _coconut.any(mod in _coconut.pandas_modules for mod in it_modules): + iterables = tuple((it.to_numpy() if mod in _coconut.pandas_modules else it) for it, mod in _coconut.zip(iterables, it_modules)) if _coconut.any(mod in _coconut.jax_numpy_modules for mod in it_modules): from jax import numpy else: @@ -1605,7 +1622,9 @@ def fmap(func, obj, **kwargs): if result is not _coconut.NotImplemented: return result obj_module = _coconut_get_base_module(obj) - if obj_module in _coconut.pandas_numpy_modules: + if obj_module in _coconut.xarray_modules: + return {_coconut_}fmap(func, _coconut_xarray_to_pandas(obj)).to_xarray() + if obj_module in _coconut.pandas_modules: if obj.ndim <= 1: return obj.apply(func) return obj.apply(func, axis=obj.ndim-1) @@ -1941,7 +1960,9 @@ def all_equal(iterable): """ iterable_module = _coconut_get_base_module(iterable) if iterable_module in _coconut.numpy_modules: - if iterable_module in _coconut.pandas_numpy_modules: + if iterable_module in _coconut.xarray_modules: + iterable = _coconut_xarray_to_numpy(iterable) + elif iterable_module in _coconut.pandas_modules: iterable = iterable.to_numpy() return not _coconut.len(iterable) or (iterable == iterable[0]).all() first_item = _coconut_sentinel @@ -2014,8 +2035,11 @@ def _coconut_mk_anon_namedtuple(fields, types=None, of_kwargs=None): return NT return NT(**of_kwargs) def _coconut_ndim(arr): - if (_coconut_get_base_module(arr) in _coconut.numpy_modules or _coconut.hasattr(arr.__class__, "__matconcat__")) and _coconut.hasattr(arr, "ndim"): + arr_mod = _coconut_get_base_module(arr) + if (arr_mod in _coconut.numpy_modules or _coconut.hasattr(arr.__class__, "__matconcat__")) and _coconut.hasattr(arr, "ndim"): return arr.ndim + if arr_mod in _coconut.xarray_modules:{COMMENT.if_we_got_here_its_a_Dataset_not_a_DataArray} + return 2 if not _coconut.isinstance(arr, _coconut.abc.Sequence) or _coconut.isinstance(arr, (_coconut.str, _coconut.bytes)): return 0 if _coconut.len(arr) == 0: @@ -2040,23 +2064,20 @@ def _coconut_expand_arr(arr, new_dims): arr = [arr] return arr def _coconut_concatenate(arrs, axis): - matconcat = None for a in arrs: if _coconut.hasattr(a.__class__, "__matconcat__"): - matconcat = a.__class__.__matconcat__ - break - a_module = _coconut_get_base_module(a) - if a_module in _coconut.pandas_numpy_modules: - from pandas import concat as matconcat - break - if a_module in _coconut.jax_numpy_modules: - from jax.numpy import concatenate as matconcat - break - if a_module in _coconut.numpy_modules: - matconcat = _coconut.numpy.concatenate - break - if matconcat is not None: - return matconcat(arrs, axis=axis) + return a.__class__.__matconcat__(arrs, axis=axis) + arr_modules = [_coconut_get_base_module(a) for a in arrs] + if any(mod in _coconut.xarray_modules for mod in arr_modules): + return _coconut_concatenate([(_coconut_xarray_to_pandas(a) if mod in _coconut.xarray_modules else a) for a, mod in _coconut.zip(arrs, arr_modules)], axis).to_xarray() + if any(mod in _coconut.pandas_modules for mod in arr_modules): + import pandas + return pandas.concat(arrs, axis=axis) + if any(mod in _coconut.jax_numpy_modules for mod in arr_modules): + import jax.numpy + return jax.numpy.concatenate(arrs, axis=axis) + if any(mod in _coconut.numpy_modules for mod in arr_modules): + return _coconut.numpy.concatenate(arrs, axis=axis) if not axis: return _coconut.list(_coconut.itertools.chain.from_iterable(arrs)) return [_coconut_concatenate(rows, axis - 1) for rows in _coconut.zip(*arrs)] @@ -2209,4 +2230,4 @@ class _coconut_SupportsInv(_coconut.typing.Protocol): {def_async_map} {def_aliases} _coconut_self_match_types = {self_match_types} -_coconut_Expected, _coconut_MatchError, _coconut_cartesian_product, _coconut_count, _coconut_cycle, _coconut_enumerate, _coconut_flatten, _coconut_filter, _coconut_groupsof, _coconut_ident, _coconut_lift, _coconut_map, _coconut_mapreduce, _coconut_multiset, _coconut_range, _coconut_reiterable, _coconut_reversed, _coconut_scan, _coconut_starmap, _coconut_tee, _coconut_windowsof, _coconut_zip, _coconut_zip_longest, TYPE_CHECKING, reduce, takewhile, dropwhile = Expected, MatchError, cartesian_product, count, cycle, enumerate, flatten, filter, groupsof, ident, lift, map, mapreduce, multiset, range, reiterable, reversed, scan, starmap, tee, windowsof, zip, zip_longest, False, _coconut.functools.reduce, _coconut.itertools.takewhile, _coconut.itertools.dropwhile{COMMENT.anything_added_here_should_be_copied_to_stub_file} +_coconut_Expected, _coconut_MatchError, _coconut_cartesian_product, _coconut_count, _coconut_cycle, _coconut_enumerate, _coconut_flatten, _coconut_fmap, _coconut_filter, _coconut_groupsof, _coconut_ident, _coconut_lift, _coconut_map, _coconut_mapreduce, _coconut_multiset, _coconut_range, _coconut_reiterable, _coconut_reversed, _coconut_scan, _coconut_starmap, _coconut_tee, _coconut_windowsof, _coconut_zip, _coconut_zip_longest, TYPE_CHECKING, reduce, takewhile, dropwhile = Expected, MatchError, cartesian_product, count, cycle, enumerate, flatten, fmap, filter, groupsof, ident, lift, map, mapreduce, multiset, range, reiterable, reversed, scan, starmap, tee, windowsof, zip, zip_longest, False, _coconut.functools.reduce, _coconut.itertools.takewhile, _coconut.itertools.dropwhile{COMMENT.anything_added_here_should_be_copied_to_stub_file} diff --git a/coconut/constants.py b/coconut/constants.py index 133e8dda5..a3268f3b3 100644 --- a/coconut/constants.py +++ b/coconut/constants.py @@ -180,7 +180,10 @@ def get_path_env_var(env_var, default): sys.setrecursionlimit(default_recursion_limit) # modules that numpy-like arrays can live in -pandas_numpy_modules = ( +xarray_modules = ( + "xarray", +) +pandas_modules = ( "pandas", ) jax_numpy_modules = ( @@ -190,7 +193,8 @@ def get_path_env_var(env_var, default): "numpy", "torch", ) + ( - pandas_numpy_modules + xarray_modules + + pandas_modules + jax_numpy_modules ) @@ -999,6 +1003,7 @@ def get_path_env_var(env_var, default): ("numpy", "py34;py<39"), ("numpy", "py39"), ("pandas", "py36"), + ("xarray", "py39"), ), "tests": ( ("pytest", "py<36"), @@ -1021,6 +1026,7 @@ def get_path_env_var(env_var, default): ("trollius", "py<3;cpy"): (2, 2), "requests": (2, 31), ("numpy", "py39"): (1, 26), + ("xarray", "py39"): (2023,), ("dataclasses", "py==36"): (0, 8), ("aenum", "py<34"): (3, 1, 15), "pydata-sphinx-theme": (0, 14), diff --git a/coconut/root.py b/coconut/root.py index 0e1a95ebf..123449d7c 100644 --- a/coconut/root.py +++ b/coconut/root.py @@ -26,7 +26,7 @@ VERSION = "3.0.4" VERSION_NAME = None # False for release, int >= 1 for develop -DEVELOP = 9 +DEVELOP = 10 ALPHA = False # for pre releases rather than post releases assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1" diff --git a/coconut/tests/src/extras.coco b/coconut/tests/src/extras.coco index 6ade7f0c8..1c5fbd7a1 100644 --- a/coconut/tests/src/extras.coco +++ b/coconut/tests/src/extras.coco @@ -10,6 +10,7 @@ from coconut.constants import ( PY34, PY35, PY36, + PY39, PYPY, ) # type: ignore from coconut._pyparsing import USE_COMPUTATION_GRAPH # type: ignore @@ -664,22 +665,46 @@ def test_pandas() -> bool: return True +def test_xarray() -> bool: + import xarray as xr + import numpy as np + def ds1 `dataset_equal` ds2 = (ds1 == ds2).all().values() |> all + da = xr.DataArray([10, 11;; 12, 13], dims=["x", "y"]) + ds = xr.Dataset({"a": da, "b": da + 10}) + assert ds$[0] == "a" + ds_ = [da; da + 10] + assert ds `dataset_equal` ds_ # type: ignore + ds__ = [da; da |> fmap$(.+10)] + assert ds `dataset_equal` ds__ # type: ignore + assert ds `dataset_equal` (ds |> fmap$(ident)) + assert da.to_numpy() `np.array_equal` (da |> fmap$(ident) |> .to_numpy()) + assert (ds |> fmap$(r -> r["a"] + r["b"]) |> .to_numpy()) `np.array_equal` np.array([30; 32;; 34; 36]) + assert not all_equal(da) + assert not all_equal(ds) + assert multi_enumerate(da) |> list == [((0, 0), 10), ((0, 1), 11), ((1, 0), 12), ((1, 1), 13)] + assert cartesian_product(da.sel(x=0), da.sel(x=1)) `np.array_equal` np.array([10; 12;; 10; 13;; 11; 12;; 11; 13]) # type: ignore + return True + + def test_extras() -> bool: if not PYPY and (PY2 or PY34): assert test_numpy() is True print(".", end="") if not PYPY and PY36: assert test_pandas() is True # . + print(".", end="") + if not PYPY and PY39: + assert test_xarray() is True # .. print(".") # newline bc we print stuff after this - assert test_setup_none() is True # .. + assert test_setup_none() is True # ... print(".") # ditto - assert test_convenience() is True # ... + assert test_convenience() is True # .... # everything after here uses incremental parsing, so it must come last print(".", end="") - assert test_incremental() is True # .... + assert test_incremental() is True # ..... if IPY: print(".", end="") - assert test_kernel() is True # ..... + assert test_kernel() is True # ...... return True From 32ca30626b11ae1f9129ce1e0cb52b25507afe46 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Thu, 21 Dec 2023 20:17:51 -0800 Subject: [PATCH 21/54] Add to arg to all_equal Resolves #817. --- DOCS.md | 10 ++++++++-- __coconut__/__init__.pyi | 2 +- coconut/compiler/templates/header.py_template | 7 ++++--- coconut/root.py | 2 +- coconut/tests/src/cocotest/agnostic/primary_2.coco | 3 +++ coconut/tests/src/extras.coco | 3 +++ 6 files changed, 20 insertions(+), 7 deletions(-) diff --git a/DOCS.md b/DOCS.md index 9c1671d6c..139b798ff 100644 --- a/DOCS.md +++ b/DOCS.md @@ -4210,9 +4210,15 @@ _Can't be done without the definition of `windowsof`; see the compiled header fo #### `all_equal` -**all\_equal**(_iterable_) +**all\_equal**(_iterable_, _to_=`...`) -Coconut's `all_equal` built-in takes in an iterable and determines whether all of its elements are equal to each other. `all_equal` assumes transitivity of equality and that `!=` is the negation of `==`. Special support is provided for [`numpy`](#numpy-integration) objects. +Coconut's `all_equal` built-in takes in an iterable and determines whether all of its elements are equal to each other. + +If _to_ is passed, `all_equal` will check that all the elements are specifically equal to that value, rather than just equal to each other. + +Note that `all_equal` assumes transitivity of equality, that `!=` is the negation of `==`, and that empty arrays always have all their elements equal. + +Special support is provided for [`numpy`](#numpy-integration) objects. ##### Example diff --git a/__coconut__/__init__.pyi b/__coconut__/__init__.pyi index 007cdcfab..5f675ea51 100644 --- a/__coconut__/__init__.pyi +++ b/__coconut__/__init__.pyi @@ -1680,7 +1680,7 @@ def lift_apart(func: _t.Callable[..., _W]) -> _t.Callable[..., _t.Callable[..., ... -def all_equal(iterable: _Iterable) -> bool: +def all_equal(iterable: _t.Iterable[_T], to: _T = ...) -> bool: """For a given iterable, check whether all elements in that iterable are equal to each other. Supports numpy arrays. Assumes transitivity and 'x != y' being equivalent to 'not (x == y)'. diff --git a/coconut/compiler/templates/header.py_template b/coconut/compiler/templates/header.py_template index f67cd339d..cdf766aee 100644 --- a/coconut/compiler/templates/header.py_template +++ b/coconut/compiler/templates/header.py_template @@ -1953,8 +1953,9 @@ class lift_apart(lift): lift_apart also supports a shortcut form such that lift_apart(f, *func_args, **func_kwargs) is equivalent to lift_apart(f)(*func_args, **func_kwargs). """ _apart = True -def all_equal(iterable): +def all_equal(iterable, to=_coconut_sentinel): """For a given iterable, check whether all elements in that iterable are equal to each other. + If 'to' is passed, check that all the elements are equal to that value. Supports numpy arrays. Assumes transitivity and 'x != y' being equivalent to 'not (x == y)'. """ @@ -1964,8 +1965,8 @@ def all_equal(iterable): iterable = _coconut_xarray_to_numpy(iterable) elif iterable_module in _coconut.pandas_modules: iterable = iterable.to_numpy() - return not _coconut.len(iterable) or (iterable == iterable[0]).all() - first_item = _coconut_sentinel + return not _coconut.len(iterable) or (iterable == (iterable[0] if to is _coconut_sentinel else to)).all() + first_item = to for item in iterable: if first_item is _coconut_sentinel: first_item = item diff --git a/coconut/root.py b/coconut/root.py index 123449d7c..57bc138f1 100644 --- a/coconut/root.py +++ b/coconut/root.py @@ -26,7 +26,7 @@ VERSION = "3.0.4" VERSION_NAME = None # False for release, int >= 1 for develop -DEVELOP = 10 +DEVELOP = 11 ALPHA = False # for pre releases rather than post releases assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1" diff --git a/coconut/tests/src/cocotest/agnostic/primary_2.coco b/coconut/tests/src/cocotest/agnostic/primary_2.coco index b562ee789..ccee37e55 100644 --- a/coconut/tests/src/cocotest/agnostic/primary_2.coco +++ b/coconut/tests/src/cocotest/agnostic/primary_2.coco @@ -421,6 +421,9 @@ def primary_test_2() -> bool: arr |>= [[3; 4] ;; .] assert arr == [3; 4;; 1; 2] == [[3; 4] ;; .] |> call$(?, [. ; 2] |> call$(?, 1)) assert (if)(10, 20, 30) == 20 == (if)(0, 10, 20) + assert all_equal([], to=10) + assert all_equal([10; 10; 10; 10], to=10) + assert not all_equal([1, 1], to=10) with process_map.multiple_sequential_calls(): # type: ignore assert map((+), range(3), range(4)$[:-1], strict=True) |> list == [0, 2, 4] == process_map((+), range(3), range(4)$[:-1], strict=True) |> list # type: ignore diff --git a/coconut/tests/src/extras.coco b/coconut/tests/src/extras.coco index 1c5fbd7a1..87d0a701c 100644 --- a/coconut/tests/src/extras.coco +++ b/coconut/tests/src/extras.coco @@ -604,6 +604,9 @@ def test_numpy() -> bool: assert all_equal(np.array([1, 1])) assert all_equal(np.array([1, 1;; 1, 1])) assert not all_equal(np.array([1, 1;; 1, 2])) + assert all_equal(np.array([]), to=10) + assert all_equal(np.array([10; 10;; 10; 10]), to=10) + assert not all_equal(np.array([1, 1]), to=10) assert ( cartesian_product(np.array([1, 2]), np.array([3, 4])) `np.array_equal` From d6d9e5103c989a8315a2d25db88fa75afcc1062e Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Fri, 22 Dec 2023 02:19:21 -0800 Subject: [PATCH 22/54] Add line-by-line parsing Refs #815. --- coconut/_pyparsing.py | 7 +- coconut/compiler/compiler.py | 163 +++++++++++++++++++++++----------- coconut/compiler/grammar.py | 25 +++--- coconut/compiler/util.py | 145 ++++++++++++++++++------------ coconut/constants.py | 4 +- coconut/exceptions.py | 28 +++--- coconut/root.py | 2 +- coconut/tests/src/extras.coco | 22 +++-- coconut/util.py | 3 + 9 files changed, 254 insertions(+), 145 deletions(-) diff --git a/coconut/_pyparsing.py b/coconut/_pyparsing.py index c973208b5..6d08487a6 100644 --- a/coconut/_pyparsing.py +++ b/coconut/_pyparsing.py @@ -49,6 +49,7 @@ warn_on_multiline_regex, num_displayed_timing_items, use_cache_file, + use_line_by_line_parser, ) from coconut.util import get_clock_time # NOQA from coconut.util import ( @@ -183,7 +184,6 @@ def _parseCache(self, instring, loc, doActions=True, callPreParse=True): if isinstance(value, Exception): raise value return value[0], value[1].copy() - ParserElement._parseCache = _parseCache # [CPYPARSING] fix append @@ -249,11 +249,12 @@ def enableIncremental(*args, **kwargs): ) SUPPORTS_ADAPTIVE = ( - hasattr(MatchFirst, "setAdaptiveMode") - and USE_COMPUTATION_GRAPH + USE_COMPUTATION_GRAPH + and hasattr(MatchFirst, "setAdaptiveMode") ) USE_CACHE = SUPPORTS_INCREMENTAL and use_cache_file +USE_LINE_BY_LINE = USE_COMPUTATION_GRAPH and use_line_by_line_parser if MODERN_PYPARSING: _trim_arity = _pyparsing.core._trim_arity diff --git a/coconut/compiler/compiler.py b/coconut/compiler/compiler.py index e9eee011c..8d7ad0058 100644 --- a/coconut/compiler/compiler.py +++ b/coconut/compiler/compiler.py @@ -40,6 +40,7 @@ from coconut._pyparsing import ( USE_COMPUTATION_GRAPH, USE_CACHE, + USE_LINE_BY_LINE, ParseBaseException, ParseResults, col as getcol, @@ -181,6 +182,7 @@ pickle_cache, handle_and_manage, sub_all, + ComputationNode, ) from coconut.compiler.header import ( minify_header, @@ -602,6 +604,7 @@ def reset(self, keep_state=False, filename=None): self.add_code_before_regexes = {} self.add_code_before_replacements = {} self.add_code_before_ignore_names = {} + self.remaining_original = None @contextmanager def inner_environment(self, ln=None): @@ -618,8 +621,10 @@ def inner_environment(self, ln=None): parsing_context, self.parsing_context = self.parsing_context, defaultdict(list) kept_lines, self.kept_lines = self.kept_lines, [] num_lines, self.num_lines = self.num_lines, 0 + remaining_original, self.remaining_original = self.remaining_original, None try: - yield + with ComputationNode.using_overrides(): + yield finally: self.outer_ln = outer_ln self.line_numbers = line_numbers @@ -631,6 +636,7 @@ def inner_environment(self, ln=None): self.parsing_context = parsing_context self.kept_lines = kept_lines self.num_lines = num_lines + self.remaining_original = remaining_original def current_parsing_context(self, name, default=None): """Get the current parsing context for the given name.""" @@ -696,15 +702,15 @@ def method(cls, method_name, is_action=None, **kwargs): trim_arity = should_trim_arity(cls_method) if is_action else False @wraps(cls_method) - def method(original, loc, tokens): + def method(original, loc, tokens_or_item): self_method = getattr(cls.current_compiler, method_name) if kwargs: self_method = partial(self_method, **kwargs) if trim_arity: self_method = _trim_arity(self_method) - return self_method(original, loc, tokens) + return self_method(original, loc, tokens_or_item) internal_assert( - hasattr(cls_method, "ignore_tokens") is hasattr(method, "ignore_tokens") + hasattr(cls_method, "ignore_arguments") is hasattr(method, "ignore_arguments") and hasattr(cls_method, "ignore_no_tokens") is hasattr(method, "ignore_no_tokens") and hasattr(cls_method, "ignore_one_token") is hasattr(method, "ignore_one_token"), "failed to properly wrap method", @@ -1163,7 +1169,7 @@ def target_info(self): """Return information on the current target as a version tuple.""" return get_target_info(self.target) - def make_err(self, errtype, message, original, loc=0, ln=None, extra=None, reformat=True, endpoint=None, include_causes=False, **kwargs): + def make_err(self, errtype, message, original, loc=0, ln=None, extra=None, reformat=True, endpoint=None, include_causes=False, use_startpoint=False, **kwargs): """Generate an error of the specified type.""" logger.log_loc("raw_loc", original, loc) logger.log_loc("raw_endpoint", original, endpoint) @@ -1173,13 +1179,19 @@ def make_err(self, errtype, message, original, loc=0, ln=None, extra=None, refor logger.log_loc("loc", original, loc) # get endpoint + startpoint = None if endpoint is None: endpoint = reformat if endpoint is False: endpoint = loc else: if endpoint is True: - endpoint = get_highest_parse_loc(original) + if self.remaining_original is None: + endpoint = get_highest_parse_loc(original) + else: + startpoint = ComputationNode.add_to_loc + raw_endpoint = get_highest_parse_loc(self.remaining_original) + endpoint = startpoint + raw_endpoint logger.log_loc("highest_parse_loc", original, endpoint) endpoint = clip( move_endpt_to_non_whitespace(original, endpoint, backwards=True), @@ -1187,6 +1199,40 @@ def make_err(self, errtype, message, original, loc=0, ln=None, extra=None, refor ) logger.log_loc("endpoint", original, endpoint) + # process startpoint + if startpoint is not None: + startpoint = move_loc_to_non_whitespace(original, startpoint) + logger.log_loc("startpoint", original, startpoint) + + # determine possible causes + if include_causes: + self.internal_assert(extra is None, original, loc, "make_err cannot include causes with extra") + causes = dictset() + for check_loc in dictset((loc, endpoint, startpoint)): + if check_loc is not None: + for cause, _, _ in all_matches(self.parse_err_msg, original[check_loc:], inner=True): + if cause: + causes.add(cause) + if causes: + extra = "possible cause{s}: {causes}".format( + s="s" if len(causes) > 1 else "", + causes=", ".join(ordered(causes)), + ) + else: + extra = None + + # use startpoint if appropriate + if startpoint is None: + use_startpoint = False + else: + if use_startpoint is None: + use_startpoint = ( + "\n" not in original[loc:endpoint] + and "\n" in original[startpoint:loc] + ) + if use_startpoint: + loc = startpoint + # get line number if ln is None: if self.outer_ln is None: @@ -1208,33 +1254,19 @@ def make_err(self, errtype, message, original, loc=0, ln=None, extra=None, refor logger.log_loc("loc_in_snip", snippet, loc_in_snip) logger.log_loc("endpt_in_snip", snippet, endpt_in_snip) - # determine possible causes - if include_causes: - self.internal_assert(extra is None, original, loc, "make_err cannot include causes with extra") - causes = dictset() - for cause, _, _ in all_matches(self.parse_err_msg, snippet[loc_in_snip:]): - if cause: - causes.add(cause) - for cause, _, _ in all_matches(self.parse_err_msg, snippet[endpt_in_snip:]): - if cause: - causes.add(cause) - if causes: - extra = "possible cause{s}: {causes}".format( - s="s" if len(causes) > 1 else "", - causes=", ".join(ordered(causes)), - ) - else: - extra = None - # reformat the snippet and fix error locations to match if reformat: snippet, loc_in_snip, endpt_in_snip = self.reformat_locs(snippet, loc_in_snip, endpt_in_snip) logger.log_loc("reformatted_loc", snippet, loc_in_snip) logger.log_loc("reformatted_endpt", snippet, endpt_in_snip) + # build the error if extra is not None: kwargs["extra"] = extra - return errtype(message, snippet, loc_in_snip, ln, endpoint=endpt_in_snip, filename=self.filename, **kwargs) + err = errtype(message, snippet, loc_in_snip, ln, endpoint=endpt_in_snip, filename=self.filename, **kwargs) + if use_startpoint: + err = err.set_formatting(point_to_endpoint=True, max_err_msg_lines=2) + return err def make_syntax_err(self, err, original, after_parsing=False): """Make a CoconutSyntaxError from a CoconutDeferredSyntaxError.""" @@ -1247,7 +1279,7 @@ def make_parse_err(self, err, msg=None, include_ln=True, **kwargs): loc = err.loc ln = self.adjust(err.lineno) if include_ln else None - return self.make_err(CoconutParseError, msg, original, loc, ln, include_causes=True, **kwargs) + return self.make_err(CoconutParseError, msg, original, loc, ln, include_causes=True, use_startpoint=None, **kwargs) def make_internal_syntax_err(self, original, loc, msg, item, extra): """Make a CoconutInternalSyntaxError.""" @@ -1289,23 +1321,24 @@ def parsing(self, keep_state=False, codepath=None): Compiler.current_compiler = self yield - def streamline(self, grammar, inputstring=None, force=False, inner=False): - """Streamline the given grammar for the given inputstring.""" - input_len = 0 if inputstring is None else len(inputstring) - if force or (streamline_grammar_for_len is not None and input_len > streamline_grammar_for_len): - start_time = get_clock_time() - prep_grammar(grammar, streamline=True) - logger.log_lambda( - lambda: "Streamlined {grammar} in {time} seconds{info}.".format( - grammar=get_name(grammar), - time=get_clock_time() - start_time, - info="" if inputstring is None else " (streamlined due to receiving input of length {length})".format( - length=input_len, + def streamline(self, grammars, inputstring=None, force=False, inner=False): + """Streamline the given grammar(s) for the given inputstring.""" + for grammar in grammars if isinstance(grammars, tuple) else (grammars,): + input_len = 0 if inputstring is None else len(inputstring) + if force or (streamline_grammar_for_len is not None and input_len > streamline_grammar_for_len): + start_time = get_clock_time() + prep_grammar(grammar, streamline=True) + logger.log_lambda( + lambda: "Streamlined {grammar} in {time} seconds{info}.".format( + grammar=get_name(grammar), + time=get_clock_time() - start_time, + info="" if inputstring is None else " (streamlined due to receiving input of length {length})".format( + length=input_len, + ), ), - ), - ) - elif inputstring is not None and not inner: - logger.log("No streamlining done for input of length {length}.".format(length=input_len)) + ) + elif inputstring is not None and not inner: + logger.log("No streamlining done for input of length {length}.".format(length=input_len)) def run_final_checks(self, original, keep_state=False): """Run post-parsing checks to raise any necessary errors/warnings.""" @@ -1323,6 +1356,32 @@ def run_final_checks(self, original, keep_state=False): endpoint=False, ) + def parse_line_by_line(self, init_parser, line_parser, original): + """Apply init_parser then line_parser repeatedly.""" + if not USE_LINE_BY_LINE: + raise CoconutException("line-by-line parsing not supported", extra="run 'pip install --upgrade cPyparsing' to fix") + with ComputationNode.using_overrides(): + ComputationNode.override_original = original + out_parts = [] + init = True + cur_loc = 0 + while cur_loc < len(original): + self.remaining_original = original[cur_loc:] + ComputationNode.add_to_loc = cur_loc + results = parse(init_parser if init else line_parser, self.remaining_original, inner=False) + if len(results) == 1: + got_loc, = results + else: + got, got_loc = results + out_parts.append(got) + got_loc = int(got_loc) + internal_assert(got_loc >= cur_loc, "invalid line by line parse", (cur_loc, results), extra=lambda: "in: " + repr(self.remaining_original.split("\n", 1)[0])) + if not init and got_loc == cur_loc: + raise self.make_err(CoconutParseError, "parsing could not continue", original, cur_loc, include_causes=True) + cur_loc = got_loc + init = False + return "".join(out_parts) + def parse( self, inputstring, @@ -1352,7 +1411,11 @@ def parse( with logger.gather_parsing_stats(): try: pre_procd = self.pre(inputstring, keep_state=keep_state, **preargs) - parsed = parse(parser, pre_procd, inner=False) + if isinstance(parser, tuple): + init_parser, line_parser = parser + parsed = self.parse_line_by_line(init_parser, line_parser, pre_procd) + else: + parsed = parse(parser, pre_procd, inner=False) out = self.post(parsed, keep_state=keep_state, **postargs) except ParseBaseException as err: raise self.make_parse_err(err) @@ -1817,7 +1880,7 @@ def ind_proc(self, inputstring, **kwargs): original=line, ln=self.adjust(len(new)), **err_kwargs - ).set_point_to_endpoint(True) + ).set_formatting(point_to_endpoint=True) self.set_skips(skips) if new: @@ -2053,7 +2116,7 @@ def split_docstring(self, block): pass else: raw_first_line = split_leading_trailing_indent(rem_comment(first_line))[1] - if match_in(self.just_a_string, raw_first_line): + if match_in(self.just_a_string, raw_first_line, inner=True): return first_line, rest_of_lines return None, block @@ -4098,7 +4161,7 @@ def get_generic_for_typevars(self): return "_coconut.typing.Generic[" + ", ".join(generics) + "]" @contextmanager - def type_alias_stmt_manage(self, item=None, original=None, loc=None): + def type_alias_stmt_manage(self, original=None, loc=None, item=None): """Manage the typevars parsing context.""" prev_typevar_info = self.current_parsing_context("typevars") with self.add_to_parsing_context("typevars", { @@ -4132,7 +4195,7 @@ def where_item_handle(self, tokens): return tokens @contextmanager - def where_stmt_manage(self, item, original, loc): + def where_stmt_manage(self, original, loc, item): """Manage where statements.""" with self.add_to_parsing_context("where", { "assigns": None, @@ -4187,7 +4250,7 @@ def ellipsis_handle(self, tokens=None): else: return "_coconut.Ellipsis" - ellipsis_handle.ignore_tokens = True + ellipsis_handle.ignore_arguments = True def match_case_tokens(self, match_var, check_var, original, tokens, top): """Build code for matching the given case.""" @@ -4634,7 +4697,7 @@ def check_py(self, version, name, original, loc, tokens): return tokens[0] @contextmanager - def class_manage(self, item, original, loc): + def class_manage(self, original, loc, item): """Manage the class parsing context.""" cls_stack = self.parsing_context["class"] if cls_stack: @@ -4660,7 +4723,7 @@ def class_manage(self, item, original, loc): cls_stack.pop() @contextmanager - def func_manage(self, item, original, loc): + def func_manage(self, original, loc, item): """Manage the function parsing context.""" cls_context = self.current_parsing_context("class") if cls_context is not None: diff --git a/coconut/compiler/grammar.py b/coconut/compiler/grammar.py index 0b7497481..76f9dc8f9 100644 --- a/coconut/compiler/grammar.py +++ b/coconut/compiler/grammar.py @@ -32,6 +32,7 @@ from functools import partial from coconut._pyparsing import ( + USE_LINE_BY_LINE, Forward, Group, Literal, @@ -2472,12 +2473,18 @@ class Grammar(object): line = newline | stmt - single_input = condense(Optional(line) - ZeroOrMore(newline)) file_input = condense(moduledoc_marker - ZeroOrMore(line)) + raw_file_parser = start_marker - file_input - end_marker + line_by_line_file_parser = ( + start_marker - moduledoc_marker - stores_loc_item, + start_marker - line - stores_loc_item, + ) + file_parser = line_by_line_file_parser if USE_LINE_BY_LINE else raw_file_parser + + single_input = condense(Optional(line) - ZeroOrMore(newline)) eval_input = condense(testlist - ZeroOrMore(newline)) single_parser = start_marker - single_input - end_marker - file_parser = start_marker - file_input - end_marker eval_parser = start_marker - eval_input - end_marker some_eval_parser = start_marker + eval_input @@ -2637,14 +2644,9 @@ class Grammar(object): unsafe_equals = Literal("=") - kwd_err_msg = attach(any_keyword_in(keyword_vars + reserved_vars), kwd_err_msg_handle) - parse_err_msg = ( - start_marker + ( - fixto(end_of_line, "misplaced newline (maybe missing ':')") - | fixto(Optional(keyword("if") + skip_to_in_line(unsafe_equals)) + equals, "misplaced assignment (maybe should be '==')") - | kwd_err_msg - ) - | fixto( + parse_err_msg = start_marker + ( + # should be in order of most likely to actually be the source of the error first + ZeroOrMore(~questionmark + ~Literal("\n") + any_char) + fixto( questionmark + ~dollar + ~lparen @@ -2652,6 +2654,9 @@ class Grammar(object): + ~dot, "misplaced '?' (naked '?' is only supported inside partial application arguments)", ) + | fixto(Optional(keyword("if") + skip_to_in_line(unsafe_equals)) + equals, "misplaced assignment (maybe should be '==')") + | attach(any_keyword_in(keyword_vars + reserved_vars), kwd_err_msg_handle) + | fixto(end_of_line, "misplaced newline (maybe missing ':')") ) end_f_str_expr = combine(start_marker + (rbrace | colon | bang)) diff --git a/coconut/compiler/util.py b/coconut/compiler/util.py index 760cf6bd1..b94bafb14 100644 --- a/coconut/compiler/util.py +++ b/coconut/compiler/util.py @@ -264,6 +264,19 @@ class ComputationNode(object): """A single node in the computation graph.""" __slots__ = ("action", "original", "loc", "tokens") pprinting = False + override_original = None + add_to_loc = 0 + + @classmethod + @contextmanager + def using_overrides(cls): + override_original, cls.override_original = cls.override_original, None + add_to_loc, cls.add_to_loc = cls.add_to_loc, 0 + try: + yield + finally: + cls.override_original = override_original + cls.add_to_loc = add_to_loc def __new__(cls, action, original, loc, tokens, ignore_no_tokens=False, ignore_one_token=False, greedy=False, trim_arity=True): """Create a ComputionNode to return from a parse action. @@ -281,8 +294,8 @@ def __new__(cls, action, original, loc, tokens, ignore_no_tokens=False, ignore_o self.action = _trim_arity(action) else: self.action = action - self.original = original - self.loc = loc + self.original = original if self.override_original is None else self.override_original + self.loc = self.add_to_loc + loc self.tokens = tokens if greedy: return self.evaluate() @@ -391,12 +404,38 @@ def add_action(item, action, make_copy=None): return item.addParseAction(action) -def attach(item, action, ignore_no_tokens=None, ignore_one_token=None, ignore_tokens=None, trim_arity=None, make_copy=None, **kwargs): +def get_func_args(func): + """Inspect a function to determine its argument names.""" + if PY2: + return inspect.getargspec(func)[0] + else: + return inspect.getfullargspec(func)[0] + + +def should_trim_arity(func): + """Determine if we need to call _trim_arity on func.""" + annotation = getattr(func, "trim_arity", None) + if annotation is not None: + return annotation + try: + func_args = get_func_args(func) + except TypeError: + return True + if not func_args: + return True + if func_args[0] == "self": + func_args.pop(0) + if func_args[:3] == ["original", "loc", "tokens"]: + return False + return True + + +def attach(item, action, ignore_no_tokens=None, ignore_one_token=None, ignore_arguments=None, trim_arity=None, make_copy=None, **kwargs): """Set the parse action for the given item to create a node in the computation graph.""" - if ignore_tokens is None: - ignore_tokens = getattr(action, "ignore_tokens", False) - # if ignore_tokens, then we can just pass in the computation graph and have it be ignored - if not ignore_tokens and USE_COMPUTATION_GRAPH: + if ignore_arguments is None: + ignore_arguments = getattr(action, "ignore_arguments", False) + # if ignore_arguments, then we can just pass in the computation graph and have it be ignored + if not ignore_arguments and USE_COMPUTATION_GRAPH: # use the action's annotations to generate the defaults if ignore_no_tokens is None: ignore_no_tokens = getattr(action, "ignore_no_tokens", False) @@ -422,7 +461,7 @@ def final_evaluate_tokens(tokens): @contextmanager -def adaptive_manager(item, original, loc, reparse=False): +def adaptive_manager(original, loc, item, reparse=False): """Manage the use of MatchFirst.setAdaptiveMode.""" if reparse: cleared_cache = clear_packrat_cache() @@ -489,11 +528,22 @@ def force_reset_packrat_cache(): @contextmanager -def parsing_context(inner_parse=True): +def parsing_context(inner_parse=None): """Context to manage the packrat cache across parse calls.""" - if not inner_parse: - yield - elif should_clear_cache(): + current_cache_matters = ParserElement._packratEnabled + new_cache_matters = ( + not inner_parse + and ParserElement._incrementalEnabled + and not ParserElement._incrementalWithResets + ) + will_clear_cache = ( + not ParserElement._incrementalEnabled + or ParserElement._incrementalWithResets + ) + if ( + current_cache_matters + and not new_cache_matters + ): # store old packrat cache old_cache = ParserElement.packrat_cache old_cache_stats = ParserElement.packrat_cache_stats[:] @@ -507,8 +557,11 @@ def parsing_context(inner_parse=True): if logger.verbose: ParserElement.packrat_cache_stats[0] += old_cache_stats[0] ParserElement.packrat_cache_stats[1] += old_cache_stats[1] - # if we shouldn't clear the cache, but we're using incrementalWithResets, then do this to avoid clearing it - elif ParserElement._incrementalWithResets: + elif ( + current_cache_matters + and new_cache_matters + and will_clear_cache + ): incrementalWithResets, ParserElement._incrementalWithResets = ParserElement._incrementalWithResets, False try: yield @@ -529,7 +582,7 @@ def prep_grammar(grammar, streamline=False): return grammar.parseWithTabs() -def parse(grammar, text, inner=True, eval_parse_tree=True): +def parse(grammar, text, inner=None, eval_parse_tree=True): """Parse text using grammar.""" with parsing_context(inner): result = prep_grammar(grammar).parseString(text) @@ -538,7 +591,7 @@ def parse(grammar, text, inner=True, eval_parse_tree=True): return result -def try_parse(grammar, text, inner=True, eval_parse_tree=True): +def try_parse(grammar, text, inner=None, eval_parse_tree=True): """Attempt to parse text using grammar else None.""" try: return parse(grammar, text, inner, eval_parse_tree) @@ -546,12 +599,12 @@ def try_parse(grammar, text, inner=True, eval_parse_tree=True): return None -def does_parse(grammar, text, inner=True): +def does_parse(grammar, text, inner=None): """Determine if text can be parsed using grammar.""" return try_parse(grammar, text, inner, eval_parse_tree=False) -def all_matches(grammar, text, inner=True, eval_parse_tree=True): +def all_matches(grammar, text, inner=None, eval_parse_tree=True): """Find all matches for grammar in text.""" with parsing_context(inner): for tokens, start, stop in prep_grammar(grammar).scanString(text): @@ -560,21 +613,21 @@ def all_matches(grammar, text, inner=True, eval_parse_tree=True): yield tokens, start, stop -def parse_where(grammar, text, inner=True): +def parse_where(grammar, text, inner=None): """Determine where the first parse is.""" for tokens, start, stop in all_matches(grammar, text, inner, eval_parse_tree=False): return start, stop return None, None -def match_in(grammar, text, inner=True): +def match_in(grammar, text, inner=None): """Determine if there is a match for grammar anywhere in text.""" start, stop = parse_where(grammar, text, inner) internal_assert((start is None) == (stop is None), "invalid parse_where results", (start, stop)) return start is not None -def transform(grammar, text, inner=True): +def transform(grammar, text, inner=None): """Transform text by replacing matches to grammar.""" with parsing_context(inner): result = prep_grammar(add_action(grammar, unpack)).transformString(text) @@ -844,11 +897,18 @@ def get_cache_items_for(original, only_useful=False, exclude_stale=True): yield lookup, value -def get_highest_parse_loc(original): +def get_highest_parse_loc(original, only_successes=False): """Get the highest observed parse location.""" - # find the highest observed parse location highest_loc = 0 for lookup, _ in get_cache_items_for(original): + if only_successes: + if SUPPORTS_INCREMENTAL and ParserElement._incrementalEnabled: + # parseIncremental failure + if lookup[1] is True: + continue + # parseCache failure + elif not isinstance(lookup, tuple): + continue loc = lookup[2] if loc > highest_loc: highest_loc = loc @@ -1179,7 +1239,7 @@ def parseImpl(self, original, loc, *args, **kwargs): reparse = False parse_loc = None while parse_loc is None: # lets wrapper catch errors to trigger a reparse - with self.wrapper(self, original, loc, **(dict(reparse=True) if reparse else {})): + with self.wrapper(original, loc, self, **(dict(reparse=True) if reparse else {})): with self.wrapped_context(): parse_loc, tokens = super(Wrap, self).parseImpl(original, loc, *args, **kwargs) if self.greedy: @@ -1215,7 +1275,7 @@ def disable_inside(item, *elems, **kwargs): level = [0] # number of wrapped items deep we are; in a list to allow modification @contextmanager - def manage_item(self, original, loc): + def manage_item(original, loc, self): level[0] += 1 try: yield @@ -1225,7 +1285,7 @@ def manage_item(self, original, loc): yield Wrap(item, manage_item, include_in_packrat_context=True) @contextmanager - def manage_elem(self, original, loc): + def manage_elem(original, loc, self): if level[0] == 0 if not _invert else level[0] > 0: yield else: @@ -1259,7 +1319,7 @@ def invalid_syntax(item, msg, **kwargs): def invalid_syntax_handle(loc, tokens): raise CoconutDeferredSyntaxError(msg, loc) - return attach(item, invalid_syntax_handle, ignore_tokens=True, **kwargs) + return attach(item, invalid_syntax_handle, ignore_arguments=True, **kwargs) def skip_to_in_line(item): @@ -1303,7 +1363,7 @@ def regex_item(regex, options=None): def fixto(item, output): """Force an item to result in a specific output.""" - return attach(item, replaceWith(output), ignore_tokens=True) + return attach(item, replaceWith(output), ignore_arguments=True) def addspace(item): @@ -1414,9 +1474,6 @@ def stores_loc_action(loc, tokens): return str(loc) -stores_loc_action.ignore_tokens = True - - always_match = Empty() stores_loc_item = attach(always_match, stores_loc_action) @@ -1883,32 +1940,6 @@ def literal_eval(py_code): raise CoconutInternalException("failed to literal eval", py_code) -def get_func_args(func): - """Inspect a function to determine its argument names.""" - if PY2: - return inspect.getargspec(func)[0] - else: - return inspect.getfullargspec(func)[0] - - -def should_trim_arity(func): - """Determine if we need to call _trim_arity on func.""" - annotation = getattr(func, "trim_arity", None) - if annotation is not None: - return annotation - try: - func_args = get_func_args(func) - except TypeError: - return True - if not func_args: - return True - if func_args[0] == "self": - func_args.pop(0) - if func_args[:3] == ["original", "loc", "tokens"]: - return False - return True - - def sequential_split(inputstr, splits): """Slice off parts of inputstr by sequential splits.""" out = [inputstr] diff --git a/coconut/constants.py b/coconut/constants.py index a3268f3b3..1ebf42bc1 100644 --- a/coconut/constants.py +++ b/coconut/constants.py @@ -131,6 +131,8 @@ def get_path_env_var(env_var, default): # below constants are experimentally determined to maximize performance +use_line_by_line_parser = True + use_packrat_parser = True # True also gives us better error messages packrat_cache_size = None # only works because final() clears the cache @@ -323,7 +325,7 @@ def get_path_env_var(env_var, default): taberrfmt = 2 # spaces to indent exceptions min_squiggles_in_err_msg = 1 -max_err_msg_lines = 10 +default_max_err_msg_lines = 10 # for pattern-matching default_matcher_style = "python warn" diff --git a/coconut/exceptions.py b/coconut/exceptions.py index c431a70e7..1f3eb1730 100644 --- a/coconut/exceptions.py +++ b/coconut/exceptions.py @@ -22,6 +22,7 @@ import traceback from coconut._pyparsing import ( + USE_LINE_BY_LINE, lineno, col as getcol, ) @@ -30,7 +31,7 @@ taberrfmt, report_this_text, min_squiggles_in_err_msg, - max_err_msg_lines, + default_max_err_msg_lines, ) from coconut.util import ( pickleable_obj, @@ -90,7 +91,6 @@ class CoconutException(BaseCoconutException, Exception): class CoconutSyntaxError(CoconutException): """Coconut SyntaxError.""" - point_to_endpoint = False argnames = ("message", "source", "point", "ln", "extra", "endpoint", "filename") def __init__(self, message, source=None, point=None, ln=None, extra=None, endpoint=None, filename=None): @@ -102,6 +102,17 @@ def kwargs(self): """Get the arguments as keyword arguments.""" return dict(zip(self.argnames, self.args)) + point_to_endpoint = False + max_err_msg_lines = default_max_err_msg_lines + + def set_formatting(self, point_to_endpoint=None, max_err_msg_lines=None): + """Sets formatting values.""" + if point_to_endpoint is not None: + self.point_to_endpoint = point_to_endpoint + if max_err_msg_lines is not None: + self.max_err_msg_lines = max_err_msg_lines + return self + def message(self, message, source, point, ln, extra=None, endpoint=None, filename=None): """Creates a SyntaxError-like message.""" message_parts = ["parsing failed" if message is None else message] @@ -195,11 +206,11 @@ def message(self, message, source, point, ln, extra=None, endpoint=None, filenam # add code, highlighting all of it together code_parts = [] - if len(lines) > max_err_msg_lines: - for i in range(max_err_msg_lines // 2): + if len(lines) > self.max_err_msg_lines: + for i in range(self.max_err_msg_lines // 2): code_parts += ["\n", " " * taberrfmt, lines[i]] code_parts += ["\n", " " * (taberrfmt // 2), "..."] - for i in range(len(lines) - max_err_msg_lines // 2, len(lines)): + for i in range(len(lines) - self.max_err_msg_lines // 2, len(lines)): code_parts += ["\n", " " * taberrfmt, lines[i]] else: for line in lines: @@ -235,11 +246,6 @@ def syntax_err(self): err.filename = filename return err - def set_point_to_endpoint(self, point_to_endpoint): - """Sets whether to point to the endpoint.""" - self.point_to_endpoint = point_to_endpoint - return self - class CoconutStyleError(CoconutSyntaxError): """Coconut --strict error.""" @@ -268,7 +274,7 @@ def message(self, message, source, point, ln, target, endpoint, filename): class CoconutParseError(CoconutSyntaxError): """Coconut ParseError.""" - point_to_endpoint = True + point_to_endpoint = not USE_LINE_BY_LINE class CoconutWarning(CoconutException): diff --git a/coconut/root.py b/coconut/root.py index 57bc138f1..ffe859e89 100644 --- a/coconut/root.py +++ b/coconut/root.py @@ -26,7 +26,7 @@ VERSION = "3.0.4" VERSION_NAME = None # False for release, int >= 1 for develop -DEVELOP = 11 +DEVELOP = 12 ALPHA = False # for pre releases rather than post releases assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1" diff --git a/coconut/tests/src/extras.coco b/coconut/tests/src/extras.coco index 87d0a701c..edcc5dffa 100644 --- a/coconut/tests/src/extras.coco +++ b/coconut/tests/src/extras.coco @@ -208,13 +208,11 @@ cannot reassign type variable 'T' (use explicit '\T' syntax if intended) (line 1 assert_raises(-> parse("$"), CoconutParseError) assert_raises(-> parse("@"), CoconutParseError) assert_raises(-> parse("range(1,10) |> reduce$(*, initializer = 1000) |> print"), CoconutParseError, err_has=( - " \\~~~~~~~~~~~~~~~~~~~~~~~^", - " \\~~~~~~~~~~~~^", + "\n \~~^", )) - assert_raises(-> parse("a := b"), CoconutParseError, err_has=" \\~^") + assert_raises(-> parse("a := b"), CoconutParseError, err_has="\n ^") assert_raises(-> parse("1 + return"), CoconutParseError, err_has=( - " \\~~~^", - " \\~~~~^", + "\n \~~^", )) assert_raises(-> parse(""" def f() = @@ -231,19 +229,19 @@ def f() = ~^ """.strip() )) - assert_raises(-> parse('b"abc" "def"'), CoconutParseError, err_has=" \\~~~~~~^") - assert_raises(-> parse('"abc" b"def"'), CoconutParseError, err_has=" \\~~~~~^") - assert_raises(-> parse('"a" 10'), CoconutParseError, err_has=" \\~~~^") - assert_raises(-> parse("A. ."), CoconutParseError, err_has=" \\~~^") + assert_raises(-> parse('b"abc" "def"'), CoconutParseError, err_has="\n ^") + assert_raises(-> parse('"abc" b"def"'), CoconutParseError, err_has="\n ^") + assert_raises(-> parse('"a" 10'), CoconutParseError, err_has="\n ^") + assert_raises(-> parse("A. ."), CoconutParseError, err_has="\n \~^") assert_raises(-> parse('''f"""{ }"""'''), CoconutSyntaxError, err_has="parsing failed for format string expression") - assert_raises(-> parse("f([] {})"), CoconutParseError, err_has=" \\~~~~^") + assert_raises(-> parse("f([] {})"), CoconutParseError, err_has="\n \~~~^") assert_raises(-> parse("return = 1"), CoconutParseError, err_has='invalid use of the keyword "return"') assert_raises(-> parse("if a = b: pass"), CoconutParseError, err_has="misplaced assignment") assert_raises(-> parse("while a == b"), CoconutParseError, err_has="misplaced newline") - assert_raises(-> parse("0xfgf"), CoconutParseError, err_has=r" \~~^") - assert_raises(-> parse("(. if 1)"), CoconutParseError, err_has=r" \~~^") + assert_raises(-> parse("0xfgf"), CoconutParseError, err_has=r"\n ^") + assert_raises(-> parse("(. if 1)"), CoconutParseError, err_has=r"\n ^") try: parse(""" diff --git a/coconut/util.py b/coconut/util.py index 1a07d0c3b..3862af193 100644 --- a/coconut/util.py +++ b/coconut/util.py @@ -266,6 +266,9 @@ def __missing__(self, key): class dictset(dict, object): """A set implemented using a dictionary to get ordering benefits.""" + def __init__(self, items=()): + super(dictset, self).__init__((x, True) for x in items) + def __bool__(self): return len(self) > 0 # fixes py2 issue From d8941a6b46679184b8a38fd54f089bf137e1799c Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Fri, 22 Dec 2023 15:34:59 -0800 Subject: [PATCH 23/54] Fix line-by-line parsing --- coconut/compiler/compiler.py | 27 +++++++++------- coconut/compiler/grammar.py | 5 +-- coconut/compiler/util.py | 58 +++++++++++++++++------------------ coconut/exceptions.py | 7 ++--- coconut/tests/main_test.py | 5 ++- coconut/tests/src/extras.coco | 17 +++++----- 6 files changed, 62 insertions(+), 57 deletions(-) diff --git a/coconut/compiler/compiler.py b/coconut/compiler/compiler.py index 8d7ad0058..49aa171f4 100644 --- a/coconut/compiler/compiler.py +++ b/coconut/compiler/compiler.py @@ -151,7 +151,6 @@ match_in, transform, parse, - all_matches, get_target_info_smart, split_leading_comments, compile_regex, @@ -1210,9 +1209,9 @@ def make_err(self, errtype, message, original, loc=0, ln=None, extra=None, refor causes = dictset() for check_loc in dictset((loc, endpoint, startpoint)): if check_loc is not None: - for cause, _, _ in all_matches(self.parse_err_msg, original[check_loc:], inner=True): - if cause: - causes.add(cause) + cause = try_parse(self.parse_err_msg, original[check_loc:], inner=True) + if cause: + causes.add(cause) if causes: extra = "possible cause{s}: {causes}".format( s="s" if len(causes) > 1 else "", @@ -1263,10 +1262,18 @@ def make_err(self, errtype, message, original, loc=0, ln=None, extra=None, refor # build the error if extra is not None: kwargs["extra"] = extra - err = errtype(message, snippet, loc_in_snip, ln, endpoint=endpt_in_snip, filename=self.filename, **kwargs) - if use_startpoint: - err = err.set_formatting(point_to_endpoint=True, max_err_msg_lines=2) - return err + return errtype( + message, + snippet, + loc_in_snip, + ln, + endpoint=endpt_in_snip, + filename=self.filename, + **kwargs, + ).set_formatting( + point_to_endpoint=True if use_startpoint else None, + max_err_msg_lines=2 if use_startpoint else None, + ) def make_syntax_err(self, err, original, after_parsing=False): """Make a CoconutSyntaxError from a CoconutDeferredSyntaxError.""" @@ -1375,9 +1382,7 @@ def parse_line_by_line(self, init_parser, line_parser, original): got, got_loc = results out_parts.append(got) got_loc = int(got_loc) - internal_assert(got_loc >= cur_loc, "invalid line by line parse", (cur_loc, results), extra=lambda: "in: " + repr(self.remaining_original.split("\n", 1)[0])) - if not init and got_loc == cur_loc: - raise self.make_err(CoconutParseError, "parsing could not continue", original, cur_loc, include_causes=True) + internal_assert(got_loc >= cur_loc and (init or got_loc > cur_loc), "invalid line by line parse", (cur_loc, results), extra=lambda: "in: " + repr(self.remaining_original.split("\n", 1)[0])) cur_loc = got_loc init = False return "".join(out_parts) diff --git a/coconut/compiler/grammar.py b/coconut/compiler/grammar.py index 76f9dc8f9..0a30dd146 100644 --- a/coconut/compiler/grammar.py +++ b/coconut/compiler/grammar.py @@ -2646,8 +2646,9 @@ class Grammar(object): parse_err_msg = start_marker + ( # should be in order of most likely to actually be the source of the error first - ZeroOrMore(~questionmark + ~Literal("\n") + any_char) + fixto( - questionmark + fixto( + ZeroOrMore(~questionmark + ~Literal("\n") + any_char) + + questionmark + ~dollar + ~lparen + ~lbrack diff --git a/coconut/compiler/util.py b/coconut/compiler/util.py index b94bafb14..d0925e277 100644 --- a/coconut/compiler/util.py +++ b/coconut/compiler/util.py @@ -454,12 +454,6 @@ def attach(item, action, ignore_no_tokens=None, ignore_one_token=None, ignore_ar return add_action(item, action, make_copy) -def final_evaluate_tokens(tokens): - """Same as evaluate_tokens but should only be used once a parse is assured.""" - clear_packrat_cache() - return evaluate_tokens(tokens, is_final=True) - - @contextmanager def adaptive_manager(original, loc, item, reparse=False): """Manage the use of MatchFirst.setAdaptiveMode.""" @@ -489,6 +483,14 @@ def adaptive_manager(original, loc, item, reparse=False): MatchFirst.setAdaptiveMode(False) +def final_evaluate_tokens(tokens): + """Same as evaluate_tokens but should only be used once a parse is assured.""" + result = evaluate_tokens(tokens, is_final=True) + # clear packrat cache after evaluating tokens so error creation gets to see the cache + clear_packrat_cache() + return result + + def final(item): """Collapse the computation graph upon parsing the given item.""" if SUPPORTS_ADAPTIVE and use_adaptive_if_available: @@ -530,9 +532,12 @@ def force_reset_packrat_cache(): @contextmanager def parsing_context(inner_parse=None): """Context to manage the packrat cache across parse calls.""" - current_cache_matters = ParserElement._packratEnabled + current_cache_matters = ( + inner_parse is not False + and ParserElement._packratEnabled + ) new_cache_matters = ( - not inner_parse + inner_parse is not True and ParserElement._incrementalEnabled and not ParserElement._incrementalWithResets ) @@ -542,7 +547,17 @@ def parsing_context(inner_parse=None): ) if ( current_cache_matters - and not new_cache_matters + and new_cache_matters + and ParserElement._incrementalWithResets + ): + incrementalWithResets, ParserElement._incrementalWithResets = ParserElement._incrementalWithResets, False + try: + yield + finally: + ParserElement._incrementalWithResets = incrementalWithResets + elif ( + current_cache_matters + and will_clear_cache ): # store old packrat cache old_cache = ParserElement.packrat_cache @@ -557,16 +572,6 @@ def parsing_context(inner_parse=None): if logger.verbose: ParserElement.packrat_cache_stats[0] += old_cache_stats[0] ParserElement.packrat_cache_stats[1] += old_cache_stats[1] - elif ( - current_cache_matters - and new_cache_matters - and will_clear_cache - ): - incrementalWithResets, ParserElement._incrementalWithResets = ParserElement._incrementalWithResets, False - try: - yield - finally: - ParserElement._incrementalWithResets = incrementalWithResets else: yield @@ -806,7 +811,7 @@ def should_clear_cache(force=False): return True elif not ParserElement._packratEnabled: return False - elif SUPPORTS_INCREMENTAL and ParserElement._incrementalEnabled: + elif ParserElement._incrementalEnabled: if not in_incremental_mode(): return repeatedly_clear_incremental_cache if ( @@ -897,18 +902,11 @@ def get_cache_items_for(original, only_useful=False, exclude_stale=True): yield lookup, value -def get_highest_parse_loc(original, only_successes=False): - """Get the highest observed parse location.""" +def get_highest_parse_loc(original): + """Get the highest observed parse location. + Note that there's no point in filtering for successes/failures, since we always see both at the same locations.""" highest_loc = 0 for lookup, _ in get_cache_items_for(original): - if only_successes: - if SUPPORTS_INCREMENTAL and ParserElement._incrementalEnabled: - # parseIncremental failure - if lookup[1] is True: - continue - # parseCache failure - elif not isinstance(lookup, tuple): - continue loc = lookup[2] if loc > highest_loc: highest_loc = loc diff --git a/coconut/exceptions.py b/coconut/exceptions.py index 1f3eb1730..9edf9f840 100644 --- a/coconut/exceptions.py +++ b/coconut/exceptions.py @@ -22,7 +22,6 @@ import traceback from coconut._pyparsing import ( - USE_LINE_BY_LINE, lineno, col as getcol, ) @@ -169,8 +168,8 @@ def message(self, message, source, point, ln, extra=None, endpoint=None, filenam message_parts += ["\n", " " * taberrfmt, highlight(part)] # add squiggles to message - if point_ind > 0 or endpoint_ind > 0: - err_len = endpoint_ind - point_ind + err_len = endpoint_ind - point_ind + if (point_ind > 0 or endpoint_ind > 0) and err_len < len(part): message_parts += ["\n", " " * (taberrfmt + point_ind)] if err_len <= min_squiggles_in_err_msg: if not self.point_to_endpoint: @@ -274,7 +273,7 @@ def message(self, message, source, point, ln, target, endpoint, filename): class CoconutParseError(CoconutSyntaxError): """Coconut ParseError.""" - point_to_endpoint = not USE_LINE_BY_LINE + point_to_endpoint = True class CoconutWarning(CoconutException): diff --git a/coconut/tests/main_test.py b/coconut/tests/main_test.py index 4e51c793e..0f84ee941 100644 --- a/coconut/tests/main_test.py +++ b/coconut/tests/main_test.py @@ -148,9 +148,8 @@ "INTERNAL ERROR", ) ignore_error_lines_with = ( - # ignore SyntaxWarnings containing assert_raises - "assert_raises(", - " raise ", + # ignore SyntaxWarnings containing assert_raises or raise + "raise", ) mypy_snip = "a: str = count()[0]" diff --git a/coconut/tests/src/extras.coco b/coconut/tests/src/extras.coco index edcc5dffa..c81fe0cf7 100644 --- a/coconut/tests/src/extras.coco +++ b/coconut/tests/src/extras.coco @@ -208,11 +208,11 @@ cannot reassign type variable 'T' (use explicit '\T' syntax if intended) (line 1 assert_raises(-> parse("$"), CoconutParseError) assert_raises(-> parse("@"), CoconutParseError) assert_raises(-> parse("range(1,10) |> reduce$(*, initializer = 1000) |> print"), CoconutParseError, err_has=( - "\n \~~^", + "\n \\~~^", )) assert_raises(-> parse("a := b"), CoconutParseError, err_has="\n ^") assert_raises(-> parse("1 + return"), CoconutParseError, err_has=( - "\n \~~^", + "\n \\~~^", )) assert_raises(-> parse(""" def f() = @@ -227,21 +227,24 @@ def f() = """ assert 2 ~^ - """.strip() + """.strip(), )) assert_raises(-> parse('b"abc" "def"'), CoconutParseError, err_has="\n ^") assert_raises(-> parse('"abc" b"def"'), CoconutParseError, err_has="\n ^") assert_raises(-> parse('"a" 10'), CoconutParseError, err_has="\n ^") - assert_raises(-> parse("A. ."), CoconutParseError, err_has="\n \~^") + assert_raises(-> parse("A. ."), CoconutParseError, err_has="\n \\~^") assert_raises(-> parse('''f"""{ }"""'''), CoconutSyntaxError, err_has="parsing failed for format string expression") - assert_raises(-> parse("f([] {})"), CoconutParseError, err_has="\n \~~~^") + assert_raises(-> parse("f([] {})"), CoconutParseError, err_has="\n \\~~~^") assert_raises(-> parse("return = 1"), CoconutParseError, err_has='invalid use of the keyword "return"') assert_raises(-> parse("if a = b: pass"), CoconutParseError, err_has="misplaced assignment") assert_raises(-> parse("while a == b"), CoconutParseError, err_has="misplaced newline") - assert_raises(-> parse("0xfgf"), CoconutParseError, err_has=r"\n ^") - assert_raises(-> parse("(. if 1)"), CoconutParseError, err_has=r"\n ^") + assert_raises(-> parse("0xfgf"), CoconutParseError, err_has=( + "\n ^", + "\n \\~^", + )) + assert_raises(-> parse("(. if 1)"), CoconutParseError, err_has="\n ^") try: parse(""" From 5efdabae70cb5ba68e5cf337866f1ef2bde21fb7 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Fri, 22 Dec 2023 16:49:16 -0800 Subject: [PATCH 24/54] Disable line-by-line --- coconut/constants.py | 4 +-- coconut/tests/src/extras.coco | 48 ++++++++++++++++++++++++++--------- 2 files changed, 38 insertions(+), 14 deletions(-) diff --git a/coconut/constants.py b/coconut/constants.py index 1ebf42bc1..dd18ca060 100644 --- a/coconut/constants.py +++ b/coconut/constants.py @@ -131,8 +131,6 @@ def get_path_env_var(env_var, default): # below constants are experimentally determined to maximize performance -use_line_by_line_parser = True - use_packrat_parser = True # True also gives us better error messages packrat_cache_size = None # only works because final() clears the cache @@ -148,6 +146,8 @@ def get_path_env_var(env_var, default): # note that _parseIncremental produces much smaller caches use_incremental_if_available = False +use_line_by_line_parser = False + use_adaptive_if_available = False # currently broken adaptive_reparse_usage_weight = 10 diff --git a/coconut/tests/src/extras.coco b/coconut/tests/src/extras.coco index c81fe0cf7..7b6811635 100644 --- a/coconut/tests/src/extras.coco +++ b/coconut/tests/src/extras.coco @@ -209,10 +209,15 @@ cannot reassign type variable 'T' (use explicit '\T' syntax if intended) (line 1 assert_raises(-> parse("@"), CoconutParseError) assert_raises(-> parse("range(1,10) |> reduce$(*, initializer = 1000) |> print"), CoconutParseError, err_has=( "\n \\~~^", + "\n \\~~~~~~~~~~~~~~~~~~~~~~~^", + )) + assert_raises(-> parse("a := b"), CoconutParseError, err_has=( + "\n ^", + "\n \\~^", )) - assert_raises(-> parse("a := b"), CoconutParseError, err_has="\n ^") assert_raises(-> parse("1 + return"), CoconutParseError, err_has=( "\n \\~~^", + "\n \\~~~~^", )) assert_raises(-> parse(""" def f() = @@ -229,22 +234,41 @@ def f() = ~^ """.strip(), )) - assert_raises(-> parse('b"abc" "def"'), CoconutParseError, err_has="\n ^") - assert_raises(-> parse('"abc" b"def"'), CoconutParseError, err_has="\n ^") - assert_raises(-> parse('"a" 10'), CoconutParseError, err_has="\n ^") - assert_raises(-> parse("A. ."), CoconutParseError, err_has="\n \\~^") + assert_raises(-> parse('b"abc" "def"'), CoconutParseError, err_has=( + "\n ^", + "\n \\~~~~~~^", + )) + assert_raises(-> parse('"abc" b"def"'), CoconutParseError, err_has=( + "\n ^", + "\n \\~~~~~^", + )) + assert_raises(-> parse('"a" 10'), CoconutParseError, err_has=( + "\n ^", + "\n \\~~~^", + )) + assert_raises(-> parse("A. ."), CoconutParseError, err_has=( + "\n \\~^", + "\n \\~~^", + )) + assert_raises(-> parse("f([] {})"), CoconutParseError, err_has=( + "\n \\~~~^", + "\n \\~~~~^", + )) + assert_raises(-> parse("0xfgf"), CoconutParseError, err_has=( + "\n ^", + "\n \\~^", + "\n \\~~^", + )) + assert_raises(-> parse("(. if 1)"), CoconutParseError, err_has=( + "\n ^", + "\n \\~~^", + )) + assert_raises(-> parse('''f"""{ }"""'''), CoconutSyntaxError, err_has="parsing failed for format string expression") - assert_raises(-> parse("f([] {})"), CoconutParseError, err_has="\n \\~~~^") - assert_raises(-> parse("return = 1"), CoconutParseError, err_has='invalid use of the keyword "return"') assert_raises(-> parse("if a = b: pass"), CoconutParseError, err_has="misplaced assignment") assert_raises(-> parse("while a == b"), CoconutParseError, err_has="misplaced newline") - assert_raises(-> parse("0xfgf"), CoconutParseError, err_has=( - "\n ^", - "\n \\~^", - )) - assert_raises(-> parse("(. if 1)"), CoconutParseError, err_has="\n ^") try: parse(""" From bc0bad153987ed37a1020a69c61cbffdc412093d Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Sat, 23 Dec 2023 01:44:23 -0800 Subject: [PATCH 25/54] Fix py2 --- coconut/compiler/compiler.py | 2 +- coconut/compiler/util.py | 64 +++++++++++++++++------------------- coconut/constants.py | 3 -- 3 files changed, 31 insertions(+), 38 deletions(-) diff --git a/coconut/compiler/compiler.py b/coconut/compiler/compiler.py index 49aa171f4..036aefc25 100644 --- a/coconut/compiler/compiler.py +++ b/coconut/compiler/compiler.py @@ -1269,7 +1269,7 @@ def make_err(self, errtype, message, original, loc=0, ln=None, extra=None, refor ln, endpoint=endpt_in_snip, filename=self.filename, - **kwargs, + **kwargs # no comma ).set_formatting( point_to_endpoint=True if use_startpoint else None, max_err_msg_lines=2 if use_startpoint else None, diff --git a/coconut/compiler/util.py b/coconut/compiler/util.py index d0925e277..4d7074ca6 100644 --- a/coconut/compiler/util.py +++ b/coconut/compiler/util.py @@ -123,11 +123,9 @@ unwrapper, incremental_cache_limit, incremental_mode_cache_successes, - adaptive_reparse_usage_weight, use_adaptive_any_of, disable_incremental_for_len, coconut_cache_dir, - use_adaptive_if_available, use_fast_pyparsing_reprs, save_new_cache_items, cache_validation_info, @@ -454,35 +452,6 @@ def attach(item, action, ignore_no_tokens=None, ignore_one_token=None, ignore_ar return add_action(item, action, make_copy) -@contextmanager -def adaptive_manager(original, loc, item, reparse=False): - """Manage the use of MatchFirst.setAdaptiveMode.""" - if reparse: - cleared_cache = clear_packrat_cache() - if cleared_cache is not True: - item.include_in_packrat_context = True - MatchFirst.setAdaptiveMode(False, usage_weight=adaptive_reparse_usage_weight) - try: - yield - finally: - MatchFirst.setAdaptiveMode(False, usage_weight=1) - if cleared_cache is not True: - item.include_in_packrat_context = False - else: - MatchFirst.setAdaptiveMode(True) - try: - yield - except Exception as exc: - if DEVELOP: - logger.log("reparsing due to:", exc) - logger.record_stat("adaptive", False) - else: - if DEVELOP: - logger.record_stat("adaptive", True) - finally: - MatchFirst.setAdaptiveMode(False) - - def final_evaluate_tokens(tokens): """Same as evaluate_tokens but should only be used once a parse is assured.""" result = evaluate_tokens(tokens, is_final=True) @@ -493,8 +462,6 @@ def final_evaluate_tokens(tokens): def final(item): """Collapse the computation graph upon parsing the given item.""" - if SUPPORTS_ADAPTIVE and use_adaptive_if_available: - item = Wrap(item, adaptive_manager, greedy=True) # evaluate_tokens expects a computation graph, so we just call add_action directly return add_action(trace(item), final_evaluate_tokens) @@ -2040,7 +2007,7 @@ def sub_all(inputstr, regexes, replacements): # ----------------------------------------------------------------------------------------------------------------------- -# PYTEST: +# EXTRAS: # ----------------------------------------------------------------------------------------------------------------------- @@ -2071,3 +2038,32 @@ def pytest_rewrite_asserts(code, module_name=reserved_prefix + "_pytest_module") rewrite_asserts(tree, module_name) fixed_tree = ast.fix_missing_locations(FixPytestNames().visit(tree)) return ast.unparse(fixed_tree) + + +@contextmanager +def adaptive_manager(original, loc, item, reparse=False): + """Manage the use of MatchFirst.setAdaptiveMode.""" + if reparse: + cleared_cache = clear_packrat_cache() + if cleared_cache is not True: + item.include_in_packrat_context = True + MatchFirst.setAdaptiveMode(False, usage_weight=10) + try: + yield + finally: + MatchFirst.setAdaptiveMode(False, usage_weight=1) + if cleared_cache is not True: + item.include_in_packrat_context = False + else: + MatchFirst.setAdaptiveMode(True) + try: + yield + except Exception as exc: + if DEVELOP: + logger.log("reparsing due to:", exc) + logger.record_stat("adaptive", False) + else: + if DEVELOP: + logger.record_stat("adaptive", True) + finally: + MatchFirst.setAdaptiveMode(False) diff --git a/coconut/constants.py b/coconut/constants.py index dd18ca060..2ff67abea 100644 --- a/coconut/constants.py +++ b/coconut/constants.py @@ -148,9 +148,6 @@ def get_path_env_var(env_var, default): use_line_by_line_parser = False -use_adaptive_if_available = False # currently broken -adaptive_reparse_usage_weight = 10 - # these only apply to use_incremental_if_available, not compiler.util.enable_incremental_parsing() default_incremental_cache_size = None repeatedly_clear_incremental_cache = True From b88803b927a48af7e4ff75885a80631137062ab4 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Sat, 23 Dec 2023 02:04:02 -0800 Subject: [PATCH 26/54] Improve tco disabling --- coconut/constants.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/coconut/constants.py b/coconut/constants.py index 2ff67abea..86f1592bc 100644 --- a/coconut/constants.py +++ b/coconut/constants.py @@ -439,11 +439,11 @@ def get_path_env_var(env_var, default): r"locals", r"globals", r"(py_)?super", - r"(typing\.)?cast", - r"(sys\.)?exc_info", - r"(sys\.)?_getframe", - r"(sys\.)?_current_frames", - r"(sys\.)?_current_exceptions", + r"cast", + r"exc_info", + r"sys\.[a-zA-Z0-9_.]+", + r"traceback\.[a-zA-Z0-9_.]+", + r"typing\.[a-zA-Z0-9_.]+", ) py3_to_py2_stdlib = { From 788f8575eb07a9dcb8a98c6c255599d02d7f0f1b Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Wed, 17 Jan 2024 01:48:32 -0800 Subject: [PATCH 27/54] Fix parsing inconsistencies Resolves #819. --- coconut/compiler/compiler.py | 12 ++---------- coconut/compiler/grammar.py | 5 ++++- coconut/compiler/util.py | 11 ++++++++--- coconut/constants.py | 2 +- coconut/root.py | 2 +- coconut/terminal.py | 2 +- coconut/tests/src/cocotest/agnostic/primary_2.coco | 4 ++++ 7 files changed, 21 insertions(+), 17 deletions(-) diff --git a/coconut/compiler/compiler.py b/coconut/compiler/compiler.py index 036aefc25..21be94792 100644 --- a/coconut/compiler/compiler.py +++ b/coconut/compiler/compiler.py @@ -1630,16 +1630,8 @@ def str_proc(self, inputstring, **kwargs): # start the string hold if we're at the start of a string if hold is not None: - is_f = False - j = i - len(hold["start"]) - while j >= 0: - prev_c = inputstring[j] - if prev_c == "f": - is_f = True - break - elif prev_c != "r": - break - j -= 1 + is_f_check_str = inputstring[clip(i - len(hold["start"]) + 1 - self.start_f_str_regex_len, min=0): i - len(hold["start"]) + 1] + is_f = self.start_f_str_regex.search(is_f_check_str) if is_f: hold.update({ "type": "f string", diff --git a/coconut/compiler/grammar.py b/coconut/compiler/grammar.py index 0a30dd146..3cf73fb22 100644 --- a/coconut/compiler/grammar.py +++ b/coconut/compiler/grammar.py @@ -817,7 +817,7 @@ class Grammar(object): octint = combine(Word("01234567") + ZeroOrMore(underscore.suppress() + Word("01234567"))) hexint = combine(Word(hexnums) + ZeroOrMore(underscore.suppress() + Word(hexnums))) - imag_j = caseless_literal("j") | fixto(caseless_literal("i", suppress=True), "j") + imag_j = caseless_literal("j") | fixto(caseless_literal("i", suppress=True, disambiguate=True), "j") basenum = combine( Optional(integer) + dot + integer | integer + Optional(dot + Optional(integer)) @@ -2660,6 +2660,9 @@ class Grammar(object): | fixto(end_of_line, "misplaced newline (maybe missing ':')") ) + start_f_str_regex = compile_regex(r"\br?fr?$") + start_f_str_regex_len = 4 + end_f_str_expr = combine(start_marker + (rbrace | colon | bang)) string_start = start_marker + python_quoted_string diff --git a/coconut/compiler/util.py b/coconut/compiler/util.py index 4d7074ca6..5e2cd75ac 100644 --- a/coconut/compiler/util.py +++ b/coconut/compiler/util.py @@ -131,6 +131,7 @@ cache_validation_info, require_cache_clear_frac, reverse_any_of, + all_keywords, ) from coconut.exceptions import ( CoconutException, @@ -1537,12 +1538,16 @@ def any_len_perm_at_least_one(*elems, **kwargs): return any_len_perm_with_one_of_each_group(*groups_and_elems) -def caseless_literal(literalstr, suppress=False): +def caseless_literal(literalstr, suppress=False, disambiguate=False): """Version of CaselessLiteral that always parses to the given literalstr.""" + out = CaselessLiteral(literalstr) if suppress: - return CaselessLiteral(literalstr).suppress() + out = out.suppress() else: - return fixto(CaselessLiteral(literalstr), literalstr) + out = fixto(out, literalstr) + if disambiguate: + out = disallow_keywords(k for k in all_keywords if k.startswith((literalstr[0].lower(), literalstr[0].upper()))) + out + return out # ----------------------------------------------------------------------------------------------------------------------- diff --git a/coconut/constants.py b/coconut/constants.py index 86f1592bc..146c9210e 100644 --- a/coconut/constants.py +++ b/coconut/constants.py @@ -37,7 +37,7 @@ def fixpath(path): return os.path.normpath(os.path.realpath(os.path.expanduser(path))) -def get_bool_env_var(env_var, default=False): +def get_bool_env_var(env_var, default=None): """Get a boolean from an environment variable.""" boolstr = os.getenv(env_var, "").lower() if boolstr in ("true", "yes", "on", "1", "t"): diff --git a/coconut/root.py b/coconut/root.py index ffe859e89..22fca3377 100644 --- a/coconut/root.py +++ b/coconut/root.py @@ -26,7 +26,7 @@ VERSION = "3.0.4" VERSION_NAME = None # False for release, int >= 1 for develop -DEVELOP = 12 +DEVELOP = 13 ALPHA = False # for pre releases rather than post releases assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1" diff --git a/coconut/terminal.py b/coconut/terminal.py index 7247c7641..ee1a9335c 100644 --- a/coconut/terminal.py +++ b/coconut/terminal.py @@ -210,7 +210,7 @@ def __init__(self, other=None): @classmethod def enable_colors(cls, file=None): """Attempt to enable CLI colors.""" - use_color = get_bool_env_var(use_color_env_var) + use_color = get_bool_env_var(use_color_env_var, default=None) if ( use_color is False or use_color is None and file is not None and not isatty(file) diff --git a/coconut/tests/src/cocotest/agnostic/primary_2.coco b/coconut/tests/src/cocotest/agnostic/primary_2.coco index ccee37e55..3d7acdbd8 100644 --- a/coconut/tests/src/cocotest/agnostic/primary_2.coco +++ b/coconut/tests/src/cocotest/agnostic/primary_2.coco @@ -424,6 +424,10 @@ def primary_test_2() -> bool: assert all_equal([], to=10) assert all_equal([10; 10; 10; 10], to=10) assert not all_equal([1, 1], to=10) + assert not 0in[1,2,3] + if"0":assert True + if"0": + assert True with process_map.multiple_sequential_calls(): # type: ignore assert map((+), range(3), range(4)$[:-1], strict=True) |> list == [0, 2, 4] == process_map((+), range(3), range(4)$[:-1], strict=True) |> list # type: ignore From a53a137faec7496ba07e58c38dafbdf65f64f55e Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Wed, 17 Jan 2024 02:18:11 -0800 Subject: [PATCH 28/54] Fix line splitting Resolves #818. --- coconut/compiler/compiler.py | 29 ++++++++++++++--------------- coconut/compiler/util.py | 3 ++- coconut/exceptions.py | 4 ++-- coconut/root.py | 2 +- coconut/tests/src/extras.coco | 1 + coconut/util.py | 15 +++++++++------ 6 files changed, 29 insertions(+), 25 deletions(-) diff --git a/coconut/compiler/compiler.py b/coconut/compiler/compiler.py index 21be94792..54465a841 100644 --- a/coconut/compiler/compiler.py +++ b/coconut/compiler/compiler.py @@ -97,7 +97,7 @@ pickleable_obj, checksum, clip, - logical_lines, + literal_lines, clean, get_target_info, get_clock_time, @@ -1240,7 +1240,7 @@ def make_err(self, errtype, message, original, loc=0, ln=None, extra=None, refor ln = self.outer_ln # get line indices for the error locs - original_lines = tuple(logical_lines(original, True)) + original_lines = tuple(literal_lines(original, True)) loc_line_ind = clip(lineno(loc, original) - 1, max=len(original_lines) - 1) # build the source snippet that the error is referring to @@ -1449,7 +1449,7 @@ def prepare(self, inputstring, strip=False, nl_at_eof_check=False, **kwargs): if self.strict and nl_at_eof_check and inputstring and not inputstring.endswith("\n"): end_index = len(inputstring) - 1 if inputstring else 0 raise self.make_err(CoconutStyleError, "missing new line at end of file", inputstring, end_index) - kept_lines = inputstring.splitlines() + kept_lines = tuple(literal_lines(inputstring)) self.num_lines = len(kept_lines) if self.keep_lines: self.kept_lines = kept_lines @@ -1719,7 +1719,7 @@ def operator_proc(self, inputstring, keep_state=False, **kwargs): """Process custom operator definitions.""" out = [] skips = self.copy_skips() - for i, raw_line in enumerate(logical_lines(inputstring, keep_newlines=True)): + for i, raw_line in enumerate(literal_lines(inputstring, keep_newlines=True)): ln = i + 1 base_line = rem_comment(raw_line) stripped_line = base_line.lstrip() @@ -1806,7 +1806,7 @@ def leading_whitespace(self, inputstring): def ind_proc(self, inputstring, **kwargs): """Process indentation and ensure balanced parentheses.""" - lines = tuple(logical_lines(inputstring)) + lines = tuple(literal_lines(inputstring)) new = [] # new lines current = None # indentation level of previous line levels = [] # indentation levels of all previous blocks, newest at end @@ -1899,11 +1899,8 @@ def reind_proc(self, inputstring, ignore_errors=False, **kwargs): out_lines = [] level = 0 - next_line_is_fake = False - for line in inputstring.splitlines(True): - is_fake = next_line_is_fake - next_line_is_fake = line.endswith("\f") and line.rstrip("\f") == line.rstrip() - + is_fake = False + for next_line_is_real, line in literal_lines(inputstring, True, yield_next_line_is_real=True): line, comment = split_comment(line.strip()) indent, line = split_leading_indent(line) @@ -1932,6 +1929,8 @@ def reind_proc(self, inputstring, ignore_errors=False, **kwargs): line = (line + comment).rstrip() out_lines.append(line) + is_fake = not next_line_is_real + if not ignore_errors and level != 0: logger.log_lambda(lambda: "failed to reindent:\n" + inputstring) complain("non-zero final indentation level: " + repr(level)) @@ -1978,7 +1977,7 @@ def endline_repl(self, inputstring, reformatting=False, ignore_errors=False, **k """Add end of line comments.""" out_lines = [] ln = 1 # line number in pre-processed original - for line in logical_lines(inputstring): + for line in literal_lines(inputstring): add_one_to_ln = False try: @@ -2331,7 +2330,7 @@ def transform_returns(self, original, loc, raw_lines, tre_return_grammar=None, i def proc_funcdef(self, original, loc, decorators, funcdef, is_async, in_method, is_stmt_lambda): """Determines if TCO or TRE can be done and if so does it, handles dotted function names, and universalizes async functions.""" - raw_lines = list(logical_lines(funcdef, True)) + raw_lines = list(literal_lines(funcdef, True)) def_stmt = raw_lines.pop(0) out = [] @@ -2684,7 +2683,7 @@ def deferred_code_proc(self, inputstring, add_code_at_start=False, ignore_names= self.compile_add_code_before_regexes() out = [] - for raw_line in inputstring.splitlines(True): + for raw_line in literal_lines(inputstring, True): bef_ind, line, aft_ind = split_leading_trailing_indent(raw_line) # look for deferred errors @@ -2707,7 +2706,7 @@ def deferred_code_proc(self, inputstring, add_code_at_start=False, ignore_names= # handle any non-function code that was added before the funcdef pre_def_lines = [] post_def_lines = [] - funcdef_lines = list(logical_lines(funcdef, True)) + funcdef_lines = list(literal_lines(funcdef, True)) for i, line in enumerate(funcdef_lines): if self.def_regex.match(line): pre_def_lines = funcdef_lines[:i] @@ -3128,7 +3127,7 @@ def yield_from_handle(self, loc, tokens): def endline_handle(self, original, loc, tokens): """Add line number information to end of line.""" endline, = tokens - lines = endline.splitlines(True) + lines = tuple(literal_lines(endline, True)) if self.minify: lines = lines[0] out = [] diff --git a/coconut/compiler/util.py b/coconut/compiler/util.py index 5e2cd75ac..a29e1821c 100644 --- a/coconut/compiler/util.py +++ b/coconut/compiler/util.py @@ -88,6 +88,7 @@ univ_open, ensure_dir, get_clock_time, + literal_lines, ) from coconut.terminal import ( logger, @@ -1839,7 +1840,7 @@ def is_blank(line): def final_indentation_level(code): """Determine the final indentation level of the given code.""" level = 0 - for line in code.splitlines(): + for line in literal_lines(code): leading_indent, _, trailing_indent = split_leading_trailing_indent(line) level += ind_change(leading_indent) + ind_change(trailing_indent) return level diff --git a/coconut/exceptions.py b/coconut/exceptions.py index 9edf9f840..89843a428 100644 --- a/coconut/exceptions.py +++ b/coconut/exceptions.py @@ -35,7 +35,7 @@ from coconut.util import ( pickleable_obj, clip, - logical_lines, + literal_lines, clean, get_displayable_target, normalize_newlines, @@ -140,7 +140,7 @@ def message(self, message, source, point, ln, extra=None, endpoint=None, filenam point_ind = getcol(point, source) - 1 endpoint_ind = getcol(endpoint, source) - 1 - source_lines = tuple(logical_lines(source, keep_newlines=True)) + source_lines = tuple(literal_lines(source, keep_newlines=True)) # walk the endpoint line back until it points to real text while endpoint_ln > point_ln and not "".join(source_lines[endpoint_ln - 1:endpoint_ln]).strip(): diff --git a/coconut/root.py b/coconut/root.py index 22fca3377..40da17bcc 100644 --- a/coconut/root.py +++ b/coconut/root.py @@ -26,7 +26,7 @@ VERSION = "3.0.4" VERSION_NAME = None # False for release, int >= 1 for develop -DEVELOP = 13 +DEVELOP = 14 ALPHA = False # for pre releases rather than post releases assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1" diff --git a/coconut/tests/src/extras.coco b/coconut/tests/src/extras.coco index 7b6811635..f326c7416 100644 --- a/coconut/tests/src/extras.coco +++ b/coconut/tests/src/extras.coco @@ -116,6 +116,7 @@ def test_setup_none() -> bool: assert "==" not in parse("None = None") assert parse("(1\f+\f2)", "lenient") == "(1 + 2)" == parse("(1\f+\f2)", "eval") assert "Ellipsis" not in parse("x: ... = 1") + assert parse("linebreaks = '\x0b\x0c\x1c\x1d\x1e'") # things that don't parse correctly without the computation graph if USE_COMPUTATION_GRAPH: diff --git a/coconut/util.py b/coconut/util.py index 3862af193..fb9c9207c 100644 --- a/coconut/util.py +++ b/coconut/util.py @@ -151,8 +151,8 @@ def clip(num, min=None, max=None): ) -def logical_lines(text, keep_newlines=False): - """Iterate over the logical code lines in text.""" +def literal_lines(text, keep_newlines=False, yield_next_line_is_real=False): + """Iterate over the literal code lines in text.""" prev_content = None for line in text.splitlines(True): real_line = True @@ -163,11 +163,14 @@ def logical_lines(text, keep_newlines=False): if not keep_newlines: line = line[:-1] else: - if prev_content is None: - prev_content = "" - prev_content += line + if not yield_next_line_is_real: + if prev_content is None: + prev_content = "" + prev_content += line real_line = False - if real_line: + if yield_next_line_is_real: + yield real_line, line + elif real_line: if prev_content is not None: line = prev_content + line prev_content = None From 17ff2a3170920b83de36d91f77858e5873f01136 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Wed, 17 Jan 2024 18:44:37 -0800 Subject: [PATCH 29/54] Fix walrus in subscripts Resolves #820. --- coconut/compiler/grammar.py | 14 ++++++++------ coconut/root.py | 2 +- .../tests/src/cocotest/target_311/py311_test.coco | 2 ++ 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/coconut/compiler/grammar.py b/coconut/compiler/grammar.py index 3cf73fb22..a6ded70cc 100644 --- a/coconut/compiler/grammar.py +++ b/coconut/compiler/grammar.py @@ -649,8 +649,8 @@ class Grammar(object): unsafe_fat_arrow = Literal("=>") | fixto(Literal("\u21d2"), "=>") colon_eq = Literal(":=") unsafe_dubcolon = Literal("::") - unsafe_colon = Literal(":") colon = disambiguate_literal(":", ["::", ":="]) + indexing_colon = disambiguate_literal(":", [":="]) # same as : but :: is allowed lt_colon = Literal("<:") semicolon = Literal(";") | invalid_syntax("\u037e", "invalid Greek question mark instead of semicolon", greedy=True) multisemicolon = combine(OneOrMore(semicolon)) @@ -1199,21 +1199,23 @@ class Grammar(object): | op_item ) + # for .[] subscript_star = Forward() subscript_star_ref = star slicetest = Optional(test_no_chain) - sliceop = condense(unsafe_colon + slicetest) + sliceop = condense(indexing_colon + slicetest) subscript = condense( slicetest + sliceop + Optional(sliceop) - | Optional(subscript_star) + test + | Optional(subscript_star) + new_namedexpr_test ) - subscriptlist = itemlist(subscript, comma, suppress_trailing=False) | new_namedexpr_test + subscriptlist = itemlist(subscript, comma, suppress_trailing=False) + # for .$[] slicetestgroup = Optional(test_no_chain, default="") - sliceopgroup = unsafe_colon.suppress() + slicetestgroup + sliceopgroup = indexing_colon.suppress() + slicetestgroup subscriptgroup = attach( slicetestgroup + sliceopgroup + Optional(sliceopgroup) - | test, + | new_namedexpr_test, subscriptgroup_handle, ) subscriptgrouplist = itemlist(subscriptgroup, comma) diff --git a/coconut/root.py b/coconut/root.py index 40da17bcc..3eed681c3 100644 --- a/coconut/root.py +++ b/coconut/root.py @@ -26,7 +26,7 @@ VERSION = "3.0.4" VERSION_NAME = None # False for release, int >= 1 for develop -DEVELOP = 14 +DEVELOP = 15 ALPHA = False # for pre releases rather than post releases assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1" diff --git a/coconut/tests/src/cocotest/target_311/py311_test.coco b/coconut/tests/src/cocotest/target_311/py311_test.coco index a2c655815..c527cf3a4 100644 --- a/coconut/tests/src/cocotest/target_311/py311_test.coco +++ b/coconut/tests/src/cocotest/target_311/py311_test.coco @@ -7,4 +7,6 @@ def py311_test() -> bool: except* ValueError as err: got_err = err assert repr(got_err) == repr(multi_err), (got_err, multi_err) + assert [1, 2, 3][x := 1] == 2 + assert x == 1 return True From e357041e4014ffe1bf3e38a2cdd1904e05c3d058 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Wed, 17 Jan 2024 18:56:49 -0800 Subject: [PATCH 30/54] Allow strings in impl calls Resolves #821. --- Makefile | 6 +++--- coconut/compiler/grammar.py | 1 - coconut/root.py | 2 +- coconut/tests/src/cocotest/agnostic/primary_2.coco | 3 +++ coconut/tests/src/extras.coco | 4 ---- 5 files changed, 7 insertions(+), 9 deletions(-) diff --git a/Makefile b/Makefile index 93742e5e7..eb2094c8f 100644 --- a/Makefile +++ b/Makefile @@ -141,7 +141,7 @@ test-any-of: test-univ .PHONY: test-mypy-univ test-mypy-univ: export COCONUT_USE_COLOR=TRUE test-mypy-univ: clean - python ./coconut/tests --strict --keep-lines --force --mypy --follow-imports silent --ignore-missing-imports --allow-redefinition + python ./coconut/tests --strict --keep-lines --force --no-cache --mypy --follow-imports silent --ignore-missing-imports --allow-redefinition python ./coconut/tests/dest/runner.py python ./coconut/tests/dest/extras.py @@ -149,7 +149,7 @@ test-mypy-univ: clean .PHONY: test-mypy test-mypy: export COCONUT_USE_COLOR=TRUE test-mypy: clean - python ./coconut/tests --strict --keep-lines --force --target sys --mypy --follow-imports silent --ignore-missing-imports --allow-redefinition + python ./coconut/tests --strict --keep-lines --force --target sys --no-cache --mypy --follow-imports silent --ignore-missing-imports --allow-redefinition python ./coconut/tests/dest/runner.py python ./coconut/tests/dest/extras.py @@ -198,7 +198,7 @@ test-mypy-verbose: clean .PHONY: test-mypy-all test-mypy-all: export COCONUT_USE_COLOR=TRUE test-mypy-all: clean - python ./coconut/tests --strict --keep-lines --force --target sys --mypy --follow-imports silent --ignore-missing-imports --allow-redefinition --check-untyped-defs + python ./coconut/tests --strict --keep-lines --force --target sys --no-cache --mypy --follow-imports silent --ignore-missing-imports --allow-redefinition --check-untyped-defs python ./coconut/tests/dest/runner.py python ./coconut/tests/dest/extras.py diff --git a/coconut/compiler/grammar.py b/coconut/compiler/grammar.py index a6ded70cc..b32c75957 100644 --- a/coconut/compiler/grammar.py +++ b/coconut/compiler/grammar.py @@ -1443,7 +1443,6 @@ class Grammar(object): ) + Optional(power_in_impl_call)) impl_call_item = condense( disallow_keywords(reserved_vars) - + ~any_string + ~non_decimal_num + atom_item + Optional(power_in_impl_call) diff --git a/coconut/root.py b/coconut/root.py index 3eed681c3..87b577dcc 100644 --- a/coconut/root.py +++ b/coconut/root.py @@ -26,7 +26,7 @@ VERSION = "3.0.4" VERSION_NAME = None # False for release, int >= 1 for develop -DEVELOP = 15 +DEVELOP = 16 ALPHA = False # for pre releases rather than post releases assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1" diff --git a/coconut/tests/src/cocotest/agnostic/primary_2.coco b/coconut/tests/src/cocotest/agnostic/primary_2.coco index 3d7acdbd8..9876cabe1 100644 --- a/coconut/tests/src/cocotest/agnostic/primary_2.coco +++ b/coconut/tests/src/cocotest/agnostic/primary_2.coco @@ -428,6 +428,9 @@ def primary_test_2() -> bool: if"0":assert True if"0": assert True + b = "b" + assert "abc".find b == 1 + assert_raises(-> "a" 10, TypeError) with process_map.multiple_sequential_calls(): # type: ignore assert map((+), range(3), range(4)$[:-1], strict=True) |> list == [0, 2, 4] == process_map((+), range(3), range(4)$[:-1], strict=True) |> list # type: ignore diff --git a/coconut/tests/src/extras.coco b/coconut/tests/src/extras.coco index f326c7416..0b7a55289 100644 --- a/coconut/tests/src/extras.coco +++ b/coconut/tests/src/extras.coco @@ -243,10 +243,6 @@ def f() = "\n ^", "\n \\~~~~~^", )) - assert_raises(-> parse('"a" 10'), CoconutParseError, err_has=( - "\n ^", - "\n \\~~~^", - )) assert_raises(-> parse("A. ."), CoconutParseError, err_has=( "\n \\~^", "\n \\~~^", From 5ff7425e452f08ba407216a80d8bce0095c22528 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Fri, 19 Jan 2024 01:10:15 -0800 Subject: [PATCH 31/54] Fix some unicode alts Resolves #822. --- coconut/compiler/grammar.py | 6 +- coconut/compiler/util.py | 7 +- coconut/constants.py | 71 +++++++++---------- coconut/highlighter.py | 4 +- coconut/root.py | 2 +- .../src/cocotest/agnostic/primary_2.coco | 3 + .../tests/src/cocotest/agnostic/suite.coco | 2 + .../tests/src/cocotest/target_3/py3_test.coco | 6 +- 8 files changed, 51 insertions(+), 50 deletions(-) diff --git a/coconut/compiler/grammar.py b/coconut/compiler/grammar.py index b32c75957..b64d040fe 100644 --- a/coconut/compiler/grammar.py +++ b/coconut/compiler/grammar.py @@ -673,9 +673,9 @@ class Grammar(object): pipe = Literal("|>") | fixto(Literal("\u21a6"), "|>") star_pipe = Literal("|*>") | fixto(Literal("*\u21a6"), "|*>") dubstar_pipe = Literal("|**>") | fixto(Literal("**\u21a6"), "|**>") - back_pipe = Literal("<|") | fixto(Literal("\u21a4"), "<|") - back_star_pipe = Literal("<*|") | ~Literal("\u21a4**") + fixto(Literal("\u21a4*"), "<*|") - back_dubstar_pipe = Literal("<**|") | fixto(Literal("\u21a4**"), "<**|") + back_pipe = Literal("<|") | disambiguate_literal("\u21a4", ["\u21a4*", "\u21a4?"], fixesto="<|") + back_star_pipe = Literal("<*|") | disambiguate_literal("\u21a4*", ["\u21a4**", "\u21a4*?"], fixesto="<*|") + back_dubstar_pipe = Literal("<**|") | disambiguate_literal("\u21a4**", ["\u21a4**?"], fixesto="<**|") none_pipe = Literal("|?>") | fixto(Literal("?\u21a6"), "|?>") none_star_pipe = ( Literal("|?*>") diff --git a/coconut/compiler/util.py b/coconut/compiler/util.py index a29e1821c..6de9b871f 100644 --- a/coconut/compiler/util.py +++ b/coconut/compiler/util.py @@ -1454,12 +1454,15 @@ def disallow_keywords(kwds, with_suffix=""): return regex_item(r"(?!" + "|".join(to_disallow) + r")").suppress() -def disambiguate_literal(literal, not_literals): +def disambiguate_literal(literal, not_literals, fixesto=None): """Get an item that matchesl literal and not any of not_literals.""" - return regex_item( + item = regex_item( r"(?!" + "|".join(re.escape(s) for s in not_literals) + ")" + re.escape(literal) ) + if fixesto is not None: + item = fixto(item, fixesto) + return item def any_keyword_in(kwds): diff --git a/coconut/constants.py b/coconut/constants.py index 146c9210e..269fca1d5 100644 --- a/coconut/constants.py +++ b/coconut/constants.py @@ -576,45 +576,34 @@ def get_path_env_var(env_var, default): ) python_builtins = ( - '__import__', 'abs', 'all', 'any', 'bin', 'bool', 'bytearray', - 'breakpoint', 'bytes', 'chr', 'classmethod', 'compile', 'complex', - 'delattr', 'dict', 'dir', 'divmod', 'enumerate', 'eval', 'filter', - 'float', 'format', 'frozenset', 'getattr', 'globals', 'hasattr', - 'hash', 'hex', 'id', 'input', 'int', 'isinstance', 'issubclass', - 'iter', 'len', 'list', 'locals', 'map', 'max', 'memoryview', - 'min', 'next', 'object', 'oct', 'open', 'ord', 'pow', 'print', - 'property', 'range', 'repr', 'reversed', 'round', 'set', 'setattr', - 'slice', 'sorted', 'staticmethod', 'str', 'sum', 'super', 'tuple', - 'type', 'vars', 'zip', - 'Ellipsis', 'NotImplemented', - 'ArithmeticError', 'AssertionError', 'AttributeError', - 'BaseException', 'BufferError', 'BytesWarning', 'DeprecationWarning', - 'EOFError', 'EnvironmentError', 'Exception', 'FloatingPointError', - 'FutureWarning', 'GeneratorExit', 'IOError', 'ImportError', - 'ImportWarning', 'IndentationError', 'IndexError', 'KeyError', - 'KeyboardInterrupt', 'LookupError', 'MemoryError', 'NameError', - 'NotImplementedError', 'OSError', 'OverflowError', - 'PendingDeprecationWarning', 'ReferenceError', 'ResourceWarning', - 'RuntimeError', 'RuntimeWarning', 'StopIteration', - 'SyntaxError', 'SyntaxWarning', 'SystemError', 'SystemExit', - 'TabError', 'TypeError', 'UnboundLocalError', 'UnicodeDecodeError', - 'UnicodeEncodeError', 'UnicodeError', 'UnicodeTranslateError', - 'UnicodeWarning', 'UserWarning', 'ValueError', 'VMSError', - 'Warning', 'WindowsError', 'ZeroDivisionError', + "abs", "aiter", "all", "anext", "any", "ascii", + "bin", "bool", "breakpoint", "bytearray", "bytes", + "callable", "chr", "classmethod", "compile", "complex", + "delattr", "dict", "dir", "divmod", + "enumerate", "eval", "exec", + "filter", "float", "format", "frozenset", + "getattr", "globals", + "hasattr", "hash", "help", "hex", + "id", "input", "int", "isinstance", "issubclass", "iter", + "len", "list", "locals", + "map", "max", "memoryview", "min", + "next", + "object", "oct", "open", "ord", + "pow", "print", "property", + "range", "repr", "reversed", "round", + "set", "setattr", "slice", "sorted", "staticmethod", "str", "sum", "super", + "tuple", "type", + "vars", + "zip", + "__import__", '__name__', '__file__', '__annotations__', '__debug__', - # we treat these as coconut_exceptions so the highlighter will always know about them: - # 'ExceptionGroup', 'BaseExceptionGroup', - # don't include builtins that aren't always made available by Coconut: - # 'BlockingIOError', 'ChildProcessError', 'ConnectionError', - # 'BrokenPipeError', 'ConnectionAbortedError', 'ConnectionRefusedError', - # 'ConnectionResetError', 'FileExistsError', 'FileNotFoundError', - # 'InterruptedError', 'IsADirectoryError', 'NotADirectoryError', - # 'PermissionError', 'ProcessLookupError', 'TimeoutError', - # 'StopAsyncIteration', 'ModuleNotFoundError', 'RecursionError', - # 'EncodingWarning', +) + +python_exceptions = ( + "BaseException", "BaseExceptionGroup", "GeneratorExit", "KeyboardInterrupt", "SystemExit", "Exception", "ArithmeticError", "FloatingPointError", "OverflowError", "ZeroDivisionError", "AssertionError", "AttributeError", "BufferError", "EOFError", "ExceptionGroup", "BaseExceptionGroup", "ImportError", "ModuleNotFoundError", "LookupError", "IndexError", "KeyError", "MemoryError", "NameError", "UnboundLocalError", "OSError", "BlockingIOError", "ChildProcessError", "ConnectionError", "BrokenPipeError", "ConnectionAbortedError", "ConnectionRefusedError", "ConnectionResetError", "FileExistsError", "FileNotFoundError", "InterruptedError", "IsADirectoryError", "NotADirectoryError", "PermissionError", "ProcessLookupError", "TimeoutError", "ReferenceError", "RuntimeError", "NotImplementedError", "RecursionError", "StopAsyncIteration", "StopIteration", "SyntaxError", "IndentationError", "TabError", "SystemError", "TypeError", "ValueError", "UnicodeError", "UnicodeDecodeError", "UnicodeEncodeError", "UnicodeTranslateError", "Warning", "BytesWarning", "DeprecationWarning", "EncodingWarning", "FutureWarning", "ImportWarning", "PendingDeprecationWarning", "ResourceWarning", "RuntimeWarning", "SyntaxWarning", "UnicodeWarning", "UserWarning", ) # ----------------------------------------------------------------------------------------------------------------------- @@ -842,12 +831,16 @@ def get_path_env_var(env_var, default): coconut_exceptions = ( "MatchError", - "ExceptionGroup", - "BaseExceptionGroup", ) -highlight_builtins = coconut_specific_builtins + interp_only_builtins -all_builtins = frozenset(python_builtins + coconut_specific_builtins + coconut_exceptions) +highlight_builtins = coconut_specific_builtins + interp_only_builtins + python_builtins +highlight_exceptions = coconut_exceptions + python_exceptions +all_builtins = frozenset( + python_builtins + + python_exceptions + + coconut_specific_builtins + + coconut_exceptions +) magic_methods = ( "__fmap__", diff --git a/coconut/highlighter.py b/coconut/highlighter.py index cb6ce0e53..9bf2b1c71 100644 --- a/coconut/highlighter.py +++ b/coconut/highlighter.py @@ -36,7 +36,7 @@ shebang_regex, magic_methods, template_ext, - coconut_exceptions, + highlight_exceptions, main_prompt, style_env_var, default_style, @@ -100,7 +100,7 @@ class CoconutLexer(Python3Lexer): ] tokens["builtins"] += [ (words(highlight_builtins, suffix=r"\b"), Name.Builtin), - (words(coconut_exceptions, suffix=r"\b"), Name.Exception), + (words(highlight_exceptions, suffix=r"\b"), Name.Exception), ] tokens["numbers"] = [ (r"0b[01_]+", Number.Integer), diff --git a/coconut/root.py b/coconut/root.py index 87b577dcc..f9636a605 100644 --- a/coconut/root.py +++ b/coconut/root.py @@ -26,7 +26,7 @@ VERSION = "3.0.4" VERSION_NAME = None # False for release, int >= 1 for develop -DEVELOP = 16 +DEVELOP = 17 ALPHA = False # for pre releases rather than post releases assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1" diff --git a/coconut/tests/src/cocotest/agnostic/primary_2.coco b/coconut/tests/src/cocotest/agnostic/primary_2.coco index 9876cabe1..102a7bf88 100644 --- a/coconut/tests/src/cocotest/agnostic/primary_2.coco +++ b/coconut/tests/src/cocotest/agnostic/primary_2.coco @@ -431,6 +431,9 @@ def primary_test_2() -> bool: b = "b" assert "abc".find b == 1 assert_raises(-> "a" 10, TypeError) + assert (,) ↤* (1, 2, 3) == (1, 2, 3) + assert (,) ↤? None is None + assert (,) ↤*? None is None # type: ignore with process_map.multiple_sequential_calls(): # type: ignore assert map((+), range(3), range(4)$[:-1], strict=True) |> list == [0, 2, 4] == process_map((+), range(3), range(4)$[:-1], strict=True) |> list # type: ignore diff --git a/coconut/tests/src/cocotest/agnostic/suite.coco b/coconut/tests/src/cocotest/agnostic/suite.coco index 1b5309bf1..45d96810a 100644 --- a/coconut/tests/src/cocotest/agnostic/suite.coco +++ b/coconut/tests/src/cocotest/agnostic/suite.coco @@ -1077,6 +1077,8 @@ forward 2""") == 900 assert pickle_round_trip(.method(x=10)) <| (method=x -> x) == 10 assert sq_and_t2p1(10) == (100, 21) assert first_false_and_last_true([3, 2, 1, 0, "11", "1", ""]) == (0, "1") + assert ret_args_kwargs ↤** dict(a=1) == ((), dict(a=1)) + assert ret_args_kwargs ↤**? None is None with process_map.multiple_sequential_calls(): # type: ignore assert process_map(tuple <.. (|>)$(to_sort), qsorts) |> list == [to_sort |> sorted |> tuple] * len(qsorts) diff --git a/coconut/tests/src/cocotest/target_3/py3_test.coco b/coconut/tests/src/cocotest/target_3/py3_test.coco index acdef4f73..8ace419a2 100644 --- a/coconut/tests/src/cocotest/target_3/py3_test.coco +++ b/coconut/tests/src/cocotest/target_3/py3_test.coco @@ -27,14 +27,14 @@ def py3_test() -> bool: čeština = "czech" assert čeština == "czech" class HasExecMethod: - def exec(self, x) = x() + def \exec(self, x) = x() has_exec = HasExecMethod() assert hasattr(has_exec, "exec") assert has_exec.exec(-> 1) == 1 def exec_rebind_test(): - exec = 1 + \exec = 1 assert exec + 1 == 2 - def exec(x) = x + def \exec(x) = x assert exec(1) == 1 return True assert exec_rebind_test() is True From d132aafe245a3359b451fee7c7584184752fdbc4 Mon Sep 17 00:00:00 2001 From: inventshah <39803835+inventshah@users.noreply.github.com> Date: Fri, 19 Jan 2024 21:32:47 -0500 Subject: [PATCH 32/54] Fix function application style in pure-Python example for flatten built-in docs --- DOCS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DOCS.md b/DOCS.md index 139b798ff..8c8560575 100644 --- a/DOCS.md +++ b/DOCS.md @@ -3960,7 +3960,7 @@ flat_it = iter_of_iters |> flatten |> list ```coconut_python from itertools import chain iter_of_iters = [[1, 2], [3, 4]] -flat_it = iter_of_iters |> chain.from_iterable |> list +flat_it = list(chain.from_iterable(iter_of_iters)) ``` #### `scan` From 651755976d161ea974d4d97372e83dd292605b9c Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Fri, 19 Jan 2024 18:49:44 -0800 Subject: [PATCH 33/54] Fix formatting --- DOCS.md | 2 +- coconut/compiler/matching.py | 14 ++++++++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/DOCS.md b/DOCS.md index 8c8560575..2da55905a 100644 --- a/DOCS.md +++ b/DOCS.md @@ -2486,7 +2486,7 @@ where `` is defined as ``` where `` is the name of the function, `` is an optional additional check, `` is the body of the function, `` is defined by Coconut's [`match` statement](#match), `` is the optional default if no argument is passed, and `` is the optional return type annotation (note that argument type annotations are not supported for pattern-matching functions). The `match` keyword at the beginning is optional, but is sometimes necessary to disambiguate pattern-matching function definition from normal function definition, since Python function definition will always take precedence. Note that the `async` and `match` keywords can be in any order. -If `` has a variable name (either directly or with `as`), the resulting pattern-matching function will support keyword arguments using that variable name. +If `` has a variable name (via any variable binding that binds the entire pattern), the resulting pattern-matching function will support keyword arguments using that variable name. In addition to supporting pattern-matching in their arguments, pattern-matching function definitions also have a couple of notable differences compared to Python functions. Specifically: - If pattern-matching function definition fails, it will raise a [`MatchError`](#matcherror) (just like [destructuring assignment](#destructuring-assignment)) instead of a `TypeError`. diff --git a/coconut/compiler/matching.py b/coconut/compiler/matching.py index 96765f91a..e70bdf46e 100644 --- a/coconut/compiler/matching.py +++ b/coconut/compiler/matching.py @@ -1050,7 +1050,7 @@ def match_class(self, tokens, item): handle_indentation( """ raise _coconut.TypeError("too many positional args in class match (pattern requires {num_pos_matches}; '{cls_name}' only supports 1)") - """, + """, ).format( num_pos_matches=len(pos_matches), cls_name=cls_name, @@ -1063,13 +1063,15 @@ def match_class(self, tokens, item): other_cls_matcher.add_check("not _coconut.type(" + item + ") in _coconut_self_match_types") match_args_var = other_cls_matcher.get_temp_var() other_cls_matcher.add_def( - handle_indentation(""" + handle_indentation( + """ {match_args_var} = _coconut.getattr({cls_name}, '__match_args__', ()) {type_any} {type_ignore} if not _coconut.isinstance({match_args_var}, _coconut.tuple): raise _coconut.TypeError("{cls_name}.__match_args__ must be a tuple") if _coconut.len({match_args_var}) < {num_pos_matches}: raise _coconut.TypeError("too many positional args in class match (pattern requires {num_pos_matches}; '{cls_name}' only supports %s)" % (_coconut.len({match_args_var}),)) - """).format( + """, + ).format( cls_name=cls_name, match_args_var=match_args_var, num_pos_matches=len(pos_matches), @@ -1089,7 +1091,7 @@ def match_class(self, tokens, item): """ {match_args_var} = _coconut.getattr({cls_name}, '__match_args__', ()) {star_match_var} = _coconut.tuple(_coconut.getattr({item}, {match_args_var}[i]) for i in _coconut.range({num_pos_matches}, _coconut.len({match_args_var}))) - """, + """, ).format( match_args_var=self.get_temp_var(), cls_name=cls_name, @@ -1164,7 +1166,7 @@ def match_data_or_class(self, tokens, item): handle_indentation( """ {is_data_result_var} = _coconut.getattr({cls_name}, "{is_data_var}", False) or _coconut.isinstance({cls_name}, _coconut.tuple) and _coconut.all(_coconut.getattr(_coconut_x, "{is_data_var}", False) for _coconut_x in {cls_name}) {type_ignore} - """, + """, ).format( is_data_result_var=is_data_result_var, is_data_var=is_data_var, @@ -1241,7 +1243,7 @@ def match_view(self, tokens, item): {func_result_var} = _coconut_sentinel else: raise - """, + """, ).format( func_result_var=func_result_var, view_func=view_func, From e57c05c8d7207011523b8292e5f1da537cd1375a Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Fri, 19 Jan 2024 19:35:05 -0800 Subject: [PATCH 34/54] Add walrus partial in pipes Resolves #823. --- DOCS.md | 11 ++-- coconut/compiler/compiler.py | 51 ++++++++++++++----- coconut/compiler/grammar.py | 10 +++- coconut/root.py | 2 +- .../src/cocotest/target_38/py38_test.coco | 3 ++ 5 files changed, 56 insertions(+), 21 deletions(-) diff --git a/DOCS.md b/DOCS.md index 2da55905a..6290ebf87 100644 --- a/DOCS.md +++ b/DOCS.md @@ -688,9 +688,10 @@ Coconut uses pipe operators for pipeline-style function application. All the ope The None-aware pipe operators here are equivalent to a [monadic bind](https://en.wikipedia.org/wiki/Monad_(functional_programming)) treating the object as a `Maybe` monad composed of either `None` or the given object. Thus, `x |?> f` is equivalent to `None if x is None else f(x)`. Note that only the object being piped, not the function being piped into, may be `None` for `None`-aware pipes. -For working with `async` functions in pipes, all non-starred pipes support piping into `await` to await the awaitable piped into them, such that `x |> await` is equivalent to `await x`. - -Additionally, all pipe operators support a lambda as the last argument, despite lambdas having a lower precedence. Thus, `a |> x => b |> c` is equivalent to `a |> (x => b |> c)`, not `a |> (x => b) |> c`. +Additionally, some special syntax constructs are only available in pipes to enable doing as many operations as possible via pipes if so desired: +* For working with `async` functions in pipes, all non-starred pipes support piping into `await` to await the awaitable piped into them, such that `x |> await` is equivalent to `await x`. +* All non-starred pipes support piping into `( := .)` (mirroring the syntax for [operator implicit partials](#implicit-partial-application)) to assign the piped in item to ``. +* All pipe operators support a lambda as the last argument, despite lambdas having a lower precedence. Thus, `a |> x => b |> c` is equivalent to `a |> (x => b |> c)`, not `a |> (x => b) |> c`. _Note: To visually spread operations across several lines, just use [parenthetical continuation](#enhanced-parenthetical-continuation)._ @@ -1766,6 +1767,8 @@ _Deprecated: if the deprecated `->` is used in place of `=>`, then return type a Coconut uses a simple operator function short-hand: surround an operator with parentheses to retrieve its function. Similarly to iterator comprehensions, if the operator function is the only argument to a function, the parentheses of the function call can also serve as the parentheses for the operator function. +All operator functions also support [implicit partial application](#implicit-partial-application), e.g. `(. + 1)` is equivalent to `(=> _ + 1)`. + ##### Rationale A very common thing to do in functional programming is to make use of function versions of built-in operators: currying them, composing them, and piping them. To make this easy, Coconut provides a short-hand syntax to access operator functions. @@ -2486,7 +2489,7 @@ where `` is defined as ``` where `` is the name of the function, `` is an optional additional check, `` is the body of the function, `` is defined by Coconut's [`match` statement](#match), `` is the optional default if no argument is passed, and `` is the optional return type annotation (note that argument type annotations are not supported for pattern-matching functions). The `match` keyword at the beginning is optional, but is sometimes necessary to disambiguate pattern-matching function definition from normal function definition, since Python function definition will always take precedence. Note that the `async` and `match` keywords can be in any order. -If `` has a variable name (via any variable binding that binds the entire pattern), the resulting pattern-matching function will support keyword arguments using that variable name. +If `` has a variable name (via any variable binding that binds the entire pattern, e.g. `x` in `int(x)` or `[a, b] as x`), the resulting pattern-matching function will support keyword arguments using that variable name. In addition to supporting pattern-matching in their arguments, pattern-matching function definitions also have a couple of notable differences compared to Python functions. Specifically: - If pattern-matching function definition fails, it will raise a [`MatchError`](#matcherror) (just like [destructuring assignment](#destructuring-assignment)) instead of a `TypeError`. diff --git a/coconut/compiler/compiler.py b/coconut/compiler/compiler.py index 54465a841..f0477bb52 100644 --- a/coconut/compiler/compiler.py +++ b/coconut/compiler/compiler.py @@ -2445,7 +2445,8 @@ def proc_funcdef(self, original, loc, decorators, funcdef, is_async, in_method, raise self.make_err( CoconutTargetError, "async function definition requires a specific target", - original, loc, + original, + loc, target="sys", ) elif self.target_info >= (3, 5): @@ -2456,7 +2457,8 @@ def proc_funcdef(self, original, loc, decorators, funcdef, is_async, in_method, raise self.make_err( CoconutTargetError, "found Python 3.6 async generator (Coconut can only backport async generators as far back as 3.5)", - original, loc, + original, + loc, target="35", ) else: @@ -2815,16 +2817,18 @@ def function_call_handle(self, loc, tokens): """Enforce properly ordered function parameters.""" return "(" + join_args(*self.split_function_call(tokens, loc)) + ")" - def pipe_item_split(self, tokens, loc): + def pipe_item_split(self, original, loc, tokens): """Process a pipe item, which could be a partial, an attribute access, a method call, or an expression. - Return (type, split) where split is: - - (expr,) for expression - - (func, pos_args, kwd_args) for partial - - (name, args) for attr/method - - (attr, [(op, args)]) for itemgetter - - (op, arg) for right op partial - - (op, arg) for right arr concat partial + Return (type, split) where split is, for each type: + - expr: (expr,) + - partial: (func, pos_args, kwd_args) + - attrgetter: (name, args) + - itemgetter: (attr, [(op, args)]) for itemgetter + - right op partial: (op, arg) + - right arr concat partial: (op, arg) + - await: () + - namedexpr: (varname,) """ # list implies artificial tokens, which must be expr if isinstance(tokens, list) or "expr" in tokens: @@ -2868,7 +2872,18 @@ def pipe_item_split(self, tokens, loc): raise CoconutInternalException("invalid arr concat partial tokens in pipe_item", inner_toks) elif "await" in tokens: internal_assert(len(tokens) == 1 and tokens[0] == "await", "invalid await pipe item tokens", tokens) - return "await", [] + return "await", () + elif "namedexpr" in tokens: + if self.target_info < (3, 8): + raise self.make_err( + CoconutTargetError, + "named expression partial in pipe only supported for targets 3.8+", + original, + loc, + target="38", + ) + varname, = tokens + return "namedexpr", (varname,) else: raise CoconutInternalException("invalid pipe item tokens", tokens) @@ -2882,7 +2897,7 @@ def pipe_handle(self, original, loc, tokens, **kwargs): return item # we've only been given one operand, so we can't do any optimization, so just produce the standard object - name, split_item = self.pipe_item_split(item, loc) + name, split_item = self.pipe_item_split(original, loc, item) if name == "expr": expr, = split_item return expr @@ -2899,6 +2914,8 @@ def pipe_handle(self, original, loc, tokens, **kwargs): return partial_arr_concat_handle(item) elif name == "await": raise CoconutDeferredSyntaxError("await in pipe must have something piped into it", loc) + elif name == "namedexpr": + raise CoconutDeferredSyntaxError("named expression partial in pipe must have something piped into it", loc) else: raise CoconutInternalException("invalid split pipe item", split_item) @@ -2929,7 +2946,7 @@ def pipe_handle(self, original, loc, tokens, **kwargs): elif direction == "forwards": # if this is an implicit partial, we have something to apply it to, so optimize it - name, split_item = self.pipe_item_split(item, loc) + name, split_item = self.pipe_item_split(original, loc, item) subexpr = self.pipe_handle(original, loc, tokens) if name == "expr": @@ -2976,6 +2993,11 @@ def pipe_handle(self, original, loc, tokens, **kwargs): if stars: raise CoconutDeferredSyntaxError("cannot star pipe into await", loc) return self.await_expr_handle(original, loc, [subexpr]) + elif name == "namedexpr": + if stars: + raise CoconutDeferredSyntaxError("cannot star pipe into named expression partial", loc) + varname, = split_item + return "({varname} := {item})".format(varname=varname, item=subexpr) else: raise CoconutInternalException("invalid split pipe item", split_item) @@ -3952,7 +3974,8 @@ def await_expr_handle(self, original, loc, tokens): raise self.make_err( CoconutTargetError, "await requires a specific target", - original, loc, + original, + loc, target="sys", ) elif self.target_info >= (3, 5): diff --git a/coconut/compiler/grammar.py b/coconut/compiler/grammar.py index b64d040fe..cd6ff8339 100644 --- a/coconut/compiler/grammar.py +++ b/coconut/compiler/grammar.py @@ -1565,6 +1565,9 @@ class Grammar(object): back_none_dubstar_pipe, use_adaptive=False, ) + pipe_namedexpr_partial = lparen.suppress() + setname + (colon_eq + dot + rparen).suppress() + + # make sure to keep these three definitions in sync pipe_item = ( # we need the pipe_op since any of the atoms could otherwise be the start of an expression labeled_group(keyword("await"), "await") + pipe_op @@ -1574,6 +1577,7 @@ class Grammar(object): | labeled_group(attrgetter_atom_tokens, "attrgetter") + pipe_op | labeled_group(partial_op_atom_tokens, "op partial") + pipe_op | labeled_group(partial_arr_concat_tokens, "arr concat partial") + pipe_op + | labeled_group(pipe_namedexpr_partial, "namedexpr") + pipe_op # expr must come at end | labeled_group(comp_pipe_expr, "expr") + pipe_op ) @@ -1585,23 +1589,25 @@ class Grammar(object): | labeled_group(attrgetter_atom_tokens, "attrgetter") + end_simple_stmt_item | labeled_group(partial_op_atom_tokens, "op partial") + end_simple_stmt_item | labeled_group(partial_arr_concat_tokens, "arr concat partial") + end_simple_stmt_item + | labeled_group(pipe_namedexpr_partial, "namedexpr") + end_simple_stmt_item ) last_pipe_item = Group( lambdef("expr") # we need longest here because there's no following pipe_op we can use as above | longest( keyword("await")("await"), + partial_atom_tokens("partial"), itemgetter_atom_tokens("itemgetter"), attrgetter_atom_tokens("attrgetter"), - partial_atom_tokens("partial"), partial_op_atom_tokens("op partial"), partial_arr_concat_tokens("arr concat partial"), + pipe_namedexpr_partial("namedexpr"), comp_pipe_expr("expr"), ) ) + normal_pipe_expr = Forward() normal_pipe_expr_tokens = OneOrMore(pipe_item) + last_pipe_item - pipe_expr = ( comp_pipe_expr + ~pipe_op | normal_pipe_expr diff --git a/coconut/root.py b/coconut/root.py index f9636a605..4e13ec11b 100644 --- a/coconut/root.py +++ b/coconut/root.py @@ -26,7 +26,7 @@ VERSION = "3.0.4" VERSION_NAME = None # False for release, int >= 1 for develop -DEVELOP = 17 +DEVELOP = 18 ALPHA = False # for pre releases rather than post releases assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1" diff --git a/coconut/tests/src/cocotest/target_38/py38_test.coco b/coconut/tests/src/cocotest/target_38/py38_test.coco index 5df470874..8c4f30efc 100644 --- a/coconut/tests/src/cocotest/target_38/py38_test.coco +++ b/coconut/tests/src/cocotest/target_38/py38_test.coco @@ -7,4 +7,7 @@ def py38_test() -> bool: assert a == 3 == b def f(x: int, /, y: int) -> int = x + y assert f(1, y=2) == 3 + assert 10 |> (x := .) == 10 == x + assert 10 |> (x := .) |> (. + 1) == 11 + assert x == 10 return True From 4be1bb53b351c86a6fb33be1e247b5fda78b0fd1 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Tue, 23 Jan 2024 15:32:32 -0800 Subject: [PATCH 35/54] Add regression test Refs #825. --- coconut/tests/src/cocotest/agnostic/primary_2.coco | 1 + 1 file changed, 1 insertion(+) diff --git a/coconut/tests/src/cocotest/agnostic/primary_2.coco b/coconut/tests/src/cocotest/agnostic/primary_2.coco index 102a7bf88..b9605646e 100644 --- a/coconut/tests/src/cocotest/agnostic/primary_2.coco +++ b/coconut/tests/src/cocotest/agnostic/primary_2.coco @@ -434,6 +434,7 @@ def primary_test_2() -> bool: assert (,) ↤* (1, 2, 3) == (1, 2, 3) assert (,) ↤? None is None assert (,) ↤*? None is None # type: ignore + assert '''\u2029'''!='''\n''' with process_map.multiple_sequential_calls(): # type: ignore assert map((+), range(3), range(4)$[:-1], strict=True) |> list == [0, 2, 4] == process_map((+), range(3), range(4)$[:-1], strict=True) |> list # type: ignore From 5ba3d1a1753653a37f9861ab4163c500bae5da7f Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Fri, 26 Jan 2024 21:43:43 -0800 Subject: [PATCH 36/54] Universalize bytes --- DOCS.md | 1 + __coconut__/__init__.pyi | 1 + _coconut/__init__.pyi | 1 + coconut/compiler/templates/header.py_template | 4 ++- coconut/root.py | 26 ++++++++++++++++--- .../src/cocotest/agnostic/primary_2.coco | 6 +++++ 6 files changed, 35 insertions(+), 4 deletions(-) diff --git a/DOCS.md b/DOCS.md index 6290ebf87..4a847f592 100644 --- a/DOCS.md +++ b/DOCS.md @@ -258,6 +258,7 @@ While Coconut syntax is based off of the latest Python 3, Coconut code compiled To make Coconut built-ins universal across Python versions, Coconut makes available on any Python version built-ins that only exist in later versions, including **automatically overwriting Python 2 built-ins with their Python 3 counterparts.** Additionally, Coconut also [overwrites some Python 3 built-ins for optimization and enhancement purposes](#enhanced-built-ins). If access to the original Python versions of any overwritten built-ins is desired, the old built-ins can be retrieved by prefixing them with `py_`. Specifically, the overwritten built-ins are: +- `py_bytes` - `py_chr` - `py_dict` - `py_hex` diff --git a/__coconut__/__init__.pyi b/__coconut__/__init__.pyi index 5f675ea51..a73472dad 100644 --- a/__coconut__/__init__.pyi +++ b/__coconut__/__init__.pyi @@ -155,6 +155,7 @@ if sys.version_info < (3, 7): ... +py_bytes = bytes py_chr = chr py_dict = dict py_hex = hex diff --git a/_coconut/__init__.pyi b/_coconut/__init__.pyi index 82d320478..809c7bf0e 100644 --- a/_coconut/__init__.pyi +++ b/_coconut/__init__.pyi @@ -131,6 +131,7 @@ ValueError = _builtins.ValueError StopIteration = _builtins.StopIteration RuntimeError = _builtins.RuntimeError callable = _builtins.callable +chr = _builtins.chr classmethod = _builtins.classmethod complex = _builtins.complex all = _builtins.all diff --git a/coconut/compiler/templates/header.py_template b/coconut/compiler/templates/header.py_template index cdf766aee..69b93010a 100644 --- a/coconut/compiler/templates/header.py_template +++ b/coconut/compiler/templates/header.py_template @@ -61,7 +61,7 @@ class _coconut{object}:{COMMENT.EVERYTHING_HERE_MUST_BE_COPIED_TO_STUB_FILE} reiterables = abc.Sequence, abc.Mapping, abc.Set fmappables = list, tuple, dict, set, frozenset abc.Sequence.register(collections.deque) - Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, KeyError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bool, bytes, callable, classmethod, complex, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, globals, map, min, max, next, object, property, range, reversed, set, setattr, slice, str, sum, super, tuple, type, vars, zip, repr, print{comma_bytearray} = Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, KeyError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bool, bytes, callable, classmethod, complex, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, globals, map, min, max, next, object, property, range, reversed, set, setattr, slice, str, sum, {lstatic}super{rstatic}, tuple, type, vars, zip, {lstatic}repr{rstatic}, {lstatic}print{rstatic}{comma_bytearray} + Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, KeyError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bool, bytes, callable, chr, classmethod, complex, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, globals, map, min, max, next, object, property, range, reversed, set, setattr, slice, str, sum, super, tuple, type, vars, zip, repr, print{comma_bytearray} = Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, KeyError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bool, bytes, callable, chr, classmethod, complex, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, globals, map, min, max, next, object, property, range, reversed, set, setattr, slice, str, sum, {lstatic}super{rstatic}, tuple, type, vars, zip, {lstatic}repr{rstatic}, {lstatic}print{rstatic}{comma_bytearray} @_coconut.functools.wraps(_coconut.functools.partial) def _coconut_partial(_coconut_func, *args, **kwargs): partial_func = _coconut.functools.partial(_coconut_func, *args, **kwargs) @@ -1583,6 +1583,8 @@ def _coconut_base_makedata(data_type, args, from_fmap=False, fallback_to_init=Fa return args if _coconut.issubclass(data_type, _coconut.str): return "".join(args) + if _coconut.issubclass(data_type, _coconut.bytes): + return b"".join(args) if fallback_to_init or _coconut.issubclass(data_type, _coconut.fmappables): return data_type(args) if from_fmap: diff --git a/coconut/root.py b/coconut/root.py index 4e13ec11b..a0c2bf798 100644 --- a/coconut/root.py +++ b/coconut/root.py @@ -61,7 +61,7 @@ def _get_target_info(target): # if a new assignment is added below, a new builtins import should be added alongside it _base_py3_header = r'''from builtins import chr, dict, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, repr -py_chr, py_dict, py_hex, py_input, py_int, py_map, py_object, py_oct, py_open, py_print, py_range, py_str, py_super, py_zip, py_filter, py_reversed, py_enumerate, py_repr = chr, dict, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, repr +py_bytes, py_chr, py_dict, py_hex, py_input, py_int, py_map, py_object, py_oct, py_open, py_print, py_range, py_str, py_super, py_zip, py_filter, py_reversed, py_enumerate, py_repr = bytes, chr, dict, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, repr _coconut_py_str, _coconut_py_super, _coconut_py_dict = str, super, dict from functools import wraps as _coconut_wraps exec("_coconut_exec = exec") @@ -69,8 +69,8 @@ def _get_target_info(target): # if a new assignment is added below, a new builtins import should be added alongside it _base_py2_header = r'''from __builtin__ import chr, dict, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, raw_input, xrange, repr, long -py_chr, py_dict, py_hex, py_input, py_int, py_map, py_object, py_oct, py_open, py_print, py_range, py_str, py_super, py_zip, py_filter, py_reversed, py_enumerate, py_raw_input, py_xrange, py_repr = chr, dict, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, raw_input, xrange, repr -_coconut_py_raw_input, _coconut_py_xrange, _coconut_py_int, _coconut_py_long, _coconut_py_print, _coconut_py_str, _coconut_py_super, _coconut_py_unicode, _coconut_py_repr, _coconut_py_dict = raw_input, xrange, int, long, print, str, super, unicode, repr, dict +py_bytes, py_chr, py_dict, py_hex, py_input, py_int, py_map, py_object, py_oct, py_open, py_print, py_range, py_str, py_super, py_zip, py_filter, py_reversed, py_enumerate, py_raw_input, py_xrange, py_repr = bytes, chr, dict, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, raw_input, xrange, repr +_coconut_py_raw_input, _coconut_py_xrange, _coconut_py_int, _coconut_py_long, _coconut_py_print, _coconut_py_str, _coconut_py_super, _coconut_py_unicode, _coconut_py_repr, _coconut_py_dict, _coconut_py_bytes = raw_input, xrange, int, long, print, str, super, unicode, repr, dict, bytes from functools import wraps as _coconut_wraps from collections import Sequence as _coconut_Sequence from future_builtins import * @@ -96,6 +96,26 @@ def __instancecheck__(cls, inst): return _coconut.isinstance(inst, (_coconut_py_int, _coconut_py_long)) def __subclasscheck__(cls, subcls): return _coconut.issubclass(subcls, (_coconut_py_int, _coconut_py_long)) +class bytes(_coconut_py_bytes): + __slots__ = () + __doc__ = getattr(_coconut_py_bytes, "__doc__", "") + class __metaclass__(type): + def __instancecheck__(cls, inst): + return _coconut.isinstance(inst, _coconut_py_bytes) + def __subclasscheck__(cls, subcls): + return _coconut.issubclass(subcls, _coconut_py_bytes) + def __new__(self, *args): + if not args: + return b"" + elif _coconut.len(args) == 1: + if _coconut.isinstance(args[0], _coconut.int): + return b"\x00" * args[0] + elif _coconut.isinstance(args[0], _coconut.bytes): + return _coconut_py_bytes(args[0]) + else: + return b"".join(_coconut.chr(x) for x in args[0]) + else: + return args[0].encode(*args[1:]) class range(object): __slots__ = ("_xrange",) __doc__ = getattr(_coconut_py_xrange, "__doc__", "") diff --git a/coconut/tests/src/cocotest/agnostic/primary_2.coco b/coconut/tests/src/cocotest/agnostic/primary_2.coco index b9605646e..bc6000be4 100644 --- a/coconut/tests/src/cocotest/agnostic/primary_2.coco +++ b/coconut/tests/src/cocotest/agnostic/primary_2.coco @@ -435,6 +435,12 @@ def primary_test_2() -> bool: assert (,) ↤? None is None assert (,) ↤*? None is None # type: ignore assert '''\u2029'''!='''\n''' + assert b"a" `isinstance` bytes + assert b"a" `isinstance` py_bytes + assert bytes() == b"" + assert bytes(10) == b"\x00" * 10 + assert bytes([35, 40]) == b'#(' + assert bytes(b"abc") == b"abc" == bytes("abc", "utf-8") with process_map.multiple_sequential_calls(): # type: ignore assert map((+), range(3), range(4)$[:-1], strict=True) |> list == [0, 2, 4] == process_map((+), range(3), range(4)$[:-1], strict=True) |> list # type: ignore From b697bc87f7d8e8f459371c0d3116ec3228cdb396 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Fri, 26 Jan 2024 22:41:50 -0800 Subject: [PATCH 37/54] Support fmap of bytes, bytearray Resolves #826. --- DOCS.md | 2 +- __coconut__/__init__.pyi | 2 +- _coconut/__init__.pyi | 1 + coconut/compiler/header.py | 9 +++++++++ coconut/compiler/templates/header.py_template | 15 +++++++-------- coconut/root.py | 2 +- .../tests/src/cocotest/agnostic/primary_2.coco | 3 +++ 7 files changed, 23 insertions(+), 11 deletions(-) diff --git a/DOCS.md b/DOCS.md index 4a847f592..0deb91963 100644 --- a/DOCS.md +++ b/DOCS.md @@ -3385,7 +3385,7 @@ _Can't be done without a series of method definitions for each data type. See th In Haskell, `fmap(func, obj)` takes a data type `obj` and returns a new data type with `func` mapped over the contents. Coconut's `fmap` function does the exact same thing for Coconut's [data types](#data). -`fmap` can also be used on the built-in objects `str`, `dict`, `list`, `tuple`, `set`, `frozenset`, and `dict` as a variant of `map` that returns back an object of the same type. +`fmap` can also be used on the built-in objects `str`, `dict`, `list`, `tuple`, `set`, `frozenset`, `bytes`, `bytearray`, and `dict` as a variant of `map` that returns back an object of the same type. For `dict`, or any other `collections.abc.Mapping`, `fmap` will map over the mapping's `.items()` instead of the default iteration through its `.keys()`, with the new mapping reconstructed from the mapped over items. _Deprecated: `fmap$(starmap_over_mappings=True)` will `starmap` over the `.items()` instead of `map` over them._ diff --git a/__coconut__/__init__.pyi b/__coconut__/__init__.pyi index a73472dad..09313eb57 100644 --- a/__coconut__/__init__.pyi +++ b/__coconut__/__init__.pyi @@ -1457,7 +1457,7 @@ def fmap(func: _t.Callable[[_T, _U], _t.Tuple[_V, _W]], obj: _t.Mapping[_T, _U], Supports: * Coconut data types - * `str`, `dict`, `list`, `tuple`, `set`, `frozenset` + * `str`, `dict`, `list`, `tuple`, `set`, `frozenset`, `bytes`, `bytearray` * `dict` (maps over .items()) * asynchronous iterables * numpy arrays (uses np.vectorize) diff --git a/_coconut/__init__.pyi b/_coconut/__init__.pyi index 809c7bf0e..17c0e3418 100644 --- a/_coconut/__init__.pyi +++ b/_coconut/__init__.pyi @@ -160,6 +160,7 @@ min = _builtins.min max = _builtins.max next = _builtins.next object = _builtins.object +ord = _builtins.ord print = _builtins.print property = _builtins.property range = _builtins.range diff --git a/coconut/compiler/header.py b/coconut/compiler/header.py index 2d14cbc88..39b2d2664 100644 --- a/coconut/compiler/header.py +++ b/coconut/compiler/header.py @@ -778,6 +778,15 @@ def __aiter__(self): {async_def_anext} '''.format(**format_dict), ), + handle_bytes=pycondition( + (3,), + if_lt=''' +if _coconut.isinstance(obj, _coconut.bytes): + return _coconut_base_makedata(_coconut.bytes, [func(_coconut.ord(x)) for x in obj], from_fmap=True, fallback_to_init=fallback_to_init) + ''', + indent=1, + newline=True, + ), maybe_bind_lru_cache=pycondition( (3, 2), if_lt=''' diff --git a/coconut/compiler/templates/header.py_template b/coconut/compiler/templates/header.py_template index 69b93010a..dbf5c41ba 100644 --- a/coconut/compiler/templates/header.py_template +++ b/coconut/compiler/templates/header.py_template @@ -59,9 +59,9 @@ class _coconut{object}:{COMMENT.EVERYTHING_HERE_MUST_BE_COPIED_TO_STUB_FILE} jax_numpy_modules = {jax_numpy_modules} tee_type = type(itertools.tee((), 1)[0]) reiterables = abc.Sequence, abc.Mapping, abc.Set - fmappables = list, tuple, dict, set, frozenset + fmappables = list, tuple, dict, set, frozenset, bytes, bytearray abc.Sequence.register(collections.deque) - Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, KeyError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bool, bytes, callable, chr, classmethod, complex, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, globals, map, min, max, next, object, property, range, reversed, set, setattr, slice, str, sum, super, tuple, type, vars, zip, repr, print{comma_bytearray} = Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, KeyError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bool, bytes, callable, chr, classmethod, complex, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, globals, map, min, max, next, object, property, range, reversed, set, setattr, slice, str, sum, {lstatic}super{rstatic}, tuple, type, vars, zip, {lstatic}repr{rstatic}, {lstatic}print{rstatic}{comma_bytearray} + Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, KeyError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bool, bytes, callable, chr, classmethod, complex, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, globals, map, min, max, next, object, ord, property, range, reversed, set, setattr, slice, str, sum, super, tuple, type, vars, zip, repr, print{comma_bytearray} = Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, KeyError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bool, bytes, callable, chr, classmethod, complex, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, globals, map, min, max, next, object, ord, property, range, reversed, set, setattr, slice, str, sum, {lstatic}super{rstatic}, tuple, type, vars, zip, {lstatic}repr{rstatic}, {lstatic}print{rstatic}{comma_bytearray} @_coconut.functools.wraps(_coconut.functools.partial) def _coconut_partial(_coconut_func, *args, **kwargs): partial_func = _coconut.functools.partial(_coconut_func, *args, **kwargs) @@ -1583,8 +1583,6 @@ def _coconut_base_makedata(data_type, args, from_fmap=False, fallback_to_init=Fa return args if _coconut.issubclass(data_type, _coconut.str): return "".join(args) - if _coconut.issubclass(data_type, _coconut.bytes): - return b"".join(args) if fallback_to_init or _coconut.issubclass(data_type, _coconut.fmappables): return data_type(args) if from_fmap: @@ -1602,7 +1600,7 @@ def fmap(func, obj, **kwargs): Supports: * Coconut data types - * `str`, `dict`, `list`, `tuple`, `set`, `frozenset` + * `str`, `dict`, `list`, `tuple`, `set`, `frozenset`, `bytes`, `bytearray` * `dict` (maps over .items()) * asynchronous iterables * numpy arrays (uses np.vectorize) @@ -1644,10 +1642,11 @@ def fmap(func, obj, **kwargs): else: if aiter is not _coconut.NotImplemented: return _coconut_amap(func, aiter) - if starmap_over_mappings: - return _coconut_base_makedata(obj.__class__, {_coconut_}starmap(func, obj.items()) if _coconut.isinstance(obj, _coconut.abc.Mapping) else {_coconut_}map(func, obj), from_fmap=True, fallback_to_init=fallback_to_init) +{handle_bytes} if _coconut.isinstance(obj, _coconut.abc.Mapping): + mapped_obj = ({_coconut_}starmap if starmap_over_mappings else {_coconut_}map)(func, obj.items()) else: - return _coconut_base_makedata(obj.__class__, {_coconut_}map(func, obj.items() if _coconut.isinstance(obj, _coconut.abc.Mapping) else obj), from_fmap=True, fallback_to_init=fallback_to_init) + mapped_obj = _coconut_map(func, obj) + return _coconut_base_makedata(obj.__class__, mapped_obj, from_fmap=True, fallback_to_init=fallback_to_init) def _coconut_memoize_helper(maxsize=None, typed=False): return maxsize, typed def memoize(*args, **kwargs): diff --git a/coconut/root.py b/coconut/root.py index a0c2bf798..000d9bd0b 100644 --- a/coconut/root.py +++ b/coconut/root.py @@ -26,7 +26,7 @@ VERSION = "3.0.4" VERSION_NAME = None # False for release, int >= 1 for develop -DEVELOP = 18 +DEVELOP = 19 ALPHA = False # for pre releases rather than post releases assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1" diff --git a/coconut/tests/src/cocotest/agnostic/primary_2.coco b/coconut/tests/src/cocotest/agnostic/primary_2.coco index bc6000be4..322457897 100644 --- a/coconut/tests/src/cocotest/agnostic/primary_2.coco +++ b/coconut/tests/src/cocotest/agnostic/primary_2.coco @@ -441,6 +441,9 @@ def primary_test_2() -> bool: assert bytes(10) == b"\x00" * 10 assert bytes([35, 40]) == b'#(' assert bytes(b"abc") == b"abc" == bytes("abc", "utf-8") + assert b"Abc" |> fmap$(.|32) == b"abc" + assert bytearray(b"Abc") |> fmap$(.|32) == bytearray(b"abc") + assert (bytearray(b"Abc") |> fmap$(.|32)) `isinstance` bytearray with process_map.multiple_sequential_calls(): # type: ignore assert map((+), range(3), range(4)$[:-1], strict=True) |> list == [0, 2, 4] == process_map((+), range(3), range(4)$[:-1], strict=True) |> list # type: ignore From e2fa488865dec34fd9cc2246357cda8907922f1c Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Fri, 2 Feb 2024 19:29:18 -0800 Subject: [PATCH 38/54] Improve implicit coefficient error message Resolves #827. --- coconut/compiler/templates/header.py_template | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coconut/compiler/templates/header.py_template b/coconut/compiler/templates/header.py_template index dbf5c41ba..d6a8a4c9d 100644 --- a/coconut/compiler/templates/header.py_template +++ b/coconut/compiler/templates/header.py_template @@ -2094,7 +2094,7 @@ def _coconut_call_or_coefficient(func, *args): if _coconut.callable(func): return func(*args) if not _coconut.isinstance(func, (_coconut.int, _coconut.float, _coconut.complex)) and _coconut_get_base_module(func) not in _coconut.numpy_modules: - raise _coconut.TypeError("implicit function application and coefficient syntax only supported for Callable, int, float, complex, and numpy objects") + raise _coconut.TypeError("first argument in implicit function application and coefficient syntax must be Callable, int, float, complex, or numpy object") func = func for x in args: func = func * x{COMMENT.no_times_equals_to_avoid_modification} From 12b2e73670fdd5c80711807bead8bbc6a8090e41 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Fri, 2 Feb 2024 19:30:21 -0800 Subject: [PATCH 39/54] Further clarify error message --- coconut/compiler/templates/header.py_template | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coconut/compiler/templates/header.py_template b/coconut/compiler/templates/header.py_template index d6a8a4c9d..a332de645 100644 --- a/coconut/compiler/templates/header.py_template +++ b/coconut/compiler/templates/header.py_template @@ -2094,7 +2094,7 @@ def _coconut_call_or_coefficient(func, *args): if _coconut.callable(func): return func(*args) if not _coconut.isinstance(func, (_coconut.int, _coconut.float, _coconut.complex)) and _coconut_get_base_module(func) not in _coconut.numpy_modules: - raise _coconut.TypeError("first argument in implicit function application and coefficient syntax must be Callable, int, float, complex, or numpy object") + raise _coconut.TypeError("first object in implicit function application and coefficient syntax must be Callable, int, float, complex, or numpy") func = func for x in args: func = func * x{COMMENT.no_times_equals_to_avoid_modification} From 615f063a4a14846aa6a83dbc4d60e780aa046d84 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Sat, 24 Feb 2024 16:30:13 -0800 Subject: [PATCH 40/54] Improve coloring --- coconut/terminal.py | 16 +++++++++++----- coconut/util.py | 6 +++--- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/coconut/terminal.py b/coconut/terminal.py index ee1a9335c..3fe3cad9d 100644 --- a/coconut/terminal.py +++ b/coconut/terminal.py @@ -183,6 +183,16 @@ def logging(self): sys.stdout = old_stdout +def should_use_color(file=None): + """Determine if colors should be used for the given file object.""" + use_color = get_bool_env_var(use_color_env_var, default=None) + if use_color is not None: + return use_color + if get_bool_env_var("CLICOLOR_FORCE") or get_bool_env_var("FORCE_COLOR"): + return True + return file is not None and not isatty(file) + + # ----------------------------------------------------------------------------------------------------------------------- # LOGGER: # ----------------------------------------------------------------------------------------------------------------------- @@ -210,11 +220,7 @@ def __init__(self, other=None): @classmethod def enable_colors(cls, file=None): """Attempt to enable CLI colors.""" - use_color = get_bool_env_var(use_color_env_var, default=None) - if ( - use_color is False - or use_color is None and file is not None and not isatty(file) - ): + if not should_use_color(file): return False if not cls.colors_enabled: # necessary to resolve https://bugs.python.org/issue40134 diff --git a/coconut/util.py b/coconut/util.py index fb9c9207c..e0b487870 100644 --- a/coconut/util.py +++ b/coconut/util.py @@ -8,7 +8,7 @@ """ Author: Evan Hubinger License: Apache 2.0 -Description: Installer for the Coconut Jupyter kernel. +Description: Base Coconut utilities. """ # ----------------------------------------------------------------------------------------------------------------------- @@ -331,10 +331,10 @@ def replace_all(inputstr, all_to_replace, replace_to): return inputstr -def highlight(code): +def highlight(code, force=False): """Attempt to highlight Coconut code for the terminal.""" from coconut.terminal import logger # hide to remove circular deps - if logger.enable_colors(sys.stdout) and logger.enable_colors(sys.stderr): + if force or logger.enable_colors(sys.stdout) and logger.enable_colors(sys.stderr): try: from coconut.highlighter import highlight_coconut_for_terminal except ImportError: From cd4f79c4f0737006206995fe17985f9258f091c4 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Sat, 24 Feb 2024 16:39:11 -0800 Subject: [PATCH 41/54] Improve kernel err msgs Resolves #812. --- coconut/root.py | 2 +- coconut/util.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/coconut/root.py b/coconut/root.py index 000d9bd0b..88841b3f7 100644 --- a/coconut/root.py +++ b/coconut/root.py @@ -26,7 +26,7 @@ VERSION = "3.0.4" VERSION_NAME = None # False for release, int >= 1 for develop -DEVELOP = 19 +DEVELOP = 20 ALPHA = False # for pre releases rather than post releases assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1" diff --git a/coconut/util.py b/coconut/util.py index e0b487870..f9f4905d0 100644 --- a/coconut/util.py +++ b/coconut/util.py @@ -340,7 +340,8 @@ def highlight(code, force=False): except ImportError: logger.log_exc() else: - return highlight_coconut_for_terminal(code) + code_base, code_white = split_trailing_whitespace(code) + return highlight_coconut_for_terminal(code_base).rstrip() + code_white return code From 04d906b928002c59c00949dbe05bf4892d640237 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Sat, 24 Feb 2024 22:18:35 -0800 Subject: [PATCH 42/54] Fix stmt lambda scoping Resolves #814. --- DOCS.md | 2 +- coconut/compiler/compiler.py | 717 ++++++++++-------- coconut/compiler/grammar.py | 73 +- coconut/compiler/util.py | 12 +- coconut/root.py | 2 +- .../src/cocotest/agnostic/primary_2.coco | 8 + coconut/tests/src/cocotest/agnostic/util.coco | 16 +- .../src/cocotest/target_38/py38_test.coco | 2 + coconut/tests/src/extras.coco | 13 +- 9 files changed, 504 insertions(+), 341 deletions(-) diff --git a/DOCS.md b/DOCS.md index 0deb91963..5d3be5c74 100644 --- a/DOCS.md +++ b/DOCS.md @@ -1728,7 +1728,7 @@ If the last `statement` (not followed by a semicolon) in a statement lambda is a Statement lambdas also support implicit lambda syntax such that `def => _` is equivalent to `def (_=None) => _` as well as explicitly marking them as pattern-matching such that `match def (x) => x` will be a pattern-matching function. -Importantly, statement lambdas do not capture variables introduced only in the surrounding expression, e.g. inside of a list comprehension or normal lambda. To avoid such situations, only nest statement lambdas inside other statement lambdas, and explicitly partially apply a statement lambda to pass in a value from a list comprehension. +Additionally, statement lambdas have slightly different scoping rules than normal lambdas. When a statement lambda is inside of an expression with an expression-local variable, such as a normal lambda or comprehension, the statement lambda will capture the value of the variable at the time that the statement lambda is defined (rather than a reference to the overall namespace as with normal lambdas). As a result, while `[=> y for y in range(2)] |> map$(call) |> list` is `[1, 1]`, `[def => y for y in range(2)] |> map$(call) |> list` is `[0, 1]`. Note that this only works for expression-local variables: to copy the entire namespace at the time of function definition, use [`copyclosure`](#copyclosure-functions) (which can be used with statement lambdas). Note that statement lambdas have a lower precedence than normal lambdas and thus capture things like trailing commas. To avoid confusion, statement lambdas should always be wrapped in their own set of parentheses. diff --git a/coconut/compiler/compiler.py b/coconut/compiler/compiler.py index f0477bb52..7f6efdafb 100644 --- a/coconut/compiler/compiler.py +++ b/coconut/compiler/compiler.py @@ -17,6 +17,7 @@ # - Compiler # - Processors # - Handlers +# - Managers # - Checking Handlers # - Endpoints # - Binding @@ -180,6 +181,7 @@ load_cache_for, pickle_cache, handle_and_manage, + manage, sub_all, ComputationNode, ) @@ -637,23 +639,6 @@ def inner_environment(self, ln=None): self.num_lines = num_lines self.remaining_original = remaining_original - def current_parsing_context(self, name, default=None): - """Get the current parsing context for the given name.""" - stack = self.parsing_context[name] - if stack: - return stack[-1] - else: - return default - - @contextmanager - def add_to_parsing_context(self, name, obj): - """Add the given object to the parsing context for the given name.""" - self.parsing_context[name].append(obj) - try: - yield - finally: - self.parsing_context[name].pop() - @contextmanager def disable_checks(self): """Run the block without checking names or strict errors.""" @@ -774,6 +759,30 @@ def bind(cls): cls.method("where_stmt_manage"), ) + # handle parsing_context for expr_setnames + # (we need include_in_packrat_context here because some parses will be in an expr_setname context and some won't) + cls.expr_lambdef <<= manage( + cls.expr_lambdef_ref, + cls.method("has_expr_setname_manage"), + include_in_packrat_context=True, + ) + cls.lambdef_no_cond <<= manage( + cls.lambdef_no_cond_ref, + cls.method("has_expr_setname_manage"), + include_in_packrat_context=True, + ) + cls.comprehension_expr <<= manage( + cls.comprehension_expr_ref, + cls.method("has_expr_setname_manage"), + include_in_packrat_context=True, + ) + cls.dict_comp <<= handle_and_manage( + cls.dict_comp_ref, + cls.method("dict_comp_handle"), + cls.method("has_expr_setname_manage"), + include_in_packrat_context=True, + ) + # greedy handlers (we need to know about them even if suppressed and/or they use the parsing_context) cls.comment <<= attach(cls.comment_tokens, cls.method("comment_handle"), greedy=True) cls.type_param <<= attach(cls.type_param_ref, cls.method("type_param_handle"), greedy=True) @@ -783,7 +792,16 @@ def bind(cls): # name handlers cls.refname <<= attach(cls.name_ref, cls.method("name_handle")) cls.setname <<= attach(cls.name_ref, cls.method("name_handle", assign=True)) - cls.classname <<= attach(cls.name_ref, cls.method("name_handle", assign=True, classname=True), greedy=True) + cls.classname <<= attach( + cls.name_ref, + cls.method("name_handle", assign=True, classname=True), + greedy=True, + ) + cls.expr_setname <<= attach( + cls.name_ref, + cls.method("name_handle", assign=True, expr_setname=True), + greedy=True, + ) # abnormally named handlers cls.moduledoc_item <<= attach(cls.moduledoc, cls.method("set_moduledoc")) @@ -796,6 +814,11 @@ def bind(cls): cls.trailer_atom <<= attach(cls.trailer_atom_ref, cls.method("item_handle")) cls.no_partial_trailer_atom <<= attach(cls.no_partial_trailer_atom_ref, cls.method("item_handle")) cls.simple_assign <<= attach(cls.simple_assign_ref, cls.method("item_handle")) + cls.expr_simple_assign <<= attach(cls.expr_simple_assign_ref, cls.method("item_handle")) + + # handle all star assignments with star_assign_item_check + cls.star_assign_item <<= attach(cls.star_assign_item_ref, cls.method("star_assign_item_check")) + cls.expr_star_assign_item <<= attach(cls.expr_star_assign_item_ref, cls.method("star_assign_item_check")) # handle all string atoms with string_atom_handle cls.string_atom <<= attach(cls.string_atom_ref, cls.method("string_atom_handle")) @@ -819,7 +842,6 @@ def bind(cls): cls.complex_raise_stmt <<= attach(cls.complex_raise_stmt_ref, cls.method("complex_raise_stmt_handle")) cls.augassign_stmt <<= attach(cls.augassign_stmt_ref, cls.method("augassign_stmt_handle")) cls.kwd_augassign <<= attach(cls.kwd_augassign_ref, cls.method("kwd_augassign_handle")) - cls.dict_comp <<= attach(cls.dict_comp_ref, cls.method("dict_comp_handle")) cls.destructuring_stmt <<= attach(cls.destructuring_stmt_ref, cls.method("destructuring_stmt_handle")) cls.full_match <<= attach(cls.full_match_ref, cls.method("full_match_handle")) cls.name_match_funcdef <<= attach(cls.name_match_funcdef_ref, cls.method("name_match_funcdef_handle")) @@ -849,7 +871,6 @@ def bind(cls): # these handlers just do strict/target checking cls.u_string <<= attach(cls.u_string_ref, cls.method("u_string_check")) cls.nonlocal_stmt <<= attach(cls.nonlocal_stmt_ref, cls.method("nonlocal_check")) - cls.star_assign_item <<= attach(cls.star_assign_item_ref, cls.method("star_assign_item_check")) cls.keyword_lambdef <<= attach(cls.keyword_lambdef_ref, cls.method("lambdef_check")) cls.star_sep_arg <<= attach(cls.star_sep_arg_ref, cls.method("star_sep_check")) cls.star_sep_setarg <<= attach(cls.star_sep_setarg_ref, cls.method("star_sep_check")) @@ -3895,69 +3916,104 @@ def set_letter_literal_handle(self, tokens): def stmt_lambdef_handle(self, original, loc, tokens): """Process multi-line lambdef statements.""" - if len(tokens) == 4: - got_kwds, params, stmts_toks, followed_by = tokens - typedef = None - else: - got_kwds, params, typedef, stmts_toks, followed_by = tokens + name = self.get_temp_var("lambda", loc) - if followed_by == ",": - self.strict_err_or_warn("found statement lambda followed by comma; this isn't recommended as it can be unclear whether the comma is inside or outside the lambda (just wrap the lambda in parentheses)", original, loc) - else: - internal_assert(followed_by == "", "invalid stmt_lambdef followed_by", followed_by) - - is_async = False - add_kwds = [] - for kwd in got_kwds: - if kwd == "async": - self.internal_assert(not is_async, original, loc, "duplicate stmt_lambdef async keyword", kwd) - is_async = True - elif kwd == "copyclosure": - add_kwds.append(kwd) + # avoid regenerating the code if we already built it on a previous call + if name not in self.add_code_before: + if len(tokens) == 4: + got_kwds, params, stmts_toks, followed_by = tokens + typedef = None else: - raise CoconutInternalException("invalid stmt_lambdef keyword", kwd) - - if len(stmts_toks) == 1: - stmts, = stmts_toks - elif len(stmts_toks) == 2: - stmts, last = stmts_toks - if "tests" in stmts_toks: - stmts = stmts.asList() + ["return " + last] + got_kwds, params, typedef, stmts_toks, followed_by = tokens + + if followed_by == ",": + self.strict_err_or_warn("found statement lambda followed by comma; this isn't recommended as it can be unclear whether the comma is inside or outside the lambda (just wrap the lambda in parentheses)", original, loc) else: - stmts = stmts.asList() + [last] - else: - raise CoconutInternalException("invalid statement lambda body tokens", stmts_toks) + internal_assert(followed_by == "", "invalid stmt_lambdef followed_by", followed_by) + + is_async = False + add_kwds = [] + for kwd in got_kwds: + if kwd == "async": + self.internal_assert(not is_async, original, loc, "duplicate stmt_lambdef async keyword", kwd) + is_async = True + elif kwd == "copyclosure": + add_kwds.append(kwd) + else: + raise CoconutInternalException("invalid stmt_lambdef keyword", kwd) + + if len(stmts_toks) == 1: + stmts, = stmts_toks + elif len(stmts_toks) == 2: + stmts, last = stmts_toks + if "tests" in stmts_toks: + stmts = stmts.asList() + ["return " + last] + else: + stmts = stmts.asList() + [last] + else: + raise CoconutInternalException("invalid statement lambda body tokens", stmts_toks) - name = self.get_temp_var("lambda", loc) - body = openindent + "\n".join(stmts) + closeindent + body = openindent + "\n".join(stmts) + closeindent - if typedef is None: - colon = ":" - else: - colon = self.typedef_handle([typedef]) - if isinstance(params, str): - decorators = "" - funcdef = "def " + name + params + colon + "\n" + body - else: - match_tokens = [name] + list(params) - before_colon, after_docstring = self.name_match_funcdef_handle(original, loc, match_tokens) - decorators = "@_coconut_mark_as_match\n" - funcdef = ( - before_colon - + colon - + "\n" - + after_docstring - + body - ) + if typedef is None: + colon = ":" + else: + colon = self.typedef_handle([typedef]) + if isinstance(params, str): + decorators = "" + funcdef = "def " + name + params + colon + "\n" + body + else: + match_tokens = [name] + list(params) + before_colon, after_docstring = self.name_match_funcdef_handle(original, loc, match_tokens) + decorators = "@_coconut_mark_as_match\n" + funcdef = ( + before_colon + + colon + + "\n" + + after_docstring + + body + ) - funcdef = " ".join(add_kwds + [funcdef]) + funcdef = " ".join(add_kwds + [funcdef]) - self.add_code_before[name] = self.decoratable_funcdef_stmt_handle(original, loc, [decorators, funcdef], is_async, is_stmt_lambda=True) + self.add_code_before[name] = self.decoratable_funcdef_stmt_handle(original, loc, [decorators, funcdef], is_async, is_stmt_lambda=True) + + expr_setname_context = self.current_parsing_context("expr_setnames") + if expr_setname_context is None: + return name + else: + builder_name = self.get_temp_var("lambda_builder", loc) + + parent_context = expr_setname_context["parent"] + parent_setnames = set() + while parent_context: + parent_setnames |= parent_context["new_names"] + parent_context = parent_context["parent"] + + def stmt_lambdef_callback(): + expr_setnames = parent_setnames | expr_setname_context["new_names"] + expr_setnames_str = ", ".join(sorted(expr_setnames) + ["**_coconut_other_locals"]) + # the actual code for the function will automatically be added by add_code_before for name + builder_code = handle_indentation(""" +def {builder_name}({expr_setnames_str}): + del _coconut_other_locals + return {name} + """).format( + builder_name=builder_name, + expr_setnames_str=expr_setnames_str, + name=name, + ) + self.add_code_before[builder_name] = builder_code - return name + expr_setname_context["callbacks"].append(stmt_lambdef_callback) + if parent_setnames: + builder_args = "**({" + ", ".join('"' + name + '": ' + name for name in sorted(parent_setnames)) + "} | _coconut.locals())" + else: + builder_args = "**_coconut.locals()" + return builder_name + "(" + builder_args + ")" def decoratable_funcdef_stmt_handle(self, original, loc, tokens, is_async=False, is_stmt_lambda=False): - """Wraps the given function for later processing""" + """Wrap the given function for later processing.""" if len(tokens) == 1: funcdef, = tokens decorators = "" @@ -4079,178 +4135,6 @@ def typed_assign_stmt_handle(self, tokens): type_ignore=self.type_ignore_comment(), ) - def funcname_typeparams_handle(self, tokens): - """Handle function names with type parameters.""" - if len(tokens) == 1: - name, = tokens - return name - else: - name, paramdefs = tokens - return self.add_code_before_marker_with_replacement(name, "".join(paramdefs), add_spaces=False) - - funcname_typeparams_handle.ignore_one_token = True - - def type_param_handle(self, original, loc, tokens): - """Compile a type param into an assignment.""" - args = "" - bound_op = None - bound_op_type = "" - if "TypeVar" in tokens: - TypeVarFunc = "TypeVar" - bound_op_type = "bound" - if len(tokens) == 2: - name_loc, name = tokens - else: - name_loc, name, bound_op, bound = tokens - args = ", bound=" + self.wrap_typedef(bound, for_py_typedef=False) - elif "TypeVar constraint" in tokens: - TypeVarFunc = "TypeVar" - bound_op_type = "constraint" - name_loc, name, bound_op, constraints = tokens - args = ", " + ", ".join(self.wrap_typedef(c, for_py_typedef=False) for c in constraints) - elif "TypeVarTuple" in tokens: - TypeVarFunc = "TypeVarTuple" - name_loc, name = tokens - elif "ParamSpec" in tokens: - TypeVarFunc = "ParamSpec" - name_loc, name = tokens - else: - raise CoconutInternalException("invalid type_param tokens", tokens) - - kwargs = "" - if bound_op is not None: - self.internal_assert(bound_op_type in ("bound", "constraint"), original, loc, "invalid type_param bound_op", bound_op) - # uncomment this line whenever mypy adds support for infer_variance in TypeVar - # (and remove the warning about it in the DOCS) - # kwargs = ", infer_variance=True" - if bound_op == "<=": - self.strict_err_or_warn( - "use of " + repr(bound_op) + " as a type parameter " + bound_op_type + " declaration operator is deprecated (Coconut style is to use '<:' for bounds and ':' for constaints)", - original, - loc, - ) - else: - self.internal_assert(bound_op in (":", "<:"), original, loc, "invalid type_param bound_op", bound_op) - if bound_op_type == "bound" and bound_op != "<:" or bound_op_type == "constraint" and bound_op != ":": - self.strict_err( - "found use of " + repr(bound_op) + " as a type parameter " + bound_op_type + " declaration operator (Coconut style is to use '<:' for bounds and ':' for constaints)", - original, - loc, - ) - - name_loc = int(name_loc) - internal_assert(name_loc == loc if TypeVarFunc == "TypeVar" else name_loc >= loc, "invalid name location for " + TypeVarFunc, (name_loc, loc, tokens)) - - typevar_info = self.current_parsing_context("typevars") - if typevar_info is not None: - # check to see if we already parsed this exact typevar, in which case just reuse the existing temp_name - if typevar_info["typevar_locs"].get(name, None) == name_loc: - name = typevar_info["all_typevars"][name] - else: - if name in typevar_info["all_typevars"]: - raise CoconutDeferredSyntaxError("type variable {name!r} already defined".format(name=name), loc) - temp_name = self.get_temp_var(("typevar", name), name_loc) - typevar_info["all_typevars"][name] = temp_name - typevar_info["new_typevars"].append((TypeVarFunc, temp_name)) - typevar_info["typevar_locs"][name] = name_loc - name = temp_name - - return '{name} = _coconut.typing.{TypeVarFunc}("{name}"{args}{kwargs})\n'.format( - name=name, - TypeVarFunc=TypeVarFunc, - args=args, - kwargs=kwargs, - ) - - def get_generic_for_typevars(self): - """Get the Generic instances for the current typevars.""" - typevar_info = self.current_parsing_context("typevars") - internal_assert(typevar_info is not None, "get_generic_for_typevars called with no typevars") - generics = [] - for TypeVarFunc, name in typevar_info["new_typevars"]: - if TypeVarFunc in ("TypeVar", "ParamSpec"): - generics.append(name) - elif TypeVarFunc == "TypeVarTuple": - if self.target_info >= (3, 11): - generics.append("*" + name) - else: - generics.append("_coconut.typing.Unpack[" + name + "]") - else: - raise CoconutInternalException("invalid TypeVarFunc", TypeVarFunc, "(", name, ")") - return "_coconut.typing.Generic[" + ", ".join(generics) + "]" - - @contextmanager - def type_alias_stmt_manage(self, original=None, loc=None, item=None): - """Manage the typevars parsing context.""" - prev_typevar_info = self.current_parsing_context("typevars") - with self.add_to_parsing_context("typevars", { - "all_typevars": {} if prev_typevar_info is None else prev_typevar_info["all_typevars"].copy(), - "new_typevars": [], - "typevar_locs": {}, - }): - yield - - def type_alias_stmt_handle(self, tokens): - """Handle type alias statements.""" - if len(tokens) == 2: - name, typedef = tokens - paramdefs = () - else: - name, paramdefs, typedef = tokens - if self.target_info >= (3, 12): - return "type " + name + " = " + self.wrap_typedef(typedef, for_py_typedef=True) - else: - return "".join(paramdefs) + self.typed_assign_stmt_handle([ - name, - "_coconut.typing.TypeAlias", - self.wrap_typedef(typedef, for_py_typedef=False), - ]) - - def where_item_handle(self, tokens): - """Manage where items.""" - where_context = self.current_parsing_context("where") - internal_assert(not where_context["assigns"], "invalid where_context", where_context) - where_context["assigns"] = set() - return tokens - - @contextmanager - def where_stmt_manage(self, original, loc, item): - """Manage where statements.""" - with self.add_to_parsing_context("where", { - "assigns": None, - }): - yield - - def where_stmt_handle(self, loc, tokens): - """Process where statements.""" - main_stmt, body_stmts = tokens - - where_assigns = self.current_parsing_context("where")["assigns"] - internal_assert(lambda: where_assigns is not None, "missing where_assigns") - - where_init = "".join(body_stmts) - where_final = main_stmt + "\n" - out = where_init + where_final - if not where_assigns: - return out - - name_regexes = { - name: compile_regex(r"\b" + name + r"\b") - for name in where_assigns - } - name_replacements = { - name: self.get_temp_var(("where", name), loc) - for name in where_assigns - } - - where_init = self.deferred_code_proc(where_init) - where_final = self.deferred_code_proc(where_final) - out = where_init + where_final - - out = sub_all(out, name_regexes, name_replacements) - - return self.wrap_passthrough(out, early=True) - def with_stmt_handle(self, tokens): """Process with statements.""" withs, body = tokens @@ -4648,72 +4532,204 @@ class {protocol_var}({tokens}, _coconut.typing.Protocol): pass # end: HANDLERS # ----------------------------------------------------------------------------------------------------------------------- -# CHECKING HANDLERS: +# MANAGERS: # ----------------------------------------------------------------------------------------------------------------------- - def check_strict(self, name, original, loc, tokens=(None,), only_warn=False, always_warn=False): - """Check that syntax meets --strict requirements.""" - self.internal_assert(len(tokens) == 1, original, loc, "invalid " + name + " tokens", tokens) - message = "found " + name - if self.strict: - kwargs = {} - if only_warn: - if not always_warn: - kwargs["extra"] = "remove --strict to dismiss" - self.syntax_warning(message, original, loc, **kwargs) - else: - if always_warn: - kwargs["extra"] = "remove --strict to downgrade to a warning" - return self.raise_or_wrap_error(self.make_err(CoconutStyleError, message, original, loc, **kwargs)) - elif always_warn: - self.syntax_warning(message, original, loc) - return tokens[0] - - def lambdef_check(self, original, loc, tokens): - """Check for Python-style lambdas.""" - return self.check_strict("Python-style lambda", original, loc, tokens) + def current_parsing_context(self, name, default=None): + """Get the current parsing context for the given name.""" + stack = self.parsing_context[name] + if stack: + return stack[-1] + else: + return default - def endline_semicolon_check(self, original, loc, tokens): - """Check for semicolons at the end of lines.""" - return self.check_strict("semicolon at end of line", original, loc, tokens, always_warn=True) + @contextmanager + def add_to_parsing_context(self, name, obj, callbacks_key=None): + """Pur the given object on the parsing context stack for the given name.""" + self.parsing_context[name].append(obj) + try: + yield + finally: + popped_ctx = self.parsing_context[name].pop() + if callbacks_key is not None: + for callback in popped_ctx[callbacks_key]: + callback() - def u_string_check(self, original, loc, tokens): - """Check for Python-2-style unicode strings.""" - return self.check_strict("Python-2-style unicode string (all Coconut strings are unicode strings)", original, loc, tokens, always_warn=True) + def funcname_typeparams_handle(self, tokens): + """Handle function names with type parameters.""" + if len(tokens) == 1: + name, = tokens + return name + else: + name, paramdefs = tokens + return self.add_code_before_marker_with_replacement(name, "".join(paramdefs), add_spaces=False) - def match_dotted_name_const_check(self, original, loc, tokens): - """Check for Python-3.10-style implicit dotted name match check.""" - return self.check_strict("Python-3.10-style dotted name in pattern-matching (Coconut style is to use '=={name}' not '{name}')".format(name=tokens[0]), original, loc, tokens) + funcname_typeparams_handle.ignore_one_token = True - def match_check_equals_check(self, original, loc, tokens): - """Check for old-style =item in pattern-matching.""" - return self.check_strict("deprecated equality-checking '=...' pattern; use '==...' instead", original, loc, tokens, always_warn=True) + def type_param_handle(self, original, loc, tokens): + """Compile a type param into an assignment.""" + args = "" + bound_op = None + bound_op_type = "" + if "TypeVar" in tokens: + TypeVarFunc = "TypeVar" + bound_op_type = "bound" + if len(tokens) == 2: + name_loc, name = tokens + else: + name_loc, name, bound_op, bound = tokens + args = ", bound=" + self.wrap_typedef(bound, for_py_typedef=False) + elif "TypeVar constraint" in tokens: + TypeVarFunc = "TypeVar" + bound_op_type = "constraint" + name_loc, name, bound_op, constraints = tokens + args = ", " + ", ".join(self.wrap_typedef(c, for_py_typedef=False) for c in constraints) + elif "TypeVarTuple" in tokens: + TypeVarFunc = "TypeVarTuple" + name_loc, name = tokens + elif "ParamSpec" in tokens: + TypeVarFunc = "ParamSpec" + name_loc, name = tokens + else: + raise CoconutInternalException("invalid type_param tokens", tokens) - def power_in_impl_call_check(self, original, loc, tokens): - """Check for exponentation in implicit function application / coefficient syntax.""" - return self.check_strict( - "syntax with new behavior in Coconut v3; 'f x ** y' is now equivalent to 'f(x**y)' not 'f(x)**y'", - original, - loc, - tokens, - only_warn=True, - always_warn=True, + if bound_op is not None: + self.internal_assert(bound_op_type in ("bound", "constraint"), original, loc, "invalid type_param bound_op", bound_op) + if bound_op == "<=": + self.strict_err_or_warn( + "use of " + repr(bound_op) + " as a type parameter " + bound_op_type + " declaration operator is deprecated (Coconut style is to use '<:' for bounds and ':' for constaints)", + original, + loc, + ) + else: + self.internal_assert(bound_op in (":", "<:"), original, loc, "invalid type_param bound_op", bound_op) + if bound_op_type == "bound" and bound_op != "<:" or bound_op_type == "constraint" and bound_op != ":": + self.strict_err( + "found use of " + repr(bound_op) + " as a type parameter " + bound_op_type + " declaration operator (Coconut style is to use '<:' for bounds and ':' for constaints)", + original, + loc, + ) + + kwargs = "" + # uncomment these lines whenever mypy adds support for infer_variance in TypeVar + # (and remove the warning about it in the DOCS) + # if TypeVarFunc == "TypeVar": + # kwargs += ", infer_variance=True" + + name_loc = int(name_loc) + internal_assert(name_loc == loc if TypeVarFunc == "TypeVar" else name_loc >= loc, "invalid name location for " + TypeVarFunc, (name_loc, loc, tokens)) + + typevar_info = self.current_parsing_context("typevars") + if typevar_info is not None: + # check to see if we already parsed this exact typevar, in which case just reuse the existing temp_name + if typevar_info["typevar_locs"].get(name, None) == name_loc: + name = typevar_info["all_typevars"][name] + else: + if name in typevar_info["all_typevars"]: + raise CoconutDeferredSyntaxError("type variable {name!r} already defined".format(name=name), loc) + temp_name = self.get_temp_var(("typevar", name), name_loc) + typevar_info["all_typevars"][name] = temp_name + typevar_info["new_typevars"].append((TypeVarFunc, temp_name)) + typevar_info["typevar_locs"][name] = name_loc + name = temp_name + + return '{name} = _coconut.typing.{TypeVarFunc}("{name}"{args}{kwargs})\n'.format( + name=name, + TypeVarFunc=TypeVarFunc, + args=args, + kwargs=kwargs, ) - def check_py(self, version, name, original, loc, tokens): - """Check for Python-version-specific syntax.""" - self.internal_assert(len(tokens) == 1, original, loc, "invalid " + name + " tokens", tokens) - version_info = get_target_info(version) - if self.target_info < version_info: - return self.raise_or_wrap_error(self.make_err( - CoconutTargetError, - "found Python " + ".".join(str(v) for v in version_info) + " " + name, - original, - loc, - target=version, - )) + def get_generic_for_typevars(self): + """Get the Generic instances for the current typevars.""" + typevar_info = self.current_parsing_context("typevars") + internal_assert(typevar_info is not None, "get_generic_for_typevars called with no typevars") + generics = [] + for TypeVarFunc, name in typevar_info["new_typevars"]: + if TypeVarFunc in ("TypeVar", "ParamSpec"): + generics.append(name) + elif TypeVarFunc == "TypeVarTuple": + if self.target_info >= (3, 11): + generics.append("*" + name) + else: + generics.append("_coconut.typing.Unpack[" + name + "]") + else: + raise CoconutInternalException("invalid TypeVarFunc", TypeVarFunc, "(", name, ")") + return "_coconut.typing.Generic[" + ", ".join(generics) + "]" + + @contextmanager + def type_alias_stmt_manage(self, original=None, loc=None, item=None): + """Manage the typevars parsing context.""" + prev_typevar_info = self.current_parsing_context("typevars") + with self.add_to_parsing_context("typevars", { + "all_typevars": {} if prev_typevar_info is None else prev_typevar_info["all_typevars"].copy(), + "new_typevars": [], + "typevar_locs": {}, + }): + yield + + def type_alias_stmt_handle(self, tokens): + """Handle type alias statements.""" + if len(tokens) == 2: + name, typedef = tokens + paramdefs = () else: - return tokens[0] + name, paramdefs, typedef = tokens + out = "".join(paramdefs) + if self.target_info >= (3, 12): + out += "type " + name + " = " + self.wrap_typedef(typedef, for_py_typedef=True) + else: + out += self.typed_assign_stmt_handle([ + name, + "_coconut.typing.TypeAlias", + self.wrap_typedef(typedef, for_py_typedef=False), + ]) + return out + + def where_item_handle(self, tokens): + """Manage where items.""" + where_context = self.current_parsing_context("where") + internal_assert(not where_context["assigns"], "invalid where_context", where_context) + where_context["assigns"] = set() + return tokens + + @contextmanager + def where_stmt_manage(self, original, loc, item): + """Manage where statements.""" + with self.add_to_parsing_context("where", { + "assigns": None, + }): + yield + + def where_stmt_handle(self, loc, tokens): + """Process where statements.""" + main_stmt, body_stmts = tokens + + where_assigns = self.current_parsing_context("where")["assigns"] + internal_assert(lambda: where_assigns is not None, "missing where_assigns") + + where_init = "".join(body_stmts) + where_final = main_stmt + "\n" + out = where_init + where_final + if not where_assigns: + return out + + name_regexes = { + name: compile_regex(r"\b" + name + r"\b") + for name in where_assigns + } + name_replacements = { + name: self.get_temp_var(("where", name), loc) + for name in where_assigns + } + + where_init = self.deferred_code_proc(where_init) + where_final = self.deferred_code_proc(where_final) + out = where_init + where_final + + out = sub_all(out, name_regexes, name_replacements) + + return self.wrap_passthrough(out, early=True) @contextmanager def class_manage(self, original, loc, item): @@ -4761,8 +4777,25 @@ def in_method(self): cls_context = self.current_parsing_context("class") return cls_context is not None and cls_context["name"] is not None and cls_context["in_method"] - def name_handle(self, original, loc, tokens, assign=False, classname=False): + @contextmanager + def has_expr_setname_manage(self, original, loc, item): + """Handle parses that can assign expr_setname.""" + with self.add_to_parsing_context( + "expr_setnames", + { + "parent": self.current_parsing_context("expr_setnames"), + "new_names": set(), + "callbacks": [], + "loc": loc, + }, + callbacks_key="callbacks", + ): + yield + + def name_handle(self, original, loc, tokens, assign=False, classname=False, expr_setname=False): """Handle the given base name.""" + internal_assert(assign if expr_setname else True, "expr_setname should always imply assign", (expr_setname, assign)) + name, = tokens if name.startswith("\\"): name = name[1:] @@ -4785,6 +4818,11 @@ def name_handle(self, original, loc, tokens, assign=False, classname=False): self.internal_assert(cls_context is not None, original, loc, "found classname outside of class", tokens) cls_context["name"] = name + if expr_setname: + expr_setnames_context = self.current_parsing_context("expr_setnames") + self.internal_assert(expr_setnames_context is not None, original, loc, "found expr_setname outside of has_expr_setname_manage", tokens) + expr_setnames_context["new_names"].add(name) + # raise_or_wrap_error for all errors here to make sure we don't # raise spurious errors if not using the computation graph @@ -4819,8 +4857,8 @@ def name_handle(self, original, loc, tokens, assign=False, classname=False): # greedily, which means this might be an invalid parse, in which # case we can't be sure this is actually shadowing a builtin and USE_COMPUTATION_GRAPH - # classnames are handled greedily, so ditto the above - and not classname + # classnames and expr_setnames are handled greedily, so ditto the above + and not (classname or expr_setname) and name in all_builtins ): self.strict_err_or_warn( @@ -4863,6 +4901,75 @@ def name_handle(self, original, loc, tokens, assign=False, classname=False): else: return name +# end: MANAGERS +# ----------------------------------------------------------------------------------------------------------------------- +# CHECKING HANDLERS: +# ----------------------------------------------------------------------------------------------------------------------- + + def check_strict(self, name, original, loc, tokens=(None,), only_warn=False, always_warn=False): + """Check that syntax meets --strict requirements.""" + self.internal_assert(len(tokens) == 1, original, loc, "invalid " + name + " tokens", tokens) + message = "found " + name + if self.strict: + kwargs = {} + if only_warn: + if not always_warn: + kwargs["extra"] = "remove --strict to dismiss" + self.syntax_warning(message, original, loc, **kwargs) + else: + if always_warn: + kwargs["extra"] = "remove --strict to downgrade to a warning" + return self.raise_or_wrap_error(self.make_err(CoconutStyleError, message, original, loc, **kwargs)) + elif always_warn: + self.syntax_warning(message, original, loc) + return tokens[0] + + def lambdef_check(self, original, loc, tokens): + """Check for Python-style lambdas.""" + return self.check_strict("Python-style lambda", original, loc, tokens) + + def endline_semicolon_check(self, original, loc, tokens): + """Check for semicolons at the end of lines.""" + return self.check_strict("semicolon at end of line", original, loc, tokens, always_warn=True) + + def u_string_check(self, original, loc, tokens): + """Check for Python-2-style unicode strings.""" + return self.check_strict("Python-2-style unicode string (all Coconut strings are unicode strings)", original, loc, tokens, always_warn=True) + + def match_dotted_name_const_check(self, original, loc, tokens): + """Check for Python-3.10-style implicit dotted name match check.""" + return self.check_strict("Python-3.10-style dotted name in pattern-matching (Coconut style is to use '=={name}' not '{name}')".format(name=tokens[0]), original, loc, tokens) + + def match_check_equals_check(self, original, loc, tokens): + """Check for old-style =item in pattern-matching.""" + return self.check_strict("deprecated equality-checking '=...' pattern; use '==...' instead", original, loc, tokens, always_warn=True) + + def power_in_impl_call_check(self, original, loc, tokens): + """Check for exponentation in implicit function application / coefficient syntax.""" + return self.check_strict( + "syntax with new behavior in Coconut v3; 'f x ** y' is now equivalent to 'f(x**y)' not 'f(x)**y'", + original, + loc, + tokens, + only_warn=True, + always_warn=True, + ) + + def check_py(self, version, name, original, loc, tokens): + """Check for Python-version-specific syntax.""" + self.internal_assert(len(tokens) == 1, original, loc, "invalid " + name + " tokens", tokens) + version_info = get_target_info(version) + if self.target_info < version_info: + return self.raise_or_wrap_error(self.make_err( + CoconutTargetError, + "found Python " + ".".join(str(v) for v in version_info) + " " + name, + original, + loc, + target=version, + )) + else: + return tokens[0] + def nonlocal_check(self, original, loc, tokens): """Check for Python 3 nonlocal statement.""" return self.check_py("3", "nonlocal statement", original, loc, tokens) @@ -4893,7 +5000,7 @@ def namedexpr_check(self, original, loc, tokens): def new_namedexpr_check(self, original, loc, tokens): """Check for Python 3.10 assignment expressions.""" - return self.check_py("310", "assignment expression in set literal or indexing", original, loc, tokens) + return self.check_py("310", "assignment expression in syntactic location only supported for 3.10+", original, loc, tokens) def except_star_clause_check(self, original, loc, tokens): """Check for Python 3.11 except* statements.""" diff --git a/coconut/compiler/grammar.py b/coconut/compiler/grammar.py index cd6ff8339..ecd180cec 100644 --- a/coconut/compiler/grammar.py +++ b/coconut/compiler/grammar.py @@ -802,6 +802,7 @@ class Grammar(object): refname = Forward() setname = Forward() + expr_setname = Forward() classname = Forward() name_ref = combine(Optional(backslash) + base_name) unsafe_name = combine(Optional(backslash.suppress()) + base_name) @@ -955,13 +956,13 @@ class Grammar(object): expr = Forward() star_expr = Forward() dubstar_expr = Forward() - comp_for = Forward() test_no_cond = Forward() infix_op = Forward() namedexpr_test = Forward() # for namedexpr locations only supported in Python 3.10 new_namedexpr_test = Forward() - lambdef = Forward() + comp_for = Forward() + comprehension_expr = Forward() typedef = Forward() typedef_default = Forward() @@ -971,6 +972,10 @@ class Grammar(object): typedef_ellipsis = Forward() typedef_op_item = Forward() + expr_lambdef = Forward() + stmt_lambdef = Forward() + lambdef = expr_lambdef | stmt_lambdef + negable_atom_item = condense(Optional(neg_minus) + atom_item) testlist = itemlist(test, comma, suppress_trailing=False) @@ -1148,8 +1153,8 @@ class Grammar(object): ZeroOrMore( condense( # everything here must end with setarg_comma - setname + Optional(default) + setarg_comma - | (star | dubstar) + setname + setarg_comma + expr_setname + Optional(default) + setarg_comma + | (star | dubstar) + expr_setname + setarg_comma | star_sep_setarg | slash_sep_setarg ) @@ -1180,7 +1185,7 @@ class Grammar(object): # everything here must end with rparen rparen.suppress() | tokenlist(Group(call_item), comma) + rparen.suppress() - | Group(attach(addspace(test + comp_for), add_parens_handle)) + rparen.suppress() + | Group(attach(comprehension_expr, add_parens_handle)) + rparen.suppress() | Group(op_item) + rparen.suppress() ) function_call = Forward() @@ -1230,10 +1235,6 @@ class Grammar(object): comma, ) - comprehension_expr = ( - addspace(namedexpr_test + comp_for) - | invalid_syntax(star_expr + comp_for, "iterable unpacking cannot be used in comprehension") - ) paren_atom = condense(lparen + any_of( # everything here must end with rparen rparen, @@ -1282,7 +1283,7 @@ class Grammar(object): setmaker = Group( (new_namedexpr_test + FollowedBy(rbrace))("test") | (new_namedexpr_testlist_has_comma + FollowedBy(rbrace))("list") - | addspace(new_namedexpr_test + comp_for + FollowedBy(rbrace))("comp") + | (comprehension_expr + FollowedBy(rbrace))("comp") | (testlist_star_namedexpr + FollowedBy(rbrace))("testlist_star_expr") ) set_literal_ref = lbrace.suppress() + setmaker + rbrace.suppress() @@ -1382,6 +1383,9 @@ class Grammar(object): no_partial_trailer_atom_ref = atom + ZeroOrMore(no_partial_trailer) partial_atom_tokens = no_partial_trailer_atom + partial_trailer_tokens + # must be kept in sync with expr_assignlist block below + assignlist = Forward() + star_assign_item = Forward() simple_assign = Forward() simple_assign_ref = maybeparens( lparen, @@ -1391,12 +1395,8 @@ class Grammar(object): | setname | passthrough_atom ), - rparen + rparen, ) - simple_assignlist = maybeparens(lparen, itemlist(simple_assign, comma, suppress_trailing=False), rparen) - - assignlist = Forward() - star_assign_item = Forward() base_assign_item = condense( simple_assign | lparen + assignlist + rparen @@ -1406,6 +1406,30 @@ class Grammar(object): assign_item = base_assign_item | star_assign_item assignlist <<= itemlist(assign_item, comma, suppress_trailing=False) + # must be kept in sync with assignlist block above (but with expr_setname) + expr_assignlist = Forward() + expr_star_assign_item = Forward() + expr_simple_assign = Forward() + expr_simple_assign_ref = maybeparens( + lparen, + ( + # refname if there's a trailer, expr_setname if not + (refname | passthrough_atom) + OneOrMore(ZeroOrMore(complex_trailer) + OneOrMore(simple_trailer)) + | expr_setname + | passthrough_atom + ), + rparen, + ) + expr_base_assign_item = condense( + expr_simple_assign + | lparen + expr_assignlist + rparen + | lbrack + expr_assignlist + rbrack + ) + expr_star_assign_item_ref = condense(star + expr_base_assign_item) + expr_assign_item = expr_base_assign_item | expr_star_assign_item + expr_assignlist <<= itemlist(expr_assign_item, comma, suppress_trailing=False) + + simple_assignlist = maybeparens(lparen, itemlist(simple_assign, comma, suppress_trailing=False), rparen) typed_assign_stmt = Forward() typed_assign_stmt_ref = simple_assign + colon.suppress() + typedef_test + Optional(equals.suppress() + test_expr) basic_stmt = addspace(ZeroOrMore(assignlist + equals) + test_expr) @@ -1639,7 +1663,10 @@ class Grammar(object): unsafe_lambda_arrow = any_of(fat_arrow, arrow) keyword_lambdef_params = maybeparens(lparen, set_args_list, rparen) - arrow_lambdef_params = lparen.suppress() + set_args_list + rparen.suppress() | setname + arrow_lambdef_params = ( + lparen.suppress() + set_args_list + rparen.suppress() + | expr_setname + ) keyword_lambdef = Forward() keyword_lambdef_ref = addspace(keyword("lambda") + condense(keyword_lambdef_params + colon)) @@ -1651,7 +1678,6 @@ class Grammar(object): keyword_lambdef, ) - stmt_lambdef = Forward() match_guard = Optional(keyword("if").suppress() + namedexpr_test) closing_stmt = longest(new_testlist_star_expr("tests"), unsafe_simple_stmt_item) stmt_lambdef_match_params = Group(lparen.suppress() + match_args_list + match_guard + rparen.suppress()) @@ -1698,8 +1724,9 @@ class Grammar(object): | fixto(always_match, "") ) - lambdef <<= addspace(lambdef_base + test) | stmt_lambdef - lambdef_no_cond = addspace(lambdef_base + test_no_cond) + expr_lambdef_ref = addspace(lambdef_base + test) + lambdef_no_cond = Forward() + lambdef_no_cond_ref = addspace(lambdef_base + test_no_cond) typedef_callable_arg = Group( test("arg") @@ -1808,11 +1835,15 @@ class Grammar(object): invalid_syntax(maybeparens(lparen, namedexpr, rparen), "PEP 572 disallows assignment expressions in comprehension iterable expressions") | test_item ) - base_comp_for = addspace(keyword("for") + assignlist + keyword("in") + comp_it_item + Optional(comp_iter)) + base_comp_for = addspace(keyword("for") + expr_assignlist + keyword("in") + comp_it_item + Optional(comp_iter)) async_comp_for_ref = addspace(keyword("async") + base_comp_for) comp_for <<= base_comp_for | async_comp_for comp_if = addspace(keyword("if") + test_no_cond + Optional(comp_iter)) comp_iter <<= any_of(comp_for, comp_if) + comprehension_expr_ref = ( + addspace(namedexpr_test + comp_for) + | invalid_syntax(star_expr + comp_for, "iterable unpacking cannot be used in comprehension") + ) return_stmt = addspace(keyword("return") - Optional(new_testlist_star_expr)) @@ -2547,7 +2578,7 @@ class Grammar(object): original_function_call_tokens = ( lparen.suppress() + rparen.suppress() # we need to keep the parens here, since f(x for x in y) is fine but tail_call(f, x for x in y) is not - | condense(lparen + originalTextFor(test + comp_for) + rparen) + | condense(lparen + originalTextFor(comprehension_expr) + rparen) | attach(parens, strip_parens_handle) ) diff --git a/coconut/compiler/util.py b/coconut/compiler/util.py index 6de9b871f..3ed7744ec 100644 --- a/coconut/compiler/util.py +++ b/coconut/compiler/util.py @@ -1166,7 +1166,7 @@ class Wrap(ParseElementEnhance): global_instance_counter = 0 inside = False - def __init__(self, item, wrapper, greedy=False, include_in_packrat_context=False): + def __init__(self, item, wrapper, greedy=False, include_in_packrat_context=True): super(Wrap, self).__init__(item) self.wrapper = wrapper self.greedy = greedy @@ -1225,10 +1225,14 @@ def __repr__(self): return self.wrapped_name -def handle_and_manage(item, handler, manager): +def manage(item, manager, greedy=True, include_in_packrat_context=False): + """Attach a manager to the given parse item.""" + return Wrap(item, manager, greedy=greedy, include_in_packrat_context=include_in_packrat_context) + + +def handle_and_manage(item, handler, manager, **kwargs): """Attach a handler and a manager to the given parse item.""" - new_item = attach(item, handler) - return Wrap(new_item, manager, greedy=True) + return manage(attach(item, handler), manager, **kwargs) def disable_inside(item, *elems, **kwargs): diff --git a/coconut/root.py b/coconut/root.py index 88841b3f7..7718ae01c 100644 --- a/coconut/root.py +++ b/coconut/root.py @@ -26,7 +26,7 @@ VERSION = "3.0.4" VERSION_NAME = None # False for release, int >= 1 for develop -DEVELOP = 20 +DEVELOP = 21 ALPHA = False # for pre releases rather than post releases assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1" diff --git a/coconut/tests/src/cocotest/agnostic/primary_2.coco b/coconut/tests/src/cocotest/agnostic/primary_2.coco index 322457897..b01ab2a24 100644 --- a/coconut/tests/src/cocotest/agnostic/primary_2.coco +++ b/coconut/tests/src/cocotest/agnostic/primary_2.coco @@ -444,6 +444,14 @@ def primary_test_2() -> bool: assert b"Abc" |> fmap$(.|32) == b"abc" assert bytearray(b"Abc") |> fmap$(.|32) == bytearray(b"abc") assert (bytearray(b"Abc") |> fmap$(.|32)) `isinstance` bytearray + assert 10 |> lift(+)((x -> x), (def y -> y)) == 20 + assert (x -> def y -> (x, y))(1)(2) == (1, 2) == (x -> copyclosure def y -> (x, y))(1)(2) # type: ignore + assert ((x, y) -> def z -> (x, y, z))(1, 2)(3) == (1, 2, 3) == (x -> y -> def z -> (x, y, z))(1)(2)(3) # type: ignore + assert [def x -> (x, y) for y in range(10)] |> map$(call$(?, 10)) |> list == [(10, y) for y in range(10)] + assert [x -> (x, y) for y in range(10)] |> map$(call$(?, 10)) |> list == [(10, 9) for y in range(10)] + assert [=> y for y in range(2)] |> map$(call) |> list == [1, 1] + assert [def => y for y in range(2)] |> map$(call) |> list == [0, 1] + assert (x -> x -> def y -> (x, y))(1)(2)(3) == (2, 3) with process_map.multiple_sequential_calls(): # type: ignore assert map((+), range(3), range(4)$[:-1], strict=True) |> list == [0, 2, 4] == process_map((+), range(3), range(4)$[:-1], strict=True) |> list # type: ignore diff --git a/coconut/tests/src/cocotest/agnostic/util.coco b/coconut/tests/src/cocotest/agnostic/util.coco index 0feebd3a1..f58003eec 100644 --- a/coconut/tests/src/cocotest/agnostic/util.coco +++ b/coconut/tests/src/cocotest/agnostic/util.coco @@ -34,6 +34,14 @@ def assert_raises(c, exc): else: raise AssertionError(f"{c} failed to raise exception {exc}") +def x `typed_eq` y = (type(x), x) == (type(y), y) + +def pickle_round_trip(obj) = ( + obj + |> pickle.dumps + |> pickle.loads +) + try: prepattern() # type: ignore except NameError, TypeError: @@ -44,14 +52,6 @@ except NameError, TypeError: return addpattern(func, base_func, **kwargs) return pattern_prepender -def x `typed_eq` y = (type(x), x) == (type(y), y) - -def pickle_round_trip(obj) = ( - obj - |> pickle.dumps - |> pickle.loads -) - # Old functions: old_fmap = fmap$(starmap_over_mappings=True) diff --git a/coconut/tests/src/cocotest/target_38/py38_test.coco b/coconut/tests/src/cocotest/target_38/py38_test.coco index 8c4f30efc..13ed72b9c 100644 --- a/coconut/tests/src/cocotest/target_38/py38_test.coco +++ b/coconut/tests/src/cocotest/target_38/py38_test.coco @@ -10,4 +10,6 @@ def py38_test() -> bool: assert 10 |> (x := .) == 10 == x assert 10 |> (x := .) |> (. + 1) == 11 assert x == 10 + assert not consume(y := i for i in range(10)) + assert y == 9 return True diff --git a/coconut/tests/src/extras.coco b/coconut/tests/src/extras.coco index 0b7a55289..0bb22fbde 100644 --- a/coconut/tests/src/extras.coco +++ b/coconut/tests/src/extras.coco @@ -313,8 +313,14 @@ def g(x) = x assert parse("def f(x):\n ${var}", "xonsh") == "def f(x):\n ${var}\n" assert "data ABC" not in parse("data ABC:\n ${var}", "xonsh") - assert parse('"abc" "xyz"', "lenient") == "'abcxyz'" + assert "builder" not in parse("def x -> x", "lenient") + assert parse("def x -> x", "lenient").count("def") == 1 + assert "builder" in parse("x -> def y -> (x, y)", "lenient") + assert parse("x -> def y -> (x, y)", "lenient").count("def") == 2 + assert "builder" in parse("[def x -> (x, y) for y in range(10)]", "lenient") + assert parse("[def x -> (x, y) for y in range(10)]", "lenient").count("def") == 2 + assert parse("123 # derp", "lenient") == "123 # derp" return True @@ -465,6 +471,11 @@ async def async_map_test() = # Compiled Coconut: ----------------------------------------------------------- type Num = int | float""".strip()) + assert parse("type L[T] = list[T]").strip().endswith(""" +# Compiled Coconut: ----------------------------------------------------------- + +_coconut_typevar_T_0 = _coconut.typing.TypeVar("_coconut_typevar_T_0") +type L = list[_coconut_typevar_T_0]""".strip()) setup(line_numbers=False, minify=True) assert parse("123 # derp", "lenient") == "123# derp" From ccfc8823ad61c8ded37f5384a45062d13eb69dcf Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Sat, 24 Feb 2024 22:29:58 -0800 Subject: [PATCH 43/54] Make kwarg required --- coconut/compiler/compiler.py | 9 +++++++++ coconut/compiler/util.py | 4 ++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/coconut/compiler/compiler.py b/coconut/compiler/compiler.py index 7f6efdafb..17c651d23 100644 --- a/coconut/compiler/compiler.py +++ b/coconut/compiler/compiler.py @@ -711,16 +711,19 @@ def bind(cls): cls.classdef_ref, cls.method("classdef_handle"), cls.method("class_manage"), + include_in_packrat_context=False, ) cls.datadef <<= handle_and_manage( cls.datadef_ref, cls.method("datadef_handle"), cls.method("class_manage"), + include_in_packrat_context=False, ) cls.match_datadef <<= handle_and_manage( cls.match_datadef_ref, cls.method("match_datadef_handle"), cls.method("class_manage"), + include_in_packrat_context=False, ) # handle parsing_context for function definitions @@ -728,16 +731,19 @@ def bind(cls): cls.stmt_lambdef_ref, cls.method("stmt_lambdef_handle"), cls.method("func_manage"), + include_in_packrat_context=False, ) cls.decoratable_normal_funcdef_stmt <<= handle_and_manage( cls.decoratable_normal_funcdef_stmt_ref, cls.method("decoratable_funcdef_stmt_handle"), cls.method("func_manage"), + include_in_packrat_context=False, ) cls.decoratable_async_funcdef_stmt <<= handle_and_manage( cls.decoratable_async_funcdef_stmt_ref, cls.method("decoratable_funcdef_stmt_handle", is_async=True), cls.method("func_manage"), + include_in_packrat_context=False, ) # handle parsing_context for type aliases @@ -745,6 +751,7 @@ def bind(cls): cls.type_alias_stmt_ref, cls.method("type_alias_stmt_handle"), cls.method("type_alias_stmt_manage"), + include_in_packrat_context=False, ) # handle parsing_context for where statements @@ -752,11 +759,13 @@ def bind(cls): cls.where_stmt_ref, cls.method("where_stmt_handle"), cls.method("where_stmt_manage"), + include_in_packrat_context=False, ) cls.implicit_return_where <<= handle_and_manage( cls.implicit_return_where_ref, cls.method("where_stmt_handle"), cls.method("where_stmt_manage"), + include_in_packrat_context=False, ) # handle parsing_context for expr_setnames diff --git a/coconut/compiler/util.py b/coconut/compiler/util.py index 3ed7744ec..ffbbf6151 100644 --- a/coconut/compiler/util.py +++ b/coconut/compiler/util.py @@ -1225,9 +1225,9 @@ def __repr__(self): return self.wrapped_name -def manage(item, manager, greedy=True, include_in_packrat_context=False): +def manage(item, manager, include_in_packrat_context, greedy=True): """Attach a manager to the given parse item.""" - return Wrap(item, manager, greedy=greedy, include_in_packrat_context=include_in_packrat_context) + return Wrap(item, manager, include_in_packrat_context=include_in_packrat_context, greedy=greedy) def handle_and_manage(item, handler, manager, **kwargs): From b65e25e355648cf95e17e4c7d2317e059576ed81 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Sat, 24 Feb 2024 23:11:14 -0800 Subject: [PATCH 44/54] Fix tests --- coconut/tests/main_test.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/coconut/tests/main_test.py b/coconut/tests/main_test.py index 0f84ee941..f50aac556 100644 --- a/coconut/tests/main_test.py +++ b/coconut/tests/main_test.py @@ -153,8 +153,7 @@ ) mypy_snip = "a: str = count()[0]" -mypy_snip_err_2 = '''error: Incompatible types in assignment (expression has type\n"int", variable has type "unicode")''' -mypy_snip_err_3 = '''error: Incompatible types in assignment (expression has type\n"int", variable has type "str")''' +mypy_snip_err = '''error: Incompatible types in assignment (expression has type''' mypy_args = ["--follow-imports", "silent", "--ignore-missing-imports", "--allow-redefinition"] @@ -427,6 +426,7 @@ def comp(path=None, folder=None, file=None, args=[], **kwargs): def rm_path(path, allow_keep=False): """Delete a path.""" + print("DELETING", path) path = os.path.abspath(fixpath(path)) assert not base_dir.startswith(path), "refusing to delete Coconut itself: " + repr(path) if allow_keep and get_bool_env_var("COCONUT_KEEP_TEST_FILES"): @@ -856,7 +856,7 @@ def test_target_3_snip(self): def test_universal_mypy_snip(self): call( ["coconut", "-c", mypy_snip, "--mypy"], - assert_output=mypy_snip_err_3, + assert_output=mypy_snip_err, check_errors=False, check_mypy=False, ) @@ -864,7 +864,7 @@ def test_universal_mypy_snip(self): def test_sys_mypy_snip(self): call( ["coconut", "--target", "sys", "-c", mypy_snip, "--mypy"], - assert_output=mypy_snip_err_3, + assert_output=mypy_snip_err, check_errors=False, check_mypy=False, ) @@ -872,7 +872,7 @@ def test_sys_mypy_snip(self): def test_no_wrap_mypy_snip(self): call( ["coconut", "--target", "sys", "--no-wrap", "-c", mypy_snip, "--mypy"], - assert_output=mypy_snip_err_3, + assert_output=mypy_snip_err, check_errors=False, check_mypy=False, ) @@ -889,7 +889,8 @@ def test_import_hook(self): with using_coconut(): auto_compilation(True) import runnable - reload(runnable) + if not PY2: # triggers a weird metaclass conflict + reload(runnable) assert runnable.success == "" def test_find_packages(self): From 8ad52232dfc63b9db0520505cfbf927b364d149e Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Sat, 24 Feb 2024 23:23:36 -0800 Subject: [PATCH 45/54] Add print statements --- coconut/command/util.py | 1 + coconut/tests/main_test.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/coconut/command/util.py b/coconut/command/util.py index 53cb00bfb..3f1375dc3 100644 --- a/coconut/command/util.py +++ b/coconut/command/util.py @@ -420,6 +420,7 @@ def unlink(link_path): def rm_dir_or_link(dir_to_rm): """Safely delete a directory without deleting the contents of symlinks.""" + print("rm_dir_or_link", dir_to_rm) if not unlink(dir_to_rm) and os.path.exists(dir_to_rm): if WINDOWS: try: diff --git a/coconut/tests/main_test.py b/coconut/tests/main_test.py index f50aac556..4599e18f5 100644 --- a/coconut/tests/main_test.py +++ b/coconut/tests/main_test.py @@ -426,7 +426,7 @@ def comp(path=None, folder=None, file=None, args=[], **kwargs): def rm_path(path, allow_keep=False): """Delete a path.""" - print("DELETING", path) + print("rm_path", path) path = os.path.abspath(fixpath(path)) assert not base_dir.startswith(path), "refusing to delete Coconut itself: " + repr(path) if allow_keep and get_bool_env_var("COCONUT_KEEP_TEST_FILES"): From 568a620a233396873ac3915642f5a0ebc6f64fcc Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Sat, 24 Feb 2024 23:31:41 -0800 Subject: [PATCH 46/54] Fix py2 deleting dir --- coconut/command/util.py | 8 ++++++-- coconut/tests/main_test.py | 1 - 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/coconut/command/util.py b/coconut/command/util.py index 3f1375dc3..c4e0b1e7d 100644 --- a/coconut/command/util.py +++ b/coconut/command/util.py @@ -420,9 +420,13 @@ def unlink(link_path): def rm_dir_or_link(dir_to_rm): """Safely delete a directory without deleting the contents of symlinks.""" - print("rm_dir_or_link", dir_to_rm) if not unlink(dir_to_rm) and os.path.exists(dir_to_rm): - if WINDOWS: + if PY2: # shutil.rmtree doesn't seem to be fully safe on Python 2 + try: + os.rmdir(dir_to_rm) + except OSError: + logger.warn_exc() + elif WINDOWS: try: os.rmdir(dir_to_rm) except OSError: diff --git a/coconut/tests/main_test.py b/coconut/tests/main_test.py index 4599e18f5..155f8e17b 100644 --- a/coconut/tests/main_test.py +++ b/coconut/tests/main_test.py @@ -426,7 +426,6 @@ def comp(path=None, folder=None, file=None, args=[], **kwargs): def rm_path(path, allow_keep=False): """Delete a path.""" - print("rm_path", path) path = os.path.abspath(fixpath(path)) assert not base_dir.startswith(path), "refusing to delete Coconut itself: " + repr(path) if allow_keep and get_bool_env_var("COCONUT_KEEP_TEST_FILES"): From 0adf2d60e0b72fa39e8a6eee60aced4d3272ede3 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Sat, 24 Feb 2024 23:48:18 -0800 Subject: [PATCH 47/54] Bump dependencies --- .pre-commit-config.yaml | 2 +- coconut/constants.py | 20 +++++++++++--------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2df5155a4..c7868a2d8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,7 +34,7 @@ repos: args: - --autofix - repo: https://github.com/pycqa/flake8 - rev: 6.1.0 + rev: 7.0.0 hooks: - id: flake8 args: diff --git a/coconut/constants.py b/coconut/constants.py index 269fca1d5..8652033ae 100644 --- a/coconut/constants.py +++ b/coconut/constants.py @@ -999,7 +999,8 @@ def get_path_env_var(env_var, default): ), "tests": ( ("pytest", "py<36"), - ("pytest", "py36"), + ("pytest", "py>=36;py<38"), + ("pytest", "py38"), "pexpect", ), } @@ -1012,30 +1013,30 @@ def get_path_env_var(env_var, default): "jupyter": (1, 0), "types-backports": (0, 1), ("futures", "py<3"): (3, 4), - ("backports.functools-lru-cache", "py<3"): (1, 6), + ("backports.functools-lru-cache", "py<3"): (2,), ("argparse", "py<27"): (1, 4), "pexpect": (4,), ("trollius", "py<3;cpy"): (2, 2), "requests": (2, 31), ("numpy", "py39"): (1, 26), - ("xarray", "py39"): (2023,), + ("xarray", "py39"): (2024,), ("dataclasses", "py==36"): (0, 8), ("aenum", "py<34"): (3, 1, 15), - "pydata-sphinx-theme": (0, 14), + "pydata-sphinx-theme": (0, 15), "myst-parser": (2,), "sphinx": (7,), - "mypy[python2]": (1, 7), + "mypy[python2]": (1, 8), ("jupyter-console", "py37"): (6, 6), ("typing", "py<35"): (3, 10), - ("typing_extensions", "py>=38"): (4, 8), + ("typing_extensions", "py>=38"): (4, 9), ("ipykernel", "py38"): (6,), ("jedi", "py39"): (0, 19), ("pygments", "py>=39"): (2, 17), - ("xonsh", "py38"): (0, 14), - ("pytest", "py36"): (7,), + ("xonsh", "py38"): (0, 15), + ("pytest", "py38"): (8,), ("async_generator", "py35"): (1, 10), ("exceptiongroup", "py37;py<311"): (1,), - ("ipython", "py>=39"): (8, 18), + ("ipython", "py>=39"): (8, 22), "py-spy": (0, 3), } @@ -1053,6 +1054,7 @@ def get_path_env_var(env_var, default): ("pandas", "py36"): (1,), ("jupyter-client", "py36"): (7, 1, 2), ("typing_extensions", "py==36"): (4, 1), + ("pytest", "py>=36;py<38"): (7,), # don't upgrade these; they break on Python 3.5 ("ipykernel", "py3;py<38"): (5, 5), ("ipython", "py3;py<37"): (7, 9), From fcd104373f8400cf13f9e101d580252df75f8156 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Sun, 25 Feb 2024 01:14:39 -0800 Subject: [PATCH 48/54] Improve match def default handling Resolves #618. --- DOCS.md | 12 ++- coconut/compiler/matching.py | 97 ++++++++++++++----- coconut/root.py | 2 +- .../src/cocotest/agnostic/primary_2.coco | 3 + 4 files changed, 84 insertions(+), 30 deletions(-) diff --git a/DOCS.md b/DOCS.md index 5d3be5c74..69cd43466 100644 --- a/DOCS.md +++ b/DOCS.md @@ -1149,11 +1149,17 @@ depth: 1 ### `match` -Coconut provides fully-featured, functional pattern-matching through its `match` statements. +Coconut provides fully-featured, functional pattern-matching through its `match` statements. Coconut `match` syntax is a strict superset of [Python's `match` syntax](https://peps.python.org/pep-0636/). + +_Note: In describing Coconut's pattern-matching syntax, this section focuses on `match` statements, but Coconut's pattern-matching can also be used in many other places, such as [pattern-matching function definition](#pattern-matching-functions), [`case` statements](#case), [destructuring assignment](#destructuring-assignment), [`match data`](#match-data), and [`match for`](#match-for)._ ##### Overview -Match statements follow the basic syntax `match in `. The match statement will attempt to match the value against the pattern, and if successful, bind any variables in the pattern to whatever is in the same position in the value, and execute the code below the match statement. Match statements also support, in their basic syntax, an `if ` that will check the condition after executing the match before executing the code below, and an `else` statement afterwards that will only be executed if the `match` statement is not. What is allowed in the match statement's pattern has no equivalent in Python, and thus the specifications below are provided to explain it. +Match statements follow the basic syntax `match in `. The match statement will attempt to match the value against the pattern, and if successful, bind any variables in the pattern to whatever is in the same position in the value, and execute the code below the match statement. + +Match statements also support, in their basic syntax, an `if ` that will check the condition after executing the match before executing the code below, and an `else` statement afterwards that will only be executed if the `match` statement is not. + +All pattern-matching in Coconut is atomic, such that no assignments will be executed unless the whole match succeeds. ##### Syntax Specification @@ -2494,7 +2500,7 @@ If `` has a variable name (via any variable binding that binds the enti In addition to supporting pattern-matching in their arguments, pattern-matching function definitions also have a couple of notable differences compared to Python functions. Specifically: - If pattern-matching function definition fails, it will raise a [`MatchError`](#matcherror) (just like [destructuring assignment](#destructuring-assignment)) instead of a `TypeError`. -- All defaults in pattern-matching function definition are late-bound rather than early-bound. Thus, `match def f(xs=[]) = xs` will instantiate a new list for each call where `xs` is not given, unlike `def f(xs=[]) = xs`, which will use the same list for all calls where `xs` is unspecified. +- All defaults in pattern-matching function definition are late-bound rather than early-bound. Thus, `match def f(xs=[]) = xs` will instantiate a new list for each call where `xs` is not given, unlike `def f(xs=[]) = xs`, which will use the same list for all calls where `xs` is unspecified. This also allows defaults for later arguments to be specified in terms of matched values from earlier arguments, as in `match def f(x, y=x) = (x, y)`. Pattern-matching function definition can also be combined with `async` functions, [`copyclosure` functions](#copyclosure-functions), [`yield` functions](#explicit-generators), [infix function definition](#infix-functions), and [assignment function syntax](#assignment-functions). The various keywords in front of the `def` can be put in any order. diff --git a/coconut/compiler/matching.py b/coconut/compiler/matching.py index e70bdf46e..ff778b528 100644 --- a/coconut/compiler/matching.py +++ b/coconut/compiler/matching.py @@ -307,6 +307,39 @@ def get_set_name_var(self, name): """Gets the var for checking whether a name should be set.""" return match_set_name_var + "_" + name + def add_default_expr(self, assign_to, default_expr): + """Add code that evaluates expr in the context of any names that have been matched so far + and assigns the result to assign_to if assign_to is currently _coconut_sentinel.""" + vars_var = self.get_temp_var() + add_names_code = [] + for name in self.names: + add_names_code.append( + handle_indentation( + """ +if {set_name_var} is not _coconut_sentinel: + {vars_var}["{name}"] = {set_name_var} + """, + add_newline=True, + ).format( + set_name_var=self.get_set_name_var(name), + vars_var=vars_var, + name=name, + ) + ) + code = self.comp.reformat_post_deferred_code_proc(assign_to + " = " + default_expr) + self.add_def(handle_indentation(""" +if {assign_to} is _coconut_sentinel: + {vars_var} = _coconut.globals().copy() + {vars_var}.update(_coconut.locals().copy()) + {add_names_code}_coconut_exec({code_str}, {vars_var}) + {assign_to} = {vars_var}["{assign_to}"] + """).format( + vars_var=vars_var, + add_names_code="".join(add_names_code), + assign_to=assign_to, + code_str=self.comp.wrap_str_of(code), + )) + def register_name(self, name): """Register a new name at the current position.""" internal_assert(lambda: name not in self.parent_names and name not in self.names, "attempt to register duplicate name", name) @@ -373,7 +406,7 @@ def match_function( ).format( first_arg=first_arg, args=args, - ), + ) ) with self.down_a_level(): @@ -418,7 +451,7 @@ def match_in_args_kwargs(self, pos_only_match_args, match_args, args, kwargs, al # if i >= req_len "_coconut.sum((_coconut.len(" + args + ") > " + str(i) + ", " + ", ".join('"' + name + '" in ' + kwargs for name in names) - + ")) == 1", + + ")) == 1" ) tempvar = self.get_temp_var() self.add_def( @@ -428,16 +461,19 @@ def match_in_args_kwargs(self, pos_only_match_args, match_args, args, kwargs, al kwargs + '.pop("' + name + '") if "' + name + '" in ' + kwargs + " else " for name in names[:-1] ) - + kwargs + '.pop("' + names[-1] + '")', + + kwargs + '.pop("' + names[-1] + '")' ) with self.down_a_level(): self.match(match, tempvar) else: if not names: tempvar = self.get_temp_var() - self.add_def(tempvar + " = " + args + "[" + str(i) + "] if _coconut.len(" + args + ") > " + str(i) + " else " + default) - with self.down_a_level(): - self.match(match, tempvar) + self.add_def(tempvar + " = " + args + "[" + str(i) + "] if _coconut.len(" + args + ") > " + str(i) + " else _coconut_sentinel") + # go down to end to ensure we've matched as much as possible before evaluating the default + with self.down_to_end(): + self.add_default_expr(tempvar, default) + with self.down_a_level(): + self.match(match, tempvar) else: arg_checks[i] = ( # if i < req_len @@ -445,7 +481,7 @@ def match_in_args_kwargs(self, pos_only_match_args, match_args, args, kwargs, al # if i >= req_len "_coconut.sum((_coconut.len(" + args + ") > " + str(i) + ", " + ", ".join('"' + name + '" in ' + kwargs for name in names) - + ")) <= 1", + + ")) <= 1" ) tempvar = self.get_temp_var() self.add_def( @@ -455,10 +491,13 @@ def match_in_args_kwargs(self, pos_only_match_args, match_args, args, kwargs, al kwargs + '.pop("' + name + '") if "' + name + '" in ' + kwargs + " else " for name in names ) - + default, + + "_coconut_sentinel" ) - with self.down_a_level(): - self.match(match, tempvar) + # go down to end to ensure we've matched as much as possible before evaluating the default + with self.down_to_end(): + self.add_default_expr(tempvar, default) + with self.down_a_level(): + self.match(match, tempvar) # length checking max_len = None if allow_star_args else len(pos_only_match_args) + len(match_args) @@ -484,12 +523,18 @@ def match_in_kwargs(self, match_args, kwargs): kwargs + '.pop("' + name + '") if "' + name + '" in ' + kwargs + " else " for name in names ) - + (default if default is not None else "_coconut_sentinel"), + + "_coconut_sentinel" ) - with self.down_a_level(): - if default is None: + if default is None: + with self.down_a_level(): self.add_check(tempvar + " is not _coconut_sentinel") - self.match(match, tempvar) + self.match(match, tempvar) + else: + # go down to end to ensure we've matched as much as possible before evaluating the default + with self.down_to_end(): + self.add_default_expr(tempvar, default) + with self.down_a_level(): + self.match(match, tempvar) def match_dict(self, tokens, item): """Matches a dictionary.""" @@ -1054,7 +1099,7 @@ def match_class(self, tokens, item): ).format( num_pos_matches=len(pos_matches), cls_name=cls_name, - ), + ) ) else: self_match_matcher.match(pos_matches[0], item) @@ -1077,7 +1122,7 @@ def match_class(self, tokens, item): num_pos_matches=len(pos_matches), type_any=self.comp.wrap_comment(" type: _coconut.typing.Any"), type_ignore=self.comp.type_ignore_comment(), - ), + ) ) with other_cls_matcher.down_a_level(): for i, match in enumerate(pos_matches): @@ -1098,7 +1143,7 @@ def match_class(self, tokens, item): star_match_var=star_match_var, item=item, num_pos_matches=len(pos_matches), - ), + ) ) with self.down_a_level(): self.match(star_match, star_match_var) @@ -1118,7 +1163,7 @@ def match_data(self, tokens, item): "_coconut.len({item}) >= {min_len}".format( item=item, min_len=len(pos_matches), - ), + ) ) self.match_all_in(pos_matches, item) @@ -1152,7 +1197,7 @@ def match_data(self, tokens, item): min_len=len(pos_matches), name_matches=tuple_str_of(name_matches, add_quotes=True), type_ignore=self.comp.type_ignore_comment(), - ), + ) ) with self.down_a_level(): self.add_check(temp_var) @@ -1172,7 +1217,7 @@ def match_data_or_class(self, tokens, item): is_data_var=is_data_var, cls_name=cls_name, type_ignore=self.comp.type_ignore_comment(), - ), + ) ) if_data, if_class = self.branches(2) @@ -1248,7 +1293,7 @@ def match_view(self, tokens, item): func_result_var=func_result_var, view_func=view_func, item=item, - ), + ) ) with self.down_a_level(): @@ -1325,7 +1370,7 @@ def out(self): check_var=self.check_var, parameterization=parameterization, child_checks=child.out().rstrip(), - ), + ) ) # handle normal child groups @@ -1353,7 +1398,7 @@ def out(self): ).format( check_var=self.check_var, children_checks=children_checks, - ), + ) ) # commit variable definitions @@ -1369,7 +1414,7 @@ def out(self): ).format( set_name_var=self.get_set_name_var(name), name=name, - ), + ) ) if name_set_code: out.append( @@ -1381,7 +1426,7 @@ def out(self): ).format( check_var=self.check_var, name_set_code="".join(name_set_code), - ), + ) ) # handle guards @@ -1396,7 +1441,7 @@ def out(self): ).format( check_var=self.check_var, guards=paren_join(self.guards, "and"), - ), + ) ) return "".join(out) diff --git a/coconut/root.py b/coconut/root.py index 7718ae01c..0ea88022b 100644 --- a/coconut/root.py +++ b/coconut/root.py @@ -26,7 +26,7 @@ VERSION = "3.0.4" VERSION_NAME = None # False for release, int >= 1 for develop -DEVELOP = 21 +DEVELOP = 22 ALPHA = False # for pre releases rather than post releases assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1" diff --git a/coconut/tests/src/cocotest/agnostic/primary_2.coco b/coconut/tests/src/cocotest/agnostic/primary_2.coco index b01ab2a24..ee8ca556b 100644 --- a/coconut/tests/src/cocotest/agnostic/primary_2.coco +++ b/coconut/tests/src/cocotest/agnostic/primary_2.coco @@ -452,6 +452,9 @@ def primary_test_2() -> bool: assert [=> y for y in range(2)] |> map$(call) |> list == [1, 1] assert [def => y for y in range(2)] |> map$(call) |> list == [0, 1] assert (x -> x -> def y -> (x, y))(1)(2)(3) == (2, 3) + match def maybe_dup(x, y=x) = (x, y) + assert maybe_dup(1) == (1, 1) == maybe_dup(x=1) + assert maybe_dup(1, 2) == (1, 2) == maybe_dup(x=1, y=2) with process_map.multiple_sequential_calls(): # type: ignore assert map((+), range(3), range(4)$[:-1], strict=True) |> list == [0, 2, 4] == process_map((+), range(3), range(4)$[:-1], strict=True) |> list # type: ignore From 16905312d87071f8af41acfd88866b7c2c8a1ad9 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Sun, 25 Feb 2024 01:23:25 -0800 Subject: [PATCH 49/54] Remove unnecessary copying --- coconut/compiler/compiler.py | 2 +- coconut/compiler/matching.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/coconut/compiler/compiler.py b/coconut/compiler/compiler.py index 17c651d23..94f13a99c 100644 --- a/coconut/compiler/compiler.py +++ b/coconut/compiler/compiler.py @@ -2662,7 +2662,7 @@ def {mock_var}({mock_paramdef}): {vars_var} = {{"{def_name}": {def_name}}} else: {vars_var} = _coconut.globals().copy() - {vars_var}.update(_coconut.locals().copy()) + {vars_var}.update(_coconut.locals()) _coconut_exec({code_str}, {vars_var}) {func_name} = {func_from_vars} ''', diff --git a/coconut/compiler/matching.py b/coconut/compiler/matching.py index ff778b528..df35745bb 100644 --- a/coconut/compiler/matching.py +++ b/coconut/compiler/matching.py @@ -330,7 +330,7 @@ def add_default_expr(self, assign_to, default_expr): self.add_def(handle_indentation(""" if {assign_to} is _coconut_sentinel: {vars_var} = _coconut.globals().copy() - {vars_var}.update(_coconut.locals().copy()) + {vars_var}.update(_coconut.locals()) {add_names_code}_coconut_exec({code_str}, {vars_var}) {assign_to} = {vars_var}["{assign_to}"] """).format( From f38d95ba075d42e45e6d44decb3a28989e28fb7e Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Sun, 25 Feb 2024 16:53:36 -0800 Subject: [PATCH 50/54] Fix bugs --- coconut/compiler/compiler.py | 3 ++- coconut/constants.py | 14 ++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/coconut/compiler/compiler.py b/coconut/compiler/compiler.py index 94f13a99c..b567759dc 100644 --- a/coconut/compiler/compiler.py +++ b/coconut/compiler/compiler.py @@ -4016,7 +4016,8 @@ def {builder_name}({expr_setnames_str}): expr_setname_context["callbacks"].append(stmt_lambdef_callback) if parent_setnames: - builder_args = "**({" + ", ".join('"' + name + '": ' + name for name in sorted(parent_setnames)) + "} | _coconut.locals())" + # use _coconut.dict to ensure it supports | + builder_args = "**(_coconut.dict(" + ", ".join(name + '=' + name for name in sorted(parent_setnames)) + ") | _coconut.locals())" else: builder_args = "**_coconut.locals()" return builder_name + "(" + builder_args + ")" diff --git a/coconut/constants.py b/coconut/constants.py index 8652033ae..019964231 100644 --- a/coconut/constants.py +++ b/coconut/constants.py @@ -940,7 +940,8 @@ def get_path_env_var(env_var, default): ("ipython", "py3;py<37"), ("ipython", "py==37"), ("ipython", "py==38"), - ("ipython", "py>=39"), + ("ipython", "py==39"), + ("ipython", "py>=310"), ("ipykernel", "py<3"), ("ipykernel", "py3;py<38"), ("ipykernel", "py38"), @@ -974,8 +975,8 @@ def get_path_env_var(env_var, default): ), "xonsh": ( ("xonsh", "py<36"), - ("xonsh", "py>=36;py<38"), - ("xonsh", "py38"), + ("xonsh", "py>=36;py<39"), + ("xonsh", "py39"), ), "dev": ( ("pre-commit", "py3"), @@ -1032,17 +1033,18 @@ def get_path_env_var(env_var, default): ("ipykernel", "py38"): (6,), ("jedi", "py39"): (0, 19), ("pygments", "py>=39"): (2, 17), - ("xonsh", "py38"): (0, 15), + ("xonsh", "py39"): (0, 15), ("pytest", "py38"): (8,), ("async_generator", "py35"): (1, 10), ("exceptiongroup", "py37;py<311"): (1,), - ("ipython", "py>=39"): (8, 22), + ("ipython", "py>=310"): (8, 22), "py-spy": (0, 3), } pinned_min_versions = { # don't upgrade these; they break on Python 3.9 ("numpy", "py34;py<39"): (1, 18), + ("ipython", "py==39"): (8, 18), # don't upgrade these; they break on Python 3.8 ("ipython", "py==38"): (8, 12), # don't upgrade these; they break on Python 3.7 @@ -1050,7 +1052,7 @@ def get_path_env_var(env_var, default): ("typing_extensions", "py==37"): (4, 7), # don't upgrade these; they break on Python 3.6 ("anyio", "py36"): (3,), - ("xonsh", "py>=36;py<38"): (0, 11), + ("xonsh", "py>=36;py<39"): (0, 11), ("pandas", "py36"): (1,), ("jupyter-client", "py36"): (7, 1, 2), ("typing_extensions", "py==36"): (4, 1), From 5a331c47f0d8e3192f1ac7a0fafcf8406d0da7d7 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Sun, 25 Feb 2024 19:39:10 -0800 Subject: [PATCH 51/54] Fix py2 --- coconut/constants.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coconut/constants.py b/coconut/constants.py index 019964231..a827cdbdb 100644 --- a/coconut/constants.py +++ b/coconut/constants.py @@ -1014,7 +1014,6 @@ def get_path_env_var(env_var, default): "jupyter": (1, 0), "types-backports": (0, 1), ("futures", "py<3"): (3, 4), - ("backports.functools-lru-cache", "py<3"): (2,), ("argparse", "py<27"): (1, 4), "pexpect": (4,), ("trollius", "py<3;cpy"): (2, 2), @@ -1085,6 +1084,7 @@ def get_path_env_var(env_var, default): "watchdog": (0, 10), "papermill": (1, 2), ("numpy", "py<3;cpy"): (1, 16), + ("backports.functools-lru-cache", "py<3"): (1, 6), # don't upgrade this; it breaks with old IPython versions ("jedi", "py<39"): (0, 17), # Coconut requires pyparsing 2 From d0acf3dba97fb36e504ae15f41c9a4f1a04d8d20 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Thu, 29 Feb 2024 23:35:33 -0800 Subject: [PATCH 52/54] Optimize match def defaults --- coconut/compiler/grammar.py | 24 +++++++++++----- coconut/compiler/matching.py | 55 +++++++++++++++++++++--------------- 2 files changed, 50 insertions(+), 29 deletions(-) diff --git a/coconut/compiler/grammar.py b/coconut/compiler/grammar.py index ecd180cec..967930699 100644 --- a/coconut/compiler/grammar.py +++ b/coconut/compiler/grammar.py @@ -953,6 +953,7 @@ class Grammar(object): ) atom_item = Forward() + const_atom = Forward() expr = Forward() star_expr = Forward() dubstar_expr = Forward() @@ -1161,13 +1162,17 @@ class Grammar(object): ) ) ) + match_arg_default = Group( + const_atom("const") + | test("expr") + ) match_args_list = Group(Optional( tokenlist( Group( (star | dubstar) + match | star # not star_sep because pattern-matching can handle star separators on any Python version | slash # not slash_sep as above - | match + Optional(equals.suppress() + test) + | match + Optional(equals.suppress() + match_arg_default) ), comma, ) @@ -1292,19 +1297,24 @@ class Grammar(object): lazy_items = Optional(tokenlist(test, comma)) lazy_list = attach(lbanana.suppress() + lazy_items + rbanana.suppress(), lazy_list_handle) - known_atom = ( + # for const_atom, value should be known at compile time + const_atom <<= ( keyword_atom - | string_atom | num_atom + # typedef ellipsis must come before ellipsis + | typedef_ellipsis + | ellipsis + ) + # for known_atom, type should be known at compile time + known_atom = ( + const_atom + | string_atom | list_item | dict_literal | dict_comp | set_literal | set_letter_literal | lazy_list - # typedef ellipsis must come before ellipsis - | typedef_ellipsis - | ellipsis ) atom = ( # known_atom must come before name to properly parse string prefixes @@ -2197,7 +2207,7 @@ class Grammar(object): ( lparen.suppress() + match - + Optional(equals.suppress() + test) + + Optional(equals.suppress() + match_arg_default) + rparen.suppress() ) | interior_name_match ) diff --git a/coconut/compiler/matching.py b/coconut/compiler/matching.py index df35745bb..99e5457f5 100644 --- a/coconut/compiler/matching.py +++ b/coconut/compiler/matching.py @@ -307,38 +307,49 @@ def get_set_name_var(self, name): """Gets the var for checking whether a name should be set.""" return match_set_name_var + "_" + name - def add_default_expr(self, assign_to, default_expr): + def add_default_expr(self, assign_to, default): """Add code that evaluates expr in the context of any names that have been matched so far and assigns the result to assign_to if assign_to is currently _coconut_sentinel.""" - vars_var = self.get_temp_var() - add_names_code = [] - for name in self.names: - add_names_code.append( - handle_indentation( - """ + default_expr, = default + if "const" in default: + self.add_def(handle_indentation(""" +if {assign_to} is _coconut_sentinel: + {assign_to} = {default_expr} + """.format( + assign_to=assign_to, + default_expr=default_expr, + ))) + else: + internal_assert("expr" in default, "invalid match default tokens", default) + vars_var = self.get_temp_var() + add_names_code = [] + for name in self.names: + add_names_code.append( + handle_indentation( + """ if {set_name_var} is not _coconut_sentinel: {vars_var}["{name}"] = {set_name_var} - """, - add_newline=True, - ).format( - set_name_var=self.get_set_name_var(name), - vars_var=vars_var, - name=name, + """, + add_newline=True, + ).format( + set_name_var=self.get_set_name_var(name), + vars_var=vars_var, + name=name, + ) ) - ) - code = self.comp.reformat_post_deferred_code_proc(assign_to + " = " + default_expr) - self.add_def(handle_indentation(""" + code = self.comp.reformat_post_deferred_code_proc(assign_to + " = " + default_expr) + self.add_def(handle_indentation(""" if {assign_to} is _coconut_sentinel: {vars_var} = _coconut.globals().copy() {vars_var}.update(_coconut.locals()) {add_names_code}_coconut_exec({code_str}, {vars_var}) {assign_to} = {vars_var}["{assign_to}"] - """).format( - vars_var=vars_var, - add_names_code="".join(add_names_code), - assign_to=assign_to, - code_str=self.comp.wrap_str_of(code), - )) + """).format( + vars_var=vars_var, + add_names_code="".join(add_names_code), + assign_to=assign_to, + code_str=self.comp.wrap_str_of(code), + )) def register_name(self, name): """Register a new name at the current position.""" From 77ce54ca50a07523441843776a3ca7db6a7d9cff Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Thu, 29 Feb 2024 23:38:15 -0800 Subject: [PATCH 53/54] Set to v3.1.0 --- coconut/root.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/coconut/root.py b/coconut/root.py index 0ea88022b..44fe2b5c8 100644 --- a/coconut/root.py +++ b/coconut/root.py @@ -23,10 +23,10 @@ # VERSION: # ----------------------------------------------------------------------------------------------------------------------- -VERSION = "3.0.4" +VERSION = "3.1.0" VERSION_NAME = None # False for release, int >= 1 for develop -DEVELOP = 22 +DEVELOP = False ALPHA = False # for pre releases rather than post releases assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1" From d95706861671ed34e4731e12d48b9421cd5bde23 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Fri, 1 Mar 2024 01:33:13 -0800 Subject: [PATCH 54/54] Fix py35 test --- coconut/constants.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coconut/constants.py b/coconut/constants.py index a827cdbdb..c9d7d095a 100644 --- a/coconut/constants.py +++ b/coconut/constants.py @@ -84,7 +84,7 @@ def get_path_env_var(env_var, default): PY311 = sys.version_info >= (3, 11) PY312 = sys.version_info >= (3, 12) IPY = ( - PY35 + PY36 and (PY37 or not PYPY) and not (PYPY and WINDOWS) and sys.version_info[:2] != (3, 7)