-
Notifications
You must be signed in to change notification settings - Fork 3.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix several issues regarding recent mapping update #4551
Conversation
Signed-off-by: Chun-Wei Chen <[email protected]>
Signed-off-by: Chun-Wei Chen <[email protected]>
Signed-off-by: Chun-Wei Chen <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
It is possible to add a unit test to check the warning is raised when using the dictionary and not raised when using the function? |
Signed-off-by: Chun-Wei Chen <[email protected]>
Signed-off-by: Chun-Wei Chen <[email protected]>
Good idea. I just added tests by |
onnx/test/helper_test.py
Outdated
@@ -724,6 +724,26 @@ def test_make_tensor_raw(tensor_dtype: int) -> None: | |||
np.testing.assert_equal(np_array, numpy_helper.to_array(tensor)) | |||
|
|||
|
|||
# TODO (#4554): remove this test after the deprecation period | |||
# Test these new functions should not raise any depreaction warnings |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Misspelling deprecation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Corrected. Thanks!
def test_tensor_dtype_to_np_dtype_not_throw_warning(self) -> None: | ||
_ = helper.tensor_dtype_to_np_dtype(TensorProto.FLOAT) | ||
|
||
@pytest.mark.filterwarnings("error::DeprecationWarning") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if it is worth checking what happens in case of bfloat16.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tests for bfloat16 looks useful since the mapping for bfloat16 is quite confusing. Adding these tests can prevent future regression. Just added them. PTAL. Thanks!
Signed-off-by: Chun-Wei Chen <[email protected]>
Signed-off-by: Chun-Wei Chen <[email protected]>
@@ -542,7 +542,7 @@ message TensorProto { | |||
// float16 values must be bit-wise converted to an uint16_t prior | |||
// to writing to the buffer. | |||
// When this field is present, the data_type field MUST be | |||
// INT32, INT16, INT8, UINT16, UINT8, BOOL, or FLOAT16 | |||
// INT32, INT16, INT8, UINT16, UINT8, BOOL, FLOAT16 or BFLOAT16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please note that I think BFLOAT16 is missing here.
helper.tensor_dtype_to_field(TensorProto.BFLOAT16), "int32_data" | ||
) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can be done in another PR but I would check all the types:
def test_numeric_types(self): # type: ignore
dtypes = [
np.float16,
np.float32,
np.float64,
np.int8,
np.int16,
np.int32,
np.int64,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
np.complex64,
np.complex128,
]
for dt in dtypes:
with self.subTest(dtype=dt):
t = np.array([0, 1, 2], dtype=dt)
ot = from_array(t)
u = to_array(ot)
self.assertEqual(t.dtype, u.dtype)
assert_almost_equal(t, u)
def test_make_tensor(self): # type: ignore
for pt, dt in TENSOR_TYPE_TO_NP_TYPE.items():
if pt == TensorProto.BFLOAT16:
continue
with self.subTest(dt=dt, pt=pt, raw=False):
if pt == TensorProto.STRING:
t = np.array([["i0", "i1", "i2"], ["i6", "i7", "i8"]], dtype=dt)
else:
t = np.array([[0, 1, 2], [6, 7, 8]], dtype=dt)
ot = make_tensor("test", pt, t.shape, t, raw=False)
self.assertFalse(ot is None)
u = to_array(ot)
self.assertEqual(t.dtype, u.dtype)
self.assertEqual(t.tolist(), u.tolist())
with self.subTest(dt=dt, pt=pt, raw=True):
t = np.array([[0, 1, 2], [6, 7, 8]], dtype=dt)
if pt == TensorProto.STRING:
with self.assertRaises(TypeError):
make_tensor("test", pt, t.shape, t.tobytes(), raw=True)
else:
ot = make_tensor("test", pt, t.shape, t.tobytes(), raw=True)
self.assertFalse(ot is None)
u = to_array(ot)
self.assertEqual(t.dtype, u.dtype)
assert_almost_equal(t, u)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @xadupre for the suggestion. It looks good to me, but I am inclined to have it in another PR (this PR more focuses on bug fixes) Feel free to let me know if you still have other concern.
* add missing case for Squence.MAP Signed-off-by: Chun-Wei Chen <[email protected]> * remove warning and improve printable_graph Signed-off-by: Chun-Wei Chen <[email protected]> * fix lint and mypy Signed-off-by: Chun-Wei Chen <[email protected]> * _NP_TYPE_TO_TENSOR_TYPE and mention available in ONNX 1.13 Signed-off-by: Chun-Wei Chen <[email protected]> * add test to prevent throwing deprecation in new functions Signed-off-by: Chun-Wei Chen <[email protected]> * add bfloat test for tensor_dtype_to_field Signed-off-by: Chun-Wei Chen <[email protected]> * add more tests for bfloat16 Signed-off-by: Chun-Wei Chen <[email protected]> Signed-off-by: Chun-Wei Chen <[email protected]>
Description
tensor_dtype_to_field
not throw deprecation warning by introducing a identical dictionary_STORAGE_TENSOR_TYPE_TO_FIELD
. Same case forNP_TYPE_TO_TENSOR_TYPE
.printable_attribute
to handle bfloat16.Motivation and Context
Recent #4270 refactored existing code, but I forgot to add the case for SequenceProto.MAP in to_list.
cc @gramalingam Please review this PR. Sorry that I should catch this in advance before my PR got merged. Thank you!