diff --git a/include/tvm/relay/attrs/reduce.h b/include/tvm/relay/attrs/reduce.h index f57c1f4ddc58..14b75ff1c0a8 100644 --- a/include/tvm/relay/attrs/reduce.h +++ b/include/tvm/relay/attrs/reduce.h @@ -60,6 +60,37 @@ struct ReduceAttrs : public tvm::AttrsNode { "Whether to perform reduction on axis that are NOT in axis instead."); } }; + +struct VarianceAttrs : public tvm::AttrsNode { + Array axis; + bool keepdims; + bool exclude; + bool unbiased; + + TVM_DECLARE_ATTRS(VarianceAttrs, "relay.attrs.VarianceAttrs") { + TVM_ATTR_FIELD(axis) + .set_default(NullValue>()) + .describe(R"code(The axis or axes along which to perform the reduction. + + The default, `axis=()`, will compute over all elements into a + scalar array with shape `(1,)`. + + If `axis` is int, a reduction is performed on a particular axis. + + If `axis` is a tuple of ints, a reduction is performed on all the axes + specified in the tuple. + + If `exclude` is true, reduction will be performed on the axes that are + NOT in axis instead.)code"); + + TVM_ATTR_FIELD(keepdims).set_default(false).describe( + "If this is set to `True`, the reduced axes are left " + "in the result as dimension with size one."); + TVM_ATTR_FIELD(exclude).set_default(false).describe( + "Whether to perform reduction on axis that are NOT in axis instead."); + TVM_ATTR_FIELD(unbiased).set_default(false).describe("Whether to use the unbiased estimation."); + } +}; } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_REDUCE_H_ diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index bbc684ea8a4c..a1cabcd5ae22 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1262,28 +1262,23 @@ def _impl(inputs, input_types): keepdims = bool(inputs[3]) unbiased = bool(inputs[2]) - if unbiased: - msg = "Currently only supports standard-deviation calculated via the biased "\ - "estimator. PyTorch's Bessel's correction is not supported." - raise NotImplementedError(msg) - - return _op.reduce.std(data, axis=axis, keepdims=keepdims) + return _op.reduce.std(data, axis=axis, keepdims=keepdims, unbiased=unbiased) return _impl def _variance(): def _impl(inputs, input_types): data = inputs[0] - axis = list(_infer_shape(inputs[1])) - keepdims = bool(inputs[3]) - unbiased = bool(inputs[2]) - - if unbiased: - msg = "Currently only supports standard-deviation calculated via the biased "\ - "estimator. PyTorch's Bessel's correction is not supported." - raise NotImplementedError(msg) + if len(inputs) == 2: + axis = None + keepdims = False + unbiased = bool(inputs[1]) + else: + axis = list(_infer_shape(inputs[1])) + keepdims = bool(inputs[3]) + unbiased = bool(inputs[2]) - return _op.reduce.variance(data, axis=axis, keepdims=keepdims) + return _op.reduce.variance(data, axis=axis, keepdims=keepdims, unbiased=unbiased) return _impl diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index aee860392723..46a45354a9cc 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -589,16 +589,23 @@ def mean_grad(orig, grad): def variance_grad(orig, grad): """Note that we take mean as an argument in the variance node""" data, data_mean, axis = orig.args[0], orig.args[1], _get_reduce_axis(orig) + unbiased = orig.attrs.unbiased shape = data.checked_type.concrete_shape if axis is None: axis = list(range(len(data.checked_type.concrete_shape))) if not orig.attrs.keepdims: grad = _unreduce_expand(grad, axis) - mult = 2.0 + mult1 = 2.0 + mult2 = -2.0 + count = 1 for a in axis: - mult /= shape[a] - return [(grad * const(mult, dtype=data.checked_type.dtype)) * data, - const(-2, dtype=data.checked_type.dtype) * grad * data_mean] + count *= shape[a] + if unbiased: + mult2 = mult2 * count / (count - 1) + count -= 1 + mult1 /= count + return [(grad * const(mult1, dtype=data.checked_type.dtype)) * data, + const(mult2, dtype=data.checked_type.dtype) * grad * data_mean] @register_gradient("copy") diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 32540a56491b..7f9198994974 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -460,6 +460,11 @@ class ReduceAttrs(Attrs): """Attributes used in reduction operators (e.g. sum)""" +@tvm._ffi.register_object("relay.attrs.VarianceAttrs") +class VarianceAttrs(Attrs): + """Attributes used in reduction operators (e.g. sum)""" + + @tvm._ffi.register_object("relay.attrs.RequantizeAttrs") class RequantizeAttrs(Attrs): """Attributes used in requantize operators""" diff --git a/python/tvm/relay/op/reduce.py b/python/tvm/relay/op/reduce.py index 988c94928d33..99189f8fabac 100644 --- a/python/tvm/relay/op/reduce.py +++ b/python/tvm/relay/op/reduce.py @@ -312,7 +312,7 @@ def mean(data, axis=None, keepdims=False, exclude=False): return _make.mean(data, axis, keepdims, exclude) -def variance(data, axis=None, keepdims=False, exclude=False): +def variance(data, axis=None, keepdims=False, exclude=False, unbiased=False): """Computes the variance of data over given axes. Parameters @@ -334,6 +334,9 @@ def variance(data, axis=None, keepdims=False, exclude=False): If `exclude` is true, reduction will be performed on the axes that are NOT in axis instead. + unbiased : bool + If this is set to True, the unbiased estimation will be used. + Returns ------- result : relay.Expr @@ -341,10 +344,10 @@ def variance(data, axis=None, keepdims=False, exclude=False): """ axis = [axis] if isinstance(axis, int) else axis m = mean(data, axis, True, exclude) - return _make._variance(data, m, axis, keepdims, exclude) + return _make._variance(data, m, axis, keepdims, exclude, unbiased) -def std(data, axis=None, keepdims=False, exclude=False): +def std(data, axis=None, keepdims=False, exclude=False, unbiased=False): """Computes the standard deviation of data over given axes. Parameters @@ -366,6 +369,9 @@ def std(data, axis=None, keepdims=False, exclude=False): If `exclude` is true, reduction will be performed on the axes that are NOT in axis instead. + unbiased : bool + If this is set to True, the unbiased estimation will be used. + Returns ------- result : relay.Expr @@ -373,7 +379,7 @@ def std(data, axis=None, keepdims=False, exclude=False): """ axis = [axis] if isinstance(axis, int) else axis m = mean(data, axis, True, exclude) - return sqrt(_make._variance(data, m, axis, keepdims, exclude)) + return sqrt(_make._variance(data, m, axis, keepdims, exclude, unbiased)) def mean_variance(data, axis=None, keepdims=False, exclude=False): @@ -405,7 +411,7 @@ def mean_variance(data, axis=None, keepdims=False, exclude=False): """ axis = [axis] if isinstance(axis, int) else axis m = mean(data, axis, True, exclude) - var = _make._variance(data, m, axis, keepdims, exclude) + var = _make._variance(data, m, axis, keepdims, exclude, False) if not keepdims: m = squeeze(m) return TupleWrapper(Tuple((m, var)), 2) @@ -440,7 +446,7 @@ def mean_std(data, axis=None, keepdims=False, exclude=False): """ axis = [axis] if isinstance(axis, int) else axis m = mean(data, axis, True, exclude) - s = sqrt(_make._variance(data, m, axis, keepdims, exclude)) + s = sqrt(_make._variance(data, m, axis, keepdims, exclude, False)) if not keepdims: m = squeeze(m) return TupleWrapper(Tuple((m, s)), 2) diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index c03a7bff854f..8ca22039974c 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -74,7 +74,8 @@ Expr MakeTile(Expr data, Array reps); Expr MakeTopK(Expr data, int k, int axis, String ret_type, bool is_ascend, DataType dtype); -Expr MakeVariance(Expr data, Expr mean, Array axis, bool keepdims, bool exclude); +Expr MakeVariance(Expr data, Expr mean, Array axis, bool keepdims, bool exclude, + bool unbiased); Expr MakeZeros(Array shape, DataType dtype); diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index 9fd140092954..e16ecb6d7119 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -38,6 +38,7 @@ namespace tvm { namespace relay { TVM_REGISTER_NODE_TYPE(ReduceAttrs); +TVM_REGISTER_NODE_TYPE(VarianceAttrs); /*! * \brief GetReduceAxes, get the new axis from indim and other arguments @@ -193,12 +194,14 @@ Array ReduceCompute(const Attrs& attrs, const Array& inp /*! * \brief ReduceShapeImpl get the outshape for the reduction operator * \param in_shape Shape of input data. - * \param param ReduceAttrs details. + * \param param Attrs details. * \param reporter The reporter to report solution to. * \return oshape Output shape inferred. + * \tparam AttrsType The attribute type. */ +template inline std::vector ReduceShapeImpl(const std::vector& in_shape, - const ReduceAttrs* param, + const AttrsType* param, const TypeReporter& reporter) { uint32_t indim = in_shape.size(); auto r_axes = GetReduceAxes(indim, param->axis, param->exclude); @@ -542,7 +545,7 @@ bool VarianceRel(const Array& types, int num_inputs, const Attrs& attrs, std::vector mean_shape(mean->shape.begin(), mean->shape.end()); CHECK_EQ(in_shape.size(), mean_shape.size()); - const ReduceAttrs* param = attrs.as(); + const VarianceAttrs* param = attrs.as(); CHECK(param != nullptr); // assign output type and shape @@ -554,39 +557,49 @@ bool VarianceRel(const Array& types, int num_inputs, const Attrs& attrs, Array VarianceCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { IndexExpr count = tir::make_const(inputs[0]->dtype, 1); - const ReduceAttrs* param = attrs.as(); + const VarianceAttrs* param = attrs.as(); CHECK(param != nullptr); auto axes = param->axis; + bool unbiased = param->unbiased; auto data = inputs[0]; auto mean = inputs[1]; for (int64_t i : GetReduceAxes(data->shape.size(), param->axis, param->exclude)) { count *= data->shape[i]; } + if (unbiased) { + count -= 1; + } std::vector expand_shape; auto sq_diff = topi::power(topi::subtract(data, mean), 2); - auto var = topi::divide(ReduceCompute(attrs, {sq_diff}, out_type, topi::sum)[0], count); + if (param->exclude) { + axes = GetExcludeAxes(sq_diff->shape.size(), param->axis); + CHECK_NE(axes.size(), 0); + } + auto var = topi::divide(topi::sum(sq_diff, axes, param->keepdims, false), count); return {var}; } -Expr MakeVariance(Expr data, Expr mean, Array axis, bool keepdims, bool exclude) { - auto attrs = make_object(); +Expr MakeVariance(Expr data, Expr mean, Array axis, bool keepdims, bool exclude, + bool unbiased = false) { + auto attrs = make_object(); attrs->axis = std::move(axis); attrs->keepdims = keepdims; attrs->exclude = exclude; + attrs->unbiased = unbiased; static const Op& op = Op::Get("variance"); return Call(op, {data, mean}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make._variance").set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeVariance, args, rv); + runtime::detail::unpack_call(MakeVariance, args, rv); }); RELAY_REGISTER_OP("variance") .describe(R"code(Computes the variance of array elements over given axes. )code" TVM_ADD_FILELINE) - .set_attrs_type() + .set_attrs_type() .set_support_level(4) .set_num_inputs(2) .add_argument("data", "Tensor", "The input tensor.") diff --git a/src/relay/transforms/pattern_util.h b/src/relay/transforms/pattern_util.h index b3e36818870a..ee655037bda0 100644 --- a/src/relay/transforms/pattern_util.h +++ b/src/relay/transforms/pattern_util.h @@ -580,8 +580,9 @@ inline Expr Mean(Expr data, Array axis, bool keepdims, bool exclude) { return MakeReduce(data, axis, keepdims, exclude, "mean"); } -inline Expr Variance(Expr data, Expr mean, Array axis, bool keepdims, bool exclude) { - return MakeVariance(data, mean, axis, keepdims, exclude); +inline Expr Variance(Expr data, Expr mean, Array axis, bool keepdims, bool exclude, + bool unbiased = false) { + return MakeVariance(data, mean, axis, keepdims, exclude, unbiased); } static inline Expr Where(const Expr& condition, const Expr& x, const Expr& y) { diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 3c9dfb13fc4c..ae03a70379c6 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1873,6 +1873,18 @@ class Std6(Module): def forward(self, *args): return args[0].std(unbiased=False) + class Std7(Module): + def forward(self, *args): + return args[0].std(dim=1, keepdim=False, unbiased=True) + + class Std8(Module): + def forward(self, *args): + return args[0].std(dim=(2,3), keepdim=True, unbiased=True) + + class Std9(Module): + def forward(self, *args): + return args[0].std(unbiased=True) + input_data = torch.rand(input_shape).float() verify_model(Std1().float().eval(), input_data=input_data) verify_model(Std2().float().eval(), input_data=input_data) @@ -1880,6 +1892,9 @@ def forward(self, *args): verify_model(Std4().float().eval(), input_data=input_data) verify_model(Std5().float().eval(), input_data=input_data) verify_model(Std6().float().eval(), input_data=input_data) + verify_model(Std7().float().eval(), input_data=input_data) + verify_model(Std8().float().eval(), input_data=input_data) + verify_model(Std9().float().eval(), input_data=input_data) def test_forward_variance(): @@ -1906,12 +1921,32 @@ class Variance5(Module): def forward(self, *args): return args[0].var(dim=(2,3), keepdim=False, unbiased=False) + class Variance6(Module): + def forward(self, *args): + return args[0].var(unbiased=False) + + class Variance7(Module): + def forward(self, *args): + return args[0].var(dim=1, keepdim=False, unbiased=True) + + class Variance8(Module): + def forward(self, *args): + return args[0].var(dim=(2,3), keepdim=True, unbiased=True) + + class Variance9(Module): + def forward(self, *args): + return args[0].var(unbiased=True) + input_data = torch.rand(input_shape).float() verify_model(Variance1().float().eval(), input_data=input_data) verify_model(Variance2().float().eval(), input_data=input_data) verify_model(Variance3().float().eval(), input_data=input_data) verify_model(Variance4().float().eval(), input_data=input_data) verify_model(Variance5().float().eval(), input_data=input_data) + verify_model(Variance6().float().eval(), input_data=input_data) + verify_model(Variance7().float().eval(), input_data=input_data) + verify_model(Variance8().float().eval(), input_data=input_data) + verify_model(Variance9().float().eval(), input_data=input_data) def test_forward_rsub(): diff --git a/tests/python/relay/test_op_grad_level4.py b/tests/python/relay/test_op_grad_level4.py index 956c6af8d5cb..b35ffe923b24 100644 --- a/tests/python/relay/test_op_grad_level4.py +++ b/tests/python/relay/test_op_grad_level4.py @@ -26,7 +26,10 @@ def verify_reduction_grad(red_fn, d_shape, axis=None, keepdims=False, exclude=Fa def test_reduction_grad(): - for op in (relay.sum, relay.variance, relay.mean): + def _unbiased_variance(x, axis=None, keepdims=False, exclude=False): + return relay.variance(x, axis=axis, keepdims=keepdims, exclude=exclude, unbiased=True) + + for op in (relay.sum, relay.variance, _unbiased_variance, relay.mean): verify_reduction_grad(op, (4, 2)) verify_reduction_grad(op, (4, 2), axis=-1, keepdims=True) verify_reduction_grad(op, (4, 2, 1), axis=(1, 2), exclude=True) diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index c800b1c947c6..8e01fa2a89cd 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -225,6 +225,16 @@ def _np_log_sum_exp(x, axis, keepdims=False): if not keepdims: x = np.squeeze(x, axis=axis) return x + + def _unbiased_relay_wrapper(f): + def _unbiased_func(x, axis=None, keepdims=False, exclude=False): + return f(x, axis=axis, keepdims=keepdims, exclude=exclude, unbiased=True) + return _unbiased_func + + def _unbiased_np_wrapper(f): + def _unbiased_func(a, axis=None, dtype=None, keepdims=None): + return f(a, axis=axis, dtype=dtype, ddof=1, keepdims=keepdims) + return _unbiased_func d1, d2, d3, d4 = te.var("d1"), te.var("d2"), te.var("d3"), te.var("d4") for func in [[relay.sum, np.sum], @@ -232,7 +242,9 @@ def _np_log_sum_exp(x, axis, keepdims=False): [relay.min, np.min], [relay.mean, np.mean], [relay.variance, np.var], + [_unbiased_relay_wrapper(relay.variance), _unbiased_np_wrapper(np.var)], [relay.std, np.std], + [_unbiased_relay_wrapper(relay.std), _unbiased_np_wrapper(np.std)], [relay.prod, np.prod], [relay.all, np.all], [relay.any, np.any],