From 0a1e04de6bdd32fbf01a10aad553ca2b4a768409 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 15 Aug 2022 12:33:04 -0700 Subject: [PATCH] [TVMScript] New Parser: Part B (#217) This PR is organized as follows: - The basic infrastructure: `tvm/script/ir_builder/base.{py/h/cc} ` - IRBuilder for package `tvm.ir`: `tvm/script/ir_builder/ir/*.{py/h/cc}` - IRBuilder for package `tvm.tir`: `tvm/script/ir_builder/tir/*.{py/h/cc}` --- include/tvm/script/ir_builder/base.h | 166 ++++ include/tvm/script/ir_builder/ir/frame.h | 61 ++ include/tvm/script/ir_builder/ir/ir.h | 39 + include/tvm/script/ir_builder/tir/frame.h | 459 +++++++++ include/tvm/script/ir_builder/tir/ir.h | 142 +++ python/tvm/script/__init__.py | 6 +- python/tvm/script/ir_builder/__init__.py | 19 + python/tvm/script/ir_builder/_ffi_api.py | 20 + python/tvm/script/ir_builder/base.py | 76 ++ python/tvm/script/ir_builder/ir/__init__.py | 19 + python/tvm/script/ir_builder/ir/_ffi_api.py | 20 + python/tvm/script/ir_builder/ir/frame.py | 26 + python/tvm/script/ir_builder/ir/ir.py | 24 + python/tvm/script/ir_builder/tir/__init__.py | 20 + python/tvm/script/ir_builder/tir/_ffi_api.py | 20 + python/tvm/script/ir_builder/tir/frame.py | 116 +++ python/tvm/script/ir_builder/tir/ir.py | 954 +++++++++++++++++++ src/script/ir_builder/base.cc | 115 +++ src/script/ir_builder/ir/frame.cc | 43 + src/script/ir_builder/ir/ir.cc | 38 + src/script/ir_builder/tir/frame.cc | 210 ++++ src/script/ir_builder/tir/ir.cc | 665 +++++++++++++ src/script/ir_builder/tir/utils.h | 95 ++ 23 files changed, 3349 insertions(+), 4 deletions(-) create mode 100644 include/tvm/script/ir_builder/base.h create mode 100644 include/tvm/script/ir_builder/ir/frame.h create mode 100644 include/tvm/script/ir_builder/ir/ir.h create mode 100644 include/tvm/script/ir_builder/tir/frame.h create mode 100644 include/tvm/script/ir_builder/tir/ir.h create mode 100644 python/tvm/script/ir_builder/__init__.py create mode 100644 python/tvm/script/ir_builder/_ffi_api.py create mode 100644 python/tvm/script/ir_builder/base.py create mode 100644 python/tvm/script/ir_builder/ir/__init__.py create mode 100644 python/tvm/script/ir_builder/ir/_ffi_api.py create mode 100644 python/tvm/script/ir_builder/ir/frame.py create mode 100644 python/tvm/script/ir_builder/ir/ir.py create mode 100644 python/tvm/script/ir_builder/tir/__init__.py create mode 100644 python/tvm/script/ir_builder/tir/_ffi_api.py create mode 100644 python/tvm/script/ir_builder/tir/frame.py create mode 100644 python/tvm/script/ir_builder/tir/ir.py create mode 100644 src/script/ir_builder/base.cc create mode 100644 src/script/ir_builder/ir/frame.cc create mode 100644 src/script/ir_builder/ir/ir.cc create mode 100644 src/script/ir_builder/tir/frame.cc create mode 100644 src/script/ir_builder/tir/ir.cc create mode 100644 src/script/ir_builder/tir/utils.h diff --git a/include/tvm/script/ir_builder/base.h b/include/tvm/script/ir_builder/base.h new file mode 100644 index 0000000000..179cca42df --- /dev/null +++ b/include/tvm/script/ir_builder/base.h @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_IR_BUILDER_BASE_H_ +#define TVM_SCRIPT_IR_BUILDER_BASE_H_ + +#include +#include +#include + +#include + +namespace tvm { +namespace script { +namespace ir_builder { + +////////////////////////////// Core Infra: Frame ////////////////////////////// + +class IRBuilderFrameNode : public runtime::Object { + public: + std::vector> callbacks; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `callbacks` is not visited. + } + + static constexpr const char* _type_key = "script.ir_builder.IRBuilderFrame"; + TVM_DECLARE_BASE_OBJECT_INFO(IRBuilderFrameNode, runtime::Object); + + public: + virtual ~IRBuilderFrameNode() = default; + virtual void EnterWithScope(); + virtual void ExitWithScope(); + + void AddCallback(runtime::TypedPackedFunc callback); +}; + +class IRBuilderFrame : public runtime::ObjectRef { + public: + virtual ~IRBuilderFrame() = default; + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRBuilderFrame, ObjectRef, IRBuilderFrameNode); + + protected: + IRBuilderFrame() = default; + + public: + inline void EnterWithScope(); + inline void ExitWithScope(); +}; + +////////////////////////////// Core Infra: Builder ////////////////////////////// +/// +class IRBuilderNode : public runtime::Object { + public: + runtime::Array frames; + Optional result; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("frames", &frames); + v->Visit("result", &result); + } + + static constexpr const char* _type_key = "script.ir_builder.IRBuilder"; + TVM_DECLARE_FINAL_OBJECT_INFO(IRBuilderNode, runtime::Object); + + public: + template + inline Optional FindFrame() const; + template + inline Optional GetLastFrame() const; + template + inline TObjectRef Get() const; +}; + +class IRBuilder : public runtime::ObjectRef { + public: + IRBuilder(); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRBuilder, ObjectRef, IRBuilderNode); + + public: + void EnterWithScope(); + void ExitWithScope(); + static IRBuilder Current(); + template + inline static TObjectRef Name(String name, TObjectRef obj); +}; + +////////////////////////////// Details ////////////////////////////// + +namespace details { + +class Namer { + public: + using FType = NodeFunctor; + static FType& vtable(); + static void Name(ObjectRef node, String name); +}; + +} // namespace details + +template +inline TObjectRef IRBuilder::Name(String name, TObjectRef obj) { + details::Namer::Name(obj, name); + return Downcast(obj); +} + +inline void IRBuilderFrame::EnterWithScope() { + ICHECK(data_ != nullptr); + static_cast(data_.get())->EnterWithScope(); +} + +inline void IRBuilderFrame::ExitWithScope() { + ICHECK(data_ != nullptr); + static_cast(data_.get())->ExitWithScope(); + data_.reset(); +} + +template +inline Optional IRBuilderNode::FindFrame() const { + using TFrameNode = typename TFrame::ContainerType; + for (auto it = frames.rbegin(); it != frames.rend(); ++it) { + if (const TFrameNode* p = (*it).template as()) { + return GetRef(p); + } + } + return NullOpt; +} + +template +inline Optional IRBuilderNode::GetLastFrame() const { + using TFrameNode = typename TFrame::ContainerType; + if (!frames.empty() && frames.back()->IsInstance()) { + return Downcast(frames.back()); + } + return NullOpt; +} + +template +inline TObjectRef IRBuilderNode::Get() const { + using TObject = typename TObjectRef::ContainerType; + CHECK(result.defined()) << "IndexError: No result exists in IRBuilder yet"; + const auto* n = result.as(); + CHECK(n != nullptr) << "IndexError: IRBuilder result is not of type: " << TObject::_type_key; + return GetRef(n); +} + +} // namespace ir_builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_IR_BUILDER_BASE_H_ diff --git a/include/tvm/script/ir_builder/ir/frame.h b/include/tvm/script/ir_builder/ir/frame.h new file mode 100644 index 0000000000..9a8791be7c --- /dev/null +++ b/include/tvm/script/ir_builder/ir/frame.h @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_IR_BUILDER_IR_FRAME_H_ +#define TVM_SCRIPT_IR_BUILDER_IR_FRAME_H_ + +#include +#include +#include +#include + +#include + +namespace tvm { +namespace script { +namespace ir_builder { + +class IRModuleFrameNode : public IRBuilderFrameNode { + public: + Array global_vars; + Array functions; + + void VisitAttrs(tvm::AttrVisitor* v) { + IRBuilderFrameNode::VisitAttrs(v); + v->Visit("global_vars", &global_vars); + v->Visit("functions", &functions); + } + + static constexpr const char* _type_key = "script.ir_builder.IRModuleFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleFrameNode, IRBuilderFrameNode); + + public: + void ExitWithScope() final; +}; + +class IRModuleFrame : public IRBuilderFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRModuleFrame, IRBuilderFrame, + IRModuleFrameNode); +}; + +} // namespace ir_builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_IR_BUILDER_IR_FRAME_H_ diff --git a/include/tvm/script/ir_builder/ir/ir.h b/include/tvm/script/ir_builder/ir/ir.h new file mode 100644 index 0000000000..b58e51a945 --- /dev/null +++ b/include/tvm/script/ir_builder/ir/ir.h @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_IR_BUILDER_IR_IR_H_ +#define TVM_SCRIPT_IR_BUILDER_IR_IR_H_ + +#include +#include +#include +#include + +#include + +namespace tvm { +namespace script { +namespace ir_builder { + +TVM_DLL IRModuleFrame IRModule(); + +} +} // namespace script +} // namespace tvm + +#endif // TVM_IR_IR_BUILDER_IR_IR_H_ diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h new file mode 100644 index 0000000000..d2d2485bbe --- /dev/null +++ b/include/tvm/script/ir_builder/tir/frame.h @@ -0,0 +1,459 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_IR_BUILDER_TIR_FRAME_H_ +#define TVM_SCRIPT_IR_BUILDER_TIR_FRAME_H_ + +#include +#include +#include + +namespace tvm { +namespace script { +namespace ir_builder { +namespace tir { + +class TIRFrameNode : public IRBuilderFrameNode { + public: + Array stmts; + + void VisitAttrs(tvm::AttrVisitor* v) { + IRBuilderFrameNode::VisitAttrs(v); + v->Visit("stmts", &stmts); + } + + static constexpr const char* _type_key = "script.ir_builder.tir.TIRFrame"; + TVM_DECLARE_BASE_OBJECT_INFO(TIRFrameNode, IRBuilderFrameNode); +}; + +class TIRFrame : public IRBuilderFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TIRFrame, IRBuilderFrame, TIRFrameNode); + + protected: + TIRFrame() = default; +}; + +class BlockFrameNode : public TIRFrameNode { + public: + String name; + Array iter_vars; + Optional> reads; + Optional> writes; + Optional init; + Array alloc_buffers; + Array match_buffers; + Optional> annotations; + + Array iter_values; + Optional predicate; + bool no_realize; + + void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("name", &name); + v->Visit("iter_vars", &iter_vars); + v->Visit("reads", &reads); + v->Visit("writes", &writes); + v->Visit("init", &init); + v->Visit("alloc_buffers", &alloc_buffers); + v->Visit("match_buffers", &match_buffers); + v->Visit("annotations", &annotations); + v->Visit("iter_values", &iter_values); + v->Visit("predicate", &predicate); + v->Visit("no_realize", &no_realize); + } + + static constexpr const char* _type_key = "script.ir_builder.tir.BlockFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(BlockFrameNode, TIRFrameNode); + + public: + void ExitWithScope() final; +}; + +class BlockFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockFrame, TIRFrame, BlockFrameNode); +}; + +class BlockInitFrameNode : public TIRFrameNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) { TIRFrameNode::VisitAttrs(v); } + + static constexpr const char* _type_key = "script.ir_builder.tir.BlockInitFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(BlockInitFrameNode, TIRFrameNode); + + public: + void EnterWithScope() final; + void ExitWithScope() final; +}; + +class BlockInitFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockInitFrame, TIRFrame, BlockInitFrameNode); +}; + +class ForFrameNode : public TIRFrameNode { + public: + using FMakeForLoop = + runtime::TypedPackedFunc, Array, tvm::tir::Stmt)>; + + Array vars; + Array doms; + FMakeForLoop f_make_for_loop; + + void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("vars", &vars); + v->Visit("doms", &doms); + // `f_make_for_loop` is not visited. + } + + static constexpr const char* _type_key = "script.ir_builder.tir.ForFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(ForFrameNode, TIRFrameNode); + + public: + void ExitWithScope() final; +}; + +class ForFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ForFrame, TIRFrame, ForFrameNode); +}; + +class PrimFuncFrameNode : public TIRFrameNode { + public: + Optional name; + Array args; + Optional ret_type; + Map buffer_map; + Map preflattened_buffer_map; + Optional> attrs; + Map env_threads; + Array root_alloc_buffers; + + void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("name", &name); + v->Visit("args", &args); + v->Visit("ret_type", &ret_type); + v->Visit("buffer_map", &buffer_map); + v->Visit("preflattened_buffer_map", &preflattened_buffer_map); + v->Visit("attrs", &attrs); + v->Visit("env_threads", &env_threads); + v->Visit("root_alloc_buffers", &root_alloc_buffers); + } + + static constexpr const char* _type_key = "script.ir_builder.tir.PrimFuncFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncFrameNode, TIRFrameNode); + + public: + void ExitWithScope() final; +}; + +class PrimFuncFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PrimFuncFrame, TIRFrame, PrimFuncFrameNode); +}; + +class AssertFrameNode : public TIRFrameNode { + public: + PrimExpr condition; + PrimExpr message; + + void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("condition", &condition); + v->Visit("message", &message); + } + + static constexpr const char* _type_key = "script.ir_builder.tir.AssertFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(AssertFrameNode, TIRFrameNode); + + public: + void ExitWithScope() final; +}; + +class AssertFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AssertFrame, TIRFrame, AssertFrameNode); +}; + +class LetFrameNode : public TIRFrameNode { + public: + tvm::tir::Var var; + PrimExpr value; + + void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("var", &var); + v->Visit("value", &value); + } + + static constexpr const char* _type_key = "script.ir_builder.tir.LetFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(LetFrameNode, TIRFrameNode); + + public: + void ExitWithScope() final; +}; + +class LetFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LetFrame, TIRFrame, LetFrameNode); +}; + +class AllocateFrameNode : public TIRFrameNode { + public: + Array extents; + DataType dtype; + String storage_scope; + PrimExpr condition; + Map annotations; + tvm::tir::Buffer buffer; + + void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("extents", &extents); + v->Visit("dtype", &dtype); + v->Visit("storage_scope", &storage_scope); + v->Visit("condition", &condition); + v->Visit("annotations", &annotations); + v->Visit("buffer", &buffer); + } + + static constexpr const char* _type_key = "script.ir_builder.tir.AllocateFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(AllocateFrameNode, TIRFrameNode); + + public: + void ExitWithScope() final; +}; + +class AllocateFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateFrame, TIRFrame, AllocateFrameNode); +}; + +class AllocateConstFrameNode : public TIRFrameNode { + public: + DataType dtype; + Array extents; + tvm::runtime::NDArray data; + tvm::tir::Buffer buffer; + Map annotations; + + void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("dtype", &dtype); + v->Visit("extents", &extents); + v->Visit("data", &data); + v->Visit("buffer", &buffer); + v->Visit("annotations", &annotations); + } + + static constexpr const char* _type_key = "script.ir_builder.tir.AllocateConstFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(AllocateConstFrameNode, TIRFrameNode); + + public: + void ExitWithScope() final; +}; + +class AllocateConstFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateConstFrame, TIRFrame, + AllocateConstFrameNode); +}; + +class LaunchThreadFrameNode : public TIRFrameNode { + public: + PrimExpr extent; + String attr_key; + tvm::tir::IterVar iter_var; + + void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("extent", &extent); + v->Visit("attr_key", &attr_key); + v->Visit("iter_var", &iter_var); + } + + static constexpr const char* _type_key = "script.ir_builder.tir.LaunchThreadFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(LaunchThreadFrameNode, TIRFrameNode); + + public: + void ExitWithScope() final; +}; + +class LaunchThreadFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LaunchThreadFrame, TIRFrame, + LaunchThreadFrameNode); +}; + +class RealizeFrameNode : public TIRFrameNode { + public: + tvm::tir::BufferRegion buffer_slice; + String storage_scope; + PrimExpr condition; + + void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("buffer_slice", &buffer_slice); + v->Visit("storage_scope", &storage_scope); + v->Visit("condition", &condition); + } + + static constexpr const char* _type_key = "script.ir_builder.tir.RealizeFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(RealizeFrameNode, TIRFrameNode); + + public: + void ExitWithScope() final; +}; + +class RealizeFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RealizeFrame, TIRFrame, RealizeFrameNode); +}; + +class AttrFrameNode : public TIRFrameNode { + public: + ObjectRef node; + String attr_key; + PrimExpr value; + + void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("node", &node); + v->Visit("attr_key", &attr_key); + v->Visit("value", &value); + } + + static constexpr const char* _type_key = "script.ir_builder.tir.AttrFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(AttrFrameNode, TIRFrameNode); + + public: + void ExitWithScope() final; +}; + +class AttrFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AttrFrame, TIRFrame, AttrFrameNode); +}; + +class WhileFrameNode : public TIRFrameNode { + public: + PrimExpr condition; + + void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("condition", &condition); + } + + static constexpr const char* _type_key = "script.ir_builder.tir.WhileFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(WhileFrameNode, TIRFrameNode); + + public: + void ExitWithScope() final; +}; + +class WhileFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(WhileFrame, TIRFrame, WhileFrameNode); +}; + +class IfFrameNode : public TIRFrameNode { + public: + PrimExpr condition; + Optional> then_stmts; + Optional> else_stmts; + + void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("condition", &condition); + v->Visit("then_stmts", &then_stmts); + v->Visit("else_stmts", &else_stmts); + } + + static constexpr const char* _type_key = "script.ir_builder.tir.IfFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(IfFrameNode, TIRFrameNode); + + public: + void ExitWithScope() final; +}; + +class IfFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IfFrame, TIRFrame, IfFrameNode); +}; + +class ThenFrameNode : public TIRFrameNode { + public: + static constexpr const char* _type_key = "script.ir_builder.tir.ThenFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(ThenFrameNode, TIRFrameNode); + + public: + void EnterWithScope() final; + void ExitWithScope() final; +}; + +class ThenFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ThenFrame, TIRFrame, ThenFrameNode); +}; + +class ElseFrameNode : public TIRFrameNode { + public: + static constexpr const char* _type_key = "script.ir_builder.tir.ElseFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(ElseFrameNode, TIRFrameNode); + + public: + void EnterWithScope() final; + void ExitWithScope() final; +}; + +class ElseFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ElseFrame, TIRFrame, ElseFrameNode); +}; + +class DeclBufferFrameNode : public TIRFrameNode { + public: + tvm::tir::Buffer buffer; + + void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("buffer", &buffer); + } + + static constexpr const char* _type_key = "script.ir_builder.tir.DeclBufferFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(DeclBufferFrameNode, TIRFrameNode); + + public: + void ExitWithScope() final; +}; + +class DeclBufferFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(DeclBufferFrame, TIRFrame, DeclBufferFrameNode); +}; + +} // namespace tir +} // namespace ir_builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_IR_BUILDER_TIR_FRAME_H_ diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h new file mode 100644 index 0000000000..c26d552737 --- /dev/null +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_IR_BUILDER_TIR_IR_H_ +#define TVM_SCRIPT_IR_BUILDER_TIR_IR_H_ + +#include +#include +#include + +namespace tvm { +namespace script { +namespace ir_builder { +namespace tir { + +using tvm::runtime::NDArray; +using tvm::tir::Buffer; +using tvm::tir::Var; + +Buffer BufferDecl(Array shape, DataType dtype, String buffer_name, Optional data, + Optional> strides, Optional elem_offset, + String storage_scope, int align, int offset_factor, String buffer_type, + Optional> axis_separators); +PrimExpr Ptr(runtime::DataType dtype, String storage_scope = "global"); + +BlockFrame Block(String name, bool no_realize = false); +BlockInitFrame Init(); +void Where(PrimExpr predicate); +void Reads(Array buffer_slices); +void Writes(Array buffer_slices); +void BlockAttrs(Map attrs); +Buffer AllocBuffer(Array shape, DataType dtype = DataType::Float(32), + Optional data = NullOpt, Array strides = {}, + PrimExpr elem_offset = PrimExpr(), String storage_scope = "", int align = -1, + int offset_factor = 0, String buffer_type = "default", + Array axis_separators = {}); + +namespace axis { +Var Spatial(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32)); +Var Reduce(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32)); +Var Scan(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32)); +Var Opaque(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32)); +Array Remap(String kinds, Array bindings, DataType dtype = DataType::Int(32)); +} // namespace axis + +ForFrame Serial(PrimExpr start, PrimExpr stop, + Optional> annotations = NullOpt); +ForFrame Parallel(PrimExpr start, PrimExpr stop, + Optional> annotations = NullOpt); +ForFrame Vectorized(PrimExpr start, PrimExpr stop, + Optional> annotations = NullOpt); +ForFrame Unroll(PrimExpr start, PrimExpr stop, + Optional> annotations = NullOpt); +ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread, + Optional> annotations = NullOpt); +ForFrame Grid(Array extents); + +PrimFuncFrame PrimFunc(); +Var Arg(String name, Var var); +Buffer Arg(String name, Buffer buffer); +void FuncName(String name); +void FuncAttrs(Map attrs); +Type FuncRet(Type ret_type); +Buffer MatchBuffer(ObjectRef param, Array shape, DataType dtype = DataType::Float(32), + Optional data = NullOpt, Array strides = {}, + PrimExpr elem_offset = PrimExpr(), String storage_scope = "global", + int align = -1, int offset_factor = 0, String buffer_type = "default", + Array axis_separators = {}); +void PreflattenedBuffer(Buffer postflattened_buffer, Array shape, + DataType dtype = DataType::Float(32), Optional data = NullOpt, + Array strides = {}, PrimExpr elem_offset = PrimExpr(), + String storage_scope = "global", int align = -1, int offset_factor = 0, + String buffer_type = "default", Array axis_separators = {}); + +AssertFrame Assert(PrimExpr condition, String message); +LetFrame Let(Var var, PrimExpr value); +AllocateFrame Allocate(Array extents, DataType dtype, String storage_scope = "", + Optional condition = NullOpt, + Optional> annotations = NullOpt); +AllocateConstFrame AllocateConst( + NDArray data, DataType dtype, Array extents, + Map annotations = NullValue>()); +RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope, PrimExpr condition); +AttrFrame Attr(ObjectRef node, String attr_key, PrimExpr value); +WhileFrame While(PrimExpr condition); +IfFrame If(PrimExpr condition); +ThenFrame Then(); +ElseFrame Else(); +LaunchThreadFrame LaunchThread(Var var, PrimExpr extent); +Var EnvThread(String thread_tag); +void BufferStore(Buffer buffer, PrimExpr value, Array indices); +void Prefetch(Buffer buffer, Array bounds); +void Evaluate(PrimExpr value); + +#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \ + inline PrimExpr FuncName(Optional expr = NullOpt) { \ + DataType dtype = DType; \ + return expr.defined() ? tvm::cast(dtype, expr.value()) : tvm::tir::Var("", dtype); \ + } + +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int8, DataType::Int(8)); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int16, DataType::Int(16)); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32, DataType::Int(32)); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int64, DataType::Int(64)); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt8, DataType::UInt(8)); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt16, DataType::UInt(16)); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt32, DataType::UInt(32)); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt64, DataType::UInt(64)); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float8, DataType::Float(8)); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float16, DataType::Float(16)); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float32, DataType::Float(32)); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float64, DataType::Float(64)); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x4, DataType::Int(32, 4)); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x8, DataType::Int(32, 8)); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x16, DataType::Int(32, 16)); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool()); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Handle, DataType::Handle()); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void()); + +#undef TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST + +} // namespace tir +} // namespace ir_builder +} // namespace script +} // namespace tvm + +#endif // TVM_TIR_IR_BUILDER_H_ diff --git a/python/tvm/script/__init__.py b/python/tvm/script/__init__.py index 62279b46c1..309a8d2741 100644 --- a/python/tvm/script/__init__.py +++ b/python/tvm/script/__init__.py @@ -16,7 +16,5 @@ # under the License. """TVM Script APIs of TVM Python Package, aimed to support TIR""" -from . import tir -from . import relax - -from .parser import ir_module, from_source +from . import ir_builder, relax, tir +from .parser import from_source, ir_module diff --git a/python/tvm/script/ir_builder/__init__.py b/python/tvm/script/ir_builder/__init__.py new file mode 100644 index 0000000000..237e19e5d7 --- /dev/null +++ b/python/tvm/script/ir_builder/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.script.ir_builder is a generic IR builder for TVM.""" +from . import tir +from .ir import ir_module diff --git a/python/tvm/script/ir_builder/_ffi_api.py b/python/tvm/script/ir_builder/_ffi_api.py new file mode 100644 index 0000000000..68811c9e01 --- /dev/null +++ b/python/tvm/script/ir_builder/_ffi_api.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.script.ir_builder""" +import tvm._ffi + +tvm._ffi._init_api("script.ir_builder", __name__) # pylint: disable=protected-access diff --git a/python/tvm/script/ir_builder/base.py b/python/tvm/script/ir_builder/base.py new file mode 100644 index 0000000000..d8b965d03b --- /dev/null +++ b/python/tvm/script/ir_builder/base.py @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +"""A generic IRBuilder across the TVM stack""" +from typing import List, TypeVar + +from tvm._ffi import register_object as _register_object +from tvm.runtime import Object as _Object + +from . import _ffi_api + + +@_register_object("script.ir_builder.IRBuilderFrame") +class IRBuilderFrame(_Object): + def __enter__(self) -> "IRBuilderFrame": + _ffi_api.IRBuilderFrameEnter(self) # pylint: disable=no-member # type: ignore + return self + + def __exit__(self, ptype, value, trace) -> None: # pylint: disable=unused-argument + _ffi_api.IRBuilderFrameExit(self) # pylint: disable=no-member # type: ignore + + def add_callback(self, callback) -> None: # pylint: disable=unused-argument + _ffi_api.IRBuilderFrameAddCallback( # pylint: disable=no-member # type: ignore + self, callback + ) + + +@_register_object("script.ir_builder.IRBuilder") +class IRBuilder(_Object): + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.IRBuilder # pylint: disable=no-member # type: ignore + ) + + def __enter__(self) -> "IRBuilder": + _ffi_api.IRBuilderEnter(self) # pylint: disable=no-member # type: ignore + return self + + def __exit__(self, ptype, value, trace) -> None: # pylint: disable=unused-argument + _ffi_api.IRBuilderExit(self) # pylint: disable=no-member # type: ignore + + @staticmethod + def current() -> "IRBuilder": + return _ffi_api.IRBuilderCurrent() # pylint: disable=no-member # type: ignore + + def get(self) -> _Object: + return _ffi_api.IRBuilderGet(self) # pylint: disable=no-member # type: ignore + + +DefType = TypeVar("DefType", bound=_Object) + + +def name(s: str, v: DefType) -> DefType: + return _ffi_api.IRBuilderName(s, v) # pylint: disable=no-member # type: ignore + + +def name_many( # pylint: disable=invalid-name + s: List[str], + vs: List[DefType], +) -> List[DefType]: + assert len(s) == len(vs) + return [name(i, v) for i, v in zip(s, vs)] diff --git a/python/tvm/script/ir_builder/ir/__init__.py b/python/tvm/script/ir_builder/ir/__init__.py new file mode 100644 index 0000000000..ebb9728737 --- /dev/null +++ b/python/tvm/script/ir_builder/ir/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Package tvm.script.ir_builder.ir""" +from .frame import IRModuleFrame +from .ir import ir_module diff --git a/python/tvm/script/ir_builder/ir/_ffi_api.py b/python/tvm/script/ir_builder/ir/_ffi_api.py new file mode 100644 index 0000000000..874cc278af --- /dev/null +++ b/python/tvm/script/ir_builder/ir/_ffi_api.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs""" +import tvm._ffi + +tvm._ffi._init_api("script.ir_builder.ir", __name__) # pylint: disable=protected-access diff --git a/python/tvm/script/ir_builder/ir/frame.py b/python/tvm/script/ir_builder/ir/frame.py new file mode 100644 index 0000000000..e16d86dc22 --- /dev/null +++ b/python/tvm/script/ir_builder/ir/frame.py @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Package tvm.script.ir_builder.ir.frame""" + +from tvm._ffi import register_object as _register_object + +from ..base import IRBuilderFrame + + +@_register_object("script.ir_builder.IRModuleFrame") +class IRModuleFrame(IRBuilderFrame): + ... diff --git a/python/tvm/script/ir_builder/ir/ir.py b/python/tvm/script/ir_builder/ir/ir.py new file mode 100644 index 0000000000..df92036435 --- /dev/null +++ b/python/tvm/script/ir_builder/ir/ir.py @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Package tvm.script.ir_builder.ir.ir""" + +from . import _ffi_api +from .frame import IRModuleFrame + + +def ir_module() -> IRModuleFrame: + return _ffi_api.IRModule() # pylint: disable=no-member # type: ignore diff --git a/python/tvm/script/ir_builder/tir/__init__.py b/python/tvm/script/ir_builder/tir/__init__.py new file mode 100644 index 0000000000..2ba5df8a8e --- /dev/null +++ b/python/tvm/script/ir_builder/tir/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Package tvm.script.ir_builder.tir""" +from . import frame + +# from .ir import diff --git a/python/tvm/script/ir_builder/tir/_ffi_api.py b/python/tvm/script/ir_builder/tir/_ffi_api.py new file mode 100644 index 0000000000..876f5f3a35 --- /dev/null +++ b/python/tvm/script/ir_builder/tir/_ffi_api.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs""" +import tvm._ffi + +tvm._ffi._init_api("script.ir_builder.tir", __name__) # pylint: disable=protected-access diff --git a/python/tvm/script/ir_builder/tir/frame.py b/python/tvm/script/ir_builder/tir/frame.py new file mode 100644 index 0000000000..22b03ccdd4 --- /dev/null +++ b/python/tvm/script/ir_builder/tir/frame.py @@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""IRBuilder for TIR""" +from typing import List + +from tvm._ffi import register_object as _register_object +from tvm.tir import Buffer, Var + +from ..base import IRBuilderFrame + + +@_register_object("script.ir_builder.tir.TIRFrame") +class TIRFrame(IRBuilderFrame): + ... + + +@_register_object("script.ir_builder.tir.BlockFrame") +class BlockFrame(TIRFrame): + ... + + +@_register_object("script.ir_builder.tir.BlockInitFrame") +class BlockInitFrame(TIRFrame): + ... + + +@_register_object("script.ir_builder.tir.ForFrame") +class ForFrame(TIRFrame): + def __enter__(self) -> List[Var]: + super().__enter__() + return self.vars if len(self.vars) > 1 else self.vars[0] + + +@_register_object("script.ir_builder.tir.PrimFuncFrame") +class PrimFuncFrame(TIRFrame): + ... + + +@_register_object("script.ir_builder.tir.AssertFrame") +class AssertFrame(TIRFrame): + ... + + +@_register_object("script.ir_builder.tir.LetFrame") +class LetFrame(TIRFrame): + ... + + +@_register_object("script.ir_builder.tir.AllocateFrame") +class AllocateFrame(TIRFrame): + def __enter__(self) -> Buffer: + super().__enter__() + return self.buffer + + +@_register_object("script.ir_builder.tir.AllocateConstFrame") +class AllocateConstFrame(TIRFrame): + def __enter__(self) -> Buffer: + super().__enter__() + return self.buffer + + +@_register_object("script.ir_builder.tir.LaunchThreadFrame") +class LaunchThreadFrame(TIRFrame): + ... + + +@_register_object("script.ir_builder.tir.RealizeFrame") +class RealizeFrame(TIRFrame): + ... + + +@_register_object("script.ir_builder.tir.AttrFrame") +class AttrFrame(TIRFrame): + ... + + +@_register_object("script.ir_builder.tir.WhileFrame") +class WhileFrame(TIRFrame): + ... + + +@_register_object("script.ir_builder.tir.IfFrame") +class IfFrame(TIRFrame): + ... + + +@_register_object("script.ir_builder.tir.ThenFrame") +class ThenFrame(TIRFrame): + ... + + +@_register_object("script.ir_builder.tir.ElseFrame") +class ElseFrame(TIRFrame): + ... + + +@_register_object("script.ir_builder.tir.DeclBufferFrame") +class DeclBufferFrame(TIRFrame): + def __enter__(self) -> Buffer: + super().__enter__() + return self.buffer diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py new file mode 100644 index 0000000000..ebd764cf1d --- /dev/null +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -0,0 +1,954 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +"""IRBuilder for TIR""" +import functools +import inspect +from numbers import Integral +from typing import Any, Dict, List, Optional, Union + +import numpy as np +from tvm.ir import Range, Type +from tvm.runtime import convert, ndarray +from tvm.tir import Broadcast as broadcast +from tvm.tir import ( + Buffer, + BufferLoad, + BufferRegion, + Cast, + CommReducer, + IntImm, + IterVar, + Let, + PrimExpr, +) +from tvm.tir import Ramp as ramp +from tvm.tir import Select, Shuffle, StringImm, Var, cast +from tvm.tir import op as _tir_op +from tvm.tir import type_annotation + +from . import _ffi_api, frame + + +def buffer_decl( + shape, + dtype="float32", + data=None, + strides=None, + elem_offset=None, + scope="", + align=0, + offset_factor=0, + buffer_type="", + axis_separators=None, +) -> Buffer: + shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape + return _ffi_api.BufferDecl( # pylint: disable=no-member # type: ignore + shape, + dtype, + "", + data, + strides, + elem_offset, + scope, + align, + offset_factor, + buffer_type, + axis_separators, + ) + + +def ptr(dtype, storage_scope="global"): + return _ffi_api.Ptr(dtype, storage_scope) # pylint: disable=no-member # type: ignore + + +def block(name: str = "", no_realize: bool = False) -> frame.BlockFrame: + return _ffi_api.Block(name, no_realize) # pylint: disable=no-member # type: ignore + + +def init() -> frame.BlockInitFrame: + return _ffi_api.Init() # pylint: disable=no-member # type: ignore + + +def where(predicate) -> None: + if isinstance(predicate, bool): + predicate = IntImm("bool", predicate) + if isinstance(predicate, int): + if predicate in [0, 1]: + predicate = IntImm("bool", predicate) + else: + raise ValueError("Invalid value for predicate: {}".format(predicate)) + _ffi_api.Where(predicate) # pylint: disable=no-member # type: ignore + + +def reads(*buffer_slices: List[Union[BufferRegion, BufferLoad]]) -> None: + if len(buffer_slices) == 1: + if isinstance(buffer_slices[0], tuple): + buffer_slices = list(buffer_slices[0]) + elif isinstance(buffer_slices[0], list): + buffer_slices = buffer_slices[0] # type: ignore + else: + buffer_slices = [buffer_slices[0]] # type: ignore + else: + buffer_slices = list(buffer_slices) # type: ignore + _ffi_api.Reads(buffer_slices) # pylint: disable=no-member # type: ignore + + +def writes(*buffer_slices: List[Union[BufferRegion, BufferLoad]]) -> None: + if len(buffer_slices) == 1: + if isinstance(buffer_slices[0], tuple): + buffer_slices = list(buffer_slices[0]) + elif isinstance(buffer_slices[0], list): + buffer_slices = buffer_slices[0] # type: ignore + else: + buffer_slices = [buffer_slices[0]] + else: + buffer_slices = list(buffer_slices) # type: ignore + _ffi_api.Writes(buffer_slices) # pylint: disable=no-member # type: ignore + + +def block_attr(attrs: Dict[str, Any]) -> None: + return _ffi_api.BlockAttrs(attrs) # pylint: disable=no-member # type: ignore + + +def alloc_buffer( + shape, + dtype="float32", + data=None, + strides=None, + elem_offset=None, + scope="", + align=-1, + offset_factor=0, + buffer_type="default", + axis_separators=None, +) -> Buffer: + shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape + if strides is None: + strides = [] + return _ffi_api.AllocBuffer( # pylint: disable=no-member # type: ignore + shape, + dtype, + data, + strides, + elem_offset, + scope, + align, + offset_factor, + buffer_type, + axis_separators, + ) + + +def _as_range(dom) -> Range: + if isinstance(dom, Range): + return dom + if isinstance(dom, (list, tuple)): + return Range(dom[0], dom[1]) + return Range(0, dom) + + +class axis: # pylint: disable=invalid-name + @staticmethod + def spatial(dom, binding, dtype="int32") -> IterVar: + return _ffi_api.AxisSpatial( # pylint: disable=no-member # type: ignore + _as_range(dom), binding, dtype + ) + + @staticmethod + def reduce(dom, binding, dtype="int32") -> IterVar: + return _ffi_api.AxisReduce( # pylint: disable=no-member # type: ignore + _as_range(dom), binding, dtype + ) + + @staticmethod + def scan(dom, binding, dtype="int32") -> IterVar: + return _ffi_api.AxisScan( # pylint: disable=no-member # type: ignore + _as_range(dom), binding, dtype + ) + + @staticmethod + def opaque(dom, binding, dtype="int32") -> IterVar: + return _ffi_api.AxisOpaque( # pylint: disable=no-member # type: ignore + _as_range(dom), binding, dtype + ) + + @staticmethod + def remap(kinds, bindings, dtype="int32") -> Union[List[IterVar], IterVar]: + iter_vars = _ffi_api.AxisRemap( # pylint: disable=no-member # type: ignore + kinds, bindings, dtype + ) + return iter_vars[0] if len(iter_vars) == 1 else iter_vars + + S = spatial # pylint: disable=invalid-name + R = reduce # pylint: disable=invalid-name + + +def serial(start, stop=None, *, annotations=None) -> frame.ForFrame: + if stop is None: + stop = start + start = 0 + return _ffi_api.Serial(start, stop, annotations) # pylint: disable=no-member # type: ignore + + +def parallel(start, stop=None, *, annotations=None) -> frame.ForFrame: + if stop is None: + stop = start + start = 0 + return _ffi_api.Parallel(start, stop, annotations) # pylint: disable=no-member # type: ignore + + +def vectorized(start, stop=None, *, annotations=None) -> frame.ForFrame: + if stop is None: + stop = start + start = 0 + return _ffi_api.Vectorized(start, stop, annotations) # pylint: disable=no-member # type: ignore + + +def unroll(start, stop=None, *, annotations=None) -> frame.ForFrame: + if stop is None: + stop = start + start = 0 + return _ffi_api.Unroll(start, stop, annotations) # pylint: disable=no-member # type: ignore + + +def thread_binding( + start, + stop=None, + thread=None, + *, + annotations=None, +) -> frame.ForFrame: + if thread is None: + if not isinstance(stop, str): + raise ValueError("Thread cannot be None for thread_binding") + thread = stop + stop = start + start = 0 + elif stop is None: + stop = start + start = 0 + return _ffi_api.ThreadBinding( # pylint: disable=no-member # type: ignore + start, stop, thread, annotations + ) + + +def grid(*extents) -> frame.ForFrame: + return _ffi_api.Grid(extents) # pylint: disable=no-member # type: ignore + + +def prim_func() -> frame.PrimFuncFrame: + return _ffi_api.PrimFunc() # pylint: disable=no-member # type: ignore + + +def arg(name, obj): + return _ffi_api.Arg(name, obj) # pylint: disable=no-member # type: ignore + + +def func_name(name: str) -> str: + return _ffi_api.FuncName(name) # pylint: disable=no-member # type: ignore + + +def func_attr(attrs: Dict[str, Any]) -> None: + return _ffi_api.FuncAttrs(attrs) # pylint: disable=no-member # type: ignore + + +def func_ret(ret_type) -> Type: + return _ffi_api.FuncRet(ret_type) # pylint: disable=no-member # type: ignore + + +def match_buffer( + param, + shape, + dtype="float32", + data=None, + strides=None, + elem_offset=None, + scope="global", + align=-1, + offset_factor=0, + buffer_type="default", + axis_separators=None, +) -> Buffer: + shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape + if strides is None: + strides = [] + return _ffi_api.MatchBuffer( # pylint: disable=no-member # type: ignore + param, + shape, + dtype, + data, + strides, + elem_offset, + scope, + align, + offset_factor, + buffer_type, + axis_separators, + ) + + +def preflattened_buffer( + postflattened, + shape, + dtype="float32", + data=None, + strides=None, + elem_offset=None, + scope="global", + align=-1, + offset_factor=0, + buffer_type="default", + axis_separators=None, +) -> None: + shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape + if strides is None: + strides = [] + _ffi_api.PreflattenedBuffer( # pylint: disable=no-member # type: ignore + postflattened, + shape, + dtype, + data, + strides, + elem_offset, + scope, + align, + offset_factor, + buffer_type, + axis_separators, + ) + + +def Assert(condition: PrimExpr, message: str) -> frame.AssertFrame: # pylint: disable=invalid-name + return _ffi_api.Assert(condition, message) # pylint: disable=no-member # type: ignore + + +def let( + v: Var, + value: PrimExpr, + body: PrimExpr = None, +) -> frame.LetFrame: + if body is None: + return _ffi_api.Let(v, value) # pylint: disable=no-member # type: ignore + return Let(v, value, body) + + +def allocate( + extents: List[PrimExpr], + dtype: str, + scope: str = "", + condition: PrimExpr = None, + annotations=None, +) -> frame.AllocateFrame: + if isinstance(condition, bool): + condition = IntImm("bool", condition) + return _ffi_api.Allocate( # pylint: disable=no-member # type: ignore + extents, dtype, scope, condition, annotations + ) + + +def allocate_const( + data: List[PrimExpr], + dtype: str, + extents: List[PrimExpr], + annotations=None, +) -> frame.AllocateConstFrame: + + return _ffi_api.AllocateConst( # pylint: disable=no-member # type: ignore + ndarray.array(np.asarray(data, dtype)), dtype, extents, annotations + ) + + +def realize( + buffer_slice: BufferRegion, + storage_scope: str, + condition: PrimExpr = True, +) -> frame.RealizeFrame: + return _ffi_api.Realize( # pylint: disable=no-member # type: ignore + buffer_slice, storage_scope, condition + ) + + +def attr(node: Any, attr_key: str, value: Union[PrimExpr, str]) -> frame.AttrFrame: + node = convert(node) + value = convert(value) + return _ffi_api.Attr(node, attr_key, value) # pylint: disable=no-member # type: ignore + + +def While(condition: PrimExpr) -> frame.WhileFrame: # pylint: disable=invalid-name + if isinstance(condition, bool): + condition = IntImm("bool", condition) + return _ffi_api.While(condition) # pylint: disable=no-member # type: ignore + + +def If(condition: PrimExpr) -> frame.IfFrame: # pylint: disable=invalid-name + if isinstance(condition, bool): + condition = IntImm("bool", condition) + return _ffi_api.If(condition) # pylint: disable=no-member # type: ignore + + +def Then() -> frame.ThenFrame: # pylint: disable=invalid-name + return _ffi_api.Then() # pylint: disable=no-member # type: ignore + + +def Else() -> frame.ElseFrame: # pylint: disable=invalid-name + return _ffi_api.Else() # pylint: disable=no-member # type: ignore + + +def decl_buffer( + shape, + dtype="float32", + data=None, + strides=None, + elem_offset=None, + scope="", + align=0, + offset_factor=0, + buffer_type="", + axis_separators=None, +) -> frame.DeclBufferFrame: + + shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape + return _ffi_api.DeclBuffer( # pylint: disable=no-member # type: ignore + shape, + dtype, + "", + data, + strides, + elem_offset, + scope, + align, + offset_factor, + buffer_type, + axis_separators, + ) + + +def launch_thread( + iter_var: IterVar, # pylint: disable=redefined-outer-name + extent: PrimExpr, +) -> frame.LaunchThreadFrame: + return _ffi_api.LaunchThread(iter_var, extent) # pylint: disable=no-member # type: ignore + + +def env_thread(thread_tag: str) -> IterVar: + return _ffi_api.EnvThread(thread_tag) # pylint: disable=no-member # type: ignore + + +def buffer_store(buffer: Buffer, value: PrimExpr, indices: List[Union[PrimExpr, slice]]) -> None: + from tvm.arith import Analyzer # pylint: disable=import-outside-toplevel + + expr_indices = [] + for index in indices: + if isinstance(index, slice): + step = 1 if index.step is None else index.step + lanes = Analyzer().simplify((index.stop - index.start + step - 1) // step) + if lanes == 1: + expr_indices.append(index.start) + else: + expr_indices.append(ramp(index.start, step, int(lanes))) + else: + expr_indices.append(index) + if isinstance(value, bool) and buffer.dtype == "bool": + value = IntImm("bool", value) + return _ffi_api.BufferStore( # pylint: disable=no-member # type: ignore + buffer, value, expr_indices + ) + + +def prefetch(buffer: Buffer, indices: List[PrimExpr]) -> None: + return _ffi_api.Prefetch(buffer, indices) # pylint: disable=no-member # type: ignore + + +def evaluate(value: PrimExpr) -> None: + if isinstance(value, str): + value = StringImm(value) + return _ffi_api.Evaluate(value) # pylint: disable=no-member # type: ignore + + +def int8(expr: Optional[PrimExpr] = None) -> PrimExpr: + return _ffi_api.Int8(expr) # pylint: disable=no-member # type: ignore + + +def int16(expr: Optional[PrimExpr] = None) -> PrimExpr: + return _ffi_api.Int16(expr) # pylint: disable=no-member # type: ignore + + +def int32(expr: Optional[PrimExpr] = None) -> PrimExpr: + return _ffi_api.Int32(expr) # pylint: disable=no-member # type: ignore + + +def int64(expr: Optional[PrimExpr] = None) -> PrimExpr: + return _ffi_api.Int64(expr) # pylint: disable=no-member # type: ignore + + +def uint8(expr: Optional[PrimExpr] = None) -> PrimExpr: + return _ffi_api.UInt8(expr) # pylint: disable=no-member # type: ignore + + +def uint16(expr: Optional[PrimExpr] = None) -> PrimExpr: + return _ffi_api.UInt16(expr) # pylint: disable=no-member # type: ignore + + +def uint32(expr: Optional[PrimExpr] = None) -> PrimExpr: + return _ffi_api.UInt32(expr) # pylint: disable=no-member # type: ignore + + +def uint64(expr: Optional[PrimExpr] = None) -> PrimExpr: + return _ffi_api.UInt64(expr) # pylint: disable=no-member # type: ignore + + +def float8(expr: Optional[PrimExpr] = None) -> PrimExpr: + if not isinstance(expr, PrimExpr): + expr = convert(expr) + return _ffi_api.Float8(expr) # pylint: disable=no-member # type: ignore + + +def float16(expr: Optional[PrimExpr] = None) -> PrimExpr: + if not isinstance(expr, PrimExpr): + expr = convert(expr) + return _ffi_api.Float16(expr) # pylint: disable=no-member # type: ignore + + +def float32(expr: Optional[PrimExpr] = None) -> PrimExpr: + if not isinstance(expr, PrimExpr): + expr = convert(expr) + return _ffi_api.Float32(expr) # pylint: disable=no-member # type: ignore + + +def float64(expr: Optional[PrimExpr] = None) -> PrimExpr: + if not isinstance(expr, PrimExpr): + expr = convert(expr) + return _ffi_api.Float64(expr) # pylint: disable=no-member # type: ignore + + +def int32x4(expr: Optional[PrimExpr] = None) -> PrimExpr: + return _ffi_api.Int32x4(expr) # pylint: disable=no-member # type: ignore + + +def int32x8(expr: Optional[PrimExpr] = None) -> PrimExpr: + return _ffi_api.Int32x8(expr) # pylint: disable=no-member # type: ignore + + +def int32x16(expr: Optional[PrimExpr] = None) -> PrimExpr: + return _ffi_api.Int32x16(expr) # pylint: disable=no-member # type: ignore + + +def boolean(expr: Optional[PrimExpr] = None) -> PrimExpr: + return _ffi_api.Boolean(expr) # pylint: disable=no-member # type: ignore + + +def handle(expr: Optional[PrimExpr] = None) -> PrimExpr: + return _ffi_api.Handle(expr) # pylint: disable=no-member # type: ignore + + +def void(expr: Optional[PrimExpr] = None) -> PrimExpr: + return _ffi_api.Void(expr) # pylint: disable=no-member # type: ignore + + +def min(a, b): # pylint: disable=redefined-builtin + """Compute the minimum value of two expressions. + + Parameters + ---------- + a : PrimExpr + The left hand operand + + b : PrimExpr + The right hand operand + + Returns + ------- + res : PrimExpr + The result expression. + """ + return _ffi_api.min(a, b) # pylint: disable=no-member # type: ignore + + +def max(a, b): # pylint: disable=redefined-builtin + """Compute the maximum value of two expressions. + + Parameters + ---------- + a : PrimExpr + The left hand operand + + b : PrimExpr + The right hand operand + + Returns + ------- + res : PrimExpr + The result expression. + """ + return _ffi_api.max(a, b) # pylint: disable=no-member # type: ignore + + +def var(dtype, name="") -> Var: + return Var(name, dtype) # pylint: disable=no-member # type: ignore + + +def iter_var(v, dom, iter_type, thread_tag): + iter_type = getattr(IterVar, iter_type) + return IterVar(dom, v, iter_type, thread_tag) + + +def comm_reducer(combiner, identity): + """Create a CommReducer from lambda inputs/outputs and the identities""" + params = inspect.signature(combiner).parameters + num_args = len(params) + args = [] + for name, i in zip(params.keys(), identity + identity): + if isinstance(i, int): + args.append(Var(name, "int32")) + else: + args.append(Var(name, i.dtype)) + res = combiner(*args) + if not isinstance(res, tuple): + res = (res,) + return CommReducer(args[: num_args // 2], args[num_args // 2 :], res, identity) + + +def llvm_lookup_intrinsic_id(name): + # pylint: disable=import-outside-toplevel + from tvm.target.codegen import llvm_lookup_intrinsic_id as f + + # pylint: enable=import-outside-toplevel + return f(name) + + +def _op_wrapper(func): + @functools.wraps(func) + def wrapped(*args, **kwargs): + if "dtype" in kwargs: + kwargs.pop("dtype") + return func(*args, **kwargs) + + return wrapped + + +def _dtype_forward(func): + @functools.wraps(func) + def wrapped(*args, **kwargs): + if "dtype" in kwargs: + args = (kwargs.pop("dtype"),) + args + return func(*args, **kwargs) + + return wrapped + + +# pylint: disable=invalid-name + +buffer_var = ptr +abs = _op_wrapper(_tir_op.abs) # pylint: disable=redefined-builtin +fabs = abs +acos = _op_wrapper(_tir_op.acos) +acosh = _op_wrapper(_tir_op.acosh) +address_of = _op_wrapper(_tir_op.address_of) +asin = _op_wrapper(_tir_op.asin) +asinh = _op_wrapper(_tir_op.asinh) +atan = _op_wrapper(_tir_op.atan) +atan2 = _op_wrapper(_tir_op.atan2) +atanh = _op_wrapper(_tir_op.atanh) +ceil = _op_wrapper(_tir_op.ceil) +clz = _op_wrapper(_tir_op.clz) +copysign = _op_wrapper(_tir_op.copysign) +cos = _op_wrapper(_tir_op.cos) +cosh = _op_wrapper(_tir_op.cosh) +erf = _op_wrapper(_tir_op.erf) +exp = _op_wrapper(_tir_op.exp) +exp2 = _op_wrapper(_tir_op.exp2) +exp10 = _op_wrapper(_tir_op.exp10) +floor = _op_wrapper(_tir_op.floor) +ceildiv = _op_wrapper(_tir_op.ceildiv) +floordiv = _op_wrapper(_tir_op.floordiv) +floormod = _op_wrapper(_tir_op.floormod) +fmod = _op_wrapper(_tir_op.fmod) +hypot = _op_wrapper(_tir_op.hypot) +if_then_else = _op_wrapper(_tir_op.if_then_else) +infinity = _op_wrapper(_tir_op.infinity) +isfinite = _op_wrapper(_tir_op.isfinite) +isinf = _op_wrapper(_tir_op.isinf) +isnan = _op_wrapper(_tir_op.isnan) +isnullptr = _op_wrapper(_tir_op.isnullptr) +ldexp = _op_wrapper(_tir_op.ldexp) +likely = _op_wrapper(_tir_op.likely) +log = _op_wrapper(_tir_op.log) +log1p = _op_wrapper(_tir_op.log1p) +log2 = _op_wrapper(_tir_op.log2) +log10 = _op_wrapper(_tir_op.log10) +lookup_param = _op_wrapper(_tir_op.lookup_param) +max_value = _op_wrapper(_tir_op.max_value) +min_value = _op_wrapper(_tir_op.min_value) +nearbyint = _op_wrapper(_tir_op.nearbyint) +nextafter = _op_wrapper(_tir_op.nextafter) +popcount = _op_wrapper(_tir_op.popcount) +power = _op_wrapper(_tir_op.power) +q_multiply_shift = _op_wrapper(_tir_op.q_multiply_shift) +ret = _op_wrapper(_tir_op.ret) +reinterpret = _dtype_forward(_tir_op.reinterpret) +round = _op_wrapper(_tir_op.round) # pylint: disable=redefined-builtin +rsqrt = _op_wrapper(_tir_op.rsqrt) +shift_left = _op_wrapper(_tir_op.shift_left) +shift_right = _op_wrapper(_tir_op.shift_right) +sigmoid = _op_wrapper(_tir_op.sigmoid) +sin = _op_wrapper(_tir_op.sin) +sinh = _op_wrapper(_tir_op.sinh) +sqrt = _op_wrapper(_tir_op.sqrt) +tan = _op_wrapper(_tir_op.tan) +tanh = _op_wrapper(_tir_op.tanh) +trunc = _op_wrapper(_tir_op.trunc) +truncdiv = _op_wrapper(_tir_op.truncdiv) +truncmod = _op_wrapper(_tir_op.truncmod) +tvm_access_ptr = _op_wrapper(_tir_op.tvm_access_ptr) +tvm_throw_last_error = _op_wrapper(_tir_op.tvm_throw_last_error) +tvm_stack_alloca = _op_wrapper(_tir_op.tvm_stack_alloca) +tvm_stack_make_shape = _op_wrapper(_tir_op.tvm_stack_make_shape) +tvm_stack_make_array = _op_wrapper(_tir_op.tvm_stack_make_array) +call_packed = _op_wrapper(_tir_op.call_packed) +call_cpacked = _op_wrapper(_tir_op.call_cpacked) +call_packed_lowered = _op_wrapper(_tir_op.call_packed_lowered) +call_cpacked_lowered = _op_wrapper(_tir_op.call_cpacked_lowered) +call_extern = _dtype_forward(_tir_op.call_extern) +call_intrin = _dtype_forward(_tir_op.call_intrin) +call_llvm_intrin = _dtype_forward(_tir_op.call_llvm_intrin) +call_llvm_pure_intrin = _dtype_forward(_tir_op.call_llvm_pure_intrin) +call_pure_extern = _dtype_forward(_tir_op.call_pure_extern) +tvm_access_ptr = _op_wrapper(_tir_op.tvm_access_ptr) +tvm_tuple = _op_wrapper(_tir_op.tvm_tuple) +tvm_struct_set = _op_wrapper(_tir_op.tvm_struct_set) +tvm_struct_get = _tir_op.tvm_struct_get +tvm_thread_allreduce = _op_wrapper(_tir_op.tvm_thread_allreduce) +tvm_load_matrix_sync = _op_wrapper(_tir_op.tvm_load_matrix_sync) +tvm_mma_sync = _op_wrapper(_tir_op.tvm_mma_sync) +tvm_bmma_sync = _op_wrapper(_tir_op.tvm_bmma_sync) +tvm_fill_fragment = _op_wrapper(_tir_op.tvm_fill_fragment) +tvm_store_matrix_sync = _op_wrapper(_tir_op.tvm_store_matrix_sync) +ptx_mma = _dtype_forward(_tir_op.ptx_mma) +ptx_mma_sp = _dtype_forward(_tir_op.ptx_mma_sp) +ptx_ldmatrix = _dtype_forward(_tir_op.ptx_ldmatrix) +ptx_cp_async = _dtype_forward(_tir_op.ptx_cp_async) +ptx_wait_group = _op_wrapper(_tir_op.ptx_wait_group) +ptx_commit_group = _op_wrapper(_tir_op.ptx_commit_group) +mma_store = _dtype_forward(_tir_op.mma_store) +mma_fill = _dtype_forward(_tir_op.mma_fill) +vectorlow = _dtype_forward(_tir_op.vectorlow) +vectorhigh = _dtype_forward(_tir_op.vectorhigh) +vectorcombine = _dtype_forward(_tir_op.vectorcombine) +assume = _op_wrapper(_tir_op.assume) +undef = _op_wrapper(_tir_op.undef) +tvm_call_packed = call_packed +tvm_call_cpacked = call_cpacked +tvm_call_packed_lowered = call_packed_lowered +tvm_call_cpacked_lowered = call_cpacked_lowered +TVMBackendAllocWorkspace = _op_wrapper(_tir_op.TVMBackendAllocWorkspace) +TVMBackendFreeWorkspace = _op_wrapper(_tir_op.TVMBackendFreeWorkspace) + + +class inline: + def __init__(self, value) -> None: + self.value = value + self.i = 0 + + def __iter__(self): + def f(): + for i in self.value: + yield inline(i) + + return f() + + +# pylint: enable=invalid-name + + +__all__ = [ + "Assert", + "Cast", + "Else", + "If", + "Let", + "Select", + "Shuffle", + "TVMBackendAllocWorkspace", + "TVMBackendFreeWorkspace", + "Then", + "While", + "abs", + "acos", + "acosh", + "address_of", + "alloc_buffer", + "allocate", + "allocate_const", + "arg", + "asin", + "asinh", + "assume", + "atan", + "atan2", + "atanh", + "attr", + "axis", + "block", + "block_attr", + "boolean", + "broadcast", + "buffer_decl", + "buffer_store", + "buffer_var", + "call_cpacked", + "call_cpacked_lowered", + "call_extern", + "call_intrin", + "call_llvm_intrin", + "call_llvm_pure_intrin", + "call_packed", + "call_packed_lowered", + "call_pure_extern", + "cast", + "ceil", + "ceildiv", + "clz", + "comm_reducer", + "copysign", + "cos", + "cosh", + "env_thread", + "erf", + "evaluate", + "exp", + "exp10", + "exp2", + "decl_buffer", + "fabs", + "float16", + "float32", + "float64", + "float8", + "floor", + "floordiv", + "floormod", + "fmod", + "func_attr", + "func_name", + "func_ret", + "grid", + "handle", + "hypot", + "if_then_else", + "infinity", + "init", + "inline", + "int16", + "int32", + "int32x16", + "int32x4", + "int32x8", + "int64", + "int8", + "isfinite", + "isinf", + "isnan", + "isnullptr", + "iter_var", + "launch_thread", + "ldexp", + "let", + "likely", + "llvm_lookup_intrinsic_id", + "log", + "log10", + "log1p", + "log2", + "lookup_param", + "match_buffer", + "max", + "max_value", + "min", + "min_value", + "mma_fill", + "mma_store", + "nearbyint", + "nextafter", + "parallel", + "popcount", + "power", + "prefetch", + "preflattened_buffer", + "prim_func", + "ptr", + "ptx_commit_group", + "ptx_cp_async", + "ptx_ldmatrix", + "ptx_mma", + "ptx_mma_sp", + "ptx_wait_group", + "q_multiply_shift", + "ramp", + "reads", + "realize", + "reinterpret", + "ret", + "round", + "rsqrt", + "serial", + "shift_left", + "shift_right", + "sigmoid", + "sin", + "sinh", + "sqrt", + "tan", + "tanh", + "thread_binding", + "trunc", + "truncdiv", + "truncmod", + "tvm_access_ptr", + "tvm_bmma_sync", + "tvm_call_cpacked", + "tvm_call_cpacked_lowered", + "tvm_call_packed", + "tvm_call_packed_lowered", + "tvm_fill_fragment", + "tvm_load_matrix_sync", + "tvm_mma_sync", + "tvm_stack_alloca", + "tvm_stack_make_array", + "tvm_stack_make_shape", + "tvm_store_matrix_sync", + "tvm_struct_get", + "tvm_struct_set", + "tvm_thread_allreduce", + "tvm_throw_last_error", + "tvm_tuple", + "type_annotation", + "uint16", + "uint32", + "uint64", + "uint8", + "undef", + "unroll", + "var", + "vectorcombine", + "vectorhigh", + "vectorized", + "vectorlow", + "void", + "where", + "writes", +] diff --git a/src/script/ir_builder/base.cc b/src/script/ir_builder/base.cc new file mode 100644 index 0000000000..8303efff4f --- /dev/null +++ b/src/script/ir_builder/base.cc @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include + +namespace tvm { +namespace script { +namespace ir_builder { + +void IRBuilderFrameNode::EnterWithScope() { + IRBuilder::Current()->frames.push_back(GetRef(this)); +} + +void IRBuilderFrameNode::ExitWithScope() { + for (auto it = callbacks.rbegin(); it != callbacks.rend(); ++it) { + (*it)(); + } + this->callbacks.clear(); + IRBuilder::Current()->frames.pop_back(); +} + +void IRBuilderFrameNode::AddCallback(runtime::TypedPackedFunc callback) { + if (IRBuilder::Current()->frames.empty()) { + LOG(FATAL) << "ValueError: No frames in Builder to add callback"; + } + IRBuilder::Current()->frames.back()->callbacks.push_back(callback); +} + +IRBuilder::IRBuilder() { + ObjectPtr n = make_object(); + n->frames.clear(); + n->result = NullOpt; + data_ = n; +} + +std::vector* ThreadLocalBuilderStack() { + thread_local std::vector stack; + return &stack; +} + +void IRBuilder::EnterWithScope() { + IRBuilderNode* n = this->get(); + CHECK(n->frames.empty()) << "ValueError: There are frame(s) left in the builder: " + << n->frames.size() + << ". Please use a fresh new builder every time building IRs"; + n->result = NullOpt; + std::vector* stack = ThreadLocalBuilderStack(); + stack->push_back(*this); +} + +void IRBuilder::ExitWithScope() { + std::vector* stack = ThreadLocalBuilderStack(); + ICHECK(!stack->empty()); + stack->pop_back(); +} + +IRBuilder IRBuilder::Current() { + std::vector* stack = ThreadLocalBuilderStack(); + CHECK(!stack->empty()) << "ValueError: No builder in current scope"; + return stack->back(); +} + +namespace details { + +Namer::FType& Namer::vtable() { + static FType inst; + return inst; +} + +void Namer::Name(ObjectRef node, String name) { + static const FType& f = vtable(); + CHECK(node.defined()) << "ValueError: Cannot name nullptr with: " << name; + CHECK(f.can_dispatch(node)) << "ValueError: Do not know how to name type \"" + << node->GetTypeKey(); + f(node, name); +} + +} // namespace details + +TVM_REGISTER_NODE_TYPE(IRBuilderFrameNode); +TVM_REGISTER_NODE_TYPE(IRBuilderNode); +TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderFrameEnter") + .set_body_method(&IRBuilderFrameNode::EnterWithScope); +TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderFrameExit") + .set_body_method(&IRBuilderFrameNode::ExitWithScope); +TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderFrameAddCallback") + .set_body_method(&IRBuilderFrameNode::AddCallback); +TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilder").set_body_typed([]() { return IRBuilder(); }); +TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderEnter").set_body_method(&IRBuilder::EnterWithScope); +TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderExit").set_body_method(&IRBuilder::ExitWithScope); +TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderCurrent").set_body_typed(IRBuilder::Current); +TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderGet") + .set_body_method(&IRBuilderNode::Get); +TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderName").set_body_typed(IRBuilder::Name); + +} // namespace ir_builder +} // namespace script +} // namespace tvm diff --git a/src/script/ir_builder/ir/frame.cc b/src/script/ir_builder/ir/frame.cc new file mode 100644 index 0000000000..c85e30544a --- /dev/null +++ b/src/script/ir_builder/ir/frame.cc @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include + +namespace tvm { +namespace script { +namespace ir_builder { + +void IRModuleFrameNode::ExitWithScope() { + ICHECK_EQ(functions.size(), global_vars.size()); + int n = functions.size(); + Map func_map; + for (int i = 0; i < n; ++i) { + func_map.Set(global_vars[i], functions[i]); + } + IRBuilder builder = IRBuilder::Current(); + ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; + builder->result = tvm::IRModule(func_map); +} + +TVM_REGISTER_NODE_TYPE(IRModuleFrameNode); + +} // namespace ir_builder +} // namespace script +} // namespace tvm diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc new file mode 100644 index 0000000000..bcaee5dcaa --- /dev/null +++ b/src/script/ir_builder/ir/ir.cc @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include + +namespace tvm { +namespace script { +namespace ir_builder { + +IRModuleFrame IRModule() { + ObjectPtr n = make_object(); + n->global_vars.clear(); + n->functions.clear(); + return IRModuleFrame(n); +} + +TVM_REGISTER_GLOBAL("script.ir_builder.IRModule").set_body_typed(IRModule); + +} // namespace ir_builder +} // namespace script +} // namespace tvm diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc new file mode 100644 index 0000000000..85a9c145a5 --- /dev/null +++ b/src/script/ir_builder/tir/frame.cc @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include + +#include "../../../tir/ir/script/script_complete.h" +#include "./utils.h" + +namespace tvm { +namespace script { +namespace ir_builder { +namespace tir { + +void BlockFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + Array tir_alloc_buffers; + for (const tvm::tir::Buffer& buffer : alloc_buffers) { + tir_alloc_buffers.push_back(buffer); + } + Map attrs = annotations.value_or({}); + if (int detect_access = (!reads.defined()) | (!writes.defined() << 1)) { + attrs.Set("tir.script_parsing_detect_access", tvm::IntImm(DataType::Int(64), detect_access)); + } + tvm::tir::Block block(iter_vars, reads.value_or(Array()), + writes.value_or(Array()), name, AsStmt(stmts), init, + tir_alloc_buffers, match_buffers, attrs); + if (no_realize) { + CHECK(iter_values.empty()) + << "ValueError: Block bindings are not allowed when `no_realize=True`"; + CHECK(!predicate.defined()) << "ValueError: `T.where` is not allowed when `no_realize=True`"; + AddToParent(block); + } else { + AddToParent(tvm::tir::BlockRealize(iter_values, predicate.value_or(Bool(true)), block)); + } +} + +void BlockInitFrameNode::EnterWithScope() { + BlockFrame frame = FindBlockFrame("T.init"); + if (frame->init.defined()) { + LOG(FATAL) << "ValueError: Duplicate block init declaration"; + } + TIRFrameNode::EnterWithScope(); +} + +void BlockInitFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + BlockFrame frame = FindBlockFrame("T.init"); + frame->init = AsStmt(stmts); +} + +void ForFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + AddToParent(this->f_make_for_loop(vars, doms, AsStmt(stmts))); +} + +void PrimFuncFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + tvm::tir::PrimFunc func( + /*params=*/args, + /*body=*/AsStmt(stmts), + /*ret_type=*/ret_type.value_or(TupleType::Empty()), + /*buffer_map=*/buffer_map, + /*preflattened_buffer_map=*/preflattened_buffer_map, + /*attrs=*/attrs.defined() ? DictAttrs(attrs.value()) : NullValue()); + func = tvm::tir::ScriptComplete(func, root_alloc_buffers); + IRBuilder builder = IRBuilder::Current(); + if (builder->frames.empty()) { + ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; + builder->result = func; + } else if (Optional opt_frame = builder->FindFrame()) { + IRModuleFrame frame = opt_frame.value(); + frame->global_vars.push_back(GlobalVar(name.value_or(""))); + frame->functions.push_back(func); + } else { + LOG(FATAL) << "ValueError: Cannot find where to insert PrimFunc"; + } +} + +void AssertFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + AddToParent(tvm::tir::AssertStmt(condition, message, AsStmt(stmts))); +} + +void LetFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + AddToParent(tvm::tir::LetStmt(var, value, AsStmt(stmts))); +} + +void AllocateFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + AddToParent(tvm::tir::Allocate(buffer->data, buffer->dtype, buffer->shape, condition, + AsStmt(stmts), annotations)); +} + +void AllocateConstFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + AddToParent( + tvm::tir::AllocateConst(buffer->data, dtype, extents, data, AsStmt(stmts), annotations)); +} + +void LaunchThreadFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + AddToParent(tvm::tir::AttrStmt(iter_var, attr_key, extent, AsStmt(stmts))); +} + +void RealizeFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + AddToParent(tvm::tir::AttrStmt(buffer_slice->buffer, "realize_scope", + tvm::tir::StringImm(storage_scope), + tvm::tir::BufferRealize(buffer_slice->buffer, buffer_slice->region, + condition, AsStmt(stmts)))); +} + +void AttrFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + AddToParent(tvm::tir::AttrStmt(node, attr_key, value, AsStmt(stmts))); +} + +void WhileFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + AddToParent(tvm::tir::While(condition, AsStmt(stmts))); +} + +void IfFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + if (!stmts.empty()) { + LOG(FATAL) << "stmt within IfThenElse frame should be either in ThenFrame or ElseFrame"; + } + if (!then_stmts.defined()) { + LOG(FATAL) << "IfThenElse frame should have at least one then branch"; + } + AddToParent(tvm::tir::IfThenElse( + condition, AsStmt(then_stmts.value()), + else_stmts.defined() ? AsStmt(else_stmts.value()) : tvm::tir::Stmt(nullptr))); +} + +void ThenFrameNode::EnterWithScope() { + IfFrame frame = FindIfFrame("T.then_"); + if (frame->then_stmts.defined()) { + LOG(FATAL) << "ValueError: Duplicate then branch declaration, previous one is " + << frame->then_stmts.value(); + } + TIRFrameNode::EnterWithScope(); +} + +void ThenFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + FindIfFrame("T.then_")->then_stmts = stmts; +} + +void ElseFrameNode::EnterWithScope() { + IfFrame frame = FindIfFrame("T.else_"); + if (!frame->then_stmts.defined()) { + LOG(FATAL) << "The else branch should follow then branch"; + } + if (frame->else_stmts.defined()) { + LOG(FATAL) << "ValueError: Duplicate else branch declaration, previous one is " + << frame->else_stmts.value(); + } + TIRFrameNode::EnterWithScope(); +} + +void ElseFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + FindIfFrame("T.else_")->else_stmts = stmts; +} + +void DeclBufferFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + AddToParent(tvm::tir::DeclBuffer(buffer, AsStmt(stmts))); +} + +TVM_REGISTER_NODE_TYPE(TIRFrameNode); +TVM_REGISTER_NODE_TYPE(BlockFrameNode); +TVM_REGISTER_NODE_TYPE(BlockInitFrameNode); +TVM_REGISTER_NODE_TYPE(ForFrameNode); +TVM_REGISTER_NODE_TYPE(PrimFuncFrameNode); +TVM_REGISTER_NODE_TYPE(AssertFrameNode); +TVM_REGISTER_NODE_TYPE(LetFrameNode); +TVM_REGISTER_NODE_TYPE(AllocateFrameNode); +TVM_REGISTER_NODE_TYPE(AllocateConstFrameNode); +TVM_REGISTER_NODE_TYPE(LaunchThreadFrameNode); +TVM_REGISTER_NODE_TYPE(RealizeFrameNode); +TVM_REGISTER_NODE_TYPE(AttrFrameNode); +TVM_REGISTER_NODE_TYPE(WhileFrameNode); +TVM_REGISTER_NODE_TYPE(IfFrameNode); +TVM_REGISTER_NODE_TYPE(ThenFrameNode); +TVM_REGISTER_NODE_TYPE(ElseFrameNode); +TVM_REGISTER_NODE_TYPE(DeclBufferFrameNode); + +} // namespace tir +} // namespace ir_builder +} // namespace script +} // namespace tvm diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc new file mode 100644 index 0000000000..463fad0c4c --- /dev/null +++ b/src/script/ir_builder/tir/ir.cc @@ -0,0 +1,665 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include + +#include "./utils.h" + +namespace tvm { +namespace script { +namespace ir_builder { +namespace tir { + +using tvm::tir::IterVar; + +Buffer BufferDecl(Array shape, DataType dtype, String buffer_name, Optional data, + Optional> strides, Optional elem_offset, + String storage_scope, int align, int offset_factor, String buffer_type, + Optional> axis_separators) { + Var buffer_data; + if (!data.defined()) { + DataType storage_dtype = dtype; + if (storage_dtype == DataType::Bool()) { + storage_dtype = DataType::Int(8); + } + buffer_data = tvm::tir::Var(buffer_name, PointerType(PrimType(storage_dtype), storage_scope)); + } else { + buffer_data = data.value(); + } + if (!elem_offset.defined() && offset_factor) { + DataType shape_dtype = shape.empty() ? DataType::Int(32) : shape[0]->dtype; + elem_offset = tvm::tir::Var("elem_offset", shape_dtype); + } + return Buffer(buffer_data, dtype, shape, strides.value_or(Array()), + elem_offset.value_or(PrimExpr()), buffer_name, align, offset_factor, + (buffer_type == "auto_broadcast") ? tvm::tir::kAutoBroadcast : tvm::tir::kDefault, + axis_separators.value_or(Array())); +} + +DeclBufferFrame DeclBuffer(Array shape, DataType dtype, String buffer_name, + Optional data, Optional> strides, + Optional elem_offset, String storage_scope, int align, + int offset_factor, String buffer_type, + Optional> axis_separators) { + ObjectPtr n = make_object(); + n->buffer = BufferDecl(shape, dtype, buffer_name, data, strides, elem_offset, storage_scope, + align, offset_factor, buffer_type, axis_separators); + return DeclBufferFrame(n); +} + +PrimExpr Ptr(runtime::DataType dtype, String storage_scope) { + return tvm::tir::Var("", tvm::PointerType(PrimType(dtype), storage_scope)); +} + +BlockFrame Block(String name, bool no_realize) { + ObjectPtr n = make_object(); + n->name = name; + n->iter_vars.clear(); + n->reads = NullOpt; + n->writes = NullOpt; + n->init = NullOpt; + n->alloc_buffers.clear(); + n->match_buffers.clear(); + n->annotations = NullOpt; + n->iter_values.clear(); + n->predicate = NullOpt; + n->no_realize = no_realize; + return BlockFrame(n); +} + +BlockInitFrame Init() { return BlockInitFrame(make_object()); } + +void Where(PrimExpr predicate) { + BlockFrame frame = FindBlockFrame("T.where"); + if (frame->predicate.defined()) { + LOG(FATAL) << "ValueError: Duplicate block predicate declaration, previous one is " + << frame->predicate; + } + frame->predicate = predicate; +} + +void Reads(Array buffer_slices) { + using namespace tvm::tir; + BlockFrame frame = FindBlockFrame("T.reads"); + if (frame->reads.defined()) { + LOG(FATAL) << "ValueError: Duplicate read region declaration, previous one is " << frame->reads; + } + Array reads; + for (const ObjectRef& obj : buffer_slices) { + if (const auto* buffer_region = obj.as()) { + reads.push_back(GetRef(buffer_region)); + } else if (const auto* buffer_load = obj.as()) { + reads.push_back(BufferRegionFromLoad(GetRef(buffer_load))); + } else { + LOG(FATAL) << "Invalid type for buffer reads."; + } + } + frame->reads = reads; +} + +void Writes(Array buffer_slices) { + using namespace tvm::tir; + BlockFrame frame = FindBlockFrame("T.writes"); + if (frame->writes.defined()) { + LOG(FATAL) << "ValueError: Duplicate write region declaration, previous one is " + << frame->writes; + } + Array writes; + for (const ObjectRef& obj : buffer_slices) { + if (const auto* buffer_region = obj.as()) { + writes.push_back(GetRef(buffer_region)); + } else if (const auto* buffer_load = obj.as()) { + writes.push_back(BufferRegionFromLoad(GetRef(buffer_load))); + } else { + LOG(FATAL) << "Invalid type for buffer writes."; + } + } + frame->writes = writes; +} + +void BlockAttrs(Map attrs) { + BlockFrame frame = FindBlockFrame("T.block_attr"); + if (frame->annotations.defined()) { + LOG(FATAL) << "ValueError: Duplicate block annotations, previous one is " << frame->annotations; + } + frame->annotations = attrs; +} + +Buffer AllocBuffer(Array shape, DataType dtype, Optional data, + Array strides, PrimExpr elem_offset, String storage_scope, int align, + int offset_factor, String buffer_type_str, Array axis_separators) { + Buffer buffer = BufferDecl(shape, dtype, "", data, strides, elem_offset, storage_scope, align, + offset_factor, buffer_type_str, axis_separators); + IRBuilder builder = IRBuilder::Current(); + if (Optional frame = builder->GetLastFrame()) { + frame.value()->alloc_buffers.push_back(buffer); + } else if (Optional frame = builder->GetLastFrame()) { + frame.value()->root_alloc_buffers.push_back(buffer); + } else { + LOG(FATAL) << "ValueError: Block frame or PrimFunc frame not find. Please ensure " + "'T.alloc_buffer' is called under T.block() or T.prim_func()"; + } + return buffer; +} + +namespace axis { + +IterVar PushBlockVar(IterVar iter_var, PrimExpr binding) { + if (Optional opt_frame = IRBuilder::Current()->GetLastFrame()) { + BlockFrame frame = opt_frame.value(); + frame->iter_vars.push_back(iter_var); + frame->iter_values.push_back(binding); + } else { + LOG(FATAL) << "TypeError: The last frame is not BlockFrame"; + } + return iter_var; +} + +#define TVM_TIR_IR_BUILDER_AXIS(Method, Kind, Name) \ + Var Method(Range dom, PrimExpr binding, DataType dtype) { \ + ICHECK(dom.defined()) << Name << " axis must have a domain"; \ + int bits = std::max({dom->min.dtype().bits(), dom->extent.dtype().bits(), dtype.bits()}); \ + return PushBlockVar(IterVar(/*dom=*/dom, /*var=*/Var("", dtype.with_bits(bits)), \ + /*iter_type=*/Kind, /*thread_tag=*/""), \ + binding) \ + ->var; \ + } +TVM_TIR_IR_BUILDER_AXIS(Spatial, tvm::tir::IterVarType::kDataPar, "Spatial"); +TVM_TIR_IR_BUILDER_AXIS(Reduce, tvm::tir::IterVarType::kCommReduce, "Reduction"); +TVM_TIR_IR_BUILDER_AXIS(Scan, tvm::tir::IterVarType::kOrdered, "Scan"); +TVM_TIR_IR_BUILDER_AXIS(Opaque, tvm::tir::IterVarType::kOpaque, "Opaque"); +#undef TVM_TIR_IR_BUILDER_AXIS + +Array Remap(String kinds, Array bindings, DataType dtype) { + using namespace tvm::tir; + Array results; + ICHECK_EQ(kinds.size(), bindings.size()); + int n = bindings.size(); + results.reserve(n); + for (int i = 0; i < n; ++i) { + char c = kinds.c_str()[i]; + PrimExpr e = bindings[i]; + const VarNode* v = e.as(); + ICHECK(v) << "TypeError: Only Var is supported in T.axis.remap"; + Range dom{nullptr}; + for (const auto& frame : IRBuilder::Current()->frames) { + if (const auto* for_frame = frame.as()) { + ICHECK_EQ(for_frame->doms.size(), for_frame->vars.size()); + int n = for_frame->doms.size(); + for (int i = 0; i < n; ++i) { + if (for_frame->vars[i].get() == v) { + dom = for_frame->doms[i]; + break; + } + } + if (dom.defined()) { + break; + } + } + } + ICHECK(dom.defined()) << "TypeError: Variable is not in the loop: " << GetRef(v); + DataType dtype = v->dtype; + if (c == 'S') { + results.push_back(PushBlockVar(IterVar(/*dom=*/dom, + /*var=*/Var("", dtype), + /*iter_type=*/IterVarType::kDataPar, + /*thread_tag=*/""), + e) + ->var); + } else if (c == 'R') { + results.push_back(PushBlockVar(IterVar(/*dom=*/dom, + /*var=*/Var("", dtype), + /*iter_type=*/IterVarType::kCommReduce, + /*thread_tag=*/""), + e) + ->var); + } else { + LOG(FATAL) << "Unknown axis kind: " << c; + } + } + return results; +} + +} // namespace axis + +#define TVM_TIR_IR_BUILDER_FOR_FRAME(Method, Kind) \ + ForFrame Method(PrimExpr start, PrimExpr stop, Optional> annotations) { \ + PrimExpr min = start; \ + PrimExpr extent = arith::Analyzer().Simplify(stop - start); \ + ObjectPtr n = make_object(); \ + int bits = std::max(min.dtype().bits(), extent.dtype().bits()); \ + n->vars = {Var("v", DataType::Int(bits))}; \ + n->doms = {Range::FromMinExtent(min, extent)}; \ + n->f_make_for_loop = [annotations](Array vars, Array doms, tvm::tir::Stmt body) { \ + ICHECK_EQ(vars.size(), 1); \ + ICHECK_EQ(doms.size(), 1); \ + return tvm::tir::For(vars[0], doms[0]->min, doms[0]->extent, Kind, body, NullOpt, \ + annotations.value_or(Map())); \ + }; \ + return ForFrame(n); \ + } + +TVM_TIR_IR_BUILDER_FOR_FRAME(Serial, tvm::tir::ForKind::kSerial); +TVM_TIR_IR_BUILDER_FOR_FRAME(Parallel, tvm::tir::ForKind::kParallel); +TVM_TIR_IR_BUILDER_FOR_FRAME(Vectorized, tvm::tir::ForKind::kVectorized); +TVM_TIR_IR_BUILDER_FOR_FRAME(Unroll, tvm::tir::ForKind::kUnrolled); + +#undef TVM_TIR_IR_BUILDER_FOR_FRAME + +ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread, + Optional> annotations) { + using namespace tvm::tir; + PrimExpr min = start; + PrimExpr extent = arith::Analyzer().Simplify(stop - start); + ObjectPtr n = make_object(); + int bits = std::max(min.dtype().bits(), extent.dtype().bits()); + n->vars = {Var("v", DataType::Int(bits))}; + n->doms = {Range::FromMinExtent(min, extent)}; + n->f_make_for_loop = [annotations, thread](Array vars, Array doms, Stmt body) -> For { + ICHECK_EQ(vars.size(), 1); + ICHECK_EQ(doms.size(), 1); + IterVar iter_var(Range(nullptr), Var("iter", DataType::Int(32)), IterVarType::kThreadIndex, + thread); + return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kThreadBinding, body, iter_var, + annotations.value_or(Map())); + }; + return ForFrame(n); +} + +ForFrame Grid(Array extents) { + using namespace tvm::tir; + ObjectPtr n = make_object(); + n->vars.reserve(extents.size()); + n->doms.reserve(extents.size()); + for (const auto& extent : extents) { + DataType dtype = extent.dtype(); + n->vars.push_back(Var("v", extent.dtype())); + n->doms.push_back(Range(make_const(dtype, 0), extent)); + } + n->f_make_for_loop = [](Array vars, Array doms, Stmt body) -> Stmt { + ICHECK_EQ(vars.size(), doms.size()); + int n = vars.size(); + for (int i = n - 1; i >= 0; --i) { + Range dom = doms[i]; + Var var = vars[i]; + body = For(var, dom->min, dom->extent, ForKind::kSerial, std::move(body), + /*thread_binding=*/NullOpt, /*annotations=*/{}); + } + return body; + }; + return ForFrame(n); +} + +PrimFuncFrame PrimFunc() { + ObjectPtr n = make_object(); + n->name = NullOpt; + n->args.clear(); + n->ret_type = NullOpt; + n->buffer_map.clear(); + n->preflattened_buffer_map.clear(); + n->attrs = NullOpt; + n->env_threads.clear(); + n->root_alloc_buffers.clear(); + return PrimFuncFrame(n); +} + +Var Arg(String name, Var var) { + PrimFuncFrame frame = FindPrimFuncFrame("T.Arg"); + details::Namer::Name(var, name); + frame->args.push_back(var); + return var; +} + +Buffer Arg(String name, Buffer buffer) { + PrimFuncFrame frame = FindPrimFuncFrame("T.Arg"); + details::Namer::Name(buffer, name); + Var handle(buffer->name + "_handle", DataType::Handle()); + frame->args.push_back(handle); + frame->buffer_map.Set(handle, buffer); + return buffer; +} + +void FuncName(String name) { + PrimFuncFrame frame = FindPrimFuncFrame("T.func_name"); + if (frame->name.defined()) { + LOG(FATAL) << "ValueError: Duplicate prim func name, previous one is " << frame->name.value(); + } + frame->name = name; +} + +void FuncAttrs(Map attrs) { + using namespace tvm::tir; + PrimFuncFrame frame = FindPrimFuncFrame("T.func_attr"); + if (frame->attrs.defined()) { + LOG(FATAL) << "ValueError: Duplicate prim func annotations, previous one is " << frame->attrs; + } + frame->attrs = attrs; +} + +tvm::Type FuncRet(tvm::Type ret_type) { + PrimFuncFrame frame = FindPrimFuncFrame("T.ret_type"); + if (frame->ret_type.defined()) { + LOG(FATAL) << "ValueError: Duplicate prim func return type, previous one is " + << frame->ret_type.value(); + } + frame->ret_type = ret_type; + return ret_type; +} + +Buffer MatchBuffer(ObjectRef param, Array shape, DataType dtype, Optional data, + Array strides, PrimExpr elem_offset, String storage_scope, int align, + int offset_factor, String buffer_type_str, Array axis_separators) { + Buffer buffer = BufferDecl(shape, dtype, "", data, strides, elem_offset, storage_scope, align, + offset_factor, buffer_type_str, axis_separators); + if (const auto* var = param.as()) { + PrimFuncFrame frame = FindPrimFuncFrame("T.match_buffer"); + Var v = GetRef(var); + for (auto const& arg : frame->args) { + if (arg.same_as(v)) { + frame->buffer_map.Set(v, buffer); + return buffer; + } + } + LOG(FATAL) << "ValueError: Can not bind non-input param to buffer."; + } else if (const auto* buffer_load = param.as()) { + BlockFrame frame = FindBlockFrame("T.match_buffer"); + frame->match_buffers.push_back(tvm::tir::MatchBufferRegion( + buffer, BufferRegionFromLoad(GetRef(buffer_load)))); + } else if (const auto* buffer_region = param.as()) { + BlockFrame frame = FindBlockFrame("T.match_buffer"); + frame->match_buffers.push_back( + tvm::tir::MatchBufferRegion(buffer, GetRef(buffer_region))); + } else { + LOG(FATAL) << "ValueError: Unexpected type for TIR MatchBuffer."; + } + return buffer; +} + +void PreflattenedBuffer(Buffer postflattened_buffer, Array shape, DataType dtype, + Optional data, Array strides, PrimExpr elem_offset, + String storage_scope, int align, int offset_factor, String buffer_type_str, + Array axis_separators) { + PrimFuncFrame frame = FindPrimFuncFrame("T.preflattened_buffer"); + for (auto const& p : frame->buffer_map) { + if (p.second.same_as(postflattened_buffer)) { + String buffer_name(postflattened_buffer->name + "_preflatten"); + Buffer buffer = + BufferDecl(shape, dtype, buffer_name, data.value_or(p.second->data), strides, elem_offset, + storage_scope, align, offset_factor, buffer_type_str, axis_separators); + details::Namer::Name(buffer, buffer_name); + frame->preflattened_buffer_map.Set(p.first, buffer); + return; + } + } + LOG(FATAL) << "ValueError: postflattened buffer " << postflattened_buffer->name + << " does not exist."; +} + +AssertFrame Assert(PrimExpr condition, String message) { + ObjectPtr n = make_object(); + n->condition = condition; + n->message = tvm::tir::StringImm(message); + return AssertFrame(n); +} + +LetFrame Let(Var var, PrimExpr value) { + ObjectPtr n = make_object(); + n->var = var; + n->value = value; + return LetFrame(n); +} + +AllocateFrame Allocate(Array extents, DataType dtype, String storage_scope, + Optional condition, Optional> annotations) { + ObjectPtr n = make_object(); + n->extents = extents; + n->dtype = dtype; + n->storage_scope = storage_scope; + n->condition = condition.value_or(tvm::Bool(true)); + n->annotations = annotations.value_or(Map()); + n->buffer = BufferDecl(extents, dtype, "", NullOpt, NullOpt, NullOpt, storage_scope, 0, 0, + "default", NullOpt); + return AllocateFrame(n); +} + +AllocateConstFrame AllocateConst(tvm::runtime::NDArray data, DataType dtype, + Array extents, Map annotations) { + ObjectPtr n = make_object(); + n->dtype = dtype; + n->extents = extents; + n->data = data; + n->annotations = annotations; + n->buffer = + BufferDecl(extents, dtype, "", NullOpt, NullOpt, NullOpt, "", 0, 0, "default", NullOpt); + return AllocateConstFrame(n); +} + +LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) { + IterVar iter_var{nullptr}; + + if (Optional opt_frame = IRBuilder::Current()->FindFrame()) { + if (Optional opt_iter_var = opt_frame.value()->env_threads.Get(var)) { + iter_var = opt_iter_var.value(); + } else { + LOG(FATAL) << "ValueError: " << var->name_hint + << " is not an env_thread created using T.env_thread."; + } + } else { + LOG(FATAL) << "LaunchThread can only be used inside a PrimFunc"; + } + ObjectPtr n = make_object(); + if (!iter_var->dom.defined()) { + const_cast(iter_var.get())->dom = Range(0, extent); + } else if (!arith::Analyzer().CanProveEqual(iter_var->dom->extent, extent)) { + LOG(FATAL) << "ValueError: Inconsistent extents of environment thread. " + << iter_var->dom->extent << " vs " << extent; + } + n->iter_var = iter_var; + n->extent = extent; + n->attr_key = iter_var->thread_tag == "vthread" ? "virtual_thread" : "thread_extent"; + return LaunchThreadFrame(n); +} + +RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope, + PrimExpr condition) { + ObjectPtr n = make_object(); + n->buffer_slice = buffer_slice; + n->storage_scope = storage_scope; + n->condition = condition; + return RealizeFrame(n); +} + +AttrFrame Attr(ObjectRef node, String attr_key, PrimExpr value) { + ObjectPtr n = make_object(); + n->node = node; + n->attr_key = attr_key; + n->value = value; + return AttrFrame(n); +} + +WhileFrame While(PrimExpr condition) { + ObjectPtr n = make_object(); + n->condition = condition; + return WhileFrame(n); +} + +IfFrame If(PrimExpr condition) { + ObjectPtr n = make_object(); + n->condition = condition; + n->then_stmts = NullOpt; + n->else_stmts = NullOpt; + return IfFrame(n); +} + +ThenFrame Then() { + ObjectPtr n = make_object(); + return ThenFrame(n); +} + +ElseFrame Else() { + ObjectPtr n = make_object(); + return ElseFrame(n); +} + +Var EnvThread(String thread_tag) { + IterVar iter_var(Range{nullptr}, Var("", DataType::Int(32)), tvm::tir::IterVarType::kThreadIndex, + thread_tag); + Var var = iter_var->var; + if (Optional opt_frame = IRBuilder::Current()->FindFrame()) { + opt_frame.value()->env_threads.Set(var, iter_var); + } else { + LOG(FATAL) << "EnvThread can only be used inside a PrimFunc"; + } + return var; +} + +void BufferStore(Buffer buffer, PrimExpr value, Array indices) { + AddToParent(tvm::tir::BufferStore(buffer, value, indices)); +} + +void Prefetch(Buffer buffer, Array bounds) { + AddToParent(tvm::tir::Prefetch(buffer, bounds)); +} + +void Evaluate(PrimExpr value) { AddToParent(tvm::tir::Evaluate(value)); } + +using tvm::script::ir_builder::details::Namer; + +TVM_STATIC_IR_FUNCTOR(Namer, vtable) + .set_dispatch([](const ObjectRef& node, String name) -> void { + tvm::tir::BufferNode* buffer = + const_cast(node.as()); + buffer->name = name; + Namer::Name(buffer->data, name); + int n = buffer->strides.size(); + for (int i = 0; i < n; ++i) { + PrimExpr e = buffer->strides[i]; + if (const tvm::tir::VarNode* v = e.as()) { + Namer::Name(GetRef(v), name + "_s" + std::to_string(i)); + } + } + }); + +TVM_STATIC_IR_FUNCTOR(Namer, vtable) + .set_dispatch([](const ObjectRef& node, String name) -> void { + using namespace tvm::tir; + SizeVarNode* var = const_cast(node.as()); + var->name_hint = name; + }); + +TVM_STATIC_IR_FUNCTOR(Namer, vtable) + .set_dispatch([](const ObjectRef& node, String name) -> void { + using namespace tvm::tir; + VarNode* var = const_cast(node.as()); + var->name_hint = name; + }); + +TVM_STATIC_IR_FUNCTOR(Namer, vtable) + .set_dispatch([](const ObjectRef& node, String name) -> void { + using namespace tvm::tir; + IterVarNode* var = const_cast(node.as()); + Namer::Name(var->var, name); + }); + +TVM_REGISTER_GLOBAL("script.ir_builder.tir.BufferDecl").set_body_typed(BufferDecl); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Ptr").set_body_typed(Ptr); + +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Block").set_body_typed(Block); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Init").set_body_typed(Init); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Where").set_body_typed(Where); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Reads").set_body_typed(Reads); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Writes").set_body_typed(Writes); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.BlockAttrs").set_body_typed(BlockAttrs); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.AllocBuffer").set_body_typed(AllocBuffer); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisSpatial").set_body_typed(axis::Spatial); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisReduce").set_body_typed(axis::Reduce); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisScan").set_body_typed(axis::Scan); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisOpaque").set_body_typed(axis::Opaque); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisRemap").set_body_typed(axis::Remap); + +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Serial").set_body_typed(Serial); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Parallel").set_body_typed(Parallel); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Vectorized").set_body_typed(Vectorized); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Unroll").set_body_typed(Unroll); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.ThreadBinding").set_body_typed(ThreadBinding); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Grid").set_body_typed(Grid); + +TVM_REGISTER_GLOBAL("script.ir_builder.tir.PrimFunc").set_body_typed(PrimFunc); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Arg") + .set_body_typed([](String name, ObjectRef obj) -> ObjectRef { + using namespace tvm::tir; + if (const auto* var = obj.as()) { + return Arg(name, GetRef(var)); + } + if (const auto* buffer = obj.as()) { + return Arg(name, GetRef(buffer)); + } + LOG(FATAL) << "ValueError: Unexpected type for TIR Arg: " << obj->GetTypeKey(); + throw; + }); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.FuncName").set_body_typed(FuncName); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.FuncAttrs").set_body_typed(FuncAttrs); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.FuncRet").set_body_typed(FuncRet); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.MatchBuffer").set_body_typed(MatchBuffer); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.PreflattenedBuffer").set_body_typed(PreflattenedBuffer); + +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Assert").set_body_typed(Assert); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Let").set_body_typed(Let); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Allocate").set_body_typed(Allocate); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.AllocateConst").set_body_typed(AllocateConst); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Realize").set_body_typed(Realize); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Attr").set_body_typed(Attr); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.While").set_body_typed(While); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.If").set_body_typed(If); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Then").set_body_typed(Then); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Else").set_body_typed(Else); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.DeclBuffer").set_body_typed(DeclBuffer); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.LaunchThread").set_body_typed(LaunchThread); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.EnvThread").set_body_typed(EnvThread); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.BufferStore").set_body_typed(BufferStore); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Prefetch").set_body_typed(Prefetch); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Evaluate").set_body_typed(Evaluate); + +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int8").set_body_typed(Int8); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int16").set_body_typed(Int16); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int32").set_body_typed(Int32); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int64").set_body_typed(Int64); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.UInt8").set_body_typed(UInt8); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.UInt16").set_body_typed(UInt16); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.UInt32").set_body_typed(UInt32); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.UInt64").set_body_typed(UInt64); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float8").set_body_typed(Float8); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float16").set_body_typed(Float16); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float32").set_body_typed(Float32); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float64").set_body_typed(Float64); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int32x4").set_body_typed(Int32x4); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int32x8").set_body_typed(Int32x8); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int32x16").set_body_typed(Int32x16); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Boolean").set_body_typed(Boolean); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Handle").set_body_typed(Handle); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Void").set_body_typed(Void); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.min") + .set_body_typed([](PrimExpr a, PrimExpr b) -> PrimExpr { return tvm::min(a, b); }); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.max") + .set_body_typed([](PrimExpr a, PrimExpr b) -> PrimExpr { return tvm::max(a, b); }); + +} // namespace tir +} // namespace ir_builder +} // namespace script +} // namespace tvm diff --git a/src/script/ir_builder/tir/utils.h b/src/script/ir_builder/tir/utils.h new file mode 100644 index 0000000000..8a712c07dd --- /dev/null +++ b/src/script/ir_builder/tir/utils.h @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_IR_BUILDER_TIR_UTILS_H_ +#define TVM_SCRIPT_IR_BUILDER_TIR_UTILS_H_ + +#include +#include +#include + +namespace tvm { +namespace script { +namespace ir_builder { +namespace tir { + +inline void AddToParent(tvm::tir::Stmt stmt) { + IRBuilder builder = IRBuilder::Current(); + if (builder->frames.empty()) { + ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; + builder->result = stmt; + } else if (const auto* tir_frame = builder->frames.back().as()) { + GetRef(tir_frame)->stmts.push_back(stmt); + } else { + LOG(FATAL) << "TypeError: Unsupported frame type: " << builder->frames.back(); + } +} + +inline tvm::tir::Stmt AsStmt(const Array& stmt) { + using namespace tvm::tir; + if (stmt.empty()) { + return tvm::tir::Evaluate(0); + } else if (stmt.size() == 1) { + return stmt[0]; + } else { + return SeqStmt(stmt); + } +} + +inline BlockFrame FindBlockFrame(const String& method) { + if (Optional frame = IRBuilder::Current()->GetLastFrame()) { + return frame.value(); + } + LOG(FATAL) << "ValueError: Block frame not find. Please ensure '" << method + << "' is called under T.block()"; + throw; +} + +inline PrimFuncFrame FindPrimFuncFrame(const String& method) { + if (Optional frame = IRBuilder::Current()->GetLastFrame()) { + return frame.value(); + } + LOG(FATAL) << "ValueError: PrimFunc frame not find. Please ensure '" << method + << "' is called under T.prim_func()"; + throw; +} + +inline IfFrame FindIfFrame(const String& method) { + if (Optional frame = IRBuilder::Current()->GetLastFrame()) { + return frame.value(); + } else { + LOG(FATAL) << "ValueError: IfThenElse frame not find. Please ensure '" << method + << "' is called under T.if_()"; + } + throw; +} + +inline tvm::tir::BufferRegion BufferRegionFromLoad(tvm::tir::BufferLoad buffer_load) { + Array ranges; + for (const PrimExpr& index : buffer_load->indices) { + ranges.push_back(Range::FromMinExtent(index, IntImm(index->dtype, 1))); + } + return tvm::tir::BufferRegion(buffer_load->buffer, ranges); +} + +} // namespace tir +} // namespace ir_builder +} // namespace script +} // namespace tvm + +#endif // TVM_TIR_IR_BUILDER_UTILS_H_