From c09944caa5135b813e95923f790c5da768a73574 Mon Sep 17 00:00:00 2001 From: Ashutosh Parkhi <86472128+ashutosh-arm@users.noreply.github.com> Date: Sun, 25 Dec 2022 16:15:35 +0000 Subject: [PATCH] [CMSIS-NN] Global function that provides range based on dtype (#13652) Range for dtype was being sought differently from inside various CMSIS-NN tests. This commit has created a common global function inside aot.py under tvm.testing that can provide (min, max) values based on the dtype. In future, other AOT based targets can make use of this function to obtain the range. --- python/tvm/testing/aot.py | 31 ++++++++++++++++- .../contrib/test_cmsisnn/test_binary_ops.py | 14 ++++---- .../contrib/test_cmsisnn/test_conv2d.py | 21 ++++++------ .../test_cmsisnn/test_fully_connected.py | 10 +++--- .../contrib/test_cmsisnn/test_fuse_pads.py | 22 +++++++----- .../test_cmsisnn/test_generate_constants.py | 7 ++-- .../test_cmsisnn/test_invalid_graphs.py | 5 ++- .../contrib/test_cmsisnn/test_networks.py | 6 ++-- .../contrib/test_cmsisnn/test_pooling.py | 4 +-- .../test_cmsisnn/test_remove_reshapes.py | 4 +-- .../contrib/test_cmsisnn/test_softmax.py | 5 ++- tests/python/contrib/test_cmsisnn/utils.py | 34 +++---------------- .../aot/test_crt_forward_declarations.py | 24 ------------- 13 files changed, 85 insertions(+), 102 deletions(-) diff --git a/python/tvm/testing/aot.py b/python/tvm/testing/aot.py index 563a7dff4a50..30d3c78ae43b 100644 --- a/python/tvm/testing/aot.py +++ b/python/tvm/testing/aot.py @@ -24,7 +24,7 @@ import subprocess import tarfile import logging -from typing import Any, NamedTuple, Union, Optional, List, Dict +from typing import Any, NamedTuple, Union, Tuple, Optional, List, Dict import numpy as np import tvm @@ -901,6 +901,35 @@ def compile_and_run( ) +def get_dtype_range(dtype: str) -> Tuple[int, int]: + """ + Produces the min,max for a give data type. + + Parameters + ---------- + dtype : str + a type string (e.g., int8, float64) + + Returns + ------- + type_info.min : int + the minimum of the range + type_info.max : int + the maximum of the range + """ + type_info = None + np_dtype = np.dtype(dtype) + kind = np_dtype.kind + + if kind == "f": + type_info = np.finfo(np_dtype) + elif kind in ["i", "u"]: + type_info = np.iinfo(np_dtype) + else: + raise TypeError(f"dtype ({dtype}) must indicate some floating-point or integral data type.") + return type_info.min, type_info.max + + def generate_ref_data(mod, input_data, params=None, target="llvm"): """Generate reference data through executing the relay module""" with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): diff --git a/tests/python/contrib/test_cmsisnn/test_binary_ops.py b/tests/python/contrib/test_cmsisnn/test_binary_ops.py index 29335072bf06..663a1bd45d5c 100644 --- a/tests/python/contrib/test_cmsisnn/test_binary_ops.py +++ b/tests/python/contrib/test_cmsisnn/test_binary_ops.py @@ -25,7 +25,7 @@ import tvm from tvm import relay from tvm.relay.op.contrib import cmsisnn -from tvm.testing.aot import generate_ref_data, AOTTestModel, compile_and_run +from tvm.testing.aot import get_dtype_range, generate_ref_data, AOTTestModel, compile_and_run from tvm.micro.testing.aot_test_utils import ( AOT_USMP_CORSTONE300_RUNNER, ) @@ -34,7 +34,6 @@ skip_if_no_reference_system, make_module, make_qnn_relu, - get_range_for_dtype_str, assert_partitioned_function, assert_no_external_function, create_test_runner, @@ -45,9 +44,8 @@ def generate_tensor_constant(): rng = np.random.default_rng(12321) dtype = "int8" shape = (1, 16, 16, 3) - values = tvm.nd.array( - rng.integers(np.iinfo(dtype).min, high=np.iinfo(dtype).max, size=shape, dtype=dtype) - ) + in_min, in_max = get_dtype_range(dtype) + values = tvm.nd.array(rng.integers(in_min, high=in_max, size=shape, dtype=dtype)) return relay.const(values, dtype) @@ -136,7 +134,7 @@ def test_op_int8( assert_partitioned_function(orig_mod, cmsisnn_mod) # validate the output - in_min, in_max = get_range_for_dtype_str(dtype) + in_min, in_max = get_dtype_range(dtype) inputs = { "input_0": np.random.randint(in_min, high=in_max, size=shape, dtype=dtype), "input_1": np.random.randint(in_min, high=in_max, size=shape, dtype=dtype), @@ -196,7 +194,7 @@ def test_same_input_to_binary_op(op, relu_type): ), "Composite function for the binary op should have only 1 parameter." # validate the output - in_min, in_max = get_range_for_dtype_str(dtype) + in_min, in_max = get_dtype_range(dtype) inputs = { "input": np.random.randint(in_min, high=in_max, size=shape, dtype=dtype), } @@ -275,7 +273,7 @@ def test_constant_input_int8(op, input_0, input_1): assert_partitioned_function(orig_mod, cmsisnn_mod) # validate the output - in_min, in_max = get_range_for_dtype_str(dtype) + in_min, in_max = get_dtype_range(dtype) inputs = {} if isinstance(input_0, tvm.relay.expr.Var): inputs.update({"input_0": np.random.randint(in_min, high=in_max, size=shape, dtype=dtype)}) diff --git a/tests/python/contrib/test_cmsisnn/test_conv2d.py b/tests/python/contrib/test_cmsisnn/test_conv2d.py index 66ff5d793880..20e7b9ed2f62 100644 --- a/tests/python/contrib/test_cmsisnn/test_conv2d.py +++ b/tests/python/contrib/test_cmsisnn/test_conv2d.py @@ -24,6 +24,7 @@ from tvm.relay.op.contrib import cmsisnn from tvm.testing.aot import ( + get_dtype_range, generate_ref_data, AOTTestModel, compile_models, @@ -33,7 +34,6 @@ from tvm.micro.testing.aot_test_utils import AOT_USMP_CORSTONE300_RUNNER from .utils import ( make_module, - get_range_for_dtype_str, get_same_padding, get_conv2d_qnn_params, get_kernel_bias_dtype, @@ -82,10 +82,11 @@ def make_model( p = get_same_padding((shape[1], shape[2]), (kernel_h, kernel_w), dilation, strides) rng = np.random.default_rng(12321) + kmin, kmax = get_dtype_range(kernel_dtype) kernel = tvm.nd.array( rng.integers( - np.iinfo(kernel_dtype).min, - high=np.iinfo(kernel_dtype).max, + kmin, + high=kmax, size=kernel_shape, dtype=kernel_dtype, ) @@ -157,7 +158,7 @@ def test_conv2d_number_primfunc_args( kernel_w = kernel_size[1] kernel_shape = (kernel_h, kernel_w, ifm_shape[3] // groups, out_channels) kernel_zero_point = 0 - in_min, in_max = get_range_for_dtype_str(dtype) + in_min, in_max = get_dtype_range(dtype) relu_type = "RELU" kernel_dtype, bias_dtype = get_kernel_bias_dtype(dtype) @@ -264,7 +265,7 @@ def test_conv2d_symmetric_padding( kernel_w = kernel_size[1] kernel_shape = (kernel_h, kernel_w, ifm_shape[3] // groups, out_channels) kernel_zero_point = 0 - in_min, in_max = get_range_for_dtype_str(dtype) + in_min, in_max = get_dtype_range(dtype) kernel_dtype, bias_dtype = get_kernel_bias_dtype(dtype) @@ -358,7 +359,7 @@ def test_conv2d_asymmetric_padding( kernel_w = kernel_size[1] kernel_shape = (kernel_h, kernel_w, ifm_shape[3] // groups, out_channels) kernel_zero_point = 0 - in_min, in_max = get_range_for_dtype_str(dtype) + in_min, in_max = get_dtype_range(dtype) kernel_dtype, bias_dtype = get_kernel_bias_dtype(dtype) @@ -454,7 +455,7 @@ def test_pad_conv2d_fusion_int8( kernel_w = kernel_size[1] kernel_shape = (kernel_h, kernel_w, ifm_shape[3] // groups, out_channels) kernel_zero_point = 0 - in_min, in_max = get_range_for_dtype_str(dtype) + in_min, in_max = get_dtype_range(dtype) kernel_dtype, bias_dtype = get_kernel_bias_dtype(dtype) output_scale, output_zero_point = get_conv2d_qnn_params( @@ -567,7 +568,7 @@ def test_invalid_pad_conv2d_fusion_int8( kernel_w = kernel_size[1] kernel_shape = (kernel_h, kernel_w, ifm_shape[3] // groups, out_channels) kernel_zero_point = 0 - in_min, in_max = get_range_for_dtype_str(dtype) + in_min, in_max = get_dtype_range(dtype) kernel_dtype, bias_dtype = get_kernel_bias_dtype(dtype) @@ -740,7 +741,7 @@ def test_depthwise( kernel_w = kernel_size[1] kernel_shape = (kernel_h, kernel_w, ifm_shape[3] // groups, out_channels) kernel_zero_point = 0 - in_min, in_max = get_range_for_dtype_str(dtype) + in_min, in_max = get_dtype_range(dtype) groups = ifm_shape[3] kernel_layout = "HWOI" @@ -844,7 +845,7 @@ def test_relay_conv2d_cmsisnn_depthwise_int8( test_runner = AOT_USMP_CORSTONE300_RUNNER dtype = "int8" - in_min, in_max = get_range_for_dtype_str(dtype) + in_min, in_max = get_dtype_range(dtype) ifm_shape = (1, 24, 24, 1) groups = ifm_shape[3] diff --git a/tests/python/contrib/test_cmsisnn/test_fully_connected.py b/tests/python/contrib/test_cmsisnn/test_fully_connected.py index 3b220eb42c9b..46b1488eb3fe 100644 --- a/tests/python/contrib/test_cmsisnn/test_fully_connected.py +++ b/tests/python/contrib/test_cmsisnn/test_fully_connected.py @@ -23,10 +23,9 @@ from tvm import relay from tvm.relay.op.contrib import cmsisnn -from tvm.testing.aot import generate_ref_data, AOTTestModel, compile_and_run +from tvm.testing.aot import get_dtype_range, generate_ref_data, AOTTestModel, compile_and_run from .utils import ( make_module, - get_range_for_dtype_str, get_conv2d_qnn_params, make_qnn_relu, assert_partitioned_function, @@ -55,10 +54,11 @@ def make_model( """Return a model and any parameters it may have""" input_ = relay.var("input", shape=in_shape, dtype=dtype) rng = np.random.default_rng(12321) + kmin, kmax = get_dtype_range(kernel_dtype) weight = tvm.nd.array( rng.integers( - np.iinfo(kernel_dtype).min, - high=np.iinfo(kernel_dtype).max, + kmin, + high=kmax, size=kernel_shape, dtype=kernel_dtype, ) @@ -123,7 +123,7 @@ def test_ops( kernel_zero_point = 0 kernel_shape = [out_channels, in_shape[1]] conv2d_kernel_shape = (1, 1, kernel_shape[0], kernel_shape[1]) - in_min, in_max = get_range_for_dtype_str(dtype) + in_min, in_max = get_dtype_range(dtype) output_scale, output_zero_point = get_conv2d_qnn_params( conv2d_kernel_shape, diff --git a/tests/python/contrib/test_cmsisnn/test_fuse_pads.py b/tests/python/contrib/test_cmsisnn/test_fuse_pads.py index f57dc5cd5bab..4ea306cc4382 100644 --- a/tests/python/contrib/test_cmsisnn/test_fuse_pads.py +++ b/tests/python/contrib/test_cmsisnn/test_fuse_pads.py @@ -19,7 +19,7 @@ import numpy as np import pytest import tvm -import tvm.testing +from tvm.testing.aot import get_dtype_range from tvm import relay from .utils import CheckForPadsWithinCompositeFunc @@ -59,10 +59,11 @@ def test_invalid_padding_for_fusion(ifm_shape, pad_width, conv2d_padding, ofm_sh pad_mode="constant", ) rng = np.random.default_rng(12321) + in_min, in_max = get_dtype_range(dtype) local_weight = tvm.nd.array( rng.integers( - np.iinfo(dtype).min, - high=np.iinfo(dtype).max, + in_min, + high=in_max, size=(ofm_channels, kernel_size[0], kernel_size[1], ifm_shape[3]), dtype=dtype, ) @@ -139,10 +140,11 @@ def test_pad_conv2d_fusion_noncmsisnn_target(ifm_shape, pad_width, conv2d_paddin pad_mode="constant", ) rng = np.random.default_rng(12321) + in_min, in_max = get_dtype_range(dtype) local_weight = tvm.nd.array( rng.integers( - np.iinfo(dtype).min, - high=np.iinfo(dtype).max, + in_min, + high=in_max, size=(ofm_channels, kernel_size[0], kernel_size[1], ifm_shape[3]), dtype=dtype, ) @@ -217,10 +219,11 @@ def test_pad_conv2d_fusion(ifm_shape, pad_width, conv2d_padding, ofm_shape): pad_mode="constant", ) rng = np.random.default_rng(12321) + kmin, kmax = get_dtype_range(dtype) local_weight = tvm.nd.array( rng.integers( - np.iinfo(dtype).min, - high=np.iinfo(dtype).max, + kmin, + high=kmax, size=(ofm_channels, kernel_size[0], kernel_size[1], ifm_shape[3]), dtype=dtype, ) @@ -281,10 +284,11 @@ def test_without_preceding_pad(): ofm_shape = (1, 56, 56, 64) local_input = relay.var("local_input", shape=ifm_shape, dtype=dtype) rng = np.random.default_rng(12321) + kmin, kmax = get_dtype_range(dtype) local_weight = tvm.nd.array( rng.integers( - np.iinfo(dtype).min, - high=np.iinfo(dtype).max, + kmin, + high=kmax, size=(64, 3, 3, 64), dtype=dtype, ) diff --git a/tests/python/contrib/test_cmsisnn/test_generate_constants.py b/tests/python/contrib/test_cmsisnn/test_generate_constants.py index 86737370bc5d..b83884128441 100644 --- a/tests/python/contrib/test_cmsisnn/test_generate_constants.py +++ b/tests/python/contrib/test_cmsisnn/test_generate_constants.py @@ -20,7 +20,7 @@ import numpy as np import pytest import tvm -import tvm.testing +from tvm.testing.aot import get_dtype_range from tvm import relay from tvm.relay.op.contrib import cmsisnn @@ -107,10 +107,11 @@ def make_model( weight_shape = (kernel_h, kernel_w, shape[3] // groups, out_channels) rng = np.random.default_rng(12321) + kmin, kmax = get_dtype_range(kernel_dtype) weight = tvm.nd.array( rng.integers( - np.iinfo(kernel_dtype).min, - high=np.iinfo(kernel_dtype).max, + kmin, + high=kmax, size=weight_shape, dtype=kernel_dtype, ) diff --git a/tests/python/contrib/test_cmsisnn/test_invalid_graphs.py b/tests/python/contrib/test_cmsisnn/test_invalid_graphs.py index c66f9d0e0726..ace1db7811da 100644 --- a/tests/python/contrib/test_cmsisnn/test_invalid_graphs.py +++ b/tests/python/contrib/test_cmsisnn/test_invalid_graphs.py @@ -19,13 +19,12 @@ import numpy as np import tvm -from tvm.testing.aot import AOTTestModel, compile_and_run, generate_ref_data +from tvm.testing.aot import AOTTestModel, get_dtype_range, compile_and_run, generate_ref_data from tvm.micro.testing.aot_test_utils import ( AOT_USMP_CORSTONE300_RUNNER, ) from .utils import ( skip_if_no_reference_system, - get_range_for_dtype_str, ) @@ -58,7 +57,7 @@ def @main(%data : Tensor[(16, 29), int8]) -> Tensor[(16, 29), int8] { use_unpacked_api = True test_runner = AOT_USMP_CORSTONE300_RUNNER dtype = "int8" - in_min, in_max = get_range_for_dtype_str(dtype) + in_min, in_max = get_dtype_range(dtype) rng = np.random.default_rng(12345) inputs = {"data": rng.integers(in_min, high=in_max, size=(16, 29), dtype=dtype)} outputs = generate_ref_data(orig_mod["main"], inputs, params) diff --git a/tests/python/contrib/test_cmsisnn/test_networks.py b/tests/python/contrib/test_cmsisnn/test_networks.py index 6f9f3743a622..9f64be246182 100644 --- a/tests/python/contrib/test_cmsisnn/test_networks.py +++ b/tests/python/contrib/test_cmsisnn/test_networks.py @@ -24,12 +24,12 @@ from tvm import relay from tvm.contrib.download import download_testdata from tvm.relay.op.contrib import cmsisnn -from tvm.testing.aot import AOTTestModel, compile_and_run, generate_ref_data +from tvm.testing.aot import AOTTestModel, get_dtype_range, compile_and_run, generate_ref_data from tvm.micro.testing.aot_test_utils import ( AOT_CORSTONE300_RUNNER, AOT_USMP_CORSTONE300_RUNNER, ) -from .utils import skip_if_no_reference_system, get_range_for_dtype_str +from .utils import skip_if_no_reference_system # pylint: disable=import-outside-toplevel def _convert_to_relay( @@ -93,7 +93,7 @@ def test_cnn_small(test_runner): input_shape = (1, 490) dtype = "int8" - in_min, in_max = get_range_for_dtype_str(dtype) + in_min, in_max = get_dtype_range(dtype) rng = np.random.default_rng(12345) input_data = rng.integers(in_min, high=in_max, size=input_shape, dtype=dtype) diff --git a/tests/python/contrib/test_cmsisnn/test_pooling.py b/tests/python/contrib/test_cmsisnn/test_pooling.py index 7657e0e63220..c6e5f02e712a 100644 --- a/tests/python/contrib/test_cmsisnn/test_pooling.py +++ b/tests/python/contrib/test_cmsisnn/test_pooling.py @@ -23,6 +23,7 @@ from tvm.relay.op.contrib import cmsisnn from tvm.testing.aot import ( + get_dtype_range, generate_ref_data, AOTTestModel, compile_and_run, @@ -30,7 +31,6 @@ from tvm.micro.testing.aot_test_utils import AOT_USMP_CORSTONE300_RUNNER from .utils import ( make_module, - get_range_for_dtype_str, get_same_padding, make_qnn_relu, assert_partitioned_function, @@ -128,7 +128,7 @@ def test_ops( assert_partitioned_function(orig_mod, cmsisnn_mod) # validate the output - in_min, in_max = get_range_for_dtype_str(dtype) + in_min, in_max = get_dtype_range(dtype) np.random.seed(0) inputs = { "input": np.random.randint(in_min, high=in_max, size=in_shape, dtype=dtype), diff --git a/tests/python/contrib/test_cmsisnn/test_remove_reshapes.py b/tests/python/contrib/test_cmsisnn/test_remove_reshapes.py index 8b33a8a90b76..3cd60341ebfe 100644 --- a/tests/python/contrib/test_cmsisnn/test_remove_reshapes.py +++ b/tests/python/contrib/test_cmsisnn/test_remove_reshapes.py @@ -23,6 +23,7 @@ from tvm.relay.op.contrib import cmsisnn from tvm.testing.aot import ( + get_dtype_range, generate_ref_data, AOTTestModel, compile_models, @@ -31,7 +32,6 @@ from tvm.micro.testing.aot_test_utils import AOT_USMP_CORSTONE300_RUNNER from .utils import ( make_module, - get_range_for_dtype_str, get_same_padding, make_qnn_relu, assert_partitioned_function, @@ -126,7 +126,7 @@ def test_reshape_removal(padding): # generate reference output rng = np.random.default_rng(12345) - in_min, in_max = get_range_for_dtype_str("int8") + in_min, in_max = get_dtype_range("int8") inputs = {"input": rng.integers(in_min, high=in_max, size=in_shape, dtype="int8")} output_list = generate_ref_data(orig_mod["main"], inputs, params=None) diff --git a/tests/python/contrib/test_cmsisnn/test_softmax.py b/tests/python/contrib/test_cmsisnn/test_softmax.py index d048723529e0..0316d567adf4 100644 --- a/tests/python/contrib/test_cmsisnn/test_softmax.py +++ b/tests/python/contrib/test_cmsisnn/test_softmax.py @@ -24,12 +24,11 @@ import tvm.testing from tvm import relay from tvm.relay.op.contrib import cmsisnn -from tvm.testing.aot import AOTTestModel, compile_and_run, generate_ref_data +from tvm.testing.aot import get_dtype_range, AOTTestModel, compile_and_run, generate_ref_data from .utils import ( skip_if_no_reference_system, make_module, - get_range_for_dtype_str, assert_partitioned_function, assert_no_external_function, create_test_runner, @@ -78,7 +77,7 @@ def test_op_int8(zero_point, scale, compiler_cpu, cpu_flags): assert_partitioned_function(orig_mod, cmsisnn_mod) # validate the output - in_min, in_max = get_range_for_dtype_str(dtype) + in_min, in_max = get_dtype_range(dtype) np.random.seed(0) input_data = np.random.randint(in_min, high=in_max, size=shape, dtype=dtype) inputs = {"in0": input_data} diff --git a/tests/python/contrib/test_cmsisnn/utils.py b/tests/python/contrib/test_cmsisnn/utils.py index f3a6b0c1343b..1ec3e609f1a3 100644 --- a/tests/python/contrib/test_cmsisnn/utils.py +++ b/tests/python/contrib/test_cmsisnn/utils.py @@ -23,7 +23,7 @@ import tvm from tvm import relay -from tvm.testing.aot import AOTTestRunner +from tvm.testing.aot import AOTTestRunner, get_dtype_range def skip_if_no_reference_system(func): @@ -86,30 +86,6 @@ def assert_no_external_function(mod): assert not any(attrs), "No function should have an external attribute." -def get_range_for_dtype_str(dtype): - """ - Produces the min,max for a give data type. - - Parameters - ---------- - dtype : str - a type string (e.g., int8) - - Returns - ------- - type_info.min : int - the minimum of the range - type_info.max : int - the maximum of the range - """ - - try: - type_info = np.iinfo(dtype) - except ValueError: - type_info = np.finfo(dtype) - return type_info.min, type_info.max - - def make_module(func): """Creates IRModule from Function""" func = relay.Function(relay.analysis.free_vars(func), func) @@ -193,11 +169,11 @@ def get_conv2d_qnn_params( output_zp : int zero point of the output tensor """ - input_dtype_min, input_dtype_max = get_range_for_dtype_str(input_dtype) + input_dtype_min, input_dtype_max = get_dtype_range(input_dtype) input_max = input_scale * (input_dtype_max - input_zp) input_min = input_scale * (input_dtype_min - input_zp) - kernel_dtype_min, kernel_dtype_max = get_range_for_dtype_str(kernel_dtype) + kernel_dtype_min, kernel_dtype_max = get_dtype_range(kernel_dtype) kernel_sc_max = np.max(kernel_scale) kernel_max = kernel_sc_max * (kernel_dtype_max - kernel_zp) @@ -222,7 +198,7 @@ def get_conv2d_qnn_params( output_max = max(output_limits) output_min = min(output_limits) - output_dtype_min, output_dtype_max = get_range_for_dtype_str(output_dtype) + output_dtype_min, output_dtype_max = get_dtype_range(output_dtype) output_scale = (output_max - output_min) / (output_dtype_max - output_dtype_min) output_zp = int(output_dtype_min - (output_min / output_scale)) @@ -236,7 +212,7 @@ def make_qnn_relu(expr, fused_activation_fn, scale, zero_point, dtype): # Get min/max of the output dtype. This will be used to ensure that clip a_min/a_max are not # beyond the dtype range. - qmin, qmax = get_range_for_dtype_str(dtype) + qmin, qmax = get_dtype_range(dtype) # The input expr is a quantized tensor with its scale and zero point. We calculate the # suitable clip off points based on these scale and zero point. diff --git a/tests/python/relay/aot/test_crt_forward_declarations.py b/tests/python/relay/aot/test_crt_forward_declarations.py index 7454f85ed153..e54846f3aaca 100644 --- a/tests/python/relay/aot/test_crt_forward_declarations.py +++ b/tests/python/relay/aot/test_crt_forward_declarations.py @@ -34,30 +34,6 @@ ) -def get_range_for_dtype_str(dtype): - """ - Produces the min,max for a give data type. - - Parameters - ---------- - dtype : str - a type string (e.g., int8) - - Returns - ------- - type_info.min : int - the minimum of the range - type_info.max : int - the maximum of the range - """ - - try: - type_info = np.iinfo(dtype) - except ValueError: - type_info = np.finfo(dtype) - return type_info.min, type_info.max - - def _change_ndarray_layout(arr, src_layout, dst_layout): """Makes a copy of an ndarray, reshaping it to a new data layout.