Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix casts in ufunc outputs #5550

Merged
merged 21 commits into from
Aug 2, 2021
Merged

Fix casts in ufunc outputs #5550

merged 21 commits into from
Aug 2, 2021

Conversation

toslunar
Copy link
Member

This PR fixes #5527 and simplifies elementwise_copy (e.g. reverts #5410 except for the test).

See also #5539.

@toslunar
Copy link
Member Author

Let me run the tests because the change affects globally. Jenkins, test this please.

@chainer-ci
Copy link
Member

Jenkins CI test (for commit 682f491, target branch master) failed with status FAILURE.

@kmaehashi kmaehashi added cat:bug Bugs to-be-backported Pull-requests to be backported to stable branch prio:medium labels Jul 26, 2021
Comment on lines 914 to 926
else:
op.append(
'out{0}_type out{0}({1});'
.format(
i,
fix_cast_expr(arginfo.dtype, x, f'_raw_out{i}[_ind.get()]')
))
out_op.append(
'_raw_out{0}[_ind.get()] = {1};'.format(
i,
fix_cast_expr(x, arginfo.dtype, f'out{i}')
)
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This "load and save" approach is not really a good idea.

  1. The load is not necessary. I'd remove the ability of reading out from cupy.ufunc, while it's fine with cupy.ElementwiseKernel.
  2. The save causes "precision loss" if no-op is expected. For example, CuPy should carefully implement where of ufunc.

The current problem is just a special case of 2.: cupy.copyto(..., where=...) is getting wrong with

elementwise_copy_where = create_ufunc(
    'cupy_copy_where',
    ('??->?', 'b?->b', 'B?->B', 'h?->h', 'H?->H', 'i?->i', 'I?->I', 'l?->l',
     'L?->L', 'q?->q', 'Q?->Q', 'e?->e', 'f?->f', 'd?->d', 'F?->F', 'D?->D'),
    'if (in1) out0 = in0',
    default_casting='unsafe')

, which was found in the tests of cupy.unwrap. (BTW, for float dtypes, cupy.testing.shaped_arange could use arange(size) + 0.5 instead of + 1 in order to detect this kind of bugs.)

cupy.copyto(ddmod, interval_high, where=(
ddmod == interval_low) & (dd > 0))

To resolve it, either

  • implement (private) _where arg of ufunc.__call__ that works at least for cupy_copy, or
  • rewrite the cupy_copy_where kernel as cupy.ElementwiseKernel.

@toslunar
Copy link
Member Author

Jenkins, test this please.

@chainer-ci
Copy link
Member

Jenkins CI test (for commit adad8fa, target branch master) succeeded!

@toslunar
Copy link
Member Author

Jenkins, test this please.

@chainer-ci
Copy link
Member

Jenkins CI test (for commit 03d0973, target branch master) failed with status FAILURE.

@toslunar
Copy link
Member Author

_real_setter and _imag_setter read out0, too.

@toslunar
Copy link
Member Author

toslunar commented Jul 29, 2021

Now we can simplify some ufunc codes e.g. from

out0_type a = _floor_divide(in0, in1);
out0 = a;
out1 = in0 - a * in1

(https://github.com/cupy/cupy/blob/v9.2.0/cupy/_core/core.pyx#L2016-L2019) to

out0 = _floor_divide(in0, in1);
out1 = in0 - out0 * in1

But let's do so in another PR.

cupy/_core/_kernel.pyx Show resolved Hide resolved
src_kind = get_dtype(src_type).kind
dst_kind = get_dtype(dst_type).kind
if src_kind == dst_kind:
return expr
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about

  1. casting explicitly?
Suggested change
return expr
cast_type = _get_typename(dst_type)
return f'static_cast<{cast_type}>({expr})'

, or
2. Could you add a comment that the cast is not needed because the returned fixed expr will be always used as out[i] = fixed_expr?

cdef function.Function _get_ufunc_kernel(
tuple in_types, tuple out_types, routine, tuple arginfos, params,
tuple in_types, tuple out_types, routine, tuple arginfos,
bint has_where, params,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: (internal discussion with @toslunar)
has_where argument can be removed by introducing a constant array, that is a ndaray-like compile-time constant object.
If where option is not given, where becomes a constant True.

@toslunar toslunar force-pushed the ufunc-cast-output branch from 316ac56 to 00264a7 Compare July 30, 2021 13:04
@toslunar
Copy link
Member Author

I squashed some recent commits for commit-by-commit review.

@toslunar
Copy link
Member Author

Jenkins, test this please.

@chainer-ci
Copy link
Member

Jenkins CI test (for commit 00264a7, target branch master) failed with status FAILURE.

if where is not None:
_core.elementwise_copy(src, dst, _where=where)
return

if dst.size == 0:
return
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, cupy.copyto(cupy.empty((2, 0)), cupy.empty((0, 3))) should raise

ValueError: could not broadcast input array from shape (0,3) into shape (2,0)

@asi1024
Copy link
Member

asi1024 commented Aug 2, 2021

Jenkins, test this please.

@chainer-ci
Copy link
Member

Jenkins CI test (for commit 927b110, target branch master) succeeded!

@asi1024
Copy link
Member

asi1024 commented Aug 2, 2021

LGTM!

@asi1024 asi1024 merged commit cb178c9 into cupy:master Aug 2, 2021
chainer-ci pushed a commit to chainer-ci/cupy that referenced this pull request Aug 2, 2021
@toslunar toslunar deleted the ufunc-cast-output branch August 3, 2021 02:10
@kmaehashi kmaehashi added this to the v10.0.0b1 milestone Aug 5, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cat:bug Bugs prio:medium to-be-backported Pull-requests to be backported to stable branch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

unsafe cast from complex becomes NVRTC compile error
4 participants