Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[OpPerf] Add Indexing ops #16253

Merged
merged 8 commits into from
Feb 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions benchmark/opperf/nd_operations/indexing_routines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import mxnet as mx
from benchmark.opperf.utils.benchmark_utils import run_op_benchmarks
from benchmark.opperf.utils.op_registry_utils import get_all_indexing_routines

"""Performance benchmark tests for MXNet Indexing routines.

1. slice
2. slice_axis
3. slice_like
4. take
5. pick
6. where
7. ravel_multi_index
8. unravel_index [to do]
9. gather_nd
10. scatter_nd [to do]
11. one_hot
"""


def run_indexing_routines_benchmarks(ctx=mx.cpu(), dtype='float32', profiler='native', warmup=25, runs=100):
"""Runs benchmarks with the given context and precision (dtype) for all the indexing routines
in MXNet.

Parameters
----------
ctx: mx.ctx
Context to run benchmarks
dtype: str, default 'float32'
Precision to use for benchmarks
profiler: str, default 'native'
Type of Profiler to use (native/python)
warmup: int, default 25
ChaiBapchya marked this conversation as resolved.
Show resolved Hide resolved
Number of times to run for warmup
runs: int, default 100
Number of runs to capture benchmark results

Returns
-------
Dictionary of results. Key -> Name of the operator, Value -> Benchmark results.

"""
# Fetch all indexing routines
mx_indexing_ops = get_all_indexing_routines()

# Run benchmarks
mx_indexing_op_results = run_op_benchmarks(mx_indexing_ops, dtype, ctx, profiler, warmup, runs)
return mx_indexing_op_results
4 changes: 4 additions & 0 deletions benchmark/opperf/opperf.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from benchmark.opperf.nd_operations.nn_basic_operators import run_nn_basic_operators_benchmarks
from benchmark.opperf.nd_operations.nn_optimizer_operators import run_optimizer_operators_benchmarks
from benchmark.opperf.nd_operations.array_rearrange import run_rearrange_operators_benchmarks
from benchmark.opperf.nd_operations.indexing_routines import run_indexing_routines_benchmarks
from benchmark.opperf.nd_operations.nn_loss_operators import run_loss_operators_benchmarks

from benchmark.opperf.utils.common_utils import merge_map_list, save_to_file
Expand Down Expand Up @@ -84,6 +85,9 @@ def run_all_mxnet_operator_benchmarks(ctx=mx.cpu(), dtype='float32', profiler='n
# Run all Array Rearrange operations benchmarks with default input values
mxnet_operator_benchmark_results.append(run_rearrange_operators_benchmarks(ctx=ctx, dtype=dtype, profiler=profiler))

# Run all Indexing routines benchmarks with default input values
mxnet_operator_benchmark_results.append(run_indexing_routines_benchmarks(ctx=ctx, dtype=dtype, profiler=profiler))

# ************************ MXNET NN OPERATOR BENCHMARKS ****************************

# Run all basic NN operations benchmarks with default input values
Expand Down
99 changes: 63 additions & 36 deletions benchmark/opperf/rules/default_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,30 +79,45 @@
DEFAULT_Z = [(1024, 1024), (10000, 1), (10000, 100)]
DEFAULT_G = [(1024, 1024), (10000, 1), (10000, 100)]
DEFAULT_DELTA = [(1024, 1024), (10000, 1), (10000, 100)]
DEFAULT_LRS = [(0.1,0.1)]
DEFAULT_LR = [0.1,0.5,0.9]
DEFAULT_GAMMA_1 = [0.1,0.5,0.9]
DEFAULT_GAMMA_2 = [0.1,0.5,0.9]
DEFAULT_LRS = [(0.1, 0.1)]
DEFAULT_LR = [0.1, 0.5, 0.9]
DEFAULT_GAMMA_1 = [0.1, 0.5, 0.9]
DEFAULT_GAMMA_2 = [0.1, 0.5, 0.9]
DEFAULT_EPSILON = [1e-08]
DEFAULT_BETA_1 = [0.1,0.5,0.9]
DEFAULT_BETA_2 = [0.1,0.5,0.9]
DEFAULT_T = [1,5]
DEFAULT_BETA_1 = [0.1, 0.5, 0.9]
DEFAULT_BETA_2 = [0.1, 0.5, 0.9]
DEFAULT_T = [1, 5]
DEFAULT_RESCALE_GRAD = [0.4, 0.77]
DEFAULT_CLIP_GRADIENT = [-1.0,0.8]
DEFAULT_CLIP_WEIGHTS = [-1.0,0.8]
DEFAULT_LAZY_UPDATE = [0,1]
DEFAULT_CLIP_GRADIENT = [-1.0, 0.8]
DEFAULT_CLIP_WEIGHTS = [-1.0, 0.8]
DEFAULT_LAZY_UPDATE = [0, 1]

# For rearrange operators
# NOTE: Data needs to be a 4D tensor for operators like space_to_depth and depth_to_space
# Hence below we append 4d to mark the difference.
# For depth_to_space, dimension 3 needs to be a multiple of 'block' and 1 should be a multiple of `block^2`
DEFAULT_DATA_4d = [(1, 4, 2, 4), (10,25,10,100)]
DEFAULT_DATA_4d = [(1, 4, 2, 4), (10, 25, 10, 100)]
DEFAULT_BLOCK_SIZE = [2, 5]

# For swapaxis operator
DEFAULT_DIM_1 = [0]
DEFAULT_DIM_2 = [1]

# For indexing routines
DEFAULT_INDEX = [(1,1024), (1,1), (1,100)]
DEFAULT_INDICES = [(1, 1)]
DEFAULT_BEGIN = [0] # slice_axis expects int, slice can have tuple/int
DEFAULT_END =[1] # same as above
DEFAULT_SHAPE_LIKE = [(100, 100), (10, 1), (100, 10)]
DEFAULT_X = [(1024, 1024), (10000, 1), (10000, 100)]
DEFAULT_Y = [(1024, 1024), (10000, 1), (10000, 100)]
DEFAULT_COND = [(1024,), (10000,), (10000,)]
DEFAULT_DEPTH = [0]
# For ravel_multi_index op, ndim(shape) = 2; hence data NDArray's first dim = 2
# First dimension of input of ravel operator should match shape parameter dimension
# DEFAULT_SHAPE is reused for ravel_multi_index op
RAVEL_DATA = [(2, 1024)]

# For loss operators
DEFAULT_DATA_3d = [(1024, 100, 100)]
DEFAULT_LABEL = [(100,100)]
Expand Down Expand Up @@ -130,35 +145,46 @@
"p_nd": DEFAULT_P_ND,
"axis_shape": DEFAULT_AXIS_SHAPE,
"axis": DEFAULT_AXIS,
"weight" : DEFAULT_WEIGHT,
"weight32" : DEFAULT_WEIGHT,
"grad" : DEFAULT_GRAD,
"mean" : DEFAULT_MEAN,
"var" : DEFAULT_VAR,
"mom" : DEFAULT_MOM,
"n" : DEFAULT_N,
"d" : DEFAULT_D,
"v" : DEFAULT_V,
"z" : DEFAULT_Z,
"g" : DEFAULT_G,
"delta" : DEFAULT_DELTA,
"lr" : DEFAULT_LR,
"lrs" : DEFAULT_LRS,
"wds" : DEFAULT_LRS,
"gamma1" : DEFAULT_GAMMA_1,
"gamma2" : DEFAULT_GAMMA_2,
"epsilon" : DEFAULT_EPSILON,
"beta1" : DEFAULT_BETA_1,
"beta2" : DEFAULT_BETA_2,
"t" : DEFAULT_T,
"rescale_grad" : DEFAULT_RESCALE_GRAD,
"clip_grad" : DEFAULT_CLIP_GRADIENT,
"lazy_update" : DEFAULT_LAZY_UPDATE,
"weight": DEFAULT_WEIGHT,
"weight32": DEFAULT_WEIGHT,
"grad": DEFAULT_GRAD,
"mean": DEFAULT_MEAN,
"var": DEFAULT_VAR,
"mom": DEFAULT_MOM,
"n": DEFAULT_N,
"d": DEFAULT_D,
"v": DEFAULT_V,
"z": DEFAULT_Z,
"g": DEFAULT_G,
"delta": DEFAULT_DELTA,
"lr": DEFAULT_LR,
"lrs": DEFAULT_LRS,
"wds": DEFAULT_LRS,
"gamma1": DEFAULT_GAMMA_1,
"gamma2": DEFAULT_GAMMA_2,
"epsilon": DEFAULT_EPSILON,
"beta1": DEFAULT_BETA_1,
"beta2": DEFAULT_BETA_2,
"t": DEFAULT_T,
"rescale_grad": DEFAULT_RESCALE_GRAD,
"clip_grad": DEFAULT_CLIP_GRADIENT,
"lazy_update": DEFAULT_LAZY_UPDATE,
"data_4d": DEFAULT_DATA_4d,
"dim1": DEFAULT_DIM_1,
"dim2": DEFAULT_DIM_2,
"block_size": DEFAULT_BLOCK_SIZE,
"args": DEFAULT_ARGS,
"a": DEFAULT_DATA,
"index": DEFAULT_INDEX,
"indices": DEFAULT_INDICES,
"begin": DEFAULT_BEGIN,
"end": DEFAULT_END,
"shape_like": DEFAULT_SHAPE_LIKE,
"x": DEFAULT_X,
"y": DEFAULT_Y,
"condition": DEFAULT_COND,
"depth": DEFAULT_DEPTH,
"ravel_data": RAVEL_DATA,
"data_smce": DEFAULT_DATA_SMCE,
"data_3d": DEFAULT_DATA_3d,
"label_smce": DEFAULT_LABEL_SMCE,
Expand All @@ -174,4 +200,5 @@
"mu", "sigma", "lam", "alpha", "beta", "gamma", "k", "p",
"low", "high", "weight", "bias", "moving_mean", "moving_var",
"weight", "weight32", "grad", "mean", "var", "mom", "n", "d",
"v", "z", "g", "delta", "args", "label"]
"v", "z", "g", "delta", "args", "indices", "shape_like", "y",
"x", "condition", "a", "index", "raveL_data", "label"]
11 changes: 5 additions & 6 deletions benchmark/opperf/utils/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,9 @@
from .common_utils import merge_map_list
from .op_registry_utils import prepare_op_inputs
from benchmark.opperf.rules.default_params import PARAMS_OF_TYPE_NDARRAY
from .profiler_utils import cpp_profile,python_profile
from .profiler_utils import cpp_profile, python_profile


no_backward = ['softmax_cross_entropy']
no_backward = ['gather_nd', 'softmax_cross_entropy']
ChaiBapchya marked this conversation as resolved.
Show resolved Hide resolved

def _prepare_op_inputs(inputs, run_backward, dtype, ctx):
mx.random.seed(41)
Expand Down Expand Up @@ -70,7 +69,7 @@ def _run_nd_operator_performance_test(op, inputs, run_backward, warmup, runs, ar
# Warm up, ignore the profiler output
if not args_list:
_, _ = benchmark_helper_func(op, warmup, [], **kwargs_list[0])
else:
else:
_, _ = benchmark_helper_func(op, warmup, args_list[0], **kwargs_list[0])

# Run Benchmarks
Expand All @@ -84,7 +83,7 @@ def _run_nd_operator_performance_test(op, inputs, run_backward, warmup, runs, ar
profiler_output["inputs"] = inputs[idx]
op_benchmark_result[op.__name__].append(profiler_output)
else:
for idx, (args,kwargs) in enumerate(zip(args_list,kwargs_list)):
for idx, (args, kwargs) in enumerate(zip(args_list, kwargs_list)):
_, profiler_output = benchmark_helper_func(op, runs, args, **kwargs)

# Add inputs used for profiling this operator into result
Expand Down Expand Up @@ -153,7 +152,7 @@ def run_op_benchmarks(ops, dtype, ctx, profiler, warmup, runs):

# setting backward false for ops with known issue
if op in no_backward:
op_params["has_backward"] = False
op_params["has_backward"] = False

# Run benchmarks
cur_op_res = run_performance_test(op_params["nd_op_handle"],
Expand Down
4 changes: 2 additions & 2 deletions benchmark/opperf/utils/ndarray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def nd_forward_backward_and_profile(op, runs, *args, **kwargs):
"""
for _ in range(runs):
with mx.autograd.record():
if not isinstance(args[0],nd.NDArray):
if not isinstance(args[0], nd.NDArray):
res = op(**kwargs)
else:
res = op(*args, **kwargs)
Expand Down Expand Up @@ -75,7 +75,7 @@ def nd_forward_and_profile(op, runs, *args, **kwargs):
any results from NDArray operation execution
"""
for _ in range(runs):
if not isinstance(args[0],nd.NDArray):
if not isinstance(args[0], nd.NDArray):
res = op(**kwargs)
else:
res = op(*args, **kwargs)
Expand Down
46 changes: 37 additions & 9 deletions benchmark/opperf/utils/op_registry_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _select_ops(operator_names, filters=("_contrib", "_"), merge_op_forward_back
# Filter out deprecated operators
filters += ("normal", "uniform", "BatchNorm_v1", "Flatten", "contrib_CTCLoss", "Pad", "Cast",
"Pooling_v1", "Concat", "Reshape", "Convolution_v1", "SliceChannel", "Crop",
"crop", "onehot_encode")
"crop", "onehot_encode", "batch_take")

if merge_op_forward_backward:
filters += ("_backward",)
Expand Down Expand Up @@ -117,7 +117,7 @@ def prepare_op_inputs(op, arg_params):
inputs = []

# 4d tensor is needed only by following two ops
ops_4d = ['depth_to_space','space_to_depth']
ops_4d = ['depth_to_space', 'space_to_depth']

# 3d tensor is needed by following ops
ops_3d = ['CTCLoss', 'ctc_loss']
Expand All @@ -126,7 +126,9 @@ def prepare_op_inputs(op, arg_params):
arg_values = {}
for arg_name, arg_type in zip(arg_params["params"]["arg_names"],
arg_params["params"]["arg_types"]):
if "NDArray" in arg_type and arg_name + "_nd" in DEFAULTS_INPUTS:
if "NDArray" in arg_type and op == "ravel_multi_index":
arg_values[arg_name] = DEFAULTS_INPUTS["ravel_data"]
elif "NDArray" in arg_type and arg_name + "_nd" in DEFAULTS_INPUTS:
arg_values[arg_name] = DEFAULTS_INPUTS[arg_name + "_nd"]
elif "NDArray" in arg_type and op in ops_4d and arg_name + "_4d" in DEFAULTS_INPUTS:
arg_values[arg_name] = DEFAULTS_INPUTS[arg_name + "_4d"]
Expand Down Expand Up @@ -235,7 +237,7 @@ def get_all_random_sampling_operators():

# Filter for Random Sampling operators
random_sampling_mx_operators = {}
for op_name, op_params in mx_operators.items():
for op_name, _ in mx_operators.items():
if op_name.startswith(("random_", "sample_")) and op_name not in unique_ops:
random_sampling_mx_operators[op_name] = mx_operators[op_name]
return random_sampling_mx_operators
Expand Down Expand Up @@ -277,9 +279,9 @@ def get_all_optimizer_operators():

# Filter for Optimizer operators
optimizer_mx_operators = {}
for op_name, op_params in mx_operators.items():
if op_name in optimizer_ops and op_name not in unique_ops:
optimizer_mx_operators[op_name] = mx_operators[op_name]
for op_name, _ in mx_operators.items():
if op_name in optimizer_ops and op_name not in unique_ops:
optimizer_mx_operators[op_name] = mx_operators[op_name]
return optimizer_mx_operators

def get_all_sorting_searching_operators():
Expand All @@ -296,7 +298,7 @@ def get_all_sorting_searching_operators():

# Filter for Sort and search operators
sort_search_mx_operators = {}
for op_name, op_params in mx_operators.items():
for op_name, _ in mx_operators.items():
if op_name in sort_search_ops and op_name not in unique_ops:
sort_search_mx_operators[op_name] = mx_operators[op_name]
return sort_search_mx_operators
Expand All @@ -316,11 +318,37 @@ def get_all_rearrange_operators():

# Filter for Array Rearrange operators
rearrange_mx_operators = {}
for op_name, op_params in mx_operators.items():
for op_name, _ in mx_operators.items():
if op_name in rearrange_ops and op_name not in unique_ops:
rearrange_mx_operators[op_name] = mx_operators[op_name]
return rearrange_mx_operators


def get_all_indexing_routines():
"""Gets all indexing routines registered with MXNet.

Returns
-------
{"operator_name": {"has_backward", "nd_op_handle", "params"}}
"""
# @ChaiBapchya unravel_index errors out on certain inputs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a TODO ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is an open issue whose link I have already added.

# tracked here https://github.com/apache/incubator-mxnet/issues/16771
# @ChaiBapchya scatter_nd errors with core dump
# tracked here https://github.com/apache/incubator-mxnet/issues/17480
indexing_routines = ['slice', 'slice_axis', 'slice_like', 'take', 'one_hot',
'where', 'ravel_multi_index', 'gather_nd', 'pick']

# Get all mxnet operators
mx_operators = _get_all_mxnet_operators()

# Filter for Indexing routines
indexing_mx_routines = {}
for op_name, _ in mx_operators.items():
if op_name in indexing_routines and op_name not in unique_ops:
indexing_mx_routines[op_name] = mx_operators[op_name]
return indexing_mx_routines


def get_all_loss_operators():
"""Gets all Neural Network loss operators registered with MXNet.

Expand Down
6 changes: 1 addition & 5 deletions src/operator/tensor/broadcast_reduce_op_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,6 @@ Examples::
// picks elements with specified indices along axis 1
pick(x, y=[0,1,0], 1) = [ 1., 4., 5.]

y = [[ 1.],
[ 0.],
[ 2.]]

// picks elements with specified indices along axis 1 using 'wrap' mode
// to place indicies that would normally be out of bounds
pick(x, y=[2,-1,-2], 1, mode='wrap') = [ 1., 4., 5.]
Expand All @@ -148,7 +144,7 @@ Examples::
[ 2.]]

// picks elements with specified indices along axis 1 and dims are maintained
pick(x,y, 1, keepdims=True) = [[ 2.],
pick(x, y, 1, keepdims=True) = [[ 2.],
[ 3.],
[ 6.]]

Expand Down