Skip to content

Commit

Permalink
minor cases (apache#64)
Browse files Browse the repository at this point in the history
  • Loading branch information
cyx-6 authored and junrushao committed Jul 13, 2022
1 parent 8d1f91c commit 6937916
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 69 deletions.
10 changes: 6 additions & 4 deletions python/tvm/script/builder/tir/block_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
# under the License.
"""TVM Script TIR Block Frame"""
from typing import Any, Dict, List, Union
from numbers import Integral

from tvm._ffi import register_object as _register_object
from tvm.tir import Buffer, BufferLoad, BufferRegion
from tvm.tir import Buffer, BufferLoad, BufferRegion, PrimExpr

from . import _ffi_api
from .base import TIRFrame
Expand All @@ -34,7 +35,7 @@ class BlockInitFrame(TIRFrame):
...


def block(name: str, no_realize: bool = False) -> BlockFrame:
def block(name: str = "", no_realize: bool = False) -> BlockFrame:
return _ffi_api.BlockFrame(name, no_realize) # pylint: disable=no-member # type: ignore


Expand Down Expand Up @@ -82,19 +83,20 @@ def alloc_buffer(
data=None,
strides=[],
elem_offset=None,
storage_scope="",
scope="",
align=-1,
offset_factor=0,
buffer_type="default",
axis_separators=None,
) -> Buffer:
shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
return _ffi_api.AllocBuffer(
shape,
dtype,
data,
strides,
elem_offset,
storage_scope,
scope,
align,
offset_factor,
buffer_type,
Expand Down
124 changes: 65 additions & 59 deletions python/tvm/script/builder/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,61 +18,71 @@

from tvm.tir.expr import Broadcast, Ramp as ramp, Select, Shuffle
from tvm.tir.generic import cast
from tvm.tir.op import (
abs,
acos,
acosh,
asin,
asinh,
atan,
atan2,
atanh,
call_extern,
call_packed,
ceil,
clz,
comm_reducer,
copysign,
cos,
cosh,
erf,
exp,
exp2,
exp10,
floor,
floordiv,
floormod,
fmod,
hypot,
if_then_else as if_then_else_,
infinity,
isfinite,
isinf,
isnan,
ldexp,
log,
log1p,
log2,
log10,
max_value,
min_value,
nearbyint,
nextafter,
popcount,
power,
reinterpret,
round,
rsqrt,
sigmoid,
sin,
sinh,
sqrt,
tan,
tanh,
trunc,
truncdiv,
truncmod,
)
from tvm.tir import op


def op_wrapper(func):
def wrapped(*args, **kwargs):
if "dtype" in kwargs:
kwargs.pop("dtype")
return func(*args, **kwargs)

return wrapped


abs = op_wrapper(op.abs)
acos = op_wrapper(op.acos)
acosh = op_wrapper(op.acosh)
asin = op_wrapper(op.asin)
asinh = op_wrapper(op.asinh)
atan = op_wrapper(op.atan)
atan2 = op_wrapper(op.atan2)
atanh = op_wrapper(op.atanh)
call_extern = op_wrapper(op.call_extern)
call_packed = op_wrapper(op.call_packed)
ceil = op_wrapper(op.ceil)
clz = op_wrapper(op.clz)
comm_reducer = op_wrapper(op.comm_reducer)
copysign = op_wrapper(op.copysign)
cos = op_wrapper(op.cos)
cosh = op_wrapper(op.cosh)
erf = op_wrapper(op.erf)
exp = op_wrapper(op.exp)
exp2 = op_wrapper(op.exp2)
exp10 = op_wrapper(op.exp10)
floor = op_wrapper(op.floor)
floordiv = op_wrapper(op.floordiv)
floormod = op_wrapper(op.floormod)
fmod = op_wrapper(op.fmod)
hypot = op_wrapper(op.hypot)
if_then_else = op_wrapper(op.if_then_else)
infinity = op_wrapper(op.infinity)
isfinite = op_wrapper(op.isfinite)
isinf = op_wrapper(op.isinf)
isnan = op_wrapper(op.isnan)
ldexp = op_wrapper(op.ldexp)
log = op_wrapper(op.log)
log1p = op_wrapper(op.log1p)
log2 = op_wrapper(op.log2)
log10 = op_wrapper(op.log10)
max_value = op_wrapper(op.max_value)
min_value = op_wrapper(op.min_value)
nearbyint = op_wrapper(op.nearbyint)
nextafter = op_wrapper(op.nextafter)
popcount = op_wrapper(op.popcount)
power = op_wrapper(op.power)
reinterpret = op_wrapper(op.reinterpret)
round = op_wrapper(op.round)
rsqrt = op_wrapper(op.rsqrt)
sigmoid = op_wrapper(op.sigmoid)
sin = op_wrapper(op.sin)
sinh = op_wrapper(op.sinh)
sqrt = op_wrapper(op.sqrt)
tan = op_wrapper(op.tan)
tanh = op_wrapper(op.tanh)
trunc = op_wrapper(op.trunc)
truncdiv = op_wrapper(op.truncdiv)
truncmod = op_wrapper(op.truncmod)

from . import _ffi_api

Expand Down Expand Up @@ -133,10 +143,6 @@ def handle():
return _ffi_api.Handle()


def if_then_else(cond, t, f, dtype=None):
return if_then_else_(cond, t, f)


def min(a, b):
"""Compute the minimum value of two expressions.
Expand Down
9 changes: 6 additions & 3 deletions python/tvm/script/builder/tir/prim_func_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
"""TVM Script TIR Prim Func Frame"""
import inspect
from typing import Any, Callable, Dict, Optional, Union
from numbers import Integral

from tvm._ffi import register_object as _register_object
from tvm.ir import Type
from tvm.tir import Buffer, PrimFunc
from tvm.tir import Buffer, PrimFunc, PrimExpr
from tvm.tir.expr import Var

from . import _ffi_api
Expand Down Expand Up @@ -87,6 +88,7 @@ def match_buffer(
buffer_type="default",
axis_separators=None,
) -> Buffer:
shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
return _ffi_api.MatchBuffer( # pylint: disable=no-member # type: ignore
param,
shape,
Expand All @@ -109,20 +111,21 @@ def preflattened_buffer(
data=None,
strides=[],
elem_offset=None,
storage_scope="",
scope="",
align=-1,
offset_factor=0,
buffer_type="default",
axis_separators=None,
) -> None:
shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
_ffi_api.PreflattenedBuffer( # pylint: disable=no-member # type: ignore
postflattened,
shape,
dtype,
data,
strides,
elem_offset,
storage_scope,
scope,
align,
offset_factor,
buffer_type,
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/script/builder/tir/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,12 @@ def let(var: Var, value: PrimExpr) -> LetFrame:
def allocate(
extents: List[PrimExpr],
dtype: str,
storage_scope: str = "",
scope: str = "",
condition: PrimExpr = None,
annotations=None,
) -> AllocateFrame:
return _ffi_api.AllocateFrame(
extents, dtype, storage_scope, condition, annotations
extents, dtype, scope, condition, annotations
) # pylint: disable=no-member # type: ignore


Expand Down Expand Up @@ -160,4 +160,6 @@ def prefetch(buffer: Buffer, indices: List[PrimExpr]) -> None:


def evaluate(value: PrimExpr) -> None:
if isinstance(value, str):
value = StringImm(value)
return _ffi_api.Evaluate(value) # pylint: disable=no-member # type: ignore
32 changes: 31 additions & 1 deletion python/tvm/script/builder/tir/var.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
"""TVM Script TIR Buffer"""
from numbers import Integral

from tvm._ffi import register_object as _register_object
from tvm.ir import Array, PrimExpr, Range
from tvm.runtime import DataType, Object
Expand All @@ -39,6 +41,7 @@ def __init__(
buffer_type="",
axis_separators=None,
) -> None:
shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
self.__init_handle_by_constructor__(
_ffi_api.Buffer,
shape,
Expand Down Expand Up @@ -142,5 +145,32 @@ def __getitem__(self, keys) -> Buffer_:
return self(*keys) # pylint: disable=no-member # type: ignore


buffer_decl = Buffer_
def buffer_decl(
shape,
dtype="float32",
name="buffer",
data=None,
strides=None,
elem_offset=None,
scope="",
alignment=0,
offset_factor=0,
buffer_type="",
axis_separators=None,
) -> Buffer_:
return Buffer_(
shape,
dtype,
name,
data,
strides,
elem_offset,
scope,
alignment,
offset_factor,
buffer_type,
axis_separators,
)


Buffer = BufferProxy()

0 comments on commit 6937916

Please sign in to comment.