Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI][TIR][TE][x86] Extend x86 SIMD (u)int8 coverage for dense & conv2d #15918

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,21 @@ TVM_DLL const Op& ret();
*/
TVM_DLL const Op& reinterpret();

/*!
* \brief Zero extend the value using the target type.
*/
TVM_DLL const Op& zextend();

/*!
* \brief Sign extend the value using the target type.
*/
TVM_DLL const Op& sextend();

/*!
* \brief Truncate the value using the target type.
*/
TVM_DLL const Op& truncate();

/*!
* \brief Marks a condition is likely going to happen.
*/
Expand Down Expand Up @@ -769,9 +784,20 @@ TVM_DLL const Op& vectorlow();
TVM_DLL const Op& vectorcombine();

/*!
* \brief atomic add instruction, corresponding e.g. to atomicAdd in CUDA
* \brief Shuffle two vectors using indices.
*/
TVM_DLL const Op& vectorshuffle();

/*!
* \brief Permute vector using indices.
*/
TVM_DLL const Op& vectorpermute();

/*!
* \brief Atomic add instruction.
*/
TVM_DLL const Op& atomic_add();

/*!
* \brief Create an Nd memory allocation with storage scope
*/
Expand Down
33 changes: 33 additions & 0 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,39 @@ class StringImm : public PrimExpr {
TVM_DEFINE_OBJECT_REF_COW_METHOD(StringImmNode);
};

/*! \brief Array of integer constants */
class ArrayIntImmNode : public PrimExprNode {
public:
/*! \brief The constant value content. */
Array<Integer> data;

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("data", &data);
v->Visit("span", &span);
}

bool SEqualReduce(const ArrayIntImmNode* other, SEqualReducer equal) const {
return equal(data, other->data);
}

void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(data); }

static constexpr const char* _type_key = "tir.ArrayIntImm";
TVM_DECLARE_FINAL_OBJECT_INFO(ArrayIntImmNode, PrimExprNode);
};

/*!
* \brief Managed reference to ArrayIntImmNode.
* \sa ArrayIntImmNode
*/
class ArrayIntImm : public PrimExpr {
public:
TVM_DLL ArrayIntImm(Array<Integer> data, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(ArrayIntImm, PrimExpr, ArrayIntImmNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ArrayIntImmNode);
};

/*!
* \brief Cast value from one data type to another.
* \note The lanes of value should keep fixed.
Expand Down
4 changes: 4 additions & 0 deletions include/tvm/tir/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
virtual R VisitExpr_(const IntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const FloatImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const ArrayIntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const AnyNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Object* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
Expand Down Expand Up @@ -192,6 +193,7 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
IR_EXPR_FUNCTOR_DISPATCH(IntImmNode);
IR_EXPR_FUNCTOR_DISPATCH(FloatImmNode);
IR_EXPR_FUNCTOR_DISPATCH(StringImmNode);
IR_EXPR_FUNCTOR_DISPATCH(ArrayIntImmNode);
IR_EXPR_FUNCTOR_DISPATCH(AnyNode);
return vtable;
}
Expand Down Expand Up @@ -243,6 +245,7 @@ class TVM_DLL ExprVisitor : public ExprFunctor<void(const PrimExpr&)> {
void VisitExpr_(const IntImmNode* op) override;
void VisitExpr_(const FloatImmNode* op) override;
void VisitExpr_(const StringImmNode* op) override;
void VisitExpr_(const ArrayIntImmNode* op) override;
void VisitExpr_(const AnyNode* op) override;
};

Expand Down Expand Up @@ -289,6 +292,7 @@ class TVM_DLL ExprMutator : protected ExprFunctor<PrimExpr(const PrimExpr&)> {
PrimExpr VisitExpr_(const IntImmNode* op) override;
PrimExpr VisitExpr_(const FloatImmNode* op) override;
PrimExpr VisitExpr_(const StringImmNode* op) override;
PrimExpr VisitExpr_(const ArrayIntImmNode* op) override;
PrimExpr VisitExpr_(const AnyNode* op) override;
};

Expand Down
2 changes: 2 additions & 0 deletions python/tvm/autotvm/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def _encode(x):
return x
if isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)):
return x.value
if isinstance(x, expr.ArrayIntImm):
return x.data
if isinstance(x, runtime.container.String):
return str(x)
if x is None:
Expand Down
1 change: 1 addition & 0 deletions python/tvm/ir/json_compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def _convert(item, nodes):
"Variable": [_update_tir_var("tir.Var"), _update_from_std_str("name")],
"SizeVar": [_update_tir_var("tir.SizeVar"), _update_from_std_str("name")],
"StringImm": [_rename("tir.StringImm"), _update_from_std_str("value")],
"ArrayIntImm": [_rename("tir.ArrayIntImm"), _update_from_std_str("data")],
"Cast": _rename("tir.Cast"),
"Add": _rename("tir.Add"),
"Sub": _rename("tir.Sub"),
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def legalize_dense(attrs, inputs, types):
Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current convolution
Attributes of current dense operation
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
Expand Down
20 changes: 13 additions & 7 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tvm.meta_schedule import is_meta_schedule_enabled
from tvm.relay.ty import is_dynamic
from tvm.te import SpecializedCondition
from tvm.target.x86 import get_x86_simd_32bit_lanes

from .. import op as _op
from .generic import *
Expand Down Expand Up @@ -588,11 +589,12 @@ def dense_strategy_cpu(attrs, inputs, out_type, target):
def dense_pack_strategy_cpu(attrs, inputs, out_type, target):
"""dense_pack x86 strategy"""
strategy = _op.OpStrategy()
vec_width = get_x86_simd_32bit_lanes() if get_x86_simd_32bit_lanes() else 16
if (
inputs[0].dtype == "uint8"
and inputs[1].dtype == "int8"
inputs[0].dtype in ("uint8", "int8")
and inputs[1].dtype in ("int8", "uint8")
and out_type.dtype == "int32"
and attrs["weight_layout"] == "NC16n4c"
and attrs["weight_layout"] == f"NC{vec_width}n4c"
):
strategy.add_implementation(
wrap_compute_dense(topi.x86.dense_int8),
Expand Down Expand Up @@ -622,10 +624,14 @@ def batch_matmul_strategy_cpu(attrs, inputs, out_type, target):
if (
not attrs.transpose_a
and attrs.transpose_b
and inputs[0].dtype == "uint8"
and inputs[1].dtype == "int8"
and inputs[1].shape[-2] % 16 == 0
and inputs[1].shape[-1] % 4 == 0
and inputs[0].dtype in ("uint8", "int8")
and inputs[1].dtype in ("int8", "uint8")
and (
# legalized SIMD
get_x86_simd_32bit_lanes()
# unknown SIMD
or (inputs[1].shape[-2] % 16 == 0 and inputs[1].shape[-1] % 4 == 0)
)
):
strategy.add_implementation(
wrap_compute_batch_matmul(topi.x86.batch_matmul_int8_compute, need_out_dtype=True),
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
Shuffle,
SizeVar,
StringImm,
ArrayIntImm,
Sub,
Var,
)
Expand Down Expand Up @@ -1869,6 +1870,11 @@ def wrapped(*args, **kwargs):


reinterpret = _dtype_forward(_tir_op.reinterpret)
sextend = _dtype_forward(_tir_op.sextend)
zextend = _dtype_forward(_tir_op.zextend)
truncate = _dtype_forward(_tir_op.truncate)
vectorpermute = _dtype_forward(_tir_op.vectorpermute)
vectorshuffle = _dtype_forward(_tir_op.vectorshuffle)
call_extern = _dtype_forward(_tir_op.call_extern)
call_intrin = _dtype_forward(_tir_op.call_intrin)
call_llvm_intrin = _dtype_forward(_tir_op.call_llvm_intrin)
Expand Down Expand Up @@ -2072,6 +2078,11 @@ def wrapped(*args, **kwargs):
"q_multiply_shift_per_axis",
"ret",
"reinterpret",
"sextend",
"zextend",
"truncate",
"vectorpermute",
"vectorshuffle",
"round",
"rsqrt",
"shift_left",
Expand Down Expand Up @@ -2155,6 +2166,7 @@ def wrapped(*args, **kwargs):
"FloatImm",
"IntImm",
"StringImm",
"ArrayIntImm",
"Cast",
"Add",
"Sub",
Expand Down
12 changes: 8 additions & 4 deletions python/tvm/target/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from .codegen import target_has_features


@register_func("tvm.topi.x86.utils.get_simd_32bit_lanes")
def get_simd_32bit_lanes():
@register_func("tvm.topi.x86.utils.get_x86_simd_32bit_lanes")
def get_x86_simd_32bit_lanes():
"""X86 SIMD optimal vector length lookup.
Parameters
----------
Expand All @@ -29,9 +29,13 @@ def get_simd_32bit_lanes():
vec_len : int
The optimal vector length of CPU from the global context target.
"""
vec_len = 4
if target_has_features(["avx512bw", "avx512f"]):
vec_len = None
if target_has_features("avx512vnni") or target_has_features("avxvnni"):
vec_len = 16
elif target_has_features(["avx512bw", "avx512f"]):
vec_len = 16
elif target_has_features("avx2"):
vec_len = 8
elif target_has_features("ssse3"):
vec_len = 4
return vec_len
7 changes: 4 additions & 3 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from .buffer import Buffer, decl_buffer, DataProducer
from .data_layout import Layout, BijectiveLayout, bijective_layout, layout
from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast
from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, ArrayIntImm, Cast
from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod
from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not
from .expr import Select, BufferLoad, ProducerLoad, Ramp, Broadcast, Shuffle
Expand Down Expand Up @@ -73,8 +73,8 @@
ptx_wait_barrier,
create_barriers,
)
from .op import vectorlow, vectorhigh, vectorcombine
from .op import infinity, reinterpret
from .op import vectorlow, vectorhigh, vectorcombine, vectorpermute, vectorshuffle
from .op import infinity, reinterpret, zextend, sextend, truncate
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz
from .op import sin, sinh, asin, asinh
from .op import cos, cosh, acos, acosh
Expand All @@ -88,6 +88,7 @@
from .op import q_multiply_shift, q_multiply_shift_per_axis, shift_left, shift_right
from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace
from .op import start_profile_intrinsic, end_profile_intrinsic
from .op import atomic_add
from .generic import add, subtract, multiply

from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError
Expand Down
30 changes: 30 additions & 0 deletions python/tvm/tir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,36 @@ def __hash__(self):
return PrimExpr.__hash__(self)


@tvm._ffi.register_object("tir.ArrayIntImm") # type: ignore
class ArrayIntImm(ConstExpr):
"""Array of integer constants.

Parameters
----------
data : list
The list with values of the function.

span : Optional[Span]
The location of this itervar in the source code.
"""

def __init__(self, data, span=None):
self.__init_handle_by_constructor__(_ffi_api.ArrayIntImm, data, span) # type: ignore

def __eq__(self, other):
if isinstance(other, ConstExpr):
return str(self.data) == str(other.data)
return str(self.data) == str(other)

def __ne__(self, other):
if isinstance(other, ConstExpr):
return str(self.data) != str(other.data)
return str(self.data) != str(other)

def __hash__(self):
return PrimExpr.__hash__(self)


@tvm._ffi.register_object("tir.Cast")
class Cast(PrimExprWithOp):
"""Cast expression.
Expand Down
Loading
Loading