forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
15 changed files
with
665 additions
and
258 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
/*! | ||
* Copyright (c) 2017 by Contributors | ||
* \file elemwise.cc | ||
* \brief Elemenwise operators | ||
*/ | ||
#include <nnvm/op.h> | ||
#include <nnvm/node.h> | ||
#include <nnvm/op_attr_types.h> | ||
#include <nnvm/top/tensor.h> | ||
#include "../op_common.h" | ||
#include "../elemwise_op_common.h" | ||
|
||
namespace nnvm { | ||
namespace top { | ||
// sigmoid | ||
NNVM_REGISTER_ELEMWISE_UNARY_OP(sigmoid) | ||
.describe(R"code(Computes sigmoid. | ||
.. math:: | ||
y = 1 / (1 + exp(-x)) | ||
)code" NNVM_ADD_FILELINE) | ||
.set_support_level(1); | ||
|
||
// tanh | ||
NNVM_REGISTER_ELEMWISE_UNARY_OP(tanh) | ||
.describe(R"code(Returns the hyperbolic tangent of the input array, computed element-wise. | ||
.. math:: | ||
tanh(x) = sinh(x) / cosh(x) | ||
)code" NNVM_ADD_FILELINE) | ||
.set_support_level(1); | ||
|
||
// exp | ||
NNVM_REGISTER_ELEMWISE_UNARY_OP(exp) | ||
.describe(R"code(Returns the exp input array, computed element-wise. | ||
.. math:: | ||
exp(x) | ||
)code" NNVM_ADD_FILELINE) | ||
.set_support_level(1); | ||
|
||
// log | ||
NNVM_REGISTER_ELEMWISE_UNARY_OP(log) | ||
.describe(R"code(Returns the log input array, computed element-wise. | ||
.. math:: | ||
log(x) | ||
)code" NNVM_ADD_FILELINE) | ||
.set_support_level(1); | ||
|
||
// binary ops | ||
|
||
NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_add) | ||
.describe(R"code(Element-wise add | ||
)code") | ||
.set_support_level(1); | ||
|
||
NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_sub) | ||
.describe(R"code(Element-wise substraction | ||
)code" NNVM_ADD_FILELINE) | ||
.set_support_level(1); | ||
|
||
NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_mul) | ||
.describe(R"code(Element-wise multiplication | ||
)code" NNVM_ADD_FILELINE) | ||
.set_support_level(1); | ||
|
||
NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_div) | ||
.describe(R"code(Element-wise multiplication | ||
)code" NNVM_ADD_FILELINE) | ||
.set_support_level(1); | ||
|
||
// negative | ||
NNVM_REGISTER_ELEMWISE_UNARY_OP(negative) | ||
.describe(R"code(Elemenwise numeric negative | ||
)code" NNVM_ADD_FILELINE) | ||
.set_support_level(3); | ||
|
||
// copy | ||
NNVM_REGISTER_ELEMWISE_UNARY_OP(copy) | ||
.describe(R"code(Copy tensor to another one. | ||
)code" NNVM_ADD_FILELINE) | ||
.set_support_level(3); | ||
|
||
// unary scalar op | ||
DMLC_REGISTER_PARAMETER(ScalarParam); | ||
|
||
NNVM_REGISTER_ELEMWISE_UNARY_OP(__add_scalar__) | ||
.describe(R"code(Tensor add scalar | ||
)code" NNVM_ADD_FILELINE) | ||
.set_attr_parser(ParamParser<ScalarParam>) | ||
.set_support_level(3); | ||
|
||
NNVM_REGISTER_ELEMWISE_UNARY_OP(__sub_scalar__) | ||
.describe(R"code(Tensor substract scalar | ||
)code" NNVM_ADD_FILELINE) | ||
.set_attr_parser(ParamParser<ScalarParam>) | ||
.set_support_level(3); | ||
|
||
NNVM_REGISTER_ELEMWISE_UNARY_OP(__rsub_scalar__) | ||
.describe(R"code(scalar substract Tensor | ||
)code" NNVM_ADD_FILELINE) | ||
.set_attr_parser(ParamParser<ScalarParam>) | ||
.set_support_level(3); | ||
|
||
NNVM_REGISTER_ELEMWISE_UNARY_OP(__mul_scalar__) | ||
.describe(R"code(Tensor multiplies scalar | ||
)code" NNVM_ADD_FILELINE) | ||
.set_attr_parser(ParamParser<ScalarParam>) | ||
.set_support_level(3); | ||
|
||
NNVM_REGISTER_ELEMWISE_UNARY_OP(__div_scalar__) | ||
.describe(R"code(Tensor divides scalar | ||
)code" NNVM_ADD_FILELINE) | ||
.set_attr_parser(ParamParser<ScalarParam>) | ||
.set_support_level(3); | ||
|
||
NNVM_REGISTER_ELEMWISE_UNARY_OP(__rdiv_scalar__) | ||
.describe(R"code(scalar divides Tensor | ||
)code" NNVM_ADD_FILELINE) | ||
.set_attr_parser(ParamParser<ScalarParam>) | ||
.set_support_level(3); | ||
|
||
NNVM_REGISTER_ELEMWISE_UNARY_OP(__pow_scalar__) | ||
.describe(R"code(Tensor power scalar | ||
)code" NNVM_ADD_FILELINE) | ||
.set_attr_parser(ParamParser<ScalarParam>) | ||
.set_support_level(3); | ||
|
||
NNVM_REGISTER_ELEMWISE_UNARY_OP(__rpow_scalar__) | ||
.describe(R"code(scalar power Tensor | ||
)code" NNVM_ADD_FILELINE) | ||
.set_attr_parser(ParamParser<ScalarParam>) | ||
.set_support_level(3); | ||
|
||
|
||
} // namespace top | ||
} // namespace nnvm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,319 @@ | ||
/*! | ||
* Copyright (c) 2017 by Contributors | ||
* \file transform.cc | ||
* \brief Injective transformation of shape or type. | ||
*/ | ||
#include <nnvm/op.h> | ||
#include <nnvm/node.h> | ||
#include <nnvm/op_attr_types.h> | ||
#include <nnvm/top/tensor.h> | ||
#include "../op_common.h" | ||
#include "../elemwise_op_common.h" | ||
|
||
namespace nnvm { | ||
namespace top { | ||
|
||
// flatten | ||
inline bool FlattenInferShape(const NodeAttrs& attrs, | ||
std::vector<TShape> *in_attrs, | ||
std::vector<TShape> *out_attrs) { | ||
CHECK_EQ(in_attrs->size(), 1U) << "Input: [data]"; | ||
CHECK_EQ(out_attrs->size(), 1U); | ||
const TShape &dshape = (*in_attrs)[0]; | ||
if (dshape.ndim() == 0) return false; | ||
uint32_t target_dim = 1; | ||
for (uint32_t i = 1; i < dshape.ndim(); ++i) { | ||
target_dim *= dshape[i]; | ||
} | ||
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, TShape({dshape[0], target_dim})); | ||
return true; | ||
} | ||
|
||
NNVM_REGISTER_OP(flatten) | ||
.describe(R"code(Flattens the input array into a 2-D array by collapsing the higher dimensions. | ||
For an input array with shape ``(d1, d2, ..., dk)``, `flatten` operation reshapes | ||
the input array into an output array of shape ``(d1, d2*...*dk)``. | ||
Example:: | ||
x = [[ | ||
[1,2,3], | ||
[4,5,6], | ||
[7,8,9] | ||
], | ||
[ [1,2,3], | ||
[4,5,6], | ||
[7,8,9] | ||
]], | ||
flatten(x) = [[ 1., 2., 3., 4., 5., 6., 7., 8., 9.], | ||
[ 1., 2., 3., 4., 5., 6., 7., 8., 9.]] | ||
)code" NNVM_ADD_FILELINE) | ||
.set_num_inputs(1) | ||
.set_num_outputs(1) | ||
.set_attr<FInferShape>("FInferShape", FlattenInferShape) | ||
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) | ||
.add_argument("data", "Tensor", "Input data.") | ||
.set_support_level(1); | ||
|
||
// concatenate | ||
DMLC_REGISTER_PARAMETER(ConcatenateParam); | ||
|
||
inline bool ConcatenateInferShape(const NodeAttrs& attrs, | ||
std::vector<TShape> *in_shape, | ||
std::vector<TShape> *out_shape) { | ||
const ConcatenateParam& param = nnvm::get<ConcatenateParam>(attrs.parsed); | ||
TShape dshape; | ||
dim_t size = 0; | ||
bool has_zero = false; | ||
for (size_t i = 0; i < in_shape->size(); ++i) { | ||
TShape tmp = (*in_shape)[i]; | ||
if (tmp.ndim()) { | ||
CHECK_LT(static_cast<dim_t>(param.axis), tmp.ndim()) | ||
<< "concat dim " << param.axis << " out of range of input shape " << tmp; | ||
has_zero = tmp[param.axis] == 0 || has_zero; | ||
size += tmp[param.axis]; | ||
tmp[param.axis] = 0; | ||
shape_assign(&dshape, tmp); | ||
} | ||
} | ||
|
||
TShape tmp = (*out_shape)[0]; | ||
if (tmp.ndim()) { | ||
CHECK_LT(static_cast<dim_t>(param.axis), tmp.ndim()) | ||
<< "concat dim " << param.axis << " out of range of input shape " << tmp; | ||
tmp[param.axis] = 0; | ||
shape_assign(&dshape, tmp); | ||
} | ||
|
||
if (dshape.ndim() == 0) return false; | ||
|
||
for (size_t i = 0; i < in_shape->size(); ++i) { | ||
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, i, dshape); | ||
} | ||
|
||
if (!has_zero) dshape[param.axis] = size; | ||
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, dshape); | ||
return dshape.Size() != 0; | ||
} | ||
|
||
NNVM_REGISTER_OP(concatenate) | ||
.describe(R"code(Joins input arrays along a given axis. | ||
The dimensions of the input arrays should be the same except the axis along | ||
which they will be concatenated. | ||
The dimension of the output array along the concatenated axis will be equal | ||
to the sum of the corresponding dimensions of the input arrays. | ||
Example:: | ||
x = [[1,1],[2,2]] | ||
y = [[3,3],[4,4],[5,5]] | ||
z = [[6,6], [7,7],[8,8]] | ||
concatenate(x,y,z,dim=0) = [[ 1., 1.], | ||
[ 2., 2.], | ||
[ 3., 3.], | ||
[ 4., 4.], | ||
[ 5., 5.], | ||
[ 6., 6.], | ||
[ 7., 7.], | ||
[ 8., 8.]] | ||
Note that you cannot concat x,y,z along dimension 1 since dimension | ||
0 is not the same for all the input arrays. | ||
concatenate(y,z,dim=1) = [[ 3., 3., 6., 6.], | ||
[ 4., 4., 7., 7.], | ||
[ 5., 5., 8., 8.]] | ||
)code" NNVM_ADD_FILELINE) | ||
.set_num_outputs(1) | ||
.set_num_inputs(kVarg) | ||
.set_attr_parser(ParamParser<ConcatenateParam>) | ||
.add_argument("data", "Tensor-or-Tensor[]", "List of arrays to concatenate") | ||
.set_attr<FInferShape>("FInferShape", ConcatenateInferShape) | ||
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>) | ||
.add_arguments(ConcatenateParam::__FIELDS__()) | ||
.set_support_level(1); | ||
|
||
|
||
// cast | ||
DMLC_REGISTER_PARAMETER(CastParam); | ||
|
||
inline bool CastInferType(const NodeAttrs& attrs, | ||
std::vector<int> *in_attrs, | ||
std::vector<int> *out_attrs) { | ||
const CastParam& param = nnvm::get<CastParam>(attrs.parsed); | ||
CHECK_EQ(out_attrs->size(), 1U); | ||
NNVM_ASSIGN_OUTPUT_TYPE(attrs, *out_attrs, 0, param.dtype); | ||
return true; | ||
} | ||
|
||
NNVM_REGISTER_OP(cast) | ||
.describe(R"code(Cast the content of input to dtype. | ||
)code" NNVM_ADD_FILELINE) | ||
.add_argument("data", "Tensor", "Input data array") | ||
.set_attr_parser(ParamParser<CastParam>) | ||
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>) | ||
.set_attr<FInferType>("FInferType", CastInferType) | ||
.add_arguments(CastParam::__FIELDS__()) | ||
.set_num_inputs(1) | ||
.set_num_outputs(1) | ||
.set_support_level(1); | ||
|
||
|
||
// reshape | ||
DMLC_REGISTER_PARAMETER(ReshapeParam); | ||
|
||
inline bool ReshapeInferShape(const NodeAttrs& attrs, | ||
std::vector<TShape> *in_attrs, | ||
std::vector<TShape> *out_attrs) { | ||
const ReshapeParam& param = nnvm::get<ReshapeParam>(attrs.parsed); | ||
CHECK_GT(param.shape.ndim(), 0); | ||
CHECK_EQ(in_attrs->size(), 1U) << "Input: [data]"; | ||
CHECK_EQ(out_attrs->size(), 1U); | ||
|
||
const TShape &dshape = (*in_attrs)[0]; | ||
if (dshape.ndim() == 0) return false; | ||
|
||
const Tuple<int64_t>& target_shape = param.shape; | ||
std::vector<int64_t> oshape; | ||
dim_t src_idx = 0; | ||
int infer_idx = -1; | ||
|
||
for (dim_t i = 0; i < target_shape.ndim(); ++i) { | ||
int svalue = target_shape[i]; | ||
// special flag handling for shape inference. | ||
if (svalue > 0) { | ||
oshape.push_back(svalue); | ||
++src_idx; | ||
} else if (svalue == 0) { | ||
// keep same | ||
CHECK_LT(src_idx, dshape.ndim()); | ||
oshape.push_back(dshape[src_idx++]); | ||
} else if (svalue == -1) { | ||
// inference based on rest | ||
CHECK_LT(infer_idx, 0) | ||
<< "One and only one dim can be inferred"; | ||
infer_idx = i; | ||
oshape.push_back(1); | ||
++src_idx; | ||
} else if (svalue == -2) { | ||
// copy all remaining dims from source | ||
while (src_idx < dshape.ndim()) { | ||
oshape.push_back(dshape[src_idx++]); | ||
} | ||
} else if (svalue == -3) { | ||
// merge two dims from source | ||
CHECK_LT(src_idx + 1, dshape.ndim()); | ||
dim_t d1 = dshape[src_idx++]; | ||
dim_t d2 = dshape[src_idx++]; | ||
oshape.push_back(d1 * d2); | ||
} else if (svalue == -4) { | ||
// split the source dim s into two dims | ||
// read the left dim and then the right dim (either can be -1) | ||
CHECK_LT(i + 2, target_shape.ndim()); | ||
CHECK_LT(src_idx, dshape.ndim()); | ||
dim_t d0 = dshape[src_idx++]; | ||
int d1 = target_shape[++i]; | ||
int d2 = target_shape[++i]; | ||
CHECK(d1 != -1 || d2 != -1) << "Split dims cannot both be -1."; | ||
if (d1 == -1) d1 = d0 / d2; | ||
if (d2 == -1) d2 = d0 / d1; | ||
CHECK_EQ(d1 * d2, static_cast<int>(d0)) << | ||
"Split dims " << d1 << ", " << d2 << " do not divide original dim " << d0; | ||
oshape.push_back(d1); | ||
oshape.push_back(d2); | ||
} | ||
} | ||
|
||
if (infer_idx >= 0) { | ||
if (dshape.Size() > 0) { | ||
int new_size = 1; | ||
for (int x : oshape) { | ||
new_size *= x; | ||
} | ||
oshape[infer_idx] = dshape.Size() / new_size; | ||
} else { | ||
oshape[infer_idx] = 0; | ||
} | ||
} | ||
TShape out_shape(oshape.begin(), oshape.end()); | ||
CHECK_EQ(out_shape.Size(), dshape.Size()) | ||
<< "Target shape size is different to source. " | ||
<< "Target: " << out_shape | ||
<< "\nSource: " << dshape; | ||
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, out_shape); | ||
return true; | ||
} | ||
|
||
NNVM_REGISTER_OP(reshape) | ||
.describe(R"code(Reshapes the input array. | ||
Given an array and a shape, this function returns a copy of the array in the new shape. | ||
The shape is a tuple of integers such as (2,3,4).The size of the new shape should be same as the size of the input array. | ||
Example:: | ||
reshape([1,2,3,4], shape=(2,2)) = [[1,2], [3,4]] | ||
To give user more convenience in without doing manual shape inference, | ||
some dimensions of the shape can take special values from the set {0, -1, -2, -3, -4}. | ||
The significance of each is explained below: | ||
- ``0`` copy this dimension from the input to the output shape. | ||
Example:: | ||
- input shape = (2,3,4), shape = (4,0,2), output shape = (4,3,2) | ||
- input shape = (2,3,4), shape = (2,0,0), output shape = (2,3,4) | ||
- ``-1`` infers the dimension of the output shape by using the remainder of the input dimensions | ||
keeping the size of the new array same as that of the input array. | ||
At most one dimension of shape can be -1. | ||
Example:: | ||
- input shape = (2,3,4), shape = (6,1,-1), output shape = (6,1,4) | ||
- input shape = (2,3,4), shape = (3,-1,8), output shape = (3,1,8) | ||
- input shape = (2,3,4), shape=(-1,), output shape = (24,) | ||
- ``-2`` copy all/remainder of the input dimensions to the output shape. | ||
Example:: | ||
- input shape = (2,3,4), shape = (-2,), output shape = (2,3,4) | ||
- input shape = (2,3,4), shape = (2,-2), output shape = (2,3,4) | ||
- input shape = (2,3,4), shape = (-2,1,1), output shape = (2,3,4,1,1) | ||
- ``-3`` use the product of two consecutive dimensions of the input shape as the output dimension. | ||
Example:: | ||
- input shape = (2,3,4), shape = (-3,4), output shape = (6,4) | ||
- input shape = (2,3,4,5), shape = (-3,-3), output shape = (6,20) | ||
- input shape = (2,3,4), shape = (0,-3), output shape = (2,12) | ||
- input shape = (2,3,4), shape = (-3,-2), output shape = (6,4) | ||
- ``-4`` split one dimension of the input into two dimensions passed subsequent to -4 in shape (can contain -1). | ||
Example:: | ||
- input shape = (2,3,4), shape = (-4,1,2,-2), output shape =(1,2,3,4) | ||
- input shape = (2,3,4), shape = (2,-4,-1,3,-2), output shape = (2,1,3,4) | ||
)code" NNVM_ADD_FILELINE) | ||
.set_num_inputs(1) | ||
.set_num_outputs(1) | ||
.set_attr_parser(ParamParser<ReshapeParam>) | ||
.set_attr<FInferShape>("FInferShape", ReshapeInferShape) | ||
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) | ||
.add_argument("data", "Tensor", "Input data.") | ||
.set_support_level(3); | ||
|
||
} // namespace top | ||
} // namespace nnvm |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import json | ||
import nnvm.symbol as sym | ||
import nnvm.graph as graph | ||
|
||
def infer_shape(sym): | ||
g = graph.create(sym) | ||
g._set_json_attr("shape_attr_key", "shape") | ||
g = g.apply("InferShape") | ||
jgraph = json.loads(g.apply("SaveJSON").json_attr("json")) | ||
jnodes = jgraph["nodes"] | ||
jnode_row_ptr = jgraph["node_row_ptr"] | ||
sdict = {} | ||
vshape = g.json_attr("shape") | ||
for i, n in enumerate(jnodes): | ||
begin, end = jnode_row_ptr[i], jnode_row_ptr[i + 1] | ||
sdict[n["name"]] = vshape[begin:end] | ||
return sdict | ||
|
||
# Level 1 | ||
def test_dense(): | ||
x = sym.Variable("x", shape=(10, 20)) | ||
y = sym.dense(x, units=30, name="fc") | ||
sdict = infer_shape(y) | ||
assert(sdict["fc"][0] == [10, 30]) | ||
assert(sdict["fc_bias"][0] == [30]) | ||
|
||
|
||
def test_concatenate(): | ||
x1 = sym.Variable("x", shape=(10, 20)) | ||
x2 = sym.Variable("y", shape=(10, 30)) | ||
z = sym.concatenate(x1, x2, name="concat") | ||
sdict = infer_shape(z) | ||
assert(sdict["concat"][0] == [10, 50]) | ||
z = sym.concatenate(x1, x1, axis=0, name="concat") | ||
sdict = infer_shape(z) | ||
assert(sdict["concat"][0] == [20, 20]) | ||
|
||
|
||
def test_batchnorm(): | ||
x = sym.Variable("x", shape=(10, 20)) | ||
y = sym.batch_norm(1 / x, name="bn") | ||
sdict = infer_shape(y) | ||
assert(sdict["bn_gamma"][0] == [20]) | ||
|
||
|
||
def test_flatten(): | ||
x = sym.Variable("x", shape=(10, 20, 10)) | ||
y = sym.flatten(x) * 2 | ||
y = sym.exp(y, name="y") | ||
sdict = infer_shape(y) | ||
assert(sdict["y"][0] == [10, 200]) | ||
|
||
# Level 3 | ||
def test_reshape(): | ||
def check(in_shape, tshape, out_shape): | ||
x = sym.Variable("x", shape=in_shape) | ||
y = sym.reshape(x, shape=tshape, name="y") | ||
sdict = infer_shape(y) | ||
assert(tuple(sdict["y"][0]) == tuple(out_shape)) | ||
|
||
check((4,), (2, 2), (2, 2)) | ||
check((2, 3, 4), (4, 0, 2), (4, 3, 2)) | ||
check((2, 3, 4), (2, 0, 0), (2, 3, 4)) | ||
check((2, 3, 4), (6, 1, -1), (6, 1, 4)) | ||
check((2, 3, 4), (3, -1, 8), (3, 1, 8)) | ||
check((2, 3, 4), (-1,), (24,)) | ||
check((2, 3, 4), (-2,), (2, 3, 4)) | ||
check((2, 3, 4), (2, -2), (2, 3, 4)) | ||
check((2, 3, 4), (-2, 1, 1), (2, 3, 4, 1, 1)) | ||
check((2, 3, 4), (-3, 4), (6, 4)) | ||
check((2, 3, 4, 5), (-3, -3), (6, 20)) | ||
check((2, 3, 4), (0, -3), (2, 12)) | ||
check((2, 3, 4), (-3, -2), (6, 4)) | ||
check((2, 3, 4), (-4, 1, 2, -2), (1, 2, 3, 4)) | ||
check((2, 3, 4), (2, -4, -1, 3, -2), (2, 1, 3, 4)) | ||
|
||
if __name__ == "__main__": | ||
test_dense() | ||
test_concatenate() | ||
test_batchnorm() | ||
test_flatten() | ||
test_reshape() |
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import nnvm.symbol as sym | ||
|
||
def test_reshape(): | ||
x = sym.Variable("x") | ||
y = sym.reshape(x, shape=(10, 20), name="y") | ||
assert(y.list_input_names() == ["x"]) | ||
|
||
|
||
def test_scalar_op(): | ||
x = sym.Variable("x") | ||
y = (1 / (x * 2) - 1) ** 2 | ||
assert(y.list_input_names() == ["x"]) | ||
|
||
def test_leaky_relu(): | ||
x = sym.Variable("x") | ||
y = sym.leaky_relu(x, alpha=0.1) | ||
assert(y.list_input_names() == ["x"]) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_scalar_op() | ||
test_reshape() | ||
test_leaky_relu() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters