diff --git a/3rdparty/mkldnn b/3rdparty/mkldnn
index 52c3052df8ec..cb2cc7ac17ff 160000
--- a/3rdparty/mkldnn
+++ b/3rdparty/mkldnn
@@ -1 +1 @@
-Subproject commit 52c3052df8ec1d5b8b45cb6c350a952840eabd42
+Subproject commit cb2cc7ac17ff4e2ef50805c7048d33256d82be4d
diff --git a/cd/python/pypi/Jenkins_pipeline.groovy b/cd/python/pypi/Jenkins_pipeline.groovy
index e9f172a570fe..fa9300db3ca0 100644
--- a/cd/python/pypi/Jenkins_pipeline.groovy
+++ b/cd/python/pypi/Jenkins_pipeline.groovy
@@ -27,7 +27,7 @@
// This is a temporary solution until we are confident with the packages generated by CI
// This should be removed in the not too distant future.
// We only skip the publish step so we can still QA the other variants.
-pypi_releases = ["cu92", "cu92mkl"]
+pypi_releases = []
def get_pipeline(mxnet_variant) {
def node_type = mxnet_variant.startsWith('cu') ? NODE_LINUX_GPU : NODE_LINUX_CPU
@@ -72,6 +72,7 @@ def push(mxnet_variant) {
} else {
echo "Temporarily skipping publishing PyPI package for '${mxnet_variant}'."
}
+ sh "./ci/docker/runtime_functions.sh cd_s3_publish"
}
}
diff --git a/cd/python/pypi/pypi_publish.py b/cd/python/pypi/pypi_publish.py
index 7e09f644c734..2729068dd503 100755
--- a/cd/python/pypi/pypi_publish.py
+++ b/cd/python/pypi/pypi_publish.py
@@ -35,10 +35,8 @@ def post_wheel(path):
logging.info('Posting {} to PyPI'.format(path))
pypi_credentials = get_secret()
- cmd = 'python3 -m twine upload --username {} --password {} {}'.format(
- pypi_credentials['username'],
- pypi_credentials['password'],
- path)
+ cmd = 'python3 -m twine upload {}'.format(path)
+ version = os.path.basename(path).split('-')[1]
# The PyPI credentials for DEV has username set to 'skipPublish'
# This way we do not attempt to publish the PyPI package
@@ -47,14 +45,15 @@ def post_wheel(path):
print('In DEV account, skipping publish')
print('Would have run: {}'.format(cmd))
return 0
- else:
+ elif any(test_version_mark in version for test_version_mark in ['a', 'b', 'dev']):
print('Skipping publishing nightly builds to Pypi.')
print('See https://github.com/pypa/pypi-support/issues/50 for details')
return 0
-
- # DO NOT PRINT CMD IN THIS BLOCK, includes password
- p = subprocess.run(cmd.split(' '),
- stdout=subprocess.PIPE)
+ else:
+ env = os.environ.copy()
+ env['TWINE_USERNAME'] = pypi_credentials['username']
+ env['TWINE_PASSWORD'] = pypi_credentials['password']
+ p = subprocess.run(cmd.split(' '), stdout=subprocess.PIPE, env=env)
logging.info(p.stdout)
return p.returncode
@@ -85,7 +84,7 @@ def get_secret():
raise e
else:
return json.loads(get_secret_value_response['SecretString'])
-
-
+
+
if __name__ == '__main__':
sys.exit(post_wheel(sys.argv[1]))
diff --git a/ci/docker/install/requirements b/ci/docker/install/requirements
index cbfc521e2c08..fd716f5fa815 100644
--- a/ci/docker/install/requirements
+++ b/ci/docker/install/requirements
@@ -26,8 +26,8 @@ h5py==2.8.0rc1
mock==2.0.0
nose==1.3.7
nose-timer==0.7.3
-numpy>1.16.0,<2.0.0
+numpy>1.16.0,<1.18.0
pylint==2.3.1; python_version >= '3.0'
requests<2.19.0,>=2.18.4
-scipy==1.0.1
+scipy==1.2.1
six==1.11.0
diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh
index b658f953a78a..e078b2a8f89c 100755
--- a/ci/docker/runtime_functions.sh
+++ b/ci/docker/runtime_functions.sh
@@ -2065,6 +2065,15 @@ cd_pypi_publish() {
./cd/python/pypi/pypi_publish.py `readlink -f wheel_build/dist/*.whl`
}
+cd_s3_publish() {
+ set -ex
+ pip3 install --user awscli
+ filepath=$(readlink -f wheel_build/dist/*.whl)
+ filename=$(basename $file_path)
+ variant=$(echo $filename | cut -d'-' -f1 | cut -d'_' -f2 -s)
+ aws s3 cp --grants read=uri=http://acs.amazonaws.com/groups/global/AllUsers,full=id=43f628fab72838a4f0b929d7f1993b14411f4b0294b011261bc6bd3e950a6822 s3://apache-mxnet/dist/${variant}/${filename}
+}
+
build_static_scala_mkl() {
set -ex
pushd .
diff --git a/example/neural_collaborative_filtering/README.md b/example/neural_collaborative_filtering/README.md
index 819f4d94dff9..00d3ed12295b 100644
--- a/example/neural_collaborative_filtering/README.md
+++ b/example/neural_collaborative_filtering/README.md
@@ -29,15 +29,6 @@ Author: Dr. Xiangnan He (http://www.comp.nus.edu.sg/~xiangnan/)
Code Reference: https://github.com/hexiangnan/neural_collaborative_filtering
-## Environment Settings
-We use MXnet with MKL-DNN as the backend.
-- MXNet version: '1.5.1'
-
-## Install
-```
-pip install -r requirements.txt
-```
-
## Dataset
We provide the processed datasets on [Google Drive](https://drive.google.com/drive/folders/1qACR_Zhc2O2W0RrazzcepM2vJeh0MMdO?usp=sharing): MovieLens 20 Million (ml-20m), you can download directly or
@@ -66,7 +57,9 @@ We provide the pretrained ml-20m model on [Google Drive](https://drive.google.co
|dtype|HR@10|NDCG@10|
|:---:|:--:|:--:|
|float32|0.6393|0.3849|
-|int8|0.6366|0.3824|
+|float32 opt|0.6393|0.3849|
+|int8|0.6395|0.3852|
+|int8 opt|0.6396|0.3852|
## Training
@@ -75,11 +68,20 @@ We provide the pretrained ml-20m model on [Google Drive](https://drive.google.co
python train.py # --gpu=0
```
+## Model Optimizer
+
+```
+# optimize model
+python model_optimizer.py
+```
+
## Calibration
```
# neumf calibration on ml-20m dataset
python ncf.py --prefix=./model/ml-20m/neumf --calibration
+# optimized neumf calibration on ml-20m dataset
+python ncf.py --prefix=./model/ml-20m/neumf-opt --calibration
```
## Evaluation
@@ -87,15 +89,25 @@ python ncf.py --prefix=./model/ml-20m/neumf --calibration
```
# neumf float32 inference on ml-20m dataset
python ncf.py --batch-size=1000 --prefix=./model/ml-20m/neumf
+# optimized neumf float32 inference on ml-20m dataset
+python ncf.py --batch-size=1000 --prefix=./model/ml-20m/neumf-opt
# neumf int8 inference on ml-20m dataset
python ncf.py --batch-size=1000 --prefix=./model/ml-20m/neumf-quantized
+# optimized neumf int8 inference on ml-20m dataset
+python ncf.py --batch-size=1000 --prefix=./model/ml-20m/neumf-opt-quantized
```
## Benchmark
```
+usage: bash ./benchmark.sh [[[-p prefix ] [-e epoch] [-d dataset] [-b batch_size] [-i instance] [-c cores/instance]] | [-h]]
+
# neumf float32 benchmark on ml-20m dataset
-python ncf.py --batch-size=1000 --prefix=./model/ml-20m/neumf --benchmark
+sh benchmark.sh -p model/ml-20m/neumf
+# optimized neumf float32 benchmark on ml-20m dataset
+sh benchmark.sh -p model/ml-20m/neumf-opt
# neumf int8 benchmark on ml-20m dataset
-python ncf.py --batch-size=1000 --prefix=./model/ml-20m/neumf-quantized --benchmark
+sh benchmark.sh -p model/ml-20m/neumf-quantized
+# optimized neumf int8 benchmark on ml-20m dataset
+sh benchmark.sh -p model/ml-20m/neumf-opt-quantized
```
diff --git a/example/neural_collaborative_filtering/benchmark.sh b/example/neural_collaborative_filtering/benchmark.sh
new file mode 100644
index 000000000000..60fec746cd20
--- /dev/null
+++ b/example/neural_collaborative_filtering/benchmark.sh
@@ -0,0 +1,114 @@
+#!/bin/bash
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+usage()
+{
+ echo "usage: bash ./benchmark.sh [[[-p prefix ] [-e epoch] [-d dataset] [-b batch_size] [-i instance] [-c cores/instance]] | [-h]]"
+}
+
+while [ $# -gt 0 ]; do
+ case "$1" in
+ --prefix | -p)
+ shift
+ PREFIX=$1
+ ;;
+ --epoch | -e)
+ shift
+ EPOCH=$1
+ ;;
+ --dataset | -d)
+ shift
+ DATASET=$1
+ ;;
+ --batch-size | -b)
+ shift
+ BS=$1
+ ;;
+ --instance | -i)
+ shift
+ INS=$1
+ ;;
+ --core | -c)
+ shift
+ CORES=$1
+ ;;
+ --help | -h)
+ usage
+ exit 1
+ ;;
+ *)
+ usage
+ exit 1
+ esac
+ shift
+done
+
+NUM_SOCKET=`lscpu | grep 'Socket(s)' | awk '{print $NF}'`
+NUM_NUMA_NODE=`lscpu | grep 'NUMA node(s)' | awk '{print $NF}'`
+CORES_PER_SOCKET=`lscpu | grep 'Core(s) per socket' | awk '{print $NF}'`
+NUM_CORES=$((CORES_PER_SOCKET * NUM_SOCKET))
+CORES_PER_NUMA=$((NUM_CORES / NUM_NUMA_NODE))
+echo "target machine has $NUM_CORES physical core(s) on $NUM_NUMA_NODE numa nodes of $NUM_SOCKET socket(s)."
+
+if [ -z $PREFIX ]; then
+ echo "Error: Need a model prefix."
+ exit
+fi
+if [ -z $EPOCH ]; then
+ echo "Default: set epoch of model parameters to 7."
+ EPOCH=7
+fi
+if [ -z $DATASET ]; then
+ echo "Default: set dataset to ml-20m."
+ DATASET='ml-20m'
+fi
+if [ -z $INS ]; then
+ echo "Default: launch one instance per physical core."
+ INS=$NUM_CORES
+fi
+if [ -z $CORES ]; then
+ echo "Default: divide full physical cores."
+ CORES=$((NUM_CORES / $INS))
+fi
+if [ -z $BS ]; then
+ echo "Default: set batch size to 700."
+ BS=700
+fi
+
+echo " cores/instance: $CORES"
+echo " total instances: $INS"
+echo " batch size: $BS"
+echo ""
+
+rm NCF_*.log
+
+for((i=0;i<$INS;i++));
+do
+ ((a=$i*$CORES))
+ ((b=$a+$CORES-1))
+ memid=$((b/CORES_PER_NUMA))
+ LOG=NCF_$i.log
+ echo " $i instance use $a-$b cores with $LOG"
+ KMP_AFFINITY=granularity=fine,noduplicates,compact,1,0 \
+ OMP_NUM_THREADS=$CORES \
+ numactl --physcpubind=$a-$b --membind=$memid python ncf.py --batch-size=$BS --dataset=$DATASET --epoch=$EPOCH --benchmark --prefix=$PREFIX 2>&1 | tee $LOG &
+done
+wait
+
+grep speed NCF_*.log | awk '{ sum += $(NF-1) }; END { print "Total Performance is " sum " samples/sec"}'
diff --git a/example/neural_collaborative_filtering/convert.py b/example/neural_collaborative_filtering/convert.py
index 4c64d2cdedab..7fb7f1ede9e4 100644
--- a/example/neural_collaborative_filtering/convert.py
+++ b/example/neural_collaborative_filtering/convert.py
@@ -38,7 +38,7 @@ def parse_args():
parser = ArgumentParser()
parser.add_argument('--dataset', nargs='?', default='ml-20m', choices=['ml-1m', 'ml-20m'],
help='The dataset name, temporary support ml-1m and ml-20m.')
- parser.add_argument('path', type=str, default = './data/',
+ parser.add_argument('--path', type=str, default = './data/',
help='Path to reviews CSV file from MovieLens')
parser.add_argument('-n', '--negatives', type=int, default=999,
help='Number of negative samples for each positive'
diff --git a/example/neural_collaborative_filtering/core/model.py b/example/neural_collaborative_filtering/core/model.py
index b516e5039fed..6c03bb01a357 100644
--- a/example/neural_collaborative_filtering/core/model.py
+++ b/example/neural_collaborative_filtering/core/model.py
@@ -37,6 +37,27 @@ def _init_weight(self, _, arr):
limit = np.sqrt(3. / self._fan_in)
mx.random.uniform(-limit, limit, out=arr)
+# only for inference model optimize
+def mlp_opt(user, item, factor_size, model_layers, max_user, max_item):
+ user_weight = mx.sym.Variable('fused_mlp_user_weight', init=mx.init.Normal(0.01))
+ item_weight = mx.sym.Variable('fused_mlp_item_weight', init=mx.init.Normal(0.01))
+ embed_user = mx.sym.Embedding(data=user, weight=user_weight, input_dim=max_user,
+ output_dim=factor_size * 2, name='fused_embed_user'+str(factor_size))
+ embed_item = mx.sym.Embedding(data=item, weight=item_weight, input_dim=max_item,
+ output_dim=factor_size * 2, name='fused_embed_item'+str(factor_size))
+ pre_gemm_concat = embed_user + embed_item
+
+ for i in range(1, len(model_layers)):
+ if i==1:
+ pre_gemm_concat = mx.sym.Activation(data=pre_gemm_concat, act_type='relu', name='act_'+str(i-1))
+ continue
+ else:
+ mlp_weight_init = golorot_uniform(model_layers[i-1], model_layers[i])
+ mlp_weight = mx.sym.Variable('fc_{}_weight'.format(i-1), init=mlp_weight_init)
+ pre_gemm_concat = mx.sym.FullyConnected(data=pre_gemm_concat, weight=mlp_weight, num_hidden=model_layers[i], name='fc_'+str(i-1))
+ pre_gemm_concat = mx.sym.Activation(data=pre_gemm_concat, act_type='relu', name='act_'+str(i-1))
+
+ return pre_gemm_concat
def mlp(user, item, factor_size, model_layers, max_user, max_item):
user_weight = mx.sym.Variable('mlp_user_weight', init=mx.init.Normal(0.01))
@@ -47,14 +68,11 @@ def mlp(user, item, factor_size, model_layers, max_user, max_item):
output_dim=factor_size, name='embed_item'+str(factor_size))
pre_gemm_concat = mx.sym.concat(embed_user, embed_item, dim=1, name='pre_gemm_concat')
- for i, layer in enumerate(model_layers):
- if i==0:
- mlp_weight_init = golorot_uniform(2 * factor_size, model_layers[i])
- else:
- mlp_weight_init = golorot_uniform(model_layers[i-1], model_layers[i])
- mlp_weight = mx.sym.Variable('fc_{}_weight'.format(i), init=mlp_weight_init)
- pre_gemm_concat = mx.sym.FullyConnected(data=pre_gemm_concat, weight=mlp_weight, num_hidden=layer, name='fc_'+str(i))
- pre_gemm_concat = mx.sym.Activation(data=pre_gemm_concat, act_type='relu', name='act_'+str(i))
+ for i in range(1, len(model_layers)):
+ mlp_weight_init = golorot_uniform(model_layers[i-1], model_layers[i])
+ mlp_weight = mx.sym.Variable('fc_{}_weight'.format(i-1), init=mlp_weight_init)
+ pre_gemm_concat = mx.sym.FullyConnected(data=pre_gemm_concat, weight=mlp_weight, num_hidden=model_layers[i], name='fc_'+str(i-1))
+ pre_gemm_concat = mx.sym.Activation(data=pre_gemm_concat, act_type='relu', name='act_'+str(i-1))
return pre_gemm_concat
@@ -70,24 +88,34 @@ def gmf(user, item, factor_size, max_user, max_item):
return pred
def get_model(model_type='neumf', factor_size_mlp=128, factor_size_gmf=64,
- model_layers=[256, 128, 64], num_hidden=1,
- max_user=138493, max_item=26744):
+ model_layers=[256, 256, 128, 64], num_hidden=1,
+ max_user=138493, max_item=26744, opt=False):
# input
user = mx.sym.Variable('user')
item = mx.sym.Variable('item')
if model_type == 'mlp':
- net = mlp(user=user, item=item,
- factor_size=factor_size_mlp, model_layers=model_layers,
- max_user=max_user, max_item=max_item)
+ if opt:
+ net = mlp_opt(user=user, item=item,
+ factor_size=factor_size_mlp, model_layers=model_layers,
+ max_user=max_user, max_item=max_item)
+ else:
+ net = mlp(user=user, item=item,
+ factor_size=factor_size_mlp, model_layers=model_layers,
+ max_user=max_user, max_item=max_item)
elif model_type == 'gmf':
net = gmf(user=user, item=item,
factor_size=factor_size_gmf,
max_user=max_user, max_item=max_item)
elif model_type == 'neumf':
- net_mlp = mlp(user=user, item=item,
- factor_size=factor_size_mlp, model_layers=model_layers,
- max_user=max_user, max_item=max_item)
+ if opt:
+ net_mlp = mlp_opt(user=user, item=item,
+ factor_size=factor_size_mlp, model_layers=model_layers,
+ max_user=max_user, max_item=max_item)
+ else:
+ net_mlp = mlp(user=user, item=item,
+ factor_size=factor_size_mlp, model_layers=model_layers,
+ max_user=max_user, max_item=max_item)
net_gmf = gmf(user=user, item=item,
factor_size=factor_size_gmf,
max_user=max_user, max_item=max_item)
diff --git a/example/neural_collaborative_filtering/model_optimizer.py b/example/neural_collaborative_filtering/model_optimizer.py
new file mode 100644
index 000000000000..2866ae7e7e05
--- /dev/null
+++ b/example/neural_collaborative_filtering/model_optimizer.py
@@ -0,0 +1,81 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import os
+import time
+import argparse
+import logging
+import math
+import random
+import numpy as np
+import mxnet as mx
+from core.model import get_model
+from core.dataset import NCFTrainData
+
+logging.basicConfig(level=logging.DEBUG)
+
+parser = argparse.ArgumentParser(description="Run model optimizer.",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+parser.add_argument('--path', nargs='?', default='./data/',
+ help='Input data path.')
+parser.add_argument('--dataset', nargs='?', default='ml-20m',
+ help='The dataset name.')
+parser.add_argument('--model-prefix', type=str, default='./model/ml-20m/neumf')
+parser.add_argument('--epoch', type=int, default=7, help='parameters epoch')
+parser.add_argument('--model-type', type=str, default='neumf', choices=['neumf', 'gmf', 'mlp'],
+ help="mdoel type")
+parser.add_argument('--layers', default='[256, 256, 128, 64]',
+ help="list of number hiddens of fc layers in mlp model.")
+parser.add_argument('--factor-size-gmf', type=int, default=64,
+ help="outdim of gmf embedding layers.")
+parser.add_argument('--num-hidden', type=int, default=1,
+ help="num-hidden of neumf fc layer")
+
+head = '%(asctime)-15s %(message)s'
+logging.basicConfig(level=logging.INFO, format=head)
+
+# arg parser
+args = parser.parse_args()
+logging.info(args)
+
+model_prefix = args.model_prefix
+model_type = args.model_type
+model_layers = eval(args.layers)
+factor_size_gmf = args.factor_size_gmf
+factor_size_mlp = int(model_layers[0]/2)
+num_hidden = args.num_hidden
+train_dataset = NCFTrainData((args.path + args.dataset + '/train-ratings.csv'), nb_neg=4)
+net = get_model(model_type, factor_size_mlp, factor_size_gmf,
+ model_layers, num_hidden, train_dataset.nb_users, train_dataset.nb_items, opt=True)
+
+raw_params, _ = mx.model.load_params(model_prefix, args.epoch)
+fc_0_weight_split = mx.nd.split(raw_params['fc_0_weight'], axis=1, num_outputs=2)
+fc_0_left = fc_0_weight_split[0]
+fc_0_right = fc_0_weight_split[1]
+
+user_weight_fusion = mx.nd.FullyConnected(data = raw_params['mlp_user_weight'], weight=fc_0_left, bias=raw_params['fc_0_bias'], no_bias=False, num_hidden=model_layers[0])
+item_weight_fusion = mx.nd.FullyConnected(data = raw_params['mlp_item_weight'], weight=fc_0_right, no_bias=True, num_hidden=model_layers[0])
+
+opt_params = raw_params
+del opt_params['mlp_user_weight']
+del opt_params['mlp_item_weight']
+del opt_params['fc_0_bias']
+opt_params['fused_mlp_user_weight'] = user_weight_fusion
+opt_params['fused_mlp_item_weight'] = item_weight_fusion
+
+mx.model.save_checkpoint(model_prefix + '-opt', args.epoch, net, opt_params, {})
+
diff --git a/example/neural_collaborative_filtering/ncf.py b/example/neural_collaborative_filtering/ncf.py
index 0fd9f733a1bd..b01be01bc8d9 100644
--- a/example/neural_collaborative_filtering/ncf.py
+++ b/example/neural_collaborative_filtering/ncf.py
@@ -42,20 +42,12 @@
help='max number of item index.')
parser.add_argument('--batch-size', type=int, default=256,
help='number of examples per batch')
-parser.add_argument('--model-type', type=str, default='neumf', choices=['neumf', 'gmf', 'mlp'],
- help="mdoel type")
-parser.add_argument('--layers', default='[256, 128, 64]',
- help="list of number hiddens of fc layers in mlp model.")
-parser.add_argument('--factor-size-gmf', type=int, default=64,
- help="outdim of gmf embedding layers.")
-parser.add_argument('--num-hidden', type=int, default=1,
- help="num-hidden of neumf fc layer")
parser.add_argument('--topk', type=int, default=10,
help="topk for accuracy evaluation.")
parser.add_argument('--gpu', type=int, default=None,
help="index of gpu to run, e.g. 0 or 1. None means using cpu().")
parser.add_argument('--benchmark', action='store_true', help="whether to benchmark performance only")
-parser.add_argument('--epoch', type=int, default=0, help='model checkpoint index for inference')
+parser.add_argument('--epoch', type=int, default=7, help='model checkpoint index for inference')
parser.add_argument('--prefix', default='./model/ml-20m/neumf', help="model checkpoint prefix")
parser.add_argument('--calibration', action='store_true', help="whether to calibrate model")
parser.add_argument('--calib-mode', type=str, choices=['naive', 'entropy'], default='naive',
@@ -85,11 +77,6 @@
max_user = args.max_user
max_item = args.max_item
batch_size = args.batch_size
- model_type = args.model_type
- model_layers = eval(args.layers)
- factor_size_gmf = args.factor_size_gmf
- factor_size_mlp = int(model_layers[0]/2)
- num_hidden = args.num_hidden
benchmark = args.benchmark
calibration = args.calibration
calib_mode = args.calib_mode
@@ -129,7 +116,7 @@
cqsym, cqarg_params, aux_params, collector = quantize_graph(sym=net, arg_params=arg_params, aux_params=aux_params,
excluded_sym_names=excluded_sym_names,
calib_mode=calib_mode,
- quantized_dtype=args.quantized_dtype, logger=logging)
+ quantized_dtype=quantized_dtype, logger=logging)
max_num_examples = num_calib_batches * batch_size
mod._exec_group.execs[0].set_monitor_callback(collector.collect, monitor_all=True)
num_batches = 0
@@ -144,12 +131,17 @@
% (num_batches, batch_size))
cqsym, cqarg_params, aux_params = calib_graph(qsym=cqsym, arg_params=arg_params, aux_params=aux_params,
collector=collector, calib_mode=calib_mode,
- quantized_dtype=args.quantized_dtype, logger=logging)
+ quantized_dtype=quantized_dtype, logger=logging)
sym_name = '%s-symbol.json' % (args.prefix + '-quantized')
cqsym = cqsym.get_backend_symbol('MKLDNN_QUANTIZE')
mx.model.save_checkpoint(args.prefix + '-quantized', args.epoch, cqsym, cqarg_params, aux_params)
elif benchmark:
logging.info('Benchmarking...')
+ data = [mx.random.randint(0, 1000, shape=shape, ctx=ctx) for _, shape in mod.data_shapes]
+ batch = mx.io.DataBatch(data, []) # empty label
+ for i in range(2000):
+ mod.forward(batch, is_train=False)
+ logging.info('Benchmarking...')
num_samples = 0
for ib, batch in enumerate(val_iter):
if ib == 5:
diff --git a/example/neural_collaborative_filtering/train.py b/example/neural_collaborative_filtering/train.py
index 0b0cfad1ef39..c68f271a6f0d 100644
--- a/example/neural_collaborative_filtering/train.py
+++ b/example/neural_collaborative_filtering/train.py
@@ -45,7 +45,7 @@
help="mdoel type")
parser.add_argument('--num-negative', type=int, default=4,
help="number of negative samples per positive sample while training.")
-parser.add_argument('--layers', default='[256, 128, 64]',
+parser.add_argument('--layers', default='[256, 256, 128, 64]',
help="list of number hiddens of fc layers in mlp model.")
parser.add_argument('--factor-size-gmf', type=int, default=64,
help="outdim of gmf embedding layers.")
diff --git a/example/quantization/README.md b/example/quantization/README.md
index 8cdc1bb7e06f..b934a811f31d 100644
--- a/example/quantization/README.md
+++ b/example/quantization/README.md
@@ -9,7 +9,7 @@ This folder contains examples of quantizing a FP32 model with Intel® MKL-DNN or
Model Quantization with Intel® MKL-DNN
-Intel® MKL-DNN supports quantization with subgraph features on Intel® CPU Platform and can bring performance improvements on the [Intel® Xeon® Scalable Platform](https://www.intel.com/content/www/us/en/processors/xeon/scalable/xeon-scalable-platform.html). A new quantization script `imagenet_gen_qsym_mkldnn.py` has been designed to launch quantization for image-classification models with Intel® MKL-DNN. This script integrates with [Gluon-CV modelzoo](https://gluon-cv.mxnet.io/model_zoo/classification.html), so that more pre-trained models can be downloaded from Gluon-CV and then converted for quantization. To apply quantization flow to your project directly, please refer [Quantize custom models with MKL-DNN backend](https://mxnet.apache.org/tutorials/mkldnn/mkldnn_quantization.html).
+Intel® MKL-DNN supports quantization with subgraph features on Intel® CPU Platform and can bring performance improvements on the [Intel® Xeon® Scalable Platform](https://www.intel.com/content/www/us/en/processors/xeon/scalable/xeon-scalable-platform.html). A new quantization script `imagenet_gen_qsym_mkldnn.py` has been designed to launch quantization for image-classification models with Intel® MKL-DNN. This script integrates with [Gluon-CV modelzoo](https://gluon-cv.mxnet.io/model_zoo/classification.html), so that more pre-trained models can be downloaded from Gluon-CV and then converted for quantization. To apply quantization flow to your project directly, please refer [Quantize custom models with MKL-DNN backend](https://mxnet.apache.org/api/python/docs/tutorials/performance/backend/mkldnn/mkldnn_quantization.html).
```
usage: imagenet_gen_qsym_mkldnn.py [-h] [--model MODEL] [--epoch EPOCH]
diff --git a/python/mxnet/_numpy_op_doc.py b/python/mxnet/_numpy_op_doc.py
index 0d0e3b64491b..6abb96c62c09 100644
--- a/python/mxnet/_numpy_op_doc.py
+++ b/python/mxnet/_numpy_op_doc.py
@@ -20,6 +20,107 @@
"""Doc placeholder for numpy ops with prefix _np."""
+def _np_all(a, axis=None, keepdims=False, out=None):
+ """
+ Test whether all array elements along a given axis evaluate to True.
+
+ Parameters
+ ----------
+ a : array_like
+ Input array or object that can be converted to an array.
+ axis : None or int or tuple of ints, optional
+ Axis or axes along which a logical AND reduction is performed.
+ The default (axis = None) is to perform a logical AND over
+ all the dimensions of the input array.
+ keepdims : bool, optional
+ If this is set to True, the axes which are reduced are left in
+ the result as dimensions with size one. With this option,
+ the result will broadcast correctly against the input array.
+ out : ndarray, optional
+ Alternate output array in which to place the result. It must have
+ the same shape as the expected output and its type is preserved
+
+ Returns
+ --------
+ all : ndarray, bool
+ A new boolean or array is returned unless out is specified,
+ in which case a reference to out is returned.
+
+ Examples:
+ ---------
+ >>> np.all([[True,False],[True,True]])
+ False
+
+ >>> np.all([[True,False],[True,True]], axis=0)
+ array([ True, False])
+
+ >>> np.all([-1, 4, 5])
+ True
+
+ >>> np.all([1.0, np.nan])
+ True
+
+ >>> o=np.array(False)
+ >>> z=np.all([-1, 4, 5], out=o)
+ >>> id(z), id(o), z
+ (28293632, 28293632, array(True)) # may vary
+ """
+ pass
+
+def _np_any(a, axis=None, keepdims=False, out=None):
+ """
+ Test whether any array element along a given axis evaluates to True.
+ Returns single boolean unless axis is not None
+
+ Parameters
+ ----------
+ a : array_like
+ Input array or object that can be converted to an array.
+ axis : None or int or tuple of ints, optional
+ Axis or axes along which a logical AND reduction is performed.
+ The default (axis = None) is to perform a logical AND over
+ all the dimensions of the input array.
+ keepdims : bool, optional
+ If this is set to True, the axes which are reduced are left in
+ the result as dimensions with size one. With this option,
+ the result will broadcast correctly against the input array.
+ out : ndarray, optional
+ Alternate output array in which to place the result. It must have
+ the same shape as the expected output and its type is preserved
+
+ Returns
+ --------
+ any : bool or ndarray
+ A new boolean or ndarray is returned unless out is specified,
+ in which case a reference to out is returned.
+
+ Examples:
+ ---------
+ >>> np.any([[True, False], [True, True]])
+ True
+
+ >>> np.any([[True, False], [False, False]], axis=0)
+ array([ True, False])
+
+ >>> np.any([-1, 0, 5])
+ True
+
+ >>> np.any(np.nan)
+ True
+
+ >>> o=np.array(False)
+ >>> z=np.any([-1, 4, 5], out=o)
+ >>> z, o
+ (array(True), array(True))
+ >>> # Check now that z is a reference to o
+ >>> z is o
+ True
+ >>> id(z), id(o) # identity of z and o # doctest: +SKIP
+ (191614240, 191614240)
+ """
+ pass
+
+
def _np_cumsum(a, axis=None, dtype=None, out=None):
"""
Return the cumulative sum of the elements along a given axis.
@@ -630,7 +731,7 @@ def _np_squeeze(a, axis=None, out=None):
pass
-def _np_max(a, axis=None, out=None, keepdims=False):
+def _np_max(a, axis=None, keepdims=False, out=None):
"""
Return the maximum of an array or maximum along an axis.
@@ -694,7 +795,14 @@ def _np_max(a, axis=None, out=None, keepdims=False):
pass
-def _np_min(a, axis=None, out=None, keepdims=False):
+def _np_amax(a, axis=None, keepdims=False, out=None):
+ """
+ Refer to _np_max
+ """
+ pass
+
+
+def _np_min(a, axis=None, keepdims=False, out=None):
"""
Return the minimum of an array or minimum along an axis.
diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py
index 01051ab7c8e4..ce22fb753ace 100644
--- a/python/mxnet/contrib/quantization.py
+++ b/python/mxnet/contrib/quantization.py
@@ -27,6 +27,7 @@
import logging
import os
import shutil
+import warnings
import numpy as np
from ..base import _LIB, check_call, py_str
from ..base import c_array, c_str, mx_uint, c_str_array
@@ -419,6 +420,7 @@ def __init__(self, calib_data):
else:
data_example = [data_example]
# suppose there must be one label in data_example
+ # TODO(xinyu-intel): little tricky here, need to refactor.
num_data = len(data_example)
assert num_data > 0
# here reshape is to handle the 5D/6D input data
@@ -426,6 +428,10 @@ def __init__(self, calib_data):
data_example[0] = data_example[0].reshape((-1,) + data_example[0].shape[2:])
self.provide_data = [DataDesc(name='data', shape=(data_example[0].shape))]
self.provide_data += [DataDesc(name='data{}'.format(i), shape=x.shape) for i, x in enumerate(data_example[1:])]
+ # data0, data1, ..., label
+ if num_data >= 3:
+ self.provide_data = [DataDesc(name='data{}'.format(i), shape=x.shape)
+ for i, x in enumerate(data_example[0:])]
self.batch_size = data_example[0].shape[0]
self.reset()
@@ -607,7 +613,9 @@ def quantize_model_mkldnn(sym, arg_params, aux_params,
A tuple of quantized symbol, quantized arg_params, and aux_params.
-------
"""
- if ctx != cpu():
+ if not isinstance(ctx, Context):
+ raise ValueError('currently only supports single ctx, while received %s' % str(ctx))
+ if ctx.device_type != 'cpu':
raise ValueError(
'quantize_model_mkldnn only support Intel cpu platform with MKL-DNN Backend')
@@ -627,8 +635,9 @@ def quantize_model_mkldnn(sym, arg_params, aux_params,
return qsym, qarg_params, aux_params
def quantize_graph(sym, arg_params, aux_params, ctx=cpu(),
- excluded_sym_names=None, excluded_op_names=None, calib_mode='entropy',
- quantized_dtype='int8', quantize_mode='full', logger=None):
+ excluded_sym_names=None, excluded_op_names=None,
+ calib_mode='entropy', quantized_dtype='int8', quantize_mode='full',
+ LayerOutputCollector=None, logger=None):
"""User-level API for generating a quantized model from a FP32 model w/o calibration
and a collector for naive or entropy calibration.
The backend quantized operators are only enabled for Linux systems. Please do not run
@@ -667,6 +676,8 @@ def quantize_graph(sym, arg_params, aux_params, ctx=cpu(),
The mode that quantization pass to apply. Support 'full' and 'smart'.
'full' means quantize all operator if possible.
'smart' means quantization pass will smartly choice which operator should be quantized.
+ LayerOutputCollector : class
+ For customize calibration method usage.
logger : Object
A logging object for printing information during the process of quantization.
Returns
@@ -711,9 +722,14 @@ def quantize_graph(sym, arg_params, aux_params, ctx=cpu(),
if logger:
logger.info(
'Create a layer output minmax collector for naive calibration')
+ elif calib_mode == 'customize' and LayerOutputCollector is not None:
+ collector = LayerOutputCollector
+ if logger:
+ logger.info(
+ 'Create a customize layer output minmax collector for calibration')
else:
raise ValueError('unknown calibration mode %s received,'
- ' expected `none`, `naive`, or `entropy`' % calib_mode)
+ ' expected `none`, `naive`, `entropy` or `customize`' % calib_mode)
if logger:
logger.info('Collector created, please use set_monitor_callback'
' to collect calibration information.')
@@ -770,9 +786,11 @@ def calib_graph(qsym, arg_params, aux_params, collector,
collector.hist_dict, quantized_dtype, logger=logger)
elif calib_mode == 'naive':
th_dict = collector.min_max_dict
+ elif calib_mode == 'customize':
+ th_dict = collector.min_max_dict
else:
raise ValueError('unknown calibration mode %s received,'
- ' expected `none`, `naive`, or `entropy`' % calib_mode)
+ ' expected `none`, `naive`, `entropy` or `customize`' % calib_mode)
qsym = _calibrate_quantized_sym(qsym, th_dict)
else:
raise ValueError('please set calibration mode to naive or entropy.')
@@ -783,10 +801,10 @@ def calib_graph(qsym, arg_params, aux_params, collector,
return qsym, qarg_params, aux_params
-def quantize_net(network, quantized_dtype='auto', quantize_mode='full',
- exclude_layers=None, exclude_layers_match=None, exclude_operators=None,
- calib_data=None, data_shapes=None, calib_mode='none',
- num_calib_examples=None, ctx=cpu(), logger=None):
+def quantize_net_v2(network, quantized_dtype='auto', quantize_mode='full',
+ exclude_layers=None, exclude_layers_match=None, exclude_operators=None,
+ calib_data=None, data_shapes=None, calib_mode='none',
+ num_calib_examples=None, ctx=cpu(), LayerOutputCollector=None, logger=None):
"""User-level API for Gluon users to generate a quantized SymbolBlock from a FP32 HybridBlock w/ or w/o calibration.
The backend quantized operators are only enabled for Linux systems. Please do not run
inference using the quantized models on Windows for now.
@@ -830,6 +848,8 @@ def quantize_net(network, quantized_dtype='auto', quantize_mode='full',
ctx : Context
Defines the device that users want to run forward propagation on the calibration
dataset for collecting layer output statistics. Currently, only supports single context.
+ LayerOutputCollector : class
+ For customize calibration method usage.
logger : Object
A logging object for printing information during the process of quantization.
@@ -906,7 +926,8 @@ def __exit__(self, exc_type, exc_value, traceback):
qsym, qarg_params, aux_params, collector = quantize_graph(
sym=symnet, arg_params=args, aux_params=auxs, ctx=ctx,
excluded_sym_names=exclude_layers, excluded_op_names=exclude_operators,
- calib_mode=calib_mode, quantized_dtype=quantized_dtype, quantize_mode=quantize_mode, logger=logger)
+ calib_mode=calib_mode, quantized_dtype=quantized_dtype, quantize_mode=quantize_mode,
+ LayerOutputCollector=LayerOutputCollector, logger=logger)
if calib_mode is not None and calib_mode != 'none':
if not isinstance(ctx, Context):
@@ -915,7 +936,7 @@ def __exit__(self, exc_type, exc_value, traceback):
if calib_data is None:
raise ValueError(
'calib_data must be provided when calib_mode=%s' % calib_mode)
- if calib_mode in ['naive', 'entropy']:
+ if calib_mode in ['naive', 'entropy', 'customize']:
data_names = [pair[0] for pair in calib_data.provide_data]
mod = Module(symbol=symnet, context=ctx,
data_names=data_names, label_names=None)
@@ -956,3 +977,19 @@ def __exit__(self, exc_type, exc_value, traceback):
net.collect_params().load(param_name, cast_dtype=True, dtype_source='saved')
net.collect_params().reset_ctx(ctx)
return net
+
+def quantize_net(network, quantized_dtype='auto', quantize_mode='full',
+ exclude_layers=None, exclude_layers_match=None, exclude_operators=None,
+ calib_data=None, data_shapes=None, calib_mode='none',
+ num_calib_examples=None, ctx=cpu(), logger=None):
+ """User-level API for Gluon users to generate a quantized SymbolBlock from a FP32 HybridBlock w/ or w/o calibration.
+ Will be deprecated after MXNet 2.0, please use quantize_net_v2.
+ """
+ warnings.warn('WARNING: This will be deprecated after MXNet 2.0, please use quantize_net_v2.')
+ return quantize_net_v2(network=network, quantized_dtype=quantized_dtype,
+ quantize_mode=quantize_mode, exclude_layers=exclude_layers,
+ exclude_layers_match=exclude_layers_match,
+ exclude_operators=exclude_operators,
+ calib_data=calib_data, data_shapes=data_shapes,
+ calib_mode=calib_mode, num_calib_examples=num_calib_examples,
+ ctx=ctx, LayerOutputCollector=None, logger=logger)
diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py
index 3b0858007197..65d015b18c33 100644
--- a/python/mxnet/ndarray/numpy/_op.py
+++ b/python/mxnet/ndarray/numpy/_op.py
@@ -28,7 +28,7 @@
from . import _internal as _npi
from ..ndarray import NDArray
-__all__ = ['shape', 'zeros', 'zeros_like', 'ones', 'ones_like', 'full', 'full_like', 'invert',
+__all__ = ['shape', 'zeros', 'zeros_like', 'ones', 'ones_like', 'full', 'full_like', 'invert', 'delete',
'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'bitwise_not',
'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs',
'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2',
@@ -914,6 +914,67 @@ def mod(x1, x2, out=None, **kwargs):
return _ufunc_helper(x1, x2, _npi.mod, _np.mod, _npi.mod_scalar, _npi.rmod_scalar, out)
+@set_module('mxnet.ndarray.numpy')
+def delete(arr, obj, axis=None):
+ """
+ Return a new array with sub-arrays along an axis deleted. For a one
+ dimensional array, this returns those entries not returned by
+ `arr[obj]`.
+
+ Parameters
+ ----------
+ arr : ndarray
+ Input array.
+ obj : slice, int or ndarray of ints
+ Indicate indices of sub-arrays to remove along the specified axis.
+ axis : int, optional
+ The axis along which to delete the subarray defined by `obj`.
+ If `axis` is None, `obj` is applied to the flattened array.
+
+ Returns
+ -------
+ out : ndarray
+ A copy of `arr` with the elements specified by `obj` removed. Note
+ that `delete` does not occur in-place. If `axis` is None, `out` is
+ a flattened array.
+
+ Examples
+ --------
+ >>> arr = np.array([[1,2,3,4], [5,6,7,8], [9,10,11,12]])
+ >>> arr
+ array([[ 1., 2., 3., 4.],
+ [ 5., 6., 7., 8.],
+ [ 9., 10., 11., 12.]])
+
+ >>> np.delete(arr, 1, 0)
+ array([[ 1., 2., 3., 4.],
+ [ 9., 10., 11., 12.]])
+
+ >>> np.delete(arr, slice(None, None, 2), 1)
+ array([[ 2., 4.],
+ [ 6., 8.],
+ [10., 12.]])
+
+ >>> np.delete(arr, np.array([1,3,5]), None)
+ array([ 1., 3., 5., 7., 8., 9., 10., 11., 12.])
+ >>> np.delete(arr, np.array([1,1,5]), None)
+ array([ 1., 3., 4., 5., 7., 8., 9., 10., 11., 12.])
+ """
+ if not isinstance(arr, NDArray):
+ raise TypeError("'arr' can not support type {}".format(str(type(arr))))
+ if isinstance(obj, slice):
+ start = obj.start
+ stop = obj.stop
+ step = 1 if obj.step is None else obj.step
+ return _npi.delete(arr, start=start, stop=stop, step=step, axis=axis)
+ elif isinstance(obj, integer_types):
+ return _npi.delete(arr, int_ind=obj, axis=axis)
+ elif isinstance(obj, NDArray):
+ return _npi.delete(arr, obj, axis=axis)
+ else:
+ raise TypeError("'obj' can not support type {}".format(str(type(obj))))
+
+
@set_module('mxnet.ndarray.numpy')
@wrap_np_binary_func
def remainder(x1, x2, out=None):
@@ -4783,6 +4844,25 @@ def around(x, decimals=0, out=None, **kwargs):
raise TypeError('type {} not supported'.format(str(type(x))))
+@set_module('mxnet.ndarray.numpy')
+def round(x, decimals=0, out=None, **kwargs):
+ r"""
+ round_(a, decimals=0, out=None)
+ Round an array to the given number of decimals.
+
+ See Also
+ --------
+ around : equivalent function; see for details.
+ """
+ from ...numpy import ndarray
+ if isinstance(x, numeric_types):
+ return _np.around(x, decimals, **kwargs)
+ elif isinstance(x, ndarray):
+ return _npi.around(x, decimals, out=out, **kwargs)
+ else:
+ raise TypeError('type {} not supported'.format(str(type(x))))
+
+
@set_module('mxnet.ndarray.numpy')
@wrap_np_binary_func
def arctan2(x1, x2, out=None, **kwargs):
diff --git a/python/mxnet/ndarray/numpy/linalg.py b/python/mxnet/ndarray/numpy/linalg.py
index 4c49c35b4a44..e4fee158bea4 100644
--- a/python/mxnet/ndarray/numpy/linalg.py
+++ b/python/mxnet/ndarray/numpy/linalg.py
@@ -21,7 +21,7 @@
from . import _op as _mx_nd_np
from . import _internal as _npi
-__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'solve', 'tensorinv']
+__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'solve', 'tensorinv', 'tensorsolve']
def norm(x, ord=None, axis=None, keepdims=False):
@@ -461,3 +461,51 @@ def tensorinv(a, ind=2):
True
"""
return _npi.tensorinv(a, ind)
+
+
+def tensorsolve(a, b, axes=None):
+ r"""
+ Solve the tensor equation ``a x = b`` for x.
+ It is assumed that all indices of `x` are summed over in the product,
+ together with the rightmost indices of `a`, as is done in, for example,
+ ``tensordot(a, x, axes=b.ndim)``.
+
+ Parameters
+ ----------
+ a : ndarray
+ Coefficient tensor, of shape ``b.shape + Q``. `Q`, a tuple, equals
+ the shape of that sub-tensor of `a` consisting of the appropriate
+ number of its rightmost indices, and must be such that
+ ``prod(Q) == prod(b.shape)`` (in which sense `a` is said to be
+ 'square').
+ b : ndarray
+ Right-hand tensor, which can be of any shape.
+ axes : tuple of ints, optional
+ Axes in `a` to reorder to the right, before inversion.
+ If None (default), no reordering is done.
+
+ Returns
+ -------
+ x : ndarray, shape Q
+
+ Raises
+ ------
+ MXNetError
+ If `a` is singular or not 'square' (in the above sense).
+
+ See Also
+ --------
+ numpy.tensordot, tensorinv, numpy.einsum
+
+ Examples
+ --------
+ >>> a = np.eye(2*3*4)
+ >>> a.shape = (2*3, 4, 2, 3, 4)
+ >>> b = np.random.randn(2*3, 4)
+ >>> x = np.linalg.tensorsolve(a, b)
+ >>> x.shape
+ (2, 3, 4)
+ >>> np.allclose(np.tensordot(a, x, axes=3), b)
+ True
+ """
+ return _npi.tensorsolve(a, b, axes)
diff --git a/python/mxnet/numpy/linalg.py b/python/mxnet/numpy/linalg.py
index 2ee2d2670693..96fe1d311028 100644
--- a/python/mxnet/numpy/linalg.py
+++ b/python/mxnet/numpy/linalg.py
@@ -20,7 +20,7 @@
from __future__ import absolute_import
from ..ndarray import numpy as _mx_nd_np
-__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'solve', 'tensorinv']
+__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'solve', 'tensorinv', 'tensorsolve']
def norm(x, ord=None, axis=None, keepdims=False):
@@ -479,3 +479,51 @@ def tensorinv(a, ind=2):
True
"""
return _mx_nd_np.linalg.tensorinv(a, ind)
+
+
+def tensorsolve(a, b, axes=None):
+ r"""
+ Solve the tensor equation ``a x = b`` for x.
+ It is assumed that all indices of `x` are summed over in the product,
+ together with the rightmost indices of `a`, as is done in, for example,
+ ``tensordot(a, x, axes=b.ndim)``.
+
+ Parameters
+ ----------
+ a : ndarray
+ Coefficient tensor, of shape ``b.shape + Q``. `Q`, a tuple, equals
+ the shape of that sub-tensor of `a` consisting of the appropriate
+ number of its rightmost indices, and must be such that
+ ``prod(Q) == prod(b.shape)`` (in which sense `a` is said to be
+ 'square').
+ b : ndarray
+ Right-hand tensor, which can be of any shape.
+ axes : tuple of ints, optional
+ Axes in `a` to reorder to the right, before inversion.
+ If None (default), no reordering is done.
+
+ Returns
+ -------
+ x : ndarray, shape Q
+
+ Raises
+ ------
+ MXNetError
+ If `a` is singular or not 'square' (in the above sense).
+
+ See Also
+ --------
+ numpy.tensordot, tensorinv, numpy.einsum
+
+ Examples
+ --------
+ >>> a = np.eye(2*3*4)
+ >>> a.shape = (2*3, 4, 2, 3, 4)
+ >>> b = np.random.randn(2*3, 4)
+ >>> x = np.linalg.tensorsolve(a, b)
+ >>> x.shape
+ (2, 3, 4)
+ >>> np.allclose(np.tensordot(a, x, axes=3), b)
+ True
+ """
+ return _mx_nd_np.linalg.tensorsolve(a, b, axes)
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index 5a4035e15223..3eab52596e0d 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -47,7 +47,7 @@
from ..ndarray.ndarray import _storage_type
__all__ = ['ndarray', 'empty', 'array', 'shape', 'zeros', 'zeros_like', 'ones', 'ones_like', 'full', 'full_like',
- 'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'bitwise_not',
+ 'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'bitwise_not', 'delete',
'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'invert',
'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log',
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'histogram',
@@ -1377,7 +1377,7 @@ def argsort(self, axis=-1, kind=None, order=None): # pylint: disable=arguments-
The arguments are the same as for :py:func:`argsort`, with
this array as data.
"""
- raise argsort(self, axis=axis, kind=kind, order=order)
+ return argsort(self, axis=axis, kind=kind, order=order)
def argmax_channel(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`argmax_channel`.
@@ -1558,13 +1558,13 @@ def norm(self, *args, **kwargs):
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute norm')
- def round(self, *args, **kwargs):
+ def round(self, decimals=0, out=None, **kwargs): # pylint: disable=arguments-differ
"""Convenience fluent method for :py:func:`round`.
The arguments are the same as for :py:func:`round`, with
this array as data.
"""
- raise NotImplementedError
+ return round(self, decimals=decimals, out=out, **kwargs)
def rint(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`rint`.
@@ -5887,6 +5887,55 @@ def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylint:
return _npi.std(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out)
+@set_module('mxnet.numpy')
+def delete(arr, obj, axis=None):
+ """
+ Return a new array with sub-arrays along an axis deleted. For a one
+ dimensional array, this returns those entries not returned by
+ `arr[obj]`.
+
+ Parameters
+ ----------
+ arr : ndarray
+ Input array.
+ obj : slice, int or ndarray of ints
+ Indicate indices of sub-arrays to remove along the specified axis.
+ axis : int, optional
+ The axis along which to delete the subarray defined by `obj`.
+ If `axis` is None, `obj` is applied to the flattened array.
+
+ Returns
+ -------
+ out : ndarray
+ A copy of `arr` with the elements specified by `obj` removed. Note
+ that `delete` does not occur in-place. If `axis` is None, `out` is
+ a flattened array.
+
+ Examples
+ --------
+ >>> arr = np.array([[1,2,3,4], [5,6,7,8], [9,10,11,12]])
+ >>> arr
+ array([[ 1., 2., 3., 4.],
+ [ 5., 6., 7., 8.],
+ [ 9., 10., 11., 12.]])
+
+ >>> np.delete(arr, 1, 0)
+ array([[ 1., 2., 3., 4.],
+ [ 9., 10., 11., 12.]])
+
+ >>> np.delete(arr, slice(None, None, 2), 1)
+ array([[ 2., 4.],
+ [ 6., 8.],
+ [10., 12.]])
+
+ >>> np.delete(arr, np.array([1,3,5]), None)
+ array([ 1., 3., 5., 7., 8., 9., 10., 11., 12.])
+ >>> np.delete(arr, np.array([1,1,5]), None)
+ array([ 1., 3., 4., 5., 7., 8., 9., 10., 11., 12.])
+ """
+ return _mx_nd_np.delete(arr, obj, axis=axis)
+
+
@set_module('mxnet.numpy')
def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylint: disable=too-many-arguments
"""
@@ -6495,6 +6544,19 @@ def around(x, decimals=0, out=None, **kwargs):
return _mx_nd_np.around(x, decimals, out=out, **kwargs)
+@set_module('mxnet.numpy')
+def round(x, decimals=0, out=None, **kwargs):
+ r"""
+ round_(a, decimals=0, out=None)
+ Round an array to the given number of decimals.
+
+ See Also
+ --------
+ around : equivalent function; see for details.
+ """
+ return _mx_nd_np.around(x, decimals, out=out, **kwargs)
+
+
@set_module('mxnet.numpy')
@wrap_np_binary_func
def arctan2(x1, x2, out=None, **kwargs):
diff --git a/python/mxnet/numpy/random.py b/python/mxnet/numpy/random.py
index ebc24de63282..95719a005cec 100644
--- a/python/mxnet/numpy/random.py
+++ b/python/mxnet/numpy/random.py
@@ -20,7 +20,7 @@
from __future__ import absolute_import
from ..ndarray import numpy as _mx_nd_np
-__all__ = ["randint", "uniform", "normal", "choice", "rand", "multinomial", "shuffle"]
+__all__ = ["randint", "uniform", "normal", "choice", "rand", "multinomial", "shuffle", "randn"]
def randint(low, high=None, size=None, dtype=None, ctx=None, out=None):
@@ -357,3 +357,44 @@ def shuffle(x):
[0., 1., 2.]])
"""
_mx_nd_np.random.shuffle(x)
+
+
+def randn(*size, **kwargs):
+ r"""Return a sample (or samples) from the "standard normal" distribution.
+ If positive, int_like or int-convertible arguments are provided,
+ `randn` generates an array of shape ``(d0, d1, ..., dn)``, filled
+ with random floats sampled from a univariate "normal" (Gaussian)
+ distribution of mean 0 and variance 1 (if any of the :math:`d_i` are
+ floats, they are first converted to integers by truncation). A single
+ float randomly sampled from the distribution is returned if no
+ argument is provided.
+ This is a convenience function. If you want an interface that takes a
+ tuple as the first argument, use `numpy.random.standard_normal` instead.
+ Parameters
+ ----------
+ d0, d1, ..., dn : int, optional
+ The dimensions of the returned array, should be all positive.
+ If no argument is given a single Python float is returned.
+ Returns
+ -------
+ Z : ndarray
+ A ``(d0, d1, ..., dn)``-shaped array of floating-point samples from
+ the standard normal distribution, or a single such float if
+ no parameters were supplied.
+ Notes
+ -----
+ For random samples from :math:`N(\mu, \sigma^2)`, use:
+ ``sigma * np.random.randn(...) + mu``
+ Examples
+ --------
+ >>> np.random.randn()
+ 2.1923875335537315 #random
+ Two-by-four array of samples from N(3, 6.25):
+ >>> 2.5 * np.random.randn(2, 4) + 3
+ array([[-4.49401501, 4.00950034, -1.81814867, 7.29718677], #random
+ [ 0.39924804, 4.68456316, 4.99394529, 4.84057254]]) #random
+ """
+ output_shape = ()
+ for s in size:
+ output_shape += (s,)
+ return _mx_nd_np.random.normal(0, 1, size=output_shape, **kwargs)
diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py
index c7e9dd1398eb..65486e6e5f37 100644
--- a/python/mxnet/numpy_dispatch_protocol.py
+++ b/python/mxnet/numpy_dispatch_protocol.py
@@ -83,9 +83,12 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
_NUMPY_ARRAY_FUNCTION_LIST = [
+ 'all',
+ 'any',
'argmin',
'argmax',
'around',
+ 'round',
'argsort',
'append',
'broadcast_arrays',
@@ -103,6 +106,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'flip',
'inner',
'max',
+ 'amax',
'mean',
'min',
'nonzero',
@@ -125,6 +129,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'transpose',
'unique',
'unravel_index',
+ 'delete',
'var',
'vdot',
'vstack',
@@ -135,6 +140,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'linalg.inv',
'linalg.solve',
'linalg.tensorinv',
+ 'linalg.tensorsolve',
'shape',
'trace',
'tril',
diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py
index 86eab23ddfea..6f9d32abd336 100644
--- a/python/mxnet/symbol/numpy/_symbol.py
+++ b/python/mxnet/symbol/numpy/_symbol.py
@@ -36,7 +36,7 @@
except ImportError:
from builtins import slice as py_slice
-__all__ = ['zeros', 'zeros_like', 'ones', 'ones_like', 'full_like', 'bitwise_not', 'invert',
+__all__ = ['zeros', 'zeros_like', 'ones', 'ones_like', 'full_like', 'bitwise_not', 'invert', 'delete',
'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'arctan2',
'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp',
'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p',
@@ -492,7 +492,7 @@ def argsort(self, axis=-1, kind=None, order=None): # pylint: disable=arguments-
The arguments are the same as for :py:func:`argsort`, with
this array as data.
"""
- raise argsort(self, axis=axis, kind=kind, order=order)
+ return argsort(self, axis=axis, kind=kind, order=order)
def argmax_channel(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`argmax_channel`.
@@ -666,13 +666,13 @@ def norm(self, *args, **kwargs):
"""
raise AttributeError('_Symbol object has no attribute norm')
- def round(self, *args, **kwargs):
+ def round(self, decimals=0, out=None, **kwargs): # pylint: disable=arguments-differ
"""Convenience fluent method for :py:func:`round`.
The arguments are the same as for :py:func:`round`, with
this array as data.
"""
- raise NotImplementedError
+ return round(self, decimals=decimals, out=out, **kwargs)
def rint(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`rint`.
@@ -3162,6 +3162,45 @@ def arange(start, stop=None, step=1, dtype=None, ctx=None):
return _npi.arange(start=start, stop=stop, step=step, dtype=dtype, ctx=ctx)
+@set_module('mxnet.symbol.numpy')
+def delete(arr, obj, axis=None):
+ """
+ Return a new array with sub-arrays along an axis deleted. For a one
+ dimensional array, this returns those entries not returned by
+ `arr[obj]`.
+
+ Parameters
+ ----------
+ arr : _Symbol
+ Input array.
+ obj : slice, scaler or _Symbol of ints
+ Indicate indices of sub-arrays to remove along the specified axis.
+ axis : scaler, optional
+ The axis along which to delete the subarray defined by `obj`.
+ If `axis` is None, `obj` is applied to the flattened array.
+
+ Returns
+ -------
+ out : _Symbol
+ A copy of `arr` with the elements specified by `obj` removed. Note
+ that `delete` does not occur in-place. If `axis` is None, `out` is
+ a flattened array.
+ """
+ if not isinstance(arr, Symbol):
+ raise TypeError("'arr' can not support type {}".format(str(type(arr))))
+ if isinstance(obj, slice):
+ start = obj.start
+ stop = obj.stop
+ step = 1 if obj.step is None else obj.step
+ return _npi.delete(arr, start=start, stop=stop, step=step, axis=axis)
+ elif isinstance(obj, integer_types):
+ return _npi.delete(arr, int_ind=obj, axis=axis)
+ elif isinstance(obj, Symbol):
+ return _npi.delete(arr, obj, axis=axis)
+ else:
+ raise TypeError("'obj' can not support type {}".format(str(type(obj))))
+
+
# pylint: disable=redefined-outer-name
@set_module('mxnet.symbol.numpy')
def split(ary, indices_or_sections, axis=0):
@@ -4554,6 +4593,24 @@ def around(x, decimals=0, out=None, **kwargs):
raise TypeError('type {} not supported'.format(str(type(x))))
+@set_module('mxnet.symbol.numpy')
+def round(x, decimals=0, out=None, **kwargs):
+ r"""
+ round_(a, decimals=0, out=None)
+ Round an array to the given number of decimals.
+
+ See Also
+ --------
+ around : equivalent function; see for details.
+ """
+ if isinstance(x, numeric_types):
+ return _np.around(x, decimals, **kwargs)
+ elif isinstance(x, _Symbol):
+ return _npi.around(x, decimals, out=out, **kwargs)
+ else:
+ raise TypeError('type {} not supported'.format(str(type(x))))
+
+
@set_module('mxnet.symbol.numpy')
@wrap_np_binary_func
def arctan2(x1, x2, out=None, **kwargs):
diff --git a/python/mxnet/symbol/numpy/linalg.py b/python/mxnet/symbol/numpy/linalg.py
index a445c79001ec..0bfbb6ee540f 100644
--- a/python/mxnet/symbol/numpy/linalg.py
+++ b/python/mxnet/symbol/numpy/linalg.py
@@ -22,7 +22,7 @@
from . import _op as _mx_sym_np
from . import _internal as _npi
-__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'solve', 'tensorinv']
+__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'solve', 'tensorinv', 'tensorsolve']
def norm(x, ord=None, axis=None, keepdims=False):
@@ -448,3 +448,51 @@ def tensorinv(a, ind=2):
True
"""
return _npi.tensorinv(a, ind)
+
+
+def tensorsolve(a, b, axes=None):
+ r"""
+ Solve the tensor equation ``a x = b`` for x.
+ It is assumed that all indices of `x` are summed over in the product,
+ together with the rightmost indices of `a`, as is done in, for example,
+ ``tensordot(a, x, axes=b.ndim)``.
+
+ Parameters
+ ----------
+ a : ndarray
+ Coefficient tensor, of shape ``b.shape + Q``. `Q`, a tuple, equals
+ the shape of that sub-tensor of `a` consisting of the appropriate
+ number of its rightmost indices, and must be such that
+ ``prod(Q) == prod(b.shape)`` (in which sense `a` is said to be
+ 'square').
+ b : ndarray
+ Right-hand tensor, which can be of any shape.
+ axes : tuple of ints, optional
+ Axes in `a` to reorder to the right, before inversion.
+ If None (default), no reordering is done.
+
+ Returns
+ -------
+ x : ndarray, shape Q
+
+ Raises
+ ------
+ MXNetError
+ If `a` is singular or not 'square' (in the above sense).
+
+ See Also
+ --------
+ numpy.tensordot, tensorinv, numpy.einsum
+
+ Examples
+ --------
+ >>> a = np.eye(2*3*4)
+ >>> a.shape = (2*3, 4, 2, 3, 4)
+ >>> b = np.random.randn(2*3, 4)
+ >>> x = np.linalg.tensorsolve(a, b)
+ >>> x.shape
+ (2, 3, 4)
+ >>> np.allclose(np.tensordot(a, x, axes=3), b)
+ True
+ """
+ return _npi.tensorsolve(a, b, axes)
diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h
index e3a3c0443428..4176d3a68792 100644
--- a/src/operator/mshadow_op.h
+++ b/src/operator/mshadow_op.h
@@ -1101,6 +1101,14 @@ struct minimum : public mxnet_op::tunable {
}
};
+/*! \brief boolean any/all kernel that determines whether elem is NonZero */
+struct NonZero {
+ template
+ MSHADOW_XINLINE static bool Map(DType a) {
+ return (a != DType(0));
+ }
+};
+
/*! \brief sum reducer that ignores NaN values in the input */
struct nansum {
/*! \brief do reduction into dst */
diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h
index b15117f9f83b..d7752c4759db 100644
--- a/src/operator/mxnet_op.h
+++ b/src/operator/mxnet_op.h
@@ -1148,6 +1148,29 @@ struct set_to_int : public tunable {
*/
using set_zero = set_to_int<0>;
using set_one = set_to_int<1>;
+
+/*!
+ * \brief Set to immediate scalar value kernel
+ * \tparam val Scalar immediate
+ */
+template
+struct set_to_bool : public tunable {
+ // mxnet_op version (when used directly with Kernel<>::Launch()) */
+ template
+ MSHADOW_XINLINE static void Map(index_t i, DType *out) {
+ out[i] = DType(val);
+ }
+ // mshadow_op version (when used with op_with_req<>)
+ MSHADOW_XINLINE static int Map() {
+ return val;
+ }
+};
+
+/*!
+ * \brief Special-case kernel shortcut for setting to true and false
+ */
+using set_true = set_to_bool;
+using set_false = set_to_bool;
} // namespace mxnet_op
} // namespace op
diff --git a/src/operator/numpy/linalg/np_tensorsolve-inl.h b/src/operator/numpy/linalg/np_tensorsolve-inl.h
new file mode 100644
index 000000000000..829a119b64a2
--- /dev/null
+++ b/src/operator/numpy/linalg/np_tensorsolve-inl.h
@@ -0,0 +1,557 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file np_tensorsolve-inl.h
+ * \brief Placeholder for tensor solve
+ */
+#ifndef MXNET_OPERATOR_NUMPY_LINALG_NP_TENSORSOLVE_INL_H_
+#define MXNET_OPERATOR_NUMPY_LINALG_NP_TENSORSOLVE_INL_H_
+
+#include
+#include
+#include "../../operator_common.h"
+#include "../../mshadow_op.h"
+#include "../../tensor/la_op.h"
+#include "../../tensor/la_op-inl.h"
+#include "../np_tensordot_op-inl.h"
+#include "./np_solve-inl.h"
+
+namespace mxnet {
+namespace op {
+
+using namespace mshadow;
+
+struct TensorsolveParam : public dmlc::Parameter {
+ mxnet::Tuple a_axes;
+ DMLC_DECLARE_PARAMETER(TensorsolveParam) {
+ DMLC_DECLARE_FIELD(a_axes)
+ .set_default(mxnet::Tuple())
+ .describe("Tuple of ints, optional. Axes in a to reorder to the right, before inversion.");
+ }
+};
+
+// Fix negative axes.
+inline void FixNegativeAxes(mxnet::Tuple *a_axes_param,
+ const mxnet::TShape& a_shape) {
+ if (-1 == a_axes_param->ndim()) { return; }
+ const int a_ndim = a_shape.ndim();
+ for (auto& i : *a_axes_param) {
+ i = (i + a_ndim) % a_ndim;
+ }
+}
+
+// Get remained axes and axes of a.
+inline void GetReorderedAxes(const mxnet::Tuple& a_axes_param,
+ mxnet::Tuple *a_axes_remained,
+ mxnet::Tuple *a_axes,
+ const mxnet::TShape& a_shape) {
+ std::vector a_axes_vec;
+ for (int i = 0; i < a_shape.ndim(); ++i) {
+ a_axes_vec.push_back(i);
+ }
+ // Get remained axes and axes.
+ if (-1 == a_axes_param.ndim()) {
+ *a_axes_remained = mxnet::Tuple(a_axes_vec);
+ *a_axes = mxnet::Tuple(a_axes_vec);
+ return;
+ }
+ for (const auto& i : a_axes_param) {
+ a_axes_vec.erase(std::find(a_axes_vec.begin(), a_axes_vec.end(), i));
+ }
+ *a_axes_remained = mxnet::Tuple(a_axes_vec);
+
+ a_axes_vec.clear();
+ for (const auto& i : *a_axes_remained) {
+ a_axes_vec.push_back(i);
+ }
+ for (const auto& i : a_axes_param) {
+ a_axes_vec.push_back(i);
+ }
+ *a_axes = mxnet::Tuple(a_axes_vec);
+}
+
+// Calculate output shape if a and b is tensor
+inline mxnet::TShape GetOutShape(const mxnet::TShape& a_shape,
+ const mxnet::TShape& b_shape) {
+ const int a_ndim = a_shape.ndim(), b_ndim = b_shape.ndim();
+ const int temp = a_ndim > b_ndim ? b_ndim : b_ndim - a_ndim;
+ mxnet::TShape out_shape(a_ndim - temp, -1);
+ for (int i = temp; i < a_ndim; ++i) {
+ out_shape[i - temp] = a_shape[i];
+ }
+ return out_shape;
+}
+
+// Calculates workspace size of tensorsolve forward.
+template
+size_t TensorsolveForwardWorkspaceSize(const Tuple& a_axes_param,
+ const TBlob& a,
+ const TBlob& b,
+ const TBlob& out,
+ const std::vector& req) {
+ if (kNullOp == req[0]) { return 0U; }
+
+ // Zero-size output, no need to launch kernel
+ if (0U == out.shape_.Size()) { return 0U; }
+
+ const mxnet::TShape& a_shape = a.shape_;
+ const mxnet::TShape& b_shape = b.shape_;
+ MSHADOW_SGL_DBL_TYPE_SWITCH(out.type_flag_, DType, {
+ if (0U == a_shape.Size() || 0U == b_shape.Size()) {
+ // 0-size input
+ return 0U;
+ } else if (0 == a_shape.ndim() || 0 == b_shape.ndim()) {
+ // At least 1 scalar.
+ return (a.Size() + b.Size()) * sizeof(DType) + b.Size() * sizeof(int);
+ } else {
+ // Two tensors of at least 1 dimensions.
+ return (2 * a.Size() + b.Size()) * sizeof(DType) + b.Size() * sizeof(int);
+ }
+ });
+ LOG(FATAL) << "InternalError: cannot reach here";
+ return 0U;
+}
+
+template
+struct assign_helper {
+ template
+ MSHADOW_XINLINE static void Map(int i, const DType *in_data, DType *out_data) {
+ KERNEL_ASSIGN(out_data[i], req, in_data[i]);
+ }
+};
+
+struct tensorsolve {
+ template
+ static void op(const Tensor& A,
+ const Tensor& X,
+ const Tensor& ipiv,
+ const OpContext& ctx) {
+ mshadow::Stream *s = ctx.get_stream();
+ linalg_solve(A, X, ipiv, s); // ipiv for work_space in Lapacke_#gesv
+ }
+};
+
+template
+void TensorsolveOpForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector& inputs,
+ const std::vector& req,
+ const std::vector& outputs) {
+ CHECK_EQ(inputs.size(), 2U);
+ CHECK_EQ(outputs.size(), 1U);
+ CHECK_EQ(req.size(), 1U);
+
+ mshadow::Stream *s = ctx.get_stream();
+ const TBlob& a = inputs[0];
+ const TBlob& b = inputs[1];
+ const TBlob& out = outputs[0];
+ const mxnet::TShape a_shape = a.shape_;
+ const mxnet::TShape b_shape = b.shape_;
+ const mxnet::TShape out_shape = out.shape_;
+ const TensorsolveParam& param = nnvm::get(attrs.parsed);
+ mxnet::Tuple a_axes_param = param.a_axes;
+ FixNegativeAxes(&a_axes_param, a_shape);
+
+ size_t workspace_size = TensorsolveForwardWorkspaceSize(a_axes_param, a, b, out, req);
+ Tensor workspace = ctx.requested[0].get_space_typed(
+ Shape1(workspace_size), ctx.get_stream());
+
+ if (kNullOp == req[0]) { return; }
+
+ // Zero-size output, no need to launch kernel
+ if (0U == out.shape_.Size()) { return; }
+
+ MSHADOW_SGL_DBL_TYPE_SWITCH(out.type_flag_, DType, {
+ if (0U == a_shape.Size() || 0U == b_shape.Size()) { // 0-size input
+ if (req[0] != kAddTo) {
+ Tensor out_tensor =
+ out.get_with_shape(Shape1(out.shape_.Size()), s);
+ out_tensor = static_cast(0);
+ }
+ } else if (0U == a_shape.ndim() || 0U == b_shape.ndim()) { // At least 1 scalar.
+ // Check again
+ CHECK_EQ(a_shape.Size(), 1U)
+ << "a's and b's dimensions don't match";
+ CHECK_EQ(b_shape.Size(), 1U)
+ << "a's and b's dimensions don't match";
+
+ DType* a_ptr =
+ reinterpret_cast(workspace.dptr_);
+ DType* b_ptr =
+ reinterpret_cast(workspace.dptr_+ a.Size() * sizeof(DType));
+ int* ipiv_ptr =
+ reinterpret_cast(workspace.dptr_ + (a.Size() + b.Size()) * sizeof(DType));
+
+ // Cast type
+ MSHADOW_TYPE_SWITCH(a.type_flag_, AType, {
+ mxnet_op::Kernel::Launch(
+ s, a_shape.Size(), a_ptr, a.dptr());
+ });
+ MSHADOW_TYPE_SWITCH(b.type_flag_, BType, {
+ mxnet_op::Kernel::Launch(
+ s, b_shape.Size(), b_ptr, b.dptr());
+ });
+
+ mxnet::TBlob a_tblob(a_ptr, Shape2(1, 1), a.dev_mask(), a.dev_id());
+ mxnet::TBlob b_tblob(b_ptr, Shape2(1, 1), b.dev_mask(), b.dev_id());
+ mxnet::TBlob ipiv_tblob(ipiv_ptr, Shape1(1), out.dev_mask(), out.dev_id());
+ Tensor a_tensor = a_tblob.get(s);
+ Tensor b_tensor = b_tblob.get(s);
+ Tensor ipiv_tensor = ipiv_tblob.get(s);
+
+ // Solve linear equation
+ laop::op(a_tensor, b_tensor, ipiv_tensor, ctx);
+ MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
+ mxnet_op::Kernel, xpu>::Launch(
+ s, out_shape.Size(), b_tensor.dptr_, out.dptr());
+ });
+ } else {
+ // Two tensors of at least 1 dimensions.
+ Tuple a_axes_remained;
+ Tuple a_axes;
+ GetReorderedAxes(a_axes_param, &a_axes_remained, &a_axes, a_shape);
+ mxnet::TShape a_transpose_shape = GetReorderedShape(a_shape, a_axes);
+ const int N = b_shape.Size();
+
+ DType* a_ptr =
+ reinterpret_cast(workspace.dptr_);
+ DType* a_trans_ptr =
+ reinterpret_cast(workspace.dptr_ + a.Size() * sizeof(DType));
+ DType* b_ptr =
+ reinterpret_cast(workspace.dptr_ + 2 * a.Size() * sizeof(DType));
+ int* ipiv_ptr =
+ reinterpret_cast(workspace.dptr_ + (2 * a.Size() + b.Size()) * sizeof(DType));
+
+ // Cast type
+ MSHADOW_TYPE_SWITCH(a.type_flag_, AType, {
+ mxnet_op::Kernel::Launch(
+ s, a_shape.Size(), a_ptr, a.dptr());
+ });
+ // Cast type
+ MSHADOW_TYPE_SWITCH(b.type_flag_, BType, {
+ mxnet_op::Kernel::Launch(
+ s, b_shape.Size(), b_ptr, b.dptr());
+ });
+
+ mxnet::TBlob a_tblob =
+ TBlob(a_ptr, a_shape, a.dev_mask(), a.dev_id());
+ mxnet::TBlob a_transpose_tblob =
+ TBlob(a_trans_ptr, a_transpose_shape, a.dev_mask(), a.dev_id());
+ mxnet::TBlob b_tblob =
+ TBlob(b_ptr, b_shape, b.dev_mask(), b.dev_id());
+ mxnet::TBlob ipiv_tblob =
+ TBlob(ipiv_ptr, b_shape, out.dev_mask(), out.dev_id());
+ mxnet::op::TransposeImpl(ctx.run_ctx,
+ a_tblob, // src
+ a_transpose_tblob, // res
+ mxnet::TShape(a_axes.begin(), a_axes.end()));
+
+ Tensor a_tensor =
+ a_tblob.get_with_shape(Shape2(N, N), s);
+ Tensor ipiv_tensor =
+ ipiv_tblob.get_with_shape(Shape1(N), s);
+ Tensor b_tensor =
+ b_tblob.get_with_shape(Shape2(1, N), s);
+ Tensor out_tensor =
+ out.get_with_shape(Shape2(1, N), s);
+
+ a_tblob = a_tblob.reshape(Shape2(N, N));
+ a_transpose_tblob = a_transpose_tblob.reshape(Shape2(N, N));
+ Tuple a_axes_2D(std::vector{1, 0});
+ mxnet::op::TransposeImpl(ctx.run_ctx,
+ a_transpose_tblob, // src
+ a_tblob, // res
+ mxnet::TShape(a_axes_2D.begin(), a_axes_2D.end()));
+ // Solve linear equation
+ laop::op(a_tensor, b_tensor, ipiv_tensor, ctx);
+ MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
+ mxnet_op::Kernel, xpu>::Launch(
+ s, out_shape.Size(), b_tensor.dptr_, out_tensor.dptr_);
+ });
+ }
+ });
+}
+
+// Calculates workspace size of tensordot backward.
+template
+size_t TensorsolveBackwardWorkspaceSize(const TBlob& out_grad,
+ const TBlob& a,
+ const TBlob& b,
+ const TBlob& x) {
+ const mxnet::TShape& a_shape = a.shape_;
+ const mxnet::TShape& b_shape = b.shape_;
+ const mxnet::TShape& x_shape = x.shape_;
+
+ // Zero-size output, no need to launch kernel
+ if (0U == a_shape.Size() || 0U == b_shape.Size()) { return 0U; }
+
+ MSHADOW_SGL_DBL_TYPE_SWITCH(out_grad.type_flag_, DType, {
+ int work_space_size = 0;
+ if (0U == a_shape.ndim() || 0U == b_shape.ndim()) {
+ // At least 1 scalar.
+ work_space_size += sizeof(DType) * a_shape.Size(); // for tensorinv(a)
+ work_space_size += sizeof(DType) * a_shape.Size(); // for getri work space lu
+ work_space_size += sizeof(int) * b_shape.Size(); // for getri work space pivot
+ } else {
+ // Two tensors of at least 1 dimensions.
+ work_space_size += sizeof(DType) * a_shape.Size(); // for tensorinv(a)
+ work_space_size += sizeof(DType) * a_shape.Size(); // for getri work space lu
+ work_space_size += sizeof(DType) * b_shape.Size(); // for b
+ work_space_size += sizeof(DType) * x_shape.Size(); // for x
+ work_space_size += sizeof(DType) * a_shape.Size(); // for grad_a
+ work_space_size += sizeof(DType) * b_shape.Size(); // for grad_b
+ work_space_size += sizeof(int) * b_shape.Size(); // for getri work space pivot
+ }
+ return work_space_size;
+ });
+ LOG(FATAL) << "InternalError: cannot reach here";
+ return 0U;
+}
+
+// Get original axes for tensor a.
+inline void GetOriginAxes(const mxnet::TShape& a_shape,
+ const mxnet::Tuple& a_axes,
+ mxnet::Tuple *a_origin_axes) {
+ std::vector a_origin_axes_vec(a_shape.ndim(), -1);
+ for (int i = 0; i < a_shape.ndim(); ++i) {
+ a_origin_axes_vec[a_axes[i]] = i;
+ }
+ *a_origin_axes = mxnet::Tuple(a_origin_axes_vec);
+}
+
+struct tensorsolve_backward {
+ template
+ static void op(const Tensor& dX,
+ const Tensor& inv_A,
+ const Tensor& B,
+ const Tensor& X,
+ const Tensor& dA,
+ const Tensor& dB,
+ const OpContext& ctx) {
+ // (1) calcualte dB = trans(tensorinv(A)) * dX
+ // (2) calcualte dA = dB * trans(X)
+ Stream *s = ctx.get_stream();
+ gemm2::op(inv_A, dX, dB, DType(1), true, false, s);
+ gemm2::op(dB, X, dA, DType(-1), false, true, s);
+ }
+};
+
+template
+void TensorsolveBackwardImpl(const Tuple& a_axes_param,
+ const TBlob& out_grad,
+ const TBlob& a,
+ const TBlob& b,
+ const TBlob& x,
+ const TBlob& grad_a,
+ const TBlob& grad_b,
+ const OpContext& ctx,
+ const std::vector& req,
+ const Tensor& workspace) {
+ mshadow::Stream *s = ctx.get_stream();
+ const mxnet::TShape& a_shape = a.shape_;
+ const mxnet::TShape& b_shape = b.shape_;
+ const mxnet::TShape& x_shape = x.shape_;
+
+ if (kNullOp == req[0] && kNullOp == req[1]) { return; }
+
+ // Zero-size output, no need to launch kernel
+ if (0U == a_shape.Size() || 0U == b_shape.Size()) { return; }
+
+ MSHADOW_SGL_DBL_TYPE_SWITCH(out_grad.type_flag_, DType, {
+ if (0 == a_shape.ndim() || 0 == b_shape.ndim()) {
+ // At least 1 scalar.
+ CHECK_EQ(a_shape.Size(), 1U)
+ << "a's and b's dimensions don't match";
+ CHECK_EQ(b_shape.Size(), 1U)
+ << "a's and b's dimensions don't match";
+
+ // Allocate workspace.
+ DType *tensorinv_a_ptr = reinterpret_cast(workspace.dptr_);
+ DType *lu_ptr = reinterpret_cast(workspace.dptr_ + a_shape.Size() * sizeof(DType));
+ int *ipiv_ptr = reinterpret_cast(workspace.dptr_ + 2 * a_shape.Size() * sizeof(DType));
+ TBlob tensorinv_a(tensorinv_a_ptr, a_shape, xpu::kDevMask);
+ TBlob lu(lu_ptr, a_shape, xpu::kDevMask);
+ TBlob ipiv(ipiv_ptr, b_shape, xpu::kDevMask);
+
+ MSHADOW_TYPE_SWITCH(a.type_flag_, AType, {
+ mxnet_op::Kernel::Launch(
+ s, a_shape.Size(),
+ tensorinv_a_ptr,
+ a.dptr());
+ });
+ // Calculate tensorinv(a)
+ Tensor tensorinv_a_tensor =
+ tensorinv_a.get_with_shape(Shape3(1, 1, 1), s);
+ Tensor lu_tensor =
+ lu.get_with_shape(Shape3(1, 1, 1), s);
+ Tensor ipiv_tensor =
+ ipiv.get_with_shape(Shape2(1, 1), s);
+ batch_inverse(tensorinv_a_tensor, lu_tensor, ipiv_tensor, ctx);
+
+ MSHADOW_TYPE_SWITCH(x.type_flag_, XType, {
+ DType temp1 = (*(tensorinv_a_tensor.dptr_)) * (*(out_grad.dptr()));
+ DType temp2 = -temp1 * static_cast(*x.dptr());
+ ASSIGN_DISPATCH(*grad_b.dptr(), req[1], temp1);
+ ASSIGN_DISPATCH(*grad_a.dptr(), req[0], temp2);
+ });
+ } else {
+ // Two tensors of at least 1 dimensions.
+ const int N = b_shape.Size();
+ Tuple a_axes_remained;
+ Tuple a_axes;
+ Tuple a_origin_axes;
+ // Use a_axes to transpose (a_shape) --> (a_reordered_shape).
+ GetReorderedAxes(a_axes_param, &a_axes_remained, &a_axes, a_shape);
+ // Use a_origin_axes to transpose (a_reordered_shape) --> (a_shape).
+ GetOriginAxes(a_shape, a_axes, &a_origin_axes);
+ mxnet::TShape reordered_a_shape = GetReorderedShape(a_shape, a_axes);
+
+ // Allocate workspace.
+ DType *tensorinv_a_ptr = reinterpret_cast(
+ workspace.dptr_);
+ DType *lu_ptr = reinterpret_cast(
+ workspace.dptr_ + a_shape.Size() * sizeof(DType));
+ DType *b_ptr = reinterpret_cast(
+ workspace.dptr_ + 2 * a_shape.Size() * sizeof(DType));
+ DType *x_ptr = reinterpret_cast(
+ workspace.dptr_ + (2 * a_shape.Size() + b_shape.Size()) * sizeof(DType));
+ DType *grad_a_ptr = reinterpret_cast(
+ workspace.dptr_ + 2 * (a_shape.Size() + b_shape.Size()) * sizeof(DType));
+ DType *grad_b_ptr = reinterpret_cast(
+ workspace.dptr_ + (3 * a_shape.Size() + 2 * b_shape.Size()) * sizeof(DType));
+ int *ipiv_ptr = reinterpret_cast(
+ workspace.dptr_ + 3 * (a_shape.Size() + b_shape.Size()) * sizeof(DType));
+
+ TBlob tensorinv_a_data(tensorinv_a_ptr, a_shape, xpu::kDevMask);
+ TBlob lu_data(lu_ptr, a_shape, xpu::kDevMask);
+ TBlob b_data(b_ptr, b_shape, xpu::kDevMask);
+ TBlob x_data(x_ptr, x_shape, xpu::kDevMask);
+ TBlob grad_a_data(grad_a_ptr, reordered_a_shape, xpu::kDevMask);
+ TBlob grad_b_data(grad_b_ptr, b_shape, xpu::kDevMask);
+ TBlob ipiv_data(ipiv_ptr, b_shape, xpu::kDevMask);
+ MSHADOW_TYPE_SWITCH(a.type_flag_, AType, {
+ mxnet_op::Kernel::Launch(
+ s, a_shape.Size(),
+ lu_ptr,
+ a.dptr());
+ });
+ MSHADOW_TYPE_SWITCH(b.type_flag_, BType, {
+ mxnet_op::Kernel::Launch(
+ s, b_shape.Size(),
+ b_ptr,
+ b.dptr());
+ });
+ MSHADOW_TYPE_SWITCH(x.type_flag_, XType, {
+ mxnet_op::Kernel::Launch(
+ s, x_shape.Size(),
+ x_ptr,
+ x.dptr());
+ });
+ // Eg: lu_data(2, 3, 2, 15, 4, 5) -> tensorinv_a_data(3, 4, 5, 15, 2, 2)
+ tensorinv_a_data = tensorinv_a_data.reshape(reordered_a_shape);
+ mxnet::op::TransposeImpl(ctx.run_ctx,
+ lu_data, // src
+ tensorinv_a_data, // res
+ mxnet::TShape(a_axes.begin(), a_axes.end()));
+
+ Tensor tensorinv_a_tensor =
+ tensorinv_a_data.get_with_shape(Shape3(1, N, N), s);
+ Tensor lu_tensor =
+ lu_data.get_with_shape(Shape3(1, N, N), s);
+ Tensor b_tensor =
+ b_data.get_with_shape(Shape3(1, N, 1), s);
+ Tensor x_tensor =
+ x_data.get_with_shape(Shape3(1, N, 1), s);
+ Tensor grad_a_tensor =
+ grad_a_data.get_with_shape(Shape3(1, N, N), s);
+ Tensor grad_b_tensor =
+ grad_b_data.get_with_shape(Shape3(1, N, 1), s);
+ Tensor ipiv_tensor =
+ ipiv_data.get_with_shape(Shape2(1, N), s);
+
+ // Calculate tensorinv(a).
+ batch_inverse(tensorinv_a_tensor, lu_tensor, ipiv_tensor, ctx);
+ // No need to transpose tensorinv_a
+ // because transpose(tensorinv_a).shape == reordered_a_shape.
+ laop::op(out_grad.get_with_shape(x_tensor.shape_, s),
+ tensorinv_a_tensor,
+ b_tensor,
+ x_tensor,
+ grad_a_tensor,
+ grad_b_tensor,
+ ctx);
+ // Eg: grad_a_src(3, 4, 5, 15, 2, 2) --> lu_data(2, 3, 2, 15, 4, 5)
+ mxnet::op::TransposeImpl(ctx.run_ctx,
+ grad_a_data, // src
+ lu_data, // res
+ mxnet::TShape(a_origin_axes.begin(), a_origin_axes.end()));
+
+ MXNET_ASSIGN_REQ_SWITCH(req[1], req_type, {
+ mxnet_op::Kernel, xpu>::Launch(
+ s, b_shape.Size(), grad_b_tensor.dptr_, grad_b.dptr());
+ });
+ MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
+ mxnet_op::Kernel, xpu>::Launch(
+ s, a_shape.Size(), lu_tensor.dptr_, grad_a.dptr());
+ });
+ }
+ });
+}
+
+template
+void TensorsolveOpBackward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector& inputs,
+ const std::vector& req,
+ const std::vector& outputs) {
+ using namespace mshadow;
+ CHECK_EQ(inputs.size(), 4U);
+ CHECK_EQ(outputs.size(), 2U);
+ CHECK_EQ(req.size(), 2U);
+
+ const TBlob& out_grad = inputs[0];
+ const TBlob& a = inputs[1];
+ const TBlob& b = inputs[2];
+ const TBlob& x = inputs[3];
+ const TBlob& grad_a = outputs[0];
+ const TBlob& grad_b = outputs[1];
+ const mxnet::TShape a_shape = a.shape_;
+ const mxnet::TShape b_shape = b.shape_;
+ const TensorsolveParam& param = nnvm::get(attrs.parsed);
+ mxnet::Tuple a_axes_param = param.a_axes;
+ FixNegativeAxes(&a_axes_param, a_shape);
+
+ size_t workspace_size = TensorsolveBackwardWorkspaceSize(out_grad, a, b, x);
+ Tensor workspace =
+ ctx.requested[0].get_space_typed(Shape1(workspace_size),
+ ctx.get_stream());
+ TensorsolveBackwardImpl(a_axes_param,
+ out_grad,
+ a, b, x,
+ grad_a, grad_b,
+ ctx, req,
+ workspace);
+}
+
+} // namespace op
+} // namespace mxnet
+
+#endif // MXNET_OPERATOR_NUMPY_LINALG_NP_TENSORSOLVE_INL_H_
diff --git a/src/operator/numpy/linalg/np_tensorsolve.cc b/src/operator/numpy/linalg/np_tensorsolve.cc
new file mode 100644
index 000000000000..1dabcdd0eac4
--- /dev/null
+++ b/src/operator/numpy/linalg/np_tensorsolve.cc
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file np_tensorsolve.cc
+ * \brief CPU implementation placeholder of Tensor Solve Operator
+ */
+#include "./np_tensorsolve-inl.h"
+
+namespace mxnet {
+namespace op {
+
+bool TensorsolveOpShape(const nnvm::NodeAttrs& attrs,
+ mxnet::ShapeVector *in_attrs,
+ mxnet::ShapeVector *out_attrs) {
+ CHECK_EQ(in_attrs->size(), 2U);
+ CHECK_EQ(out_attrs->size(), 1U);
+
+ const mxnet::TShape& a_shape = in_attrs->at(0);
+ const mxnet::TShape& b_shape = in_attrs->at(1);
+ const int a_ndim = a_shape.ndim();
+ const int b_ndim = b_shape.ndim();
+
+ if (!ndim_is_known(a_shape) || !ndim_is_known(b_shape)) {
+ return false;
+ }
+
+ if (0 == a_ndim && 0 == b_ndim) {
+ // a and b is scalar
+ SHAPE_ASSIGN_CHECK(*out_attrs, 0, b_shape);
+ } else if (0 == a_ndim && 0 != b_ndim) {
+ // a is scalar, b is tensor
+ CHECK_EQ(b_shape.Size(), 1U)
+ << "a's and b's dimensions don't match";
+ SHAPE_ASSIGN_CHECK(*out_attrs, 0, a_shape);
+ } else if (0 != a_ndim && 0 == b_ndim) {
+ // a is tensor, a is scalar
+ CHECK_EQ(a_shape.Size(), 1U)
+ << "a's and b's dimensions don't match";
+ SHAPE_ASSIGN_CHECK(*out_attrs, 0, a_shape);
+ } else {
+ // a and b of at least 1 dimensions.
+ const TensorsolveParam& param = nnvm::get(attrs.parsed);
+ mxnet::Tuple a_axes_param = param.a_axes;
+ FixNegativeAxes(&a_axes_param, a_shape);
+
+ mxnet::Tuple a_axes_remained;
+ mxnet::Tuple a_axes;
+ GetReorderedAxes(a_axes_param, &a_axes_remained, &a_axes, a_shape);
+ mxnet::TShape a_transpose_shape = GetReorderedShape(a_shape, a_axes);
+
+ // Calculate output shape
+ const int temp = a_ndim > b_ndim ? b_ndim : b_ndim - a_ndim;
+ int prod_front = 1, prod_back = 1;
+ mxnet::TShape out_shape(a_ndim - temp > 0 ? a_ndim - temp : 0, -1);
+ for (int i = 0; i < a_ndim; ++i) {
+ if (i < temp) {
+ prod_front *= a_transpose_shape[i];
+ } else {
+ prod_back *= a_transpose_shape[i];
+ out_shape[i - temp] = a_transpose_shape[i];
+ }
+ }
+ CHECK_EQ(prod_front, prod_back) << "a shape must be square.";
+ CHECK_EQ(prod_back, b_shape.Size()) << "a's and b's dimensions don't match";
+ SHAPE_ASSIGN_CHECK(*out_attrs, 0, out_shape);
+ }
+
+ return shape_is_known(*in_attrs) && shape_is_known(*out_attrs);
+}
+
+inline bool TensorsolveOpType(const nnvm::NodeAttrs& attrs,
+ std::vector* in_attrs,
+ std::vector* out_attrs) {
+ CHECK_EQ(in_attrs->size(), 2U);
+ CHECK_EQ(out_attrs->size(), 1U);
+ int a_type = in_attrs->at(0);
+ int b_type = in_attrs->at(1);
+ // unsupport float16
+ CHECK_NE(a_type, mshadow::kFloat16)
+ << "array type float16 is unsupported in linalg";
+ CHECK_NE(b_type, mshadow::kFloat16)
+ << "array type float16 is unsupported in linalg";
+ if (mshadow::kFloat32 == a_type && mshadow::kFloat32 == b_type) {
+ TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(1));
+ } else {
+ TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat64);
+ }
+ return out_attrs->at(0) != -1;
+}
+
+DMLC_REGISTER_PARAMETER(TensorsolveParam);
+
+NNVM_REGISTER_OP(_npi_tensorsolve)
+.set_attr_parser(mxnet::op::ParamParser)
+.set_num_inputs(2)
+.set_num_outputs(1)
+.set_attr("FListInputNames",
+ [](const NodeAttrs& attrs) {
+ return std::vector{"a", "b"};
+ })
+.set_attr("FInferShape", TensorsolveOpShape)
+.set_attr("FInferType", TensorsolveOpType)
+.set_attr("FResourceRequest",
+ [](const NodeAttrs& attrs) {
+ return std::vector(1, ResourceRequest::kTempSpace);
+ })
+.set_attr("THasDeterministicOutput", true)
+.set_attr("FCompute", TensorsolveOpForward)
+.set_attr("FGradient",
+ mxnet::op::ElemwiseGradUseInOut{"_backward_npi_tensorsolve"})
+.add_argument("a", "NDArray-or-Symbol", "First input")
+.add_argument("b", "NDArray-or-Symbol", "Second input")
+.add_arguments(TensorsolveParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_npi_tensorsolve)
+.set_attr_parser(mxnet::op::ParamParser)
+.set_num_inputs(4)
+.set_num_outputs(2)
+.set_attr("FResourceRequest",
+ [](const NodeAttrs& ){
+ return std::vector{1, ResourceRequest::kTempSpace};
+ })
+.set_attr("TIsBackward", true)
+.set_attr("FCompute", TensorsolveOpBackward);
+
+} // namespace op
+} // namespace mxnet
diff --git a/src/operator/numpy/linalg/np_tensorsolve.cu b/src/operator/numpy/linalg/np_tensorsolve.cu
new file mode 100644
index 000000000000..07e2121750d5
--- /dev/null
+++ b/src/operator/numpy/linalg/np_tensorsolve.cu
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file np_tensorsolve.cu
+ * \brief GPU implementation placeholder of Tensor Solve Operator
+ */
+
+#include
+#include "./np_tensorsolve-inl.h"
+
+namespace mxnet {
+namespace op {
+
+#if MXNET_USE_CUSOLVER == 1
+
+NNVM_REGISTER_OP(_npi_tensorsolve)
+.set_attr("FCompute", TensorsolveOpForward);
+
+NNVM_REGISTER_OP(_backward_npi_tensorsolve)
+.set_attr("FCompute", TensorsolveOpBackward);
+
+#endif
+
+} // namespace op
+} // namespace mxnet
diff --git a/src/operator/numpy/np_broadcast_reduce_op.h b/src/operator/numpy/np_broadcast_reduce_op.h
index 7d0025a62ad2..0efe2c2aa3df 100644
--- a/src/operator/numpy/np_broadcast_reduce_op.h
+++ b/src/operator/numpy/np_broadcast_reduce_op.h
@@ -86,6 +86,21 @@ struct NumpyReduceAxesNoDTypeParam : public dmlc::Parameter {
+ dmlc::optional> axis;
+ bool keepdims;
+ DMLC_DECLARE_PARAMETER(NumpyReduceAxesBoolParam) {
+ DMLC_DECLARE_FIELD(axis)
+ .set_default(dmlc::optional>())
+ .describe("Axis or axes along which a sum is performed. The default, axis=None, will sum "
+ "all of the elements of the input array. If axis is negative it counts from the "
+ "last to the first axis.");
+ DMLC_DECLARE_FIELD(keepdims).set_default(false)
+ .describe("If this is set to `True`, the reduced axes are left "
+ "in the result as dimension with size one.");
+ }
+};
+
inline TShape NumpyReduceAxesShapeImpl(const TShape& ishape,
const dmlc::optional>& axis,
bool keepdims) {
@@ -173,6 +188,20 @@ inline bool NumpyReduceAxesShape(const nnvm::NodeAttrs& attrs,
return shape_is_known(out_attrs->at(0));
}
+inline bool NumpyReduceAxesBoolShape(const nnvm::NodeAttrs& attrs,
+ std::vector *in_attrs,
+ std::vector *out_attrs) {
+ CHECK_EQ(in_attrs->size(), 1U);
+ CHECK_EQ(out_attrs->size(), 1U);
+ if (!shape_is_known(in_attrs->at(0))) {
+ return false;
+ }
+ const NumpyReduceAxesBoolParam& param = nnvm::get(attrs.parsed);
+ SHAPE_ASSIGN_CHECK(*out_attrs, 0,
+ NumpyReduceAxesShapeImpl((*in_attrs)[0], param.axis, param.keepdims));
+ return shape_is_known(out_attrs->at(0));
+}
+
inline bool NumpyReduceAxesNoDTypeShape(const nnvm::NodeAttrs& attrs,
std::vector *in_attrs,
std::vector *out_attrs) {
@@ -298,6 +327,30 @@ void NumpyReduceAxesNoDTypeCompute(const nnvm::NodeAttrs& attrs,
ReduceAxesComputeImpl(ctx, inputs, req, outputs, small);
}
+template
+void NumpyReduceAxesBoolCompute(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector& inputs,
+ const std::vector& req,
+ const std::vector& outputs) {
+ const NumpyReduceAxesBoolParam& param = nnvm::get(attrs.parsed);
+ mshadow::Stream* s = ctx.get_stream();
+ if (inputs[0].shape_.Size() == 0 && outputs[0].shape_.Size() != 0) {
+ using namespace mxnet_op;
+ Kernel::Launch(s, outputs[0].shape_.Size(), outputs[0].dptr());
+ return;
+ }
+ if (param.axis.has_value() && param.axis.value().ndim() == 0) {
+ UnaryOp::IdentityCompute(attrs, ctx, inputs, req, outputs);
+ }
+ TShape small;
+ if (param.keepdims) {
+ small = outputs[0].shape_;
+ } else {
+ small = NumpyReduceAxesShapeImpl(inputs[0].shape_, param.axis, true);
+ }
+ ReduceAxesComputeBoolImpl(ctx, inputs, req, outputs, small);
+}
template
inline void NumpyReduceAxesBackwardUseNone(const nnvm::NodeAttrs& attrs,
diff --git a/src/operator/numpy/np_broadcast_reduce_op_boolean.cc b/src/operator/numpy/np_broadcast_reduce_op_boolean.cc
new file mode 100644
index 000000000000..7529c0d4e1d3
--- /dev/null
+++ b/src/operator/numpy/np_broadcast_reduce_op_boolean.cc
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file np_broadcast_reduce_op_boolean.cc
+ * \brief CPU Implementation of broadcast and reduce functions based on boolean.
+ */
+
+#include "./np_broadcast_reduce_op.h"
+
+namespace mxnet {
+namespace op {
+
+inline bool NumpyReduceAxesBoolType(const nnvm::NodeAttrs& attrs,
+ std::vector *in_attrs,
+ std::vector *out_attrs) {
+ CHECK_EQ(in_attrs->size(), 1U);
+ CHECK_EQ(out_attrs->size(), 1U);
+ TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kBool);
+ return out_attrs->at(0) != -1 && in_attrs->at(0) != -1;
+}
+
+DMLC_REGISTER_PARAMETER(NumpyReduceAxesBoolParam);
+
+NNVM_REGISTER_OP(_np_any)
+.set_attr_parser(ParamParser)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr("FListInputNames",
+ [](const NodeAttrs& attrs) {
+ return std::vector{"data"};
+ })
+.set_attr("FResourceRequest",
+ [](const NodeAttrs& attrs) {
+ return std::vector{ResourceRequest::kTempSpace};
+ })
+.set_attr("THasDeterministicOutput", true)
+.set_attr("FInferShape", NumpyReduceAxesBoolShape)
+.set_attr("FInferType", NumpyReduceAxesBoolType)
+.set_attr("FCompute", NumpyReduceAxesBoolCompute)
+.set_attr("FGradient", MakeZeroGradNodes)
+.add_argument("data", "NDArray-or-Symbol", "Input ndarray")
+.add_arguments(NumpyReduceAxesBoolParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_np_all)
+.set_attr_parser(ParamParser)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr("FListInputNames",
+ [](const NodeAttrs& attrs) {
+ return std::vector{"data"};
+ })
+.set_attr("FResourceRequest",
+ [](const NodeAttrs& attrs) {
+ return std::vector{ResourceRequest::kTempSpace};
+ })
+.set_attr("THasDeterministicOutput", true)
+.set_attr("FInferShape", NumpyReduceAxesBoolShape)
+.set_attr("FInferType", NumpyReduceAxesBoolType)
+.set_attr("FCompute", NumpyReduceAxesBoolCompute)
+.set_attr("FGradient", MakeZeroGradNodes)
+.add_argument("data", "NDArray-or-Symbol", "Input ndarray")
+.add_arguments(NumpyReduceAxesBoolParam::__FIELDS__());
+
+} // namespace op
+} // namespace mxnet
diff --git a/src/operator/numpy/np_broadcast_reduce_op_boolean.cu b/src/operator/numpy/np_broadcast_reduce_op_boolean.cu
new file mode 100644
index 000000000000..2c206bf88b2f
--- /dev/null
+++ b/src/operator/numpy/np_broadcast_reduce_op_boolean.cu
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file np_broadcast_reduce_op_boolean.cu
+ * \brief GPU Implementation of broadcast and reduce functions based on boolean.
+ */
+
+#include "./np_broadcast_reduce_op.h"
+
+namespace mxnet {
+namespace op {
+
+NNVM_REGISTER_OP(_np_any)
+.set_attr("FCompute", NumpyReduceAxesBoolCompute);
+
+NNVM_REGISTER_OP(_np_all)
+.set_attr("FCompute", NumpyReduceAxesBoolCompute);
+
+} // namespace op
+} // namespace mxnet
diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cc b/src/operator/numpy/np_broadcast_reduce_op_value.cc
index 2a1bc5261701..cf92da52d1f8 100644
--- a/src/operator/numpy/np_broadcast_reduce_op_value.cc
+++ b/src/operator/numpy/np_broadcast_reduce_op_value.cc
@@ -161,6 +161,7 @@ inline bool NumpyReduceAxesNoDTypeType(const nnvm::NodeAttrs& attrs,
}
NNVM_REGISTER_OP(_np_max)
+.add_alias("_np_amax")
.describe(R"code()code" ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
diff --git a/src/operator/numpy/np_delete_op-inl.h b/src/operator/numpy/np_delete_op-inl.h
new file mode 100644
index 000000000000..a144833f3294
--- /dev/null
+++ b/src/operator/numpy/np_delete_op-inl.h
@@ -0,0 +1,347 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file np_delete_op-inl.h
+ * \brief Function definition of delete operators
+ */
+#ifndef MXNET_OPERATOR_NUMPY_NP_DELETE_OP_INL_H_
+#define MXNET_OPERATOR_NUMPY_NP_DELETE_OP_INL_H_
+
+#include