From 23b35bcc04431b658b391b867025703b56cafd7f Mon Sep 17 00:00:00 2001 From: stu1130 Date: Wed, 23 Oct 2019 10:25:13 -0700 Subject: [PATCH] Make MXIsNumpyShape return enum --- include/mxnet/c_api.h | 2 +- include/mxnet/imperative.h | 10 ++++++---- .../src/main/native/org_apache_mxnet_native_c_api.cc | 2 +- src/c_api/c_api_ndarray.cc | 2 +- src/ndarray/ndarray.cc | 3 ++- 5 files changed, 11 insertions(+), 8 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 177ec5d40146..ac0c6726f2c7 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1170,7 +1170,7 @@ MXNET_DLL int MXAutogradIsTraining(bool* curr); * \param curr returns the current status * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXIsNumpyShape(bool* curr); +MXNET_DLL int MXIsNumpyShape(int* curr); /*! * \brief set numpy compatibility switch * \param is_np_shape 1 when numpy shape semantics is thread local on, diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h index 18f6424e54f7..dbd81e575872 100644 --- a/include/mxnet/imperative.h +++ b/include/mxnet/imperative.h @@ -108,12 +108,14 @@ class Imperative { is_recording_ = is_recording; return old; } - /*! \brief whether numpy compatibility is on. */ - bool is_np_shape() const { + /*! \brief return current numpy compatibility status, + * GlobalOn(2), ThreadLocalOn(1), Off(0). + * */ + int is_np_shape() const { if (is_np_shape_global_) { - return true; + return 2; } - return is_np_shape_thread_local_; + return is_np_shape_thread_local_ ? 1 : 0; } /*! \brief specify numpy compatibility off, thread local on or global on. */ bool set_is_np_shape(int is_np_shape) { diff --git a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc index 5c704c9646a2..00ead4a0147c 100644 --- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc @@ -2778,7 +2778,7 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDumpProfile JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxIsNumpyShape (JNIEnv *env, jobject obj, jobject compatibleRef) { bool isNumpyShape; - int ret = MXIsNumpyShape(&isNumpyShape); + int ret = MXIsNumpyShape(static_cast(&isNumpyShape)); SetIntField(env, compatibleRef, static_cast(isNumpyShape)); return ret; } diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index b80e17c18071..de208c0fed99 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -276,7 +276,7 @@ int MXAutogradSetIsRecording(int is_recording, int* prev) { API_END(); } -int MXIsNumpyShape(bool* curr) { +int MXIsNumpyShape(int* curr) { API_BEGIN(); *curr = Imperative::Get()->is_np_shape(); API_END(); diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index e15f72fa6cfa..44da670b800d 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -1732,7 +1732,8 @@ bool NDArray::Load(dmlc::Stream *strm) { " Please turn on np shape semantics in Python using `with np_shape(True)`" " or decorator `use_np_shape` to scope the code of loading the ndarray."; } else { - CHECK(!Imperative::Get()->is_np_shape()) + // when the flag is global on, skip the check since it would be always global on. + CHECK(Imperative::Get()->is_np_shape() == GlobalOn || !Imperative::Get()->is_np_shape()) << "ndarray was not saved in np shape semantics, but being loaded in np shape semantics." " Please turn off np shape semantics in Python using `with np_shape(False)`" " to scope the code of loading the ndarray.";