Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Expose TVM Backend API-related Builtins and Misc #12468

Merged
merged 1 commit into from
Aug 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from .op import tvm_stack_alloca, tvm_stack_make_shape, tvm_stack_make_array
from .op import tvm_tuple, tvm_struct_get, tvm_struct_set
from .op import address_of, lookup_param, assume, undef
from .op import tvm_thread_allreduce, type_annotation, tvm_access_ptr, tvm_throw_last_error
from .op import infinity, reinterpret
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz
from .op import sin, sinh, asin, asinh
Expand All @@ -62,6 +63,7 @@
from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod, ceildiv
from .op import comm_reducer, min, max, sum
from .op import q_multiply_shift
from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace

from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError

Expand Down
159 changes: 154 additions & 5 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
# pylint: disable=redefined-builtin, invalid-name
"""Operators used in TIR expression."""
import warnings
from typing import Any, Optional
import tvm._ffi
from tvm.ir.base import Span
Expand Down Expand Up @@ -262,10 +263,22 @@ def call_llvm_intrin(dtype, name, *args, span=None):
# pylint: disable=import-outside-toplevel
from tvm.target import codegen

llvm_id = codegen.llvm_lookup_intrinsic_id(name)
assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
from .expr import IntImm

if isinstance(name, str):
llvm_id = codegen.llvm_lookup_intrinsic_id(name)
elif isinstance(name, IntImm):
llvm_id = name.value
else:
llvm_id = name
if llvm_id == 0:
warnings.warn(f"Unknown llvm intrinsic function {name}, falling back to 0")
return call_intrin(
dtype, Op.get("tir.call_llvm_intrin"), tvm.tir.const(llvm_id, "uint32"), *args, span=span
dtype,
Op.get("tir.call_llvm_intrin"),
tvm.tir.const(llvm_id, "uint32"),
*args,
span=span,
)


Expand Down Expand Up @@ -294,8 +307,16 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None):
# pylint: disable=import-outside-toplevel
from tvm.target import codegen

llvm_id = codegen.llvm_lookup_intrinsic_id(name)
assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
from .expr import IntImm

if isinstance(name, str):
llvm_id = codegen.llvm_lookup_intrinsic_id(name)
elif isinstance(name, IntImm):
llvm_id = name.value
else:
llvm_id = name
if llvm_id == 0:
warnings.warn(f"Unknown llvm intrinsic function {name}, falling back to 0")
return call_intrin(
dtype,
Op.get("tir.call_llvm_pure_intrin"),
Expand Down Expand Up @@ -504,6 +525,76 @@ def lookup_param(param_name, span=None):
return call_intrin("handle", "tir.lookup_param", param_name, span=span)


def tvm_thread_allreduce(*freduce_args):
"""
Parameters
----------
freduce_args : Expr
The args.

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("handle", "tir.tvm_thread_allreduce", *freduce_args)


def type_annotation(dtype):
"""Create a type annotation expression

Parameters
----------
dtype : Expr
The data type.

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(dtype, "tir.type_annotation")


def tvm_access_ptr(ptype, data, offset, extent, rw_mask):
"""Get head access address with memory access pattern info

Parameters
----------
ptype : Expr
The data type of pointer.

data : DType*
The data of pointer.

offset : int
The offset of pointer.

extent : int
The extent of pointer.

rw_mask : int
The read write mask.

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("handle", "tir.tvm_access_ptr", ptype, data, offset, extent, rw_mask)


def tvm_throw_last_error():
"""Throw TVMGetLastError()

Returns
-------
ret : PrimExpr
The return expression
"""
return call_intrin("handle", "tir.tvm_throw_last_error")


def ret(val):
"""Create a tir return expression

Expand Down Expand Up @@ -1857,6 +1948,64 @@ def reducer(expr, axis, where=None, init=None, *args):
return reducer


def TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, dtype_bits_hint):
"""Backend function to allocate temporal workspace

Parameters
----------
device_type : int
The device type which the space will be allocated.

device_id : int
The device id which the space will be allocated.

nbytes : int
The size of the space requested.

dtype_code_hint : int
The type code of the array elements. Only used in certain backends such as OpenGL.

dtype_bits_hint : int
The type bits of the array elements. Only used in certain backends such as OpenGL.

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
"handle",
"tir.TVMBackendAllocWorkspace",
device_type,
device_id,
nbytes,
dtype_code_hint,
dtype_bits_hint,
)


def TVMBackendFreeWorkspace(device_type, device_id, ptr):
"""Backend function to free temporal workspace.

Parameters
----------
device_type : int
The device type which the space will be allocated.

device_id : int
The device id which the space will be allocated.

ptr : Var
The result allocated space pointer.

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("int32", "tir.TVMBackendFreeWorkspace", device_type, device_id, ptr)


# pylint: disable=unnecessary-lambda
sum = comm_reducer(lambda x, y: x + y, lambda t: const(0, dtype=t), name="sum")
min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y, None), max_value, name="min") # type: ignore
Expand Down
42 changes: 42 additions & 0 deletions tests/python/unittest/test_tir_op_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,42 @@ def test_tir_op_call_likely():
assert expr.op.name == "tir.likely"


def test_tir_op_tvm_thread_allreduce():
x = tir.Var("x", "int32")
buffer = tir.decl_buffer((128), "float32")
y = tir.Var("y", "handle")
z = tir.Var("z", "int32")
expr = tir.tvm_thread_allreduce(x, buffer[0], True, y, z)
assert expr.op.name == "tir.tvm_thread_allreduce"


def test_tir_op_type_annotation():
expr = tir.type_annotation("int32")
assert expr.op.name == "tir.type_annotation"


def test_tir_op_tvm_access_ptr():
buffer = tir.decl_buffer((128), "float32")
expr = tir.tvm_access_ptr("float32", buffer.data, 0, 1, 2)
assert expr.op.name == "tir.tvm_access_ptr"


def test_tir_op_tvm_throw_last_error():
expr = tir.tvm_throw_last_error()
assert expr.op.name == "tir.tvm_throw_last_error"


def test_tir_op_TVMBackendAllocWorkspace():
expr = tir.TVMBackendAllocWorkspace(0, 1, 2, 3, 4)
assert expr.op.name == "tir.TVMBackendAllocWorkspace"


def test_tir_op_TVMBackendFreeWorkspace():
buffer = tir.decl_buffer((128), "float32")
expr = tir.TVMBackendFreeWorkspace(0, 1, buffer.data)
assert expr.op.name == "tir.TVMBackendFreeWorkspace"


if __name__ == "__main__":
test_tir_op_tvm_tuple()
test_tir_op_tvm_struct_get()
Expand All @@ -90,3 +126,9 @@ def test_tir_op_call_likely():
test_tir_op_call_assume()
test_tir_op_call_undef()
test_tir_op_call_likely()
test_tir_op_tvm_thread_allreduce()
test_tir_op_type_annotation()
test_tir_op_tvm_access_ptr()
test_tir_op_tvm_throw_last_error()
test_tir_op_TVMBackendAllocWorkspace()
test_tir_op_TVMBackendFreeWorkspace()