diff --git a/cmake/modules/contrib/BNNS.cmake b/cmake/modules/contrib/BNNS.cmake index 1adb3ba10231..fcac1a26682a 100644 --- a/cmake/modules/contrib/BNNS.cmake +++ b/cmake/modules/contrib/BNNS.cmake @@ -16,10 +16,8 @@ # under the License. if(USE_BNNS STREQUAL "ON") - add_definitions(-DUSE_JSON_RUNTIME=1) tvm_file_glob(GLOB BNNS_RELAY_CONTRIB_SRC src/relay/backend/contrib/bnns/*.cc) list(APPEND COMPILER_SRCS ${BNNS_RELAY_CONTRIB_SRC}) - list(APPEND COMPILER_SRCS ${JSON_RELAY_CONTRIB_SRC}) list(APPEND TVM_RUNTIME_LINKER_LIBS "-framework Accelerate") diff --git a/cmake/modules/contrib/DNNL.cmake b/cmake/modules/contrib/DNNL.cmake index 9e36f39891e1..edfafc195554 100644 --- a/cmake/modules/contrib/DNNL.cmake +++ b/cmake/modules/contrib/DNNL.cmake @@ -15,17 +15,73 @@ # specific language governing permissions and limitations # under the License. +macro(find_dnnl) + # 1. Try to find via dnnl-config.cmake + find_package(dnnl CONFIG) + + if (NOT dnnl_FOUND) + # 2. Try to find dnnl like a lib + headers distribution + find_library(EXTERN_LIBRARY_DNNL dnnl NO_CACHE) + if (EXTERN_LIBRARY_DNNL) + get_filename_component(DNNL_LIB_DIR ${EXTERN_LIBRARY_DNNL} DIRECTORY) + get_filename_component(DNNL_HDR_DIR ${DNNL_LIB_DIR} DIRECTORY) + string(APPEND DNNL_HDR_DIR "/include") + + find_file(DNNL_CONFIG_HDR dnnl_config.h PATHS ${DNNL_HDR_DIR} NO_CACHEEEE) + if (DNNL_CONFIG_HDR) + file(READ ${DNNL_CONFIG_HDR} DNNL_CONFIG) + string(REGEX MATCH "DNNL_CPU_RUNTIME DNNL_RUNTIME_(OMP|SEQ|TBB)" DNNL_CPU_RUNTIME "${DNNL_CONFIG}") + string(REGEX MATCH "(OMP|SEQ|TBB)" DNNL_CPU_RUNTIME "${DNNL_CPU_RUNTIME}") + + if (DNNL_CPU_RUNTIME) + add_library(DNNL::dnnl SHARED IMPORTED) + set_target_properties(DNNL::dnnl PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${DNNL_HDR_DIR}" + IMPORTED_LOCATION "${EXTERN_LIBRARY_DNNL}" + ) + + set(dnnl_FOUND TRUE) + set(dnnl_DIR "${DNNL_LIB_DIR}") + endif() + endif() + + # because find_file put this value to cache + unset(EXTERN_LIBRARY_DNNL CACHE) + unset(DNNL_CONFIG_HDR CACHE) + endif() + endif() + + if (NOT dnnl_FOUND) + message(FATAL_ERROR + "Cannot detect DNNL package. Please make sure that you have it properly installed " + "and corresponding variables are set (CMAKE_PREFIX_PATH or CMAKE_LIBRARY_PATH).") + endif() +endmacro(find_dnnl) + + +if (USE_DNNL_CODEGEN STREQUAL "ON" ) + find_dnnl() + + if (DNNL_CPU_RUNTIME STREQUAL "OMP" AND NOT USE_OPENMP) + message(WARNING + "DNNL and TVM are using different threading runtimes. Mixing of thread " + "pools may lead to significant performance penalty. Suggestion is to " + "switch TVM to use OpenMP (cmake flag: -DUSE_OPENMP=ON).") + endif() +endif() + if((USE_DNNL_CODEGEN STREQUAL "ON") OR (USE_DNNL_CODEGEN STREQUAL "JSON")) - add_definitions(-DUSE_JSON_RUNTIME=1) tvm_file_glob(GLOB DNNL_RELAY_CONTRIB_SRC src/relay/backend/contrib/dnnl/*.cc) - list(APPEND COMPILER_SRCS ${DNNL_RELAY_CONTRIB_SRC}) - list(APPEND COMPILER_SRCS ${JSON_RELAY_CONTRIB_SRC}) + tvm_file_glob(GLOB DNNL_CONTRIB_SRC src/runtime/contrib/dnnl/*.cc) - find_library(EXTERN_LIBRARY_DNNL dnnl) - list(APPEND TVM_RUNTIME_LINKER_LIBS ${EXTERN_LIBRARY_DNNL}) - tvm_file_glob(GLOB DNNL_CONTRIB_SRC src/runtime/contrib/dnnl/dnnl_json_runtime.cc) + list(APPEND COMPILER_SRCS ${DNNL_RELAY_CONTRIB_SRC}) list(APPEND RUNTIME_SRCS ${DNNL_CONTRIB_SRC}) - message(STATUS "Build with DNNL JSON runtime: " ${EXTERN_LIBRARY_DNNL}) + list(APPEND TVM_RUNTIME_LINKER_LIBS DNNL::dnnl) + # WA. Have to use system include path while TVM doesn't use targets to describe dependencies + include_directories(SYSTEM $) + add_definitions(-DUSE_JSON_RUNTIME=1) + + message(STATUS "Build with DNNL JSON runtime: ${dnnl_DIR} (${DNNL_CPU_RUNTIME})" ) elseif(USE_DNNL_CODEGEN STREQUAL "C_SRC") tvm_file_glob(GLOB DNNL_RELAY_CONTRIB_SRC src/relay/backend/contrib/dnnl/*.cc) list(APPEND COMPILER_SRCS ${DNNL_RELAY_CONTRIB_SRC}) diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index 05b588051a1c..ed05b88c1bdb 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -38,9 +38,21 @@ from tvm.relay import transform from tvm.relay.build_module import bind_params_by_name -from ...dataflow_pattern import wildcard, is_op +from ...dataflow_pattern import wildcard, is_op, is_constant from .register import register_pattern_table + +def get_dnnl_version(): + """Return tuple with version or DNNL library if known + Otherwise return unknown value which is bigger than any over real + versions. + """ + f = tvm.get_global_func("runtime.module.dnnl_version", allow_missing=True) + return tuple(int(el) for el in f().split(".")) if f else (100500,) + + +dnnl_version = get_dnnl_version() + logger = logging.getLogger("DNNL") @@ -48,8 +60,8 @@ def _register_external_op_helper(op_name, supported=True): """The helper function to indicate that a given operator can be supported by DNNL. - Paramters - --------- + Parameters + ---------- op_name : Str The name of operator that will be registered. @@ -159,6 +171,90 @@ def make_dnnl_pattern(op, with_bias, with_eltwise): return dnnl_pattern +def make_qnn_conv2d_pattern(with_sum=False): + """Make qnn.conv2d based pattern supported by DNNL + + Parameters + ---------- + with_sum : bool + Indicate to append qnn.sum at the end of pattern + + Returns + ------- + pattern : Tuple(pattern_name, CallPattern) + Created pattern name, along with its CallPattern. + """ + weight = is_constant() # |const requirements, have to recalculate bias to compensate src_zp + bias = is_constant() + + pat = wildcard() + pat = is_op("qnn.conv2d")( + pat, weight, is_constant(), is_constant(), is_constant(), is_constant() + ) + pat = is_op("add")(pat, bias) | pat + pat = is_op("qnn.requantize")(pat, is_constant(), is_constant(), is_constant(), is_constant()) + pat = is_op("clip")(pat) + pat = is_op("cast")(pat) + if with_sum is True: + pat = is_op("qnn.add")( + pat, + wildcard(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), + ) + pat = is_op("clip")(pat) + + pat_name = "dnnl.qnn.conv2d_sum" if with_sum else "dnnl.qnn.conv2d" + + return pat_name, pat + + +def make_qnn_dense_pattern(with_sum=False): + """Make qnn.dense based pattern supported by DNNL + + Parameters + ---------- + with_sum : bool + Indicate to append qnn.sum at the end of pattern + + Returns + ------- + pattern : Tuple(pattern_name, CallPattern) + Created pattern name, along with its CallPattern. + """ + weight = is_constant() + bias = is_constant() + + pat = wildcard() + pat = is_op("qnn.dense")( + pat, weight, is_constant(), is_constant(), is_constant(), is_constant() + ) + pat = is_op("add")(pat, bias) | pat + pat = is_op("qnn.requantize")(pat, is_constant(), is_constant(), is_constant(), is_constant()) + pat = is_op("clip")(pat) + pat = is_op("cast")(pat) + if with_sum is True: + pat = is_op("qnn.add")( + pat, + wildcard(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), + ) + pat = is_op("clip")(pat) + + pat_name = "dnnl.qnn.dense_sum" if with_sum else "dnnl.qnn.dense" + + return pat_name, pat + + @register_pattern_table("dnnl") def pattern_table(): """Create dnnl patterns. @@ -173,9 +269,16 @@ def pattern_table(): for with_bias in [True, False]: for elt in elt_list: if not with_bias and not elt: - return dnnl_patterns + continue dnnl_patterns.append(make_dnnl_pattern("conv2d", with_bias, elt)) dnnl_patterns.append(make_dnnl_pattern("dense", with_bias, elt)) + + for with_sum in [True, False]: + dnnl_patterns.append(make_qnn_conv2d_pattern(with_sum)) + # Old dnnl version doesn't support per channel o_scale + if dnnl_version >= (2, 2) or not with_sum: + dnnl_patterns.append(make_qnn_dense_pattern(with_sum)) + return dnnl_patterns diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index 0c930dd1153c..622719f608c4 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -434,7 +434,7 @@ def abs(data): def sign(data): - """Compute element-wise absolute of data. + """Compute element-wise sign of data. Parameters ---------- diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index b1b2f580cf94..120d9401007c 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -34,6 +34,7 @@ #include #include "../../utils.h" +#include "composite_op.h" #ifdef USE_JSON_RUNTIME #include "../../../../runtime/contrib/json/json_node.h" @@ -435,63 +436,74 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase { #else // DNNL JSON runtime +/*! + * @brief Replace var expr which bind with args of call node + * + * @param args collection of expression (contains vars or constant nodes) + * @param cn call node which describe mapping of internal body vars with args + * @return + */ +static tvm::Array BindToCallNodeArgs(const std::vector& args, const CallNode* cn) { + tvm::Array res; + for (const auto& arg : args) { + if (arg->IsInstance()) { + res.push_back(arg); + } else { + auto body_params = cn->op.as()->params; + auto found = std::find(body_params.begin(), body_params.end(), arg); + ICHECK(found != body_params.end()); + auto idx = std::distance(body_params.begin(), found); + res.push_back(cn->args[idx]); + } + } + return res; +} + +/*! + * @brief Serializer to DNNL JSON runtime module + */ class DNNLJSONSerializer : public backend::contrib::JSONSerializer { using JSONGraphNode = tvm::runtime::json::JSONGraphNode; using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry; public: - DNNLJSONSerializer(const std::string& symbol, const Expr& expr) : JSONSerializer(symbol, expr) {} + // "dnnl_" prefix is only because of constraint to has constant naming + // starts from name of code generator. Looks like TVM issue... + DNNLJSONSerializer(const std::string& symbol, const Expr& expr) + : JSONSerializer("dnnl_" + symbol, expr) {} std::vector VisitExpr_(const CallNode* cn) override { Expr expr = GetRef(cn); std::string name; - const CallNode* call = cn; + tvm::Array args; + std::unordered_map attrs; + if (const auto* op_node = cn->op.as()) { name = op_node->name; + args = cn->args; + attrs = extractAttrs(cn); } else if (const auto* fn = cn->op.as()) { auto comp = fn->GetAttr(attr::kComposite); ICHECK(comp.defined()) << "DNNL JSON runtime only supports composite functions."; name = comp.value(); - if (name == "dnnl.conv2d_bias_relu") { - call = GetRootCall(fn->body.as(), 2, {"nn.conv2d", "add", "nn.relu"}); - } else if (name == "dnnl.conv2d_bias_tanh") { - call = GetRootCall(fn->body.as(), 2, {"nn.conv2d", "add", "tanh"}); - ICHECK(call->op.as()) << "Not op node"; - } else if (name == "dnnl.conv2d_bias_sigmoid") { - call = GetRootCall(fn->body.as(), 2, {"nn.conv2d", "add", "sigmoid"}); - ICHECK(call->op.as()) << "Not op node"; - } else if (name == "dnnl.conv2d_bias") { - call = GetRootCall(fn->body.as(), 1, {"nn.conv2d", "add"}); - ICHECK(call->op.as()) << "Not op node"; - } else if (name == "dnnl.conv2d_relu") { - call = GetRootCall(fn->body.as(), 1, {"nn.conv2d", "nn.relu"}); - ICHECK(call->op.as()) << "Not op node"; - } else if (name == "dnnl.conv2d_tanh") { - call = GetRootCall(fn->body.as(), 1, {"nn.conv2d", "tanh"}); - ICHECK(call->op.as()) << "Not op node"; - } else if (name == "dnnl.conv2d_sigmoid") { - call = GetRootCall(fn->body.as(), 1, {"nn.conv2d", "sigmoid"}); - ICHECK(call->op.as()) << "Not op node"; - } else if (name == "dnnl.dense_bias") { - call = GetRootCall(fn->body.as(), 1, {"nn.dense", "add"}); - ICHECK(call->op.as()) << "Not op node"; - } else { - LOG(FATAL) << "Unrecognized DNNL pattern: " << name; - } + std::vector args_loc; + std::tie(args_loc, attrs) = DNNLCompositeFunctionsParser(fn); + args = BindToCallNodeArgs(args_loc, cn); } else { LOG(FATAL) << "DNNL JSON runtime does not support calls to " << cn->op->GetTypeKey(); } std::vector inputs; - for (const auto& arg : cn->args) { + for (const auto& arg : args) { auto res = VisitExpr(arg); inputs.insert(inputs.end(), res.begin(), res.end()); } auto node = std::make_shared(name, /* name_ */ "kernel", /* op_type_ */ inputs, 1 /* num_outputs_ */); - SetCallNodeAttribute(node, call); + for (const auto& kvp : attrs) node->SetAttr(kvp.first, kvp.second); + return AddNode(node, GetRef(cn)); } }; @@ -523,6 +535,61 @@ runtime::Module DNNLCompiler(const ObjectRef& ref) { TVM_REGISTER_GLOBAL("relay.ext.dnnl").set_body_typed(DNNLCompiler); +/*! + * @brief Constant Updater for DNNL JSON runtime + * + * Not all originally existing ConstantNode should be passed to JSON runtime. + * Some of them should be skipped or recalculated. Exactly the same traversing + * as DNNLJSONSerializer. + * + * TODO: DNNLJSONSerializer and DNNLConstantUpdater perform identical constant tensor + * recalculation. Better to reuse results from each other. + */ +struct DNNLConstantUpdater : public ConstantUpdater { + public: + DNNLConstantUpdater(const std::string& symbol, + std::unordered_map* params) + : ConstantUpdater("dnnl_" + symbol, params) {} + using ConstantUpdater::VisitExpr_; + + void VisitExpr_(const CallNode* cn) final { + this->VisitSpan(cn->span); + + if (const auto* fn = cn->op.as()) { + auto args_and_attr = DNNLCompositeFunctionsParser(fn); + auto args = BindToCallNodeArgs(args_and_attr.first, cn); + + // Customized visit order of args + for (const auto& arg : args) { + this->VisitExpr(arg); + } + } else { + // Original visit order of args + for (auto arg : cn->args) { + this->VisitExpr(arg); + } + } + } +}; + +/*! + * \brief The external compiler/codegen tool. It takes a Relay expression/module and + * produce collection of required constant NDArrays. + */ +Map DNNLConstantUpdaterFunc(Expr expr, std::string symbol) { + // Visit all suitable constant nodes + std::unordered_map res; + DNNLConstantUpdater const_updater(symbol, &res); + const_updater(expr); + + // Convert to tvm::Map + Map ret; + for (const auto& kvp : res) ret.Set(kvp.first, kvp.second); + return ret; +} + +TVM_REGISTER_GLOBAL("relay.ext.dnnl.constant_updater").set_body_typed(DNNLConstantUpdaterFunc); + } // namespace contrib } // namespace relay } // namespace tvm diff --git a/src/relay/backend/contrib/dnnl/codegen_tools.h b/src/relay/backend/contrib/dnnl/codegen_tools.h new file mode 100644 index 000000000000..49987f0532a9 --- /dev/null +++ b/src/relay/backend/contrib/dnnl/codegen_tools.h @@ -0,0 +1,458 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RELAY_BACKEND_CONTRIB_DNNL_CODEGEN_TOOLS_H_ +#define TVM_RELAY_BACKEND_CONTRIB_DNNL_CODEGEN_TOOLS_H_ + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "../../../op/make_op.h" +#include "../codegen_json/codegen_json.h" + +namespace tvm { +namespace relay { +namespace contrib { + +namespace details { + +template +DataType make_dtype(); + +template <> +DataType make_dtype() { + return DataType::Int(32); +} +template <> +DataType make_dtype() { + return DataType::UInt(32); +} +template <> +DataType make_dtype() { + return DataType::Int(8); +} +template <> +DataType make_dtype() { + return DataType::UInt(8); +} +template <> +DataType make_dtype() { + return DataType::Float(32); +} + +} // namespace details + +/*! + * @brief Infer type for provided expression + */ +Expr InferType(const Expr& expr) { + auto mod_and_global = IRModule::FromExprInContext(expr, {}, {}, {}); + auto mod = transform::InferType()(mod_and_global.first); + auto inferred = Downcast(mod->Lookup(mod_and_global.second->name_hint)); + return inferred->body; +} + +/*! + * @brief Evaluate expression if possible + * + * Transformation rules: + * Empty expr -> Empty expr + * Constant expr -> ConstantNode with corresponding values + * All other expr -> original expr + * + * @param expr original expression + * @return resulting expr. Corresponding ConstantNode or original expr + */ +Expr EvalExpr(Expr expr) { + if (!expr.defined()) return expr; + + Device dev{kDLCPU, 0}; + Target target = Target("llvm"); + + With fresh_build_ctx(transform::PassContext::Create()); + auto res = Eval(expr, {}, {}, dev, target); + + if (res->IsInstance()) { + auto nd_array = Downcast(res); + return InferType(Constant(nd_array)); + } else { + LOG(ERROR) << "Unexpected object type"; + } + return {}; +} + +/*! + * @brief Evaluate shape of resulting tensor for provided expression + * @param exp expression to evaluate result shape + * @return shape of tensor + */ +static std::vector shape_of(const Expr& exp) { + auto typed_exp = InferType(exp); + auto tt = typed_exp->checked_type().as(); + + ICHECK(tt) << "Expr has none tensor type"; + + std::vector res; + for (const auto d : tt->shape) { + auto i_d = d.as(); + ICHECK(i_d); + res.push_back(i_d->value); + } + return res; +} + +/*! + * @brief Evaluate shape of resulting tensor + * @param exp expression to evaluate + * @return resulting data type + */ +static DataType dtype_of(const Expr& exp) { + auto typed_exp = InferType(exp); + auto tt = typed_exp->checked_type().as(); + + ICHECK(tt) << "Expr is not tensor type"; + return tt->dtype; +} + +static bool is_scalar(const Expr& exp) { + const Expr typed_exp = exp.defined() ? exp : InferType(exp); + const auto* tt = typed_exp->type_as(); + ICHECK(tt) << "Expression is not Tensor producer"; + return tt->shape.size() == 0; +} + +static bool is_const(const Expr& exp) { return exp->IsInstance(); } + +template +static bool is_const_scalar_eq(const Expr& exp, T val) { + if (details::make_dtype() != dtype_of(exp)) return false; + if (const auto* constant = exp.as()) { + if (constant->data->ndim == 0) { + return *static_cast(constant->data->data) == val; + } + } + return false; +} + +Constant constant(int val) { + auto value = runtime::NDArray::Empty({}, DataType::Int(32), {kDLCPU, 0}); + value.CopyFromBytes(&val, sizeof(val)); + return Constant(value); +} + +Constant constant(float val) { + auto value = runtime::NDArray::Empty({}, DataType::Float(32), {kDLCPU, 0}); + value.CopyFromBytes(&val, sizeof(val)); + return Constant(value); +} + +/*! + * @brief Check if expr produce a tensor with broadcast value. + * If yes return corresponding scalar value otherwise return the same expr. + */ +Expr collapse_to_scalar(const Expr& exp) { + const Expr const_exp = is_const(exp) ? exp : EvalExpr(exp); + if (is_scalar(const_exp)) return const_exp; + + if (const auto* const_node = const_exp.as()) { + auto ptr = static_cast(const_node->data->data); + auto size = const_node->data->shape[0]; + bool is_same = true; + for (int i = 0; i < size; i++) { + is_same &= ptr[i] == ptr[0]; + } + if (is_same) { + return EvalExpr(constant(ptr[0])); + } + } + return exp; +} + +template +static Expr cast(const Expr& that) { + return MakeCast(that, details::make_dtype()); +} + +static Expr squeeze(const Expr& exp) { + const Expr typed_exp = exp.defined() ? exp : InferType(exp); + const auto* tt = typed_exp->type_as(); + ICHECK(tt) << "Expression is not Tensor producer"; + // Empty list doesn't work. Have to specify it manually + Array axis_to_squeeze; + for (size_t i = 0; i < tt->shape.size(); i++) + if (tt->shape[i].as()->value == 1) axis_to_squeeze.push_back(i); + + return MakeSqueeze(exp, axis_to_squeeze); +} + +static Expr permute(const Expr& exp, const Array& perm) { + return MakeTranspose(exp, perm); +} + +static Expr broadcast(const Expr& exp, const Array& shape) { + return MakeBroadCastTo(exp, shape); +} + +static Array permutation(const std::string& from, const std::string& to) { + Array perm; + for (const auto& c : to) { + auto found = from.find(c); + ICHECK_NE(found, std::string::npos); + perm.push_back(found); + } + return perm; +} + +/*! + * @brief Helper namespace. Introduce elementwise arithmetic operations for expressions + * + * Broadcast semantic is included(forward and backward). If result tensor is a broadcast value it + * may be collapsed into scalar. + */ +namespace tensor_arithmetic { + +Expr operator+(const Expr& lhs, const Expr& rhs) { + if (is_const_scalar_eq(lhs, 0) || is_const_scalar_eq(lhs, 0.0f)) return rhs; + if (is_const_scalar_eq(rhs, 0) || is_const_scalar_eq(rhs, 0.0f)) return lhs; + + static const Op& op = Op::Get("add"); + return Call(op, {lhs, rhs}, Attrs(), {}); +} + +Expr operator-(const Expr& that) { + static const Op& op = Op::Get("negative"); + return Call(op, {that}, Attrs(), {}); +} + +Expr operator-(const Expr& lhs, const Expr& rhs) { + if (is_const_scalar_eq(lhs, 0) || is_const_scalar_eq(lhs, 0.0f)) return -rhs; + if (is_const_scalar_eq(rhs, 0) || is_const_scalar_eq(rhs, 0.0f)) return lhs; + + static const Op& op = Op::Get("subtract"); + return Call(op, {lhs, rhs}, Attrs(), {}); +} + +Expr operator*(const Expr& lhs, const Expr& rhs) { + if (is_const_scalar_eq(lhs, 1) || is_const_scalar_eq(lhs, 1.0f)) return rhs; + if (is_const_scalar_eq(rhs, 1) || is_const_scalar_eq(rhs, 1.0f)) return lhs; + if (is_const_scalar_eq(lhs, 0) || is_const_scalar_eq(lhs, 0.0f)) return constant(0.0f); + if (is_const_scalar_eq(rhs, 0) || is_const_scalar_eq(rhs, 0.0f)) return constant(0.0f); + + static const Op& op = Op::Get("multiply"); + return Call(op, {lhs, rhs}, Attrs(), {}); +} + +Expr operator/(const Expr& lhs, const Expr& rhs) { + if (is_const_scalar_eq(rhs, 1) || is_const_scalar_eq(rhs, 1.0f)) return lhs; + + static const Op& op = Op::Get("divide"); + return Call(op, {lhs, rhs}, Attrs(), {}); +} + +} // namespace tensor_arithmetic + +/*! + * @brief Graph linearizator. Construct sequence of CallNode objects in post dfs order. + * Helpful to check existence of some op in Function expr. And search it by name. + */ +class OpSeq : public ExprVisitor { + public: + struct Layer { + const CallNode* call_node_ = nullptr; + std::vector extern_args_ = {}; + + operator bool() const { return call_node_ != nullptr; } + }; + + /** return op descriptor for provided name, or empty layer if not exists */ + const Layer& getOpLayer(const std::string& name) const { + static Layer empty; + + auto found = std::find_if(layers_.begin(), layers_.end(), [&name](auto& l) { + return l.call_node_->op.template as()->name == name; + }); + + const auto& res = (found == layers_.end()) ? empty : *found; + return res; + } + + /** return list of call node names if post dfs order */ + std::vector getOpNames() const { + std::vector res; + for (auto& l : layers_) res.push_back(l.call_node_->op.as()->name); + return res; + } + + protected: + void VisitExpr_(const CallNode* cn) final { + ExprVisitor::VisitExpr_(cn); + + Layer res{cn}; + for (const auto& arg : cn->args) { + if (arg->IsInstance() || arg->IsInstance()) + res.extern_args_.push_back(arg); + } + layers_.push_back(res); + } + std::vector layers_; +}; + +class OpAttrMapExtractor : public AttrVisitor { + public: + OpAttrMapExtractor() {} + + const std::unordered_map& get() { return attr_map; } + + template ::value>> + std::string Fp2String(const T value) { + std::ostringstream out; + out.precision(std::numeric_limits::max_digits10); + out << value; + return out.str(); + } + + void SetNodeAttr(const char* key, const std::vector& value) { + std::vector attr; + attr.emplace_back(value); + attr_map[key] = dmlc::any{attr}; + } + + void Visit(const char* key, double* value) final { SetNodeAttr(key, {Fp2String(*value)}); } + + void Visit(const char* key, int64_t* value) final { SetNodeAttr(key, {std::to_string(*value)}); } + + void Visit(const char* key, uint64_t* value) final { SetNodeAttr(key, {std::to_string(*value)}); } + + void Visit(const char* key, int* value) final { SetNodeAttr(key, {std::to_string(*value)}); } + + void Visit(const char* key, bool* value) final { SetNodeAttr(key, {std::to_string(*value)}); } + + void Visit(const char* key, std::string* value) final { SetNodeAttr(key, {*value}); } + + void Visit(const char* key, DataType* value) final { + if (!value->is_void()) { + SetNodeAttr(key, {runtime::DLDataType2String(*value)}); + } else { + SetNodeAttr(key, {""}); + } + } + + void Visit(const char* key, runtime::ObjectRef* value) final { + if (const auto* an = (*value).as()) { + std::vector attr; + for (size_t i = 0; i < an->size(); ++i) { + if (const auto* im = (*an)[i].as()) { + attr.push_back(std::to_string(im->value)); + } else if (const auto* fm = (*an)[i].as()) { + attr.push_back(Fp2String(fm->value)); + } else if (const auto* str = (*an)[i].as()) { + String s = GetRef(str); + attr.push_back(s); + } else { + LOG(FATAL) << "Not supported type: " << (*an)[i]->GetTypeKey(); + } + } + SetNodeAttr(key, attr); + } else if (!(*value).defined()) { // Skip NullValue + SetNodeAttr(key, std::vector{""}); + } else if (const auto* im = (*value).as()) { + SetNodeAttr(key, std::vector{std::to_string(im->value)}); + } else if (const auto* fm = (*value).as()) { + SetNodeAttr(key, std::vector{Fp2String(fm->value)}); + } else if (const auto* str = (*value).as()) { + String s = GetRef(str); + SetNodeAttr(key, std::vector{s}); + } else { + LOG(FATAL) << "Not yet supported type: " << (*value)->GetTypeKey() << ": " << *value; + } + } + + void Visit(const char* key, runtime::NDArray* value) final { + LOG(FATAL) << "NDArray is not allowed in op attribute"; + } + + void Visit(const char* key, void** value) final { + LOG(FATAL) << "void pointer is not allowed in op attribute"; + } + + void Extract(Object* node) { + if (node) { + reflection_->VisitAttrs(node, this); + } + } + + private: + std::unordered_map attr_map; + ReflectionVTable* reflection_ = ReflectionVTable::Global(); +}; + +/*! + * @brief Helper function to extract attributes as collection of dmlc objects + * + * @param node node to extract attrs + * @return resulting collection of attributes + */ +std::unordered_map extractAttrs(const CallNode* node) { + OpAttrMapExtractor extractor; + const Object* call_attr = node->attrs.get(); + extractor.Extract(const_cast(call_attr)); + return extractor.get(); +} + +/*! + * @brief Converter attribute to dmlc acceptable format + * + * @tparam T type of value (auto deduction) + * @param val value to convert + * @return resulting dmlc object + */ +template ::value, bool> = true> +dmlc::any dmlc_attr(const T& val) { + std::vector attr; + attr.emplace_back(std::vector{std::to_string(val)}); + return dmlc::any{attr}; +} + +template ::value, bool> = true> +dmlc::any dmlc_attr(const T& val) { + std::vector attr; + attr.emplace_back(std::vector{val}); + return dmlc::any{attr}; +} + +template >::value, bool> = true> +dmlc::any dmlc_attr(const T& val) { + std::vector attr; + attr.emplace_back(val); + return dmlc::any{attr}; +} + +} // namespace contrib +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_BACKEND_CONTRIB_DNNL_CODEGEN_TOOLS_H_ diff --git a/src/relay/backend/contrib/dnnl/composite_op.h b/src/relay/backend/contrib/dnnl/composite_op.h new file mode 100644 index 000000000000..7f58e9fc5f8a --- /dev/null +++ b/src/relay/backend/contrib/dnnl/composite_op.h @@ -0,0 +1,403 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RELAY_BACKEND_CONTRIB_DNNL_COMPOSITE_OP_H_ +#define TVM_RELAY_BACKEND_CONTRIB_DNNL_COMPOSITE_OP_H_ + +#include +#include +#include +#include + +#include "codegen_tools.h" + +namespace tvm { +namespace relay { +namespace contrib { + +using KernelAttrs = std::unordered_map; +using KernelRequisites = std::pair, KernelAttrs>; + +/** + * Rule to replace relay graph with DNNL supported pattern. + * Conv+Bias+Requantize+Clip(cast_int8 or relu)+Sum + * + * @note + * assume wgh_zp == 0. Only symmetric weight are supported right now. + * + * Original relay representation: + * %1 = conv(SRC, WGH) - conv(src_zp, WGH) + BIAS + * %2 = (%1 - rq_in_zp) * rq_in_scl / rq_out_scl + rq_out_zp + * %3 = clip(%2, 0, 255) + * %4 = ((%3 - sum_lh_zp) * sum_lh_scl + (SRC2 - sum_rh_zp) * sum_rh_scl)/sum_out_scl + sum_out_zp + * + * DNNL implemented patern: + * %1 = clip((conv(SRC, WGH + wgh_shft) + (BIAS + bias_shft)) * o_scl, clip_low, clip_high) + * * clip_scl + SRC2 * sum_scl + dst_zp + * + * @note + * dst_zp can be moved into bias_shft + * clip_scl can be moved into o_scl + * + * Possible solution #0: + * clip_scl = sum_lh_scl /sum_out_scl + * clip_low = 0 - sum_lh_zp - sum_rh_zp * sum_rh_scl / sum_lh_scl + * clip_high = 255 - sum_lh_zp - sum_rh_zp * sum_rh_scl / sum_lh_scl + * o_scl = rq_in_scl / rq_out_scl + * bias_shft = - conv(src_zp, WGH) - rq_in_zp + (rq_out_zp - sum_lh_zp - sum_rh_zp * sum_rh_scl / + * sum_lh_scl) * rq_out_scl / rq_in_scl + * wgh_shft = 0 + * sum_scl = sum_rh_scl / sum_out_scl + * dst_zp = sum_out_zp + * + * + * Possible solution #1 (dst_zp == 0): + * new_clip_low = clip_low + dst_zp / clip_scl + * new_clip_high = clip_high + dst_zp / clip_scl + * new_bias_shft = bias_shft + dst_zp / clip_scl / o_scl + * + * clip_scl = sum_lh_scl /sum_out_scl + * clip_low = 0 - sum_lh_zp - sum_rh_zp * sum_rh_scl / sum_lh_scl + + * sum_out_zp * sum_out_scl / sum_lh_scl + * clip_high = 255 - sum_lh_zp - sum_rh_zp * sum_rh_scl / sum_lh_scl + + * sum_out_zp * sum_out_scl / sum_lh_scl + * o_scl = rq_in_scl / rq_out_scl + * bias_shft = - conv(src_zp, WGH) - rq_in_zp + (rq_out_zp - sum_lh_zp - sum_rh_zp * sum_rh_scl / + * sum_lh_scl) * rq_out_scl / rq_in_scl + sum_out_zp * sum_out_scl / sum_lh_scl * + * rq_out_scl / rq_in_scl + * sum_scl = sum_rh_scl / sum_out_scl + * dst_zp = 0 + * + * + * Possible solution #2 (clip_scl == 1.f): + * new_clip_low = clip_low * clip_scl + * new_clip_high = clip_high * clip_scl + * new_o_scl = o_scl * clip_scl + * + * clip_scl = 1.f + * clip_low = (0 - sum_lh_zp - sum_rh_zp * sum_rh_scl / sum_lh_scl + sum_out_zp * sum_out_scl / + * sum_lh_scl) * sum_lh_scl /sum_out_scl + * clip_high = (255 - sum_lh_zp - sum_rh_zp * sum_rh_scl / sum_lh_scl + sum_out_zp * sum_out_scl / + * sum_lh_scl) * sum_lh_scl /sum_out_scl + * o_scl = rq_in_scl / rq_out_scl * sum_lh_scl /sum_out_scl + * bias_shft = - conv(src_zp, WGH) - rq_in_zp + (rq_out_zp - sum_lh_zp - sum_rh_zp * sum_rh_scl / + * sum_lh_scl) * rq_out_scl / rq_in_scl + sum_out_zp * sum_out_scl / sum_lh_scl * + * rq_out_scl / rq_in_scl + * sum_scl = sum_rh_scl / sum_out_scl + * dst_zp = 0 + */ +struct qnn_arg_set_relay { + Expr wgh, bias, src_zp, wgh_zp, src_scl, wgh_scl, rq_in_zp, rq_in_scl, rq_out_zp, rq_out_scl, + sum_lh_zp, sum_lh_scl, sum_rh_zp, sum_rh_scl, sum_out_scl, sum_out_zp; +}; + +struct qnn_arg_set_dnnl { + Expr bias, clip_scl, clip_low, clip_high, o_scl, sum_scl, dst_zp; + + /** Evaluate contained expressions and collapse to scalar if it's broadcast scalar */ + qnn_arg_set_dnnl evalAndCollapseToScalar() const { + qnn_arg_set_dnnl res; + + res.bias = collapse_to_scalar(EvalExpr(this->bias)); + res.clip_low = collapse_to_scalar(EvalExpr(this->clip_low)); + res.clip_high = collapse_to_scalar(EvalExpr(this->clip_high)); + res.clip_scl = collapse_to_scalar(EvalExpr(this->clip_scl)); + res.dst_zp = collapse_to_scalar(EvalExpr(this->dst_zp)); + res.o_scl = collapse_to_scalar(EvalExpr(this->o_scl)); + res.sum_scl = collapse_to_scalar(EvalExpr(this->sum_scl)); + + return res; + } +}; + +qnn_arg_set_dnnl qnnReformulate(const qnn_arg_set_relay& origin) { + auto& r = origin; // short alias "relay" + using namespace tensor_arithmetic; + ICHECK(is_const_scalar_eq(r.wgh_zp, 0)) << "Doesn't support patterns with not zero kernel_zp"; + + // Convolution on zp filled data. Also applicable for dense and grouped conv. + auto conv_zp = [](const Expr& zp, const Expr& wgh) -> Expr { + if (is_const_scalar_eq(zp, 0)) return constant(0); + ICHECK(is_scalar(zp)) << "Only scalar data_zp is supported for qnn primitives"; + + // reduce kernel {OC, IC, KH, KW} -> {OC} in case of group that is still correct + auto reduced_kernel = + MakeReduce(cast(wgh), {0}, false /*keepdims*/, true /*exclude*/, "sum"); + return zp * reduced_kernel; + }; + + // If there is no bias will use zero value + auto bias = r.bias.defined() ? r.bias : constant(0); + + // Will use formulas #2 (dst_zp == 0, clip_scl == 1.0f) + qnn_arg_set_dnnl res; + res.dst_zp = constant(0); + res.o_scl = r.rq_in_scl / r.rq_out_scl * r.sum_lh_scl / r.sum_out_scl; + res.sum_scl = r.sum_rh_scl / r.sum_out_scl; + res.clip_scl = constant(1.0f); + res.clip_low = (cast(constant(0) - r.sum_lh_zp) - + cast(r.sum_rh_zp) * r.sum_rh_scl / r.sum_lh_scl + + cast(r.sum_out_zp) * r.sum_out_scl / r.sum_lh_scl) * + r.sum_lh_scl / r.sum_out_scl; + res.clip_high = (cast(constant(255) - r.sum_lh_zp) - + cast(r.sum_rh_zp) * r.sum_rh_scl / r.sum_lh_scl + + cast(r.sum_out_zp) * r.sum_out_scl / r.sum_lh_scl) * + r.sum_lh_scl / r.sum_out_scl; + res.bias = cast(bias) - cast(conv_zp(r.src_zp, r.wgh) + r.rq_in_zp) + + cast(r.rq_out_zp - r.sum_lh_zp) * r.rq_out_scl / r.rq_in_scl - + cast(r.sum_rh_zp) * r.sum_rh_scl / r.sum_lh_scl * r.rq_out_scl / r.rq_in_scl + + cast(r.sum_out_zp) * r.sum_out_scl / r.sum_lh_scl * r.rq_out_scl / r.rq_in_scl; + + return res.evalAndCollapseToScalar(); +} + +/*! + * @brief Specify optional QNN args and attrs if required + * + * @param wgh weight node + * @param bias bias node (constant node) + * @param base base action node (conv or dense) + * @param rq requantize node (optional) + * @param sum sum node (optional) + * @param inputs resulting input collection (will append to it) + * @param attrs resulting attribute collection (will append to it) + */ +void optQnnArgsForRqSumPattern(const Expr& wgh, const Expr& bias, const OpSeq::Layer& base, + const OpSeq::Layer& rq, const OpSeq::Layer& sum, + std::vector* inputs, KernelAttrs* attrs) { + ICHECK(wgh.defined()); + ICHECK(base); + ICHECK(inputs); + ICHECK(attrs); + + qnn_arg_set_relay args_relay; + args_relay.wgh = wgh; + args_relay.bias = bias; + + args_relay.src_zp = base.extern_args_[2]; + args_relay.wgh_zp = base.extern_args_[3]; + args_relay.src_scl = base.extern_args_[4]; + args_relay.wgh_scl = base.extern_args_[5]; + + // Requantize is optional + args_relay.rq_in_scl = rq ? rq.extern_args_[0] : constant(1.f); + args_relay.rq_in_zp = rq ? rq.extern_args_[1] : constant(0); + args_relay.rq_out_scl = rq ? rq.extern_args_[2] : constant(1.f); + args_relay.rq_out_zp = rq ? rq.extern_args_[3] : constant(0); + + // Sum is optional + args_relay.sum_lh_scl = sum ? sum.extern_args_[1] : constant(1.f); + args_relay.sum_lh_zp = sum ? sum.extern_args_[2] : constant(0); + args_relay.sum_rh_scl = sum ? sum.extern_args_[3] : constant(0.f); + args_relay.sum_rh_zp = sum ? sum.extern_args_[4] : constant(0); + args_relay.sum_out_scl = sum ? sum.extern_args_[5] : constant(1.f); + args_relay.sum_out_zp = sum ? sum.extern_args_[6] : constant(0); + + // Recalculate QNN specific arguments + auto args_dnnl = qnnReformulate(args_relay); + + // Helper to register optional qnn args + auto put_arg = [&attrs, &inputs](const Expr& expr, std::string name, auto skip_value) { + if (expr.defined() && !is_const_scalar_eq(expr, skip_value)) { + (*attrs)[name] = dmlc_attr(inputs->size()); + inputs->push_back(expr); + } + }; + + // Bias should be a vector {OC}, even if it's scalar + if (is_scalar(args_dnnl.bias) && !is_const_scalar_eq(args_dnnl.bias, 0)) { + int OC = shape_of(wgh)[0]; + args_dnnl.bias = EvalExpr(broadcast(args_dnnl.bias, {OC})); + } + + put_arg(args_dnnl.bias, "bias_idx", 0); + put_arg(args_dnnl.o_scl, "o_scl_idx", 1); + put_arg(args_dnnl.dst_zp, "dst_zp_idx", 0); + + if (!is_const_scalar_eq(args_dnnl.sum_scl, 0.f)) { + put_arg(sum.extern_args_[0], "sum_idx", std::nanf("")); + put_arg(args_dnnl.sum_scl, "sum_scl_idx", 0); + } + + if (args_dnnl.clip_scl.defined()) { + ICHECK(is_scalar(args_dnnl.clip_low)); + ICHECK(is_scalar(args_dnnl.clip_high)); + + std::vector clip_attr{"clip"}; + clip_attr.push_back(std::to_string(inputs->size())); + inputs->push_back(args_dnnl.clip_scl); + clip_attr.push_back(std::to_string(inputs->size())); + inputs->push_back(args_dnnl.clip_low); + clip_attr.push_back(std::to_string(inputs->size())); + inputs->push_back(args_dnnl.clip_high); + + (*attrs)["activation"] = dmlc_attr(clip_attr); + } +} + +/*! + * Legalize bias shape to 1D form + * + * @param orig_bias + * @return 1D version of original bias expr + */ +Expr legalizeBiasShape(const Expr& orig_bias) { return EvalExpr(squeeze(orig_bias)); } + +/** + * Parse qnn.conv2d based fused patterns + * @param fn function to parse + */ +KernelRequisites parseQnnConv2dComposite(const FunctionNode* fn) { + OpSeq ops; + ops(fn->body); + + std::vector qnn_conv_sum_pat{"qnn.conv2d", "add", "qnn.requantize", "clip", "cast", + "qnn.add", "clip"}, + qnn_conv_sum_no_bias_pat{"qnn.conv2d", "qnn.requantize", "clip", "cast", "qnn.add", "clip"}, + qnn_conv_pat{"qnn.conv2d", "add", "qnn.requantize", "clip", "cast"}, + qnn_conv_no_bias_pat{"qnn.conv2d", "qnn.requantize", "clip", "cast"}; + + auto layer_names = ops.getOpNames(); + ICHECK(layer_names == qnn_conv_sum_pat || layer_names == qnn_conv_pat || + layer_names == qnn_conv_no_bias_pat || layer_names == qnn_conv_sum_no_bias_pat) + << "Unsupported patter for DNNL code generator. Looks like some discrepancy " + "between DNNL partitioner pass and code generator."; + + auto conv = ops.getOpLayer("qnn.conv2d"); + auto bs = ops.getOpLayer("add"); + auto rq = ops.getOpLayer("qnn.requantize"); + auto sum = ops.getOpLayer("qnn.add"); + + auto data = conv.extern_args_[0]; + auto wgh = conv.extern_args_[1]; + auto bias = bs ? legalizeBiasShape(bs.extern_args_[0]) : Expr{}; + + // make regular wights layout + auto wgh_layout = conv.call_node_->attrs.as()->kernel_layout; + auto oihw_wgh = permute(wgh, permutation(wgh_layout, "OIHW")); + + auto attrs = extractAttrs(conv.call_node_); // extract original attrs + std::vector inputs = {data, wgh}; // args with fixed positions + + optQnnArgsForRqSumPattern(oihw_wgh, bias, conv, rq, sum, &inputs, &attrs); + return {inputs, attrs}; +} + +KernelRequisites parseQnnDenseComposite(const FunctionNode* fn) { + OpSeq ops; + ops(fn->body); + + std::vector qnn_dense_sum_pat{"qnn.dense", "add", "qnn.requantize", "clip", "cast", + "qnn.add", "clip"}, + qnn_dense_sum_no_bias_pat{"qnn.dense", "qnn.requantize", "clip", "cast", "qnn.add", "clip"}, + qnn_dense_pat{"qnn.dense", "add", "qnn.requantize", "clip", "cast"}, + qnn_dense_no_bias_pat{"qnn.dense", "qnn.requantize", "clip", "cast"}; + + auto layer_names = ops.getOpNames(); + ICHECK(layer_names == qnn_dense_sum_pat || layer_names == qnn_dense_sum_no_bias_pat || + layer_names == qnn_dense_pat || layer_names == qnn_dense_no_bias_pat) + << "Unsupported patter for DNNL code generator. Looks like some discrepancy " + "between DNNL partitioner pass and code generator."; + + auto dense = ops.getOpLayer("qnn.dense"); + auto bs = ops.getOpLayer("add"); + auto rq = ops.getOpLayer("qnn.requantize"); + auto sum = ops.getOpLayer("qnn.add"); + + auto data = dense.extern_args_[0]; + auto wgh = dense.extern_args_[1]; + auto bias = bs ? legalizeBiasShape(bs.extern_args_[0]) : Expr{}; + + auto attrs = extractAttrs(dense.call_node_); // extract original attrs + std::vector inputs = {data, wgh}; // args with fixed positions + + optQnnArgsForRqSumPattern(wgh, bias, dense, rq, sum, &inputs, &attrs); + return {inputs, attrs}; +} + +KernelRequisites parseBaseOpComposite(const FunctionNode* fn, const std::string& base_op_name) { + ICHECK(base_op_name == "nn.conv2d" || base_op_name == "nn.dense"); + OpSeq ops; + ops(fn->body); + + auto conv = ops.getOpLayer(base_op_name); + auto bias = ops.getOpLayer("add"); + auto relu = ops.getOpLayer("nn.relu"); + auto tanh = ops.getOpLayer("tanh"); + auto sigm = ops.getOpLayer("sigmoid"); + + auto act = relu ? relu : tanh ? tanh : sigm ? sigm : OpSeq::Layer{}; + + auto attrs = extractAttrs(conv.call_node_); + std::vector inputs = { + conv.extern_args_[0], // data + conv.extern_args_[1] // kernel + }; + + if (bias) { + attrs["bias_idx"] = dmlc_attr(inputs.size()); + inputs.push_back(bias.extern_args_[0]); + } + + if (act) { + auto act_name = act.call_node_->op.as()->name; + std::vector act_attr = {act_name}; + act_attr.push_back(std::to_string(inputs.size())); + inputs.push_back(InferType(constant(1.0f))); + act_attr.push_back(std::to_string(inputs.size())); + inputs.push_back(InferType(constant(0.0f))); + act_attr.push_back(std::to_string(inputs.size())); + inputs.push_back(InferType(constant(0.0f))); + + attrs["activation"] = dmlc_attr(act_attr); + } + + return {inputs, attrs}; +} + +KernelRequisites DNNLCompositeFunctionsParser(const FunctionNode* fn) { + auto comp = fn->GetAttr(attr::kComposite); + ICHECK(comp.defined()); + std::string name = comp.value(); + + if (name == "dnnl.qnn.conv2d_sum" || name == "dnnl.qnn.conv2d") { + return parseQnnConv2dComposite(fn); + } else if (name == "dnnl.qnn.dense_sum" || name == "dnnl.qnn.dense") { + return parseQnnDenseComposite(fn); + } else if (name == "dnnl.conv2d_bias_relu" || name == "dnnl.conv2d_bias_tanh" || + name == "dnnl.conv2d_bias_sigmoid" || name == "dnnl.conv2d_bias" || + name == "dnnl.conv2d_relu" || name == "dnnl.conv2d_tanh" || + name == "dnnl.conv2d_sigmoid") { + return parseBaseOpComposite(fn, "nn.conv2d"); + } else if (name == "dnnl.dense_bias_relu" || name == "dnnl.dense_bias_tanh" || + name == "dnnl.dense_bias_sigmoid" || name == "dnnl.dense_bias" || + name == "dnnl.dense_relu" || name == "dnnl.dense_tanh" || + name == "dnnl.dense_sigmoid") { + return parseBaseOpComposite(fn, "nn.dense"); + } else { + LOG(FATAL) << "Unknown composite function " << name; + } + return {}; +} + +} // namespace contrib +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_BACKEND_CONTRIB_DNNL_COMPOSITE_OP_H_ diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index f9f1961e2697..8738633fe668 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -32,6 +32,7 @@ #include "../json/json_node.h" #include "../json/json_runtime.h" #include "dnnl.hpp" +#include "dnnl_node_helper.h" namespace tvm { namespace runtime { @@ -41,85 +42,192 @@ using namespace tvm::runtime; using namespace tvm::runtime::json; class DNNLJSONRuntime : public JSONRuntimeBase { - using tag = dnnl::memory::format_tag; - using dt = dnnl::memory::data_type; - public: DNNLJSONRuntime(const std::string& symbol_name, const std::string& graph_json, - const Array const_names) - : JSONRuntimeBase(symbol_name, graph_json, const_names) {} + const Array& const_names) + : JSONRuntimeBase(symbol_name, graph_json, const_names), + g_explorer_(nodes_, data_entry_, node_row_ptr_, engine_) {} - const char* type_key() const { return "dnnl_json"; } + const char* type_key() const override { return "dnnl_json"; } - void Init(const Array& consts) override { - BuildEngine(); + static std::string get_version() { + auto v = dnnl_version(); + std::stringstream ver_strm; + ver_strm << v->major << '.' << v->minor << '.' << v->patch; + return ver_strm.str(); + } + void Init(const Array& consts) override { ICHECK_EQ(consts.size(), const_idx_.size()) << "The number of input constants must match the number of required."; - // Setup constants entries for weights. SetupConstants(consts); + // Init internal DNNL specific objects + BuildEngine(); } - void Run() override { - // Fill in the input buffers. - for (size_t i = 0; i < input_nodes_.size(); ++i) { - auto eid = EntryID(input_nodes_[i], 0); - // TODO(@comaniac): Support other data lengths. - size_t offset_in_bytes = entry_out_mem_[eid].second * 4; - size_t buffer_size = GetDataSize(*data_entry_[eid]); - write_to_dnnl_memory(data_entry_[eid]->data, entry_out_mem_[eid].first, buffer_size, - offset_in_bytes); - } + /** + * Override of GetFunction methods to replace main symbol_name_ implementation with + * thread safe one. + */ + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) override { + if (this->symbol_name_ == name) { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + ICHECK(this->initialized_) << "The module has not been initialized"; + + ICHECK_EQ(args.size(), input_var_eid_.size() + outputs_.size()) + << "Found mismatch in the number of provided data entries and required."; - // Invoke the engine through intepreting the stream. - for (size_t i = 0; i < net_.size(); ++i) { - net_.at(i).execute(stream_, net_args_.at(i)); + Run(args); + }); + } else { + return JSONRuntimeBase::GetFunction(name, sptr_to_self); } - stream_.wait(); - - // Read output buffers. - for (size_t i = 0; i < outputs_.size(); ++i) { - auto eid = EntryID(outputs_[i]); - size_t offset_in_bytes = entry_out_mem_[eid].second * 4; - size_t buffer_size = GetDataSize(*data_entry_[eid]); - read_from_dnnl_memory(data_entry_[eid]->data, entry_out_mem_[eid].first, buffer_size, - offset_in_bytes); + } + + /** + * @brief Thread safe version of base method Run. + * + * The main purpose of this overwrite is to make symbol_name_ function thread safe. + * The base implementation of that method is using SetInputOutputBuffers() which + * is not thread safe and lead to changes in DNNLJSONRuntime itself. + * + * @param args kernel arguments + */ + void Run(const TVMArgs& args) const { + auto io_data_provider = makeIoDataProvider(args); + // Execute primitives one by one + for (const auto& act : net_) { + auto req_args = std::get(act); + auto prim = std::get(act); + + // Find proper dnnl::memory buffer based on provided ArgRequisite + auto mem_args = tensor_registry_.solve(req_args, io_data_provider); + prim.execute(stream_, mem_args); } } + /** @brief Stub implementation */ + void Run() override { LOG(ERROR) << "Unimplemented. Should never be called."; } + private: + /** Receive tensor memory buffer handler based from provided arg */ + static void* extractDataHandle(const TVMArgValue& val) { + ICHECK(val.type_code() == kTVMNDArrayHandle || val.type_code() == kTVMDLTensorHandle) + << "Expect NDArray or DLTensor"; + void* hdl = nullptr; + if (val.IsObjectRef()) { + NDArray arr = val; + hdl = arr.operator->()->data; + } else { + hdl = val.operator DLTensor*()->data; + } + return hdl; + } + + TensorRegistry::ExtDataProvider makeIoDataProvider(const TVMArgs& args) const { + std::map io_map; // eid to data handler + + int i = 0; + for (auto e : input_var_eid_) io_map[e] = extractDataHandle(args[i++]); + for (auto e : outputs_) io_map[EntryID(e)] = extractDataHandle(args[i++]); + + // lambda with captured IO data handlers + return [io_map](uint32_t eid) -> void* { return io_map.at(eid); }; + } + + std::set makeIoEids() const { + std::set io_set; // eid of inputs and outputs + for (auto e : input_var_eid_) io_set.insert(e); + for (auto e : outputs_) io_set.insert(EntryID(e)); + return io_set; + } + + struct SubmitAttr { + enum AttrType { None, ZeroCopyRequest }; + + SubmitAttr() {} + SubmitAttr(AttrType type, const TensorRequisite& tr, int flag) + : type_(type), tr_(tr), flag_(flag) {} + + AttrType type_ = AttrType::None; + const TensorRequisite tr_ = {}; + int flag_ = 0; + }; + + // Helper function to register primitive into execution queue + void submit(const dnnl::primitive& prim, const std::unordered_map& tr_args, + const SubmitAttr attr = {}) { + // collection of post action. Dst primitive processing will be stored here + TensorRegistry::ActionQue post_actions; + + // Helper func to register TensorRequisite and store corresponding Actions in proper place + auto register_tr = [this, &post_actions](const TensorRequisite& tr) { + TensorRegistry::ArgReq arg_req; + TensorRegistry::ActionQue actions; + std::tie(arg_req, actions) = tensor_registry_.registerTR(tr); + + auto& action_queue = tr.isReversed() ? post_actions : net_; + action_queue.insert(action_queue.end(), actions.begin(), actions.end()); + return arg_req; + }; + + // Register all provided TR arguments + std::unordered_map arg_reqs; + for (const auto& kvp : tr_args) { + const auto& tr = kvp.second; + const auto& key = kvp.first; + + if (!tr.defined()) continue; // empty arg is admitted. Just skip it + arg_reqs[key] = register_tr(tr); + } + + // ZeroCopyRequest or Inplace memory + if (attr.type_ == SubmitAttr::ZeroCopyRequest) { + auto zero_copy_src_tr = attr.tr_; + auto zero_copy_dst_tr = tr_args.at(attr.flag_); + auto zero_copy_src_ar = register_tr(zero_copy_src_tr); + auto zero_copy_dst_ar = arg_reqs.at(attr.flag_); + + // Register copy action direct before main primitive + dnnl::reorder::primitive_desc io_copy_pd(engine_, zero_copy_src_tr.desc(), engine_, + zero_copy_dst_tr.desc()); + net_.push_back({dnnl::reorder(io_copy_pd), + {{DNNL_ARG_SRC, zero_copy_src_ar}, {DNNL_ARG_DST, zero_copy_dst_ar}}}); + } + + // Register main primitive + net_.push_back({prim, arg_reqs}); + + // Register post actions + net_.insert(net_.end(), post_actions.begin(), post_actions.end()); + } + // Build up the engine based on the input graph. void BuildEngine() { engine_ = dnnl::engine(dnnl::engine::kind::cpu, 0); stream_ = dnnl::stream(engine_); + tensor_registry_ = TensorRegistry(engine_, makeIoEids()); // Build subgraph engine. - for (size_t nid = 0; nid < nodes_.size(); ++nid) { + for (uint32_t nid = 0; nid < nodes_.size(); ++nid) { const auto& node = nodes_[nid]; if (node.GetOpType() == "kernel") { ICHECK_EQ(node.GetOpType(), "kernel"); auto op_name = node.GetOpName(); - if ("nn.conv2d" == op_name) { - Conv2d(nid); - } else if ("dnnl.conv2d_relu" == op_name) { - Conv2d(nid, true, false, dnnl::algorithm::eltwise_relu); - } else if ("dnnl.conv2d_tanh" == op_name) { - Conv2d(nid, true, false, dnnl::algorithm::eltwise_tanh); - } else if ("dnnl.conv2d_sigmoid" == op_name) { - Conv2d(nid, true, false, dnnl::algorithm::eltwise_logistic); - } else if ("dnnl.conv2d_bias" == op_name) { - Conv2d(nid, false, true); - } else if ("dnnl.conv2d_bias_relu" == op_name) { - Conv2d(nid, true, true, dnnl::algorithm::eltwise_relu); - } else if ("dnnl.conv2d_bias_tanh" == op_name) { - Conv2d(nid, true, true, dnnl::algorithm::eltwise_tanh); - } else if ("dnnl.conv2d_bias_sigmoid" == op_name) { - Conv2d(nid, true, true, dnnl::algorithm::eltwise_logistic); - } else if ("nn.dense" == op_name) { - Dense(nid); - } else if ("dnnl.dense_bias" == op_name) { - Dense(nid, true); + + if ("nn.conv2d" == op_name || "dnnl.conv2d_relu" == op_name || + "dnnl.conv2d_tanh" == op_name || "dnnl.conv2d_sigmoid" == op_name || + "dnnl.conv2d_bias" == op_name || "dnnl.conv2d_bias_relu" == op_name || + "dnnl.conv2d_bias_tanh" == op_name || "dnnl.conv2d_bias_sigmoid" == op_name || + "dnnl.qnn.conv2d" == op_name || "dnnl.qnn.conv2d_sum" == op_name) { + UniConv2d(nid); + } else if ("nn.dense" == op_name || "dnnl.dense_relu" == op_name || + "dnnl.dense_tanh" == op_name || "dnnl.dense_sigmoid" == op_name || + "dnnl.dense_bias" == op_name || "dnnl.dense_bias_relu" == op_name || + "dnnl.dense_bias_tanh" == op_name || "dnnl.dense_bias_sigmoid" == op_name || + "dnnl.qnn.dense" == op_name || "dnnl.qnn.dense_sum" == op_name) { + UniDense(nid); } else if ("nn.batch_norm" == op_name) { BatchNorm(nid); } else if ("nn.relu" == op_name) { @@ -137,342 +245,329 @@ class DNNLJSONRuntime : public JSONRuntimeBase { } } } - } - // Bind a JSON graph node entry to a DNNL memory. - dnnl::memory BindDNNLMemory(const JSONGraphNodeEntry& entry, dnnl::memory::desc mem_desc, - size_t offset = 0) { - auto eid = EntryID(entry); - if (entry_out_mem_.count(eid) == 0) { - return BindDNNLMemory(entry, dnnl::memory(mem_desc, engine_), offset); - } - return entry_out_mem_[eid].first; + tensor_registry_.finalize(); } - // Bind a JSON graph node entry to a given DNNL memory. - dnnl::memory BindDNNLMemory(const JSONGraphNodeEntry& entry, dnnl::memory mem, - size_t offset = 0) { - auto eid = EntryID(entry); - // Since the DNNL memory has been created before calling this function, we assume the entry - // has not yet been bound to the other DNNL memory; otherwise it may have memory leak. - ICHECK_EQ(entry_out_mem_.count(eid), 0); - - // TODO(@comanic): Support other data types (i.e., int8). - auto data_node = nodes_[entry.id_]; - auto dltype = data_node.GetOpDataType()[entry.index_]; - ICHECK_EQ(dltype.bits, 32); - - entry_out_mem_[eid] = {mem, offset}; - return entry_out_mem_[eid].first; - } - - void Conv2d(const size_t& nid, const bool has_elt = false, const bool has_bias = false, - dnnl::algorithm algo = dnnl::algorithm::eltwise_relu) { - auto node = nodes_[nid]; - - // Setup attributes. - auto data_entry = node.GetInputs()[0]; - auto weight_entry = node.GetInputs()[1]; - dnnl::memory::dims input_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; - dnnl::memory::dims weight_shape = nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_]; - std::vector str_strides = node.GetAttr>("strides"); - std::vector str_dilates = node.GetAttr>("dilation"); - std::vector str_padding = node.GetAttr>("padding"); - dnnl::memory::dim groups = std::stoi(node.GetAttr>("groups")[0]); - - dnnl::memory::dim N = input_shape[0], // batch size - IC = input_shape[1], // input channels - IH = input_shape[2], // input height - IW = input_shape[3], // input width - OC = weight_shape[0], // output channels - KH = weight_shape[2], // weight height - KW = weight_shape[3], // weight width - PW_L = std::stoi(str_padding[1]), // width padding: left - PW_R = std::stoi(str_padding[3]), // width padding: right - PH_L = std::stoi(str_padding[0]), // height padding: top - PH_R = std::stoi(str_padding[2]), // height padding: bottom - SH = std::stoi(str_strides[0]), // height-wise stride - SW = std::stoi(str_strides[1]), // weight-wise stride - DH = std::stoi(str_dilates[0]) - 1, // height-wise dilate - DW = std::stoi(str_dilates[1]) - 1, // weight-wise dilate - DKH = 1 + (KH - 1) * (DH + 1), // dilated weight height - DKW = 1 + (KW - 1) * (DW + 1), // dilated weight width - OH = (IH - DKH + PH_L + PH_R) / SH + 1, // output height - OW = (IW - DKW + PW_L + PW_R) / SW + 1; // output width - - // Memory shapes. - dnnl::memory::dims src_dims = {N, IC, IH, IW}; - dnnl::memory::dims weights_dims = {OC, IC, KH, KW}; + void UniConv2d(const uint32_t& nid) { + auto node = NodeHelper{nid, g_explorer_}; + + // Fix position inputs + auto data_tr = node.getInput(0); + auto kernel_tr = node.getInput(1); + auto output_tr = node.getOutput(0); + + // Parse general conv attributes + auto strides = node.getAttr>("strides"); + auto padding = node.getAttr>("padding"); + auto dilation = node.getAttr>("dilation"); + auto groups = node.getAttr("groups"); + + auto data_layout = node.getAttr("data_layout"); + auto kernel_layout = node.getAttr("kernel_layout"); + + auto activation = node.getAttr>("activation", {"none"}); + auto bias_idx = node.getAttr("bias_idx", {"-1"}); + auto sum_idx = node.getAttr("sum_idx", {"-1"}); + auto sum_scl_idx = node.getAttr("sum_scl_idx", {"-1"}); + auto o_scl_idx = node.getAttr("o_scl_idx", {"-1"}); + auto dst_zp_idx = node.getAttr("dst_zp_idx", {"-1"}); + + // may be empty in case if '-1' + auto bias_tr = node.getInput(bias_idx); + auto sum_tr = node.getInput(sum_idx); + auto sum_scl_tr = node.getInput(sum_scl_idx); + auto o_scl_tr = node.getInput(o_scl_idx); + auto dst_zp_tr = node.getInput(dst_zp_idx); + + // permute corresponding with provided layouts + auto data_permutation = utils::permutation(data_layout, "NCHW"); + auto kernel_permutation = utils::permutation(kernel_layout, "OIHW"); + + data_tr = data_tr.permute(data_permutation); + sum_tr = sum_tr.permute(data_permutation); + output_tr = output_tr.permute(data_permutation); + kernel_tr = kernel_tr.permute(kernel_permutation); + + // TODO(@apeskov): temp WA. while codegen is not able to guarantee 1D format of bias data + bias_tr = bias_tr.squeeze(); + + // Group weight format if (groups > 1) { - weights_dims = {groups, 1, IC / groups, KH, KW}; - } - dnnl::memory::dims bias_dims = {OC}; - dnnl::memory::dims dst_dims = {N, OC, OH, OW}; - dnnl::memory::dims strides_dims = {SH, SW}; - dnnl::memory::dims dilates_dims = {DH, DW}; - dnnl::memory::dims padding_dims_l = {PH_L, PW_L}; - dnnl::memory::dims padding_dims_r = {PH_R, PW_R}; - - // Memory descriptions. - auto conv_src_md = dnnl::memory::desc(src_dims, dt::f32, tag::any); - auto conv_weights_md = dnnl::memory::desc(weights_dims, dt::f32, tag::any); - auto conv_bias_md = dnnl::memory::desc(bias_dims, dt::f32, tag::any); - auto conv_dst_md = dnnl::memory::desc(dst_dims, dt::f32, tag::nchw); - - // Covn2d description. - auto conv_desc = dnnl::convolution_forward::desc( - dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_direct, conv_src_md, - conv_weights_md, conv_bias_md, conv_dst_md, strides_dims, dilates_dims, padding_dims_l, - padding_dims_r); - - // Enable elementwise post-ops - dnnl::primitive_attr attr; - if (has_elt) { - dnnl::post_ops ops; - ops.append_eltwise(1.f, algo, 0.f, 0.f); - attr.set_post_ops(ops); + auto k_dims = kernel_tr.dims(); // OIHW -> GOIHW + k_dims[0] /= groups; + k_dims.insert(k_dims.begin(), groups); + kernel_tr = kernel_tr.reshape(k_dims); } - auto conv2d_prim_desc = dnnl::convolution_forward::primitive_desc(conv_desc, attr, engine_); + // Attributes setting + dnnl::primitive_attr attr; + attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); - // Push to the network. - auto conv = dnnl::convolution_forward(conv2d_prim_desc); - net_.push_back(conv); + if (dst_zp_tr) { + auto zp = dst_zp_tr.getConstScalarData(); + // Per channel zp is not supported. It was merged into BIAS + attr.set_zero_points(DNNL_ARG_DST, 0, {zp}); + } - // Data memory. - ICHECK_EQ(node.GetAttr>("data_layout")[0], "NCHW"); - auto conv2d_src_memory = BindDNNLMemory(data_entry, {src_dims, dt::f32, tag::nchw}); + if (o_scl_tr) { + ICHECK(o_scl_tr.isConstant()); + auto data = o_scl_tr.getConstDataLikeVec(); + attr.set_output_scales(data.size() == 1 ? 0 : (1 << 1), data); + } - // Weight memory. - ICHECK_EQ(node.GetAttr>("kernel_layout")[0], "OIHW"); - auto conv2d_weights_memory = BindDNNLMemory( - weight_entry, {weights_dims, dt::f32, (groups > 1) ? tag::goihw : tag::oihw}); + if (activation[0] != "none") { + auto a_type = utils::convert2dnnl_activation(activation[0]); + auto a_scale = node.getInput(std::stoi(activation[1])).getConstScalarData(); + auto a_alfa = node.getInput(std::stoi(activation[2])).getConstScalarData(); + auto a_beta = node.getInput(std::stoi(activation[3])).getConstScalarData(); - // Bias memory. - auto conv2d_bias_memory = dnnl::memory({bias_dims, dt::f32, tag::x}, engine_); - if (has_bias) { - auto bias_entry = node.GetInputs()[2]; - BindDNNLMemory(bias_entry, conv2d_bias_memory); - } else { - float bias[OC] = {0}; - write_to_dnnl_memory(bias, conv2d_bias_memory, OC * sizeof(float)); + auto ops = attr.get_post_ops(); + ops.append_eltwise(a_scale, a_type, a_alfa, a_beta); + attr.set_post_ops(ops); } - // Output memory. - JSONGraphNodeEntry out_entry(nid, 0); - auto conv2d_dst_memory = BindDNNLMemory(out_entry, conv2d_prim_desc.dst_desc()); + if (sum_scl_tr) { + auto scl = sum_scl_tr.getConstScalarData(); + auto ops = attr.get_post_ops(); + ops.append_sum(scl); + attr.set_post_ops(ops); + } - // Bind memory buffers. - net_args_.push_back({{DNNL_ARG_SRC, conv2d_src_memory}, - {DNNL_ARG_WEIGHTS, conv2d_weights_memory}, - {DNNL_ARG_BIAS, conv2d_bias_memory}, - {DNNL_ARG_DST, conv2d_dst_memory}}); + dnnl::memory::dim PW_L = padding[1], // width padding: left + PW_R = padding[3], // width padding: right + PH_L = padding[0], // height padding: top + PH_R = padding[2], // height padding: bottom + SH = strides[0], // height-wise stride + SW = strides[1], // weight-wise stride + DH = dilation[0] - 1, // height-wise dilation, DNNL uses dilation format with - 1 + DW = dilation[1] - 1; // weight-wise dilation + + // Conv description + auto conv_d = dnnl::convolution_forward::desc( + dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_direct, + data_tr.layoutAny().desc(), kernel_tr.layoutAny().desc(), bias_tr.layoutAny().desc(), + output_tr.layoutAny().desc(), {SH, SW} /*strides*/, {DH, DW} /*dilation*/, + {PH_L, PW_L} /*padding_l*/, {PH_R, PW_R} /*padding_r*/); + + auto conv_pd = dnnl::convolution_forward::primitive_desc(conv_d, attr, engine_); + auto conv = dnnl::convolution_forward(conv_pd); + + // Specify proper layouts + data_tr = data_tr.requestLayout(conv_pd.src_desc()); + kernel_tr = kernel_tr.requestLayout(conv_pd.weights_desc()); + output_tr = output_tr.requestLayout(conv_pd.dst_desc()); + bias_tr = bias_tr.requestLayout(conv_pd.bias_desc()); + + auto scratchpad_tr = node.makeScratchpad(conv_pd.scratchpad_desc()); + + // Inplace request for conv+sum pattern. Match input with dst tensor + auto submit_attr = + sum_tr ? SubmitAttr{SubmitAttr::ZeroCopyRequest, sum_tr, DNNL_ARG_DST} : SubmitAttr{}; + + // Register prim to execute + submit(conv, + {{DNNL_ARG_SRC, data_tr}, + {DNNL_ARG_WEIGHTS, kernel_tr}, + {DNNL_ARG_BIAS, bias_tr}, + {DNNL_ARG_SCRATCHPAD, scratchpad_tr}, + {DNNL_ARG_DST, output_tr}}, + submit_attr); } - void Dense(const size_t& nid, const bool has_bias = false) { - auto node = nodes_[nid]; + void UniDense(const uint32_t& nid) { + auto node = NodeHelper{nid, g_explorer_}; - // Setup attributes. - auto data_entry = node.GetInputs()[0]; - auto weight_entry = node.GetInputs()[1]; - dnnl::memory::dims input_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; - dnnl::memory::dims weight_shape = nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_]; + auto src_tr = node.getInput(0); + auto wgh_tr = node.getInput(1); + auto dst_tr = node.getOutput(0); - dnnl::memory::dim B = input_shape[0], // batch size - IC = input_shape[1], // input channels - OC = weight_shape[0]; // output channels + auto activation = node.getAttr>("activation", {"none"}); + auto bias_idx = node.getAttr("bias_idx", {"-1"}); + auto sum_idx = node.getAttr("sum_idx", {"-1"}); + auto sum_scl_idx = node.getAttr("sum_scl_idx", {"-1"}); + auto o_scl_idx = node.getAttr("o_scl_idx", {"-1"}); + auto dst_zp_idx = node.getAttr("dst_zp_idx", {"-1"}); - // Memory shapes. - dnnl::memory::dims data_dims = {B, IC}; - dnnl::memory::dims weight_dims = {OC, IC}; - dnnl::memory::dims bias_dims = {OC}; - dnnl::memory::dims out_dims = {B, OC}; + // may be empty in case if '-1' + auto bias_tr = node.getInput(bias_idx); + auto sum_tr = node.getInput(sum_idx); + auto sum_scl_tr = node.getInput(sum_scl_idx); + auto o_scl_tr = node.getInput(o_scl_idx); + auto dst_zp_tr = node.getInput(dst_zp_idx); - // Memory descriptions. - auto data_md = dnnl::memory::desc({data_dims, dt::f32, tag::nc}); - auto weight_md = dnnl::memory::desc({weight_dims, dt::f32, tag::nc}); - auto bias_md = dnnl::memory::desc({bias_dims, dt::f32, tag::x}); - auto dst_md = dnnl::memory::desc({out_dims, dt::f32, tag::nc}); + // TODO(@apeskov): temp WA. while codegen is not able to guarantee 1D format of bias data + bias_tr = bias_tr.squeeze(); - // Dense description. - auto dense_desc = dnnl::inner_product_forward::desc(dnnl::prop_kind::forward_inference, data_md, - weight_md, bias_md, dst_md); - auto dense_prim_desc = dnnl::inner_product_forward::primitive_desc(dense_desc, engine_); - - auto dense = dnnl::inner_product_forward(dense_prim_desc); - net_.push_back(dense); - - // Memories. - auto data_memory = BindDNNLMemory(data_entry, data_md); - auto weight_memory = BindDNNLMemory(weight_entry, weight_md); - - // Bias memory. - auto bias_memory = dnnl::memory(bias_md, engine_); - if (has_bias) { - auto bias_entry = node.GetInputs()[2]; - BindDNNLMemory(bias_entry, bias_memory); - } else { - float bias[OC] = {0}; - write_to_dnnl_memory(bias, bias_memory, OC * sizeof(float)); - } + // Attributes setting + dnnl::primitive_attr attr; + attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); - // Output memory. - JSONGraphNodeEntry out_entry(nid, 0); - auto dst_memory = BindDNNLMemory(out_entry, dense_prim_desc.dst_desc()); + ICHECK(!dst_zp_tr) << "DNNL doesn't support input zero point for optimized primitives." + "Should be merged into bias"; - net_args_.push_back({{DNNL_ARG_SRC, data_memory}, - {DNNL_ARG_WEIGHTS, weight_memory}, - {DNNL_ARG_BIAS, bias_memory}, - {DNNL_ARG_DST, dst_memory}}); - } + if (o_scl_tr) { + ICHECK(o_scl_tr.isConstant()); + auto data = o_scl_tr.getConstDataLikeVec(); + attr.set_output_scales(data.size() == 1 ? 0 : (1 << 1), data); + } - void BatchNorm(const size_t& nid) { - auto node = nodes_[nid]; + if (activation[0] != "none") { + auto a_type = utils::convert2dnnl_activation(activation[0]); + auto a_scale = node.getInput(std::stoi(activation[1])).getConstScalarData(); + auto a_alfa = node.getInput(std::stoi(activation[2])).getConstScalarData(); + auto a_beta = node.getInput(std::stoi(activation[3])).getConstScalarData(); - auto data_entry = node.GetInputs()[0]; - auto gamma_entry = node.GetInputs()[1]; - auto beta_entry = node.GetInputs()[2]; - auto mean_entry = node.GetInputs()[3]; - auto variance_entry = node.GetInputs()[4]; - dnnl::memory::dims data_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; - dnnl::memory::dim IC = data_shape[1]; - float epsilon = std::stof(node.GetAttr>("epsilon")[0]); + auto ops = attr.get_post_ops(); + ops.append_eltwise(a_scale, a_type, a_alfa, a_beta); + attr.set_post_ops(ops); + } - // Memory description. - dnnl::memory::desc data_md = GenDNNLMemDescByShape(data_shape, dt::f32); + if (sum_scl_tr) { + auto scl = sum_scl_tr.getConstScalarData(); + auto ops = attr.get_post_ops(); + ops.append_sum(scl); + attr.set_post_ops(ops); + } - // BN description. - auto bn_desc = dnnl::batch_normalization_forward::desc( - dnnl::prop_kind::forward_inference, data_md, epsilon, - dnnl::normalization_flags::use_global_stats | dnnl::normalization_flags::use_scale_shift); - auto bn_prim_desc = dnnl::batch_normalization_forward::primitive_desc(bn_desc, engine_); - auto bn = dnnl::batch_normalization_forward(bn_prim_desc); - net_.push_back(bn); - - // Memories. - auto data_memory = BindDNNLMemory(data_entry, data_md); - JSONGraphNodeEntry out_entry(nid, 0); - auto out_memory = BindDNNLMemory(out_entry, data_md); - auto mean_memory = BindDNNLMemory(mean_entry, bn_prim_desc.mean_desc()); - auto variance_memory = BindDNNLMemory(variance_entry, bn_prim_desc.variance_desc()); - - // In DNNL, weight is composed of gamma+beta, so we point them to the same DNNL memory but - // assign an offset to beta data for runtime serialization. - auto weight_memory = BindDNNLMemory(gamma_entry, bn_prim_desc.weights_desc(), 0); - BindDNNLMemory(beta_entry, weight_memory, IC); - - net_args_.push_back({{DNNL_ARG_SRC, data_memory}, - {DNNL_ARG_DST, out_memory}, - {DNNL_ARG_SCALE_SHIFT, weight_memory}, - {DNNL_ARG_MEAN, mean_memory}, - {DNNL_ARG_VARIANCE, variance_memory}}); + // Dense description. + auto dense_d = dnnl::inner_product_forward::desc( + dnnl::prop_kind::forward_inference, src_tr.layoutAny().desc(), wgh_tr.layoutAny().desc(), + bias_tr.layoutAny().desc(), dst_tr.layoutAny().desc()); + auto dense_pd = dnnl::inner_product_forward::primitive_desc(dense_d, attr, engine_); + auto dense = dnnl::inner_product_forward(dense_pd); + + // Select proper layout + src_tr = src_tr.requestLayout(dense_pd.src_desc()); + wgh_tr = wgh_tr.requestLayout(dense_pd.weights_desc()); + dst_tr = dst_tr.requestLayout(dense_pd.dst_desc()); + + auto scratch_pad_d = node.makeScratchpad(dense_pd.scratchpad_desc()); + + // Inplace request for conv+sum pattern. Match input with dst tensor + auto submit_attr = + sum_tr ? SubmitAttr{SubmitAttr::ZeroCopyRequest, sum_tr, DNNL_ARG_DST} : SubmitAttr{}; + + submit(dense, + {{DNNL_ARG_SRC, src_tr}, + {DNNL_ARG_WEIGHTS, wgh_tr}, + {DNNL_ARG_BIAS, bias_tr}, + {DNNL_ARG_SCRATCHPAD, scratch_pad_d}, + {DNNL_ARG_DST, dst_tr}}, + submit_attr); } - void Eltwise(const size_t& nid, dnnl::algorithm algo) { - auto node = nodes_[nid]; - - auto data_entry = node.GetInputs()[0]; - dnnl::memory::dims shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; - dnnl::memory::desc data_md = GenDNNLMemDescByShape(shape, dt::f32); + void BatchNorm(const uint32_t& nid) { + auto node = NodeHelper{nid, g_explorer_}; + + auto src_tr = node.getInput(0); + auto gamma_tr = node.getInput(1); + auto beta_tr = node.getInput(2); + auto mean_tr = node.getInput(3); + auto variance_tr = node.getInput(4); + auto dst_tr = node.getOutput(0); + + auto axis = node.getAttr("axis"); + auto epsilon = node.getAttr("epsilon"); + auto center = node.getAttr("center"); + auto scale = node.getAttr("scale"); + + // TODO(@apeskov): Add support of all type of axis, center and scale args + ICHECK(axis == 1); + ICHECK(center); + ICHECK(scale); + + // TODO(@apeskov): Should it use "any" layout to select proper one? + auto bn_d = dnnl::batch_normalization_forward::desc( + dnnl::prop_kind::forward_inference, dst_tr.desc(), epsilon, + dnnl::normalization_flags::use_global_stats | dnnl::normalization_flags::use_scale_shift); + auto bn_pd = dnnl::batch_normalization_forward::primitive_desc(bn_d, engine_); + auto bn = dnnl::batch_normalization_forward(bn_pd); + + src_tr = src_tr.requestLayout(bn_pd.src_desc()); + dst_tr = dst_tr.requestLayout(bn_pd.dst_desc()); + mean_tr = mean_tr.requestLayout(bn_pd.mean_desc()); + variance_tr = variance_tr.requestLayout(bn_pd.variance_desc()); + + // TODO(@apeskov): DNNL v2.5 and late has API for separate scale and shift + // it will eliminate requirements of data copy. + // Prepare concatenated Scale and Shift tensor + auto scale_shift_tr = node.makeTemp(bn_pd.weights_desc(), g_explorer_.generateUniqueEID()); + auto sc_sh_dims = scale_shift_tr.dims(); + ICHECK(sc_sh_dims.size() == 2); + ICHECK(sc_sh_dims[0] == 2); + sc_sh_dims[0] /= 2; + auto scale_tr = scale_shift_tr.crop(sc_sh_dims, {0, 0}).squeeze(); + auto shift_tr = scale_shift_tr.crop(sc_sh_dims, {1, 0}).squeeze(); + + auto register_copy = [this](const TensorRequisite& src, const TensorRequisite& dst) { + dnnl::reorder::primitive_desc copy_pd(engine_, src.desc(), engine_, dst.desc()); + submit(dnnl::reorder(copy_pd), {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst}}); + }; + + register_copy(gamma_tr, scale_tr); + register_copy(beta_tr, shift_tr); + + submit(bn, {{DNNL_ARG_SRC, src_tr}, + {DNNL_ARG_DST, dst_tr}, + {DNNL_ARG_SCALE_SHIFT, scale_shift_tr}, + {DNNL_ARG_MEAN, mean_tr}, + {DNNL_ARG_VARIANCE, variance_tr}}); + } - auto elt_desc = - dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo, data_md, 0); - auto elt_prim_desc = dnnl::eltwise_forward::primitive_desc(elt_desc, engine_); - ICHECK(data_md == elt_prim_desc.dst_desc()); + void Eltwise(const uint32_t& nid, dnnl::algorithm algo) { + auto node = NodeHelper{nid, g_explorer_}; - auto elt = dnnl::eltwise_forward(elt_prim_desc); - net_.push_back(elt); + auto src_tr = node.getInput(0); + auto dst_tr = node.getOutput(0); + ICHECK(src_tr.dims() == dst_tr.dims()); + // Eltwise op required same layout for src/dst + src_tr = src_tr.requestLayout(dst_tr.desc()); - auto data_memory = BindDNNLMemory(data_entry, data_md); - JSONGraphNodeEntry out_entry(nid, 0); - auto out_memory = BindDNNLMemory(out_entry, data_md); + auto eltwise_d = + dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo, dst_tr.desc()); + auto eltwise_pd = dnnl::eltwise_forward::primitive_desc(eltwise_d, engine_); + auto eltwise = dnnl::eltwise_forward(eltwise_pd); - net_args_.push_back({{DNNL_ARG_SRC, data_memory}, {DNNL_ARG_DST, out_memory}}); + submit(eltwise, {{DNNL_ARG_SRC, src_tr}, {DNNL_ARG_DST, dst_tr}}); } - void Binary(const size_t& nid, dnnl::algorithm algo) { - auto node = nodes_[nid]; - - // Memory and compute description. - std::vector data_dims; - std::vector data_mds; - std::vector data_memories; + void Binary(const uint32_t& nid, dnnl::algorithm algo) { + auto node = NodeHelper{nid, g_explorer_}; - ICHECK_EQ(node.GetInputs().size(), 2U); - for (auto entry : node.GetInputs()) { - auto data_shape = nodes_[entry.id_].GetOpShape()[entry.index_]; - dnnl::memory::desc data_md = GenDNNLMemDescByShape(data_shape, dt::f32); + auto lhs_tr = node.getInput(0); + auto rhs_tr = node.getInput(1); + auto out_tr = node.getOutput(0); - data_dims.push_back(data_shape); - data_mds.push_back(data_md); - data_memories.push_back(BindDNNLMemory(entry, data_md)); - } - ICHECK(data_dims[0] == data_dims[1]); - auto out_md = data_mds[0]; - JSONGraphNodeEntry out_entry(nid, 0); - auto out_memory = BindDNNLMemory(out_entry, out_md); - - auto binary_desc = dnnl::binary::desc(algo, data_mds[0], data_mds[1], out_md); - auto binary_prim_desc = dnnl::binary::primitive_desc(binary_desc, engine_); - auto binary = dnnl::binary(binary_prim_desc); - net_.push_back(binary); - - net_args_.push_back({{DNNL_ARG_SRC_0, data_memories[0]}, - {DNNL_ARG_SRC_1, data_memories[1]}, - {DNNL_ARG_DST, out_memory}}); - } + lhs_tr = lhs_tr.broadcast(out_tr.dims()); + rhs_tr = rhs_tr.broadcast(out_tr.dims()); - // Read from DNNL memory (+offset) and write to the handle. - inline void read_from_dnnl_memory(void* handle, const dnnl::memory& mem, size_t size, - size_t offset = 0) { - uint8_t* src = static_cast(mem.get_data_handle()); - std::copy(src + offset, src + offset + size, static_cast(handle)); - } + // Any layouts cannot be used for binary prim + auto binary_d = dnnl::binary::desc(algo, lhs_tr.desc(), rhs_tr.desc(), out_tr.desc()); + auto binary_pd = dnnl::binary::primitive_desc(binary_d, engine_); + auto binary = dnnl::binary(binary_pd); - // Read from the handle and write to DNNL memory (+offset). - inline void write_to_dnnl_memory(void* handle, const dnnl::memory& mem, size_t size, - size_t offset = 0) { - uint8_t* dst = static_cast(mem.get_data_handle()); - std::copy(reinterpret_cast(handle), reinterpret_cast(handle) + size, - dst + offset); - } + // Request proper layouts + lhs_tr = lhs_tr.requestLayout(binary_pd.src0_desc()); + rhs_tr = rhs_tr.requestLayout(binary_pd.src1_desc()); + out_tr = out_tr.requestLayout(binary_pd.dst_desc()); - // Generate DNNL memory description and infer the data layout by the given shape. - inline dnnl::memory::desc GenDNNLMemDescByShape(const dnnl::memory::dims& shape, dt dtype) { - dnnl::memory::desc data_md; - switch (shape.size()) { - case 2: - data_md = dnnl::memory::desc({shape, dtype, tag::ab}); - break; - case 3: - data_md = dnnl::memory::desc({shape, dtype, tag::abc}); - break; - case 4: - data_md = dnnl::memory::desc({shape, dtype, tag::abcd}); - break; - case 5: - data_md = dnnl::memory::desc({shape, dtype, tag::abcde}); - break; - default: - LOG(FATAL) << "Unsupported data shape dimension: " << shape.size(); - break; - } - return data_md; + submit(binary, {{DNNL_ARG_SRC_0, lhs_tr}, {DNNL_ARG_SRC_1, rhs_tr}, {DNNL_ARG_DST, out_tr}}); } - /* The dnnl engine. */ + /** The dnnl engine. */ dnnl::engine engine_; - /* The dnnl stream. */ + /** The dnnl stream. */ dnnl::stream stream_; - /* The network layers that are represented in dnnl primitives. */ - std::vector net_; - /* The memory that is consumed by arguments. */ - std::vector> net_args_; - /* The entry ID to its corresponding output memory. */ - std::unordered_map> entry_out_mem_; + /** Tensor registry which manages all real dnnl memory objects */ + TensorRegistry tensor_registry_; + /** The network layers that are represented as dnnl primitives plus there args. */ + TensorRegistry::ActionQue net_; + /** Utility object */ + GraphExplorer g_explorer_; }; -runtime::Module DNNLJSONRuntimeCreate(String symbol_name, String graph_json, +runtime::Module DNNLJSONRuntimeCreate(const String& symbol_name, const String& graph_json, const Array& const_names) { auto n = make_object(symbol_name, graph_json, const_names); return runtime::Module(n); @@ -483,6 +578,8 @@ TVM_REGISTER_GLOBAL("runtime.DNNLJSONRuntimeCreate").set_body_typed(DNNLJSONRunt TVM_REGISTER_GLOBAL("runtime.module.loadbinary_dnnl_json") .set_body_typed(JSONRuntimeBase::LoadFromBinary); +TVM_REGISTER_GLOBAL("runtime.module.dnnl_version").set_body_typed(DNNLJSONRuntime::get_version); + } // namespace contrib } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/dnnl/dnnl_node_helper.h b/src/runtime/contrib/dnnl/dnnl_node_helper.h new file mode 100644 index 000000000000..b7f636677b1d --- /dev/null +++ b/src/runtime/contrib/dnnl/dnnl_node_helper.h @@ -0,0 +1,796 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RUNTIME_CONTRIB_DNNL_DNNL_NODE_HELPER_H_ +#define TVM_RUNTIME_CONTRIB_DNNL_DNNL_NODE_HELPER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../json/json_runtime.h" +#include "dnnl.hpp" + +namespace tvm { +namespace runtime { +namespace contrib { + +using namespace dnnl; + +namespace utils { + +/** Converter helper for shape objects */ +inline static dnnl::memory::dims convert2dnnl(std::vector shape) { + if (shape.empty()) return {1}; // DNNL scalar representation + return shape; +} + +/** Converter helper for data type objects */ +inline static dnnl::memory::data_type convert2dnnl(DLDataType dtype) { + if (dtype.code == DLDataTypeCode::kDLInt) { + if (dtype.bits == 8) return dnnl::memory::data_type::s8; + if (dtype.bits == 32) return dnnl::memory::data_type::s32; + } else if (dtype.code == DLDataTypeCode::kDLUInt) { + if (dtype.bits == 8) return dnnl::memory::data_type::u8; + } else if (dtype.code == DLDataTypeCode::kDLFloat) { + if (dtype.bits == 16) return dnnl::memory::data_type::f16; + if (dtype.bits == 32) return dnnl::memory::data_type::f32; + } else if (dtype.code == DLDataTypeCode::kDLBfloat) { + if (dtype.bits == 16) return dnnl::memory::data_type::bf16; + } + LOG(FATAL) << "Data type is not supported"; + return dnnl::memory::data_type::undef; +} + +/** Converter of primitive types to corresponding DNNL data type */ +template +dnnl::memory::data_type dnnlDType(); +template <> +dnnl::memory::data_type dnnlDType() { + return dnnl::memory::data_type::s32; +} +template <> +dnnl::memory::data_type dnnlDType() { + return dnnl::memory::data_type::f32; +} + +/** Generator of dnnl format_tag for plain version of tensor */ +inline static dnnl::memory::format_tag plainLayout(uint32_t rank) { + switch (rank) { + case 0: + case 1: + return dnnl::memory::format_tag::a; + case 2: + return dnnl::memory::format_tag::ab; + case 3: + return dnnl::memory::format_tag::abc; + case 4: + return dnnl::memory::format_tag::abcd; + case 5: + return dnnl::memory::format_tag::abcde; + case 6: + return dnnl::memory::format_tag::abcdef; + default: + LOG(FATAL) << "Unsupported data tensor rank: " << rank; + break; + } + return dnnl::memory::format_tag::undef; +} + +inline static dnnl::memory::desc makePlainTDesc(const std::vector& shape, + const DLDataType& dtype) { + return {convert2dnnl(shape), convert2dnnl(dtype), plainLayout(shape.size())}; +} + +/** Builder of dnnl memory on top of provided DLTensor */ +dnnl::memory convert2dnnl(const DLTensor* dl_tensor, const dnnl::engine engine) { + // TODO(apeskov): assume that data is always in plain format, check if it's true + ICHECK(dl_tensor->strides == nullptr); + ICHECK_EQ(dl_tensor->device.device_type, kDLCPU); + + std::vector dl_dims(dl_tensor->shape, dl_tensor->shape + dl_tensor->ndim); + dnnl::memory::desc desc{convert2dnnl(dl_dims), convert2dnnl(dl_tensor->dtype), + plainLayout(dl_dims.size())}; + + desc.data.offset0 = dl_tensor->byte_offset; + return {desc, engine, dl_tensor->data}; +} + +/** Converter helper for Eltwise op name proper dnnl::algorithm value */ +dnnl::algorithm convert2dnnl_activation(std::string name) { + if (name == "nn.relu") + return dnnl::algorithm::eltwise_relu; + else if (name == "clip") + return dnnl::algorithm::eltwise_clip; + else if (name == "gelu") + return dnnl::algorithm::eltwise_gelu; + else if (name == "tanh") + return dnnl::algorithm::eltwise_tanh; + else if (name == "sqrt") + return dnnl::algorithm::eltwise_sqrt; + else if (name == "sigmoid") + return dnnl::algorithm::eltwise_logistic; + else + LOG(FATAL) << "Unknown activation name"; + + return dnnl::algorithm::undef; +} + +/** Find a permutation of chars in src string to achieve a ref string version */ +inline static std::vector permutation(const std::string& src, const std::string& ref) { + std::set chars(src.begin(), src.end()); + ICHECK_EQ(chars.size(), src.size()) << "\"" << src << "\" has a duplicate symbols"; + + std::vector perm; + for (const auto& c : src) { + auto found = ref.find(c); + ICHECK_NE(found, std::string::npos) << "\"" << src << "\" is not a permutation of " + << "\"" << ref << "\""; + perm.push_back(found); + } + return perm; +} + +/** Data copy function */ +void copy_now(const dnnl::memory& src, const dnnl::memory& dst) { + auto reorder = dnnl::reorder(src, dst); + auto stream = dnnl::stream(src.get_engine()); + // DNNL api requires non const ref for src. Have to use const_cast + auto src_non_const = const_cast(src); + auto dst_non_const = const_cast(dst); + reorder.execute(stream, src_non_const, dst_non_const); +} + +} // namespace utils + +/** + * Helper object to simplify handling of tensor + * + * Allow to specify tensor in future and actions which should be applied to it. + * Can be treated couple of future tensor source reference and list of action which should be + * applied to this tensor. Finally TensorRequisite object should be registered in TensorRegistry. + * + * @note: Empty TR object allow any manipulation. Empty TR will be returned. + * + * Like: + * source - input tensor on position 3 + * actions - reinterpret like a plain 1D tensor + * + * Example: + * auto tr = node.getInput(3); // source is node input #3 + * tr = tr.permute({1, 2, 0}); // permute axes chw -> hwc + * tr = tr.crop({128, 128, 1}, {0, 0, 0}); // extract first channel + * tr = tr.squeeze(); + * + * submit(prim, {DNNL_ARG_SRC, tr}); + */ +class TensorRequisite { + public: + static constexpr uint32_t INVALID_EID = std::numeric_limits::max() - 1; + + TensorRequisite() {} + + /** return shape of tensor */ + dnnl::memory::dims dims() const { return t_desc_.dims(); } + /** return tensor desc */ + dnnl::memory::desc desc() const { return t_desc_; } + + /** Produce tensor with permuted axes */ + TensorRequisite permute(const std::vector& permutation) const { + if (!defined()) return *this; // nothing for empty TR + + auto orig = std::make_shared(*this); + // reinterpret memory buffer with new strides + auto desc = t_desc_.permute_axes(permutation); + return {desc, orig, true, {}, INVALID_EID, reverse_data_flow_}; + } + + /** Produce tensor with reinterpret data of original tr */ + TensorRequisite reshape(const dnnl::memory::dims& shape) const { + if (!defined()) return *this; // nothing for empty TR + if (t_desc_.dims() == shape) return *this; + + auto orig = std::make_shared(*this); + // reinterpret memory buffer with new strides + auto desc = t_desc_.reshape(shape); + return {desc, orig, true, {}, INVALID_EID, reverse_data_flow_}; + } + + /** Produce tensor with broadcasted values */ + TensorRequisite broadcast(const dnnl::memory::dims& shape) const { + if (!defined()) return *this; // nothing for empty TR + if (t_desc_.dims() == shape) return *this; + ICHECK(!reverse_data_flow_); + + auto orig = std::make_shared(*this); + + // numpy like broadcast + auto extended_dims = t_desc_.dims(); + auto one_filled = dnnl::memory::dims(shape.size() - extended_dims.size(), 1); + extended_dims.insert(extended_dims.begin(), one_filled.begin(), one_filled.end()); + auto desc = t_desc_.reshape(extended_dims); + for (size_t i = 0; i < extended_dims.size(); i++) { + if (extended_dims[i] == shape[i]) continue; + ICHECK(extended_dims[i] == 1); + ICHECK(desc.data.dims[i] == desc.data.padded_dims[i]); + + desc.data.dims[i] = shape[i]; + desc.data.padded_dims[i] = shape[i]; + desc.data.format_desc.blocking.strides[i] = 0; + } + + // reinterpret memory buffer with new strides + return {desc, orig, true, {}, INVALID_EID, reverse_data_flow_}; + } + + /** Produce tensor with sub memory view (ROI) */ + TensorRequisite crop(const dnnl::memory::dims& shape, const dnnl::memory::dims& offset) const { + if (!defined()) return *this; // nothing for empty TR + + auto orig = std::make_shared(*this); + // reinterpret memory buffer with new strides + auto desc = t_desc_.submemory_desc(shape, offset); + return {desc, orig, true, {}, INVALID_EID, reverse_data_flow_}; + } + + /** Produce tensor with squeeze shape */ + TensorRequisite squeeze(const dnnl::memory::dims& dims_to_squeeze = {}) const { + if (!defined()) return *this; // nothing for empty TR + + dnnl::memory::dims squeezed_dims; + if (dims_to_squeeze.empty()) { + for (auto d : t_desc_.dims()) + if (d != 1) squeezed_dims.push_back(d); + } else { + for (size_t i = 0; i < t_desc_.dims().size(); i++) + if (std::find(dims_to_squeeze.begin(), dims_to_squeeze.end(), i) == dims_to_squeeze.end()) + squeezed_dims.push_back(t_desc_.dims()[i]); + } + + auto orig = std::make_shared(*this); + // reinterpret memory buffer with new strides + auto desc = t_desc_.reshape(squeezed_dims); + return {desc, orig, true, {}, INVALID_EID, reverse_data_flow_}; + } + + /** Produce tensor with specified layout */ + TensorRequisite requestLayout(dnnl::memory::desc desc) const { + if (!defined()) return *this; // nothing for empty TR + + // If it's the same desc just return self + if (desc == t_desc_) return *this; + + ICHECK(t_desc_.dims() == desc.dims()) << "Requested layout is not compatible with " + "presented shape"; + + auto orig = std::make_shared(*this); + return {desc, orig, false, {}, INVALID_EID, reverse_data_flow_}; + } + + /** + * Produce tensor with unspecified layout + * Cannot be registered in TensorRegistry. Only for querying DNNL for preferred layouts. + */ + TensorRequisite layoutAny() const { + auto orig = std::make_shared(*this); + // Recreate tensor desc with layout 'any' + dnnl::memory::desc any_desc{t_desc_.dims(), t_desc_.data_type(), dnnl::memory::format_tag::any}; + return {any_desc, orig, false, {}, INVALID_EID, reverse_data_flow_}; + } + + /** Check is tensor is constant */ + bool isConstant() const { + if (orig_) return orig_->isConstant(); + return mem_.operator bool(); + } + + /** Check is tensor is scalar */ + bool isScalar() const { return t_desc_.dims().size() == 1 && t_desc_.dims()[0] == 1; } + + /** Produce const data memory object with proper content */ + dnnl::memory getConstData() const { + if (reverse_data_flow_ || eid_ != INVALID_EID) return {}; + if (mem_) return mem_; + + ICHECK(orig_); + if (auto orig_const_data = orig_->getConstData()) { + if (reinterpret_) { + return {t_desc_, orig_const_data.get_engine(), orig_const_data.get_data_handle()}; + } else { + auto res = dnnl::memory{t_desc_, orig_const_data.get_engine()}; + utils::copy_now(orig_const_data, res); + return res; + } + } + return {}; + } + + /** + * Same as getConstData but in form of std::vector + * Useful for 1D constant tensor like zero_point or per_channel_scale + * + * @tparam T desired data type + * @return resulting data + */ + template + std::vector getConstDataLikeVec() const { + auto const_data = getConstData(); + auto desc = const_data.get_desc(); + ICHECK(desc.data_type() == utils::dnnlDType()); + ICHECK(desc.dims().size() == 1); + + auto size = desc.get_size() / sizeof(T); + auto ptr = static_cast(const_data.get_data_handle()); + + return std::vector(ptr, ptr + size); + } + + /** + * Produce value of constant scalar tensor + * @tparam T desired scalar type + * @return resulting value of type T + */ + template + T getConstScalarData() const { + ICHECK(isConstant()); + ICHECK(isScalar()); + auto const_data = getConstData(); + auto desc = const_data.get_desc(); + ICHECK(desc.data_type() == utils::dnnlDType()); + + auto ptr = static_cast(const_data.get_data_handle()); + return *ptr; + } + + /** Check if tensor is not empty */ + bool defined() const { return !t_desc_.is_zero(); } + + /** Same as defined() */ + operator bool() const { return defined(); } + + /** Check if tensor represent a reversed action queue (aka is a dst) */ + bool isReversed() const { return reverse_data_flow_; } + + private: + TensorRequisite(const dnnl::memory::desc& t_desc, const std::shared_ptr& orig, + bool reinterpret, const dnnl::memory& const_mem, uint32_t eid, + bool reverse_data_flow) + : t_desc_(t_desc), + orig_(orig), + reinterpret_(reinterpret), + mem_(const_mem), + eid_(eid), + reverse_data_flow_(reverse_data_flow) {} + + /** Descriptor of PT */ + dnnl::memory::desc t_desc_ = {}; + /** Original PT to relay in operation */ + std::shared_ptr orig_ = {}; + /** Flag to specify reinterpret orig or do reordering */ + bool reinterpret_ = false; + /** Const memory object if available */ + dnnl::memory mem_ = {}; + /** Entry ID of tensor if it available */ + uint32_t eid_ = INVALID_EID; + + /** + * Flag to describe reverse data flow case + * All operation on queue will be executed in reverse order. Actual for dst tensor description + */ + bool reverse_data_flow_ = false; + + friend class TensorRegistry; + friend class NodeHelper; +}; + +class TensorRegistry { + private: + enum ArgReqFlag { + UNKNOWN, /// < Undefined type of args. Cannot be matched to real tensor + CONST, /// < Constant tensor. ExecutionCTX independent + TMP_STORAGE, /// < Intermediate tensors. Stored inside TensorRegistry. Inaccessible outside + EXT_EID, /// < External data. Input or Output. + SCRATCHPAD /// < Scratchpad tensor. May overlap with other Scratchpad buffers. + }; + + public: + struct ArgReq { + TensorRegistry::ArgReqFlag flag_; + uint32_t idx_; + }; + using ArgReqSet = std::unordered_map; + using Action = std::tuple; + using ActionQue = std::vector; + using ExtDataProvider = std::function; + + TensorRegistry() = default; + TensorRegistry(const dnnl::engine& eng, const std::set& ext_eid_set) + : ext_eid_(ext_eid_set), eng_(eng) {} + + /** + * Register a TensorRequisite + * + * As result corresponding ArgReq and related action which should be executed before + * (or after in case of reverse data flow) usage of this tensor. + * @param tr TensorRequisite to register + * @return Associated ArgReq ar list of actions + */ + std::pair registerTR(const TensorRequisite& tr) { + // 1) Constant tensor. Direct reference + if (auto const_data = tr.getConstData()) { + auto idx = const_mem_collection_.size(); + const_mem_collection_.push_back(const_data); + auto arg_req = makeArgReq(ArgReqFlag::CONST, static_cast(idx)); + return {arg_req, {}}; + } + + // 2) EID mapped tensor. Direct reference + if (tr.eid_ != TensorRequisite::INVALID_EID) { + if (isTempEID(tr.eid_)) { + if (eid2idx_tmp_.count(tr.eid_)) { + auto idx = eid2idx_tmp_.at(tr.eid_); + auto arg_req = makeArgReq(ArgReqFlag::TMP_STORAGE, idx); + return {arg_req, {}}; + } else { + // register himself + auto mem = dnnl::memory{tr.t_desc_, eng_}; + auto idx = tmp_mem_collection_.size(); + tmp_mem_collection_.push_back(mem); + eid2idx_tmp_[tr.eid_] = idx; + auto arg_req = makeArgReq(ArgReqFlag::TMP_STORAGE, static_cast(idx)); + return {arg_req, {}}; + } + } else { + auto idx = ext_mem_collection_.size(); + ext_mem_collection_.push_back({tr.eid_, tr.t_desc_}); + auto arg_req = makeArgReq(ArgReqFlag::EXT_EID, static_cast(idx)); + return {arg_req, {}}; + } + } + + // 3) Tensors with transform actions + if (tr.orig_) { + ArgReq arg_req; + ActionQue actions; + + // recursive register of orig TR + std::tie(arg_req, actions) = registerTR(*tr.orig_); + if (tr.reinterpret_) { + arg_req = register_reinterp(arg_req, tr.t_desc_); + } else { + ActionQue reorder_act; + std::tie(arg_req, reorder_act) = + register_reorder(arg_req, tr.t_desc_, tr.reverse_data_flow_); + + actions.insert(tr.reverse_data_flow_ ? actions.begin() : actions.end(), reorder_act.begin(), + reorder_act.end()); + } + return {arg_req, actions}; + } + + // 4) Scratchpad + ICHECK(!tr.orig_ && !tr.mem_ && tr.eid_ == TensorRequisite::INVALID_EID); + auto scratchpad_ar = register_scratchpad(tr.t_desc_); + return {scratchpad_ar, {}}; + } + + std::unordered_map solve(const ArgReqSet& args, + const ExtDataProvider& ext_provider) const { + std::unordered_map res; + for (const auto& kvp : args) res[kvp.first] = solve(kvp.second, ext_provider); + return res; + } + + /** + * Find a proper memory object associated with provided ArgReq + * @param ar ArgReq to + * @param ext_provider + * @return + */ + dnnl::memory solve(const ArgReq& ar, const ExtDataProvider& ext_provider) const { + switch (ar.flag_) { + case CONST: + return const_mem_collection_.at(ar.idx_); + case TMP_STORAGE: + return tmp_mem_collection_.at(ar.idx_); + case EXT_EID: { + auto eid_and_desc = ext_mem_collection_.at(ar.idx_); + auto eid = eid_and_desc.first; + auto desc = eid_and_desc.second; + + auto hdl = ext_provider(eid); + return dnnl::memory{desc, eng_, hdl}; + } + case SCRATCHPAD: { + auto desc = scratchpad_desc_collection_.at(ar.idx_); + // TODO(@apeskov): make it thread local and avoid recreation each time + return dnnl::memory(desc, eng_); + } + default: + LOG(FATAL) << "Unknown case"; + } + return {}; + } + + /** Finalize registry. Should be called before any call of solve() method */ + void finalize() { + // calc total scratchpad size + dnnl::memory::dim scratchpad_size = 0; + for (const auto& scr_desc : scratchpad_desc_collection_) { + dnnl::memory::dim size = scr_desc.get_size(); + scratchpad_size = std::max(scratchpad_size, size); + } + scratchpad_mem_ = dnnl::memory::desc({scratchpad_size}, dnnl::memory::data_type::u8, + dnnl::memory::format_tag::a); + } + + private: + ArgReq register_reinterp(ArgReq src_ar, const dnnl::memory::desc& desc) { + switch (src_ar.flag_) { + case CONST: { + LOG(FATAL) << "Unreachable case"; + return {}; + } + case TMP_STORAGE: { + auto src = tmp_mem_collection_[src_ar.idx_]; + auto dst = dnnl::memory{desc, src.get_engine(), src.get_data_handle()}; + auto idx = tmp_mem_collection_.size(); + tmp_mem_collection_.push_back(dst); + return makeArgReq(TMP_STORAGE, idx); + } + case EXT_EID: { + auto ext_req = ext_mem_collection_[src_ar.idx_]; + auto idx = ext_mem_collection_.size(); + ext_mem_collection_.push_back({ext_req.first, desc}); + return makeArgReq(EXT_EID, idx); + } + default: + LOG(FATAL) << "Unknown case"; + } + return {}; + } + + std::pair register_reorder(ArgReq src_ar, const dnnl::memory::desc& desc, + bool reverse_data_flow) { + switch (src_ar.flag_) { + case CONST: { + LOG(FATAL) << "Unreachable case"; + return {}; + } + case TMP_STORAGE: { + auto src = tmp_mem_collection_[src_ar.idx_]; + + auto dst = dnnl::memory{desc, eng_}; + auto idx = tmp_mem_collection_.size(); + tmp_mem_collection_.push_back(dst); + auto dst_ar = makeArgReq(TMP_STORAGE, idx); + + // Action + Action res_action; + if (reverse_data_flow) { + res_action = {dnnl::reorder(dst, src), {{DNNL_ARG_FROM, dst_ar}, {DNNL_ARG_TO, src_ar}}}; + } else { + res_action = {dnnl::reorder(src, dst), {{DNNL_ARG_FROM, src_ar}, {DNNL_ARG_TO, dst_ar}}}; + } + + return {dst_ar, {res_action}}; + } + case EXT_EID: { + auto src_desc = ext_mem_collection_[src_ar.idx_].second; + + auto dst = dnnl::memory{desc, eng_}; + auto idx = tmp_mem_collection_.size(); + tmp_mem_collection_.push_back(dst); + auto dst_ar = makeArgReq(TMP_STORAGE, idx); + + // Action + Action res_action; + if (reverse_data_flow) { + auto reorder_pd = dnnl::reorder::primitive_desc(eng_, desc, eng_, src_desc); + auto reorder = dnnl::reorder(reorder_pd); + res_action = {reorder, {{DNNL_ARG_FROM, dst_ar}, {DNNL_ARG_TO, src_ar}}}; + } else { + auto reorder_pd = dnnl::reorder::primitive_desc(eng_, src_desc, eng_, desc); + auto reorder = dnnl::reorder(reorder_pd); + res_action = {reorder, {{DNNL_ARG_FROM, src_ar}, {DNNL_ARG_TO, dst_ar}}}; + } + + return {dst_ar, {res_action}}; + } + default: + LOG(FATAL) << "Unknown case"; + } + return {}; + } + + ArgReq register_scratchpad(const dnnl::memory::desc& desc) { + auto idx = scratchpad_desc_collection_.size(); + scratchpad_desc_collection_.push_back(desc); + return makeArgReq(SCRATCHPAD, idx); + } + + ArgReq makeArgReq(ArgReqFlag flag, uint32_t idx) { return {flag, idx}; } + + bool isTempEID(uint32_t eid) { return ext_eid_.count(eid) == 0; } + + /** Collection of const memory objects. */ + std::vector const_mem_collection_; + + /** Collection of intermediate memory objects. */ + std::vector tmp_mem_collection_; + + /** Map of eid to index of temp buffer in tmp_mem_collection_ */ + std::unordered_map eid2idx_tmp_; + + /** Collection of external_intermediate memory objects. + * first - eid of external buffer to ask + * second - t_desc describes how to treat external buffer */ + std::vector> ext_mem_collection_; + + /** Scratchpad collection */ + std::vector scratchpad_desc_collection_; + + /** Overall scratchpad memory obj */ + dnnl::memory::desc scratchpad_mem_; + + /** List of external eid */ + std::set ext_eid_; + + /** Engine of all tensors returned form this registry */ + dnnl::engine eng_; +}; + +/** + * GraphExplorer is a list of fields of original JSONRuntimeBase which allows + * to travers through the graph. + * + * Essentially that is a WA for access of protected fields of JSONRuntimeBase. + */ +struct GraphExplorer { + GraphExplorer(const std::vector& nodes, + const std::vector& data_entry, + const std::vector& node_row_ptr, const dnnl::engine& engine) + : nodes_(nodes), + data_entry_(data_entry), + node_row_ptr_(node_row_ptr), + engine_(engine), + gen_eid_offset(data_entry.size()) {} + + const std::vector& nodes_; + const std::vector& data_entry_; + const std::vector& node_row_ptr_; + + const dnnl::engine& engine_; + + uint32_t gen_eid_offset; + + uint32_t generateUniqueEID() { return gen_eid_offset++; } +}; + +class NodeHelper { + public: + NodeHelper(const uint32_t& nid, const GraphExplorer& graph_explorer) + : nid_(nid), node_(graph_explorer.nodes_[nid]), graph_explorer_(graph_explorer) {} + + template + typename std::enable_if::value, T>::type convert( + std::vector val) { + ICHECK_EQ(val.size(), 1); + return std::stol(val[0]); + } + + template + typename std::enable_if::value, T>::type convert( + std::vector val) { + ICHECK_EQ(val.size(), 1); + return std::stof(val[0]); + } + + template + typename std::enable_if>::value, T>::type convert( + std::vector val) { + T res; + for (const auto& el : val) res.push_back(convert({el})); + return res; + } + + template + typename std::enable_if::value, T>::type convert( + std::vector val) { + ICHECK_EQ(val.size(), 1); + return val[0]; + } + + // TODO(apeskov): enhance to any vector type, not only int + template + typename std::enable_if>::value, T>::type convert( + std::vector val) { + T res; + for (const auto& el : val) res.push_back(convert({el})); + return res; + } + + template + const T getAttr(std::string name, std::vector def = {}) { + auto attr = node_.HasAttr(name) ? node_.GetAttr>(name) : def; + return convert(attr); + } + + TensorRequisite getInput(int idx) { + if (idx == -1) return {}; // unavailable input + + ICHECK_LT(idx, node_.GetInputs().size()); + auto data_entry = node_.GetInputs()[idx]; + + auto shape = graph_explorer_.nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; + auto dtype = graph_explorer_.nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]; + auto eid = graph_explorer_.node_row_ptr_[data_entry.id_] + data_entry.index_; + auto dl_tensor = graph_explorer_.data_entry_[eid]; + + auto desc = utils::makePlainTDesc(shape, dtype); + + dnnl::memory mem = {}; + if (dl_tensor) { + eid = TensorRequisite::INVALID_EID; + mem = utils::convert2dnnl(dl_tensor, graph_explorer_.engine_); + ICHECK(mem.get_desc() == desc); + } + + return {desc, nullptr, false, mem, eid, false}; + } + + TensorRequisite getOutput(int idx) { + ICHECK_LT(idx, node_.GetNumOutput()); + + auto shape = node_.GetOpShape()[idx]; + auto dtype = node_.GetOpDataType()[idx]; + auto eid = graph_explorer_.node_row_ptr_[nid_] + static_cast(idx); + auto dl_tensor = graph_explorer_.data_entry_[eid]; + + auto desc = utils::makePlainTDesc(shape, dtype); + + ICHECK(!dl_tensor) << "Output of operation node cannot be constant"; + return {desc, nullptr, true, {}, eid, true}; + } + + TensorRequisite makeTemp(const dnnl::memory::desc& desc, uint32_t eid) { + return {desc, nullptr, false, {}, eid, true}; + } + + TensorRequisite makeScratchpad(const dnnl::memory::desc& desc) { + return {desc, nullptr, false, {}, TensorRequisite::INVALID_EID, true}; + } + + private: + const uint32_t nid_; + const json::JSONGraphNode& node_; + const GraphExplorer& graph_explorer_; +}; + +} // namespace contrib +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONTRIB_DNNL_DNNL_NODE_HELPER_H_ diff --git a/src/runtime/contrib/json/json_runtime.h b/src/runtime/contrib/json/json_runtime.h index 1735d8569215..a151f9b24ffd 100644 --- a/src/runtime/contrib/json/json_runtime.h +++ b/src/runtime/contrib/json/json_runtime.h @@ -88,8 +88,11 @@ class JSONRuntimeBase : public ModuleNode { // The function to initialize constant tensors. return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { ICHECK_EQ(args.size(), 1U); - this->Init(args[0]); - this->initialized_ = true; + std::lock_guard guard(this->initialize_mutex_); + if (!this->initialized_) { + this->Init(args[0]); + this->initialized_ = true; + } *rv = 0; }); } else { @@ -269,6 +272,8 @@ class JSONRuntimeBase : public ModuleNode { std::vector const_idx_; /*! \brief Indicate if the engine has been initialized. */ bool initialized_{false}; + /*! \brief Initializer mutex*/ + std::mutex initialize_mutex_; }; } // namespace json diff --git a/tests/python/contrib/test_dnnl/__init__.py b/tests/python/contrib/test_dnnl/__init__.py new file mode 100644 index 000000000000..0586bb59dba5 --- /dev/null +++ b/tests/python/contrib/test_dnnl/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Infrastructure and tests for DNNL runtime""" diff --git a/tests/python/contrib/test_dnnl/common.py b/tests/python/contrib/test_dnnl/common.py new file mode 100644 index 000000000000..965109cb102c --- /dev/null +++ b/tests/python/contrib/test_dnnl/common.py @@ -0,0 +1,236 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test DNNL integration conv2d tests.""" + +import numpy as np +import pytest +import tvm +from tvm import relay, runtime +from tvm.relay.backend import te_compiler +from tvm.contrib import graph_executor + +import collections +from numbers import Number + +requires_dnnl = pytest.mark.skipif( + tvm.get_global_func("runtime.DNNLJSONRuntimeCreate", True) is None, + reason="DNNL codegen is not available", +) + + +def parametrized(arg_name, workloads): + return pytest.mark.parametrize( + arg_name, [w[1:] for w in workloads], ids=[w[0] for w in workloads] + ) + + +def permute(shape, l_from="", l_to=""): + res_shape = [] + for label in l_to: + pos = l_from.find(label) + res_shape.append(shape[pos]) + + return res_shape + + +def expand_dim(shape, rank=0): + assert len(shape) == 1 + return shape + [1] * (rank - 1) + + +def check_fully_annotated(mod, desired_compiler): + matched_ops = [] + other_ops = [] + + def _visit(node): + if isinstance(node, tvm.relay.Call): + op = node.op + if isinstance(op, relay.GlobalVar): + func = mod[op] + if "Compiler" in func.attrs and func.attrs["Compiler"] == desired_compiler: + matched_ops.append(op) + return + else: + other_ops.append(op) + + tvm.relay.analysis.post_order_visit(mod["main"].body, _visit) + + assert len(other_ops) == 0 and len(matched_ops) != 0, "Model is not fully DNNL compiled" + + +def check_result( + mod, + ref_mod, + map_inputs, + tol=1e-5, + target="llvm", + device=tvm.cpu(), + params=None, + ref_result=None, + atol=None, + desired_compiler="dnnl", +): + if atol is None: + atol = tol + + if desired_compiler is not None: + check_fully_annotated(mod, desired_compiler) + + if ref_result is None: + # Run the reference result + te_compiler.get().clear() + with tvm.transform.PassContext(opt_level=3): + ref_lib = relay.build(ref_mod, target=target, params=params) + ref_rt_mod = tvm.contrib.graph_executor.GraphModule(ref_lib["default"](device)) + + for name, data in map_inputs.items(): + ref_rt_mod.set_input(name, data) + ref_rt_mod.run() + out = ref_rt_mod.get_output(0) + ref_result = out.numpy() + + def check_vm_result(): + te_compiler.get().clear() + with tvm.transform.PassContext(opt_level=3): + exe = relay.vm.compile(mod, target=target, params=params) + code, lib = exe.save() + exe = runtime.vm.Executable.load_exec(code, lib) + vm = runtime.vm.VirtualMachine(exe, device) + output = vm.run(**map_inputs) + tvm.testing.assert_allclose(output.numpy(), ref_result, rtol=tol, atol=atol) + + def check_graph_executor_result(): + te_compiler.get().clear() + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, target=target, params=params) + rt_mod = tvm.contrib.graph_executor.GraphModule(lib["default"](device)) + + rt_mod.run(**map_inputs) + output = rt_mod.get_output(0) + tvm.testing.assert_allclose(output.numpy(), ref_result, rtol=tol, atol=atol) + + check_vm_result() + check_graph_executor_result() + + +def filler_uni(low=0, high=1): + def filler_func(shape): + return np.random.uniform(low, high, shape) + + return filler_func + + +class Builder: + def __init__(self, qnn_profile=None): + self._args = {} + self._args_op = [] + self._qp = qnn_profile + + def arg(self, shape=[], dtype="float32", filler=filler_uni(), is_const=True): + if isinstance(filler, Number): + value = np.full(shape, filler).astype(dtype) + else: + value = filler(shape).astype(dtype) + + if is_const: + res = relay.const(value, dtype=dtype) + else: + name = f"in_{len(self._args)}" + res = relay.var(name, shape=shape, dtype=dtype) + self._args[name] = value + self._args_op.append(res) + + return res + + def make_zp(self, mean_val, num_ch=1, dispersion=0.2): + if num_ch == 1: + return self.arg(shape=[], dtype="int32", filler=mean_val) + else: + low = int(mean_val * (1 - dispersion)) + high = int(mean_val * (1 + dispersion)) + return self.arg(shape=[num_ch], dtype="int32", filler=filler_uni(low, high)) + + def make_scl(self, mean_val, num_ch=1, dispersion=0.2): + if num_ch == 1: + return self.arg(shape=[], dtype="float32", filler=mean_val) + else: + low = mean_val * (1 - dispersion) + high = mean_val * (1 + dispersion) + return self.arg(shape=[num_ch], dtype="float32", filler=filler_uni(low, high)) + + def make_zp_and_scl(self, name, num_ch=1, dispersion=0.2): + is_per_channel = getattr(self._qp, f"{name}_pc") + zp_val = getattr(self._qp, f"{name}_zp") + scl_val = getattr(self._qp, f"{name}_scl") + + zp = self.make_zp(zp_val, num_ch if is_per_channel else 1, dispersion) + scl = self.make_scl(scl_val, num_ch if is_per_channel else 1, dispersion) + return zp, scl + + def finalize(self, op): + func = relay.Function(self._args_op, op) + mod = tvm.IRModule.from_expr(func) + mod = relay.transform.InferType()(mod) + return mod, self._args + + +ConvProfile = collections.namedtuple( + "ConvProfile", + [ + "N", + "IH", + "IW", + "IC", + "OC", + "KH", + "KW", + "SH", + "SW", + "PH", + "PW", + "DH", + "DW", + "GR", + "D_LAYOUT", + "K_LAYOUT", + ], +) + +DenseProfile = collections.namedtuple("DenseProfile", ["N", "IC", "OC"]) + +ArgConstConfig = collections.namedtuple("ArgConstConfig", ["Data", "Weights", "Bias", "Sum"]) + +QuantizationConfig = collections.namedtuple( + "QuantizationConfig", + [ + "d_zp", + "d_scl", + "d_pc", + "k_zp", + "k_scl", + "k_pc", + "rq_zp", + "rq_scl", + "rq_pc", + "sum_zp", + "sum_scl", + "sum_pc", + "o_zp", + "o_scl", + "o_pc", + ], +) diff --git a/tests/python/contrib/test_dnnl/test_binary.py b/tests/python/contrib/test_dnnl/test_binary.py new file mode 100644 index 000000000000..4a26c3f6869b --- /dev/null +++ b/tests/python/contrib/test_dnnl/test_binary.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test DNNL integration conv2d tests.""" + +import numpy as np +import tvm +from tvm import relay +from tvm.relay.op.contrib.dnnl import partition_for_dnnl + +from .common import requires_dnnl, parametrized, check_result, Builder + +import collections + +BinaryShapeConfig = collections.namedtuple("BinaryShapeConfig", ["lhs_shape", "rhs_shape"]) +base_0D = BinaryShapeConfig(lhs_shape=[], rhs_shape=[]) +base_3D = BinaryShapeConfig(lhs_shape=[3, 2, 1], rhs_shape=[3, 2, 1]) +base_4D = BinaryShapeConfig(lhs_shape=[4, 3, 2, 1], rhs_shape=[4, 3, 2, 1]) +base_6D = BinaryShapeConfig(lhs_shape=[2, 3, 4, 3, 2, 1], rhs_shape=[2, 3, 4, 3, 2, 1]) + +scalar_broadcast_6D = BinaryShapeConfig(lhs_shape=[2, 3, 4, 3, 2, 1], rhs_shape=[]) +bias_like_broadcast = BinaryShapeConfig(lhs_shape=[2, 7, 8, 8], rhs_shape=[7, 1, 1]) + +BinaryProfile = [ + ("Add_0D", tvm.relay.op.add, "float32", base_0D), + ("Add_4D", tvm.relay.op.add, "float32", base_4D), + ("Add_7D", tvm.relay.op.add, "float32", base_6D), + ("Add_Broadcast_scalar_4D", tvm.relay.op.add, "float32", scalar_broadcast_6D), + ("Add_BiasLike", tvm.relay.op.add, "float32", bias_like_broadcast), + ("Mul_BiasLike", tvm.relay.op.multiply, "float32", bias_like_broadcast), +] + + +@requires_dnnl +@parametrized("profile", BinaryProfile) +def test_binary(profile): + def generate_model(p, b_op_type, dtype): + np.random.seed(0) + bld = Builder() + + lhs = bld.arg(shape=p.lhs_shape, dtype=dtype, is_const=False) + rhs = bld.arg(shape=p.rhs_shape, dtype=dtype, is_const=False) + op = b_op_type(lhs, rhs) + return bld.finalize(op) + + op_type, dtype, shape_p = profile + ref_mod, args = generate_model(shape_p, op_type, dtype) + mod = partition_for_dnnl(ref_mod) + check_result(mod, ref_mod, args, tol=1e-10, atol=1e-10) diff --git a/tests/python/contrib/test_dnnl/test_conv2d.py b/tests/python/contrib/test_dnnl/test_conv2d.py new file mode 100644 index 000000000000..11c4a02709a6 --- /dev/null +++ b/tests/python/contrib/test_dnnl/test_conv2d.py @@ -0,0 +1,395 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test DNNL integration conv2d tests.""" + +import numpy as np +import tvm +from tvm import relay +from tvm.relay.op.contrib.dnnl import partition_for_dnnl + +from .common import ( + requires_dnnl, + parametrized, + check_result, + Builder, + filler_uni, +) +from .common import ConvProfile, ArgConstConfig, QuantizationConfig +from .common import permute, expand_dim + +acp_regular = ArgConstConfig(Data=False, Weights=True, Bias=True, Sum=None) +acp_no_bias = ArgConstConfig(Data=False, Weights=True, Bias=None, Sum=None) +acp_with_sum = ArgConstConfig(Data=False, Weights=True, Bias=True, Sum=False) +acp_no_bias_with_sum = ArgConstConfig(Data=False, Weights=True, Bias=None, Sum=False) + +# Basic convolution 3x3. More trivial, symmetric +base_conv = ConvProfile( + N=1, + IH=5, + IW=5, + IC=8, + OC=16, + KH=3, + KW=3, + SH=1, + SW=1, + PH=1, + PW=1, + DH=1, + DW=1, + GR=1, + D_LAYOUT="NCHW", + K_LAYOUT="OIHW", +) + +# same as Basic but with NHWC data layout +base_conv_nhwc = ConvProfile( + N=1, + IH=5, + IW=5, + IC=8, + OC=16, + KH=3, + KW=3, + SH=1, + SW=1, + PH=1, + PW=1, + DH=1, + DW=1, + GR=1, + D_LAYOUT="NHWC", + K_LAYOUT="HWIO", +) + +base_conv_no_pad = ConvProfile( + N=1, + IH=5, + IW=5, + IC=8, + OC=16, + KH=3, + KW=3, + SH=1, + SW=1, + PH=0, + PW=0, + DH=1, + DW=1, + GR=1, + D_LAYOUT="NCHW", + K_LAYOUT="OIHW", +) + +base_conv_no_pad_nhwc = ConvProfile( + N=1, + IH=5, + IW=5, + IC=8, + OC=16, + KH=3, + KW=3, + SH=1, + SW=1, + PH=0, + PW=0, + DH=1, + DW=1, + GR=1, + D_LAYOUT="NHWC", + K_LAYOUT="HWIO", +) + +# same as Basic but with groups +base_conv_group_no_pad = ConvProfile( + N=1, + IH=5, + IW=5, + IC=8, + OC=16, + KH=3, + KW=3, + SH=1, + SW=1, + PH=0, + PW=0, + DH=1, + DW=1, + GR=2, + D_LAYOUT="NCHW", + K_LAYOUT="OIHW", +) + +# same as Basic but with group == IC == OC +base_conv_dw_no_pad = ConvProfile( + N=1, + IH=5, + IW=5, + IC=16, + OC=16, + KH=3, + KW=3, + SH=1, + SW=1, + PH=0, + PW=0, + DH=1, + DW=1, + GR=16, + D_LAYOUT="NCHW", + K_LAYOUT="OIHW", +) + +base_conv_dilated = ConvProfile( + N=1, + IH=5, + IW=5, + IC=8, + OC=16, + KH=3, + KW=3, + SH=1, + SW=1, + PH=2, + PW=2, + DH=2, + DW=2, + GR=1, + D_LAYOUT="NCHW", + K_LAYOUT="OIHW", +) + +conv_profiles = [ + ("Base", base_conv, acp_regular), + ("NHWC", base_conv_nhwc, acp_regular), + ("Group", base_conv_group_no_pad, acp_regular), + ("DW", base_conv_dw_no_pad, acp_regular), + ("Dilated", base_conv_dilated, acp_regular), +] + + +@requires_dnnl +@parametrized("profile", conv_profiles) +def test_conv2d(profile): + def generate_model(p, c): + np.random.seed(0) + + d_shape = [p.N, p.IC, p.IH, p.IW] + w_shape = [p.OC, p.IC, p.KH, p.KW] + b_shape = [p.OC] + s_shape = [ + p.N, + p.OC, + (p.IH + 2 * p.PH - (p.KH - 1) * p.DH - 1) // p.SH + 1, + (p.IW + 2 * p.PW - (p.KW - 1) * p.DW - 1) // p.SW + 1, + ] + + if p.GR != 1: + w_shape[1] //= p.GR + + d_shape = permute(d_shape, l_from="NCHW", l_to=p.D_LAYOUT) + s_shape = permute(s_shape, l_from="NCHW", l_to=p.D_LAYOUT) + w_shape = permute(w_shape, l_from="OIHW", l_to=p.K_LAYOUT) + + c_dim = p.D_LAYOUT.find("C") + # b_shape = expand_dim(b_shape, rank=len(p.D_LAYOUT) - c_dim) + + bld = Builder() + + op = bld.arg(shape=d_shape, dtype="float32", is_const=c.Data) + wgh = bld.arg(shape=w_shape, dtype="float32", is_const=c.Weights) + op = tvm.relay.nn.conv2d( + op, + wgh, + kernel_size=[p.KH, p.KW], + padding=[p.PH, p.PW], + strides=[p.SH, p.SW], + dilation=[p.DH, p.DW], + groups=p.GR, + channels=p.OC, + out_dtype="float32", + data_layout=p.D_LAYOUT, + kernel_layout=p.K_LAYOUT, + ) + + if c.Bias is not None: + bias = bld.arg(shape=b_shape, dtype="float32", is_const=c.Bias) + op = tvm.relay.nn.bias_add(op, bias, axis=c_dim) + # op = tvm.relay.add(op, bias) + + if c.Sum is not None: + sum_in = bld.arg(shape=s_shape, dtype="float32", is_const=c.Sum) + op = tvm.relay.op.add(op, sum_in) + + return bld.finalize(op) + + conv_p, arg_p = profile + ref_mod, args = generate_model(conv_p, arg_p) + mod = partition_for_dnnl(ref_mod) + + # atol=1 means int values should match with +-1 quantum value tolerance + check_result(mod, ref_mod, args, tol=1e-10, atol=1) + + +# Regular and simple quantization scheme. All tensors are quantized per tensor. +# Data and weights quantized symmetrically (zp == 0). +qp_regular = QuantizationConfig( + d_zp=0, + d_scl=0.2, + d_pc=False, + k_zp=0, + k_scl=0.1, + k_pc=False, + rq_zp=30, + rq_scl=0.2, + rq_pc=False, + sum_zp=15, + sum_scl=0.3, + sum_pc=False, + o_zp=5, + o_scl=0.2, + o_pc=False, +) + +# Like a Regular quantization scheme but with asymmetric data quantization. +qp_asymmetric_data = QuantizationConfig( + d_zp=3, + d_scl=0.2, + d_pc=False, + k_zp=0, + k_scl=0.1, + k_pc=False, + rq_zp=10, + rq_scl=0.1, + rq_pc=False, + sum_zp=5, + sum_scl=0.3, + sum_pc=False, + o_zp=4, + o_scl=0.2, + o_pc=False, +) + +qnn_conv_profiles = [ + # Pattern Conv2d + Requantize + ("Base", base_conv, acp_regular, qp_regular), + ("NHWC", base_conv_nhwc, acp_regular, qp_regular), + # Asymmetric input. NOTE: No pad! Input ZP is not compatible with PAD + ("Group", base_conv_group_no_pad, acp_regular, qp_asymmetric_data), + ("DW", base_conv_dw_no_pad, acp_regular, qp_asymmetric_data), + ("NoBias", base_conv, acp_no_bias, qp_regular), + ("AsymmetricInput", base_conv_no_pad, acp_regular, qp_asymmetric_data), + ("AsymmetricInput_NHWC", base_conv_no_pad_nhwc, acp_regular, qp_asymmetric_data), + # Pattern Conv2d + Requantize + Sum + ("WithSum", base_conv_no_pad, acp_with_sum, qp_asymmetric_data), + ("WithSum_NHWC", base_conv_no_pad_nhwc, acp_with_sum, qp_asymmetric_data), + ("WithSum_NoBias", base_conv_no_pad, acp_no_bias_with_sum, qp_asymmetric_data), +] + + +@requires_dnnl +@parametrized("profile", qnn_conv_profiles) +def test_qnn_conv2d(profile): + def generate_model(p, c, q): + np.random.seed(0) + + d_shape = [p.N, p.IC, p.IH, p.IW] + w_shape = [p.OC, p.IC, p.KH, p.KW] + b_shape = [p.OC] + s_shape = [ + p.N, + p.OC, + (p.IH + 2 * p.PH - (p.KH - 1) * p.DH - 1) // p.SH + 1, + (p.IW + 2 * p.PW - (p.KW - 1) * p.DW - 1) // p.SW + 1, + ] + + if p.GR != 1: + w_shape[1] //= p.GR + + d_shape = permute(d_shape, l_from="NCHW", l_to=p.D_LAYOUT) + s_shape = permute(s_shape, l_from="NCHW", l_to=p.D_LAYOUT) + w_shape = permute(w_shape, l_from="OIHW", l_to=p.K_LAYOUT) + + c_dim = p.D_LAYOUT.find("C") + b_shape = expand_dim(b_shape, rank=len(p.D_LAYOUT) - c_dim) + + bld = Builder(qnn_profile=q) + + # Start build a test graph + data = bld.arg(shape=d_shape, dtype="uint8", is_const=c.Data, filler=filler_uni(0, 20)) + d_zp, d_scl = bld.make_zp_and_scl("d", p.IC) + + # Convolution + wgh = bld.arg(shape=w_shape, dtype="int8", is_const=c.Weights, filler=filler_uni(-20, 20)) + w_zp, w_scl = bld.make_zp_and_scl("k") + + op = tvm.relay.qnn.op.conv2d( + data, + wgh, + d_zp, + w_zp, + d_scl, + w_scl, + kernel_size=[p.KH, p.KW], + padding=[p.PH, p.PW], + strides=[p.SH, p.SW], + dilation=[p.DH, p.DW], + groups=p.GR, + channels=p.OC, + out_dtype="int32", + data_layout=p.D_LAYOUT, + kernel_layout=p.K_LAYOUT, + ) + # Optional bias + if c.Bias is not None: + bias = bld.arg( + shape=b_shape, dtype="int32", is_const=c.Bias, filler=filler_uni(-50, 50) + ) + op = tvm.relay.add(op, bias) + + # Re-quantization + rq_in_zp = bld.make_zp(0) + rq_in_scl = bld.make_scl(q.d_scl * q.k_scl) # in real cases that should be a vector + rq_out_zp, rq_out_scl = bld.make_zp_and_scl("rq") + + op = tvm.relay.qnn.op.requantize( + op, rq_in_scl, rq_in_zp, rq_out_scl, rq_out_zp, out_dtype="int32" + ) + op = tvm.relay.clip( + op, a_min=0.0, a_max=255.0 + ) # pytorch frontend specific, I guess it's redundant + op = tvm.relay.cast(op, dtype="uint8") + + # Optional sum (ResNet like) + if c.Sum is not None: + sum_in = bld.arg(dtype="uint8", shape=s_shape, filler=filler_uni(0, 10), is_const=c.Sum) + + lhs_zp, lhs_scl = bld.make_zp_and_scl("rq") + rhs_zp, rhs_scl = bld.make_zp_and_scl("sum") + out_zp, out_scl = bld.make_zp_and_scl("o") + + op = tvm.relay.qnn.op.add(op, sum_in, lhs_scl, lhs_zp, rhs_scl, rhs_zp, out_scl, out_zp) + op = tvm.relay.clip(op, a_min=0.0, a_max=255.0) + + return bld.finalize(op) + + conv_p, arg_p, quant_p = profile + ref_mod, args = generate_model(conv_p, arg_p, quant_p) + mod = partition_for_dnnl(ref_mod) + + # atol=1 means int values should match with +-1 quantum value tolerance + check_result(mod, ref_mod, args, tol=1e-10, atol=1) diff --git a/tests/python/contrib/test_dnnl/test_dense.py b/tests/python/contrib/test_dnnl/test_dense.py new file mode 100644 index 000000000000..f74a0f7b5399 --- /dev/null +++ b/tests/python/contrib/test_dnnl/test_dense.py @@ -0,0 +1,182 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test DNNL integration dense tests.""" + +import numpy as np +import tvm +from tvm import relay +from tvm.relay.op.contrib.dnnl import partition_for_dnnl, get_dnnl_version + +from .common import requires_dnnl, parametrized, check_result, Builder, filler_uni +from .common import DenseProfile, ArgConstConfig, QuantizationConfig + +base_dense_profile = DenseProfile(N=2, IC=10, OC=16) +regular_const_arg_prof = ArgConstConfig(Data=False, Weights=True, Bias=True, Sum=None) +cp_with_sum = ArgConstConfig(Data=False, Weights=True, Bias=True, Sum=False) + +dense_profiles = [ + ("Base", base_dense_profile, ArgConstConfig(Data=False, Weights=True, Bias=True, Sum=None)), + ("WithSum", base_dense_profile, ArgConstConfig(Data=False, Weights=True, Bias=True, Sum=False)), +] + + +@requires_dnnl +@parametrized("profile", dense_profiles) +def test_dense(profile): + def generate_model(p, c): + np.random.seed(0) + + d_shape = [p.N, p.IC] + w_shape = [p.OC, p.IC] + b_shape = [p.OC] + s_shape = [p.N, p.OC] + + c_dim = 1 + + bld = Builder() + + op = bld.arg(shape=d_shape, dtype="float32", is_const=c.Data) + wgh = bld.arg(shape=w_shape, dtype="float32", is_const=c.Weights) + op = tvm.relay.nn.dense(op, wgh, out_dtype="float32") + + if c.Bias is not None: + bias = bld.arg(shape=b_shape, dtype="float32", is_const=c.Bias) + op = tvm.relay.nn.bias_add(op, bias, axis=c_dim) + + if c.Sum is not None: + sum_in = bld.arg(shape=s_shape, dtype="float32", is_const=c.Sum) + op = tvm.relay.op.add(op, sum_in) + + return bld.finalize(op) + + dense_p, arg_p = profile + ref_mod, args = generate_model(dense_p, arg_p) + mod = partition_for_dnnl(ref_mod) + check_result(mod, ref_mod, args, tol=1e-10, atol=1) + + +qp_regular = QuantizationConfig( + d_zp=0, + d_scl=0.2, + d_pc=False, + k_zp=0, + k_scl=0.1, + k_pc=False, + rq_zp=30, + rq_scl=0.2, + rq_pc=False, # asymmetric + sum_zp=15, + sum_scl=0.3, + sum_pc=False, # asymmetric + o_zp=5, + o_scl=0.2, + o_pc=False, # asymmetric +) +qp_asymmetric_all = QuantizationConfig( + d_zp=3, + d_scl=0.2, + d_pc=False, # asymmetric + k_zp=0, + k_scl=0.1, + k_pc=False, + rq_zp=10, + rq_scl=0.1, + rq_pc=False, # asymmetric + sum_zp=5, + sum_scl=0.3, + sum_pc=False, # asymmetric + o_zp=4, + o_scl=0.2, + o_pc=False, # asymmetric +) + +qnn_dense_profiles = [ + # Pattern Dense + Requantize + ("Base", base_dense_profile, regular_const_arg_prof, qp_regular), + ("AsymmetricInput", base_dense_profile, regular_const_arg_prof, qp_asymmetric_all), + # Pattern Dense + Requantize + Sum + ("AsymmetricInput_Sum", base_dense_profile, cp_with_sum, qp_asymmetric_all), +] + + +@requires_dnnl +@parametrized("profile", qnn_dense_profiles) +def test_qnn_dense(profile): + def generate_model(p, c, q): + np.random.seed(0) + + d_shape = [p.N, p.IC] + w_shape = [p.OC, p.IC] + b_shape = [p.OC] + s_shape = [p.N, p.OC] + + bld = Builder(qnn_profile=q) + + # Start build a test graph + data = bld.arg(shape=d_shape, dtype="uint8", is_const=c.Data, filler=filler_uni(0, 20)) + d_zp, d_scl = bld.make_zp_and_scl("d", p.IC) + + # Convolution + wgh = bld.arg(shape=w_shape, dtype="int8", is_const=c.Weights, filler=filler_uni(-20, 20)) + w_zp, w_scl = bld.make_zp_and_scl("k") + + op = tvm.relay.qnn.op.dense( + data, wgh, d_zp, w_zp, d_scl, w_scl, units=p.OC, out_dtype="int32" + ) + # Optional bias + if c.Bias is not None: + bias = bld.arg( + shape=b_shape, dtype="int32", is_const=c.Bias, filler=filler_uni(-50, 50) + ) + op = tvm.relay.add(op, bias) + + # Re-quantization + rq_in_zp = bld.make_zp(0) + rq_in_scl = bld.make_scl(q.d_scl * q.k_scl) # in real cases that should be a vector + rq_out_zp, rq_out_scl = bld.make_zp_and_scl("rq") + + op = tvm.relay.qnn.op.requantize( + op, rq_in_scl, rq_in_zp, rq_out_scl, rq_out_zp, out_dtype="int32" + ) + op = tvm.relay.clip( + op, a_min=0.0, a_max=255.0 + ) # pytorch frontend specific, I guess it's redundant + op = tvm.relay.cast(op, dtype="uint8") + + # Optional sum (ResNet like) + if c.Sum is not None: + sum_in = bld.arg(dtype="uint8", shape=s_shape, filler=filler_uni(0, 10), is_const=c.Sum) + + lhs_zp, lhs_scl = bld.make_zp_and_scl("rq") + rhs_zp, rhs_scl = bld.make_zp_and_scl("sum") + out_zp, out_scl = bld.make_zp_and_scl("o") + + op = tvm.relay.qnn.op.add(op, sum_in, lhs_scl, lhs_zp, rhs_scl, rhs_zp, out_scl, out_zp) + op = tvm.relay.clip(op, a_min=0.0, a_max=255.0) + + return bld.finalize(op) + + conv_p, arg_p, quant_p = profile + ref_mod, args = generate_model(conv_p, arg_p, quant_p) + mod = partition_for_dnnl(ref_mod) + + # WA. Old DNNL versions don't support dense+sum int8 pattern(dst_zp and per channel o_scale are not supported). + # desired_compiler == None skip verification of full dnnl offload. + desired_compiler = None if arg_p.Sum is not None and get_dnnl_version() < (2, 2) else "dnnl" + + # atol=1 means int values should match with +-1 quantum value tolerance + check_result(mod, ref_mod, args, tol=1e-10, atol=1, desired_compiler=desired_compiler) diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl/test_other.py similarity index 100% rename from tests/python/contrib/test_dnnl.py rename to tests/python/contrib/test_dnnl/test_other.py diff --git a/tests/python/relay/test_json_runtime.py b/tests/python/relay/test_json_runtime.py index c6eb7531f635..f1131ba64100 100644 --- a/tests/python/relay/test_json_runtime.py +++ b/tests/python/relay/test_json_runtime.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. """Unit tests for JSON codegen and runtime.""" -import os import sys +import pytest import numpy as np @@ -24,13 +24,18 @@ import tvm.relay.op as reg import tvm.relay.testing from tvm import relay, runtime -from tvm.contrib import utils from tvm.relay import transform from tvm.relay.backend import te_compiler from tvm.relay.build_module import bind_params_by_name from tvm.relay.op.contrib.register import get_pattern_table +requires_dnnl = pytest.mark.skipif( + tvm.get_global_func("runtime.DNNLJSONRuntimeCreate", True) is None, + reason="DNNL codegen is not available", +) + + def set_func_attr(func, compile_name, symbol_name): func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1)) @@ -43,26 +48,24 @@ def check_result( mod, ref_mod, map_inputs, out_shape, tol=1e-5, target="llvm", device=tvm.cpu(), params=None ): if sys.platform == "win32": - print("Skip test on Windows for now") - return + pytest.skip("Skip DNNL test on Windows for now") # Run the reference result te_compiler.get().clear() with tvm.transform.PassContext(opt_level=3): - json, lib, param = relay.build(ref_mod, target=target, params=params) - rt_mod = tvm.contrib.graph_executor.create(json, lib, device) + ref_lib = relay.build(ref_mod, target=target, params=params) + ref_rt_mod = tvm.contrib.graph_executor.GraphModule(ref_lib["default"](device)) for name, data in map_inputs.items(): - rt_mod.set_input(name, data) - rt_mod.set_input(**param) - rt_mod.run() - out = tvm.nd.empty(out_shape, device=device) - out = rt_mod.get_output(0, out) - ref_result = out.numpy() + ref_rt_mod.set_input(name, data) + ref_rt_mod.run() + ref_out = tvm.nd.empty(out_shape, device=device) + ref_out = ref_rt_mod.get_output(0, ref_out) + ref_result = ref_out.numpy() def check_vm_result(): te_compiler.get().clear() - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): exe = relay.vm.compile(mod, target=target, params=params) code, lib = exe.save() exe = runtime.vm.Executable.load_exec(code, lib) @@ -72,13 +75,12 @@ def check_vm_result(): def check_graph_executor_result(): te_compiler.get().clear() - with relay.build_config(opt_level=3): - json, lib, param = relay.build(mod, target=target, params=params) - rt_mod = tvm.contrib.graph_executor.create(json, lib, device) + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, target=target, params=params) + rt_mod = tvm.contrib.graph_executor.GraphModule(ref_lib["default"](device)) for name, data in map_inputs.items(): rt_mod.set_input(name, data) - rt_mod.set_input(**param) rt_mod.run() out = tvm.nd.empty(out_shape, device=device) out = rt_mod.get_output(0, out) @@ -88,11 +90,10 @@ def check_graph_executor_result(): check_graph_executor_result() +@requires_dnnl def test_conv2d(): """Test a subgraph with a single conv2d operator.""" - if not tvm.get_global_func("runtime.DNNLJSONRuntimeCreate", True): - print("skip because DNNL codegen is not available") - return + np.random.seed(0) def conv2d_direct(): dtype = "float32" @@ -172,12 +173,10 @@ def group_conv2d(): check_result(mod, ref_mod, map_inputs, out_shape, tol=1e-5) +@requires_dnnl def test_add(): """Test a subgraph with a single add operator.""" - if not tvm.get_global_func("runtime.DNNLJSONRuntimeCreate", True): - print("skip because DNNL codegen is not available") - return - + np.random.seed(0) dtype = "float32" shape = (10, 10) @@ -216,12 +215,10 @@ def gen_add(): check_result(mod, ref_mod, {"data0": data0, "data1": data1}, shape, tol=1e-5) +@requires_dnnl def test_multiply(): """Test a subgraph with a single add operator.""" - if not tvm.get_global_func("runtime.DNNLJSONRuntimeCreate", True): - print("skip because DNNL codegen is not available") - return - + np.random.seed(0) dtype = "float32" shape = (10, 10) @@ -260,14 +257,11 @@ def gen_multiply(): check_result(mod, ref_mod, {"data0": data0, "data1": data1}, shape, tol=1e-5) +@requires_dnnl def test_relu(): """Test a subgraph with a single ReLU operator.""" - if not tvm.get_global_func("runtime.DNNLJSONRuntimeCreate", True): - print("skip because DNNL codegen is not available") - return - + np.random.seed(0) dtype = "float32" - shape = (1, 32, 14, 14) def gen_relu(shape): data0 = relay.var("data0", shape=shape, dtype=dtype) @@ -312,11 +306,10 @@ def check(shape): check(shape=(1, 32)) +@requires_dnnl def test_dense(): """Test a subgraph with a single dense operator.""" - if not tvm.get_global_func("runtime.DNNLJSONRuntimeCreate", True): - print("skip because DNNL codegen is not available") - return + np.random.seed(0) dtype = "float32" a_shape = (1, 512) @@ -357,11 +350,10 @@ def gen_dense(): check_result(mod, ref_mod, {"A": data_a, "B": data_b}, (1, 1024), tol=1e-5) +@requires_dnnl def test_bn(): """Test a subgraph with a single batch_norm operator.""" - if not tvm.get_global_func("runtime.DNNLJSONRuntimeCreate", True): - print("skip because DNNL codegen is not available") - return + np.random.seed(0) dtype = "float32" d_shape = (1, 8) @@ -431,12 +423,10 @@ def gen_bn(): ) +@requires_dnnl def test_multiple_ops(): """Test a subgraph with multiple operators.""" - if not tvm.get_global_func("runtime.DNNLJSONRuntimeCreate", True): - print("skip because DNNL codegen is not available") - return - + np.random.seed(0) dtype = "float32" ishape = (1, 32, 14, 14) w1shape = (32, 32, 3, 3) @@ -497,12 +487,10 @@ def get_partitoned_mod(mod): ) +@requires_dnnl def test_composite(): """Test DNNL patterns and there composite functions.""" - if not tvm.get_global_func("runtime.DNNLJSONRuntimeCreate", True): - print("skip because DNNL codegen is not available") - return - + np.random.seed(0) dtype = "float32" def conv2d_relu(): @@ -609,12 +597,10 @@ def conv2d_bias_relu(): check_result(mod, ref_mod, input_maps, out_shape, tol=1e-5) +@requires_dnnl def test_constant(): """Test the subgraph with (var, const, ...) arguments.""" - if not tvm.get_global_func("runtime.DNNLJSONRuntimeCreate", True): - print("skip because DNNL codegen is not available") - return - + np.random.seed(0) dtype = "float32" ishape = (1, 32, 14, 14) wshape = (32, 32, 3, 3) @@ -661,12 +647,10 @@ def test_constant(): check_result(mod, ref_mod, {"data": i_data}, (1, 32, 14, 14), tol=1e-5) +@requires_dnnl def test_partial_constant(): """Test the subgraph with (const, var, const, var) arguments.""" - if not tvm.get_global_func("runtime.DNNLJSONRuntimeCreate", True): - print("skip because DNNL codegen is not available") - return - + np.random.seed(0) dtype = "float32" ishape = (10, 10)