Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Op] Add unbiased variance op and corresponding support in pytorch frontend #6232

Merged
merged 1 commit into from
Aug 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions include/tvm/relay/attrs/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,37 @@ struct ReduceAttrs : public tvm::AttrsNode<ReduceAttrs> {
"Whether to perform reduction on axis that are NOT in axis instead.");
}
};

struct VarianceAttrs : public tvm::AttrsNode<VarianceAttrs> {
masahi marked this conversation as resolved.
Show resolved Hide resolved
Array<Integer> axis;
bool keepdims;
bool exclude;
bool unbiased;

TVM_DECLARE_ATTRS(VarianceAttrs, "relay.attrs.VarianceAttrs") {
TVM_ATTR_FIELD(axis)
.set_default(NullValue<Array<Integer>>())
.describe(R"code(The axis or axes along which to perform the reduction.

The default, `axis=()`, will compute over all elements into a
scalar array with shape `(1,)`.

If `axis` is int, a reduction is performed on a particular axis.

If `axis` is a tuple of ints, a reduction is performed on all the axes
specified in the tuple.

If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.)code");

TVM_ATTR_FIELD(keepdims).set_default(false).describe(
"If this is set to `True`, the reduced axes are left "
"in the result as dimension with size one.");
TVM_ATTR_FIELD(exclude).set_default(false).describe(
"Whether to perform reduction on axis that are NOT in axis instead.");
TVM_ATTR_FIELD(unbiased).set_default(false).describe("Whether to use the unbiased estimation.");
}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_REDUCE_H_
25 changes: 10 additions & 15 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1262,28 +1262,23 @@ def _impl(inputs, input_types):
keepdims = bool(inputs[3])
unbiased = bool(inputs[2])

if unbiased:
msg = "Currently only supports standard-deviation calculated via the biased "\
"estimator. PyTorch's Bessel's correction is not supported."
raise NotImplementedError(msg)

return _op.reduce.std(data, axis=axis, keepdims=keepdims)
return _op.reduce.std(data, axis=axis, keepdims=keepdims, unbiased=unbiased)

return _impl

def _variance():
def _impl(inputs, input_types):
data = inputs[0]
axis = list(_infer_shape(inputs[1]))
keepdims = bool(inputs[3])
unbiased = bool(inputs[2])

if unbiased:
msg = "Currently only supports standard-deviation calculated via the biased "\
"estimator. PyTorch's Bessel's correction is not supported."
raise NotImplementedError(msg)
if len(inputs) == 2:
axis = None
keepdims = False
unbiased = bool(inputs[1])
else:
axis = list(_infer_shape(inputs[1]))
keepdims = bool(inputs[3])
unbiased = bool(inputs[2])

return _op.reduce.variance(data, axis=axis, keepdims=keepdims)
return _op.reduce.variance(data, axis=axis, keepdims=keepdims, unbiased=unbiased)

return _impl

Expand Down
15 changes: 11 additions & 4 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,16 +589,23 @@ def mean_grad(orig, grad):
def variance_grad(orig, grad):
"""Note that we take mean as an argument in the variance node"""
data, data_mean, axis = orig.args[0], orig.args[1], _get_reduce_axis(orig)
unbiased = orig.attrs.unbiased
shape = data.checked_type.concrete_shape
if axis is None:
axis = list(range(len(data.checked_type.concrete_shape)))
if not orig.attrs.keepdims:
grad = _unreduce_expand(grad, axis)
mult = 2.0
mult1 = 2.0
mult2 = -2.0
count = 1
for a in axis:
mult /= shape[a]
return [(grad * const(mult, dtype=data.checked_type.dtype)) * data,
const(-2, dtype=data.checked_type.dtype) * grad * data_mean]
count *= shape[a]
if unbiased:
mult2 = mult2 * count / (count - 1)
count -= 1
mult1 /= count
return [(grad * const(mult1, dtype=data.checked_type.dtype)) * data,
const(mult2, dtype=data.checked_type.dtype) * grad * data_mean]


@register_gradient("copy")
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,11 @@ class ReduceAttrs(Attrs):
"""Attributes used in reduction operators (e.g. sum)"""


@tvm._ffi.register_object("relay.attrs.VarianceAttrs")
class VarianceAttrs(Attrs):
"""Attributes used in reduction operators (e.g. sum)"""


@tvm._ffi.register_object("relay.attrs.RequantizeAttrs")
class RequantizeAttrs(Attrs):
"""Attributes used in requantize operators"""
Expand Down
18 changes: 12 additions & 6 deletions python/tvm/relay/op/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def mean(data, axis=None, keepdims=False, exclude=False):
return _make.mean(data, axis, keepdims, exclude)


def variance(data, axis=None, keepdims=False, exclude=False):
def variance(data, axis=None, keepdims=False, exclude=False, unbiased=False):
"""Computes the variance of data over given axes.

Parameters
Expand All @@ -334,17 +334,20 @@ def variance(data, axis=None, keepdims=False, exclude=False):
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.

unbiased : bool
If this is set to True, the unbiased estimation will be used.

Returns
-------
result : relay.Expr
The computed result.
"""
axis = [axis] if isinstance(axis, int) else axis
m = mean(data, axis, True, exclude)
return _make._variance(data, m, axis, keepdims, exclude)
return _make._variance(data, m, axis, keepdims, exclude, unbiased)


def std(data, axis=None, keepdims=False, exclude=False):
def std(data, axis=None, keepdims=False, exclude=False, unbiased=False):
"""Computes the standard deviation of data over given axes.

Parameters
Expand All @@ -366,14 +369,17 @@ def std(data, axis=None, keepdims=False, exclude=False):
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.

unbiased : bool
If this is set to True, the unbiased estimation will be used.

Returns
-------
result : relay.Expr
The computed result.
"""
axis = [axis] if isinstance(axis, int) else axis
m = mean(data, axis, True, exclude)
return sqrt(_make._variance(data, m, axis, keepdims, exclude))
return sqrt(_make._variance(data, m, axis, keepdims, exclude, unbiased))


def mean_variance(data, axis=None, keepdims=False, exclude=False):
Expand Down Expand Up @@ -405,7 +411,7 @@ def mean_variance(data, axis=None, keepdims=False, exclude=False):
"""
axis = [axis] if isinstance(axis, int) else axis
m = mean(data, axis, True, exclude)
var = _make._variance(data, m, axis, keepdims, exclude)
var = _make._variance(data, m, axis, keepdims, exclude, False)
if not keepdims:
m = squeeze(m)
return TupleWrapper(Tuple((m, var)), 2)
Expand Down Expand Up @@ -440,7 +446,7 @@ def mean_std(data, axis=None, keepdims=False, exclude=False):
"""
axis = [axis] if isinstance(axis, int) else axis
m = mean(data, axis, True, exclude)
s = sqrt(_make._variance(data, m, axis, keepdims, exclude))
s = sqrt(_make._variance(data, m, axis, keepdims, exclude, False))
if not keepdims:
m = squeeze(m)
return TupleWrapper(Tuple((m, s)), 2)
Expand Down
3 changes: 2 additions & 1 deletion src/relay/op/make_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ Expr MakeTile(Expr data, Array<Integer> reps);

Expr MakeTopK(Expr data, int k, int axis, String ret_type, bool is_ascend, DataType dtype);

Expr MakeVariance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, bool exclude);
Expr MakeVariance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, bool exclude,
bool unbiased);

Expr MakeZeros(Array<Integer> shape, DataType dtype);

Expand Down
31 changes: 22 additions & 9 deletions src/relay/op/tensor/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ namespace tvm {
namespace relay {

TVM_REGISTER_NODE_TYPE(ReduceAttrs);
TVM_REGISTER_NODE_TYPE(VarianceAttrs);

/*!
* \brief GetReduceAxes, get the new axis from indim and other arguments
Expand Down Expand Up @@ -193,12 +194,14 @@ Array<te::Tensor> ReduceCompute(const Attrs& attrs, const Array<te::Tensor>& inp
/*!
* \brief ReduceShapeImpl get the outshape for the reduction operator
* \param in_shape Shape of input data.
* \param param ReduceAttrs details.
* \param param Attrs details.
* \param reporter The reporter to report solution to.
* \return oshape Output shape inferred.
* \tparam AttrsType The attribute type.
*/
template <typename AttrsType>
inline std::vector<IndexExpr> ReduceShapeImpl(const std::vector<IndexExpr>& in_shape,
const ReduceAttrs* param,
const AttrsType* param,
const TypeReporter& reporter) {
masahi marked this conversation as resolved.
Show resolved Hide resolved
uint32_t indim = in_shape.size();
auto r_axes = GetReduceAxes(indim, param->axis, param->exclude);
Expand Down Expand Up @@ -542,7 +545,7 @@ bool VarianceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
std::vector<IndexExpr> mean_shape(mean->shape.begin(), mean->shape.end());
CHECK_EQ(in_shape.size(), mean_shape.size());

const ReduceAttrs* param = attrs.as<ReduceAttrs>();
const VarianceAttrs* param = attrs.as<VarianceAttrs>();
CHECK(param != nullptr);

// assign output type and shape
Expand All @@ -554,39 +557,49 @@ bool VarianceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
Array<te::Tensor> VarianceCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
IndexExpr count = tir::make_const(inputs[0]->dtype, 1);
const ReduceAttrs* param = attrs.as<ReduceAttrs>();
const VarianceAttrs* param = attrs.as<VarianceAttrs>();
CHECK(param != nullptr);
auto axes = param->axis;
bool unbiased = param->unbiased;
auto data = inputs[0];
auto mean = inputs[1];
for (int64_t i : GetReduceAxes(data->shape.size(), param->axis, param->exclude)) {
count *= data->shape[i];
}
if (unbiased) {
count -= 1;
}
std::vector<Integer> expand_shape;
auto sq_diff = topi::power(topi::subtract(data, mean), 2);
auto var = topi::divide(ReduceCompute(attrs, {sq_diff}, out_type, topi::sum)[0], count);
masahi marked this conversation as resolved.
Show resolved Hide resolved
if (param->exclude) {
axes = GetExcludeAxes(sq_diff->shape.size(), param->axis);
CHECK_NE(axes.size(), 0);
}
auto var = topi::divide(topi::sum(sq_diff, axes, param->keepdims, false), count);

return {var};
}

Expr MakeVariance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, bool exclude) {
auto attrs = make_object<ReduceAttrs>();
Expr MakeVariance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, bool exclude,
bool unbiased = false) {
auto attrs = make_object<VarianceAttrs>();
attrs->axis = std::move(axis);
attrs->keepdims = keepdims;
attrs->exclude = exclude;
attrs->unbiased = unbiased;
static const Op& op = Op::Get("variance");
return Call(op, {data, mean}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op._make._variance").set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 5>(MakeVariance, args, rv);
runtime::detail::unpack_call<Expr, 6>(MakeVariance, args, rv);
});

RELAY_REGISTER_OP("variance")
.describe(R"code(Computes the variance of array elements over given axes.

)code" TVM_ADD_FILELINE)
.set_attrs_type<ReduceAttrs>()
.set_attrs_type<VarianceAttrs>()
.set_support_level(4)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
Expand Down
5 changes: 3 additions & 2 deletions src/relay/transforms/pattern_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -580,8 +580,9 @@ inline Expr Mean(Expr data, Array<Integer> axis, bool keepdims, bool exclude) {
return MakeReduce(data, axis, keepdims, exclude, "mean");
}

inline Expr Variance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, bool exclude) {
return MakeVariance(data, mean, axis, keepdims, exclude);
inline Expr Variance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, bool exclude,
bool unbiased = false) {
return MakeVariance(data, mean, axis, keepdims, exclude, unbiased);
}

static inline Expr Where(const Expr& condition, const Expr& x, const Expr& y) {
Expand Down
35 changes: 35 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1873,13 +1873,28 @@ class Std6(Module):
def forward(self, *args):
return args[0].std(unbiased=False)

class Std7(Module):
def forward(self, *args):
return args[0].std(dim=1, keepdim=False, unbiased=True)

class Std8(Module):
def forward(self, *args):
return args[0].std(dim=(2,3), keepdim=True, unbiased=True)

class Std9(Module):
def forward(self, *args):
return args[0].std(unbiased=True)

input_data = torch.rand(input_shape).float()
verify_model(Std1().float().eval(), input_data=input_data)
verify_model(Std2().float().eval(), input_data=input_data)
verify_model(Std3().float().eval(), input_data=input_data)
verify_model(Std4().float().eval(), input_data=input_data)
verify_model(Std5().float().eval(), input_data=input_data)
verify_model(Std6().float().eval(), input_data=input_data)
verify_model(Std7().float().eval(), input_data=input_data)
verify_model(Std8().float().eval(), input_data=input_data)
verify_model(Std9().float().eval(), input_data=input_data)


def test_forward_variance():
Expand All @@ -1906,12 +1921,32 @@ class Variance5(Module):
def forward(self, *args):
return args[0].var(dim=(2,3), keepdim=False, unbiased=False)

class Variance6(Module):
def forward(self, *args):
return args[0].var(unbiased=False)

class Variance7(Module):
def forward(self, *args):
return args[0].var(dim=1, keepdim=False, unbiased=True)

class Variance8(Module):
def forward(self, *args):
return args[0].var(dim=(2,3), keepdim=True, unbiased=True)

class Variance9(Module):
def forward(self, *args):
return args[0].var(unbiased=True)

input_data = torch.rand(input_shape).float()
verify_model(Variance1().float().eval(), input_data=input_data)
verify_model(Variance2().float().eval(), input_data=input_data)
verify_model(Variance3().float().eval(), input_data=input_data)
verify_model(Variance4().float().eval(), input_data=input_data)
verify_model(Variance5().float().eval(), input_data=input_data)
verify_model(Variance6().float().eval(), input_data=input_data)
verify_model(Variance7().float().eval(), input_data=input_data)
verify_model(Variance8().float().eval(), input_data=input_data)
verify_model(Variance9().float().eval(), input_data=input_data)


def test_forward_rsub():
Expand Down
5 changes: 4 additions & 1 deletion tests/python/relay/test_op_grad_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ def verify_reduction_grad(red_fn, d_shape, axis=None, keepdims=False, exclude=Fa


def test_reduction_grad():
for op in (relay.sum, relay.variance, relay.mean):
def _unbiased_variance(x, axis=None, keepdims=False, exclude=False):
return relay.variance(x, axis=axis, keepdims=keepdims, exclude=exclude, unbiased=True)

for op in (relay.sum, relay.variance, _unbiased_variance, relay.mean):
verify_reduction_grad(op, (4, 2))
verify_reduction_grad(op, (4, 2), axis=-1, keepdims=True)
verify_reduction_grad(op, (4, 2, 1), axis=(1, 2), exclude=True)
Expand Down
12 changes: 12 additions & 0 deletions tests/python/relay/test_op_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,14 +225,26 @@ def _np_log_sum_exp(x, axis, keepdims=False):
if not keepdims:
x = np.squeeze(x, axis=axis)
return x

def _unbiased_relay_wrapper(f):
def _unbiased_func(x, axis=None, keepdims=False, exclude=False):
return f(x, axis=axis, keepdims=keepdims, exclude=exclude, unbiased=True)
return _unbiased_func

def _unbiased_np_wrapper(f):
def _unbiased_func(a, axis=None, dtype=None, keepdims=None):
return f(a, axis=axis, dtype=dtype, ddof=1, keepdims=keepdims)
return _unbiased_func

d1, d2, d3, d4 = te.var("d1"), te.var("d2"), te.var("d3"), te.var("d4")
for func in [[relay.sum, np.sum],
[relay.max, np.max],
[relay.min, np.min],
[relay.mean, np.mean],
[relay.variance, np.var],
[_unbiased_relay_wrapper(relay.variance), _unbiased_np_wrapper(np.var)],
[relay.std, np.std],
[_unbiased_relay_wrapper(relay.std), _unbiased_np_wrapper(np.std)],
[relay.prod, np.prod],
[relay.all, np.all],
[relay.any, np.any],
Expand Down