Skip to content

Commit

Permalink
Return from scalar constants in get_unique_constant_value
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Oct 31, 2024
1 parent 7a0ea76 commit 4d0aa3f
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 22 deletions.
10 changes: 6 additions & 4 deletions pytensor/tensor/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,11 +1045,13 @@ def get_unique_constant_value(x: TensorVariable) -> Number | None:
if isinstance(x, Constant):
data = x.data

if isinstance(data, np.ndarray) and data.ndim > 0:
if isinstance(data, np.ndarray) and data.size > 0:
if data.size == 1:
return data.squeeze()

flat_data = data.ravel()
if flat_data.shape[0]:
if (flat_data == flat_data[0]).all():
return flat_data[0]
if (flat_data == flat_data[0]).all():
return flat_data[0]

return None

Expand Down
34 changes: 16 additions & 18 deletions tests/scan/test_printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,24 +654,22 @@ def no_shared_fn(n, x_tm1, M):
Inner graphs:
Scan{scan_fn, while_loop=False, inplace=all} [id A]
← Composite{switch(lt(i0, i1), i2, i0)} [id I] (inner_out_sit_sot-0)
├─ 0 [id J]
├─ Subtensor{i, j, k} [id K]
│ ├─ *2-<Tensor3(float64, shape=(20000, 2, 2))> [id L] -> [id H] (inner_in_non_seqs-0)
│ ├─ ScalarFromTensor [id M]
│ │ └─ *0-<Scalar(int64, shape=())> [id N] -> [id C] (inner_in_seqs-0)
│ ├─ ScalarFromTensor [id O]
│ │ └─ *1-<Scalar(int64, shape=())> [id P] -> [id D] (inner_in_sit_sot-0)
│ └─ 0 [id Q]
└─ 1 [id R]
Composite{switch(lt(i0, i1), i2, i0)} [id I]
← Switch [id S] 'o0'
├─ LT [id T]
│ ├─ i0 [id U]
│ └─ i1 [id V]
├─ i2 [id W]
└─ i0 [id U]
← Composite{switch(lt(0, i0), 1, 0)} [id I] (inner_out_sit_sot-0)
└─ Subtensor{i, j, k} [id J]
├─ *2-<Tensor3(float64, shape=(20000, 2, 2))> [id K] -> [id H] (inner_in_non_seqs-0)
├─ ScalarFromTensor [id L]
│ └─ *0-<Scalar(int64, shape=())> [id M] -> [id C] (inner_in_seqs-0)
├─ ScalarFromTensor [id N]
│ └─ *1-<Scalar(int64, shape=())> [id O] -> [id D] (inner_in_sit_sot-0)
└─ 0 [id P]
Composite{switch(lt(0, i0), 1, 0)} [id I]
← Switch [id Q] 'o0'
├─ LT [id R]
│ ├─ 0 [id S]
│ └─ i0 [id T]
├─ 1 [id U]
└─ 0 [id S]
"""

output_str = debugprint(out, file="str", print_op_info=True)
Expand Down

0 comments on commit 4d0aa3f

Please sign in to comment.