diff --git a/docs/langref/hybrid_script.rst b/docs/langref/hybrid_script.rst index f8da87d8cfd2c..122bcd95e6907 100644 --- a/docs/langref/hybrid_script.rst +++ b/docs/langref/hybrid_script.rst @@ -52,7 +52,8 @@ The current parse interface looks like: parser = tvm.hybrid.parse(outer_product, [a, b]) # return the parser of this function -If we pass these tvm tensors to this function, it returns a op node: +If we pass these tvm data structures, like ``Tensor``, ``Var``, ``Expr.*Imm``, +or ``tvm.container.Array``, to this function, it returns a op node: .. code-block:: python @@ -60,12 +61,14 @@ If we pass these tvm tensors to this function, it returns a op node: b = tvm.placeholder((99, ), name='b') c = outer_product(a, b, c) # return the output tensor(s) of the operator -**Under construction, we are still deciding what kind of node should be returned.** +You can use any methods that can be applied on a TVM ``OpNode``, like create_schedule, although +so far, the functionality of schedule is as limited as ``ExternOpNode``. At least, it can be built +to LLVM module. Tuning ~~~~~~ -**Under construction, not truly supported yet.** +**Under construction, not supported yet.** Follow up the example above, you can use some tvm like interfaces to tune the code: @@ -86,6 +89,21 @@ Here we use ``range`` aka ``serial``, ``unroll``, ``parallel``, and ``vectorize` these **4** keywords to annotate the corresponding types of for loops. The the usage is roughly the same as Python standard ``range``. +Besides all the loop types supported in Halide, ``const_range`` is supported for some specific conditions. +Sometimes, ``tvm.container.Array`` is desired to pass as an argument, but in TVM-HalideIR, there is no +such support that converts ``tvm.container.Array`` to an ``Expr``. Thus, a limited feature is supported. +Users can access containers by either constants or constants loops annotated. + +.. code-block:: python + + @tvm.hybrid.script + def foo(a, b): # b is a tvm.container.Array + c = output_tensor(a.shape, a.dtype) + for i in const_range(len(a)): # because you have b access, i should be explicitly annotated as const_range + c[i] = a[i] + b[i] + return c + + Variables ~~~~~~~~~ @@ -111,14 +129,14 @@ It regards the first store of a variable as its declaration. s += a[i, j] # do something with sum b[i] = sum # you can still use sum in this level a[0] = s # you CANNOT use s here, even though it is allowed in conventional Python - b = (1, 2) # this has NOT been supported yet! Attributes ~~~~~~~~~~ -So far, ONLY tensors' ``shape`` attribute is supported! The ``shape`` atrribute is essentailly a -tuple, so you MUST access it as an array. Also, currently, only constant-indexed access is supported. +So far, ONLY tensors' ``shape`` and ``dtype`` attribute are supported! +The ``shape`` atrribute is essentailly a tuple, so you MUST access it as an array. +Currently, only constant-indexed access is supported. .. code-block:: python @@ -133,8 +151,11 @@ Conditional Statement and Expression .. code-block:: python - if condition: - # do something + if condition1 and condition2 and condition3: + # do something + else: + # do something else + # Select a = b if condition else c However, NO ``True`` and ``False`` keyword supported yet. @@ -153,7 +174,9 @@ Array Allocation **Under construction, this function will be supported later!** Use a function call ``allocation(shape, type, share/local)`` to declare an array buffer. -The basic usage is roughly the same as a normal array. +The basic usage is roughly the same as a normal ``numpy.array``, and you should access +high-dim array in ``a[i, j, k]`` fashion instead of ``a[i][j][k]``, +even for ``tvm.container.Array`` for compilation. Thread Bind @@ -170,5 +193,5 @@ You can also do loop-thread bind by writing code like this: Keywords ~~~~~~~~ -- For keywords: ``serial``, ``range``, ``unroll``, ``parallel``, ``vectorize``, ``bind`` +- For keywords: ``serial``, ``range``, ``unroll``, ``parallel``, ``vectorize``, ``bind``, ``const_expr`` - Math keywords: ``log``, ``exp``, ``sigmoid``, ``tanh``, ``power``, ``popcount`` diff --git a/python/tvm/hybrid/calls.py b/python/tvm/hybrid/calls.py index da58280701a5b..3fd472c57afc6 100644 --- a/python/tvm/hybrid/calls.py +++ b/python/tvm/hybrid/calls.py @@ -12,15 +12,17 @@ #pylint: disable=redefined-builtin LOOP_INTRIN = { - 'range' : For.Serial, - 'unroll' : For.Unrolled, - 'parallel' : For.Parallel, - 'vectorize': For.Vectorized, + 'range' : For.Serial, + 'unroll' : For.Unrolled, + 'parallel' : For.Parallel, + 'vectorize' : For.Vectorized, + 'const_range' : (For.Unrolled, ), } + def _range(annotation, args): """Handling TVM loop types""" - n = len(args) + n = args.__len__() if n == 1: low, ext = _api.const(0, dtype='int32'), args[0] else: @@ -33,13 +35,13 @@ def _range(annotation, args): return iter_var, low, ext, for_type -range = unroll = vectorize = parallel = _range #pylint: disable=invalid-name +range = unroll = vectorize = parallel = const_range = _range #pylint: disable=invalid-name def bind(func_id, args): """Handling TVM thread binding""" _internal_assert(func_id == "bind", "This function cannot be directly invoked!") - _internal_assert(len(args) == 2, "A loop bind should only have 2 arguments!") + _internal_assert(args.__len__() == 2, "A loop bind should only have 2 arguments!") _internal_assert(isinstance(args[0], str), \ "A loop bind's first argument should be a string!") iter_var = _api.thread_axis(args[0]) @@ -56,7 +58,7 @@ def _math_intrin(func_id, args): def _min_max(func_id, args): - _internal_assert(len(args) == 2, "Max/Min function should have 2 elements") + _internal_assert(args.__len__() == 2, "Max/Min function should have 2 elements") return getattr(_make, func_id.title())(args[0], args[1]) @@ -66,7 +68,7 @@ def _min_max(func_id, args): def _allocate_tensor(func_id, args): """Handling TVM tensor allocation. You may refer hybrid.intrin.allocate for more details.""" - n = len(args) + n = args.__len__() _internal_assert(isinstance(_api.convert(args[0]), Array), \ "allocate's first argument should be a tuple of shape!") shape = args[0] @@ -89,4 +91,16 @@ def _allocate_tensor(func_id, args): scope = 'global' if func_id != 'output_tensor' else 'output' return (shape, dtype, scope) + output_tensor = allocate = _allocate_tensor #pylint: disable=invalid-name + + +def len(func_id, args): + """Iterpret the len function""" + _internal_assert(args.__len__() == 1, "Only 1 argument is expected!") + _internal_assert(func_id == "len", "This function cannot be directly invoked!") + try: + return _api.convert(args[0].__len__()) + except: #pylint: disable=bare-except + _internal_assert(args[0].shape.__len__() == 1, "Only one-dimension array can get len") + return _api.convert(args[0].shape[0]) diff --git a/python/tvm/hybrid/intrin.py b/python/tvm/hybrid/intrin.py index 48e92a8bf5acc..cb6d0fdb74b88 100644 --- a/python/tvm/hybrid/intrin.py +++ b/python/tvm/hybrid/intrin.py @@ -2,32 +2,19 @@ import numpy -class _range(object): - """Base class of the loop ranges in hybrid script""" - def __init__(self, a, b=None): - if b is None: - self.low = 0 - self.ext = a - else: - self.low = a - self.ext = b + +class bind(object): #pylint: disable=invalid-name + """GPU bind software emulataion runtime.""" + def __init__(self, _, ext): + self.ext = ext def __iter__(self): i = 0 while i < self.ext: - yield i + self.low + yield i i += 1 -class bind(_range): #pylint: disable=invalid-name - def __init__(self, tag, ext): - super(bind, self).__init__(ext) - self.tag = tag - - -unroll = vectorize = parallel = _range #pylint: disable=invalid-name - - def allocate(shape, dtype='float32', scope='global'): #pylint: disable=unused-argument """Allocate a buffer with given shape @@ -47,7 +34,6 @@ def allocate(shape, dtype='float32', scope='global'): #pylint: disable=unused-ar """ return numpy.zeros(shape).astype(dtype) -output_tensor = allocate #pylint: disable=invalid-name def popcount(x): """ @@ -87,17 +73,19 @@ def sigmoid(x): HYBRID_GLOBALS = { - 'unroll' : unroll, - 'vectorize' : vectorize, - 'parallel' : parallel, - 'allocate' : allocate, - 'output_tensor': output_tensor, + 'len' : len, + 'unroll' : range, + 'vectorize' : range, + 'parallel' : range, + 'const_range' : range, 'bind' : bind, + 'allocate' : allocate, + 'output_tensor': allocate, 'sqrt' : numpy.sqrt, 'log' : numpy.log, 'tanh' : numpy.tanh, 'power' : numpy.power, 'exp' : numpy.exp, 'sigmoid' : sigmoid, - 'popcount' : popcount + 'popcount' : popcount, } diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index b3a5e1351edaf..df32144e01508 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -4,7 +4,10 @@ import operator import logging import sys -from numbers import Integral +import types +import numbers + +from enum import Enum from .util import _internal_assert from . import calls @@ -12,18 +15,15 @@ from .var_decl import determine_variable_usage from ..api import all as _all from ..api import any as _any +from ..container import Array from ..tensor import Tensor, Operation from .. import expr as _expr from .. import make as _make from .. import api as _api from .. import ir_pass as _ir_pass -def list_to_block(visit, lst): - """Convert a list of Python IR nodes to HalideIR Block""" - lst = [visit(stmt) for stmt in lst if not util.is_docstring(stmt)] - lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, util.make_nop())] - if not lst: - return util.make_nop() + +def pack_list_to_block(lst): if len(lst) == 1: return lst[0] body = lst[0] @@ -32,6 +32,29 @@ def list_to_block(visit, lst): return body +def visit_list_to_block(visit, lst): + """Convert a list of Python IR nodes to HalideIR Block""" + lst = [visit(stmt) for stmt in lst if not util.is_docstring(stmt)] + lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, util.make_nop())] + if not lst: + return util.make_nop() + return pack_list_to_block(lst) + + +class Symbol(Enum): + """Enumerates types in the symbol table""" + Callable = 0 + Input = 1 + OutputBuffer = 2 + GlobalBuffer = 3 + LocalBuffer = 4 + SharedBuffer = 5 + ConstVar = 6 + BufferVar = 7 + LoopVar = 8 + ConstLoopVar = 9 + + class HybridParser(ast.NodeVisitor): """Python AST visitor pass which finally lowers it to HalideIR""" @@ -82,77 +105,55 @@ def __init__(self, args, usage, symbols, func_name=None): """ self.args = list(args) self.usage = usage.copy() - self._args = {} # Dict maps arg name to actual arg instance (either a var or a buffer) - self.alloc_buffers = {} # Buffers formed by explicit allocate instructions - self.loops_above = {} # State variable that indicates loop levels above the current node - self.variables = {} # The status of defined variables + + self.symbols = {} # Symbol table + for k, v in symbols.items(): + if isinstance(v, types.FunctionType): + self.symbols[k] = Symbol.Callable, v + self.func_name = func_name # The name of the function to be lowered self.outputs = [] # Output tensors' name self.side_effect = set() # Tensors with side effects self.parsed_body = None # The parsed HalideIR body self.returned = False # If this function has a valid return - self.symbols = symbols # The global context + def wrap_up_realize(self, node, body): """Wrap up all the variables which will no longer be used""" - pop_buf = [] - pop_var = [] + to_pop = [] for key, val in self.usage.items(): _, level, _ = val if level != node: continue - if key in self._args.keys(): + _internal_assert(key in self.symbols.keys(), "Unknown symbol %s!" % key) + + ty, entry = self.symbols[key] #pylint: disable=invalid-name + if ty in [Symbol.Input, Symbol.OutputBuffer]: continue - if key in self.alloc_buffers.keys(): - _buf, _scope = self.alloc_buffers[key] - if _scope == 'output': - continue - pop_buf.append(key) + elif 'Buffer' in ty.name: + _buf = entry + _scope = ty.name[:-6].lower() if ty is not Symbol.BufferVar else 'global' + to_pop.append(key) else: - _internal_assert(key in self.variables.keys(), - "Key should be either in one of args, buffers, and vars") - if not isinstance(self.variables[key], tuple): - continue - _buf, _scope = self.variables[key] - pop_var.append(key) + continue + _domain = [_make.range_by_min_extent(0, i) for i in _buf.shape] _dtype = _buf.dtype _true = _api.convert(True) body = _make.Realize(_buf.op, 0, _dtype, _domain, _true, body) body = _make.AttrStmt(_buf.op, 'realize_scope', _api.convert(_scope), body) - for elem in pop_buf: - self.alloc_buffers.pop(elem) - for elem in pop_var: - self.variables.pop(elem) - return body + for elem in to_pop: + self.symbols.pop(elem) + return body - def _get_buffer_from_id(self, s, for_provide=False): - _internal_assert((s in self._args.keys()) + (s in self.alloc_buffers.keys()) == 1, - "This %s is expected to be in either \ - argument list or allocated buffer!" % s) - if s in self._args.keys(): - if for_provide: - self.side_effect.add(self._args[s]) - return self._args[s] - return self.alloc_buffers[s][0] - - def _const(self, value, dtype=None): - if dtype is None: - if isinstance(value, bool): - dtype = "bool" - elif isinstance(value, Integral): - dtype = "int32" - else: - dtype = "float32" - return _api.const(value, dtype) #pylint: disable=invalid-name, missing-docstring def visit_Module(self, node): _internal_assert(len(node.body) == 1, \ - "Only one-function source code can be fed to this parser!") + "Only one-function source code will be fed to this parser!") return self.visit(node.body[0]) @@ -164,8 +165,8 @@ def visit_FunctionDef(self, node): self.func_name = node.name for idx, arg in enumerate(node.args.args): _attr = 'id' if sys.version_info[0] < 3 else 'arg' # To make py2 and 3 compatible - self._args[getattr(arg, _attr)] = self.args[idx] - res = list_to_block(self.visit, node.body) + self.symbols[getattr(arg, _attr)] = (Symbol.Input, self.args[idx]) + res = visit_list_to_block(self.visit, node.body) res = self.wrap_up_realize(node, res) return res @@ -176,25 +177,31 @@ def visit_Expr(self, node): def visit_Name(self, node): name = node.id - if name in self.loops_above.keys(): - return self.loops_above[name] - elif name in self.variables.keys(): - res = self.variables[name] - if isinstance(res, tuple): - buf = res[0] - if isinstance(node.ctx, ast.Load): - return _make.Call(buf.dtype, buf.name, [self._const(0)], \ - _expr.Call.Halide, buf.op, buf.value_index) - return buf, [self._const(0)] + ty, entry = self.symbols[name] + _internal_assert(name in self.symbols, "Unknown symbol %s!" % name) + if ty in [Symbol.LoopVar, Symbol.Input, Symbol.ConstLoopVar]: + return entry + elif ty is Symbol.ConstVar: + return entry if isinstance(node.ctx, ast.Load) else None + elif ty is Symbol.BufferVar: if isinstance(node.ctx, ast.Load): - return res - return None - buf = self._get_buffer_from_id(name) - return buf + return _make.Call(entry.dtype, entry.name, [_api.const(0, 'int32')], \ + _expr.Call.Halide, entry.op, entry.value_index) + return entry, [_api.const(0, 'int32')] + # Do I need any assertion here? + return entry def visit_Num(self, node): - return self._const(node.n) + if isinstance(node.n, numbers.Integral): + dtype = "int32" + elif isinstance(node.n, float): + dtype = "float32" + else: + _internal_assert(isinstance(node.n, bool), + "The data type should be one of (int, float, bool)") + dtype = "bool" + return _api.const(node.n, dtype) def visit_AugAssign(self, node): @@ -204,7 +211,7 @@ def visit_AugAssign(self, node): _internal_assert(len(buf) == 2, "LHS is supposed to be (buf, args)!") buf, args = buf else: - args = [self._const(0)] + args = [_api.const(0, 'int32')] _internal_assert(isinstance(buf, Tensor), "LHS is supposed to be Tensor!") read = _make.Call(buf.dtype, buf.name, args, _expr.Call.Halide, buf.op, buf.value_index) @@ -222,7 +229,7 @@ def visit_Assign(self, node): for i in range(rhs.num_outputs): _internal_assert(isinstance(node.targets[i], ast.Name), "You should bind a pure name to the tensors") - self.alloc_buffers[node.targets[i].id] = (rhs.output(i), 'global') + self.symbols[node.targets[i].id] = Symbol.GlobalBuffer, rhs.output(i) rmap[rhs.outputs[i].op] = rhs.output(i) return util.replace_io(rhs.body, rmap) @@ -234,25 +241,26 @@ def visit_Assign(self, node): #TODO: support defined intermediate buffer later lhs_ = lhs lhs = lhs.id - _internal_assert(lhs not in self.loops_above.keys(), \ - "Loop variable cannot be overwritten!") + if lhs in self.symbols.keys(): + ty, _ = self.symbols[lhs] + _internal_assert(ty != Symbol.LoopVar, \ + "Loop variable cannot be overwritten!") decl, _, rw = self.usage[lhs] if decl == lhs_: - _internal_assert(lhs not in self.variables.keys() and - lhs not in self.alloc_buffers.keys(), \ + _internal_assert(lhs not in self.symbols.keys(), "This value should not be defined before this point!") if isinstance(rhs, tuple): shape, dtype, scope = rhs ph = _api.placeholder(shape, dtype=dtype, name=lhs) - self.alloc_buffers[lhs] = (ph, scope) + self.symbols[lhs] = getattr(Symbol, scope.title() + "Buffer"), ph if scope == 'output': self.outputs.append(lhs) return util.make_nop() if isinstance(rhs, util.halide_imm_types) and ast.Store not in rw: - self.variables[lhs] = rhs + self.symbols[lhs] = Symbol.ConstVar, rhs else: ph = _api.placeholder((1, ), dtype=rhs.dtype, name=lhs) - self.variables[lhs] = (ph, 'global') + self.symbols[lhs] = Symbol.BufferVar, ph lhs = self.visit(lhs_) if lhs is not None: buf, args = lhs @@ -275,17 +283,30 @@ def visit_Index(self, node): def visit_Attribute(self, node): _internal_assert(isinstance(node.value, ast.Name), \ "For atrribute access, only both names are supported so far!") - buf = self._get_buffer_from_id(node.value.id) + buf = self.visit(node.value) return getattr(buf, node.attr) def visit_Subscript(self, node): args = self.visit(node.slice) if isinstance(node.value, ast.Name): + buf = self.visit(node.value) + if isinstance(buf, Array): + for i in args: + if isinstance(i, numbers.Integral): + buf = buf[i] + else: + _internal_assert(isinstance(i, (_expr.IntImm, _expr.UIntImm)), \ + "All indices are supposed to be constants") + buf = buf[i.value] + + return buf + if isinstance(node.ctx, ast.Load): return _make.Call(buf.dtype, buf.name, args, \ _expr.Call.Halide, buf.op, buf.value_index) + return buf, args shape = self.visit(node.value) @@ -308,14 +329,14 @@ def visit_With(self, node): _internal_assert(isinstance(context, ast.Call), "The object must be a Python func call!") _internal_assert(isinstance(option, ast.Name), "The object after 'as' must be an id!") self.annotation[option.id] = context.func.id - return list_to_block(self.visit, node.body) + return visit_list_to_block(self.visit, node.body) def visit_If(self, node): cond = self.visit(node.test) - if_body = list_to_block(self.visit, node.body) + if_body = visit_list_to_block(self.visit, node.body) if node.orelse: - else_body = list_to_block(self.visit, node.orelse) + else_body = visit_list_to_block(self.visit, node.orelse) else: else_body = util.make_nop() return _make.IfThenElse(cond, if_body, else_body) @@ -376,7 +397,10 @@ def visit_Call(self, node): except AttributeError: _internal_assert(func_id in self.symbols.keys(), \ "The function called is not in the context either!") - outs = self.symbols[func_id](*args) + ty, entry = self.symbols[func_id] + _internal_assert(ty is Symbol.Callable, \ + "Are you sure what you call is a function?!") + outs = entry(*args) op = outs.op if isinstance(outs, Tensor) else outs[0].op return op @@ -385,41 +409,66 @@ def visit_For(self, node): iter_var, low, ext, for_type = self.visit(node.iter) _internal_assert(isinstance(node.target, ast.Name), \ "The loop iterator should be a variable!") + _name = node.target.id - if iter_var is None: + + if isinstance(for_type, tuple): + low = _ir_pass.Simplify(low) + ext = _ir_pass.Simplify(ext) + _internal_assert(isinstance(low, _expr.ConstExpr) and + isinstance(ext, _expr.ConstExpr), \ + "Const range should start from a const" + \ + "and iterate const times") + + low, ext = low.value, ext.value + if ext > 114514: + logging.log(logging.CRITICAL, \ + '[Warning] Are you sure to unroll a large loop in Python?') + + bodies = [] + for i in range(low, low + ext): + self.symbols[_name] = Symbol.ConstLoopVar, i + bodies.append(visit_list_to_block(self.visit, node.body)) + return pack_list_to_block(bodies) + + elif iter_var is None: _internal_assert(for_type is not None, "The loop bind function parse error!") offset = iter_var = _api.var(_name) - if not _ir_pass.Equal(low, self._const(0)): + if not _ir_pass.Equal(low, _api.const(0, 'int32')): offset = iter_var + low - self.loops_above[_name] = offset + self.symbols[_name] = Symbol.LoopVar, offset + _body = visit_list_to_block(self.visit, node.body) else: _internal_assert(for_type is None, "The loop iterating function parse error!") - self.loops_above[_name] = iter_var.var - _body = list_to_block(self.visit, node.body) + self.symbols[_name] = Symbol.LoopVar, iter_var.var + _body = visit_list_to_block(self.visit, node.body) + _body = self.wrap_up_realize(node, _body) + if for_type is None: res = _make.AttrStmt(iter_var, 'thread_extent', ext, _body) - else: - res = _make.For(iter_var, self._const(0), ext, for_type, 0, _body) - self.loops_above.pop(_name) + elif not isinstance(for_type, tuple): + res = _make.For(iter_var, _api.const(0, 'int32'), ext, for_type, 0, _body) + self.symbols.pop(_name) return res def visit_Return(self, node): - _internal_assert(not self.loops_above, "Return should not be in a loop body!") + _internal_assert(all(ty != Symbol.LoopVar for ty, _ in self.symbols.values()), \ + "Return should not be in a loop body!") ids = [] if isinstance(node.value, ast.Name): - ids.append(node.value.id) + ids = [node.value.id] else: _internal_assert(isinstance(node.value, ast.Tuple), \ "You should return either a single tensor or a tuple") - for i in node.value.elts: - _internal_assert(isinstance(i, ast.Name), "What do you return?") - ids.append(i.id) + _internal_assert(all(isinstance(i, ast.Name) for i in node.value.elts), \ + "What do you return?") + ids = [i.id for i in node.value.elts] _internal_assert(len(set(ids)) == len(ids), "Duplicated tensors in the return tuples") if len(ids) < len(self.outputs): logging.log(logging.CRITICAL, '[Warning] Not all the output buffers returned!') - self.outputs = [self.alloc_buffers[i][0] for i in ids] + self.outputs = [self.symbols[i][1] for i in ids] self.returned = True return util.make_nop() diff --git a/python/tvm/hybrid/util.py b/python/tvm/hybrid/util.py index aa86d55a6fcf1..44222d2d80f76 100644 --- a/python/tvm/hybrid/util.py +++ b/python/tvm/hybrid/util.py @@ -11,12 +11,13 @@ from .. import make as _make from .. import expr as _expr from .. import stmt as _stmt +from ..container import Array from ..tensor import Tensor #pylint: disable=invalid-name np_arg_types = tuple(list(numeric_types) + [numpy.ndarray]) -tvm_arg_types = (Tensor, _expr.Var, _expr.ConstExpr) +tvm_arg_types = (Tensor, Array, _expr.Var, _expr.ConstExpr) halide_imm_types = (_expr.IntImm, _expr.FloatImm, _expr.UIntImm) def _internal_assert(cond, err): diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py index f87c75f7929d9..7c6d31b297ba3 100644 --- a/tests/python/unittest/test_hybrid_script.py +++ b/tests/python/unittest/test_hybrid_script.py @@ -13,7 +13,7 @@ def tvm_val_2_py_val(val): ctx = tvm.context(target, 0) op = None - outs = func(*args) + outs = func(*tuple(tvm.convert(i) if isinstance(i, list) else i for i in args)) op = outs[0].op if isinstance(outs, list) else outs.op emu_args = [] @@ -23,13 +23,18 @@ def tvm_val_2_py_val(val): shape = [tvm_val_2_py_val(j) for j in i.shape] emu_args.append(numpy.random.randn(*shape).astype(i.dtype)) nd_args.append(tvm.nd.array(emu_args[-1], ctx)) - else: - assert isinstance(i, tvm.expr.Var) + elif isinstance(i, tvm.expr.Var): emu_args.append(tvm_val_2_py_val(i)) nd_args.append(emu_args[-1]) + else: + assert isinstance(i, list) + emu_args.append(numpy.array(i)) sch = tvm.create_schedule(op) - module = tvm.build(sch, args + (outs if isinstance(outs, list) else [outs]), target=target) + module = tvm.build(sch, + [i for i in args if isinstance(i, (tvm.tensor.Tensor, tvm.expr.Var))] + \ + (outs if isinstance(outs, list) else [outs]), + target=target) assert module out_tensors = [] @@ -192,20 +197,20 @@ def fanout(n, a): def test_looptype(): @script def looptype(a, b, c): - d = output_tensor((8, ), 'int32') - e = output_tensor((8, ), 'int32') - f = output_tensor((8, ), 'int32') - for i in parallel(8): + d = output_tensor((16, ), 'int32') + e = output_tensor((16, ), 'int32') + f = output_tensor((16, ), 'int32') + for i in parallel(16): d[i] = a[i] - for j in vectorize(8): + for j in vectorize(16): e[j] = b[j] - for k in unroll(8): + for k in unroll(16): f[k] = c[k] return d, e, f - a = tvm.placeholder((8, ), name='a', dtype='int32') - b = tvm.placeholder((8, ), name='b', dtype='int32') - c = tvm.placeholder((8, ), name='c', dtype='int32') + a = tvm.placeholder((16, ), name='a', dtype='int32') + b = tvm.placeholder((16, ), name='b', dtype='int32') + c = tvm.placeholder((16, ), name='c', dtype='int32') try: d, e, f = looptype(a, b, c) ir = d.op.body @@ -509,9 +514,9 @@ def kernel_b(b, a): def test_func_call(): @tvm.hybrid.script def foo(a, b): - for i in range(10): + for i in range(len(a)): a[i] = i + 1.0 - for i in range(10): + for i in range(len(a)): b[i] = i + 1.0 c = outer_product(10, 10, a, b) d = output_tensor(c.shape, c.dtype) @@ -538,6 +543,26 @@ def foo(a): a = tvm.placeholder((10, ), name='a') run_and_check(foo, [a]) +def test_const_range(): + @tvm.hybrid.script + def foo(a, b): + c = output_tensor(a.shape, a.dtype) + d = output_tensor(a.shape, a.dtype) + + for i in const_range(2): + for j in const_range(5): + c[i, j] = a[i, j] + b[i, j] + + for i in const_range(len(b)): + for j in const_range(len(b[0])): + d[i, j] = a[i, j] + b[i, j] + + return c, d + + a = tvm.placeholder((2, 5), name='a', dtype='int32') + b = [[1, 2, 3, 4, 5], [5, 4, 3, 2, 1]] + run_and_check(foo, [a, b]) + if __name__ == "__main__": test_outer_product() test_fanout() @@ -553,5 +578,6 @@ def foo(a): test_value_index() test_func_call() test_bool() + test_const_range() # TODO: # test_inplace()