Skip to content

Commit

Permalink
[Relax][OP] More high-level operators (tlc-pack#18)
Browse files Browse the repository at this point in the history
* relax.cumsum

* Legalizer for expand_dims

* relax.trilu

* relax.cast

* Legalizer for batch_norm and flatten

* relax.take

* relax.full

* relax.split

* relax.broadcast_to

* relax.strided_slice

* relax.image.resize2d

* relax.nn.max_pool2d

* relax.nn.adaptive_avg_pool2d
  • Loading branch information
MasterJH5574 committed Dec 14, 2022
1 parent fd64f2d commit 81f831b
Show file tree
Hide file tree
Showing 21 changed files with 2,728 additions and 112 deletions.
165 changes: 165 additions & 0 deletions include/tvm/relax/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,171 @@ struct ReduceAttrs : public tvm::AttrsNode<ReduceAttrs> {
}
}; // struct ReduceAttrs

/*! \brief Attributes used in cumsum operator */
struct CumsumAttrs : public tvm::AttrsNode<CumsumAttrs> {
Optional<Integer> axis;

TVM_DECLARE_ATTRS(CumsumAttrs, "relax.attrs.CumsumAttrs") {
TVM_ATTR_FIELD(axis).set_default(Optional<Integer>{NullOpt});
}
}; // struct CumsumAttrs

/*! \brief Attributes used in trilu operator */
struct TriluAttrs : public tvm::AttrsNode<TriluAttrs> {
int k;
bool is_upper;

TVM_DECLARE_ATTRS(TriluAttrs, "relax.attrs.TriluAttrs") {
TVM_ATTR_FIELD(k).describe(
"The number of diagonals above or below the main diagonal to exclude or include.");
TVM_ATTR_FIELD(is_upper).set_default(true).describe(
"Whether to keep the upper or lower half of the diagonal.");
}
}; // struct TriluAttrs

/*! \brief Attributes used in cast operator */
struct CastAttrs : public tvm::AttrsNode<CastAttrs> {
DataType dtype;

TVM_DECLARE_ATTRS(CastAttrs, "relax.attrs.CastAttrs") {
TVM_ATTR_FIELD(dtype).describe("Target data type");
}
}; // struct CastAttrs.

/*! \brief Attributes used in take operator */
struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
Optional<Integer> axis;
int batch_dims;
String mode;

TVM_DECLARE_ATTRS(TakeAttrs, "relax.attrs.TakeAttrs") {
TVM_ATTR_FIELD(axis)
.set_default(Optional<Integer>{NullOpt})
.describe("The axis over which to select values.");
TVM_ATTR_FIELD(batch_dims)
.set_default(0)
.describe("The batch_dims over which to select values.");
TVM_ATTR_FIELD(mode).set_default("clip").describe(
"Specify how out-of-bound indices will behave."
"clip - clip to the range (default)"
"wrap - wrap around the indices"
"fast - no clip or wrap around (user must make sure indices are in-bound)");
}
}; // struct TakeAttrs

/*! \brief Attributes used in full operator */
struct FullAttrs : public tvm::AttrsNode<FullAttrs> {
DataType dtype;

TVM_DECLARE_ATTRS(FullAttrs, "relax.attrs.FullAttrs") {
TVM_ATTR_FIELD(dtype).describe("Target data type.");
}
}; // struct FullAttrs

/*! \brief Attributes used in split operator */
struct SplitAttrs : public tvm::AttrsNode<SplitAttrs> {
ObjectRef indices_or_sections;
int axis;

TVM_DECLARE_ATTRS(SplitAttrs, "relax.attrs.SplitAttrs") {
TVM_ATTR_FIELD(indices_or_sections)
.describe("The input array of indices or the number of split sections.");
TVM_ATTR_FIELD(axis).describe("The axis to be splitted");
}
}; // struct SplitAttrs

/*! \brief Attributes used in strided_slice operator */
struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> {
Array<PrimExpr> begin;
Array<PrimExpr> end;
Optional<Array<PrimExpr>> strides;
Optional<Array<Integer>> axes;
String slice_mode;

TVM_DECLARE_ATTRS(StridedSliceAttrs, "relax.attrs.StridedSliceAttrs") {
TVM_ATTR_FIELD(begin).describe("Indices for begin of slice, begin index is also inclusive");
TVM_ATTR_FIELD(end).describe("Indices for end of slice, end index is exclusive");
TVM_ATTR_FIELD(strides).describe(
"Stride values of the slice, a stride can be negative, which causes a reverse slice.");
TVM_ATTR_FIELD(axes).describe(
"Axes along which slicing is applied. When it is specified, the length of begin, end, "
"strides, and axes must be equal.");
TVM_ATTR_FIELD(slice_mode)
.set_default("end")
.describe(
"The slice mode [end, size]."
"end - The default slice mode, ending indices for the slice."
"size - The input strides will be ignored, input end in this mode indicates the size"
"of a slice starting at the location specified by begin. If end[i] is -1,"
"all remaining elements in that dimension are included in the slice");
}
}; // struct StridedSliceAttrs

/*! \brief Attributes used in image resize2d operator */
struct Resize2DAttrs : public tvm::AttrsNode<Resize2DAttrs> {
Array<PrimExpr> size;
Array<FloatImm> roi;
String layout;
String method;
String coordinate_transformation_mode;
String rounding_method;
double cubic_alpha;
int cubic_exclude;
double extrapolation_value;

TVM_DECLARE_ATTRS(Resize2DAttrs, "relax.attrs.Resize2DAttrs") {
TVM_ATTR_FIELD(size).describe("Output image size.");
TVM_ATTR_FIELD(roi).describe(
"Region of Interest for coordinate transformation mode 'tf_crop_and_resize'");
TVM_ATTR_FIELD(layout).set_default("NCHW").describe(
"Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Resize is applied on the 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(method).set_default("linear").describe(
"Specify the mode to use for scaling."
"nearest_neighbor - Nearest Neighbor"
"linear - Bilinear Interpolation"
"cubic - Bicubic Interpolation");
TVM_ATTR_FIELD(coordinate_transformation_mode)
.set_default("half_pixel")
.describe(
"Describes how to transform the coordinate in the resized tensor"
"to the coordinate in the original tensor."
"Refer to the ONNX Resize operator specification for details"
"Available options are half_pixel, align_corners and asymmetric");
TVM_ATTR_FIELD(rounding_method)
.set_default("round")
.describe(
"indicates how to find the \"nearest\" pixel in nearest_neighbor method"
"Available options are round, floor, and ceil.");
TVM_ATTR_FIELD(cubic_alpha)
.set_default(-0.5)
.describe("Spline Coefficient for Bicubic Interpolation");
TVM_ATTR_FIELD(cubic_exclude)
.set_default(0)
.describe("Flag to exclude exterior of the image during bicubic interpolation");
TVM_ATTR_FIELD(extrapolation_value)
.set_default(0.0)
.describe("Value to return when roi is outside of the image");
}
}; // struct Resize2dAttrs

/*! \brief Attributes for 2d adaptive pool operator */
struct AdaptivePool2DAttrs : public tvm::AttrsNode<AdaptivePool2DAttrs> {
Optional<Array<PrimExpr>> output_size;
String layout;

TVM_DECLARE_ATTRS(AdaptivePool2DAttrs, "relax.attrs.AdaptivePool2DAttrs") {
TVM_ATTR_FIELD(output_size).describe("Output height and width.");
TVM_ATTR_FIELD(layout).describe(
"Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Pooling is applied on the 'H' and"
"'W' dimensions.");
}
}; // struct AdaptivePool2DAttrs

} // namespace relax
} // namespace tvm
#endif // TVM_RELAX_OP_ATTR_TYPES_H_
4 changes: 3 additions & 1 deletion python/tvm/relax/block_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,9 @@ def _convert_te_arg_helper(arg):
), "emit_te only supports dict with string as the key currently"
return {k: _convert_te_arg_helper(arg[k]) for k in arg}
elif (
isinstance(arg, (int, float, str, tir.IntImm, tvm.ir.Type, tvm.ir.Attrs))
isinstance(
arg, (int, float, str, tir.IntImm, tir.FloatImm, tvm.ir.Type, tvm.ir.Attrs)
)
or arg is None
):
return arg
Expand Down
Loading

0 comments on commit 81f831b

Please sign in to comment.