-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rewrite batched dots that do not reduce as multiplication #1178
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from pytensor.graph.rewriting.utils import rewrite_graph | ||
|
||
|
||
all = ("rewrite_graph",) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -716,6 +716,32 @@ def test_masked_array_not_implemented( | |
ptb.as_tensor(x) | ||
|
||
|
||
def check_alloc_runtime_broadcast(mode): | ||
"""Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules.""" | ||
floatX = config.floatX | ||
x_v = vector("x", shape=(None,)) | ||
|
||
out = alloc(x_v, 5, 3) | ||
f = pytensor.function([x_v], out, mode=mode) | ||
TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why did you refactor There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I refactored the other out because trying to use the method from another module led pytest to execute all the tests of the class. The class calling a method on the class itself should be fine, which I assume is what's happening here? |
||
|
||
np.testing.assert_array_equal( | ||
f(x=np.zeros((3,), dtype=floatX)), | ||
np.zeros((5, 3), dtype=floatX), | ||
) | ||
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"): | ||
f(x=np.zeros((1,), dtype=floatX)) | ||
|
||
out = alloc(specify_shape(x_v, (1,)), 5, 3) | ||
f = pytensor.function([x_v], out, mode=mode) | ||
TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1) | ||
|
||
np.testing.assert_array_equal( | ||
f(x=np.zeros((1,), dtype=floatX)), | ||
np.zeros((5, 3), dtype=floatX), | ||
) | ||
|
||
|
||
class TestAlloc: | ||
dtype = config.floatX | ||
mode = mode_opt | ||
|
@@ -729,32 +755,6 @@ def check_allocs_in_fgraph(fgraph, n): | |
== n | ||
) | ||
|
||
@staticmethod | ||
def check_runtime_broadcast(mode): | ||
"""Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules.""" | ||
floatX = config.floatX | ||
x_v = vector("x", shape=(None,)) | ||
|
||
out = alloc(x_v, 5, 3) | ||
f = pytensor.function([x_v], out, mode=mode) | ||
TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1) | ||
|
||
np.testing.assert_array_equal( | ||
f(x=np.zeros((3,), dtype=floatX)), | ||
np.zeros((5, 3), dtype=floatX), | ||
) | ||
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"): | ||
f(x=np.zeros((1,), dtype=floatX)) | ||
|
||
out = alloc(specify_shape(x_v, (1,)), 5, 3) | ||
f = pytensor.function([x_v], out, mode=mode) | ||
TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1) | ||
|
||
np.testing.assert_array_equal( | ||
f(x=np.zeros((1,), dtype=floatX)), | ||
np.zeros((5, 3), dtype=floatX), | ||
) | ||
|
||
def setup_method(self): | ||
self.rng = np.random.default_rng(seed=utt.fetch_seed()) | ||
|
||
|
@@ -912,7 +912,7 @@ def test_alloc_of_view_linker(self): | |
|
||
@pytest.mark.parametrize("mode", (Mode("py"), Mode("c"))) | ||
def test_runtime_broadcast(self, mode): | ||
self.check_runtime_broadcast(mode) | ||
check_alloc_runtime_broadcast(mode) | ||
|
||
|
||
def test_infer_static_shape(): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's match week here at pymc-devs!