Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add a contrib operator for Constant #15993

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1417,6 +1417,37 @@ def convert_floor(node, **kwargs):
"""
return create_basic_op_node('Floor', node, kwargs)


@mx_op.register("_contrib_constant")
def convert_constant(node, **kwargs):
"""Map MXNet's contrib.constant operator attributes to onnx's Constant
operator and return the created node.
"""
name, _, attrs = get_inputs(node, kwargs)

value = convert_string_to_list(attrs["value"])
dtype = attrs.get('dtype', 'int64')
data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)]
dims = np.shape(value)

output_shape_name = "constant_attr_tensor" + str(kwargs["idx"])

constant_node = onnx.helper.make_node(
"Constant",
[],
[name],
value=onnx.helper.make_tensor(
name=output_shape_name,
data_type=data_type,
dims=dims,
vals=value,
raw=False,
),
name=name
)

return [constant_node]

# Changing shape and type.
@mx_op.register("Reshape")
def convert_reshape(node, **kwargs):
Expand Down
5 changes: 3 additions & 2 deletions python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from ._op_translations import tanh, arccos, arcsin, arctan, _cos, _sin, _tan
from ._op_translations import softplus, shape, gather, lp_pooling, size
from ._op_translations import ceil, floor, hardsigmoid, global_lppooling
from ._op_translations import concat, hardmax, topk

from ._op_translations import concat, hardmax, topk, constant
from ._op_translations import leaky_relu, _elu, _prelu, _selu, softmax, fully_connected
from ._op_translations import global_avgpooling, global_maxpooling, linalg_gemm
from ._op_translations import sigmoid, pad, relu, matrix_multiplication, batch_norm
Expand All @@ -43,7 +44,7 @@
# defined in the op_translations module.
_convert_map = {
# Generator Functions
'Constant' : identity,
'Constant' : constant,
'RandomUniform' : random_uniform,
'RandomNormal' : random_normal,
'RandomUniformLike' : random_uniform,
Expand Down
24 changes: 24 additions & 0 deletions python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,30 @@ def identity(attrs, inputs, proto_obj):
"""Returns the identity function of the input."""
return 'identity', attrs, inputs


def constant(attrs, inputs, proto_obj):
"""Returns the identity function of the input."""
try:
from onnx import numpy_helper
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
except ImportError:
raise ImportError("Onnx and protobuf need to be installed. "
"Instructions to install - https://github.com/onnx/onnx")
data = attrs['value']
dtype = 1
if not isinstance(data, tuple):
val = numpy_helper.to_array(data)
val = tuple(val) if data.dims else val
dtype = int(data.data_type)
else:
val = data

dtype = TENSOR_TYPE_TO_NP_TYPE[dtype]
sym = symbol.contrib.constant(value=val, dtype=dtype)
new_attrs = translation_utils._remove_attributes(attrs, ['value'])
return sym, new_attrs, inputs


def random_uniform(attrs, inputs, proto_obj):
"""Draw random samples from a uniform distribtuion."""
try:
Expand Down
91 changes: 91 additions & 0 deletions src/operator/contrib/constant-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file constant-inl.h
*/

#ifndef MXNET_OPERATOR_CONTRIB_CONSTANT_INL_H_
#define MXNET_OPERATOR_CONTRIB_CONSTANT_INL_H_

#include <vector>
#include "../mxnet_op.h"
#include "../mshadow_op.h"
#include "../tensor/matrix_op-inl.h"

namespace mxnet {
namespace op {

struct ConstantParam : public dmlc::Parameter<ConstantParam> {
mxnet::Tuple<float> value;
int dtype;
DMLC_DECLARE_PARAMETER(ConstantParam) {
DMLC_DECLARE_FIELD(value)
vandanavk marked this conversation as resolved.
Show resolved Hide resolved
.set_default({1.0f, 1.0f})
.describe("The target value");
DMLC_DECLARE_FIELD(dtype).set_default(mshadow::kFloat32)
MXNET_ADD_ALL_TYPES
.describe("Target data type.");
}
};

inline bool ConstantShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
const ConstantParam& param_ = nnvm::get<ConstantParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 0U);
CHECK_EQ(out_attrs->size(), 1U);
const double out_size = param_.value.end() - param_.value.begin();
SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape({static_cast<nnvm::dim_t>(out_size)}));
return true;
}

struct constant {
template<typename DType>
MSHADOW_XINLINE static void Map(index_t i, const mxnet::Tuple<float>& value,
int n, int req, DType* out) {
for (ptrdiff_t j = 0; j < n; j++) {
KERNEL_ASSIGN(out[j], req, value[j]);
}
}
};

template<typename xpu, typename ParamType>
void ConstantForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
Stream<xpu> *s = ctx.get_stream<xpu>();
const ParamType& param = nnvm::get<ParamType>(attrs.parsed);
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Kernel<constant, xpu>::Launch(s,
outputs[0].Size(),
param.value,
static_cast<DType>(param.value.ndim()),
req[0],
outputs[0].dptr<DType>());
});
}

} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_CONTRIB_CONSTANT_INL_H_
65 changes: 65 additions & 0 deletions src/operator/contrib/constant.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file constant.cc
*/

#include "./constant-inl.h"
#include "../tensor/elemwise_binary_op.h"
#include "../elemwise_op_common.h"

namespace mxnet {
namespace op {

inline bool ConstantType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(out_attrs->size(), 1U);
const ConstantParam& param_ = nnvm::get<ConstantParam>(attrs.parsed);
TYPE_ASSIGN_CHECK(*out_attrs, 0, param_.dtype);
return true;
}



DMLC_REGISTER_PARAMETER(ConstantParam);
NNVM_REGISTER_OP(_contrib_constant)
.describe(R"code(Creates a constant tensor for a value.
Example::

v1 = (1, 2)
constant_op = symbol.contrib.constant(value=v1)
executor = constant_op.simple_bind(ctx=cpu())
executor.forward(is_train=True)
executor.outputs
[ -1. 2.]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's the most confusing part. I assume v1 is the shape of constant output, but where to specify the output value? On the other words, why the output is [-1, 2]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

v1 is the value we want to store. since this operator is a constant, the output is equal to the input ([-1,2])

)code" ADD_FILELINE)
.set_num_inputs(0)
.set_num_outputs(1)
.set_attr_parser(ParamParser<ConstantParam>)
.set_attr<mxnet::FInferShape>("FInferShape", ConstantShape)
.set_attr<nnvm::FInferType>("FInferType", ConstantType)
.set_attr<FCompute>("FCompute<cpu>", ConstantForward<cpu, ConstantParam>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_arguments(ConstantParam::__FIELDS__());


} // namespace op
} // namespace mxnet
4 changes: 3 additions & 1 deletion tests/python-pytest/onnx/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,9 @@ def test_exports(self):
("test_random_normal", mx.sym.random_normal, "RandomNormal", [],
{'shape': (2, 2), 'loc': 0, 'scale': 1}, False, {'modify': {'loc': 'mean'}}, False, True),
("test_random_uniform", mx.sym.random_uniform, "RandomUniform", [],
{'shape': (2, 2), 'low': 0.5, 'high': 1.0}, False, {}, False, True)
{'shape': (2, 2), 'low': 0.5, 'high': 1.0}, False, {}, False, True),
("test_constant", mx.sym.contrib.constant, "Constant", [],
{'value': (-1, 2)}, False, {}, False, True)
]

test_scalar_ops = ['Add', 'Sub', 'rSub' 'Mul', 'Div', 'Pow']
Expand Down
25 changes: 25 additions & 0 deletions tests/python/unittest/test_contrib_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import random
import itertools
from numpy.testing import assert_allclose, assert_array_equal
from common import with_seed
from mxnet.test_utils import *
from common import with_seed, assert_raises_cudnn_not_satisfied
import unittest
Expand Down Expand Up @@ -334,6 +335,7 @@ def test_multibox_prior_op():
boxes = Y.reshape((h, w, 5, 4))
assert_allclose(boxes.asnumpy()[250, 250, 0, :], np.array([-0.948249, 0.362671, 1.636436, 0.530377]), atol=1e-5, rtol=1e-5)


def test_box_encode_op():
anchors = mx.nd.array([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]).reshape((1, -1, 4))
refs = mx.nd.array([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]).reshape((1, -1, 4))
Expand Down Expand Up @@ -445,6 +447,29 @@ def test_modulated_deformable_convolution():
rtol, atol = 0.05, 1e-3


@with_seed()
def test_constant():
def constant_testcases(value, dtype='float32'):
net = mx.sym.contrib.constant(value=value)
js = net.tojson()
net = mx.sym.load_json(js)
exe = net.bind(default_context(), {})
exe.forward(is_train=True)
assert_almost_equal(exe.outputs[0].asnumpy(), value, rtol=1e-3, atol=1e-3)
exe.backward()

test_cases = [
[(-1, 2)],
[9216],
[(3,5,1,-1)],
[()],
[(2, 3, 5, 5), 'int64'],
[(-5.5, 10.2, 3.7), 'float32']]

for test_case in test_cases:
constant_testcases(*test_case)


if __name__ == '__main__':
import nose
nose.runmodule()