Skip to content

Commit

Permalink
Fix matmul memory test by passing dtype properly (#358)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite authored Jan 25, 2024
1 parent 929f52c commit 2e28072
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion cubed/array_api/linear_algebra_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ def _sum_wo_cat(a, axis=None, dtype=None):
if a.shape[axis] == 1:
return squeeze(a, axis)

return reduction(a, _chunk_sum, axis=axis, dtype=dtype)
extra_func_kwargs = dict(dtype=dtype)
return reduction(
a, _chunk_sum, axis=axis, dtype=dtype, extra_func_kwargs=extra_func_kwargs
)


def _chunk_sum(a, axis=None, dtype=None, keepdims=None):
Expand Down

0 comments on commit 2e28072

Please sign in to comment.