diff --git a/src/operator/tensor/ordering_op-inl.h b/src/operator/tensor/ordering_op-inl.h index 16e6c0ecd3fc..a6f638e29321 100644 --- a/src/operator/tensor/ordering_op-inl.h +++ b/src/operator/tensor/ordering_op-inl.h @@ -433,6 +433,8 @@ void TopKImpl(RunContext ctx, // 3. Assign results to the ret blob // When returning indices, only update(modulo) required elements instead of full elements // to avoid redundant calculation. + // Cast `ret_indices` from int to real_t could introduce conversion error when the element_num + // is large enough. if (param.ret_typ == topk_enum::kReturnMask) { Tensor ret_mask = ret[0].get_with_shape(Shape2(ret[0].Size(), 1), s); @@ -452,20 +454,21 @@ void TopKImpl(RunContext ctx, } else if (param.ret_typ == topk_enum::kReturnIndices) { if (do_transpose) { Tensor ret_indices = ret[0].FlatTo3D(axis, axis, s); - ret_indices = tcast(transpose( - slice<2>(inplace_reshape(indices, - Shape3(ret_indices.shape_[0], - ret_indices.shape_[2], - element_num)), - 0, k), - Shape3(0, 2, 1))); - ret_indices = F(ret_indices, element_num); + ret_indices = tcast(F( + transpose(slice<2>(inplace_reshape(indices, + Shape3(ret_indices.shape_[0], + ret_indices.shape_[2], + element_num)), + 0, k), + Shape3(0, 2, 1)), + element_num)); } else { Tensor ret_indices = ret[0].get_with_shape(Shape2(batch_size, k), s); - ret_indices = tcast(slice<1>( - inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k)); - ret_indices = F(ret_indices, element_num); + ret_indices = tcast(F( + slice<1>(inplace_reshape(indices, Shape2(batch_size, element_num)), + 0, k), + element_num)); } } else { if (do_transpose) { @@ -476,23 +479,24 @@ void TopKImpl(RunContext ctx, Shape3(ret_value.shape_[0], ret_value.shape_[2], element_num)), 0, k), Shape3(0, 2, 1)); - ret_indices = tcast(transpose( - slice<2>(inplace_reshape(indices, - Shape3(ret_indices.shape_[0], - ret_indices.shape_[2], - element_num)), - 0, k), - Shape3(0, 2, 1))); - ret_indices = F(ret_indices, element_num); + ret_indices = tcast(F( + transpose(slice<2>(inplace_reshape(indices, + Shape3(ret_indices.shape_[0], + ret_indices.shape_[2], + element_num)), + 0, k), + Shape3(0, 2, 1)), + element_num)); } else { Tensor ret_value = ret[0].get_with_shape(Shape2(batch_size, k), s); Tensor ret_indices = ret[1].get_with_shape(Shape2(batch_size, k), s); ret_value = slice<1>(inplace_reshape(sorted_dat, Shape2(batch_size, element_num)), 0, k); - ret_indices = tcast(slice<1>( - inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k)); - ret_indices = F(ret_indices, element_num); + ret_indices = tcast(F( + slice<1>(inplace_reshape(indices, Shape2(batch_size, element_num)), + 0, k), + element_num)); } } }