Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Wrong result when using new numpy ffi in deferred compute #18004

Closed
hgt312 opened this issue Apr 9, 2020 · 5 comments
Closed

Wrong result when using new numpy ffi in deferred compute #18004

hgt312 opened this issue Apr 9, 2020 · 5 comments
Labels

Comments

@hgt312
Copy link
Contributor

hgt312 commented Apr 9, 2020

Find it in CI result of #17958.

Reproduce:

Sample code, use master branch

import mxnet as mx
import mxnet._deferred_compute as dc
from mxnet import np, npx
npx.set_np()
with dc.context():
    a = np.ones((2, 2))
    b = np.tril(a, 1)
    c = np.tril(a, -1)

sym = dc.get_symbol([b, c], sym_cls=mx.sym.np._Symbol)
res = sym.bind(mx.context.current_context(), args={}).forward()
res

Results:

[
[[1. 1.]
 [1. 1.]]
<NDArray 2x2 @cpu(0)>,
[[1. 1.]
 [1. 1.]]
<NDArray 2x2 @cpu(0)>]

If replace _api_internal.tril by _npi.tril, the result is right.

@leezu Can you take a look at this?

@hgt312 hgt312 added the Bug label Apr 9, 2020
@hgt312
Copy link
Contributor Author

hgt312 commented Apr 9, 2020

@hzfan

@leezu
Copy link
Contributor

leezu commented Apr 9, 2020

Thanks for reporting the issue. I can confirm that disabling the new FFI for tril via below patch avoids the issue. I'll take a look why the two FFIs behave differently

diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py
index 4820f56f3..8ff7ba073 100644
--- a/python/mxnet/ndarray/numpy/_op.py
+++ b/python/mxnet/ndarray/numpy/_op.py
@@ -2051,7 +2051,7 @@ def tril(m, k=0):
            [ 7.,  8.,  0.],
            [10., 11., 12.]])
     """
-    return _api_internal.tril(m, k)
+    return _npi.tril(m, k)


 def _unary_func_helper(x, fn_array, fn_scalar, out=None, **kwargs):

@haojin2
Copy link
Contributor

haojin2 commented Apr 17, 2020

@leezu Have you been able to find the root cause for this bug?

@leezu
Copy link
Contributor

leezu commented Apr 19, 2020

@haojin2 sorry, I didn't get to dive deep on this issue yet. I'll work on this during our bug bash in the coming week.

@leezu
Copy link
Contributor

leezu commented May 12, 2020

It's due to common expression elimination. As workaround one could set MXNET_ELIMINATE_COMMON_EXPR=0.

https://github.com/apache/incubator-mxnet/blob/ab4f7f6a7335e88034edcf61402aec170cdca5fd/src/executor/eliminate_common_expr_pass.cc#L59-L94

For the old FFI, n->attrs.dict != m->attrs.dict and the two nodes are preserved.
But in the new FFI, n->attrs.dict == m->attrs.dict and the np.tril(a, -1) is replaced by np.tril(a, 1).

For the new FFI, the reason is

https://github.com/apache/incubator-mxnet/blob/7bef85ecbf9cb064d45479efa5cc46c6f8a4e947/src/api/operator/utils.h#L52-L57

only parses attributes if autograd is enabled, whereas the old FFI always parses attributes. This was an oversight when I rebased the DC PR on top of the FFI update PR.

@leezu leezu closed this as completed May 15, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

No branches or pull requests

3 participants