diff --git a/guppylang/module.py b/guppylang/module.py index 28f16cd0..934c2b99 100644 --- a/guppylang/module.py +++ b/guppylang/module.py @@ -5,7 +5,6 @@ from types import ModuleType from typing import Any, Union -from guppylang.ast_util import AstNode from guppylang.checker.core import Globals, PyScope from guppylang.compiler.core import CompiledGlobals from guppylang.definition.common import ( @@ -104,6 +103,9 @@ def load(self, m: Union[ModuleType, "GuppyModule"]) -> None: def register_def(self, defn: RawDef, instance: TypeDef | None = None) -> None: """Registers a definition with this module. + If the name of the definition is already defined, the new definition + replaces the old. + Optionally, the definition can be marked as an instance method by passing the corresponding instance type definition. """ @@ -111,7 +113,6 @@ def register_def(self, defn: RawDef, instance: TypeDef | None = None) -> None: if self._instance_func_buffer is not None and not isinstance(defn, TypeDef): self._instance_func_buffer[defn.name] = defn else: - self._check_name_available(defn.name, defn.defined_at) if isinstance(defn, TypeDef | ParamDef): self._raw_type_defs[defn.id] = defn else: @@ -228,13 +229,6 @@ def _check_not_yet_compiled(self) -> None: if self._compiled: raise GuppyError(f"The module `{self.name}` has already been compiled") - def _check_name_available(self, name: str, node: AstNode | None) -> None: - if self.contains(name): - raise GuppyError( - f"Module `{self.name}` already contains a definition named `{name}`", - node, - ) - def get_py_scope(f: PyFunc) -> PyScope: """Returns a mapping of all variables captured by a Python function. diff --git a/tests/integration/test_redefinition.py b/tests/integration/test_redefinition.py new file mode 100644 index 00000000..f7497a95 --- /dev/null +++ b/tests/integration/test_redefinition.py @@ -0,0 +1,19 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + +import guppylang.prelude.quantum as quantum + + +def test_redefinition(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy(module) + def test() -> bool: + return True + + @guppy(module) + def test() -> bool: # noqa: F811 + return False + + validate(module.compile())