From f13c522d697a0c08c0d7044fa35b032619c4e4dc Mon Sep 17 00:00:00 2001 From: Trey Wenger Date: Tue, 12 Nov 2024 10:40:02 -0600 Subject: [PATCH 1/5] support on_unused_input for string parameter names in eval --- pytensor/graph/basic.py | 8 ++++++-- tests/graph/test_basic.py | 9 +++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 29b2043f6d..844930226a 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -616,16 +616,20 @@ def eval( """ from pytensor.compile.function import function + on_unused_input = kwargs.get("on_unused_input", None) + def convert_string_keys_to_variables(inputs_to_values) -> dict["Variable", Any]: new_input_to_values = {} for key, value in inputs_to_values.items(): if isinstance(key, str): matching_vars = get_var_by_name([self], key) if not matching_vars: - raise ValueError(f"{key} not found in graph") + if on_unused_input in ["raise", None]: + raise ValueError(f"{key} not found in graph") elif len(matching_vars) > 1: raise ValueError(f"Found multiple variables with name {key}") - new_input_to_values[matching_vars[0]] = value + else: + new_input_to_values[matching_vars[0]] = value else: new_input_to_values[key] = value return new_input_to_values diff --git a/tests/graph/test_basic.py b/tests/graph/test_basic.py index 08c352ab71..a3860a7ad2 100644 --- a/tests/graph/test_basic.py +++ b/tests/graph/test_basic.py @@ -365,8 +365,17 @@ def test_eval_with_strings_no_match(self): def test_eval_kwargs(self): with pytest.raises(UnusedInputError): self.w.eval({self.z: 3, self.x: 2.5}) + with pytest.warns( + UserWarning, + match="pytensor.function was asked to create a function", + ): + self.w.eval({self.z: 3, self.x: 2.5}, on_unused_input="warn") assert self.w.eval({self.z: 3, self.x: 2.5}, on_unused_input="ignore") == 6.0 + # regression test for + q = self.x + 1 + assert q.eval({"x": 1, "y": 2}, on_unused_input="ignore") == 2.0 + @pytest.mark.filterwarnings("error") def test_eval_unashable_kwargs(self): y_repl = constant(2.0, dtype="floatX") From 1d40471ebdc03e465fc36338a4788f2d62103b7e Mon Sep 17 00:00:00 2001 From: Trey Wenger Date: Tue, 12 Nov 2024 10:42:54 -0600 Subject: [PATCH 2/5] update on_unused_input in test_eval_kwargs --- tests/graph/test_basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/graph/test_basic.py b/tests/graph/test_basic.py index a3860a7ad2..62bad2f69b 100644 --- a/tests/graph/test_basic.py +++ b/tests/graph/test_basic.py @@ -372,7 +372,7 @@ def test_eval_kwargs(self): self.w.eval({self.z: 3, self.x: 2.5}, on_unused_input="warn") assert self.w.eval({self.z: 3, self.x: 2.5}, on_unused_input="ignore") == 6.0 - # regression test for + # regression test for https://github.com/pymc-devs/pytensor/issues/1084 q = self.x + 1 assert q.eval({"x": 1, "y": 2}, on_unused_input="ignore") == 2.0 From 827d72d2efcb538d15bdd114fcad7cd72b290888 Mon Sep 17 00:00:00 2001 From: Trey Wenger Date: Tue, 12 Nov 2024 11:54:59 -0600 Subject: [PATCH 3/5] support on_unused_input for string parameter names in eval --- pytensor/graph/basic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 844930226a..3d94ce83b7 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -616,7 +616,7 @@ def eval( """ from pytensor.compile.function import function - on_unused_input = kwargs.get("on_unused_input", None) + ignore_unused_input = kwargs.get("on_unused_input", None) in ("ignore", "warn") def convert_string_keys_to_variables(inputs_to_values) -> dict["Variable", Any]: new_input_to_values = {} @@ -624,7 +624,7 @@ def convert_string_keys_to_variables(inputs_to_values) -> dict["Variable", Any]: if isinstance(key, str): matching_vars = get_var_by_name([self], key) if not matching_vars: - if on_unused_input in ["raise", None]: + if not ignore_unused_input: raise ValueError(f"{key} not found in graph") elif len(matching_vars) > 1: raise ValueError(f"Found multiple variables with name {key}") From b263ef0a340108350ea669818135cfd913d418af Mon Sep 17 00:00:00 2001 From: Trey Wenger Date: Tue, 12 Nov 2024 12:02:56 -0600 Subject: [PATCH 4/5] update test_eval for unused_input --- tests/graph/test_basic.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/graph/test_basic.py b/tests/graph/test_basic.py index 62bad2f69b..13a6f593f7 100644 --- a/tests/graph/test_basic.py +++ b/tests/graph/test_basic.py @@ -365,15 +365,15 @@ def test_eval_with_strings_no_match(self): def test_eval_kwargs(self): with pytest.raises(UnusedInputError): self.w.eval({self.z: 3, self.x: 2.5}) + assert self.w.eval({self.z: 3, self.x: 2.5}, on_unused_input="ignore") == 6.0 + + # regression tests for https://github.com/pymc-devs/pytensor/issues/1084 + q = self.x + 1 with pytest.warns( UserWarning, match="pytensor.function was asked to create a function", ): - self.w.eval({self.z: 3, self.x: 2.5}, on_unused_input="warn") - assert self.w.eval({self.z: 3, self.x: 2.5}, on_unused_input="ignore") == 6.0 - - # regression test for https://github.com/pymc-devs/pytensor/issues/1084 - q = self.x + 1 + assert q.eval({"x": 1, "y": 2}, on_unused_input="warn") == 2.0 assert q.eval({"x": 1, "y": 2}, on_unused_input="ignore") == 2.0 @pytest.mark.filterwarnings("error") From e4255313c2780245e1af74949a736686549a8ba4 Mon Sep 17 00:00:00 2001 From: Trey Wenger Date: Tue, 12 Nov 2024 12:04:14 -0600 Subject: [PATCH 5/5] update test_eval for unused_input --- tests/graph/test_basic.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/graph/test_basic.py b/tests/graph/test_basic.py index 13a6f593f7..84ffb365b5 100644 --- a/tests/graph/test_basic.py +++ b/tests/graph/test_basic.py @@ -367,13 +367,8 @@ def test_eval_kwargs(self): self.w.eval({self.z: 3, self.x: 2.5}) assert self.w.eval({self.z: 3, self.x: 2.5}, on_unused_input="ignore") == 6.0 - # regression tests for https://github.com/pymc-devs/pytensor/issues/1084 + # regression test for https://github.com/pymc-devs/pytensor/issues/1084 q = self.x + 1 - with pytest.warns( - UserWarning, - match="pytensor.function was asked to create a function", - ): - assert q.eval({"x": 1, "y": 2}, on_unused_input="warn") == 2.0 assert q.eval({"x": 1, "y": 2}, on_unused_input="ignore") == 2.0 @pytest.mark.filterwarnings("error")