Skip to content

Commit

Permalink
support on_unused_input for string parameter names in eval (#1085)
Browse files Browse the repository at this point in the history
  • Loading branch information
tvwenger authored Nov 12, 2024
1 parent d9d8dba commit 2315e69
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
8 changes: 6 additions & 2 deletions pytensor/graph/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,16 +616,20 @@ def eval(
"""
from pytensor.compile.function import function

ignore_unused_input = kwargs.get("on_unused_input", None) in ("ignore", "warn")

def convert_string_keys_to_variables(inputs_to_values) -> dict["Variable", Any]:
new_input_to_values = {}
for key, value in inputs_to_values.items():
if isinstance(key, str):
matching_vars = get_var_by_name([self], key)
if not matching_vars:
raise ValueError(f"{key} not found in graph")
if not ignore_unused_input:
raise ValueError(f"{key} not found in graph")
elif len(matching_vars) > 1:
raise ValueError(f"Found multiple variables with name {key}")
new_input_to_values[matching_vars[0]] = value
else:
new_input_to_values[matching_vars[0]] = value
else:
new_input_to_values[key] = value
return new_input_to_values
Expand Down
4 changes: 4 additions & 0 deletions tests/graph/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,10 @@ def test_eval_kwargs(self):
self.w.eval({self.z: 3, self.x: 2.5})
assert self.w.eval({self.z: 3, self.x: 2.5}, on_unused_input="ignore") == 6.0

# regression test for https://github.com/pymc-devs/pytensor/issues/1084
q = self.x + 1
assert q.eval({"x": 1, "y": 2}, on_unused_input="ignore") == 2.0

@pytest.mark.filterwarnings("error")
def test_eval_unashable_kwargs(self):
y_repl = constant(2.0, dtype="floatX")
Expand Down

0 comments on commit 2315e69

Please sign in to comment.