-
Notifications
You must be signed in to change notification settings - Fork 6.8k
adding error message when attempting to use Large tensor with linalg_syevd #18807
Conversation
Hey @access2rohit , Thanks for submitting the PR
CI supported jobs: [website, miscellaneous, centos-cpu, centos-gpu, windows-cpu, windows-gpu, unix-gpu, clang, sanity, edge, unix-cpu] Note: |
3rdparty/mshadow/mshadow/base.h
Outdated
@@ -336,6 +336,8 @@ const float kPi = 3.1415926f; | |||
typedef int32_t index_t; | |||
#endif | |||
|
|||
const index_t kInt32Limit = (int64_t{1} << 31) - 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a comment to make it clear where is this being used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the advantage of manually defining this over using INT_MAX in <climits?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Zha0q1 good suggestion !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Do you think adding a test with assert_exception in test_large_array.py
For example
from mxnet.test_utils import assert_exception
def test_linalg_syevd():
input = mx.nd.array(LARGE_X,SMALL_Y)
assert_exception(mx.nd.linalg(data), MXNetError)
Thats why WIP |
|
ed5a883
to
6411cee
Compare
Test Run
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
left some comments on minor changes. Otherwise looking good!
src/operator/tensor/la_op.h
Outdated
@@ -470,11 +470,14 @@ inline bool DetType(const nnvm::NodeAttrs& attrs, | |||
inline bool LaEigFactShape(const nnvm::NodeAttrs& attrs, | |||
mxnet::ShapeVector* in_attrs, | |||
mxnet::ShapeVector* out_attrs) { | |||
using namespace mshadow; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this can be removed then.
CHECK_EQ(in_attrs->size(), 1); | ||
CHECK_EQ(out_attrs->size(), 2); | ||
const mxnet::TShape& in_a = (*in_attrs)[0]; | ||
const mxnet::TShape& out_u = (*out_attrs)[0]; | ||
const mxnet::TShape& out_l = (*out_attrs)[1]; | ||
CHECK_LE(in_a.Size(), INT_MAX) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did you include climit ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is included in base.h
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's better to be explicit about your dependencies.
"Avoid surprises. Avoid having to change #includes if an #included header changes. Avoid accidentally becoming dependent on implementation details and logically separate entities included in a header."
https://github.com/isocpp/CppCoreGuidelines/blob/master/CppCoreGuidelines.md#Rs-implicit
A = get_identity_mat(LARGE_SQ_X) | ||
for i in range(LARGE_SQ_X): | ||
A[i,i] = 1 | ||
assertRaises(MXNetError, mx.nd.linalg.syevd, A) |
This comment was marked as outdated.
This comment was marked as outdated.
Sorry, something went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. LGTM!
@mxnet-bot run ci [unix-gpu] |
Jenkins CI successfully triggered : [unix-gpu] |
@mxnet-bot run ci [clang] |
Jenkins CI successfully triggered : [clang] |
Description
This PR adds error message when very large inputs(>2^32-1) are passed to linear algebra operator 'syevd'
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.