Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Bug][Numpy] The data type of power is not correct #16653

Closed
sxjscience opened this issue Oct 28, 2019 · 7 comments
Closed

[Bug][Numpy] The data type of power is not correct #16653

sxjscience opened this issue Oct 28, 2019 · 7 comments

Comments

@sxjscience
Copy link
Member

import mxnet as mx
mx.npx.set_np()
a = mx.np.array(2, dtype=mx.np.int32)
b = a ** 3.1
print(b)

Returns

8

In the numpy the result should be:

import numpy as np
print(np.array(2) ** 3.1)

Returns

8.574187700290345
@xidulu
Copy link
Contributor

xidulu commented Oct 28, 2019

I believe https://github.com/apache/incubator-mxnet/blob/master/src/operator/numpy/np_elemwise_broadcast_op.cc#L32 is the root cause of this problem.
Also, power is not the only operator suffering from this bug :

>>> from mxnet import np, npx, autograd
>>> npx.set_np()
>>> a = np.array(2, dtype=np.int32)
>>> a
array(2, dtype=int32)
>>> a + 1.1
array(3, dtype=int32)

The mixed precision mechanism introduced in #16631 may be a solution to this issue. @haojin2

@haojin2
Copy link
Contributor

haojin2 commented Oct 28, 2019

@xidulu yes, the mixed precision binary ops is the effort that aims to solve this kind of problem. Right now we have multiplication between bool and all other types, later we'll expand this support to more combinations. @sxjscience

@reminisce
Copy link
Contributor

This is a known issue of many operators from which floats are expected with integral inputs. Instead of tweaking many TYPE_SWITCH macros to make it work, we plan to solve it using TVM to generate appropriate kernels for integral dtypes.

@samskalicky
Copy link
Contributor

@lanking520 assign @reminisce

@haojin2
Copy link
Contributor

haojin2 commented Dec 9, 2019

The latest master branch today has this behavior:

>>> import mxnet as mx
>>> mx.npx.set_np()
>>> a = mx.np.array(2, dtype=mx.np.int32)
>>> b = a ** mx.np.array(3.1, dtype=mx.np.float32)
>>> b
array(8.574187)

But still has a little bit problem with the scalar case:

>>> import mxnet as mx
>>> mx.npx.set_np()
>>> a = mx.np.array(2, dtype=mx.np.int32)
>>> b = a ** 3.1
>>> print(b)
8

I think this is a new case for the binary scalar op, I'll work on a solution to this.

@yzhliu
Copy link
Member

yzhliu commented Apr 30, 2020

@Tommliu is working on this.

Tommliu added a commit to Tommliu/incubator-mxnet that referenced this issue May 11, 2020
Tommliu added a commit to Tommliu/incubator-mxnet that referenced this issue May 11, 2020
Tommliu added a commit to Tommliu/incubator-mxnet that referenced this issue May 11, 2020
@yzhliu
Copy link
Member

yzhliu commented May 22, 2020

closed by #18277

@yzhliu yzhliu closed this as completed May 22, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

6 participants