From bec9f16d42fc11ac97e0f01af007551398b025a2 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 27 Sep 2022 16:51:08 -0500 Subject: [PATCH] [TIR][Transform] Clear buffer_map during MakeUnpackedAPI (#12891) * [TIR][Transform] Clear buffer_map during MakeUnpackedAPI This mimics the behavior in `MakePackedAPI`, and is assumed to be the case for some codegens. * Remove read of buffer_map in ethosu.tir_to_cs_translator This previously relied on `MakeUnpackedAPI` preserving the `PrimFunc::buffer_map`, even after it had been used for lowering. It now reads from the `BufferLoad` and `BufferStore` nodes to determine buffer shapes. * Added more documentation for MakePackedAPI/MakeUnpackedAPI --- .../relay/backend/contrib/ethosu/tir/utils.py | 30 +++++++++++++++ .../contrib/ethosu/tir_to_cs_translator.py | 37 +++++++++++++------ python/tvm/tir/transform/transform.py | 30 +++++++++++++++ src/tir/transforms/make_unpacked_api.cc | 7 +--- 4 files changed, 88 insertions(+), 16 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/utils.py b/python/tvm/relay/backend/contrib/ethosu/tir/utils.py index a823667234df..396735a07c4c 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/utils.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/utils.py @@ -158,6 +158,36 @@ def get_outer_loops(stmt, layout): return None +def collect_buffer_map(stmt): + """Collect a map of Var -> Buffer + + Generate a map from a buffer's backing `tir.Var` to the + `tir.Buffer` object that uses it. If multiple such buffers exist, + return the first occurrence. + + Parameters + ---------- + stmt : tvm.tir.Stmt + The statement to get the BufferLoads from. + + Returns + ------- + buffer_map : Dict[Var, Buffer] + The map from buffer var to the buffers that use it. + """ + buffer_map = {} + + def _visit(node): + if isinstance(node, (tvm.tir.BufferLoad, tvm.tir.BufferStore)): + buf = node.buffer + if buf.data not in buffer_map: + buffer_map[buf.data] = buf + + tvm.tir.stmt_functor.post_order_visit(stmt, _visit) + + return buffer_map + + def get_loads(stmt): """Get the BufferLoad statements. diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index f5c8994bec77..19f009d284ab 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -29,6 +29,7 @@ from tvm.relay.backend.contrib.ethosu import util from tvm.relay.backend.contrib.ethosu import vela_api from tvm.relay.backend.contrib.ethosu.tir import spec +from tvm.relay.backend.contrib.ethosu.tir import utils as tir_utils class BufferType(Enum): @@ -254,26 +255,40 @@ def extract_param_base_addresses(mod, buffer_info, scratch_region_map) -> List[u assert len(mod.functions.items()) == 1 primfunc = mod.functions.items()[0][1] + buffer_map = tir_utils.collect_buffer_map(primfunc.body) + base_addresses = list() idx = 0 + for param in primfunc.params: # constants are pooled together and handled specially # this will change after tir.allocate_const. # For now, we are skipping generating buffer addresses here if buffer_info[param].btype == BufferType.constant: continue - buffer = primfunc.buffer_map[param] - dtype = buffer.dtype - element_size_bytes = np.iinfo(dtype).bits // 8 - size_bytes = element_size_bytes * np.prod(list(buffer.shape)) - base_addresses.append( - util.BaseAddress( - param.name.replace("-", "_"), - idx, - _get_region(buffer_info[param].btype, param, scratch_region_map), - size_bytes, + + if param in buffer_map: + buffer = buffer_map[param] + dtype = buffer.dtype + element_size_bytes = np.iinfo(dtype).bits // 8 + size_bytes = element_size_bytes * np.prod(list(buffer.shape)) + base_addresses.append( + util.BaseAddress( + param.name.replace("-", "_"), + idx, + _get_region(buffer_info[param].btype, param, scratch_region_map), + size_bytes, + ) + ) + else: + base_addresses.append( + util.BaseAddress( + param.name.replace("-", "_"), + idx, + _get_region(buffer_info[param].btype, param, scratch_region_map), + 0, + ) ) - ) idx += 1 return base_addresses diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 3c1ca196f1b0..d95d15c0dfbe 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -390,6 +390,26 @@ def LowerCustomDatatypes(): def MakePackedAPI(): """Transform the PrimFuncs in the module to a packed func API. + Prior to this pass, the PrimFunc may have Buffer arguments defined + in the `PrimFuncNode::buffer_map`. This pass consumes the + `buffer_map`, using it to generate `TVMArgs` and `TVMRetValue*` + arguments that implement the `PackedFunc` API. + + For static shapes, the `BufferNode::shape`, `BufferNode::strides`, + and `BufferNode::elem_offset` member variables are used to + generate runtime checks on the corresponding member variables in + the user-provided `DLTensor*` or `tvm.nd.array` argument. (e.g. A + PrimFunc that accepts a buffer of shape `[16,32]` validates that + the `DLTensor::shape` array is `[16,32]`.) + + For dynamic Buffers, in which one or more of these `BufferNode` member + variables use `tir.Var` that are not defined by other PrimFunc + parameters, these are instead used to define the variables based on + the corresponding `DLTensor` members. (e.g. A PrimFunc that accepts a + buffer of shape `[tir.Var("n"), tir.Var("m")]`, when passed a + `DLTensor` of shape `[16,32]`, will define `n = 16` and `n=32`, based + on the argument's shape. + Returns ------- fpass : tvm.transform.Pass @@ -401,6 +421,16 @@ def MakePackedAPI(): def MakeUnpackedAPI(): """Transform the PrimFuncs in the module to a C API compatible with internal calls. + Prior to this pass, the PrimFunc may have Buffer arguments defined in + the `PrimFuncNode::buffer_map`. This pass consumes the `buffer_map`, + using it to generate `T*` arguments (e.g. `float32*`) that can be + directly called by a C API. + + For static shapes, no runtime validation is performed to confirm that + the argument buffer's shape matches the expected shape. For dynamic + shapes, `MakeUnpackedAPI` requires that the dynamic parameters be + passed as separate `tir.Var` parameters. + Returns ------- fpass : tvm.transform.Pass diff --git a/src/tir/transforms/make_unpacked_api.cc b/src/tir/transforms/make_unpacked_api.cc index c57daeabbe1d..87e8f38895cd 100644 --- a/src/tir/transforms/make_unpacked_api.cc +++ b/src/tir/transforms/make_unpacked_api.cc @@ -59,16 +59,13 @@ PrimFunc MakeUnpackedAPI(PrimFunc&& func) { // Collect variables and buffers to map between Array args; - Map new_buffer_map; + for (const Var& param : func->params) { // Ideally all func params should have Buffers defined in the buffer_map // We should look to insert buffer_maps for all PrimFuncs that are returned // to the core compiler. if (func->buffer_map.find(param) != func->buffer_map.end()) { args.push_back(func->buffer_map[param]->data); - // Rewiring the buffer_var to map to Buffers for low-level passes - // retain information about the buffer. - new_buffer_map.Set(func->buffer_map[param]->data, func->buffer_map[param]); } else { args.push_back(param); } @@ -82,7 +79,7 @@ PrimFunc MakeUnpackedAPI(PrimFunc&& func) { func_ptr->body = MergeNest(device_init, func_ptr->body); func_ptr->params = args; func_ptr->ret_type = PrimType(DataType::Int(32)); - func_ptr->buffer_map = new_buffer_map; + func_ptr->buffer_map = Map(); // return the function. return std::move(func);