From c29eb5a04bed48c7d5b366b507792c0d09a23364 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 9 Oct 2020 11:44:09 -0700 Subject: [PATCH] [Diagnostics][Relay][InferType] Refactor InferType to work on whole module, and use new diagnostics. (#6274) * Refactor the type checker to use diagnostics Although this patch is very large and seemingly disjoint the fixes are required to get it working for the entire stack. I started with first changing InferType to use the diagnostics, these weren't yet in the pass manager so this required changes to module and module pass. InferType wasn't actually written correctly as a pass requring refactoring there, then in order to add spans to AST it required turning on AnnotateSpans which in term required changes to the parser, and module to make it possible to use the errors. These changes to parse and module required changes to diagnostics and InferType. Althought seemingly disconnected there are hidden cycles between the components which require simultaneous change in order to remove the old error reporting. A huge change due to this patch is that the module no longer implicitly type checks functions which are added. * Apply suggestions from code review Co-authored-by: Robert Kimball Co-authored-by: Junru Shao * Apply suggestions from code review Co-authored-by: Tristan Konolige * Clean up parser * CR feedback * Apply Bobs suggestions * Fix up Python interface for diagnostics * Fix test_ir_parser and formatting * Fix cpplint * Fix lint * Fix format * More lint * Fix format * Kill dead doc comment * Fix documentation comment * Rebase fixups * Add docs for type.h * Fix parser.cc * Fix unittests * Fix black * Skip previously typechecked functions * fix ACL * Fix numerous issues * Add repr method * Fix issue with Pytest, I am ready to cry * Fix the rest of tests * Kill dead code * Fix dignostic tests * Fix more tests * fix more tests (#11) * Fix diagnostic.py deinit bug * Fix deinit issue * Format * Tweak disabling of override * Format * Fix BYOC * Fix TensorArray stuff * Fix PyTorch * Format * Format Co-authored-by: Robert Kimball Co-authored-by: Junru Shao Co-authored-by: Tristan Konolige Co-authored-by: Cody Yu Co-authored-by: Zhi <5145158+zhiics@users.noreply.github.com> --- CMakeLists.txt | 1 + Makefile | 5 + .../install/ubuntu_install_arm_compute_lib.sh | 0 .../ubuntu_install_ethosn_driver_stack.sh | 1 - include/tvm/ir/diagnostic.h | 262 +++++ include/tvm/ir/module.h | 12 +- include/tvm/ir/span.h | 2 +- include/tvm/ir/transform.h | 14 +- include/tvm/ir/type.h | 17 +- include/tvm/ir/type_relation.h | 13 +- include/tvm/parser/parser.h | 3 +- include/tvm/parser/source_map.h | 69 +- include/tvm/relay/analysis.h | 9 +- include/tvm/relay/transform.h | 12 - include/tvm/runtime/container.h | 14 + include/tvm/runtime/object.h | 2 +- include/tvm/tir/op.h | 1 + pyproject.toml | 2 +- python/tvm/__init__.py | 9 +- python/tvm/_ffi/runtime_ctypes.py | 2 +- python/tvm/ir/__init__.py | 1 + python/tvm/ir/diagnostics/__init__.py | 118 +++ python/tvm/ir/diagnostics/_ffi_api.py | 21 + python/tvm/ir/module.py | 23 +- python/tvm/parser/__init__.py | 13 + .../relay/backend/graph_runtime_factory.py | 4 +- python/tvm/relay/backend/interpreter.py | 4 +- python/tvm/relay/build_module.py | 2 + python/tvm/relay/dataflow_pattern/__init__.py | 2 +- python/tvm/relay/frontend/common.py | 3 +- python/tvm/relay/frontend/pytorch.py | 23 +- python/tvm/relay/frontend/tensorflow.py | 73 +- .../tvm/relay/op/contrib/arm_compute_lib.py | 1 + python/tvm/relay/prelude.py | 599 ++++++----- .../relay/quantize/_partition_conversions.py | 8 +- python/tvm/relay/quantize/quantize.py | 2 +- python/tvm/relay/std/nat.rly | 78 ++ python/tvm/relay/std/prelude.rly | 5 +- python/tvm/relay/testing/__init__.py | 3 +- python/tvm/relay/testing/nat.py | 149 +-- python/tvm/relay/testing/py_converter.py | 27 +- python/tvm/relay/transform/memory_alloc.py | 2 + python/tvm/relay/transform/transform.py | 14 + src/driver/driver_api.cc | 22 +- src/ir/diagnostic.cc | 299 ++++++ src/ir/module.cc | 114 ++- src/ir/span.cc | 6 +- src/ir/transform.cc | 26 +- src/ir/type.cc | 18 +- src/parser/diagnostic.h | 179 ---- src/parser/parser.cc | 807 +++++++++------ src/parser/source_map.cc | 90 +- src/parser/span_check.cc | 109 ++ src/parser/span_check.h | 80 ++ src/parser/token.h | 15 + src/parser/tokenizer.h | 64 +- src/printer/relay_text_printer.cc | 15 +- src/relay/analysis/kind_check.cc | 53 +- src/relay/analysis/type_solver.cc | 107 +- src/relay/analysis/type_solver.h | 23 +- src/relay/analysis/well_formed.cc | 43 +- src/relay/backend/build_module.cc | 5 + src/relay/backend/compile_engine.h | 2 +- src/relay/backend/graph_runtime_codegen.cc | 4 +- src/relay/backend/interpreter.cc | 4 +- src/relay/backend/vm/compiler.cc | 1 + src/relay/ir/transform.cc | 32 +- src/relay/op/nn/convolution.h | 90 +- src/relay/transforms/partition_graph.cc | 5 + src/relay/transforms/type_infer.cc | 281 +++--- .../contrib/random/mt_random_engine.cc | 2 +- src/tir/transforms/split_host_device.cc | 3 +- tests/cpp/relay_build_module_test.cc | 2 + .../contrib/test_ethosn/infrastructure.py | 19 +- .../contrib/test_ethosn/test_reshape.py | 1 + tests/python/frontend/pytorch/test_lstm.py | 5 +- .../frontend/tensorflow/test_forward.py | 2 + tests/python/relay/test_adt.py | 938 +++--------------- .../test_analysis_get_calibration_data.py | 3 + tests/python/relay/test_any.py | 7 +- .../relay/test_backend_graph_runtime.py | 2 + .../python/relay/test_backend_interpreter.py | 9 +- tests/python/relay/test_error_reporting.py | 68 -- tests/python/relay/test_expr_functor.py | 3 +- tests/python/relay/test_ir_module.py | 20 +- tests/python/relay/test_ir_parser.py | 47 +- .../relay/test_ir_structural_equal_hash.py | 35 +- tests/python/relay/test_ir_text_printer.py | 15 +- tests/python/relay/test_ir_well_formed.py | 7 +- tests/python/relay/test_json_runtime.py | 24 + tests/python/relay/test_op_qnn_add.py | 7 + tests/python/relay/test_op_qnn_concatenate.py | 4 + tests/python/relay/test_op_qnn_dense.py | 1 + tests/python/relay/test_op_qnn_mul.py | 5 + tests/python/relay/test_op_qnn_subtract.py | 1 + .../python/relay/test_pass_alter_op_layout.py | 5 + .../python/relay/test_pass_annotate_target.py | 2 + ...test_pass_combine_parallel_batch_matmul.py | 1 + .../test_pass_combine_parallel_conv2d.py | 1 + .../relay/test_pass_combine_parallel_dense.py | 1 + tests/python/relay/test_pass_eta_expand.py | 4 +- tests/python/relay/test_pass_fold_constant.py | 4 +- tests/python/relay/test_pass_fuse_ops.py | 17 +- tests/python/relay/test_pass_gradient.py | 9 +- tests/python/relay/test_pass_inline.py | 5 +- .../relay/test_pass_lazy_gradient_init.py | 20 +- tests/python/relay/test_pass_mac_count.py | 1 + tests/python/relay/test_pass_manager.py | 5 +- .../relay/test_pass_merge_compiler_regions.py | 2 + tests/python/relay/test_pass_partial_eval.py | 91 +- .../python/relay/test_pass_partition_graph.py | 43 +- tests/python/relay/test_pass_qnn_legalize.py | 4 + .../test_pass_remove_unused_functions.py | 26 +- .../relay/test_pass_simplify_inference.py | 4 +- .../relay/test_pass_to_a_normal_form.py | 11 +- .../test_pass_to_basic_block_normal_form.py | 20 +- tests/python/relay/test_pass_to_cps.py | 9 +- .../python/relay/test_pass_unmatched_cases.py | 82 +- tests/python/relay/test_pass_vars.py | 13 +- tests/python/relay/test_py_converter.py | 54 +- tests/python/relay/test_tensor_array.py | 785 +++++++++++++++ tests/python/relay/test_type_infer.py | 142 +-- tests/python/relay/test_vm.py | 65 +- tests/python/relay/test_vm_serialization.py | 19 +- tests/python/relay/util/assert_diagnostic.py | 68 ++ .../python/unittest/test_custom_datatypes.py | 1 + .../test_tir_transform_narrow_datatype.py | 2 + tutorials/dev/bring_your_own_datatypes.py | 3 + 128 files changed, 4274 insertions(+), 2594 deletions(-) mode change 100644 => 100755 docker/install/ubuntu_install_arm_compute_lib.sh create mode 100644 include/tvm/ir/diagnostic.h create mode 100644 python/tvm/ir/diagnostics/__init__.py create mode 100644 python/tvm/ir/diagnostics/_ffi_api.py create mode 100644 python/tvm/relay/std/nat.rly create mode 100644 src/ir/diagnostic.cc delete mode 100644 src/parser/diagnostic.h create mode 100644 src/parser/span_check.cc create mode 100644 src/parser/span_check.h delete mode 100644 tests/python/relay/test_error_reporting.py create mode 100644 tests/python/relay/test_tensor_array.py create mode 100644 tests/python/relay/util/assert_diagnostic.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 2818262df0727..e24bbeb5acd84 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -99,6 +99,7 @@ if(MSVC) add_definitions(-D_CRT_SECURE_NO_WARNINGS) add_definitions(-D_SCL_SECURE_NO_WARNINGS) add_definitions(-D_ENABLE_EXTENDED_ALIGNED_STORAGE) + add_definitions(-DNOMINMAX) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /bigobj") diff --git a/Makefile b/Makefile index 9823c5c2b5688..0896246a92eed 100644 --- a/Makefile +++ b/Makefile @@ -135,6 +135,11 @@ jvminstall: mvn install -P$(JVM_PKG_PROFILE) -Dcxx="$(CXX)" \ -Dcflags="$(PKG_CFLAGS)" -Dldflags="$(PKG_LDFLAGS)" \ -Dcurrent_libdir="$(ROOTDIR)/$(OUTPUTDIR)" $(JVM_TEST_ARGS)) +format: + ./tests/lint/git-clang-format.sh -i origin/master + black . + cd rust; which cargo && cargo fmt --all; cd .. + # clean rule clean: diff --git a/docker/install/ubuntu_install_arm_compute_lib.sh b/docker/install/ubuntu_install_arm_compute_lib.sh old mode 100644 new mode 100755 diff --git a/docker/install/ubuntu_install_ethosn_driver_stack.sh b/docker/install/ubuntu_install_ethosn_driver_stack.sh index ecf25a6281990..15b93bbdf901e 100755 --- a/docker/install/ubuntu_install_ethosn_driver_stack.sh +++ b/docker/install/ubuntu_install_ethosn_driver_stack.sh @@ -57,4 +57,3 @@ git checkout "$repo_revision" cd "driver" scons install_prefix="$install_path" install - diff --git a/include/tvm/ir/diagnostic.h b/include/tvm/ir/diagnostic.h new file mode 100644 index 0000000000000..6b9807487bae7 --- /dev/null +++ b/include/tvm/ir/diagnostic.h @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file diagnostic.h + * \brief A new diagnostic interface for TVM error reporting. + * + * A prototype of the new diagnostic reporting interface for TVM. + * + * Eventually we hope to promote this file to the top-level and + * replace the existing errors.h. + */ + +#ifndef TVM_IR_DIAGNOSTIC_H_ +#define TVM_IR_DIAGNOSTIC_H_ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace tvm { + +using tvm::parser::SourceMap; +using tvm::runtime::TypedPackedFunc; + +extern const char* kTVM_INTERNAL_ERROR_MESSAGE; + +#define ICHECK_INDENT " " + +#define ICHECK_BINARY_OP(name, op, x, y) \ + if (dmlc::LogCheckError _check_err = dmlc::LogCheck##name(x, y)) \ + dmlc::LogMessageFatal(__FILE__, __LINE__).stream() \ + << kTVM_INTERNAL_ERROR_MESSAGE << std::endl \ + << ICHECK_INDENT << "Check failed: " << #x " " #op " " #y << *(_check_err.str) << ": " + +#define ICHECK(x) \ + if (!(x)) \ + dmlc::LogMessageFatal(__FILE__, __LINE__).stream() \ + << kTVM_INTERNAL_ERROR_MESSAGE << ICHECK_INDENT << "Check failed: " #x << " == false: " + +#define ICHECK_LT(x, y) ICHECK_BINARY_OP(_LT, <, x, y) +#define ICHECK_GT(x, y) ICHECK_BINARY_OP(_GT, >, x, y) +#define ICHECK_LE(x, y) ICHECK_BINARY_OP(_LE, <=, x, y) +#define ICHECK_GE(x, y) ICHECK_BINARY_OP(_GE, >=, x, y) +#define ICHECK_EQ(x, y) ICHECK_BINARY_OP(_EQ, ==, x, y) +#define ICHECK_NE(x, y) ICHECK_BINARY_OP(_NE, !=, x, y) +#define ICHECK_NOTNULL(x) \ + ((x) == nullptr ? dmlc::LogMessageFatal(__FILE__, __LINE__).stream() \ + << kTVM_INTERNAL_ERROR_MESSAGE << __INDENT << "Check not null: " #x \ + << ' ', \ + (x) : (x)) // NOLINT(*) + +/*! \brief The diagnostic level, controls the printing of the message. */ +enum class DiagnosticLevel : int { + kBug = 10, + kError = 20, + kWarning = 30, + kNote = 40, + kHelp = 50, +}; + +class DiagnosticBuilder; + +/*! \brief A compiler diagnostic. */ +class Diagnostic; + +/*! \brief A compiler diagnostic message. */ +class DiagnosticNode : public Object { + public: + /*! \brief The level. */ + DiagnosticLevel level; + /*! \brief The span at which to report an error. */ + Span span; + /*! \brief The diagnostic message. */ + String message; + + // override attr visitor + void VisitAttrs(AttrVisitor* v) { + v->Visit("level", &level); + v->Visit("span", &span); + v->Visit("message", &message); + } + + bool SEqualReduce(const DiagnosticNode* other, SEqualReducer equal) const { + return equal(this->level, other->level) && equal(this->span, other->span) && + equal(this->message, other->message); + } + + static constexpr const char* _type_key = "Diagnostic"; + TVM_DECLARE_FINAL_OBJECT_INFO(DiagnosticNode, Object); +}; + +class Diagnostic : public ObjectRef { + public: + TVM_DLL Diagnostic(DiagnosticLevel level, Span span, const std::string& message); + + static DiagnosticBuilder Bug(Span span); + static DiagnosticBuilder Error(Span span); + static DiagnosticBuilder Warning(Span span); + static DiagnosticBuilder Note(Span span); + static DiagnosticBuilder Help(Span span); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Diagnostic, ObjectRef, DiagnosticNode); +}; + +/*! + * \brief A wrapper around std::stringstream to build a diagnostic. + */ +class DiagnosticBuilder { + public: + /*! \brief The level. */ + DiagnosticLevel level; + + /*! \brief The source name. */ + SourceName source_name; + + /*! \brief The span of the diagnostic. */ + Span span; + + template + DiagnosticBuilder& operator<<(const T& val) { // NOLINT(*) + stream_ << val; + return *this; + } + + DiagnosticBuilder() : level(DiagnosticLevel::kError), source_name(), span(Span()) {} + + DiagnosticBuilder(const DiagnosticBuilder& builder) + : level(builder.level), source_name(builder.source_name), span(builder.span) {} + + DiagnosticBuilder(DiagnosticLevel level, Span span) : level(level), span(span) {} + + operator Diagnostic() { return Diagnostic(this->level, this->span, this->stream_.str()); } + + private: + std::stringstream stream_; + friend class Diagnostic; +}; + +/*! + * \brief A diagnostic context for recording errors against a source file. + */ +class DiagnosticContext; + +/*! \brief Display diagnostics in a given display format. + * + * A diagnostic renderer is responsible for converting the + * raw diagnostics into consumable output. + * + * For example the terminal renderer will render a sequence + * of compiler diagnostics to std::out and std::err in + * a human readable form. + */ +class DiagnosticRendererNode : public Object { + public: + TypedPackedFunc renderer; + + // override attr visitor + void VisitAttrs(AttrVisitor* v) {} + + static constexpr const char* _type_key = "DiagnosticRenderer"; + TVM_DECLARE_FINAL_OBJECT_INFO(DiagnosticRendererNode, Object); +}; + +class DiagnosticRenderer : public ObjectRef { + public: + TVM_DLL DiagnosticRenderer(TypedPackedFunc render); + TVM_DLL DiagnosticRenderer() + : DiagnosticRenderer(TypedPackedFunc()) {} + + void Render(const DiagnosticContext& ctx); + + DiagnosticRendererNode* operator->() { + CHECK(get() != nullptr); + return static_cast(get_mutable()); + } + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(DiagnosticRenderer, ObjectRef, DiagnosticRendererNode); +}; + +class DiagnosticContextNode : public Object { + public: + /*! \brief The Module to report against. */ + IRModule module; + + /*! \brief The set of diagnostics to report. */ + Array diagnostics; + + /*! \brief The renderer set for the context. */ + DiagnosticRenderer renderer; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("module", &module); + v->Visit("diagnostics", &diagnostics); + } + + bool SEqualReduce(const DiagnosticContextNode* other, SEqualReducer equal) const { + return equal(module, other->module) && equal(diagnostics, other->diagnostics); + } + + static constexpr const char* _type_key = "DiagnosticContext"; + TVM_DECLARE_FINAL_OBJECT_INFO(DiagnosticContextNode, Object); +}; + +class DiagnosticContext : public ObjectRef { + public: + TVM_DLL DiagnosticContext(const IRModule& module, const DiagnosticRenderer& renderer); + TVM_DLL static DiagnosticContext Default(const IRModule& source_map); + + /*! \brief Emit a diagnostic. + * \param diagnostic The diagnostic to emit. + */ + void Emit(const Diagnostic& diagnostic); + + /*! \brief Emit a diagnostic and then immediately attempt to render all errors. + * + * \param diagnostic The diagnostic to emit. + * + * Note: this will raise an exception if you would like to instead continue execution + * use the Emit method instead. + */ + void EmitFatal(const Diagnostic& diagnostic); + + /*! \brief Render the errors and raise a DiagnosticError exception. */ + void Render(); + + DiagnosticContextNode* operator->() { + CHECK(get() != nullptr); + return static_cast(get_mutable()); + } + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(DiagnosticContext, ObjectRef, DiagnosticContextNode); +}; + +DiagnosticRenderer TerminalRenderer(std::ostream& ostream); + +} // namespace tvm +#endif // TVM_IR_DIAGNOSTIC_H_ diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 7af84b687f5fd..b3f8438f6ec93 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -29,6 +29,7 @@ #include #include #include +#include #include #include @@ -53,14 +54,17 @@ class IRModuleNode : public Object { Map functions; /*! \brief A map from global type vars to ADT type data. */ Map type_definitions; + /*! \brief The source map for the module. */ + parser::SourceMap source_map; - IRModuleNode() {} + IRModuleNode() : source_map() {} void VisitAttrs(AttrVisitor* v) { v->Visit("functions", &functions); v->Visit("type_definitions", &type_definitions); v->Visit("global_var_map_", &global_var_map_); v->Visit("global_type_var_map_", &global_type_var_map_); + v->Visit("source_map", &source_map); } TVM_DLL bool SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const; @@ -280,12 +284,14 @@ class IRModule : public ObjectRef { * \param functions Functions in the module. * \param type_definitions Type definitions in the module. * \param import_set Set of imported files in the module + * \param map The module source map. */ TVM_DLL explicit IRModule(Map functions, Map type_definitions = {}, - std::unordered_set import_set = {}); + std::unordered_set import_set = {}, parser::SourceMap map = {}); + /*! \brief default constructor */ - IRModule() : IRModule(Map()) {} + IRModule() : IRModule(Map({})) {} /*! * \brief constructor * \param n The object pointer. diff --git a/include/tvm/ir/span.h b/include/tvm/ir/span.h index 95a1acb9412db..6a7f3f3190a08 100644 --- a/include/tvm/ir/span.h +++ b/include/tvm/ir/span.h @@ -114,7 +114,7 @@ class Span : public ObjectRef { TVM_DLL Span(SourceName source_name, int line, int end_line, int column, int end_column); /*! \brief Merge two spans into one which captures the combined regions. */ - TVM_DLL Span Merge(const Span& other); + TVM_DLL Span Merge(const Span& other) const; TVM_DEFINE_OBJECT_REF_METHODS(Span, ObjectRef, SpanNode); }; diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index 5bfb51adb0ac0..2bbf28311b30b 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -56,6 +56,7 @@ #ifndef TVM_IR_TRANSFORM_H_ #define TVM_IR_TRANSFORM_H_ +#include #include #include #include @@ -84,11 +85,6 @@ using TraceFunc = */ class PassContextNode : public Object { public: - /*! - * \brief The error reporter used to notify users why an optimization fails. - */ - ErrorReporter err_reporter; - /*! \brief The default optimization level. */ int opt_level{2}; @@ -96,11 +92,12 @@ class PassContextNode : public Object { Array required_pass; /*! \brief The list of disabled passes. */ Array disabled_pass; - /*! \brief Trace function to be invoked before and after each pass. */ - TraceFunc trace_func; - + /*! \brief The diagnostic context. */ + mutable Optional diag_ctx; /*! \brief Pass specific configurations. */ Map config; + /*! \brief Trace function to be invoked before and after each pass. */ + TraceFunc trace_func; PassContextNode() = default; @@ -139,6 +136,7 @@ class PassContextNode : public Object { v->Visit("required_pass", &required_pass); v->Visit("disabled_pass", &disabled_pass); v->Visit("config", &config); + v->Visit("diag_ctx", &diag_ctx); } static constexpr const char* _type_key = "transform.PassContext"; diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index 65b454f08b52f..19b1ad0a0d835 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -250,8 +250,9 @@ class TypeVar : public Type { * \brief Constructor * \param name_hint The name of the type var. * \param kind The kind of the type var. + * \param span The span information. */ - TVM_DLL TypeVar(String name_hint, TypeKind kind); + TVM_DLL TypeVar(String name_hint, TypeKind kind, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(TypeVar, Type, TypeVarNode); }; @@ -300,8 +301,9 @@ class GlobalTypeVar : public Type { * \brief Constructor * \param name_hint The name of the type var. * \param kind The kind of the type var. + * \param span The span of the type. */ - TVM_DLL GlobalTypeVar(String name_hint, TypeKind kind); + TVM_DLL GlobalTypeVar(String name_hint, TypeKind kind, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(GlobalTypeVar, Type, GlobalTypeVarNode); }; @@ -341,8 +343,9 @@ class TupleType : public Type { /*! * \brief Constructor * \param fields Fields in the tuple. + * \param span The span of the type. */ - TVM_DLL explicit TupleType(Array fields); + TVM_DLL explicit TupleType(Array fields, Span span = Span()); /*! * \brief Create an empty tuple type that constains nothing. @@ -448,10 +451,11 @@ class FuncType : public Type { * \param ret_type The type of the return value. * \param type_params The type parameters. * \param type_constraints The type constraints. + * \param span The span information. * \sa FuncTypeNode for more docs about these fields. */ TVM_DLL FuncType(Array arg_types, Type ret_type, Array type_params, - Array type_constraints); + Array type_constraints, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(FuncType, Type, FuncTypeNode); }; @@ -495,8 +499,9 @@ class IncompleteType : public Type { /*! * \brief Constructor. * \param kind kind of the type. + * \param span The span information. */ - TVM_DLL explicit IncompleteType(TypeKind kind); + TVM_DLL explicit IncompleteType(TypeKind kind, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(IncompleteType, Type, IncompleteTypeNode); }; @@ -536,7 +541,7 @@ class RelayRefTypeNode : public TypeNode { */ class RelayRefType : public Type { public: - TVM_DLL explicit RelayRefType(Type value); + TVM_DLL explicit RelayRefType(Type value, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(RelayRefType, Type, RelayRefTypeNode); }; } // namespace tvm diff --git a/include/tvm/ir/type_relation.h b/include/tvm/ir/type_relation.h index dbd241afa4580..83323b01e419d 100644 --- a/include/tvm/ir/type_relation.h +++ b/include/tvm/ir/type_relation.h @@ -25,6 +25,7 @@ #define TVM_IR_TYPE_RELATION_H_ #include +#include #include #include #include @@ -100,7 +101,7 @@ class TypeReporterNode : public Object { * \brief assert shape expression comparison. * \note Use assert only if any of the condition input is symbolic. * \param cond The condition of operation. - * \return false if assertation can be proven to have failed + * \return false if assertion can be proven to have failed * true if solver can still proceed. */ TVM_DLL virtual bool Assert(const PrimExpr& cond) = 0; @@ -108,16 +109,20 @@ class TypeReporterNode : public Object { * \brief assert shape expression equals each other. * \param lhs The left operand. * \param rhs The right operand. - * \return false if assertation can be proven to have failed + * \return false if assertion can be proven to have failed * true if solver can still proceed. */ TVM_DLL virtual bool AssertEQ(const PrimExpr& lhs, const PrimExpr& rhs) = 0; /*! * \brief Set the location at which to report unification errors. - * \param ref The program node to report the error. + * \param span The span at which to report the error. */ - TVM_DLL virtual void SetLocation(const ObjectRef& ref) = 0; + TVM_DLL virtual void SetSpan(const Span& span) = 0; + + TVM_DLL virtual Span GetSpan() = 0; + + TVM_DLL virtual DiagnosticContext GetDiagCtx() = 0; /*! * \brief Retrieve the current global module. diff --git a/include/tvm/parser/parser.h b/include/tvm/parser/parser.h index 5c1239b1f59e9..7673eec2a337f 100644 --- a/include/tvm/parser/parser.h +++ b/include/tvm/parser/parser.h @@ -32,7 +32,8 @@ namespace tvm { namespace parser { -IRModule ParseModule(std::string file_name, std::string file_content); +IRModule ParseModule(std::string file_name, std::string file_content, + Optional init_module = Optional()); } // namespace parser } // namespace tvm diff --git a/include/tvm/parser/source_map.h b/include/tvm/parser/source_map.h index cf926665e2185..5595574265c6b 100644 --- a/include/tvm/parser/source_map.h +++ b/include/tvm/parser/source_map.h @@ -37,39 +37,38 @@ namespace parser { /*! \brief A program source in any language. * - * Could represent the source from an ML framework or the internal - * source of a TVM program. + * Could represent the source from an ML framework or a source + * representing a tvm::IRModule. */ -struct Source { +class Source; + +class SourceNode : public Object { + public: /*! \brief The source name. */ SourceName source_name; /*! \brief The raw source. */ - std::string source; + String source; + /*! \brief A mapping of line breaks into the raw source. */ std::vector> line_map; - /*! \brief An empty source. */ - Source() : source_name(), source(), line_map() {} - - /*! \brief Construct a source from a string. */ - TVM_DLL explicit Source(const SourceName& src_name, const std::string& source); - - TVM_DLL Source(const Source& source) - : source_name(source.source_name), source(source.source), line_map(source.line_map) {} - - /*! \brief Generate an error message at a specific line and column with the - * annotated message. - * - * The error is written directly to the `out` std::ostream. - * - * \param out The output ostream. - * \param span The span to report the error at. - * \param msg The message to attach. - * - */ - // TODO(@jroesch): replace the ostream with an interface for rendering errors. - TVM_DLL void ReportAt(std::ostream& out, const Span& span, const std::string& msg) const; + // override attr visitor + void VisitAttrs(AttrVisitor* v) { + v->Visit("source_name", &source_name); + v->Visit("source", &source); + } + + static constexpr const char* _type_key = "Source"; + TVM_DECLARE_FINAL_OBJECT_INFO(SourceNode, Object); +}; + +class Source : public ObjectRef { + public: + TVM_DLL Source(SourceName src_name, std::string source); + TVM_DLL tvm::String GetLine(int line); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Source, ObjectRef, SourceNode); }; /*! @@ -82,7 +81,7 @@ class SourceMap; class SourceMapNode : public Object { public: /*! \brief The source mapping. */ - Map source_map; + Map source_map; // override attr visitor void VisitAttrs(AttrVisitor* v) { v->Visit("source_map", &source_map); } @@ -97,11 +96,23 @@ class SourceMapNode : public Object { class SourceMap : public ObjectRef { public: - TVM_DLL SourceMap(Map source_map); + TVM_DLL SourceMap(Map source_map); + + TVM_DLL SourceMap(std::initializer_list> source_map) + : SourceMap(Map(source_map)) {} + + TVM_DLL SourceMap() : SourceMap({}) {} - TVM_DLL static SourceMap* Get(); + TVM_DLL static SourceMap Global(); + + void Add(const Source& source); + + SourceMapNode* operator->() { + CHECK(get() != nullptr); + return static_cast(get_mutable()); + } - TVM_DEFINE_OBJECT_REF_METHODS(SourceMap, ObjectRef, SourceMapNode); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SourceMap, ObjectRef, SourceMapNode); }; } // namespace parser diff --git a/include/tvm/relay/analysis.h b/include/tvm/relay/analysis.h index b2e7c500edddc..26e5a65ddb5e4 100644 --- a/include/tvm/relay/analysis.h +++ b/include/tvm/relay/analysis.h @@ -24,6 +24,7 @@ #ifndef TVM_RELAY_ANALYSIS_H_ #define TVM_RELAY_ANALYSIS_H_ +#include #include #include #include @@ -49,10 +50,12 @@ namespace relay { * * \param t The type to check. * \param mod The global module. + * \param diag_ctx The Diagnostic context. * * \return The kind of the passed type. */ -TVM_DLL Kind KindCheck(const Type& t, const IRModule& mod); +TVM_DLL Kind KindCheck(const Type& t, const IRModule& mod, + Optional diag_ctx = Optional()); /*! * \brief Check whether an expression is constant. @@ -84,10 +87,12 @@ TVM_DLL bool BasicBlockNormalFormCheck(const Expr& e); * although x is not shadowed. * * \param expr the expression to check. + * \param diag_ctx the diagnostic context * * \return true iff all Var in expr is bound at most once. */ -TVM_DLL bool WellFormed(const Expr& expr); +TVM_DLL bool WellFormed(const Expr& expr, + Optional diag_ctx = Optional()); /*! * \brief Get all bound variables from expression expr. diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index faa2698fdcbcd..cbd6a88e584e8 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -418,18 +418,6 @@ TVM_DLL Pass SimplifyExpr(); */ TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& binds); -/*! - * \brief Infer the type of a function as if it is mapped to var in the mod. - * - * \param f the function. - * \param mod The module used for referencing global functions. - * \param var The global variable corresponding to the function. - * - * \return A type checked Function with its checked_type field populated. - * \note this function mutates mod and is not thread-safe. - */ -TVM_DLL Function InferType(const Function& f, const IRModule& mod, const GlobalVar& var); - /*! * \brief Apply rewrite rules to rewrite the expr in post DFS order. This * function is used as a helper function to rewrtie an expression in a pass. diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h index 24e268494ae61..7778c5d8424c7 100644 --- a/include/tvm/runtime/container.h +++ b/include/tvm/runtime/container.h @@ -1281,6 +1281,20 @@ class String : public ObjectRef { */ bool empty() const { return size() == 0; } + /*! + * \brief Read an element. + * \param pos The position at which to read the character. + * + * \return The char at position + */ + char at(size_t pos) const { + if (pos < size()) { + return data()[pos]; + } else { + throw std::out_of_range("tvm::String index out of bounds"); + } + } + /*! * \brief Return the data pointer * diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index be613f1549731..e6ca832c70c21 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -633,7 +633,7 @@ struct ObjectPtrEqual { * \param ParentType The name of the ParentType */ #define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ - static_assert(!ParentType::_type_final, "ParentObj maked as final"); \ + static_assert(!ParentType::_type_final, "ParentObj marked as final"); \ static uint32_t RuntimeTypeIndex() { \ static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 || \ TypeName::_type_child_slots < ParentType::_type_child_slots, \ diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 9e53e97e5d7ff..867acfc97aabd 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -28,6 +28,7 @@ #ifndef TVM_TIR_OP_H_ #define TVM_TIR_OP_H_ +#include #include #include #include diff --git a/pyproject.toml b/pyproject.toml index 8cf53c927c3d9..5cca711ddbe6e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ exclude = ''' | \.tvm_test_data | \.vscode | \.venv - | 3rdparty\/ + | 3rdparty | build\/ | cmake\/ | conda\/ diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index c3c5e3d357e50..569e8f042486c 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -18,6 +18,7 @@ """TVM: Open Deep Learning Compiler Stack.""" import multiprocessing import sys +import os import traceback # top-level alias @@ -72,7 +73,13 @@ def tvm_wrap_excepthook(exception_hook): def wrapper(exctype, value, trbk): """Clean subprocesses when TVM is interrupted.""" - exception_hook(exctype, value, trbk) + in_pytest = "PYTEST_CURRENT_TEST" in os.environ + + if exctype is error.DiagnosticError and not in_pytest: + pass + else: + exception_hook(exctype, value, trbk) + if hasattr(multiprocessing, "active_children"): # pylint: disable=not-callable for p in multiprocessing.active_children(): diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 58be070f9ce10..3a874ebb1208e 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -193,7 +193,7 @@ class TVMContext(ctypes.Structure): def __init__(self, device_type, device_id): super(TVMContext, self).__init__() - self.device_type = device_type + self.device_type = int(device_type) self.device_id = device_id def _GetDeviceAttr(self, device_type, device_id, attr_id): diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py index f1d1d502a27ee..e35077bb5aab4 100644 --- a/python/tvm/ir/__init__.py +++ b/python/tvm/ir/__init__.py @@ -31,3 +31,4 @@ from .container import Array, Map from . import transform +from . import diagnostics diff --git a/python/tvm/ir/diagnostics/__init__.py b/python/tvm/ir/diagnostics/__init__.py new file mode 100644 index 0000000000000..6503743aaa51b --- /dev/null +++ b/python/tvm/ir/diagnostics/__init__.py @@ -0,0 +1,118 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +""" +The diagnostic interface to TVM, used for reporting and rendering +diagnostic information by the compiler. This module exposes +three key abstractions: a Diagnostic, the DiagnosticContext, +and the DiagnosticRenderer. +""" +import enum +import tvm._ffi +from . import _ffi_api +from ... import get_global_func, register_func, Object + + +def get_renderer(): + """ + Get the the diagnostic renderer. + + Returns + ------- + renderer: DiagnosticRenderer + """ + return _ffi_api.GetRenderer() + + +def override_renderer(render_func): + """ + Sets a custom renderer for diagnostics. + + Params + ------ + render_func: Option[Callable[[DiagnosticContext], None]] + If the render_func is None it will remove the current custom renderer + and return to default behavior. + """ + if render_func: + + def _render_factory(): + return DiagnosticRenderer(render_func) + + register_func("diagnostics.OverrideRenderer", _render_factory, override=True) + else: + _ffi_api.ClearRenderer() + + +class DiagnosticLevel(enum.IntEnum): + """The diagnostic level, see diagnostic.h for more details.""" + + BUG = 10 + ERROR = 20 + WARNING = 30 + NOTE = 40 + HELP = 50 + + +@tvm._ffi.register_object("Diagnostic") +class Diagnostic(Object): + """A single diagnostic object from TVM.""" + + def __init__(self, level, span, message): + self.__init_handle_by_constructor__(_ffi_api.Diagnostic, level, span, message) + + +@tvm._ffi.register_object("DiagnosticRenderer") +class DiagnosticRenderer(Object): + """ + A diagnostic renderer, which given a diagnostic context produces a "rendered" + form of the diagnostics for either human or computer consumption. + """ + + def __init__(self, render_func): + self.__init_handle_by_constructor__(_ffi_api.DiagnosticRenderer, render_func) + + def render(self, ctx): + """ + Render the provided context. + + Params + ------ + ctx: DiagnosticContext + The diagnostic context to render. + """ + return _ffi_api.DiagnosticRendererRender(self, ctx) + + +# Register the diagnostic context. +@tvm._ffi.register_object("DiagnosticContext") +class DiagnosticContext(Object): + """ + A diagnostic context which records active errors + and contains a renderer. + """ + + def __init__(self, module, renderer): + self.__init_handle_by_constructor__(_ffi_api.DiagnosticContext, module, renderer) + + def emit(self, diagnostic): + """Emit a diagnostic.""" + _ffi_api.Emit(self, diagnostic) + + def render(self): + """Render the current context using its renderer member.""" + _ffi_api.DiagnosticContextRender(self) diff --git a/python/tvm/ir/diagnostics/_ffi_api.py b/python/tvm/ir/diagnostics/_ffi_api.py new file mode 100644 index 0000000000000..430fd17f4d8a0 --- /dev/null +++ b/python/tvm/ir/diagnostics/_ffi_api.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI for TVM diagnostics.""" +import tvm._ffi + + +tvm._ffi._init_api("diagnostics", __name__) diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index 378991a96b6a3..352f8aaf04b60 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -72,9 +72,9 @@ def __setitem__(self, var, val): val: Union[Function, Type] The value. """ - return self._add(var, val) + return self._add(var, val, True) - def _add(self, var, val, update=False): + def _add(self, var, val, update=True): if isinstance(val, _expr.RelayExpr): if isinstance(var, string_types): if _ffi_api.Module_ContainGlobalVar(self, var): @@ -116,7 +116,8 @@ def update(self, other): The module to merge into the current Module. """ if isinstance(other, dict): - other = Module(other) + other = IRModule(other) + return _ffi_api.Module_Update(self, other) def update_func(self, var, func): @@ -210,6 +211,11 @@ def get_constructor(self, tag): """ return _ffi_api.Module_LookupTag(self, tag) + def get_type(self, name): + ty_var = self.get_global_type_var(name) + ty_data = self.type_definitions[ty_var] + return tuple([ty_var] + list(ty_data.constructors)) + @staticmethod def from_expr(expr, functions=None, type_defs=None): """Construct a module from a standalone expression. @@ -240,4 +246,13 @@ def _import(self, file_to_import): return _ffi_api.Module_Import(self, file_to_import) def import_from_std(self, file_to_import): - return _ffi_api.Module_ImportFromStd(self, file_to_import) + # TODO(@jroesch): clean up prelude + _ffi_api.Module_ImportFromStd(self, file_to_import) + return tvm.relay.transform.InferType()(self) + + def __str__(self): + # TODO(jroesch): why does this hang sometimes? + return self.astext() + + def __repr__(self): + return self.astext() diff --git a/python/tvm/parser/__init__.py b/python/tvm/parser/__init__.py index 11892339d2d75..60fcddb17f08b 100644 --- a/python/tvm/parser/__init__.py +++ b/python/tvm/parser/__init__.py @@ -14,10 +14,18 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=invalid-name """The under development unified IR parsing infrastructure.""" +from .. import _ffi, Object from . import _ffi_api +@_ffi.register_object("SourceMap") +class SourceMap(Object): + def add(self, name, content): + return _ffi.get_global_func("SourceMapAdd")(self, name, content) + + def parse(source, source_name="from_string"): return _ffi_api.ParseModule(source_name, source) @@ -28,3 +36,8 @@ def parse_expr(source): def fromtext(source, source_name="from_string"): return parse(source, source_name) + + +def SpanCheck(): + """A debugging utility for reporting missing span information.""" + return _ffi_api.SpanCheck() diff --git a/python/tvm/relay/backend/graph_runtime_factory.py b/python/tvm/relay/backend/graph_runtime_factory.py index 681a8426dc4ce..4c6ac47b71b4c 100644 --- a/python/tvm/relay/backend/graph_runtime_factory.py +++ b/python/tvm/relay/backend/graph_runtime_factory.py @@ -73,8 +73,8 @@ def __getitem__(self, item): def __iter__(self): warnings.warn( - "legacy graph runtime behaviour of producing json / lib / params will be " - " removed in the next release." + "legacy graph runtime behavior of producing json / lib / params will be " + "removed in the next release." " Please see documents of tvm.contrib.graph_runtime.GraphModule for the " " new recommended usage.", DeprecationWarning, diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index 218bc9f47f944..ba09094afca1f 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -213,13 +213,15 @@ def optimize(self): """ seq = tvm.transform.Sequential( [ + # tvm.parser.AnnotateSpans(), transform.SimplifyInference(), transform.FuseOps(0), transform.ToANormalForm(), transform.InferType(), ] ) - return seq(self.mod) + mod = seq(self.mod) + return mod def _make_executor(self, expr=None): if expr is None or isinstance(expr, GlobalVar): diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 0b68c8ebbfe22..bd0c3e5f4d733 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -31,6 +31,7 @@ from . import ty as _ty from . import expr as _expr from . import function as _function +from .transform import InferType from .backend import graph_runtime_factory as _graph_runtime_factory from .backend import interpreter as _interpreter from .backend.vm import VMExecutor @@ -363,6 +364,7 @@ def __init__(self, mod, ctx, target): def _make_executor(self, expr=None): if expr: self.mod["main"] = expr + self.mod = InferType()(self.mod) ret_type = self.mod["main"].checked_type.ret_type if _ty.is_dynamic(ret_type): raise ValueError("Graph Runtime only supports static graphs, got output type", ret_type) diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index 19ad5954e4206..7178bff2c1145 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -151,7 +151,7 @@ def partition( check: Callable[[Expr], bool] = lambda x: True, ) -> Expr: """ - Parition the expression into functions defined by this pattern + Partition the expression into functions defined by this pattern Parameters ---------- diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 027d6bd76141a..b27c759b8d03a 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -478,7 +478,8 @@ def infer_type(node, mod=None): new_mod = IRModule.from_expr(node) if mod is not None: new_mod.update(mod) - new_mod = _transform.InferType()(new_mod) + + new_mod = _transform.InferType()(new_mod) entry = new_mod["main"] ret = entry if isinstance(node, _function.Function) else entry.body diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 063158f35dc3e..c82b487acff31 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -57,22 +57,26 @@ def _convert_to_list_adt(py_lst, prelude): msg = "List elements should have identical types" assert all(map(lambda ty: ty == elem_tys[0], elem_tys)), msg - adt_lst = prelude.nil() + # get_type returns type_name, ctor1, ..., ctorN + # 1 is nil + _, cons, nil = prelude.mod.get_type("List") + adt_lst = nil() for elem in reversed(py_lst): - adt_lst = prelude.cons(elem, adt_lst) + adt_lst = cons(elem, adt_lst) return adt_lst def _map_tensor_array_constructor(adt_lst, prelude, shape): static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape) static_tensor_array_ops.register() - tensor_create = prelude.get_var_static("tensor_constructor", "float32", shape) + tensor_create = prelude.get_tensor_ctor_static("tensor_constructor", "float32", shape) return prelude.map(tensor_create, adt_lst) def _convert_to_tensor_array(adt_lst, prelude): + _, cons, nil = prelude.mod.get_type("List") if prelude.length(adt_lst) == 0: - return prelude.nil() + return nil() checked_type = _infer_type_with_prelude(prelude.hd(adt_lst), prelude) shape = checked_type.shape @@ -262,12 +266,12 @@ def tensor_array_concat(lst, axis): assert axis == 0, "Tensor array concat supported only for axis 0" tensor_array, shape = _convert_to_tensor_array(lst, prelude) concat_shape = (Any(),) + shape[1:] - concat = prelude.get_var_static("tensor_array_concat", "float32", shape) + concat = prelude.get_global_var_static("tensor_array_concat", "float32", shape) concatenated = concat(tensor_array) static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", concat_shape) static_tensor_array_ops.register() - get_tensor = prelude.get_var_static("tensor_get_data", "float32", concat_shape) + get_tensor = prelude.get_global_var_static("tensor_get_data", "float32", concat_shape) return get_tensor(concatenated) def _impl(inputs, input_types): @@ -2041,12 +2045,12 @@ def _impl(inputs, input_types): tensor_array, shape = _convert_to_tensor_array(inputs[0], prelude) stacked_shape = (Any(),) + shape - stack = prelude.get_var_static("tensor_array_stack", "float32", shape) + stack = prelude.get_global_var_static("tensor_array_stack", "float32", shape) stacked = stack(tensor_array) static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", stacked_shape) static_tensor_array_ops.register() - get_tensor = prelude.get_var_static("tensor_get_data", "float32", stacked_shape) + get_tensor = prelude.get_global_var_static("tensor_get_data", "float32", stacked_shape) return get_tensor(stacked) return _impl @@ -2897,7 +2901,8 @@ def get_relay_ty(ishape, itype, pt_type): if len(elem_tys) > 0 and not all(map(lambda ty: ty == elem_tys[0], elem_tys)): msg = "List elements need have identical types" raise RuntimeError(msg) - return prelude.l(elem_tys[0]) + rlist, _, _ = prelude.mod.get_type("List") + return rlist(elem_tys[0]) elif pt_type.kind() == "OptionalType": # we do not support None yet, so we fill in the type return get_relay_ty(ishape, itype, pt_type.getElementType()) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 1cd14c33b8b1f..9fe5fa01091a6 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -26,6 +26,7 @@ from tvm.ir import IRModule from tvm.relay.prelude import Prelude, StaticTensorArrayOps, get_tensor_array_shape +from tvm.relay.transform import InferType from tvm.topi.util import get_const_tuple from .. import analysis @@ -924,10 +925,10 @@ def _impl(inputs, attr, params, prelude): shape = attr["shape"] static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, shape) static_tensor_array_ops.register() - tensor_array_constructor = prelude.get_var_static("tensor_array", dtype_str, shape) + tensor_array_constructor = static_tensor_array_ops.get_global_var("tensor_array") tensor_array = tensor_array_constructor(inputs[0]) else: - tensor_array_constructor = prelude.get_var("tensor_array", dtype_str) + tensor_array_constructor = prelude.get_global_var("tensor_array", dtype_str) tensor_array = tensor_array_constructor(inputs[0]) return tensor_array @@ -946,9 +947,9 @@ def _impl(inputs, attr, params, prelude): if input_shape is None: values_rank = len(values_shape) unstack_name = "tensor_array_unstack_tensor{}".format(values_rank) - unstack_function = prelude.get_var(unstack_name, dtype_str) + unstack_function = prelude.get_global_var(unstack_name, dtype_str) values = unstack_function(inputs[2]) - tensor_array_scatter_func = prelude.get_var("tensor_array_scatter", dtype_str) + tensor_array_scatter_func = prelude.get_global_var("tensor_array_scatter", dtype_str) else: input_t_shape = _get_more_static_shape(input_t_shape, input_shape) values_shape = (values_shape[0],) + input_t_shape @@ -957,13 +958,13 @@ def _impl(inputs, attr, params, prelude): # Register static indices shape if isinstance(indices_shape[0], int): static_tensor_array_ops.define_tensor_array_scatter(indices_shape, True) - tensor_array_scatter_func = prelude.get_var_static( + tensor_array_scatter_func = prelude.get_global_var_static( "tensor_array_scatter", dtype_str, input_t_shape ) static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, values_shape) static_tensor_array_ops.register() - unstack_function = prelude.get_var_static( + unstack_function = prelude.get_global_var_static( "tensor_array_unstack", dtype_str, values_shape ) values = unstack_function(inputs[2]) @@ -987,7 +988,7 @@ def _impl(inputs, attr, params, prelude): static_tensor_array_ops.register() if not isinstance(indices_shape[0], int): - gather_function = prelude.get_var_static( + gather_function = prelude.get_global_var_static( "tensor_array_gather", dtype_str, input_shape ) out_tensor_t = gather_function(inputs[2], inputs[1]) @@ -996,12 +997,18 @@ def _impl(inputs, attr, params, prelude): static_tensor_array_ops.register() # Output shape is (indices_shape[0],) + input_shape - get_data_func = prelude.get_var_static("tensor_get_data", dtype_str, out_shape) + get_data_func = prelude.get_global_var_static( + "tensor_get_data", dtype_str, out_shape + ) out = get_data_func(out_tensor_t) else: # For fixed length indices, directly generate static shape output - read_func = prelude.get_var_static("tensor_array_read", dtype_str, input_shape) - get_data_func = prelude.get_var_static("tensor_get_data", dtype_str, input_shape) + read_func = prelude.get_global_var_static( + "tensor_array_read", dtype_str, input_shape + ) + get_data_func = prelude.get_global_var_static( + "tensor_get_data", dtype_str, input_shape + ) tensor_list = [] for i in range(indices_shape[0]): index = _op.take(inputs[1], tvm.relay.const(i)) @@ -1035,9 +1042,9 @@ def _impl(inputs, attr, params, prelude): if input_ta_shape is None: tensor_name = "tensor{}".format(input_rank) - tensor_func = prelude.get_var(tensor_name, dtype_str) + tensor_func = prelude.get_tensor_ctor(tensor_name, dtype_str) v = tensor_func(inputs[2]) - write_func = prelude.get_var("tensor_array_write", dtype_str) + write_func = prelude.get_global_var("tensor_array_write", dtype_str) else: input_ta_rank = len(input_ta_shape) assert input_ta_rank == input_rank, "Shape rank mismatch: {} vs {}".format( @@ -1045,8 +1052,7 @@ def _impl(inputs, attr, params, prelude): ) static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, input_ta_shape) static_tensor_array_ops.register() - - tensor_func = prelude.get_var_static("tensor_constructor", dtype_str, input_ta_shape) + tensor_func = static_tensor_array_ops.get_ctor("tensor_constructor") v = tensor_func(inputs[2]) # Write tensor with more static shape actual_shape = _get_more_static_shape(input_t_shape, input_ta_shape) @@ -1060,7 +1066,9 @@ def _impl(inputs, attr, params, prelude): if num_any_dim <= 1: v = tensor_func(_op.reshape(inputs[2], new_shape)) - write_func = prelude.get_var_static("tensor_array_write", dtype_str, input_ta_shape) + write_func = prelude.get_global_var_static( + "tensor_array_write", dtype_str, input_ta_shape + ) return write_func(input_ta, _op.take(inputs[1], tvm.relay.const(0)), v) @@ -1073,14 +1081,14 @@ def _impl(inputs, attr, params, prelude): input_shape = get_tensor_array_shape(inputs[2], dtype_str, prelude) if input_shape is None: - read_func = prelude.get_var("tensor_array_read", dtype_str) + read_func = prelude.get_global_var("tensor_array_read", dtype_str) out = read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0))) else: static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, input_shape) static_tensor_array_ops.register() - read_func = prelude.get_var_static("tensor_array_read", dtype_str, input_shape) + read_func = static_tensor_array_ops.get_global_var("tensor_array_read") out_tensor = read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0))) - get_data_func = prelude.get_var_static("tensor_get_data", dtype_str, input_shape) + get_data_func = static_tensor_array_ops.get_global_var("tensor_get_data") out = get_data_func(out_tensor) return out @@ -1099,8 +1107,10 @@ def _impl(inputs, attr, params, prelude): input_rank = len(value_shape) if input_ta_shape is None: - v = prelude.get_var("tensor{}".format(input_rank), dtype_str)(inputs[1]) - split_func = prelude.get_var("tensor_array_split", dtype_str) + tensor_name = "tensor{}".format(input_rank) + tensor_ctor = prelude.get_tensor_ctor(tensor_name, dtype_str) + v = tensor_ctor(inputs[1]) + split_func = prelude.get_global_var("tensor_array_split", dtype_str) else: input_ta_rank = len(input_ta_shape) assert input_ta_rank == input_rank, "Shape rank mismatch: {} vs {}".format( @@ -1113,13 +1123,13 @@ def _impl(inputs, attr, params, prelude): if isinstance(value_shape[0], int) or isinstance(lengths_shape[0], int): static_tensor_array_ops.define_tensor_array_split(value_shape, lengths_shape, True) - tensor_func_name = prelude.get_name_static("tensor_constructor", dtype_str, value_shape) - if not hasattr(prelude, tensor_func_name): - static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, value_shape) - static_tensor_array_ops.register() - tensor_func = prelude.get_var_static("tensor_constructor", dtype_str, value_shape) - v = tensor_func(inputs[1]) - split_func = prelude.get_var_static("tensor_array_split", dtype_str, input_ta_shape) + static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, value_shape) + static_tensor_array_ops.register() + tensor_ctor = static_tensor_array_ops.get_ctor("tensor_constructor") + v = tensor_ctor(inputs[1]) + split_func = prelude.get_global_var_static( + "tensor_array_split", dtype_str, input_ta_shape + ) return split_func(input_ta, v, lengths) @@ -1132,17 +1142,19 @@ def _impl(inputs, attr, params, prelude): input_shape = get_tensor_array_shape(inputs[1], dtype_str, prelude) if input_shape is None: - concat_func = prelude.get_var("tensor_array_concat", dtype_str) + concat_func = prelude.get_global_var("tensor_array_concat", dtype_str) out = concat_func(inputs[1]) else: static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, input_shape) static_tensor_array_ops.register() - concat_func = prelude.get_var_static("tensor_array_concat", dtype_str, input_shape) + concat_func = prelude.get_global_var_static( + "tensor_array_concat", dtype_str, input_shape + ) out_tensor = concat_func(inputs[1]) out_shape = (Any(),) + input_shape[1:] static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, out_shape) static_tensor_array_ops.register() - get_data_func = prelude.get_var_static("tensor_get_data", dtype_str, out_shape) + get_data_func = prelude.get_global_var_static("tensor_get_data", dtype_str, out_shape) out = get_data_func(out_tensor) return out @@ -3257,6 +3269,7 @@ def _partition_call_operator(self, inputs, attr): func_expr = _function.Function(sub_func.params, sub_func.body) global_func = tvm.relay.GlobalVar(func_name) main_graph_proto._mod[global_func] = func_expr + main_graph_proto._mod = InferType()(main_graph_proto._mod) param_exprs = [] for param_expr in sub_func.params: diff --git a/python/tvm/relay/op/contrib/arm_compute_lib.py b/python/tvm/relay/op/contrib/arm_compute_lib.py index 77fdbbd4006cd..586d98d8e1c85 100644 --- a/python/tvm/relay/op/contrib/arm_compute_lib.py +++ b/python/tvm/relay/op/contrib/arm_compute_lib.py @@ -59,6 +59,7 @@ def partition_for_arm_compute_lib(mod, params=None): seq = tvm.transform.Sequential( [ + transform.InferType(), transform.MergeComposite(arm_compute_lib_pattern_table()), transform.AnnotateTarget("arm_compute_lib"), transform.PartitionGraph(), diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 81d82ddb04d2f..376bf4a4804d8 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -17,9 +17,10 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """A prelude containing useful global functions and ADT definitions.""" from tvm.ir import IRModule, TypeCall +from tvm.tir import Any from tvm.relay.transform import ToANormalFormExpr -from .ty import GlobalTypeVar, TensorType, Any, scalar_type +from .ty import GlobalTypeVar, TensorType, scalar_type from .expr import Var, GlobalVar, If, const from .function import Function from .op.tensor import add, subtract, equal @@ -64,7 +65,7 @@ def get_tensor_array_shape(expr, dtype, prelude): shape = [] if "scalar" not in shape_str: for dim_str in shape_str.split("_"): - if dim_str == "?": + if dim_str in ["?", "any"]: shape.append(Any()) else: shape.append(int(dim_str)) @@ -75,7 +76,15 @@ def get_tensor_array_shape(expr, dtype, prelude): def _get_name_static(canonical, dtype, shape): """Get name for static shape tensor array op corresponding to the canonical name""" - shape_str = "_".join([str(dim) for dim in shape]) + dim_names = [] + for dim in shape: + if isinstance(dim, Any): + dim_names.append("any") + else: + dim_names.append(str(dim)) + + shape_str = "_".join(dim_names) + if len(shape_str) == 0: shape_str = "scalar" if canonical == "tensor_t": @@ -91,40 +100,53 @@ def __init__(self, prelude, dtype, shape): self.prelude = prelude self.dtype = dtype self.shape = shape + self.list, self.cons, self.nil = self.prelude.mod.get_type("List") def get_name(self, canonical): """Get name corresponding to the canonical name""" return _get_name_static(canonical, self.dtype, self.shape) - def get_var(self, canonical): - """Get var corresponding to the canonical name""" - name = self.get_name(canonical) - return getattr(self.prelude, name) + def get_global_var(self, canonical): + """Get global corresponding to the canonical name""" + return self.prelude.get_global_var_static(canonical, self.dtype, self.shape) + + def get_type(self, canonical): + """Get type corresponding to the canonical name""" + return self.prelude.get_type_static(canonical, self.dtype, self.shape) + + def get_ctor(self, canonical): + """Get ctor corresponding to the canonical name""" + return self.prelude.get_ctor_static("tensor_t", canonical, self.dtype, self.shape) def define_tensor_adt(self): """Defines the static tensor ADT, which is the container for tensors with fixed shapes.""" tensor_type_name = self.get_name("tensor_t") + + # This is effectively functioning as a monomorphizer. + # TODO(@jroesch): we should add full shape polymoprhism + # and do monomorphization. + # # Skip register if tensor type is already registered. global_type_names = set() for g_ty_var in self.prelude.mod.get_global_type_vars(): global_type_names.add(g_ty_var.name_hint) + if tensor_type_name in global_type_names: + self.tensor_type_var = self.get_type("tensor_t") return - tensor_type_var = GlobalTypeVar(tensor_type_name) - setattr(self.prelude, tensor_type_name, tensor_type_var) + self.tensor_type_var = GlobalTypeVar(tensor_type_name) + tensor_type = TensorType(self.shape, self.dtype) tensor_constructor_name = self.get_name("tensor_constructor") tensor_nil_name = self.get_name("tensor_nil") - tensor_nil_case = Constructor(tensor_nil_name, [], tensor_type_var) - tensor_case = Constructor(tensor_constructor_name, [tensor_type], tensor_type_var) + tensor_nil_case = Constructor(tensor_nil_name, [], self.tensor_type_var) + tensor_case = Constructor(tensor_constructor_name, [tensor_type], self.tensor_type_var) - setattr(self.prelude, tensor_nil_name, tensor_nil_case) - setattr(self.prelude, tensor_constructor_name, tensor_case) - self.prelude.mod[tensor_type_var] = TypeData( - tensor_type_var, [], [tensor_nil_case, tensor_case] + self.prelude.mod[self.tensor_type_var] = TypeData( + self.tensor_type_var, [], [tensor_nil_case, tensor_case] ) def define_tensor_array(self): @@ -133,19 +155,17 @@ def define_tensor_array(self): """ tensor_array_constructor_name = self.get_name("tensor_array") tensor_array_constructor_var = self._create_global_var(tensor_array_constructor_name) - setattr(self.prelude, tensor_array_constructor_name, tensor_array_constructor_var) - tensor_nil_var = self.get_var("tensor_nil") - tensor_type_var = self.get_var("tensor_t") + + tensor_nil_var = self.get_ctor("tensor_nil") + tensor_type_var = self.get_ctor("tensor_t") n = Var("x", scalar_type("int32")) body = If( equal(n, const(0)), - self.prelude.nil(), - self.prelude.cons( - tensor_nil_var(), tensor_array_constructor_var(subtract(n, const(1))) - ), + self.nil(), + self.cons(tensor_nil_var(), tensor_array_constructor_var(subtract(n, const(1)))), ) self.prelude.mod[tensor_array_constructor_var] = Function( - [n], body, self.prelude.l(tensor_type_var()), [] + [n], body, self.list(tensor_type_var()), [] ) def define_tensor_take(self): @@ -159,16 +179,20 @@ def define_tensor_take(self): return take_name = self.get_name("tensor_take") - take_var = self._create_global_var(take_name) - setattr(self.prelude, take_name, take_var) - origin_tensor_constructor = self.get_var("tensor_constructor") + + if self.is_cached(take_name): + return + + take_var = GlobalVar(take_name) + + origin_tensor_constructor = self.get_ctor("tensor_constructor") output_shape = [ Any(), ] + list(self.shape[1:]) - tensor_type_var, tensor_constructor = self._get_adt_by_shape(output_shape) + tensor_type_var, tensor_constructor, _ = self._get_adt_by_shape(output_shape) - t = Var("tensor", self.get_var("tensor_t")()) + t = Var("tensor", self.tensor_type_var()) lower = Var("lower", scalar_type("int32")) upper = Var("upper", scalar_type("int32")) tvar = Var("t") @@ -190,15 +214,17 @@ def define_tensor_concatenate(self): return concat_name = self.get_name("tensor_concatenate") - concat_var = self._create_global_var(concat_name) - setattr(self.prelude, concat_name, concat_var) + concat_var = GlobalVar(concat_name) + if self.is_cached(concat_name): + return + output_shape = [ Any(), ] + list(self.shape[1:]) - tensor_type_var, tensor_constructor = self._get_adt_by_shape(output_shape) + tensor_type_var, tensor_constructor, _ = self._get_adt_by_shape(output_shape) - origin_tensor_constructor = self.get_var("tensor_constructor") - origin_tensor_type_var = self.get_var("tensor_t") + origin_tensor_constructor = self.get_ctor("tensor_constructor") + origin_tensor_type_var = self.tensor_type_var x = Var("x", origin_tensor_type_var()) y = Var("y", origin_tensor_type_var()) t1 = Var("t1") @@ -230,13 +256,13 @@ def define_tensor_expand_dims(self): expand_dims_name = self.get_name("tensor_expand_dims") expand_dims_var = self._create_global_var(expand_dims_name) setattr(self.prelude, expand_dims_name, expand_dims_var) - origin_tensor_type_var = self.get_var("tensor_t") - origin_tensor_constructor = self.get_var("tensor_constructor") + origin_tensor_type_var = self.tensor_type_var + origin_tensor_constructor = self.get_ctor("tensor_constructor") x = Var("x", origin_tensor_type_var()) # Note: we set the added axis to be Any() instead of 1 due to # in stack op, we need to recursively concatenate. - tensor_type_var, tensor_constructor = self._get_adt_by_shape( + tensor_type_var, tensor_constructor, _ = self._get_adt_by_shape( [ Any(), ] @@ -259,16 +285,25 @@ def define_tensor_array_read(self): Tensor[self.shape, self.dtype] """ read_name = self.get_name("tensor_array_read") - read_var = self._create_global_var(read_name) - setattr(self.prelude, read_name, read_var) - tensor_type_var = self.get_var("tensor_t") - tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) + if self.is_cached(read_name): + return + + read_var = GlobalVar(read_name) + + tensor_array = Var("tensor_array", self.list(self.tensor_type_var())) n = Var("x", scalar_type("int32")) self.prelude.mod[read_var] = Function( - [tensor_array, n], self.prelude.nth(tensor_array, n), tensor_type_var(), [] + [tensor_array, n], self.prelude.nth(tensor_array, n), self.tensor_type_var(), [] ) + def is_cached(self, name): + try: + self.prelude.mod.get_global_var(name) + return True + except ValueError: + return False + def define_tensor_array_write(self): """Defines a function to update a tensor array at index n with value v. tensor_array_write(ta, n, v) : @@ -276,16 +311,17 @@ def define_tensor_array_write(self): list[static_tensor_t] """ write_name = self.get_name("tensor_array_write") - write_var = self._create_global_var(write_name) - setattr(self.prelude, write_name, write_var) - tensor_type_var = self.get_var("tensor_t") - tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) + if self.is_cached(write_name): + return + + write_var = GlobalVar(write_name) + tensor_array = Var("tensor_array", self.list(self.tensor_type_var())) n = Var("x", scalar_type("int32")) - v = Var("v", tensor_type_var()) + v = Var("v", self.tensor_type_var()) self.prelude.mod[write_var] = Function( [tensor_array, n, v], self.prelude.update(tensor_array, n, v), - self.prelude.l(tensor_type_var()), + self.list(self.tensor_type_var()), [], ) @@ -306,17 +342,17 @@ def define_tensor_array_unstack(self): i = Var("i", scalar_type("int32")) tensor_var = Var("tensor", TensorType(self.shape, self.dtype)) - reduced_tensor_type_var, tensor_constructor = self._get_adt_by_shape(self.shape[1:]) + reduced_tensor_type_var, tensor_constructor, _ = self._get_adt_by_shape(self.shape[1:]) helper_body = If( equal(i, up), - self.prelude.nil(), - self.prelude.cons( + self.nil(), + self.cons( tensor_constructor(op.take(tensor, i, axis=0)), helper_var(add(i, const(1)), up, tensor), ), ) self.prelude.mod[helper_var] = Function( - [i, up, tensor], helper_body, self.prelude.l(reduced_tensor_type_var()), [] + [i, up, tensor], helper_body, self.list(reduced_tensor_type_var()), [] ) unstack_name = self.get_name("tensor_array_unstack") @@ -327,7 +363,7 @@ def define_tensor_array_unstack(self): self.prelude.mod[unstack_var] = Function( [tensor_var], helper_var(const(0), unstack_length, tensor_var), - self.prelude.l(reduced_tensor_type_var()), + self.list(reduced_tensor_type_var()), [], ) @@ -348,14 +384,13 @@ def define_tensor_array_scatter(self, indices_shape=None, force_update=False): tensor_array_scatter_helper_name = self.get_name("tensor_array_scatter_helper") tensor_array_scatter_helper_var = self._create_global_var(tensor_array_scatter_helper_name) - tensor_type_var = self.get_var("tensor_t") - ta = Var("ta", self.prelude.l(tensor_type_var())) + ta = Var("ta", self.list(self.tensor_type_var())) current = Var("current", scalar_type("int32")) limit = Var("limit", scalar_type("int32")) indices_ = Var("indices_", TensorType(indices_shape or [Any()], "int32")) - values_ = Var("values_", self.prelude.l(tensor_type_var())) - write_var = self.get_var("tensor_array_write") - read_var = self.get_var("tensor_array_read") + values_ = Var("values_", self.list(self.tensor_type_var())) + write_var = self.get_global_var("tensor_array_write") + read_var = self.get_global_var("tensor_array_read") helper_body = If( equal(current, limit), ta, @@ -370,16 +405,16 @@ def define_tensor_array_scatter(self, indices_shape=None, force_update=False): self.prelude.mod[tensor_array_scatter_helper_var] = Function( [ta, current, limit, indices_, values_], helper_body, - self.prelude.l(tensor_type_var()), + self.list(self.tensor_type_var()), [], ) tensor_array_scatter_var = self._create_global_var(tensor_array_scatter_name) setattr(self.prelude, tensor_array_scatter_name, tensor_array_scatter_var) - tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) + tensor_array = Var("tensor_array", self.list(self.tensor_type_var())) indices = Var("indices", TensorType(indices_shape or [Any()], "int32")) - values = Var("values", self.prelude.l(tensor_type_var())) + values = Var("values", self.list(self.tensor_type_var())) if indices_shape is None: indices_shape = op.shape_of(indices) limit = op.take(indices_shape, const(0)) @@ -388,7 +423,7 @@ def define_tensor_array_scatter(self, indices_shape=None, force_update=False): body = tensor_array_scatter_helper_var(tensor_array, const(0), limit, indices, values) self.prelude.mod[tensor_array_scatter_var] = Function( - [tensor_array, indices, values], body, self.prelude.l(tensor_type_var()), [] + [tensor_array, indices, values], body, self.list(self.tensor_type_var()), [] ) def define_tensor_array_split(self, value_shape=None, lengths_shape=None, force_update=False): @@ -408,43 +443,39 @@ def define_tensor_array_split(self, value_shape=None, lengths_shape=None, force_ # when force_update is set. This should be used only when we need to # redefine this op for static value/indices shape. split_name = self.get_name("tensor_array_split") - if hasattr(self.prelude, split_name) and not force_update: - return - tensor_type_var = self.get_var("tensor_t") - tensor_array_split_helper_name = self.get_name("ta_split_helper") - tensor_array_split_helper_var = self._create_global_var(tensor_array_split_helper_name) - setattr(self.prelude, tensor_array_split_helper_name, tensor_array_split_helper_var) + if self.is_cached(split_name): + if not force_update: + return + tensor_array_split_helper_var = self.get_global_var("ta_split_helper") + split_var = self.get_global_var("tensor_array_split") + else: + tensor_array_split_helper_name = self.get_name("ta_split_helper") + tensor_array_split_helper_var = GlobalVar(tensor_array_split_helper_name) + split_var = GlobalVar(split_name) + output_shape = [ Any(), ] + list(self.shape[1:]) - output_tensor_type_var, _ = self._get_adt_by_shape(output_shape) + output_tensor_type_var, _, output_ops = self._get_adt_by_shape(output_shape) + output_ops.define_tensor_array_write() + write_var = output_ops.get_global_var("tensor_array_write") if value_shape is None: - value_type_var = tensor_type_var - take_var = self.get_var("tensor_take") + value_type_var = self.tensor_type_var + take_var = self.get_global_var("tensor_take") else: - value_type_var, _ = self._get_adt_by_shape(value_shape) - # Also get static shape take operator - origin_shape = list(self.shape) - self.shape = value_shape - self.define_tensor_take() - take_var = self.get_var("tensor_take") - self.shape = origin_shape - - ta1 = Var("tensor_array", self.prelude.l(output_tensor_type_var())) + value_type_var, _, value_adts = self._get_adt_by_shape(value_shape) + value_adts.define_tensor_take() + take_var = value_adts.get_global_var("tensor_take") + + ta1 = Var("tensor_array", self.list(output_tensor_type_var())) value1 = Var("value1", value_type_var()) offset1 = Var("offset1", scalar_type("int32")) current1 = Var("current1", scalar_type("int32")) limit1 = Var("limit1", scalar_type("int32")) lengths1 = Var("lengths", TensorType(lengths_shape or [Any()], "int32")) - # Register write for output shape - origin_shape = list(self.shape) - self.shape = output_shape - self.define_tensor_array_write() - write_var = self.get_var("tensor_array_write") - self.shape = origin_shape helper1_body = If( equal(current1, limit1), ta1, @@ -461,15 +492,14 @@ def define_tensor_array_split(self, value_shape=None, lengths_shape=None, force_ take_var(value1, offset1, add(op.take(lengths1, current1), offset1)), ), ) + self.prelude.mod[tensor_array_split_helper_var] = Function( [ta1, value1, offset1, current1, limit1, lengths1], helper1_body, - self.prelude.l(output_tensor_type_var()), + self.list(output_tensor_type_var()), [], ) - split_var = self._create_global_var(split_name) - setattr(self.prelude, split_name, split_var) - tensor_array = Var("tensor_array", self.prelude.l(output_tensor_type_var())) + tensor_array = Var("tensor_array", self.list(output_tensor_type_var())) value = Var("value", value_type_var()) lengths = Var("lengths", TensorType(lengths_shape or [Any()], "int32")) @@ -483,7 +513,7 @@ def define_tensor_array_split(self, value_shape=None, lengths_shape=None, force_ ) self.prelude.mod[split_var] = Function( - [tensor_array, value, lengths], body, self.prelude.l(output_tensor_type_var()), [] + [tensor_array, value, lengths], body, self.list(output_tensor_type_var()), [] ) def define_tensor_array_concat(self): @@ -496,32 +526,33 @@ def define_tensor_array_concat(self): return concat_name = self.get_name("tensor_array_concat") - concat_var = self._create_global_var(concat_name) - setattr(self.prelude, concat_name, concat_var) + + if self.is_cached(concat_name): + return + + concat_var = GlobalVar(concat_name) output_shape = [ Any(), ] + list(self.shape[1:]) - tensor_type_var, _ = self._get_adt_by_shape(output_shape) + + tensor_type_var, _, output_ops = self._get_adt_by_shape(output_shape) # Register tensor concatenate and get tensor_nil var for output shape - origin_shape = self.shape - self.shape = output_shape - self.define_tensor_concatenate() - tensor_concat_var = self.get_var("tensor_concatenate") - tensor_nil_var = self.get_var("tensor_nil") - self.shape = origin_shape + output_ops.define_tensor_concatenate() + tensor_concat_var = output_ops.get_global_var("tensor_concatenate") + tensor_nil_var = output_ops.get_ctor("tensor_nil") - tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) + tensor_array = Var("tensor_array", self.list(tensor_type_var())) hd = Var("hd") tl = Var("tl") - nil_case = Clause(PatternConstructor(self.prelude.nil), tensor_nil_var()) + nil_case = Clause(PatternConstructor(self.nil), tensor_nil_var()) cons_case = Clause( - PatternConstructor(self.prelude.cons, [PatternVar(hd), PatternVar(tl)]), + PatternConstructor(self.cons, [PatternVar(hd), PatternVar(tl)]), Match( tl, [ - Clause(PatternConstructor(self.prelude.nil), hd), + Clause(PatternConstructor(self.nil), hd), Clause(PatternWildcard(), tensor_concat_var(hd, concat_var(tl))), ], False, @@ -538,19 +569,17 @@ def define_tensor_array_stack(self): stack_name = self.get_name("tensor_array_stack") stack_var = self._create_global_var(stack_name) setattr(self.prelude, stack_name, stack_var) - tensor_type_var = self.get_var("tensor_t") - tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) - expand_dims_var = self.get_var("tensor_expand_dims") + tensor_array = Var("tensor_array", self.list(self.tensor_type_var())) + expand_dims_var = self.get_global_var("tensor_expand_dims") # Register tensor_concatenate for output_shape - origin_shape = self.shape output_shape = [ Any(), ] + list(self.shape) - self.shape = output_shape - self.define_tensor_concatenate() - concat_var = self.get_var("tensor_concatenate") - self.shape = origin_shape + + _, _, output_ops = self._get_adt_by_shape(output_shape) + output_ops.define_tensor_concatenate() + concat_var = output_ops.get_global_var("tensor_concatenate") tensor_array_expand_dims = self.prelude.map(expand_dims_var, tensor_array) tensors = self.prelude.foldl( @@ -558,7 +587,7 @@ def define_tensor_array_stack(self): self.prelude.hd(tensor_array_expand_dims), self.prelude.tl(tensor_array_expand_dims), ) - output_tensor_type_var, _ = self._get_adt_by_shape(output_shape) + output_tensor_type_var, _, _ = self._get_adt_by_shape(output_shape) self.prelude.mod[stack_var] = Function( [tensor_array], tensors, output_tensor_type_var(), [] ) @@ -569,16 +598,15 @@ def define_tensor_array_gather(self): """ helper_name = self.get_name("tensor_array_gather_helper") helper_var = self._create_global_var(helper_name) - setattr(self.prelude, helper_name, helper_var) - tensor_type_var = self.get_var("tensor_t") + output_shape = [ Any(), ] + list(self.shape) - output_tensor_type_var, _ = self._get_adt_by_shape(output_shape) - stack_var = self.get_var("tensor_array_stack") - read_var = self.get_var("tensor_array_read") - ta = Var("ta", self.prelude.l(tensor_type_var())) - accu = Var("accu", self.prelude.l(tensor_type_var())) + output_tensor_type_var, _, _ = self._get_adt_by_shape(output_shape) + stack_var = self.get_global_var("tensor_array_stack") + read_var = self.get_global_var("tensor_array_read") + ta = Var("ta", self.list(self.tensor_type_var())) + accu = Var("accu", self.list(self.tensor_type_var())) current = Var("current", scalar_type("int32")) limit = Var("limit", scalar_type("int32")) indices_ = Var("indices_", TensorType([Any()], "int32")) @@ -587,9 +615,7 @@ def define_tensor_array_gather(self): stack_var(accu), helper_var( ta, - self.prelude.cons( - read_var(ta, op.take(indices_, subtract(current, const(1)))), accu - ), + self.cons(read_var(ta, op.take(indices_, subtract(current, const(1)))), accu), subtract(current, const(1)), limit, indices_, @@ -600,12 +626,12 @@ def define_tensor_array_gather(self): ) gather_name = self.get_name("tensor_array_gather") gather_var = self._create_global_var(gather_name) - setattr(self.prelude, gather_name, gather_var) - tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) + + tensor_array = Var("tensor_array", self.list(self.tensor_type_var())) indices = Var("indices", TensorType([Any()], "int32")) indices_shape = op.shape_of(indices) limit = op.take(indices_shape, const(0)) - body = helper_var(tensor_array, self.prelude.nil(), limit, limit, indices) + body = helper_var(tensor_array, self.nil(), limit, limit, indices) self.prelude.mod[gather_var] = Function( [tensor_array, indices], body, output_tensor_type_var(), [] ) @@ -614,10 +640,9 @@ def define_tensor_get_data(self): """Defines a function to get a Tensor from tensor_t with given shape.""" tensor_get_data_name = self.get_name("tensor_get_data") tensor_get_data_var = self._create_global_var(tensor_get_data_name) - setattr(self.prelude, tensor_get_data_name, tensor_get_data_var) - tensor_type_var = self.get_var("tensor_t") - tensor_constructor = self.get_var("tensor_constructor") - t = Var("tensor", tensor_type_var()) + + tensor_constructor = self.get_ctor("tensor_constructor") + t = Var("tensor", self.tensor_type_var()) tvar = Var("t") case = Clause(PatternConstructor(tensor_constructor, [PatternVar(tvar)]), tvar) self.prelude.mod[tensor_get_data_var] = Function( @@ -643,13 +668,11 @@ def register(self): def _get_adt_by_shape(self, shape): """Get ADT type and constructor with given shape.""" - origin_shape = self.shape - self.shape = shape - self.define_tensor_adt() - tensor_type_var = self.get_var("tensor_t") - tensor_constructor = self.get_var("tensor_constructor") - self.shape = origin_shape - return tensor_type_var, tensor_constructor + adt_ops = StaticTensorArrayOps(self.prelude, self.dtype, shape) + adt_ops.define_tensor_adt() + tensor_type_var = adt_ops.get_type("tensor_t") + tensor_constructor = adt_ops.get_ctor("tensor_constructor") + return tensor_type_var, tensor_constructor, adt_ops def _create_global_var(self, name): """Create a GlobalVar if doesn't exist in prelude.""" @@ -671,21 +694,30 @@ def __init__(self, prelude, dtype): """Create tensor array ops registry""" self.prelude = prelude self.dtype = dtype + self.list, self.cons, self.nil = self.prelude.mod.get_type("List") def get_name(self, canonical): """Get name corresponding to the canonical name""" return self.prelude.get_name(canonical, self.dtype) - def get_var(self, canonical): - """Get var corresponding to the canonical name""" - return self.prelude.get_var(canonical, self.dtype) + def get_global_var(self, canonical): + """Get global corresponding to the canonical name""" + return self.prelude.get_global_var(canonical, self.dtype) + + def get_type(self, canonical): + """Get type corresponding to the canonical name""" + return self.prelude.get_type(canonical, self.dtype) + + def get_ctor(self, canonical): + """Get ctor corresponding to the canonical name""" + return self.prelude.get_ctor(self.tensor_type_var.name_hint, canonical, self.dtype) def define_tensor_adt(self): """Defines the dynamic tensor ADT, which is the container for tensors with variable shapes.""" tensor_type_name = self.get_name("tensor_t") - tensor_type_var = GlobalTypeVar(tensor_type_name) - setattr(self.prelude, tensor_type_name, tensor_type_var) + self.tensor_type_var = tensor_type_var = GlobalTypeVar(tensor_type_name) + tensor0_type = TensorType([], self.dtype) tensor1_type = TensorType([Any()], self.dtype) tensor2_type = TensorType([Any(), Any()], self.dtype) @@ -709,14 +741,7 @@ def define_tensor_adt(self): tensor4_case = Constructor(tensor4_name, [tensor4_type], tensor_type_var) tensor5_case = Constructor(tensor5_name, [tensor5_type], tensor_type_var) tensor6_case = Constructor(tensor6_name, [tensor6_type], tensor_type_var) - setattr(self.prelude, tensor_nil_name, tensor_nil_case) - setattr(self.prelude, tensor0_name, tensor0_case) - setattr(self.prelude, tensor1_name, tensor1_case) - setattr(self.prelude, tensor2_name, tensor2_case) - setattr(self.prelude, tensor3_name, tensor3_case) - setattr(self.prelude, tensor4_name, tensor4_case) - setattr(self.prelude, tensor5_name, tensor5_case) - setattr(self.prelude, tensor6_name, tensor6_case) + self.prelude.mod[tensor_type_var] = TypeData( tensor_type_var, [], @@ -739,14 +764,15 @@ def define_tensor_take(self): """ take_name = self.get_name("tensor_take") take_var = GlobalVar(take_name) - setattr(self.prelude, take_name, take_var) - tensor_t = self.get_var("tensor_t") - tensor1_var = self.get_var("tensor1") - tensor2_var = self.get_var("tensor2") - tensor3_var = self.get_var("tensor3") - tensor4_var = self.get_var("tensor4") - tensor5_var = self.get_var("tensor5") - tensor6_var = self.get_var("tensor6") + + tensor_t = self.tensor_type_var + tensor1_var = self.get_ctor("tensor1") + tensor2_var = self.get_ctor("tensor2") + tensor3_var = self.get_ctor("tensor3") + tensor4_var = self.get_ctor("tensor4") + tensor5_var = self.get_ctor("tensor5") + tensor6_var = self.get_ctor("tensor6") + t = Var("tensor", tensor_t()) lower = Var("lower", scalar_type("int32")) upper = Var("upper", scalar_type("int32")) @@ -805,8 +831,8 @@ def define_tensor_expand_dims(self): """ expand_dims_name = self.get_name("tensor_expand_dims") expand_dims_var = GlobalVar(expand_dims_name) - setattr(self.prelude, expand_dims_name, expand_dims_var) - tensor_type_var = self.get_var("tensor_t") + tensor_type_var = self.tensor_type_var + x = Var("x", tensor_type_var()) t0 = Var("t0") t1 = Var("t1") @@ -814,13 +840,13 @@ def define_tensor_expand_dims(self): t3 = Var("t3") t4 = Var("t4") t5 = Var("t5") - tensor0_var = self.get_var("tensor0") - tensor1_var = self.get_var("tensor1") - tensor2_var = self.get_var("tensor2") - tensor3_var = self.get_var("tensor3") - tensor4_var = self.get_var("tensor4") - tensor5_var = self.get_var("tensor5") - tensor6_var = self.get_var("tensor6") + tensor0_var = self.get_ctor("tensor0") + tensor1_var = self.get_ctor("tensor1") + tensor2_var = self.get_ctor("tensor2") + tensor3_var = self.get_ctor("tensor3") + tensor4_var = self.get_ctor("tensor4") + tensor5_var = self.get_ctor("tensor5") + tensor6_var = self.get_ctor("tensor6") tensor0_case = Clause( PatternConstructor(tensor0_var, [PatternVar(t0)]), tensor1_var(op.expand_dims(t0, 0, 1)) ) @@ -853,6 +879,7 @@ def define_tensor_expand_dims(self): ], False, ), + tensor_type_var(), ) def define_tensor_concat(self): @@ -862,15 +889,15 @@ def define_tensor_concat(self): """ concat_name = self.get_name("tensor_concatenate") concat_var = GlobalVar(concat_name) - setattr(self.prelude, concat_name, concat_var) - tensor_type_var = self.get_var("tensor_t") + + tensor_type_var = self.tensor_type_var x = Var("x", tensor_type_var()) y = Var("y", tensor_type_var()) - tensor1_var = self.get_var("tensor1") - tensor2_var = self.get_var("tensor2") - tensor3_var = self.get_var("tensor3") - tensor4_var = self.get_var("tensor4") + tensor1_var = self.get_ctor("tensor1") + tensor2_var = self.get_ctor("tensor2") + tensor3_var = self.get_ctor("tensor3") + tensor4_var = self.get_ctor("tensor4") t11 = Var("t11") t12 = Var("t12") t21 = Var("t21") @@ -933,7 +960,9 @@ def define_tensor_concat(self): ) # op.concatenate does not support tensor with rank higher than 4 self.prelude.mod[concat_var] = Function( - [x, y], Match(x, [tensor1_case, tensor2_case, tensor3_case, tensor4_case], False) + [x, y], + Match(x, [tensor1_case, tensor2_case, tensor3_case, tensor4_case], False), + tensor_type_var(), ) def define_tensor_array(self): @@ -943,18 +972,16 @@ def define_tensor_array(self): tensor_array_constructor_name = self.get_name("tensor_array") tensor_array_constructor_var = GlobalVar(tensor_array_constructor_name) setattr(self.prelude, tensor_array_constructor_name, tensor_array_constructor_var) - tensor_nil_var = self.get_var("tensor_nil") - tensor_type_var = self.get_var("tensor_t") + tensor_nil_var = self.get_ctor("tensor_nil") + tensor_type_var = self.get_ctor("tensor_t") n = Var("x", scalar_type("int32")) body = If( equal(n, const(0)), - self.prelude.nil(), - self.prelude.cons( - tensor_nil_var(), tensor_array_constructor_var(subtract(n, const(1))) - ), + self.nil(), + self.cons(tensor_nil_var(), tensor_array_constructor_var(subtract(n, const(1)))), ) self.prelude.mod[tensor_array_constructor_var] = Function( - [n], body, self.prelude.l(tensor_type_var()), [] + [n], body, self.list(tensor_type_var()), [] ) def define_tensor_array_read(self): @@ -966,9 +993,9 @@ def define_tensor_array_read(self): read_name = self.get_name("tensor_array_read") read_var = GlobalVar(read_name) setattr(self.prelude, read_name, read_var) - tensor_type_var = self.get_var("tensor_t") + tensor_type_var = self.tensor_type_var - tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) + tensor_array = Var("tensor_array", self.list(tensor_type_var())) n = Var("x", scalar_type("int32")) self.prelude.mod[read_var] = Function( [tensor_array, n], self.prelude.nth(tensor_array, n), tensor_type_var(), [] @@ -981,15 +1008,15 @@ def define_tensor_array_write(self): """ write_name = self.get_name("tensor_array_write") write_var = GlobalVar(write_name) - setattr(self.prelude, write_name, write_var) - tensor_type_var = self.get_var("tensor_t") - tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) + + tensor_type_var = self.tensor_type_var + tensor_array = Var("tensor_array", self.list(tensor_type_var())) n = Var("x", scalar_type("int32")) v = Var("v", tensor_type_var()) self.prelude.mod[write_var] = Function( [tensor_array, n, v], self.prelude.update(tensor_array, n, v), - self.prelude.l(tensor_type_var()), + self.list(tensor_type_var()), [], ) @@ -999,30 +1026,26 @@ def define_tensor_array_unstack_tensor1(self): """ helper_name = self.get_name("tensor_array_unstack_tensor1_helper") helper_var = GlobalVar(helper_name) - setattr(self.prelude, helper_name, helper_var) tensor = Var("t", TensorType([Any()], self.dtype)) up = Var("up", scalar_type("int32")) i = Var("i", scalar_type("int32")) - tensor_type_var = self.get_var("tensor_t") - tensor0_var = self.get_var("tensor0") + tensor_type_var = self.tensor_type_var + tensor0_var = self.get_ctor("tensor0") helper_body = If( equal(i, up), - self.prelude.nil(), - self.prelude.cons( - tensor0_var(op.take(tensor, i)), helper_var(add(i, const(1)), up, tensor) - ), + self.nil(), + self.cons(tensor0_var(op.take(tensor, i)), helper_var(add(i, const(1)), up, tensor)), ) self.prelude.mod[helper_var] = Function( - [i, up, tensor], helper_body, self.prelude.l(tensor_type_var()), [] + [i, up, tensor], helper_body, self.list(tensor_type_var()), [] ) unstack_name = self.get_name("tensor_array_unstack_tensor1") unstack_var = GlobalVar(unstack_name) - setattr(self.prelude, unstack_name, unstack_var) tensor1 = Var("tensor", TensorType([Any()], self.dtype)) shape = op.shape_of(tensor1) ndim = op.take(shape, const(0)) self.prelude.mod[unstack_var] = Function( - [tensor1], helper_var(const(0), ndim, tensor1), self.prelude.l(tensor_type_var()), [] + [tensor1], helper_var(const(0), ndim, tensor1), self.list(tensor_type_var()), [] ) def define_tensor_array_unstack_tensor2(self): @@ -1039,14 +1062,14 @@ def define_tensor_array_unstack_tensor2(self): helper_body = If( equal(i, up), - self.prelude.nil(), - self.prelude.cons( - self.get_var("tensor1")(op.take(tensor, i, axis=0)), + self.nil(), + self.cons( + self.get_ctor("tensor1")(op.take(tensor, i, axis=0)), helper_var(add(i, const(1)), up, tensor), ), ) self.prelude.mod[helper_var] = Function( - [i, up, tensor], helper_body, self.prelude.l(self.get_var("tensor_t")()), [] + [i, up, tensor], helper_body, self.list(self.tensor_type_var()), [] ) tensor_array_unstack_tensor2_name = self.get_name("tensor_array_unstack_tensor2") @@ -1058,7 +1081,7 @@ def define_tensor_array_unstack_tensor2(self): self.prelude.mod[tensor_array_unstack_tensor2_var] = Function( [tensor2], helper_var(const(0), ndim, tensor2), - self.prelude.l(self.get_var("tensor_t")()), + self.list(self.tensor_type_var()), [], ) @@ -1076,14 +1099,14 @@ def define_tensor_array_unstack_tensor3(self): helper_body = If( equal(i, up), - self.prelude.nil(), - self.prelude.cons( - self.get_var("tensor2")(op.take(tensor, i, axis=0)), + self.nil(), + self.cons( + self.get_ctor("tensor2")(op.take(tensor, i, axis=0)), helper_var(add(i, const(1)), up, tensor), ), ) self.prelude.mod[helper_var] = Function( - [i, up, tensor], helper_body, self.prelude.l(self.get_var("tensor_t")()), [] + [i, up, tensor], helper_body, self.list(self.tensor_type_var()), [] ) tensor_array_unstack_tensor3_name = self.get_name("tensor_array_unstack_tensor3") @@ -1095,7 +1118,7 @@ def define_tensor_array_unstack_tensor3(self): self.prelude.mod[tensor_array_unstack_tensor3_var] = Function( [tensor3], helper_var(const(0), ndim, tensor3), - self.prelude.l(self.get_var("tensor_t")()), + self.list(self.tensor_type_var()), [], ) @@ -1113,14 +1136,14 @@ def define_tensor_array_unstack_tensor4(self): helper_body = If( equal(i, up), - self.prelude.nil(), - self.prelude.cons( - self.get_var("tensor3")(op.take(tensor, i, axis=0)), + self.nil(), + self.cons( + self.get_ctor("tensor3")(op.take(tensor, i, axis=0)), helper_var(add(i, const(1)), up, tensor), ), ) self.prelude.mod[helper_var] = Function( - [i, up, tensor], helper_body, self.prelude.l(self.get_var("tensor_t")()), [] + [i, up, tensor], helper_body, self.list(self.tensor_type_var()), [] ) tensor_array_unstack_tensor4_name = self.get_name("tensor_array_unstack_tensor4") @@ -1132,7 +1155,7 @@ def define_tensor_array_unstack_tensor4(self): self.prelude.mod[tensor_array_unstack_tensor4_var] = Function( [tensor4], helper_var(const(0), ndim, tensor4), - self.prelude.l(self.get_var("tensor_t")()), + self.list(self.tensor_type_var()), [], ) @@ -1150,14 +1173,14 @@ def define_tensor_array_unstack_tensor5(self): helper_body = If( equal(i, up), - self.prelude.nil(), - self.prelude.cons( - self.get_var("tensor4")(op.take(tensor, i, axis=0)), + self.nil(), + self.cons( + self.get_ctor("tensor4")(op.take(tensor, i, axis=0)), helper_var(add(i, const(1)), up, tensor), ), ) self.prelude.mod[helper_var] = Function( - [i, up, tensor], helper_body, self.prelude.l(self.get_var("tensor_t")()), [] + [i, up, tensor], helper_body, self.list(self.tensor_type_var()), [] ) tensor_array_unstack_tensor5_name = self.get_name("tensor_array_unstack_tensor5") @@ -1169,7 +1192,7 @@ def define_tensor_array_unstack_tensor5(self): self.prelude.mod[tensor_array_unstack_tensor5_var] = Function( [tensor5], helper_var(const(0), ndim, tensor5), - self.prelude.l(self.get_var("tensor_t")()), + self.list(self.tensor_type_var()), [], ) @@ -1187,14 +1210,14 @@ def define_tensor_array_unstack_tensor6(self): helper_body = If( equal(i, up), - self.prelude.nil(), - self.prelude.cons( - self.get_var("tensor5")(op.take(tensor, i, axis=0)), + self.nil(), + self.cons( + self.get_ctor("tensor5")(op.take(tensor, i, axis=0)), helper_var(add(i, const(1)), up, tensor), ), ) self.prelude.mod[helper_var] = Function( - [i, up, tensor], helper_body, self.prelude.l(self.get_var("tensor_t")()), [] + [i, up, tensor], helper_body, self.list(self.tensor_type_var()), [] ) tensor_array_unstack_tensor6_name = self.get_name("tensor_array_unstack_tensor6") @@ -1206,7 +1229,7 @@ def define_tensor_array_unstack_tensor6(self): self.prelude.mod[tensor_array_unstack_tensor6_var] = Function( [tensor6], helper_var(const(0), ndim, tensor6), - self.prelude.l(self.get_var("tensor_t")()), + self.list(self.tensor_type_var()), [], ) @@ -1217,14 +1240,14 @@ def define_tensor_array_scatter(self): """ tensor_array_scatter_helper_name = self.get_name("tensor_array_scatter_helper") tensor_array_scatter_helper_var = GlobalVar(tensor_array_scatter_helper_name) - tensor_t = self.get_var("tensor_t") - ta = Var("ta", self.prelude.l(tensor_t())) + tensor_t = self.tensor_type_var + ta = Var("ta", self.list(tensor_t())) current = Var("current", scalar_type("int32")) limit = Var("limit", scalar_type("int32")) indices_ = Var("indices_", TensorType([Any()], "int32")) - values_ = Var("values_", self.prelude.l(tensor_t())) - write_var = self.get_var("tensor_array_write") - read_var = self.get_var("tensor_array_read") + values_ = Var("values_", self.list(tensor_t())) + write_var = self.get_global_var("tensor_array_write") + read_var = self.get_global_var("tensor_array_read") helper_body = If( equal(current, limit), ta, @@ -1237,19 +1260,19 @@ def define_tensor_array_scatter(self): ), ) self.prelude.mod[tensor_array_scatter_helper_var] = Function( - [ta, current, limit, indices_, values_], helper_body, self.prelude.l(tensor_t()), [] + [ta, current, limit, indices_, values_], helper_body, self.list(tensor_t()), [] ) tensor_array_scatter_name = self.get_name("tensor_array_scatter") tensor_array_scatter_var = GlobalVar(tensor_array_scatter_name) setattr(self.prelude, tensor_array_scatter_name, tensor_array_scatter_var) - tensor_array = Var("tensor_array", self.prelude.l(tensor_t())) + tensor_array = Var("tensor_array", self.list(tensor_t())) indices = Var("indices", TensorType([Any()], "int32")) - values = Var("values", self.prelude.l(tensor_t())) + values = Var("values", self.list(tensor_t())) indices_shape = op.shape_of(indices) limit = op.take(indices_shape, const(0)) body = tensor_array_scatter_helper_var(tensor_array, const(0), limit, indices, values) self.prelude.mod[tensor_array_scatter_var] = Function( - [tensor_array, indices, values], body, self.prelude.l(tensor_t()), [] + [tensor_array, indices, values], body, self.list(tensor_t()), [] ) def define_tensor_array_split(self): @@ -1257,18 +1280,18 @@ def define_tensor_array_split(self): tensor_array_split(ta, value, lengths) : list[tensor_t] -> tensor_t -> Tensor[(Any), int32] -> list[tensor_t] """ - tensor_t = self.get_var("tensor_t") + tensor_t = self.tensor_type_var tensor_array_split_helper_name = self.get_name("ta_split_helper") tensor_array_split_helper_var = GlobalVar(tensor_array_split_helper_name) setattr(self.prelude, tensor_array_split_helper_name, tensor_array_split_helper_var) - ta1 = Var("tensor_array", self.prelude.l(tensor_t())) + ta1 = Var("tensor_array", self.list(tensor_t())) value1 = Var("value1", tensor_t()) offset1 = Var("offset1", scalar_type("int32")) current1 = Var("current1", scalar_type("int32")) limit1 = Var("limit1", scalar_type("int32")) lengths1 = Var("lengths", TensorType([Any()], "int32")) - write_var = self.get_var("tensor_array_write") - take_var = self.get_var("tensor_take") + write_var = self.get_global_var("tensor_array_write") + take_var = self.get_global_var("tensor_take") helper1_body = If( equal(current1, limit1), ta1, @@ -1288,13 +1311,13 @@ def define_tensor_array_split(self): self.prelude.mod[tensor_array_split_helper_var] = Function( [ta1, value1, offset1, current1, limit1, lengths1], helper1_body, - self.prelude.l(tensor_t()), + self.list(tensor_t()), [], ) split_name = self.get_name("tensor_array_split") split_var = GlobalVar(split_name) setattr(self.prelude, split_name, split_var) - tensor_array = Var("tensor_array", self.prelude.l(tensor_t())) + tensor_array = Var("tensor_array", self.list(tensor_t())) value = Var("value", tensor_t()) lengths = Var("lengths", TensorType([Any()], "int32")) lengths_shape = op.shape_of(lengths) @@ -1303,7 +1326,7 @@ def define_tensor_array_split(self): tensor_array, value, const(0), const(0), lengths_limit, lengths ) self.prelude.mod[split_var] = Function( - [tensor_array, value, lengths], body, self.prelude.l(tensor_t()), [] + [tensor_array, value, lengths], body, self.list(tensor_t()), [] ) def define_tensor_array_concat(self): @@ -1313,19 +1336,19 @@ def define_tensor_array_concat(self): concat_name = self.get_name("tensor_array_concat") concat_var = GlobalVar(concat_name) setattr(self.prelude, concat_name, concat_var) - tensor_concat_var = self.get_var("tensor_concatenate") - tensor_t = self.get_var("tensor_t") - tensor_nil_var = self.get_var("tensor_nil") - tensor_array = Var("tensor_array", self.prelude.l(tensor_t())) + tensor_concat_var = self.get_global_var("tensor_concatenate") + tensor_t = self.tensor_type_var + tensor_nil_var = self.get_ctor("tensor_nil") + tensor_array = Var("tensor_array", self.list(tensor_t())) hd = Var("hd") tl = Var("tl") - nil_case = Clause(PatternConstructor(self.prelude.nil), tensor_nil_var()) + nil_case = Clause(PatternConstructor(self.nil), tensor_nil_var()) cons_case = Clause( - PatternConstructor(self.prelude.cons, [PatternVar(hd), PatternVar(tl)]), + PatternConstructor(self.cons, [PatternVar(hd), PatternVar(tl)]), Match( tl, [ - Clause(PatternConstructor(self.prelude.nil), hd), + Clause(PatternConstructor(self.nil), hd), Clause(PatternWildcard(), tensor_concat_var(hd, concat_var(tl))), ], False, @@ -1342,11 +1365,11 @@ def define_tensor_array_gather(self): helper_name = self.get_name("tensor_array_gather_helper") helper_var = GlobalVar(helper_name) setattr(self.prelude, helper_name, helper_var) - tensor_type_var = self.get_var("tensor_t") + tensor_type_var = self.tensor_type_var stack_var = self.get_var("tensor_array_stack") read_var = self.get_var("tensor_array_read") - ta = Var("ta", self.prelude.l(tensor_type_var())) - accu = Var("accu", self.prelude.l(tensor_type_var())) + ta = Var("ta", self.list(tensor_type_var())) + accu = Var("accu", self.list(tensor_type_var())) current = Var("current", scalar_type("int32")) limit = Var("limit", scalar_type("int32")) indices_ = Var("indices_", TensorType([Any()], "int32")) @@ -1355,9 +1378,7 @@ def define_tensor_array_gather(self): stack_var(accu), helper_var( ta, - self.prelude.cons( - read_var(ta, op.take(indices_, subtract(current, const(1)))), accu - ), + self.cons(read_var(ta, op.take(indices_, subtract(current, const(1)))), accu), subtract(current, const(1)), limit, indices_, @@ -1369,11 +1390,11 @@ def define_tensor_array_gather(self): gather_name = self.get_name("tensor_array_gather") gather_var = GlobalVar(gather_name) setattr(self.prelude, gather_name, gather_var) - tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) + tensor_array = Var("tensor_array", self.list(tensor_type_var())) indices = Var("indices", TensorType([Any()], "int32")) indices_shape = op.shape_of(indices) limit = op.take(indices_shape, const(0)) - body = helper_var(tensor_array, self.prelude.nil(), limit, limit, indices) + body = helper_var(tensor_array, self.nil(), limit, limit, indices) self.prelude.mod[gather_var] = Function( [tensor_array, indices], body, tensor_type_var(), [] ) @@ -1385,10 +1406,11 @@ def define_tensor_array_stack(self): stack_name = self.get_name("tensor_array_stack") stack_var = GlobalVar(stack_name) setattr(self.prelude, stack_name, stack_var) - tensor_type_var = self.get_var("tensor_t") - tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) - expand_dims_var = self.get_var("tensor_expand_dims") - concat_var = self.get_var("tensor_concatenate") + tensor_type_var = self.tensor_type_var + tensor_array = Var("tensor_array", self.list(tensor_type_var())) + expand_dims_var = self.get_global_var("tensor_expand_dims") + concat_var = self.get_global_var("tensor_concatenate") + tensor_array_expand_dims = self.prelude.map(expand_dims_var, tensor_array) tensors = self.prelude.foldl( concat_var, @@ -1437,39 +1459,62 @@ def get_name(self, canonical, dtype): return "tensor_{}_t".format(dtype) return "{}_{}".format(canonical, dtype) - def get_var(self, canonical, dtype): - """Get var corresponding to the canonical name""" + def get_global_var(self, canonical, dtype): + """Get global var corresponding to the canonical name""" name = self.get_name(canonical, dtype) - return getattr(self, name) + return self.mod.get_global_var(name) + + def get_type(self, canonical, dtype): + """Get type corresponding to the canonical name""" + name = self.get_name(canonical, dtype) + return self.mod.get_global_type_var(name) + + def get_ctor(self, ty_name, canonical, dtype): + """Get constructor corresponding to the canonical name""" + name = self.get_name(canonical, dtype) + ctors = self.mod.get_type(ty_name) + for ctor in ctors: + if ctor.name_hint == name: + return ctor + raise Exception(f"could not find {name}") + + def get_tensor_ctor(self, canonical, dtype): + ty = self.get_type("tensor_t", dtype) + return self.get_ctor(ty.name_hint, canonical, dtype) def get_name_static(self, canonical, dtype, shape): """Get name corresponding to the canonical name""" return _get_name_static(canonical, dtype, shape) - def get_var_static(self, canonical, dtype, shape): + def get_global_var_static(self, canonical, dtype, shape): """Get var corresponding to the canonical name""" name = self.get_name_static(canonical, dtype, shape) - return getattr(self, name) + return self.mod.get_global_var(name) + + def get_type_static(self, canonical, dtype, shape): + """Get type corresponding to the canonical name""" + name = self.get_name_static(canonical, dtype, shape) + return self.mod.get_global_type_var(name) + + def get_ctor_static(self, ty_name, name, dtype, shape): + """Get constructor corresponding to the canonical name""" + ty_name = self.get_name_static(ty_name, dtype, shape) + name = self.get_name_static(name, dtype, shape) + ctors = self.mod.get_type(ty_name) + for ctor in ctors: + if ctor.name_hint == name: + return ctor + raise Exception(f"could not find {name}") + + def get_tensor_ctor_static(self, name, dtype, shape): + """Get constructor corresponding to the canonical name""" + return self.get_ctor_static("tensor_t", name, dtype, shape) def load_prelude(self): """Parses the Prelude from Relay's text format into a module.""" # TODO(@jroesch): we should remove this helper when we port over prelude self.mod.import_from_std("prelude.rly") - self.l = self.mod.get_global_type_var("List") - list_adt = self.mod[self.l] - self.cons = list_adt.constructors[0] - self.nil = list_adt.constructors[1] - - self.optional = self.mod.get_global_type_var("Option") - optional_adt = self.mod[self.optional] - self.some = optional_adt.constructors[0] - self.none = optional_adt.constructors[1] - - self.tree = self.mod.get_global_type_var("Tree") - tree_adt = self.mod[self.tree] - self.rose = tree_adt.constructors[0] - GLOBAL_DEFS = [ "id", "compose", @@ -1496,6 +1541,7 @@ def load_prelude(self): "size", "iterate", ] + for global_def in GLOBAL_DEFS: setattr(self, global_def, self.mod.get_global_var(global_def)) @@ -1512,3 +1558,6 @@ def load_prelude(self): ]: tensor_array_ops = TensorArrayOps(self, dtype) tensor_array_ops.register() + + # Renamer doesn't properly deal with constructors, etc + # self.mod = AnnotateSpans()(self.mod) diff --git a/python/tvm/relay/quantize/_partition_conversions.py b/python/tvm/relay/quantize/_partition_conversions.py index 166e86483fa24..8ba5c9ae2f205 100644 --- a/python/tvm/relay/quantize/_partition_conversions.py +++ b/python/tvm/relay/quantize/_partition_conversions.py @@ -121,6 +121,7 @@ def fuse_partitions(pre_mod, mid_mod, post_mod): relay.GlobalVar("dequantize_outputs"): post_func, } ) + # construct a `main` that strings together the partitions, such that its # behaviour is equivalent to `main` in an *unpartitioned* module scope_builder = relay.ScopeBuilder() @@ -142,7 +143,7 @@ def fuse_partitions(pre_mod, mid_mod, post_mod): ) scope_builder.ret(dequantized_outputs) fused_mod["main"] = relay.Function(fused_mod_main_params, scope_builder.get()) - return fused_mod + return relay.transform.InferType()(fused_mod) class PrefixCutter(ExprMutator): @@ -217,6 +218,7 @@ def partition_prefix(mod, quantized_dtypes): assert func.attrs is None, "unimplemented" mid_func = relay.Function(relay.analysis.free_vars(mid_body), mid_body) mid_mod = tvm.IRModule.from_expr(mid_func) + mid_mod = relay.transform.InferType()(mid_mod) scope_builder = prefix_cutter.prefix_sb # make sure we pass through all inputs in the prefix function's return expr @@ -237,6 +239,7 @@ def partition_prefix(mod, quantized_dtypes): pre_func_body = scope_builder.get() pre_func = relay.Function(relay.analysis.free_vars(pre_func_body), pre_func_body) pre_mod = tvm.IRModule.from_expr(pre_func) + pre_mod = relay.transform.InferType()(pre_mod) return pre_mod, mid_mod @@ -288,6 +291,7 @@ def partition_suffix(mod, quantized_dtypes): assert func.attrs is None, "unimplemented" post_func = relay.Function(relay.analysis.free_vars(post_body), post_body, func.ret_type) post_mod = tvm.IRModule.from_expr(post_func) + post_mod = relay.transform.InferType()(post_mod) mid_body = suffix_cutter.mid_body if mid_body is None: @@ -298,9 +302,11 @@ def partition_suffix(mod, quantized_dtypes): post_body = relay.Var("input", mid_mod["main"].ret_type) post_func = relay.Function([post_body], post_body) post_mod = tvm.IRModule.from_expr(post_func) + post_mod = relay.transform.InferType()(post_mod) else: mid_func = relay.Function(func.params, mid_body) mid_mod = tvm.IRModule.from_expr(mid_func) + mid_mod = relay.transform.InferType()(mid_mod) return mid_mod, post_mod diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index 3d6870f75f84a..8f7333051a4cc 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -359,7 +359,7 @@ def quantize(mod, params=None, dataset=None): calibrate_pass = tvm.transform.module_pass( calibrate(dataset), opt_level=1, name="QuantizeCalibrate" ) - quant_passes = [partition(), annotate(), calibrate_pass] + quant_passes = [partition(), annotate(), calibrate_pass, tvm.relay.transform.InferType()] if not current_qconfig().do_simulation: quant_passes.append(realize()) quant_passes.append(_transform.FoldConstant()) diff --git a/python/tvm/relay/std/nat.rly b/python/tvm/relay/std/nat.rly new file mode 100644 index 0000000000000..de71beb3379c1 --- /dev/null +++ b/python/tvm/relay/std/nat.rly @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#[version = "0.0.5"] + +/* Defines a Peano (unary) natural number ADT. + Zero is represented by z(). s(n) adds 1 to a nat n. + Adds the fields nat, z, and s to the prelude, representing + (respectively) the nat ADT and the z and s constructors. +*/ +type nat { + zero, + succ(nat), +} + +/* + Defines a function that doubles a nat. Adds a field called + 'double' to the prelude, giving the GlobalVar pointing to + the function. +*/ +def @nat_double(%x: nat) -> nat { + match %x { + zero => zero, + succ(%y) => succ(succ(@nat_double(%y))) + } +} + +def @nat_add(%x: nat, %y: nat) -> nat { + match %x { + zero => %y, + succ(%z) => succ(@nat_add(%z, %y)) + } +} + +/* Defines a function to get the nth eleemnt of a list using + a nat to index into the list. +*/ +def @nat_nth[A](%l: List[A], %n: nat) -> A { + match %n { + zero => @hd(%l), + succ(%y) => @nat_nth(@tl(%l), %y) + } +} + +/* Defines a function to update the nth element of a list and return the updated list. */ +def @nat_update[A](%list: List[A], %index: nat, %value: A) -> List[A] { + match %index { + zero => Cons(%value, @tl(%list)), + succ(%index_pred) => @nat_update(@tl(%list), %index_pred, %value) + } +} + +/* Defines a function that takes a number n and a function f; + returns a closure that takes an argument and applies f + n times to its argument. +*/ +def @nat_iterate[A](%f: fn(A) -> A, %num: nat) -> fn(A) -> A { + match %num { + zero => fn(%x: A) -> A { %x }, + succ(%y) => fn (%i: A) { %f(@nat_iterate(%f, %y)(%i)) } + } +} diff --git a/python/tvm/relay/std/prelude.rly b/python/tvm/relay/std/prelude.rly index 17c91283f4d2b..57512a0369b3c 100644 --- a/python/tvm/relay/std/prelude.rly +++ b/python/tvm/relay/std/prelude.rly @@ -146,7 +146,8 @@ def @foldr1[A](%f: fn(A, A) -> A, %xs: List[A]) -> A { /* * Computes the sum of a list of integer scalars. */ -def @sum(%xs: List[Tensor[(), int32]]) { + // (@jroesch): if we leave off the return type this doesn't work +def @sum(%xs: List[Tensor[(), int32]]) -> int32 { let %add_f = fn(%x: Tensor[(), int32], %y: Tensor[(), int32]) -> Tensor[(), int32] { %x + %y }; @@ -193,7 +194,7 @@ def @zip[A, B](%xs: List[A], %ys: List[B]) -> List[(A, B)] { * Reverses a list. */ def @rev[A](%xs: List[A]) -> List[A] { - @foldl(@flip(Cons), Nil, %xs) + @foldl(@flip(fn (%h: A, %t: List[A]) { Cons(%h, %t) }), Nil, %xs) } /* diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index d9e4f1eec81e1..6eb71b581ab20 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -43,7 +43,7 @@ from . import synthetic from .init import create_workload -from .nat import add_nat_definitions, count, make_nat_value, make_nat_expr +from .nat import count, make_nat_value, make_nat_expr from .py_converter import to_python, run_as_python from ..transform import gradient @@ -53,6 +53,7 @@ def run_opt_pass(expr, opt_pass, import_prelude=False): mod = tvm.IRModule.from_expr(expr) if import_prelude: Prelude(mod) + mod = relay.transform.InferType()(mod) mod = opt_pass(mod) entry = mod["main"] return entry if isinstance(expr, relay.Function) else entry.body diff --git a/python/tvm/relay/testing/nat.py b/python/tvm/relay/testing/nat.py index b694f63126162..914a7ffdde743 100644 --- a/python/tvm/relay/testing/nat.py +++ b/python/tvm/relay/testing/nat.py @@ -14,143 +14,19 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=invalid-name """Defines a unary natural number (Peano natural number) abstract data type for Relay and provides some utility functions for it. Nats are useful for testing purposes, as they make it easy to write test cases for recursion and pattern matching.""" -from tvm.relay.adt import Constructor, TypeData, Clause, Match, PatternConstructor, PatternVar from tvm.relay.backend.interpreter import ConstructorValue -from tvm.relay.expr import Var, GlobalVar -from tvm.relay.function import Function -from tvm.relay.ty import GlobalTypeVar, TypeVar, FuncType -def define_nat_adt(prelude): - """Defines a Peano (unary) natural number ADT. - Zero is represented by z(). s(n) adds 1 to a nat n. - Adds the fields nat, z, and s to the preluide, representing - (respectively) the nat ADT and the z and s constructors. - """ - prelude.nat = GlobalTypeVar("nat") - prelude.z = Constructor("z", [], prelude.nat) - prelude.s = Constructor("s", [prelude.nat()], prelude.nat) - prelude.mod[prelude.nat] = TypeData(prelude.nat, [], [prelude.z, prelude.s]) - - -def define_nat_double(prelude): - """Defines a function that doubles a nat. Adds a field called - 'double' to the prelude, giving the GlobalVar pointing to - the function. - """ - prelude.double = GlobalVar("double") - x = Var("x", prelude.nat()) - y = Var("y") - z_case = Clause(PatternConstructor(prelude.z), prelude.z()) - s_case = Clause( - PatternConstructor(prelude.s, [PatternVar(y)]), prelude.s(prelude.s(prelude.double(y))) - ) - prelude.mod[prelude.double] = Function([x], Match(x, [z_case, s_case])) - - -def define_nat_add(prelude): - """Defines a function that adds two nats and adds a field to the - prelude 'add' giving the GlobalVar pointing to that function. - """ - prelude.add = GlobalVar("add") - x = Var("x", prelude.nat()) - y = Var("y", prelude.nat()) - a = Var("a") - z_case = Clause(PatternConstructor(prelude.z), y) - s_case = Clause(PatternConstructor(prelude.s, [PatternVar(a)]), prelude.s(prelude.add(a, y))) - prelude.mod[prelude.add] = Function([x, y], Match(x, [z_case, s_case])) - - -# versions of prelude functions that use nats instead of scalars - - -def define_nat_nth(prelude): - """Defines a function to get the nth eleemnt of a list using - a nat to index into the list. - - nat_nth(l, n): fun(list[a], nat) -> a - """ - prelude.nat_nth = GlobalVar("nat_nth") - a = TypeVar("a") - x = Var("x", prelude.l(a)) - n = Var("n", prelude.nat()) - y = Var("y") - - z_case = Clause(PatternConstructor(prelude.z), prelude.hd(x)) - s_case = Clause( - PatternConstructor(prelude.s, [PatternVar(y)]), prelude.nat_nth(prelude.tl(x), y) - ) - - prelude.mod[prelude.nat_nth] = Function([x, n], Match(n, [z_case, s_case]), a, [a]) - - -def define_nat_update(prelude): - """Defines a function to update the nth element of a list and return the updated list. - - nat_update(l, i, v) : fun(list[a], nat, a) -> list[a] - """ - prelude.nat_update = GlobalVar("nat_update") - a = TypeVar("a") - # pylint: disable=invalid-name - l = Var("l", prelude.l(a)) - n = Var("n", prelude.nat()) - v = Var("v", a) - y = Var("y") - - z_case = Clause(PatternConstructor(prelude.z), prelude.cons(v, prelude.tl(l))) - s_case = Clause( - PatternConstructor(prelude.s, [PatternVar(y)]), - prelude.cons(prelude.hd(l), prelude.nat_update(prelude.tl(l), y, v)), - ) - - prelude.mod[prelude.nat_update] = Function( - [l, n, v], Match(n, [z_case, s_case]), prelude.l(a), [a] - ) - - -def define_nat_iterate(prelude): - """Defines a function that takes a number n and a function f; - returns a closure that takes an argument and applies f - n times to its argument. - - Signature: fn(fn(a) -> a, nat) -> fn(a) -> a - """ - prelude.nat_iterate = GlobalVar("nat_iterate") - a = TypeVar("a") - f = Var("f", FuncType([a], a)) - x = Var("x", prelude.nat()) - y = Var("y", prelude.nat()) - - z_case = Clause(PatternConstructor(prelude.z), prelude.id) - s_case = Clause( - PatternConstructor(prelude.s, [PatternVar(y)]), - prelude.compose(f, prelude.nat_iterate(f, y)), - ) - - prelude.mod[prelude.nat_iterate] = Function( - [f, x], Match(x, [z_case, s_case]), FuncType([a], a), [a] - ) - - -def add_nat_definitions(prelude): - """Given a Relay prelude, adds a Peano nat ADT, as well as functions - for adding nats and doubling nats. It also adds versions of - update, nth, and iterate that take nats instead of scalars (the - names are prefixed with `nat_`).""" - define_nat_adt(prelude) - define_nat_double(prelude) - define_nat_add(prelude) - define_nat_nth(prelude) - define_nat_update(prelude) - define_nat_iterate(prelude) - - -# helper functions for working with nats +def get_type(prelude, name): + ty_var = prelude.mod.get_global_type_var(name) + ty_data = prelude.mod.type_definitions[ty_var] + return tuple([ty_var] + list(ty_data.constructors)) def count(prelude, n): @@ -159,9 +35,10 @@ def count(prelude, n): using an ADT value in Python. """ assert isinstance(n, ConstructorValue) - if n.tag == prelude.z.tag: + _, z, s = prelude.mod.get_type("nat") + if n.tag == z.tag: return 0 - assert n.tag == prelude.s.tag + assert n.tag == s.tag return 1 + count(prelude, n.fields[0]) @@ -169,9 +46,10 @@ def make_nat_value(prelude, n): """The inverse of count(): Given a non-negative Python integer, constructs a ConstructorValue representing that value as a nat. """ + _, z, s = prelude.mod.get_type("nat") if n == 0: - return ConstructorValue(prelude.z.tag, [], None) - return ConstructorValue(prelude.s.tag, [make_nat_value(prelude, n - 1)], None) + return ConstructorValue(z.tag, [], z) + return ConstructorValue(s.tag, [make_nat_value(prelude, n - 1)], s) def make_nat_expr(prelude, n): @@ -179,8 +57,9 @@ def make_nat_expr(prelude, n): expression representing that integer's value as a nat. """ assert n >= 0 - ret = prelude.z() + _, z, s = prelude.mod.get_type("nat") + ret = z() while n > 0: - ret = prelude.s(ret) + ret = s(ret) n = n - 1 return ret diff --git a/python/tvm/relay/testing/py_converter.py b/python/tvm/relay/testing/py_converter.py index c0dc97cc4b224..283a238a76260 100644 --- a/python/tvm/relay/testing/py_converter.py +++ b/python/tvm/relay/testing/py_converter.py @@ -14,7 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=no-else-return """Utility for converting Relay code into a Python script with equivalent semantics""" +import sys import ast from ast import alias, Assign, Load, Name, NameConstant, Num, Return, Store, Str import re @@ -27,6 +29,8 @@ from tvm.relay.function import Function from tvm.relay.expr_functor import ExprFunctor +__MAJOR__, __MINOR__, _, _, _ = sys.version_info + OUTPUT_VAR_NAME = "_py_out" # corresponds to: @@ -82,8 +86,12 @@ def convert(self, prog: Expr): # we finally must assign the final expression to the output var # so it can be read after running EXEC body.append(Assign([Name(OUTPUT_VAR_NAME, Store())], prog_body)) + global __MAJOR__, __MINOR__ - return ast.fix_missing_locations(ast.Module(body=body)) + if __MAJOR__ == 3 and __MINOR__ == 8: + return ast.fix_missing_locations(ast.Module(body=body, type_ignores=[])) + else: + return ast.fix_missing_locations(ast.Module(body=body)) def optimize(self, prog: Expr): """Performs optimizations necessary to be able to generate code for prog.""" @@ -210,11 +218,17 @@ def create_call(self, func_name: str, arguments): def create_def(self, func_name: str, arguments: [str], body): """Wrapper over function definition AST node, whose constructor is inconvenient.""" + inner_args = [ast.arg(argument, None) for argument in arguments] + + global __MAJOR__, __MINOR__ + if __MAJOR__ == 3 and __MINOR__ == 8: + arguments = ast.arguments([], inner_args, None, [], [], None, []) + else: + arguments = ast.arguments(inner_args, None, [], [], None, []) + return ast.FunctionDef( func_name, - ast.arguments( - [ast.arg(argument, None) for argument in arguments], None, [], [], None, [] - ), + arguments, body, [], None, @@ -576,8 +590,11 @@ def to_python(expr: Expr, mod=None, target=tvm.target.Target("llvm")): """Converts the given Relay expression into a Python script (as a Python AST object). For easiest debugging, import the astor package and use to_source().""" mod = mod if mod is not None else tvm.IRModule() + mod = relay.transform.InferType()(mod) converter = PythonConverter(mod, target) - return converter.convert(expr) + python = converter.convert(expr) + assert python + return python def run_as_python(expr: Expr, mod=None, target=tvm.target.Target("llvm")): diff --git a/python/tvm/relay/transform/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py index ccc4f76e6d047..f611c1cc14c13 100644 --- a/python/tvm/relay/transform/memory_alloc.py +++ b/python/tvm/relay/transform/memory_alloc.py @@ -21,6 +21,7 @@ import numpy as np from tvm.ir.transform import PassContext, module_pass +from tvm.relay.transform import InferType from tvm import nd, container from ..function import Function from ..expr_functor import ExprVisitor, ExprMutator @@ -351,6 +352,7 @@ def transform_module(self, mod, _): # TODO(@jroesch): Is there a way to do one shot initialization? # can we have def pass_init? mod.import_from_std("core.rly") + mod = InferType()(mod) assert isinstance(self.targets, (dict, container.Map)) if len(self.targets) > 1: diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index ade071c70ad51..e155f83a7c5de 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1032,3 +1032,17 @@ def SimplifyExpr(): The registered SimplifyExpr pass. """ return _ffi_api.SimplifyExpr() + + +def AnnotateSpans(): + """ + Annotate a program with span information by first generating its textual + representation and then parsing it back into a Relay AST annotated with + span information. + + Returns + ------- + ret : tvm.transform.Pass + The regsistered AnnotateSpans pass. + """ + return _ffi_api.AnnotateSpans() diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index d05d84614d574..2e41f0bee9213 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -215,6 +215,7 @@ std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target tir::transform::CombineContextCall(), }; auto opt_host = transform::Sequential(host_pass_list); + CHECK(mod_mixed.defined()) << "This module must be defined"; auto mhost = opt_host(mod_mixed); // device pipeline @@ -271,14 +272,23 @@ runtime::Module build(const Map& inputs, const Target& target_ IRModule mhost_all = IRModule(Map()); + CHECK(mhost_all.defined()) << "The host module must be defined"; + for (const auto& it : inputs) { - auto pair = SplitDevHostFuncs(it.second, it.first, target_host_val, pass_ctx); - auto& mhost = pair.first; - auto& mdevice = pair.second; + if (it.second.defined()) { + auto pair = SplitDevHostFuncs(it.second, it.first, target_host_val, pass_ctx); + auto& mhost = pair.first; + auto& mdevice = pair.second; + + CHECK(mhost.defined()) << "The split host module must be defined"; + + CHECK(mhost_all.defined()) << "The host module must be defined"; - mhost_all->Update(mhost); - if (mdevice->functions.size() != 0) { - device_modules.push_back(codegen::Build(mdevice, it.first)); + mhost_all->Update(mhost); + + if (mdevice->functions.size() != 0) { + device_modules.push_back(codegen::Build(mdevice, it.first)); + } } } diff --git a/src/ir/diagnostic.cc b/src/ir/diagnostic.cc new file mode 100644 index 0000000000000..ceadf78e2cfcc --- /dev/null +++ b/src/ir/diagnostic.cc @@ -0,0 +1,299 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/ir/transform.cc + * \brief Infrastructure for transformation passes. + */ +#include +#include + +#include + +namespace tvm { + +using tvm::parser::Source; + +const char* kTVM_INTERNAL_ERROR_MESSAGE = + "\n---------------------------------------------------------------\n" + "An internal invariant was violated during the execution of TVM.\n" + "Please read TVM's error reporting guidelines.\n" + "More details can be found here: https://discuss.tvm.ai/t/error-reporting/7793.\n" + "---------------------------------------------------------------\n"; + +// failed to check to argument arg0.dims[0] != 0 + +/* Diagnostic */ +TVM_REGISTER_NODE_TYPE(DiagnosticNode); + +TVM_REGISTER_GLOBAL("diagnostics.Diagnostic") + .set_body_typed([](int level, Span span, String message) { + return Diagnostic(static_cast(level), span, message); + }); + +Diagnostic::Diagnostic(DiagnosticLevel level, Span span, const std::string& message) { + auto n = make_object(); + n->level = level; + n->span = span; + n->message = message; + data_ = std::move(n); +} + +DiagnosticBuilder Diagnostic::Bug(Span span) { + return DiagnosticBuilder(DiagnosticLevel::kBug, span); +} + +DiagnosticBuilder Diagnostic::Error(Span span) { + return DiagnosticBuilder(DiagnosticLevel::kError, span); +} + +DiagnosticBuilder Diagnostic::Warning(Span span) { + return DiagnosticBuilder(DiagnosticLevel::kWarning, span); +} + +DiagnosticBuilder Diagnostic::Note(Span span) { + return DiagnosticBuilder(DiagnosticLevel::kNote, span); +} + +DiagnosticBuilder Diagnostic::Help(Span span) { + return DiagnosticBuilder(DiagnosticLevel::kHelp, span); +} + +/* Diagnostic Renderer */ +TVM_REGISTER_NODE_TYPE(DiagnosticRendererNode); + +void DiagnosticRenderer::Render(const DiagnosticContext& ctx) { (*this)->renderer(ctx); } + +TVM_DLL DiagnosticRenderer::DiagnosticRenderer( + TypedPackedFunc renderer) { + auto n = make_object(); + n->renderer = renderer; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("diagnostics.DiagnosticRenderer") + .set_body_typed([](TypedPackedFunc renderer) { + return DiagnosticRenderer(renderer); + }); + +/* Diagnostic Context */ +TVM_REGISTER_NODE_TYPE(DiagnosticContextNode); + +void DiagnosticContext::Render() { + (*this)->renderer.Render(*this); + + int errs = 0; + if ((*this)->diagnostics.size()) { + for (auto diagnostic : (*this)->diagnostics) { + if (diagnostic->level == DiagnosticLevel::kError) { + errs += 1; + } + } + } + + if (errs) { + (*this)->renderer = DiagnosticRenderer(); + LOG(FATAL) << "DiagnosticError: one or more error diagnostics were " + << "emitted, please check diagnostic render for output."; + } +} + +TVM_REGISTER_GLOBAL("diagnostics.DiagnosticRendererRender") + .set_body_typed([](DiagnosticRenderer renderer, DiagnosticContext ctx) { + renderer.Render(ctx); + }); + +DiagnosticContext::DiagnosticContext(const IRModule& module, const DiagnosticRenderer& renderer) { + auto n = make_object(); + n->module = module; + n->renderer = renderer; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("diagnostics.DiagnosticContext") + .set_body_typed([](const IRModule& module, const DiagnosticRenderer& renderer) { + return DiagnosticContext(module, renderer); + }); + +/*! \brief Emit a diagnostic. */ +void DiagnosticContext::Emit(const Diagnostic& diagnostic) { + (*this)->diagnostics.push_back(diagnostic); +} + +TVM_REGISTER_GLOBAL("diagnostics.Emit") + .set_body_typed([](DiagnosticContext ctx, const Diagnostic& diagnostic) { + return ctx.Emit(diagnostic); + }); + +TVM_REGISTER_GLOBAL("diagnostics.DiagnosticContextRender") + .set_body_typed([](DiagnosticContext context) { return context.Render(); }); + +/*! \brief Emit a diagnostic. */ +void DiagnosticContext::EmitFatal(const Diagnostic& diagnostic) { + Emit(diagnostic); + Render(); +} + +/* Default Terminal Renderer. */ +static const char* DEFAULT_RENDERER = "diagnostics.DefaultRenderer"; +static const char* OVERRIDE_RENDERER = "diagnostics.OverrideRenderer"; + +DiagnosticRenderer GetRenderer() { + auto override_pf = tvm::runtime::Registry::Get(OVERRIDE_RENDERER); + tvm::runtime::TypedPackedFunc pf; + if (override_pf) { + pf = tvm::runtime::TypedPackedFunc(*override_pf); + } else { + auto default_pf = tvm::runtime::Registry::Get(DEFAULT_RENDERER); + ICHECK(default_pf != nullptr) + << "Can not find registered function for " << DEFAULT_RENDERER << "." << std::endl + << "Either this is an internal error or the default function was overloaded incorrectly."; + pf = tvm::runtime::TypedPackedFunc(*default_pf); + } + return Downcast(pf()); +} + +DiagnosticContext DiagnosticContext::Default(const IRModule& module) { + auto renderer = GetRenderer(); + return DiagnosticContext(module, renderer); +} + +std::ostream& EmitDiagnosticHeader(std::ostream& out, const Span& span, DiagnosticLevel level, + std::string msg) { + rang::fg diagnostic_color = rang::fg::reset; + std::string diagnostic_type; + + switch (level) { + case DiagnosticLevel::kWarning: { + diagnostic_color = rang::fg::yellow; + diagnostic_type = "warning"; + break; + } + case DiagnosticLevel::kError: { + diagnostic_color = rang::fg::red; + diagnostic_type = "error"; + break; + } + case DiagnosticLevel::kBug: { + diagnostic_color = rang::fg::blue; + diagnostic_type = "bug"; + break; + } + case DiagnosticLevel::kNote: { + diagnostic_color = rang::fg::reset; + diagnostic_type = "note"; + break; + } + case DiagnosticLevel::kHelp: { + diagnostic_color = rang::fg::reset; + diagnostic_type = "help"; + break; + } + } + + out << rang::style::bold << diagnostic_color << diagnostic_type << ": " << rang::fg::reset << msg + << std::endl + << rang::fg::blue << " --> " << rang::fg::reset << rang::style::reset + << span->source_name->name << ":" << span->line << ":" << span->column << std::endl; + + return out; +} + +/*! \brief Generate an error message at a specific line and column with the + * annotated message. + * + * The error is written directly to the `out` std::ostream. + * + * \param out The output ostream. + * \param line The line at which to report a diagnostic. + * \param line The column at which to report a diagnostic. + * \param msg The message to attach. + */ +void ReportAt(const DiagnosticContext& context, std::ostream& out, const Span& span, + const Diagnostic& diagnostic) { + if (!span.defined()) { + out << diagnostic->message << std::endl; + return; + } + + CHECK(context->module->source_map.defined()); + auto it = context->module->source_map->source_map.find(span->source_name); + + // If the source name is not in the current source map, sources were not annotated. + if (it == context->module->source_map->source_map.end()) { + LOG(FATAL) << "The source maps are not populated for this module. " + << "Please use `tvm.relay.transform.AnnotateSpans` to attach source maps for error " + "reporting. " + << "Error: " << diagnostic->message; + } + + auto source = (*it).second; + DLOG(INFO) << "Source: " << std::endl << source->source; + + DLOG(INFO) << "ReportAt " + << "span = " << span << " msg = " << diagnostic->message; + + auto line_text = source.GetLine(span->line); + + std::stringstream line_header_s; + line_header_s << " " << span->line << " "; + auto line_header = line_header_s.str(); + + std::stringstream no_line_header_s; + for (size_t i = 0; i < line_header.size(); i++) { + no_line_header_s << " "; + } + auto no_line_header = no_line_header_s.str(); + + EmitDiagnosticHeader(out, span, diagnostic->level, diagnostic->message) + << no_line_header << "| " << std::endl + << line_header << "| " << line_text << std::endl + << no_line_header << "| "; + + std::stringstream marker; + for (size_t i = 1; i <= line_text.size(); i++) { + if (static_cast(i) >= span->column && static_cast(i) < span->end_column) { + marker << "^"; + } else { + marker << " "; + } + } + out << marker.str(); + out << std::endl; +} + +// TODO(@jroesch): eventually modularize the rendering interface to provide control of how to +// format errors. +DiagnosticRenderer TerminalRenderer(std::ostream& out) { + return DiagnosticRenderer([&](const DiagnosticContext& ctx) { + for (auto diagnostic : ctx->diagnostics) { + ReportAt(ctx, out, diagnostic->span, diagnostic); + } + }); +} + +TVM_REGISTER_GLOBAL(DEFAULT_RENDERER).set_body_typed([]() { return TerminalRenderer(std::cout); }); + +TVM_REGISTER_GLOBAL("diagnostics.GetRenderer").set_body_typed([]() { return GetRenderer(); }); + +TVM_REGISTER_GLOBAL("diagnostics.ClearRenderer").set_body_typed([]() { + tvm::runtime::Registry::Remove(OVERRIDE_RENDERER); +}); + +} // namespace tvm diff --git a/src/ir/module.cc b/src/ir/module.cc index 66bce0f6b8824..231ae68dd4e04 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -29,8 +29,10 @@ // and are only used in minimum cases where they are clearly marked. // // Rationale: We calls into relay's analysis module to verify correctness. +#include #include #include +#include #include #include @@ -41,7 +43,7 @@ namespace tvm { IRModule::IRModule(tvm::Map functions, tvm::Map type_definitions, - std::unordered_set import_set) { + std::unordered_set import_set, parser::SourceMap source_map) { auto n = make_object(); n->functions = std::move(functions); n->type_definitions = std::move(type_definitions); @@ -49,6 +51,7 @@ IRModule::IRModule(tvm::Map functions, n->global_var_map_ = {}; n->constructor_tag_map_ = {}; n->import_set_ = std::move(import_set); + n->source_map = source_map; for (const auto& kv : n->functions) { // set global var map @@ -174,46 +177,23 @@ tvm::Array IRModuleNode::GetGlobalTypeVars() const { return tvm::Array(global_type_vars); } -template -tvm::Array concat(const tvm::Array& l, const tvm::Array& r) { - tvm::Array ret(l); - for (const T& t : r) { - ret.push_back(t); - } - return ret; -} - -// helper function to run type check -relay::Function RunTypeCheck(const IRModule& mod, const GlobalVar& var, relay::Function f) { - auto func = Downcast(relay::DeDup(std::move(f))); +void WarnIfMalformed(const IRModule& mod, relay::Function func) { + func = Downcast(relay::DeDup(func)); // Type check the item before we add it to the module. auto fv = relay::FreeVars(func); auto ftv = relay::FreeTypeVars(func, mod); - CHECK_EQ(fv.size(), 0) << "There are free variables: " << fv - << " in function: " << AsText(func, false); + // TODO(@jroesch): refactor to use diagnostic context + CHECK_EQ(fv.size(), 0) << "There are free variables: " << fv << std::endl; CHECK_EQ(ftv.size(), 0) << "There are free type variables: " << fv << " in function: " << AsText(func, false); - // Type check the item before we add it to the module. - relay::Function checked_func = InferType(func, mod, var); - return checked_func; } void IRModuleNode::Add(const GlobalVar& var, const BaseFunc& f, bool update) { BaseFunc checked_func = f; if (auto* ptr = f.as()) { - checked_func = RunTypeCheck(GetRef(this), var, GetRef(ptr)); + WarnIfMalformed(GetRef(this), GetRef(ptr)); } - Type type = checked_func->checked_type(); - CHECK(type.as() == nullptr); - - if (functions.find(var) != functions.end()) { - CHECK(update) << "Already have definition for " << var->name_hint; - auto old_type = functions[var]->checked_type(); - CHECK(tvm::StructuralEqual()(type, old_type)) - << "Module#update changes type, not possible in this mode."; - } - var->checked_type_ = type; AddUnchecked(var, checked_func); } @@ -244,11 +224,9 @@ void IRModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData } void IRModuleNode::AddTypeDef(const GlobalTypeVar& var, const TypeData& type, bool update) { + // TODO(@jroesch): we have temporarily removed kind checking here, and will consolidate + // to the type checker in follow up PR. AddTypeDefUnchecked(var, type, update); - // need to kind check at the end because the check can look up - // a definition potentially - CHECK(relay::KindCheck(type, GetRef(this)) == TypeKind::kTypeData) - << "Invalid or malformed typedata given to module: " << type; } void IRModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData& type, @@ -306,20 +284,66 @@ Constructor IRModuleNode::LookupTag(const int32_t tag) { return (*it).second; } -void IRModuleNode::Update(const IRModule& mod) { - // add functions and type defs. we add them unchecked first, so all definitions - // can reference each other, independent of the order in which they were defined. - for (auto pair : mod->functions) { - this->AddUnchecked(pair.first, pair.second); +struct Renamer : relay::ExprMutator, TypeMutator { + Map defs; + Map types; + std::unordered_map ctors; + + Renamer(Map defs_one, Map defs_two, + Map types_one, Map types_two, + std::unordered_map ctors_one, + std::unordered_map ctor_two) { + for (auto pair : defs_one) { + defs.Set(pair.first, pair.second); + } + + for (auto pair : defs_two) { + auto it = defs.find(pair.first); + if (it == defs.end()) { + defs.Set(pair.first, pair.second); + } + } + + for (auto pair : types_one) { + types.Set(pair.first, pair.second); + } + + for (auto pair : types_two) { + auto it = types.find(pair.first); + if (it == types.end()) { + types.Set(pair.first, pair.second); + } + } } + + relay::Expr VisitExpr_(const GlobalVarNode* node) override { return defs.at(node->name_hint); } + + Type VisitType_(const GlobalTypeVarNode* node) override { return types.at(node->name_hint); } +}; + +void IRModuleNode::Update(const IRModule& mod) { + Renamer renamer(this->global_var_map_, mod->global_var_map_, this->global_type_var_map_, + mod->global_type_var_map_, this->constructor_tag_map_, mod->constructor_tag_map_); + + this->global_var_map_ = renamer.defs; + this->global_type_var_map_ = renamer.types; + this->constructor_tag_map_ = renamer.ctors; + for (auto pair : mod->type_definitions) { - this->AddTypeDefUnchecked(pair.first, pair.second); + auto tvar = renamer.types.at(pair.first->name_hint); + auto ty = renamer.ExprMutator::VisitType(pair.second); + this->AddTypeDefUnchecked(tvar, Downcast(ty), true); } + for (auto pair : mod->functions) { - this->Update(pair.first, pair.second); - } - for (auto pair : mod->type_definitions) { - this->UpdateTypeDef(pair.first, pair.second); + if (auto rfn = pair.second.as()) { + auto gvar = renamer.defs.at(pair.first->name_hint); + auto fn = renamer.VisitExpr(GetRef(rfn)); + this->AddUnchecked(gvar, Downcast(fn)); + } else { + // TODO(@jroesch): rename into IRModule. + this->AddUnchecked(pair.first, pair.second); + } } } @@ -351,7 +375,7 @@ void IRModuleNode::Import(const String& path) { std::fstream src_file(path, std::fstream::in); std::string file_contents{std::istreambuf_iterator(src_file), std::istreambuf_iterator()}; - auto mod_to_import = IRModule::FromText(file_contents, path); + auto mod_to_import = parser::ParseModule(path, file_contents, GetRef(this)); Update(mod_to_import); } } @@ -462,7 +486,7 @@ TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd").set_body_typed([](IRModule mod, S TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); - p->stream << "IRModuleNode( " << node->functions << ")"; + p->stream << "IRModule(" << node->functions << ")"; }); } // namespace tvm diff --git a/src/ir/span.cc b/src/ir/span.cc index d9c9bbc47c341..667c14e4a7aed 100644 --- a/src/ir/span.cc +++ b/src/ir/span.cc @@ -23,6 +23,8 @@ #include #include +#include + namespace tvm { ObjectPtr GetSourceNameNode(const String& name) { @@ -71,7 +73,9 @@ Span::Span(SourceName source_name, int line, int end_line, int column, int end_c data_ = std::move(n); } -Span Span::Merge(const Span& other) { +Span Span::Merge(const Span& other) const { + CHECK(this->defined() && other.defined()) << "Span::Merge: both spans must be defined"; + CHECK((*this)->source_name == other->source_name); return Span((*this)->source_name, std::min((*this)->line, other->line), std::max((*this)->end_line, other->end_line), diff --git a/src/ir/transform.cc b/src/ir/transform.cc index d74b95abebdb8..ec88482ee3bf3 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -282,14 +282,36 @@ ModulePass::ModulePass(runtime::TypedPackedFunc // Module -> Module optimizations. IRModule ModulePassNode::operator()(IRModule mod, const PassContext& pass_ctx) const { + DiagnosticContext previous = DiagnosticContext::Default(mod); + + if (pass_ctx->diag_ctx) { + DiagnosticContext tmp = pass_ctx->diag_ctx.value(); + pass_ctx->diag_ctx = previous; + previous = tmp; + } else { + pass_ctx->diag_ctx = previous; + } + + ICHECK(pass_ctx->diag_ctx) + << "The diagnostic context was set at the top of this block this is a bug."; + const PassInfo& pass_info = Info(); DLOG(INFO) << "Executing module pass : " << pass_info->name << " with opt level: " << pass_info->opt_level; - CHECK(mod.defined()); + ICHECK(mod.defined()) << "The input module must be set."; + pass_ctx.Trace(mod, pass_info, true); mod = pass_func(std::move(mod), pass_ctx); - CHECK(mod.defined()); + + ICHECK(mod.defined()) << "The return value of a module pass must be set."; + + ICHECK(pass_ctx->diag_ctx) + << "The diagnostic context was set at the top of this block this is a bug."; + + pass_ctx->diag_ctx.value().Render(); + pass_ctx->diag_ctx = previous; + pass_ctx.Trace(mod, pass_info, false); return mod; } diff --git a/src/ir/type.cc b/src/ir/type.cc index 38a6ec3e68051..1781dfd6e57ff 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -62,10 +62,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << '*'; }); -TypeVar::TypeVar(String name, TypeKind kind) { +TypeVar::TypeVar(String name, TypeKind kind, Span span) { ObjectPtr n = make_object(); n->name_hint = std::move(name); n->kind = std::move(kind); + n->span = std::move(span); data_ = std::move(n); } @@ -81,10 +82,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "TypeVar(" << node->name_hint << ", " << node->kind << ")"; }); -GlobalTypeVar::GlobalTypeVar(String name, TypeKind kind) { +GlobalTypeVar::GlobalTypeVar(String name, TypeKind kind, Span span) { ObjectPtr n = make_object(); n->name_hint = std::move(name); n->kind = std::move(kind); + n->span = std::move(span); data_ = std::move(n); } @@ -101,12 +103,13 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); FuncType::FuncType(tvm::Array arg_types, Type ret_type, tvm::Array type_params, - tvm::Array type_constraints) { + tvm::Array type_constraints, Span span) { ObjectPtr n = make_object(); n->arg_types = std::move(arg_types); n->ret_type = std::move(ret_type); n->type_params = std::move(type_params); n->type_constraints = std::move(type_constraints); + n->span = std::move(span); data_ = std::move(n); } @@ -125,9 +128,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << node->ret_type << ", " << node->type_constraints << ")"; }); -TupleType::TupleType(Array fields) { +TupleType::TupleType(Array fields, Span span) { ObjectPtr n = make_object(); n->fields = std::move(fields); + n->span = std::move(span); data_ = std::move(n); } @@ -145,9 +149,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "TupleTypeNode(" << node->fields << ")"; }); -IncompleteType::IncompleteType(TypeKind kind) { +IncompleteType::IncompleteType(TypeKind kind, Span span) { auto n = make_object(); n->kind = std::move(kind); + n->span = std::move(span); data_ = std::move(n); } @@ -163,9 +168,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")"; }); -RelayRefType::RelayRefType(Type value) { +RelayRefType::RelayRefType(Type value, Span span) { ObjectPtr n = make_object(); n->value = std::move(value); + n->span = std::move(span); data_ = std::move(n); } diff --git a/src/parser/diagnostic.h b/src/parser/diagnostic.h deleted file mode 100644 index 085d1c4ea8fb6..0000000000000 --- a/src/parser/diagnostic.h +++ /dev/null @@ -1,179 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file diagnostic.h - * \brief A new diagnostic interface for TVM error reporting. - * - * A prototype of the new diagnostic reporting interface for TVM. - * - * Eventually we hope to promote this file to the top-level and - * replace the existing errors.h. - */ - -#ifndef TVM_PARSER_DIAGNOSTIC_H_ -#define TVM_PARSER_DIAGNOSTIC_H_ - -#include -#include -#include -#include - -#include -#include -#include -#include - -namespace tvm { -namespace parser { - -/*! \brief The diagnostic level, controls the printing of the message. */ -enum class DiagnosticLevel { - kBug, - kError, - kWarning, - kNote, - kHelp, -}; - -struct DiagnosticBuilder; - -/*! \brief A diagnostic message. */ -struct Diagnostic { - /*! \brief The level. */ - DiagnosticLevel level; - /*! \brief The span at which to report an error. */ - Span span; - /*! \brief The diagnostic message. */ - std::string message; - - Diagnostic(DiagnosticLevel level, Span span, const std::string& message) - : level(level), span(span), message(message) {} - - static DiagnosticBuilder Bug(Span span); - static DiagnosticBuilder Error(Span span); - static DiagnosticBuilder Warning(Span span); - static DiagnosticBuilder Note(Span span); - static DiagnosticBuilder Help(Span span); -}; - -/*! - * \brief A wrapper around std::stringstream to build a diagnostic. - * - * \code - * - * void ReportError(const Error& err); - * - * void Test(int number) { - * // Use error reporter to construct an error. - * ReportError(ErrorBuilder() << "This is an error number=" << number); - * } - * - * \endcode - */ -struct DiagnosticBuilder { - public: - /*! \brief The level. */ - DiagnosticLevel level; - - /*! \brief The source name. */ - SourceName source_name; - - /*! \brief The span of the diagnostic. */ - Span span; - - template - DiagnosticBuilder& operator<<(const T& val) { // NOLINT(*) - stream_ << val; - return *this; - } - - DiagnosticBuilder() : level(DiagnosticLevel::kError), source_name(), span(Span()) {} - - DiagnosticBuilder(const DiagnosticBuilder& builder) - : level(builder.level), source_name(builder.source_name), span(builder.span) {} - - DiagnosticBuilder(DiagnosticLevel level, Span span) : level(level), span(span) {} - - operator Diagnostic() { return Diagnostic(this->level, this->span, this->stream_.str()); } - - private: - std::stringstream stream_; - friend struct Diagnostic; -}; - -DiagnosticBuilder Diagnostic::Bug(Span span) { - return DiagnosticBuilder(DiagnosticLevel::kBug, span); -} - -DiagnosticBuilder Diagnostic::Error(Span span) { - return DiagnosticBuilder(DiagnosticLevel::kError, span); -} - -DiagnosticBuilder Diagnostic::Warning(Span span) { - return DiagnosticBuilder(DiagnosticLevel::kWarning, span); -} - -DiagnosticBuilder Diagnostic::Note(Span span) { - return DiagnosticBuilder(DiagnosticLevel::kNote, span); -} - -DiagnosticBuilder Diagnostic::Help(Span span) { - return DiagnosticBuilder(DiagnosticLevel::kHelp, span); -} - -/*! \brief A diagnostic context for recording errors against a source file. - * TODO(jroesch): convert source map and improve in follow up PR, the parser - * assumes a single global file for now. - */ -struct DiagnosticContext { - /*! \brief The source to report against. */ - Source source; - - /*! \brief The set of diagnostics to report. */ - std::vector diagnostics; - - explicit DiagnosticContext(const Source& source) : source(source) {} - - /*! \brief Emit a diagnostic. */ - void Emit(const Diagnostic& diagnostic) { diagnostics.push_back(diagnostic); } - - /*! \brief Emit a diagnostic. */ - void EmitFatal(const Diagnostic& diagnostic) { - diagnostics.push_back(diagnostic); - Render(std::cout); - } - - // TODO(jroesch): eventually modularize the rendering interface to provide control of how to - // format errors. - void Render(std::ostream& ostream) { - for (auto diagnostic : diagnostics) { - source.ReportAt(ostream, diagnostic.span, diagnostic.message); - } - - if (diagnostics.size()) { - LOG(FATAL) << "DiagnosticError: one or more error diagnostics were " - << "emitted, please check diagnostic render for output."; - } - } -}; - -} // namespace parser -} // namespace tvm -#endif // TVM_PARSER_DIAGNOSTIC_H_ diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 8055d91382353..7dc55b0b519a0 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -21,19 +21,22 @@ * \file parser.cc * \brief A parser for TVM IR. */ +#include #include #include +#include #include #include #include +#include #include #include #include -#include "./diagnostic.h" #include "./meta_ref.h" #include "./op_table.h" +#include "./span_check.h" #include "./tokenizer.h" namespace tvm { @@ -45,6 +48,22 @@ using Expr = relay::Expr; /*! \brief The meta table maps from type key to a sequence of objects. */ using MetaTable = Map>; +using tvm::transform::CreateModulePass; +using tvm::transform::PassContext; + +/*! \brief A helper for passing around spans with data structures with + * no span field. + */ +template +struct Spanned { + T data; + Span span; + + Spanned() = default; + Spanned(const Spanned& other) = default; + Spanned(T data, Span span) : data(data), span(span) {} +}; + /*! \brief A wrapper structure for capturing the result of parsing * a global definition *before* we add it to the IRModule. * @@ -175,7 +194,7 @@ struct InternTable { } /*! \brief Return the unique allocation. */ - Optional Get(const std::string& name) { + Optional Get(const std::string& name) const { auto it = table.find(name); if (it != table.end()) { return Optional(it->second); @@ -185,6 +204,32 @@ struct InternTable { } }; +GlobalVar AddOrGet(InternTable* table, const std::string& name) { + auto var = table->Get(name); + if (var) { + return var.value(); + } else { + auto gvar = GlobalVar(name); + table->Add(name, gvar); + return gvar; + } +} + +GlobalTypeVar AddOrGet(InternTable* table, const std::string& name, + TypeKind kind = TypeKind::kType) { + auto var = table->Get(name); + if (var) { + auto tvar = var.value(); + TypeKind& tvar_kind = const_cast(tvar->kind); + tvar_kind = kind; + return tvar; + } else { + auto gvar = GlobalTypeVar(name, kind); + table->Add(name, gvar); + return gvar; + } +} + /*! \brief The parser class is the main interface to the parser. * the parser is not currently exposed beyond this .cc file. * @@ -228,10 +273,13 @@ class Parser { /*! \brief The version that the parser is parsing. */ SemVer version; + /*! \brief The IRModule we are building. */ + IRModule module; + /*! \brief The diagnostic context used for error reporting. */ - DiagnosticContext* diag_ctx; + DiagnosticContext diag_ctx; - const SourceName& source_name; + const Source& source; /*! \brief The current position in the token stream. */ int pos; @@ -266,15 +314,37 @@ class Parser { /*! \brief The metadata section. */ MetaTable meta_table; - Parser(DiagnosticContext* ctx, const SourceName& source_name, std::vector tokens, - OperatorTable op_table, Source source, MetaTable table) - : diag_ctx(ctx), - source_name(source_name), + Parser(IRModule module, DiagnosticContext ctx, const Source& source, std::vector tokens, + OperatorTable op_table, MetaTable table) + : module(module), + diag_ctx(ctx), + source(source), pos(0), tokens(tokens), op_table(op_table), ignore_whitespace(true), - meta_table(table) {} + meta_table(table) { + InitializeGlobals(); + InitializeTypeDefs(); + } + + /*! If we are parsing into a module with previously loaded data types we need to + * map constructor names and variable names in the global tables. + */ + void InitializeTypeDefs() { + for (auto pair : this->module->type_definitions) { + type_names.Add(pair.first->name_hint, pair.first); + for (auto ctor : pair.second->constructors) { + ctors.Add(ctor->name_hint, ctor); + } + } + } + + void InitializeGlobals() { + for (auto pair : this->module->functions) { + global_names.Add(pair.first->name_hint, pair.first); + } + } /*! \brief Examine the next token in the stream, the current parser is configured to be * whitespace insensitive so we will skip all whitespace or comment tokens. */ @@ -322,9 +392,9 @@ class Parser { */ void Consume(const TokenType& token_type) { if (tokens[pos]->token_type != token_type) { - this->diag_ctx->EmitFatal(Diagnostic::Error(tokens[pos]->span) - << "expected a " << Pretty(token_type) << " found " - << Pretty(Peek()->token_type)); + this->diag_ctx.EmitFatal(Diagnostic::Error(tokens[pos]->span) + << "expected a " << Pretty(token_type) << " found " + << Pretty(Peek()->token_type)); } pos++; } @@ -347,6 +417,7 @@ class Parser { * Useful for matching optional tokens, effectively looksahead by one. */ bool WhenMatch(const TokenType& token_type) { + DLOG(INFO) << "Parser::WhenMatch: Peek() == " << Peek(); if (Peek()->token_type == token_type) { Consume(token_type); return true; @@ -414,8 +485,8 @@ class Parser { Var LookupLocal(const Token& local) { auto var = this->expr_scopes.Lookup(local.ToString()); if (!var.defined()) { - diag_ctx->Emit(Diagnostic::Error(local->span) - << "this local variable has not been previously declared"); + diag_ctx.Emit(Diagnostic::Error(local->span) + << "this local variable has not been previously declared"); } return var; } @@ -426,11 +497,6 @@ class Parser { */ TypeVar LookupTypeVar(const Token& ident) { auto var = this->type_scopes.Lookup(ident.ToString()); - if (!var.defined()) { - diag_ctx->Emit( - Diagnostic::Error(ident->span) - << "this type variable has not been previously declared anywhere, perhaps a typo?"); - } return var; } @@ -467,12 +533,12 @@ class Parser { return data; } else if (token->token_type == TokenType::kFloat) { DLContext ctx = {DLDeviceType::kDLCPU, 0}; - auto dtype = String2DLDataType("float32"); - auto data = NDArray::Empty({}, dtype, ctx); + auto float_imm = Downcast(token->data); + auto data = NDArray::Empty({}, float_imm->dtype, ctx); auto array = reinterpret_cast(data->data); // revisit this, literal node issue. // TODO(@jroesch): bounds checking - float value = Downcast(token->data)->value; + float value = float_imm->value; array[0] = value; return data; } else { @@ -516,6 +582,33 @@ class Parser { return Bracket(TokenType::kLCurly, TokenType::kRCurly, parser); } + template + R WithSpan(std::function parser) { + auto start_span = Peek()->span; + DLOG(INFO) << "WithSpan: start_span = " << start_span; + R ast = parser(); + if (ast.defined()) { + // The token at the head of the stream is now 1 past where we parsed. So we find its start + // position as its start and end, so that when we merge we only grow the spanned region + // to the start of the current stream. + auto span_pos = pos - 1; + while ((tokens.at(span_pos)->token_type == TokenType::kWhitespace || + tokens.at(span_pos)->token_type == TokenType::kNewline || + tokens.at(span_pos)->token_type == TokenType::kLineComment || + tokens.at(span_pos)->token_type == TokenType::kComment)) { + span_pos--; + } + auto end_token = tokens.at(span_pos); + DLOG(INFO) << "WithSpan: end_span = " << end_token->span; + ast->span = start_span.Merge(end_token->span); + } + return ast; + } + + /*! \brief Parse a meta reference of the form `meta[type_key][node_index]`. + * For example `meta[relay.Constant][0]` references the first constant, `meta[relay.Constant][1]` + * the second, and so on. + */ ObjectRef ParseMetaRef() { auto meta_ref = Match(TokenType::kMetaReference); Call ref = Downcast(meta_ref->data); @@ -528,14 +621,14 @@ class Parser { if (index < nodes.size()) { return nodes[index]; } else { - this->diag_ctx->Emit(Diagnostic::Error(meta_ref->span) - << "the node index `" << index << "` is out of bounds for `" - << type_key << "`"); + this->diag_ctx.Emit(Diagnostic::Error(meta_ref->span) + << "the node index `" << index << "` is out of bounds for `" << type_key + << "`"); return ObjectRef(); } } else { - this->diag_ctx->Emit(Diagnostic::Error(meta_ref->span) - << "no entry in the meta table for `" << type_key << "`"); + this->diag_ctx.Emit(Diagnostic::Error(meta_ref->span) + << "no entry in the meta table for `" << type_key << "`"); return ObjectRef(); } } @@ -553,8 +646,8 @@ class Parser { template Array ParseSequence(TokenType start, TokenType sep, TokenType stop, std::function parse, std::function before_stop = nullptr) { - DLOG(INFO) << "Parser::ParseSequence: start=" << ToString(start) << "sep=" << ToString(sep) - << "stop=" << ToString(stop); + DLOG(INFO) << "Parser::ParseSequence: start=" << ToString(start) << " sep=" << ToString(sep) + << " stop=" << ToString(stop); Match(start); // This is for the empty arguments list case, if we have token stream @@ -571,6 +664,7 @@ class Parser { if (WhenMatch(stop)) { return Array(); } else { + DLOG(INFO) << "Parser::ParseSequence: parse first"; auto data = parse(); Array elements = {data}; @@ -579,6 +673,7 @@ class Parser { // parse '( expr ',' * ')' } else if (WhenMatch(sep)) { while (true) { + DLOG(INFO) << "Parser::ParseSequence: parse element"; if (WhenMatch(stop)) { break; } else { @@ -598,9 +693,9 @@ class Parser { return elements; } else { auto next = Peek(); - this->diag_ctx->EmitFatal(Diagnostic::Error(next->span) - << "expected a " << Pretty(stop) << " found " - << Pretty(next->token_type)); + this->diag_ctx.EmitFatal(Diagnostic::Error(next->span) + << "expected a " << Pretty(stop) << " found " + << Pretty(next->token_type)); return Array(nullptr); } } @@ -616,20 +711,16 @@ class Parser { auto metadata = ParseMetadata(); Match(TokenType::kEndOfFile); - Map funcs; - Map types; for (auto type_def : defs.types) { - types.Set(type_def->header, type_def); + module->AddTypeDef(type_def->header, type_def); } - auto mod = IRModule({}, types); - for (auto func : defs.funcs) { - mod->Add(func.global, func.function); + module->Add(func.global, func.function, true); } - return mod; + return module; } /*! \brief Parse the semantic versioning header. */ @@ -638,14 +729,16 @@ class Parser { auto version = Match(TokenType::kVersion); // TODO(@jroesch): we currently only support 0.0.5. if (version.ToString() != "\"0.0.5\"") { - this->diag_ctx->Emit(Diagnostic::Error(version->span) - << "invalid semantic version `" << version.ToString() << "`"); + this->diag_ctx.Emit(Diagnostic::Error(version->span) + << "invalid semantic version `" << version.ToString() << "`"); } } else if (required) { - this->diag_ctx->Emit(Diagnostic::Error(Peek()->span) - << "expected text format semantic version, found a " - << PrettyPrint(Peek()) - << "you can annotate it as #[version = \"0.0.5\"]"); + this->diag_ctx.Emit(Diagnostic::Error(Peek()->span) + << "expected text format semantic version, found a " + << PrettyPrint(Peek())); + + this->diag_ctx.Emit(Diagnostic::Help(Peek()->span) + << "you can annotate it as #[version = \"0.0.5\"]"); } return SemVer(0, 0, 5); } @@ -661,15 +754,9 @@ class Parser { Consume(TokenType::kDefn); auto global_tok = Match(TokenType::kGlobal); auto global_name = global_tok.ToString(); - auto global = GlobalVar(global_name); - try { - global_names.Add(global_name, global); - } catch (const DuplicateKeyError& e) { - this->diag_ctx->Emit(Diagnostic::Error(global_tok->span) << "a function with the name " - << "`@" << global_name << "` " - << "was previously defined"); - } - auto func = ParseFunctionDef(); + auto global = AddOrGet(&global_names, global_name); + auto func = WithSpan([&]() { return ParseFunctionDef(); }); + ICHECK(func->span.defined()) << "spans must be set in parser"; defs.funcs.push_back(GlobalFunc(global, func)); continue; } @@ -681,8 +768,8 @@ class Parser { Consume(TokenType::kExtern); auto type_def = ParseTypeDef(); if (type_def->constructors.size()) { - diag_ctx->Emit(Diagnostic::Error(next->span) - << "an external type may not have any constructors"); + diag_ctx.Emit(Diagnostic::Error(next->span) + << "an external type may not have any constructors"); } defs.types.push_back(type_def); } @@ -699,15 +786,7 @@ class Parser { // Parse the type's identifier. auto type_tok = Match(TokenType::kIdentifier); auto type_id = type_tok.ToString(); - auto type_global = tvm::GlobalTypeVar(type_id, TypeKind::kAdtHandle); - - try { - type_names.Add(type_id, type_global); - } catch (const DuplicateKeyError& e) { - this->diag_ctx->Emit(Diagnostic::Error(type_tok->span) << "a type definition with the name " - << "`" << type_id << "` " - << "was previously defined"); - } + auto type_global = AddOrGet(&type_names, type_id, TypeKind::kAdtHandle); Array generics; @@ -748,10 +827,10 @@ class Parser { try { this->ctors.Add(ctor_name, ctor); } catch (const DuplicateKeyError& e) { - this->diag_ctx->EmitFatal(Diagnostic::Error(ctor_tok->span) - << "a constructor with the name " - << "`" << ctor_name << "` " - << "was previously defined"); + this->diag_ctx.EmitFatal(Diagnostic::Error(ctor_tok->span) + << "a constructor with the name " + << "`" << ctor_name << "` " + << "was previously defined"); } return ctor; @@ -793,7 +872,7 @@ class Parser { /*! \brief Parse a single Relay expression. */ Expr ParseExpr() { DLOG(INFO) << "Parser::ParseExpr"; - return ConsumeWhitespace([this] { + return WithSpan([this] { std::vector exprs; while (true) { @@ -807,11 +886,13 @@ class Parser { // Stack should only grow proportionally to the number of // nested scopes. // Parses `{` expression `}`. - auto block = Bracket(TokenType::kLCurly, TokenType::kRCurly, [&]() { - PushScope(); - auto expr = ParseExpr(); - PopScopes(1); - return expr; + auto block = WithSpan([&]() { + return Bracket(TokenType::kLCurly, TokenType::kRCurly, [&]() { + PushScope(); + auto expr = ParseExpr(); + PopScopes(1); + return expr; + }); }); exprs.push_back(block); break; @@ -866,15 +947,20 @@ class Parser { CHECK_GE(exprs.size(), 1); if (exprs.size() == 1) { + // ICHECK(exprs[0].defined() && exprs[0]->span.defined()) + // << "parser must set expression spans.\n" + // << exprs[0]; return exprs[0]; } else { auto body = exprs.back(); exprs.pop_back(); while (exprs.size()) { auto value = exprs.back(); + ICHECK(value->span.defined()) << "parser must set expression spans."; exprs.pop_back(); - body = relay::Let(Var("", IncompleteType()), value, body); + body = relay::Let(Var("", IncompleteType()), value, body, value->span.Merge(body->span)); } + ICHECK(body->span.defined()) << "parser must set expression spans."; return body; } }); @@ -907,7 +993,7 @@ class Parser { // the call depth will be the same before // and after parsing the n bindings. DLOG(INFO) << "Parser::ParseBindingExpr"; - std::vector> bindings; + std::vector> bindings; int scopes = 0; while (true) { @@ -919,6 +1005,7 @@ class Parser { Match(TokenType::kSemicolon); AddGraphBinding(next, val); } else if (next->token_type == TokenType::kLet) { + auto span = next->span; // Parse the 'let'. Consume(TokenType::kLet); @@ -942,7 +1029,8 @@ class Parser { Consume(TokenType::kSemicolon); // Add the bindings to the local data structure. - bindings.push_back({var, val}); + std::tuple tuple(var, val, span); + bindings.push_back(tuple); scopes++; PushScope(); } else { @@ -964,7 +1052,8 @@ class Parser { } else { // We can now build the let binding up backwards. for (auto binding = bindings.rbegin(); binding != bindings.rend(); binding++) { - body = relay::Let(binding->first, binding->second, body); + auto span = body->span.Merge(std::get<2>(*binding)); + body = relay::Let(std::get<0>(*binding), std::get<1>(*binding), body, span); } return body; } @@ -978,87 +1067,92 @@ class Parser { */ Function ParseFunctionDef() { DLOG(INFO) << "Parser::ParseFunctionDef"; - PushScope(); - PushTypeScope(); - - Array generics; - if (Peek()->token_type == TokenType::kLSquare) { - // If we have generics we need to add a type scope. + return WithSpan([&]() { + PushScope(); PushTypeScope(); - generics = ParseSequence( - TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() { - auto type_var_name = Match(TokenType::kIdentifier).ToString(); - return BindTypeVar(type_var_name, TypeKind::kType); - }); - } - Map raw_attrs; + Array generics; + if (Peek()->token_type == TokenType::kLSquare) { + // If we have generics we need to add a type scope. + PushTypeScope(); + generics = ParseSequence( + TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() { + auto type_var_name = Match(TokenType::kIdentifier).ToString(); + return BindTypeVar(type_var_name, TypeKind::kType); + }); + } - auto params = ParseSequence( - TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, - [&]() { - auto token = Match(TokenType::kLocal); - auto string = token.ToString(); - Type type; - if (WhenMatch(TokenType::kColon)) { - type = ParseType(); - } - return BindVar(string, type); - }, - [&] { - auto is_ident = Lookahead(1)->token_type == TokenType::kIdentifier; - auto next_is_equal = Lookahead(2)->token_type == TokenType::kEqual; - - if (is_ident && next_is_equal) { - raw_attrs = ParseAttrs(); - return true; - } + Map raw_attrs; - return false; - }); + auto params = ParseSequence( + TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, + [&]() { + auto token = Match(TokenType::kLocal); + auto string = token.ToString(); + Type type; + if (WhenMatch(TokenType::kColon)) { + type = ParseType(); + } + return BindVar(string, type); + }, + [&] { + auto is_ident = Lookahead(1)->token_type == TokenType::kIdentifier; + auto next_is_equal = Lookahead(2)->token_type == TokenType::kEqual; + + if (is_ident && next_is_equal) { + raw_attrs = ParseAttrs(); + return true; + } - Type ret_type; - if (WhenMatch(TokenType::kMinus)) { - Match(TokenType::kRAngle); - ret_type = ParseType(); - } + return false; + }); + + Type ret_type; + if (WhenMatch(TokenType::kMinus)) { + Match(TokenType::kRAngle); + ret_type = ParseType(); + } - auto body = Block([&]() { return ParseExpr(); }); + auto body = Block([&]() { return ParseExpr(); }); - PopTypeScopes(1); - PopScopes(1); + PopTypeScopes(1); + PopScopes(1); - // TODO(@jroesch): attributes should never be null, they should always be empty. - if (raw_attrs.size()) { - return relay::Function(params, body, ret_type, generics, DictAttrs(raw_attrs)); - } else { - return relay::Function(params, body, ret_type, generics); - } + // TODO(@jroesch): attributes should never be null, they should always be empty. + if (raw_attrs.size()) { + return relay::Function(params, body, ret_type, generics, DictAttrs(raw_attrs)); + } else { + return relay::Function(params, body, ret_type, generics, tvm::DictAttrs()); + } + }); } /*! \brief Parse an if-expression. */ Expr ParseIf() { - DLOG(INFO) << "Parser::ParseIf"; - Consume(TokenType::kIf); - auto guard = Parens([&] { return ParseExpr(); }); - - auto true_branch = Block([&] { - this->PushScope(); - auto expr = ParseExpr(); - this->PopScopes(1); - return expr; - }); + return WithSpan([&]() { + DLOG(INFO) << "Parser::ParseIf"; + Consume(TokenType::kIf); - Match(TokenType::kElse); + auto guard = WithSpan([&] { return Parens([&] { return ParseExpr(); }); }); - auto false_branch = Block([&] { - this->PushScope(); - auto expr = ParseExpr(); - this->PopScopes(1); - return expr; - }); + auto true_branch = Block([&] { + this->PushScope(); + auto expr = ParseExpr(); + this->PopScopes(1); + return expr; + }); + + Match(TokenType::kElse); + + auto false_branch = Block([&] { + this->PushScope(); + auto expr = ParseExpr(); + this->PopScopes(1); + return expr; + }); - return relay::If(guard, true_branch, false_branch); + return relay::If(guard, true_branch, false_branch); + }); } /* This factors parsing a list of patterns for both tuples, and constructors. */ @@ -1094,7 +1188,15 @@ class Parser { case TokenType::kIdentifier: { auto id = Match(TokenType::kIdentifier); auto ctor = ctors.Get(id.ToString()); - CHECK(ctor) << "undefined identifier"; + if (!ctor) { + diag_ctx.EmitFatal( + // TODO(@jroesch): split into error and help + // deal with multiple rendering + Diagnostic::Error(id->span) + << "undefined constructor name `" << id.ToString() + << "`, perhaps you intended to write a" + << "pattern variable, considering changing this to `%" << id.ToString() << "`"); + } if (Peek()->token_type == TokenType::kOpenParen) { auto fields = ParsePatternList(); return PatternConstructor(ctor.value(), fields); @@ -1118,22 +1220,27 @@ class Parser { } Expr ParseMatch(bool is_total) { - Expr scrutinee = ParseExpr(); + return WithSpan([&]() { + Expr scrutinee = ParseAtomicExpr(); - Array clauses = ParseSequence( - TokenType::kLCurly, TokenType::kComma, TokenType::kRCurly, [&] { return ParseMatchArm(); }); + Array clauses = + ParseSequence(TokenType::kLCurly, TokenType::kComma, TokenType::kRCurly, + [&] { return ParseMatchArm(); }); - return relay::Match(scrutinee, clauses, is_total); + return relay::Match(scrutinee, clauses, is_total); + }); } Expr ParseExprBinOp() { DLOG(INFO) << "Parser::ParseExprBinOp"; - return ConsumeWhitespace([this] { + return WithSpan([this] { // We must parse at least one expression, the default // case is that there is no operator and we will fall // through. std::vector exprs; - exprs.push_back(ParseCallExpr()); + Expr expr = WithSpan([this] { return ParseCallExpr(); }); + + exprs.push_back(expr); // Now we parse an optional op. std::vector ops; @@ -1150,7 +1257,8 @@ class Parser { // Read the operation we parsed; auto op = opt_op[0]; - Expr right = ParseCallExpr(); + Expr right = WithSpan([this] { return ParseCallExpr(); }); + CHECK(right->span.defined()); // If the operator stack is empty // we parse an operator and expression @@ -1177,7 +1285,9 @@ class Parser { exprs.pop_back(); Expr left = exprs.back(); exprs.pop_back(); - exprs.push_back(relay::Call(new_op.op, {left, right})); + CHECK(new_op.op.defined()) << "a call op must be set " << new_op.op; + exprs.push_back( + relay::Call(new_op.op, {left, right}, Attrs(), {}, left->span.Merge(right->span))); } exprs.push_back(right); @@ -1191,11 +1301,16 @@ class Parser { exprs.pop_back(); Expr left = exprs.back(); exprs.pop_back(); - exprs.push_back(relay::Call(new_op.op, {left, right})); + CHECK(new_op.op.defined()) << "a call op must be set " << new_op.op; + exprs.push_back( + relay::Call(new_op.op, {left, right}, Attrs(), {}, left->span.Merge(right->span))); } - CHECK_EQ(ops.size(), 0); - CHECK_EQ(exprs.size(), 1); + ICHECK_EQ(ops.size(), 0) << "No operations should be left on the operation stack."; + + ICHECK_EQ(exprs.size(), 1) + << "Only a single expression should be left on the expression stack."; + return exprs[0]; }); } @@ -1254,49 +1369,46 @@ class Parser { } Expr ParseCallArgs(Expr op) { - try { - DLOG(INFO) << "Parser::ParseCallArgs"; - Map raw_attrs; - std::string op_key; - bool is_op = false; + CHECK(op.defined()) << "the operator must be defined"; - if (auto op_node = op.as()) { - is_op = true; - op_key = op_node->attrs_type_key; - } + DLOG(INFO) << "Parser::ParseCallArgs"; + Map raw_attrs; + std::string op_key; + bool is_op = false; - if (Peek()->token_type == TokenType::kOpenParen) { - Array args = ParseSequence( - TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, - [&] { return ParseExpr(); }, - [&] { - auto is_ident = Lookahead(1)->token_type == TokenType::kIdentifier; - auto next_is_equal = Lookahead(2)->token_type == TokenType::kEqual; - - if (is_op && is_ident && next_is_equal) { - raw_attrs = ParseAttrs(); - return true; - } + if (auto op_node = op.as()) { + is_op = true; + op_key = op_node->attrs_type_key; + } - return false; - }); + if (Peek()->token_type == TokenType::kOpenParen) { + Array args = ParseSequence( + TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, + [&] { return ParseExpr(); }, + [&] { + auto is_ident = Lookahead(1)->token_type == TokenType::kIdentifier; + auto next_is_equal = Lookahead(2)->token_type == TokenType::kEqual; + + if (is_op && is_ident && next_is_equal) { + raw_attrs = ParseAttrs(); + return true; + } - Attrs attrs; + return false; + }); - if (is_op && op_key.size()) { - auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, raw_attrs); - CHECK(attr_obj.defined()); - attrs = Downcast(attr_obj); - } + Attrs attrs; - return Expr(Call(op, args, attrs, {})); - } else { - return Expr(); + if (is_op && op_key.size()) { + auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, raw_attrs); + CHECK(attr_obj.defined()); + attrs = Downcast(attr_obj); } - } catch (...) { - // TODO(@jroesch): AttrErrors should have fields - this->diag_ctx->Emit(Diagnostic::Error(Peek()->span)); - // << err.what()); + + // TODO(@jroesch): in a secondary pass adjust spans. + return Expr(Call(op, args, attrs, {})); + } else { + return Expr(); } return Expr(); @@ -1304,25 +1416,19 @@ class Parser { Expr ParseCallExpr() { DLOG(INFO) << "Parser::ParseCallExpr"; - return ConsumeWhitespace([this] { + return WithSpan([this] { Expr expr = ParseAtomicExpr(); // Parse as many call args as possible, building up expression // // NB(@jroesch): this seems like a hack but in order to parse curried functions // and avoid complex grammar we will parse multiple call lists in a row. while (Peek()->token_type == TokenType::kOpenParen) { - try { - auto new_expr = ParseCallArgs(expr); + auto new_expr = ParseCallArgs(expr); - if (new_expr.defined()) { - expr = new_expr; - } else { - break; - } - } catch (...) { - // TODO(@jroesch): AttrErrors should have fields - this->diag_ctx->EmitFatal(Diagnostic::Error(Peek()->span)); - // << err.what()); + if (new_expr.defined()) { + expr = new_expr; + } else { + break; } } @@ -1337,21 +1443,22 @@ class Parser { }); } - Expr GetOp(const std::string& op_name, const Token& tok) { - DLOG(INFO) << "op_name=" << op_name << " token=" << tok; + Expr GetOp(const std::string& op_name, const Span& span) { + DLOG(INFO) << "op_name=" << op_name << " span=" << span; try { return Op::Get(op_name); } catch (const dmlc::Error& e) { - this->diag_ctx->Emit(Diagnostic::Error(tok->span) - << "operator `" << op_name - << "` not found, perhaps you forgot to register it?"); + // we can relax this, but probably need to relax checks or return non-null here. + this->diag_ctx.EmitFatal(Diagnostic::Error(span) + << "operator `" << op_name + << "` not found, perhaps you forgot to register it?"); return Expr(); } } Expr ParseAtomicExpr() { DLOG(INFO) << "Parser::ParseAtomicExpr"; - auto expr = ConsumeWhitespace([this] { + Expr expr = WithSpan([this] { auto next = Peek(); switch (next->token_type) { case TokenType::kInteger: @@ -1359,6 +1466,7 @@ class Parser { Consume(next->token_type); auto number = NumberToNDArray(next); Expr e = Constant(number, next->span); + ICHECK(e->span.defined()) << "constant spans must be defined"; return e; } case TokenType::kBoolean: { @@ -1366,6 +1474,7 @@ class Parser { int value = Downcast(next->data); auto boolean = BooleanToNDarray(value); Expr e = Constant(boolean, next->span); + ICHECK(e->span.defined()) << "constant spans must be defined"; return e; } // Parse a local of the form `%x`. @@ -1375,17 +1484,10 @@ class Parser { } // Parse a local of the form `@x`. case TokenType::kGlobal: { - auto string = next.ToString(); + auto global_name = next.ToString(); Consume(TokenType::kGlobal); - auto global = global_names.Get(string); - if (!global) { - // TODO(@jroesch): fix global's needing span information - auto global_var = GlobalVar(string); - global_names.Add(string, global_var); - return Expr(global_var); - } else { - return Expr(global.value()); - } + auto global = AddOrGet(&global_names, global_name); + return Expr(global); } // Parse a local of the form `x`. // Right now we fail to parse `x.y`. @@ -1395,7 +1497,9 @@ class Parser { Consume(TokenType::kIdentifier); return Expr(ctor.value()); } else { - auto idents = ParseHierarchicalName(); + auto spanned_idents = ParseHierarchicalName(); + auto idents = spanned_idents.data; + auto span = spanned_idents.span; CHECK_NE(idents.size(), 0); std::stringstream op_name; int i = 0; @@ -1407,7 +1511,7 @@ class Parser { i++; } } - return GetOp(op_name.str(), next); + return GetOp(op_name.str(), span); } } case TokenType::kGraph: { @@ -1419,49 +1523,88 @@ class Parser { } case TokenType::kFn: { Consume(TokenType::kFn); - return Expr(ParseFunctionDef()); + Expr e = ParseFunctionDef(); + ICHECK(e->span.defined()) << "function spans must be defined.\n" << e; + return e; + } + case TokenType::kRef: { + Consume(TokenType::kRef); + Match(TokenType::kOpenParen); + auto ref_value = ParseExpr(); + Match(TokenType::kCloseParen); + return static_cast(RefCreate(ref_value)); + } + case TokenType::kRefRead: { + return WithSpan([&]() { + Consume(TokenType::kRefRead); + Match(TokenType::kOpenParen); + auto ref = ParseExpr(); + Match(TokenType::kCloseParen); + return static_cast(RefRead(ref)); + }); + } + case TokenType::kRefWrite: { + return WithSpan([&]() { + Consume(TokenType::kRefWrite); + Match(TokenType::kOpenParen); + auto ref = ParseExpr(); + Match(TokenType::kComma); + auto value = ParseExpr(); + Match(TokenType::kCloseParen); + return static_cast(RefWrite(ref, value)); + }); } case TokenType::kOpenParen: { + Span sp = next->span; Consume(TokenType::kOpenParen); // parse '(' ')' if (WhenMatch(TokenType::kCloseParen)) { return Expr(Tuple(Array())); } else { - auto expr = ParseExpr(); + Expr subexpr = ParseExpr(); // parse '(' expr ')' if (WhenMatch(TokenType::kCloseParen)) { - return expr; + return subexpr; // parse '( expr ',' * ')' } else if (WhenMatch(TokenType::kComma)) { - Array exprs = {expr}; + Array exprs = {subexpr}; while (true) { if (WhenMatch(TokenType::kCloseParen)) { break; } else { - auto expr = ParseExpr(); - WhenMatch(TokenType::kComma); - exprs.push_back(expr); + auto element = ParseExpr(); + auto comma = Peek(); + if (WhenMatch(TokenType::kComma)) { + sp = sp.Merge(element->span.Merge(comma->span)); + } else { + sp = sp.Merge(element->span); + } + exprs.push_back(element); } } - return static_cast(Tuple(exprs)); + Expr tuple = Tuple(exprs, sp); + ICHECK(tuple->span.defined()) << "tuple span should be defined"; + return tuple; } } } default: { - this->diag_ctx->EmitFatal(Diagnostic::Error(next->span) - << "expected an expression found " - << Pretty(next->token_type)); + this->diag_ctx.EmitFatal(Diagnostic::Error(next->span) + << "expected an expression found " << Pretty(next->token_type)); return Expr(); } } }); if (WhenMatch(TokenType::kPeriod)) { - auto index = Match(TokenType::kInteger).ToNumber(); - expr = relay::TupleGetItem(expr, index); + auto token = Match(TokenType::kInteger); + auto index = token.ToNumber(); + auto span = token->span.Merge(expr->span); + DLOG(INFO) << "Parser::ParseAtomicExpr: tuple get item"; + return relay::TupleGetItem(expr, index, span); + } else { + return expr; } - - return expr; } /*! \brief Parse a hierarchical name. @@ -1475,10 +1618,19 @@ class Parser { * single stream inserting the required periods needed * to look up registered names. */ - Array ParseHierarchicalName() { + Spanned> ParseHierarchicalName() { Array idents; + Span span; while (Peek()->token_type == TokenType::kIdentifier) { - auto name = Peek().ToString(); + auto token = Peek(); + + if (span.defined()) { + span = span.Merge(token->span); + } else { + span = token->span; + } + + auto name = token.ToString(); idents.push_back(name); Consume(TokenType::kIdentifier); @@ -1492,7 +1644,7 @@ class Parser { } } - return idents; + return Spanned>(idents, span); } /*! \brief Parse a shape. */ @@ -1527,29 +1679,36 @@ class Parser { // Parses a user defined ADT or type variable. Type ParseNonPrimitiveType(const Token& tok) { - auto name = tok.ToString(); - Type head_type; - auto global_type = type_names.Get(name); + return WithSpan([&]() { + auto name = tok.ToString(); + Type head_type = LookupTypeVar(tok); - if (!global_type) { - head_type = LookupTypeVar(tok); - } else { - head_type = global_type.value(); - } + if (!head_type.defined()) { + // head_type = type_names.Get(name); + head_type = AddOrGet(&type_names, name, TypeKind::kAdtHandle); + } - CHECK(head_type.defined()) << "internal error: head type must be defined"; + if (!head_type.defined()) { + diag_ctx.EmitFatal(Diagnostic::Error(tok->span) + << "the type constructor `" << name << "` is undefined"); + } - Array arg_types; - if (Peek()->token_type == TokenType::kLSquare) { - arg_types = ParseSequence(TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, - [&]() { return ParseType(); }); - } + Array arg_types; + if (Peek()->token_type == TokenType::kLSquare) { + arg_types = ParseSequence(TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, + [&]() { return ParseType(); }); + } - if (arg_types.size()) { - return TypeCall(head_type, arg_types); - } else { - return head_type; - } + if (arg_types.size()) { + return static_cast(TypeCall(head_type, arg_types)); + } else { + if (head_type.as()) { + return static_cast(TypeCall(head_type, {})); + } else { + return static_cast(head_type); + } + } + }); } /*! \brief Parses a TVM type. @@ -1558,42 +1717,45 @@ class Parser { * a scalar type or an incomplete type `_`. */ Type ParseType() { - auto tok = Peek(); - - if (tok->token_type == TokenType::kOpenParen) { - auto tys = ParseSequence(TokenType::kOpenParen, TokenType::kComma, - TokenType::kCloseParen, [&]() { return ParseType(); }); - return relay::TupleType(tys); - } else if (WhenMatch(TokenType::kFn)) { - return ParseFunctionType(); - } else if (WhenMatch(TokenType::kIdentifier)) { - auto id = tok.ToString(); - if (id == "Tensor") { - Match(TokenType::kLSquare); - auto shape = ParseShape(); - Match(TokenType::kComma); - auto dtype_tok = Match(TokenType::kIdentifier); - auto dtype = DataType(String2DLDataType(dtype_tok.ToString())); - Match(TokenType::kRSquare); - return TensorType(shape, dtype); - } else { - auto ty = tok.ToString(); - if (ty.rfind("int", 0) == 0 || ty.find("float", 0) == 0 || ty.find("uint", 0) == 0 || - ty.find("bool", 0) == 0) { - // Need to do better error handling here. - auto dtype = DataType(String2DLDataType(tok.ToString())); - return TensorType({}, dtype); + return WithSpan([&]() -> Type { + auto tok = Peek(); + + if (tok->token_type == TokenType::kOpenParen) { + auto tys = + ParseSequence(TokenType::kOpenParen, TokenType::kComma, + TokenType::kCloseParen, [&]() { return ParseType(); }); + return relay::TupleType(tys); + } else if (WhenMatch(TokenType::kFn)) { + return ParseFunctionType(); + } else if (WhenMatch(TokenType::kIdentifier)) { + auto id = tok.ToString(); + if (id == "Tensor") { + Match(TokenType::kLSquare); + auto shape = ParseShape(); + Match(TokenType::kComma); + auto dtype_tok = Match(TokenType::kIdentifier); + auto dtype = DataType(String2DLDataType(dtype_tok.ToString())); + Match(TokenType::kRSquare); + return TensorType(shape, dtype); } else { - return ParseNonPrimitiveType(tok); + auto ty = tok.ToString(); + if (ty.rfind("int", 0) == 0 || ty.find("float", 0) == 0 || ty.find("uint", 0) == 0 || + ty.find("bool", 0) == 0) { + // Need to do better error handling here. + auto dtype = DataType(String2DLDataType(tok.ToString())); + return TensorType({}, dtype); + } else { + return ParseNonPrimitiveType(tok); + } } + } else if (WhenMatch(TokenType::kUnderscore)) { + return IncompleteType(); + } else { + this->diag_ctx.EmitFatal(Diagnostic::Error(tok->span) + << "failed to parse type found " << tok); + return Type(); } - } else if (WhenMatch(TokenType::kUnderscore)) { - return IncompleteType(); - } else { - this->diag_ctx->EmitFatal(Diagnostic::Error(tok->span) - << "failed to parse type found " << tok); - return Type(); - } + }); } template @@ -1642,32 +1804,50 @@ class Parser { } }; -IRModule ParseModule(std::string file_name, std::string file_content) { - DLOG(INFO) << "ParseModule"; +Parser InitParser(const std::string& file_name, const std::string& file_content, + Optional init_module) { + DLOG(INFO) << "InitParser: file_name: " << file_name + << "file_content_size: " << file_content.size(); SourceName src_name = SourceName::Get(file_name); - Source src(src_name, file_content); - DiagnosticContext ctx(src); - auto tokens_and_table = Tokenize(&ctx, src_name, file_content); + Source source(src_name, file_content); + + IRModule module; + if (!init_module) { + SourceMap source_map; + module = IRModule({}, {}, {}, source_map); + } else { + module = init_module.value(); + } + + module->source_map.Add(source); + + auto diag_ctx = DiagnosticContext::Default(module); + auto tokens_and_table = Tokenize(diag_ctx, source); + auto tokens = tokens_and_table.first; auto meta_data_table = tokens_and_table.second; - Parser parser(&ctx, src_name, tokens, DefaultOpTable(), src, meta_data_table.ToMetadata()); + + return Parser(module, diag_ctx, source, tokens, DefaultOpTable(), meta_data_table.ToMetadata()); +} + +IRModule ParseModule(std::string file_name, std::string file_content, + Optional init_module) { + DLOG(INFO) << "ParseModule"; + auto parser = InitParser(file_name, file_content, init_module); auto mod = parser.ParseModule(); + ICHECK(mod.defined()) << "The parser must return a non-null module."; // NB(@jroesch): it is very important that we render any errors before we procede // if there were any errors which allow the parser to procede we must render them // here. - parser.diag_ctx->Render(std::cout); - return mod; + parser.diag_ctx.Render(); + auto infer_type = tvm::relay::transform::InferType(); + ICHECK(infer_type.defined()) << "The type inferencer must be non-null."; + return infer_type(mod); } Expr ParseExpr(std::string file_name, std::string file_content) { DLOG(INFO) << "ParseExpr"; - SourceName src_name = SourceName::Get(file_name); - Source src(src_name, file_content); - DiagnosticContext ctx(src); - auto tokens_and_table = Tokenize(&ctx, src_name, file_content); - auto tokens = tokens_and_table.first; - auto meta_data_table = tokens_and_table.second; - Parser parser(&ctx, src_name, tokens, DefaultOpTable(), src, meta_data_table.ToMetadata()); + auto parser = InitParser(file_name, file_content, Optional()); parser.ParseSemVer(false); parser.PushScope(); auto expr = parser.ParseExpr(); @@ -1675,7 +1855,7 @@ Expr ParseExpr(std::string file_name, std::string file_content) { // NB(@jroesch): it is very important that we render any errors before we procede // if there were any errors which allow the parser to procede we must render them // here. - parser.diag_ctx->Render(std::cout); + parser.diag_ctx.Render(); return expr; } @@ -1689,5 +1869,14 @@ TVM_REGISTER_GLOBAL("parser.ParseExpr") return ParseExpr(file_name, file_content); }); +TVM_REGISTER_GLOBAL("relay._transform.AnnotateSpans").set_body_typed([]() { + return CreateModulePass( + [](const IRModule& mod, const PassContext& ctx) { + auto text = AsText(mod, true); + return ParseModule("GeneratedSource", text); + }, + 0, "AnnotateSpans", {}); +}); + } // namespace parser } // namespace tvm diff --git a/src/parser/source_map.cc b/src/parser/source_map.cc index a2efdb5a88fd5..40998b0c9dc4f 100644 --- a/src/parser/source_map.cc +++ b/src/parser/source_map.cc @@ -26,20 +26,28 @@ namespace tvm { namespace parser { +TVM_REGISTER_NODE_TYPE(SourceNode); + /*! \brief Construct a source from a string. */ -Source::Source(const SourceName& src_name, const std::string& source) - : source_name(src_name), source(source) { +Source::Source(SourceName src_name, std::string source) { + auto n = make_object(); + n->source_name = std::move(src_name); + n->source = std::move(source); + int index = 0; int length = 0; - line_map.push_back({index, length}); - for (auto c : source) { + n->line_map.push_back({index, length}); + // NB(@jroesch): + std::string source_str = n->source; + for (auto c : source_str) { + DLOG(INFO) << "char=" << c; if (c == '\n') { // Record the length of the line. - line_map.back().second = length; + n->line_map.back().second = length; // Bump past the newline. index += 1; // Record the start of the next line, and put placeholder for length. - line_map.push_back({index, 0}); + n->line_map.push_back({index, 0}); // Reset length to zero. length = 0; } else { @@ -47,54 +55,28 @@ Source::Source(const SourceName& src_name, const std::string& source) index += 1; } } - line_map.back().second = length; -} + n->line_map.back().second = length; -/*! \brief Generate an error message at a specific line and column with the - * annotated message. - * - * The error is written directly to the `out` std::ostream. - * - * \param out The output ostream. - * \param line The line at which to report a diagnostic. - * \param line The column at which to report a diagnostic. - * \param msg The message to attach. - */ -void Source::ReportAt(std::ostream& out, const Span& span, const std::string& msg) const { - DLOG(INFO) << "Source::ReportAt" - << "span = " << span << "msg = " << msg; - int line = span->line; - int column = span->column; + data_ = n; +} - CHECK(line - 1 <= static_cast(line_map.size())) - << "requested line: " << (line - 1) << "line_map size: " << line_map.size() - << "source: " << source; +tvm::String Source::GetLine(int line) { + DLOG(INFO) << "Source::GetLine: line=" << line; + CHECK(line - 1 < static_cast((*this)->line_map.size())) + << "requested line: " << line << "at index: " << (line - 1) + << "line_map size: " << (*this)->line_map.size() << "source: " << (*this)->source; // Adjust for zero indexing, now have (line_start, line_length); - auto range = line_map.at(line - 1); + auto range = (*this)->line_map.at(line - 1); int line_start = range.first; int line_length = range.second; - out << "file:" << line << ":" << column << ": parse error: " << msg << std::endl; - out << " " << source.substr(line_start, line_length) << std::endl; - out << " "; - std::stringstream marker; - for (int i = 1; i <= line_length; i++) { - if (i == column) { - marker << "^"; - } else if ((column - i) < 3) { - marker << "~"; - } else if ((i - column) < 3) { - marker << "~"; - } else { - marker << " "; - } - } - out << marker.str(); - out << std::endl; + DLOG(INFO) << "Source::GetLine: line_start=" << line_start << " line_length=" << line_length; + // TODO(@jroesch): expose substring on tvm::String. + auto line_text = std::string((*this)->source).substr(line_start, line_length); + DLOG(INFO) << "Source::GetLine: line_text=" << line_text; + return line_text; } -// TVM_REGISTER_GLOBAL("ir.SourceName").set_body_typed(SourceName::Get); - // TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { // auto* node = static_cast(ref.get()); @@ -103,11 +85,25 @@ void Source::ReportAt(std::ostream& out, const Span& span, const std::string& ms TVM_REGISTER_NODE_TYPE(SourceMapNode); -SourceMap::SourceMap(Map source_map) { +SourceMap::SourceMap(Map source_map) { auto n = make_object(); n->source_map = std::move(source_map); data_ = std::move(n); } +// TODO(@jroesch): fix this +static SourceMap global_source_map = SourceMap(Map()); + +SourceMap SourceMap::Global() { return global_source_map; } + +void SourceMap::Add(const Source& source) { (*this)->source_map.Set(source->source_name, source); } + +TVM_REGISTER_GLOBAL("SourceMapAdd").set_body_typed([](SourceMap map, String name, String content) { + auto src_name = SourceName::Get(name); + Source source(src_name, content); + map.Add(source); + return src_name; +}); + } // namespace parser } // namespace tvm diff --git a/src/parser/span_check.cc b/src/parser/span_check.cc new file mode 100644 index 0000000000000..a72db5b9e4b65 --- /dev/null +++ b/src/parser/span_check.cc @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file span_check.cc + * \brief A utility for checking and reporting malformed span information. + */ +#include "./span_check.h" + +#include + +namespace tvm { +namespace parser { + +using tvm::relay::transform::CreateFunctionPass; +using tvm::transform::PassContext; + +void SpanChecker::VisitExpr(const Expr& e) { + this->expression = e; + VisitSpan(e->span); + span_stack.push_back(e->span); + ExprVisitor::VisitExpr(e); + this->expression = e; + span_stack.pop_back(); +} + +// TODO(@jroesch, @junru): we need to deal with unique spans for global/var. +void SpanChecker::VisitExpr_(const VarNode* op) {} +void SpanChecker::VisitExpr_(const GlobalVarNode* op) {} +void SpanChecker::VisitExpr_(const ConstantNode* op) {} + +void SpanChecker::VisitExpr_(const TupleNode* op) { ExprVisitor::VisitExpr_(op); } + +void SpanChecker::VisitExpr_(const FunctionNode* op) { ExprVisitor::VisitExpr_(op); } + +void SpanChecker::VisitExpr_(const CallNode* op) { ExprVisitor::VisitExpr_(op); } + +void SpanChecker::VisitExpr_(const LetNode* op) { ExprVisitor::VisitExpr_(op); } + +void SpanChecker::VisitExpr_(const IfNode* op) { ExprVisitor::VisitExpr_(op); } + +void SpanChecker::VisitExpr_(const OpNode* op) {} + +void SpanChecker::VisitExpr_(const TupleGetItemNode* op) { ExprVisitor::VisitExpr_(op); } + +void SpanChecker::VisitExpr_(const RefCreateNode* op) { ExprVisitor::VisitExpr_(op); } + +void SpanChecker::VisitExpr_(const RefReadNode* op) { ExprVisitor::VisitExpr_(op); } + +void SpanChecker::VisitExpr_(const RefWriteNode* op) { ExprVisitor::VisitExpr_(op); } + +void SpanChecker::VisitExpr_(const ConstructorNode* op) {} // ExprVisitor::VisitExpr_(op); } + +void SpanChecker::VisitExpr_(const MatchNode* op) { ExprVisitor::VisitExpr_(op); } + +void SpanChecker::VisitSpan(const Span& sp) { + if (!sp.defined()) { + Span span; + int i = 0; + for (auto spans = this->span_stack.rbegin(); spans != this->span_stack.rend(); spans++) { + i += 1; + span = this->span_stack.back(); + if (span.defined()) { + diag_ctx.Emit(Diagnostic::Warning(span) << "found null-span, i-nodes deep from this span."); + return; + } + } + auto warning = Diagnostic::Warning(span); + warning << "\tAll spans are null\n"; + warning << "\t" << this->expression; + diag_ctx.Emit(warning); + } +} + +void SpanChecker::VisitType(const Type& t) {} +void SpanChecker::VisitClause(const Clause& c) {} +void SpanChecker::VisitPattern(const Pattern& c) {} + +Pass SpanCheck() { + return CreateFunctionPass( + [](const Function& func, const IRModule& mod, const PassContext& ctx) { + ICHECK(ctx->diag_ctx) << "Diagnostic context must be set."; + SpanChecker checker(ctx->diag_ctx.value()); + checker.VisitExpr(func); + ctx->diag_ctx.value().Render(); + return func; + }, + 0, "SpanCheck", {}); +} + +TVM_REGISTER_GLOBAL("parser.SpanCheck").set_body_typed([]() { return SpanCheck(); }); + +} // namespace parser +} // namespace tvm diff --git a/src/parser/span_check.h b/src/parser/span_check.h new file mode 100644 index 0000000000000..b9ba76df4b8fa --- /dev/null +++ b/src/parser/span_check.h @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file span_check.h + * \brief Check that the Relay IR has correctly attached span information. + */ + +#ifndef TVM_PARSER_SPAN_CHECK_H_ +#define TVM_PARSER_SPAN_CHECK_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace tvm { +namespace parser { + +using namespace tvm::relay; +using tvm::transform::Pass; + +struct SpanChecker : ExprVisitor { + Expr expression; + DiagnosticContext diag_ctx; + std::vector span_stack; + + explicit SpanChecker(DiagnosticContext diag_ctx) : diag_ctx(diag_ctx) {} + + void VisitExpr(const Expr& expr) override; + void VisitExpr_(const VarNode* op) override; + void VisitExpr_(const GlobalVarNode* op) override; + void VisitExpr_(const ConstantNode* op) override; + void VisitExpr_(const TupleNode* op) override; + void VisitExpr_(const FunctionNode* op) override; + void VisitExpr_(const CallNode* op) override; + void VisitExpr_(const LetNode* op) override; + void VisitExpr_(const IfNode* op) override; + void VisitExpr_(const OpNode* op) override; + void VisitExpr_(const TupleGetItemNode* op) override; + void VisitExpr_(const RefCreateNode* op) override; + void VisitExpr_(const RefReadNode* op) override; + void VisitExpr_(const RefWriteNode* op) override; + void VisitExpr_(const ConstructorNode* op) override; + void VisitExpr_(const MatchNode* op) override; + void VisitType(const Type& t) override; + void VisitClause(const Clause& c) override; + void VisitPattern(const Pattern& c) override; + void VisitSpan(const Span& span) override; +}; + +Pass SpanCheck(); + +} // namespace parser +} // namespace tvm +#endif // TVM_PARSER_SPAN_CHECK_H_ diff --git a/src/parser/token.h b/src/parser/token.h index 3750ec568cc84..1133483fa8f82 100644 --- a/src/parser/token.h +++ b/src/parser/token.h @@ -89,6 +89,9 @@ enum class TokenType { kMetadata, kMetaReference, kFreeVar, + kRef, + kRefRead, + kRefWrite, kVersion, kUnknown, kEndOfFile, @@ -199,6 +202,12 @@ std::string ToString(const TokenType& token_type) { return "FreeVar"; case TokenType::kVersion: return "Version"; + case TokenType::kRef: + return "Ref"; + case TokenType::kRefRead: + return "RefRead"; + case TokenType::kRefWrite: + return "RefWrite"; case TokenType::kUnknown: return "Unknown"; case TokenType::kEndOfFile: @@ -314,6 +323,12 @@ std::string Pretty(const TokenType& token_type) { return "`match?`"; case TokenType::kQuestion: return "`?`"; + case TokenType::kRef: + return "`ref`"; + case TokenType::kRefRead: + return "`ref_read`"; + case TokenType::kRefWrite: + return "`ref_write`"; case TokenType::kUnknown: return "unknown"; case TokenType::kEndOfFile: diff --git a/src/parser/tokenizer.h b/src/parser/tokenizer.h index 88a49290dc3d7..20ad1734e5730 100644 --- a/src/parser/tokenizer.h +++ b/src/parser/tokenizer.h @@ -67,21 +67,22 @@ bool IsIdentLetter(char c) { return '_' == c || ('a' <= c && c <= 'z') || ('A' < bool IsIdent(char c) { return IsIdentLetter(c) || IsDigit(c); } static std::unordered_map KEYWORD_TABLE = { - {"let", TokenType::kLet}, {"fn", TokenType::kFn}, - {"def", TokenType::kDefn}, {"if", TokenType::kIf}, - {"else", TokenType::kElse}, {"type", TokenType::kTypeDef}, - {"match", TokenType::kMatch}, {"extern", TokenType::kExtern}, - {"free_var", TokenType::kFreeVar}}; + {"let", TokenType::kLet}, {"fn", TokenType::kFn}, + {"def", TokenType::kDefn}, {"if", TokenType::kIf}, + {"else", TokenType::kElse}, {"type", TokenType::kTypeDef}, + {"match", TokenType::kMatch}, {"extern", TokenType::kExtern}, + {"free_var", TokenType::kFreeVar}, {"ref", TokenType::kRef}, + {"ref_read", TokenType::kRefRead}, {"ref_write", TokenType::kRefWrite}}; struct Tokenizer { - DiagnosticContext* diag_ctx; + DiagnosticContext diag_ctx; const SourceName& source_name; size_t pos; int col; int line; char next_char; - const std::string& source; + String source; std::vector tokens; char Next() { @@ -187,13 +188,26 @@ struct Tokenizer { } catch (const std::invalid_argument& ia) { auto token = NewToken(TokenType::kFloat); - if (number.back() == 'f') { - number.pop_back(); + auto suffix_pos = number.rfind("f"); + + auto literal_text = number.substr(0, suffix_pos); + + auto suffix = number.substr(suffix_pos + 1, number.size() - suffix_pos); + + int width = 32; + + if (suffix.size()) { + try { + width = std::stoi(suffix); + } catch (const std::invalid_argument& err) { + this->diag_ctx.Emit(Diagnostic::Error(token->span) + << "invalid numeric suffix `" << suffix << "`"); + } } - double value = stod(number); + double value = stod(literal_text); value = is_pos ? value : -value; - token->data = tvm::FloatImm(DataType::Float(64), value); + token->data = tvm::FloatImm(DataType::Float(width), value); return token; } } @@ -278,15 +292,15 @@ struct Tokenizer { } else { // TOOD(@jroesch): maybe make this a warning an continue parsing? auto span = SpanFrom(line, column); - this->diag_ctx->EmitFatal(Diagnostic::Error(span) << "unsupported attribute " << attribute); + this->diag_ctx.EmitFatal(Diagnostic::Error(span) << "unsupported attribute " << attribute); return Token(); } } else { auto span = SpanFrom(line, column); this->diag_ctx - ->EmitFatal(Diagnostic::Error(span) - << "`#` denotes the start of an attribute can only be followed by `[`" - << " found `" << Peek() << "`"); + .EmitFatal(Diagnostic::Error(span) + << "`#` denotes the start of an attribute can only be followed by `[`" + << " found `" << Peek() << "`"); return Token(); } } @@ -307,7 +321,7 @@ struct Tokenizer { return token; } else { auto span = SpanFrom(line, col); - this->diag_ctx->EmitFatal( + this->diag_ctx.EmitFatal( Diagnostic::Error(span) << "\\r carriage returns must be followed by a \\n in the TVM text format"); return Token(); @@ -347,9 +361,13 @@ struct Tokenizer { } bool is_float = false; + // Remove trailing floating point prefix. if (More() && Peek() == 'f') { - Next(); + ss << Next(); + while (More() && IsNumeric(Peek())) { + ss << Next(); + } is_float = true; } @@ -525,14 +543,13 @@ struct Tokenizer { this->tokens.push_back(NewToken(TokenType::kEndOfFile)); } - explicit Tokenizer(DiagnosticContext* ctx, const SourceName& source_name, - const std::string& source) + explicit Tokenizer(const DiagnosticContext& ctx, const Source& source) : diag_ctx(ctx), - source_name(source_name), + source_name(source->source_name), pos(0), col(1), line(1), - source(source), + source(source->source), tokens() {} }; @@ -615,9 +632,8 @@ std::vector Condense(const std::vector& tokens, Token* table) { return out; } -std::pair, Token> Tokenize(DiagnosticContext* ctx, const SourceName& source_name, - const std::string& source) { - auto tokenizer = Tokenizer(ctx, source_name, source); +std::pair, Token> Tokenize(const DiagnosticContext& ctx, const Source& source) { + auto tokenizer = Tokenizer(ctx, source); tokenizer.Tokenize(); Token meta_table(Span(), TokenType::kUnknown, ObjectRef()); auto tokens = Condense(tokenizer.tokens, &meta_table); diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index aa8775db53117..555d335a51dab 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -308,7 +308,7 @@ Doc RelayTextPrinter::ScalarLiteral(DataType dtype, const T& value) { } else if (dtype == DataType::Float(32)) { os << value << 'f'; } else if (dtype == DataType::Float(64)) { - os << value; + os << value << "f64"; } else if (dtype == DataType::Bool()) { return Doc::PyBoolLiteral(value != 0); } else { @@ -500,12 +500,12 @@ Doc RelayTextPrinter::VisitExpr_(const RefCreateNode* op) { Doc RelayTextPrinter::VisitExpr_(const RefReadNode* op) { Doc doc; - return doc << Print(op->ref) << "^"; + return doc << "ref_read(" << Print(op->ref) << ")"; } Doc RelayTextPrinter::VisitExpr_(const RefWriteNode* op) { Doc doc; - return doc << "(" << Print(op->ref) << " := " << Print(op->value) << ")"; + return doc << "ref_write(" << Print(op->ref) << ", " << Print(op->value) << ")"; } Doc RelayTextPrinter::VisitExpr_(const MatchNode* op) { @@ -522,10 +522,11 @@ Doc RelayTextPrinter::VisitExpr_(const MatchNode* op) { Doc clause_doc; clause_doc << PrintPattern(clause->lhs, false) << " => "; Doc rhs_doc = PrintScope(clause->rhs); - if (clause->rhs.as()) { - // only add braces if there are multiple lines on the rhs - rhs_doc = Doc::Brace("{", rhs_doc, "}"); - } + // TODO(@jroesch): This is unsound right now, and we need to revist it. + // if (clause->rhs.as()) { + // only add braces if there are multiple lines on the rhs + rhs_doc = Doc::Brace("{", rhs_doc, "}"); + // } clause_doc << rhs_doc << ","; clause_docs.push_back(clause_doc); } diff --git a/src/relay/analysis/kind_check.cc b/src/relay/analysis/kind_check.cc index ac0abc0655578..c7c5a0a9f0832 100644 --- a/src/relay/analysis/kind_check.cc +++ b/src/relay/analysis/kind_check.cc @@ -42,22 +42,26 @@ using namespace tvm::runtime; struct KindChecker : TypeFunctor { const IRModule& mod; - ErrorReporter err_reporter; + Optional diag_ctx; - explicit KindChecker(const IRModule& mod) : mod(mod), err_reporter() {} + explicit KindChecker(const IRModule& mod, Optional diag_ctx) + : mod(mod), diag_ctx(diag_ctx) {} - void ReportFatalError(const Error& err) { - this->err_reporter.Report(err); - this->err_reporter.RenderErrors(mod); + void EmitFatal(Diagnostic diagnostic) { + if (this->diag_ctx) { + this->diag_ctx.value().EmitFatal(diagnostic); + } else { + LOG(FATAL) << diagnostic->message; + } } void CheckKindMatches(const Type& t, const Type& outer, Kind expected, const std::string& description) { Kind k = this->VisitType(t); if (k != expected) { - ReportFatalError(ErrorBuilder() - << "Incorrect kind for a " << description << ". Type " << t << " inside " - << outer << " is of kind " << k << " but was expected to be " << expected); + EmitFatal(Diagnostic::Error(t->span) + << "Incorrect kind for a " << description << ". Type " << t << " inside " << outer + << " is of kind " << k << " but was expected to be " << expected); } } @@ -115,8 +119,8 @@ struct KindChecker : TypeFunctor { TypeCall tc = GetRef(op); const auto* gtv = op->func.as(); if (gtv == nullptr) { - ReportFatalError(ErrorBuilder() << "The callee in " << tc - << " is not a global type var, but is " << op->func); + EmitFatal(Diagnostic::Error(op->span) + << "The callee in " << tc << " is not a global type var, but is " << op->func); } CheckKindMatches(op->func, tc, Kind::kAdtHandle, "type call function"); @@ -127,11 +131,20 @@ struct KindChecker : TypeFunctor { // finally we need to check the module to check the number of type params auto var = GetRef(gtv); - auto data = mod->LookupTypeDef(var); - if (data->type_vars.size() != op->args.size()) { - ReportFatalError(ErrorBuilder() << "Expected " << data->type_vars.size() << "arguments for " - << tc << "; got " << op->args.size()); + try { + auto data = mod->LookupTypeDef(var); + + if (data->type_vars.size() != op->args.size()) { + EmitFatal(Diagnostic::Error(op->span) + << "Expected " << data->type_vars.size() << "arguments for " << tc << "; got " + << op->args.size()); + } + } catch (const dmlc::Error& err) { + // TODO(@jroesch): can probably relax to just emit + EmitFatal(Diagnostic::Error(op->span) + << "the type variable : `" << var->name_hint << "` is undefined"); } + return Kind::kType; } @@ -149,8 +162,8 @@ struct KindChecker : TypeFunctor { for (const auto& con : op->constructors) { if (!con->belong_to.same_as(op->header)) { - ReportFatalError(ErrorBuilder() << con << " has header " << con->belong_to << " but " << op - << " has header " << op->header); + EmitFatal(Diagnostic::Error(op->span) << con << " has header " << con->belong_to << " but " + << op << " has header " << op->header); } for (const Type& t : con->inputs) { @@ -163,16 +176,18 @@ struct KindChecker : TypeFunctor { Kind Check(const Type& t) { return this->VisitType(t); } }; -Kind KindCheck(const Type& t, const IRModule& mod) { - KindChecker kc(mod); +Kind KindCheck(const Type& t, const IRModule& mod, Optional diag_ctx) { + KindChecker kc(mod, diag_ctx); return kc.Check(t); } TVM_REGISTER_GLOBAL("relay.analysis.check_kind").set_body([](TVMArgs args, TVMRetValue* ret) { if (args.size() == 1) { *ret = KindCheck(args[0], IRModule({}, {})); + } else if (args.size() == 2) { + *ret = KindCheck(args[0], args[1], Optional()); } else { - *ret = KindCheck(args[0], args[1]); + *ret = KindCheck(args[0], args[1], args[2]); } }); diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc index a674265a88de5..2b0e30b378632 100644 --- a/src/relay/analysis/type_solver.cc +++ b/src/relay/analysis/type_solver.cc @@ -39,7 +39,7 @@ class TypeSolver::Reporter : public TypeReporterNode { public: explicit Reporter(TypeSolver* solver) : solver_(solver) {} - void Assign(const Type& dst, const Type& src) final { solver_->Unify(dst, src, location); } + void Assign(const Type& dst, const Type& src) final { solver_->Unify(dst, src, span); } bool Assert(const IndexExpr& cond) final { if (const int64_t* pdiff = tir::as_const_int(cond)) { @@ -57,13 +57,21 @@ class TypeSolver::Reporter : public TypeReporterNode { return true; } - TVM_DLL void SetLocation(const ObjectRef& ref) final { location = ref; } + TVM_DLL void SetSpan(const Span& span) final { this->span = span; } + + TVM_DLL Span GetSpan() final { return this->span; } + + TVM_DLL DiagnosticContext GetDiagCtx() final { return this->solver_->diag_ctx_; } + + // TVM_DLL void Emit(Diagnostic diagnostic) final { + // return this->solver_-> + // } TVM_DLL IRModule GetModule() final { return this->solver_->module_; } private: - /*! \brief The location to report unification errors at. */ - mutable ObjectRef location; + /*! \brief The span to report unification errors at. */ + mutable Span span; TypeSolver* solver_; }; @@ -92,7 +100,7 @@ class TypeSolver::OccursChecker : public TypeVisitor { class TypeSolver::Unifier : public TypeFunctor { public: - explicit Unifier(TypeSolver* solver, const ObjectRef& loc) : solver_(solver), loc(loc) {} + explicit Unifier(TypeSolver* solver, const Span& span) : solver_(solver), span(span) {} Type Unify(const Type& src, const Type& dst) { // Known limitation @@ -120,11 +128,14 @@ class TypeSolver::Unifier : public TypeFunctor { return lhs->resolved_type; } else { Type resolved = this->VisitType(lhs->resolved_type, rhs->resolved_type); + if (!resolved.defined()) { - solver_->ReportError(ErrorBuilder() << "unable to unify: " - << "`" << PrettyPrint(lhs->resolved_type) << "` and `" - << PrettyPrint(rhs->resolved_type) << "`", - this->loc); + solver_->diag_ctx_.Emit( + Diagnostic::Error(this->span) + << "The Relay type checker is unable to show the following types match.\n" + << "In particular " + << "`" << PrettyPrint(lhs->resolved_type) << "` does not match `" + << PrettyPrint(rhs->resolved_type) << "`"); return lhs->resolved_type; } else { TypeNode* top = solver_->GetTypeNode(resolved); @@ -232,11 +243,11 @@ class TypeSolver::Unifier : public TypeFunctor { tvm::Array shape; if (tt1->shape.size() != tt2->shape.size()) { - this->solver_->ReportError(ErrorBuilder() << "tensor type `" << PrettyPrint(tt1) << "` has " - << tt1->shape.size() << " dimensions, while `" - << PrettyPrint(tt2) << "` has " << tt2->shape.size() - << " dimensions", - this->loc); + this->solver_->diag_ctx_.Emit(Diagnostic::Error(this->span) + << "tensor type `" << PrettyPrint(tt1) << "` has " + << tt1->shape.size() << " dimensions, while `" + << PrettyPrint(tt2) << "` has " << tt2->shape.size() + << " dimensions"); return Type(nullptr); } @@ -258,14 +269,13 @@ class TypeSolver::Unifier : public TypeFunctor { } if (mismatches.size() != 0) { - ErrorBuilder err; + auto err = Diagnostic::Error(this->span); err << "in particular "; for (auto mismatch : mismatches) { err << "dimension " << std::get<0>(mismatch) << " conflicts " << std::get<1>(mismatch) << " does not match " << std::get<2>(mismatch); } - Error error(err); - this->solver_->ReportError(error, this->loc); + this->solver_->diag_ctx_.Emit(err); return Type(nullptr); } @@ -361,7 +371,7 @@ class TypeSolver::Unifier : public TypeFunctor { private: TypeSolver* solver_; - ObjectRef loc; + Span span; }; class TypeSolver::Resolver : public TypeMutator { @@ -523,13 +533,12 @@ class TypeSolver::Merger : public TypeFunctor { }; // constructor -TypeSolver::TypeSolver(const GlobalVar& current_func, const IRModule& module, - ErrorReporter* err_reporter) +TypeSolver::TypeSolver(const GlobalVar& current_func, DiagnosticContext diag_ctx) : reporter_(make_object(this)), current_func(current_func), - err_reporter_(err_reporter), - module_(module) { - CHECK(module_.defined()) << "internal error: module must be defined"; + diag_ctx_(diag_ctx), + module_(diag_ctx->module) { + CHECK(module_.defined()); } // destructor @@ -550,23 +559,17 @@ void TypeSolver::MergeFromTo(TypeNode* src, TypeNode* dst) { } // Add equality constraint -Type TypeSolver::Unify(const Type& dst, const Type& src, const ObjectRef& loc) { - Unifier unifier(this, loc); +Type TypeSolver::Unify(const Type& dst, const Type& src, const Span& span) { + Unifier unifier(this, span); return unifier.Unify(dst, src); } -void TypeSolver::ReportError(const Error& err, const ObjectRef& location) { - CHECK(location.defined()); - CHECK(current_func.defined()); - err_reporter_->ReportAt(current_func, location, err); -} - // Add type constraint to the solver. -void TypeSolver::AddConstraint(const TypeConstraint& constraint, const ObjectRef& loc) { +void TypeSolver::AddConstraint(const TypeConstraint& constraint, const Span& span) { if (const auto* op = constraint.as()) { // create a new relation node. RelationNode* rnode = arena_.make(); - rnode->location = loc; + rnode->span = span; rnode->rel = GetRef(op); rel_nodes_.push_back(rnode); // populate the type information. @@ -609,12 +612,9 @@ bool TypeSolver::Solve() { CHECK_LE(args.size(), rel->args.size()); } - CHECK(rnode->location.defined()) - << "undefined location, should be set when constructing relation node"; - // We need to set this in order to understand where unification // errors generated by the error reporting are coming from. - reporter_->SetLocation(rnode->location); + reporter_->SetSpan(rnode->span); try { // Call the Type Relation's function. @@ -626,13 +626,10 @@ bool TypeSolver::Solve() { rnode->resolved = resolved; } catch (const Error& err) { - this->ReportError(err, rnode->location); + this->diag_ctx_.Emit(Diagnostic::Error(rnode->span) << "err"); rnode->resolved = false; - } catch (const dmlc::Error& err) { - rnode->resolved = false; - this->ReportError(ErrorBuilder() << "an internal invariant was violated while " - << "typechecking your program " << err.what(), - rnode->location); + } catch (const dmlc::Error& e) { + ICHECK(false) << e.what(); } // Mark inqueue as false after the function call @@ -650,30 +647,28 @@ TVM_REGISTER_GLOBAL("relay.analysis._test_type_solver") .set_body([](runtime::TVMArgs args, runtime::TVMRetValue* ret) { using runtime::PackedFunc; using runtime::TypedPackedFunc; - ErrorReporter* err_reporter = new ErrorReporter(); auto module = IRModule({}, {}); + DiagnosticContext diag_ctx = DiagnosticContext::Default(module); auto dummy_fn_name = GlobalVar("test"); module->Add(dummy_fn_name, Function({}, Tuple(tvm::Array({})), Type(), {}, {})); - auto solver = std::make_shared(dummy_fn_name, module, err_reporter); + auto solver = std::make_shared(dummy_fn_name, diag_ctx); - auto mod = [module, solver, err_reporter](std::string name) -> PackedFunc { + auto mod = [module, solver, diag_ctx](std::string name) -> PackedFunc { if (name == "Solve") { return TypedPackedFunc([solver]() { return solver->Solve(); }); } else if (name == "Unify") { - return TypedPackedFunc( - [module, solver, err_reporter](Type lhs, Type rhs) { - auto res = solver->Unify(lhs, rhs, lhs); - if (err_reporter->AnyErrors()) { - err_reporter->RenderErrors(module, true); - } - return res; - }); + return TypedPackedFunc([module, solver, diag_ctx](Type lhs, Type rhs) { + auto res = solver->Unify(lhs, rhs, Span()); + DiagnosticContext ctx = diag_ctx; + ctx.Render(); + return res; + }); } else if (name == "Resolve") { return TypedPackedFunc([solver](Type t) { return solver->Resolve(t); }); } else if (name == "AddConstraint") { return TypedPackedFunc([solver](TypeConstraint c) { - Expr e = Var("dummy_var", IncompleteType(Kind::kType)); - return solver->AddConstraint(c, e); + Expr e = Var("dummy_var", IncompleteType(Kind::kType), Span(SourceName(), 0, 0, 0, 0)); + return solver->AddConstraint(c, e->span); }); } else { return PackedFunc(); diff --git a/src/relay/analysis/type_solver.h b/src/relay/analysis/type_solver.h index dcd8de0758549..1fc0525d6bcaf 100644 --- a/src/relay/analysis/type_solver.h +++ b/src/relay/analysis/type_solver.h @@ -63,14 +63,14 @@ using support::LinkNode; */ class TypeSolver { public: - TypeSolver(const GlobalVar& current_func, const IRModule& _mod, ErrorReporter* err_reporter); + TypeSolver(const GlobalVar& current_func, DiagnosticContext diag_ctx); ~TypeSolver(); /*! * \brief Add a type constraint to the solver. * \param constraint The constraint to be added. * \param location The location at which the constraint was incurred. */ - void AddConstraint(const TypeConstraint& constraint, const ObjectRef& lcoation); + void AddConstraint(const TypeConstraint& constraint, const Span& span); /*! * \brief Resolve type to the solution type in the solver. * \param type The type to be resolved. @@ -88,13 +88,12 @@ class TypeSolver { * \param rhs The right operand * \param location The location at which the unification problem arose. */ - Type Unify(const Type& lhs, const Type& rhs, const ObjectRef& location); + Type Unify(const Type& lhs, const Type& rhs, const Span& span); /*! - * \brief Report an error at the provided location. - * \param err The error to report. - * \param loc The location at which to report the error. + * \brief Report a diagnostic. + * \param diag The diagnostic to report. */ - void ReportError(const Error& err, const ObjectRef& location); + void EmitDiagnostic(const Diagnostic& diag); private: class OccursChecker; @@ -156,7 +155,7 @@ class TypeSolver { /*! \brief list types to this relation */ LinkedList type_list; /*! \brief The location this type relation originated from. */ - ObjectRef location; + Span span; }; /*! \brief A simple union find between shapes. */ @@ -177,8 +176,12 @@ class TypeSolver { TypeReporter reporter_; /*! \brief The global representing the current function. */ GlobalVar current_func; - /*! \brief Error reporting. */ - ErrorReporter* err_reporter_; + + public: + /*! \brief The diagnostic context. */ + DiagnosticContext diag_ctx_; + + private: /*! \brief The module. */ IRModule module_; diff --git a/src/relay/analysis/well_formed.cc b/src/relay/analysis/well_formed.cc index bccb1e1bf1ce2..3e409d10b8855 100644 --- a/src/relay/analysis/well_formed.cc +++ b/src/relay/analysis/well_formed.cc @@ -21,6 +21,7 @@ * \file well_formed.cc * \brief check that expression is well formed. */ +#include #include #include #include @@ -32,8 +33,23 @@ namespace relay { //! brief make sure each Var is bound at most once in a scope. class WellFormedChecker : private ExprVisitor, PatternVisitor { + public: + Optional diag_ctx; + Span occurs_in; + + explicit WellFormedChecker(const Optional& ctx) : diag_ctx(ctx) {} + bool well_formed = true; + void Illformed(Diagnostic diag) { + well_formed = false; + if (diag_ctx) { + diag_ctx.value().Emit(diag); + } else { + LOG(INFO) << "The IR is not well formed with: " << diag->message; + } + } + std::vector> scope; std::unordered_set current_bound; std::unordered_set total_bound; @@ -54,7 +70,8 @@ class WellFormedChecker : private ExprVisitor, PatternVisitor { void Bound(const Var& v) { if (current_bound.count(v) != 0 || total_bound.count(v) != 0 || free.count(v) != 0) { - well_formed = false; + Illformed(Diagnostic::Error(v->span) << "the variable " << v->name_hint() + << "is bound more then once, this is not valid IR"); } CHECK_GE(scope.size(), 0); scope.back().insert(v); @@ -66,7 +83,8 @@ class WellFormedChecker : private ExprVisitor, PatternVisitor { Var v = GetRef(op); if (current_bound.count(v) == 0) { if (total_bound.count(v) != 0) { - well_formed = false; + Illformed(Diagnostic::Error(v->span) << "the variable " << v->name_hint() + << "is bound more then once, this is not valid IR"); } else { free.insert(v); } @@ -99,6 +117,18 @@ class WellFormedChecker : private ExprVisitor, PatternVisitor { CheckWellFormed(f->body); } + void VisitExpr_(const CallNode* call) final { + CHECK(call->op.defined()); + + for (auto arg : call->args) { + CHECK(arg.defined()); + } + + // CHECK(call->attrs.defined()); + CHECK(call->type_args.defined()); + ExprVisitor::VisitExpr_(call); + } + void VisitClause(const Clause& c) final { Scope s(this); VisitPattern(c->lhs); @@ -113,6 +143,7 @@ class WellFormedChecker : private ExprVisitor, PatternVisitor { if (auto v = e.as()) { VisitExpr_(v); } else { + // this->occurs_in = e->span; ExprVisitor::VisitExpr(e); } } @@ -124,9 +155,13 @@ class WellFormedChecker : private ExprVisitor, PatternVisitor { } }; -bool WellFormed(const Expr& e) { return WellFormedChecker().CheckWellFormed(e); } +bool WellFormed(const Expr& e, Optional diag_ctx) { + return WellFormedChecker(diag_ctx).CheckWellFormed(e); +} -TVM_REGISTER_GLOBAL("relay.analysis.well_formed").set_body_typed(WellFormed); +TVM_REGISTER_GLOBAL("relay.analysis.well_formed").set_body_typed([](Expr e) { + return WellFormed(e); +}); } // namespace relay } // namespace tvm diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index b95e0962bd27a..64f1253ff9db2 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -241,6 +241,8 @@ class RelayBuildModule : public runtime::ModuleNode { */ IRModule Optimize(IRModule relay_module, const TargetsMap& targets, const std::unordered_map& params) { + ICHECK(relay_module.defined()) << "The IRModule must be defined for the Relay compiler."; + if (params.size()) { CHECK(relay_module->ContainGlobalVar("main")) << "Missing the main entry function"; GlobalVar main_glb_var = relay_module->GetGlobalVar("main"); @@ -293,6 +295,7 @@ class RelayBuildModule : public runtime::ModuleNode { // Alter layout transformation is only applied to homogeneous execution yet. if (targets.size() == 1) { + pass_seqs.push_back(transform::InferType()); pass_seqs.push_back(transform::AlterOpLayout()); } @@ -330,6 +333,8 @@ class RelayBuildModule : public runtime::ModuleNode { // inline functions. However, this should be very unlikely for accelerators // and vendor-provided libraries. So we don't handle for now. relay_module = transform::Inline()(relay_module); + relay_module = transform::InferType()(relay_module); + CHECK(relay_module.defined()); return relay_module; diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h index e392c79d8cdef..95166c74f891b 100644 --- a/src/relay/backend/compile_engine.h +++ b/src/relay/backend/compile_engine.h @@ -82,7 +82,7 @@ struct CachedFuncNode : public Object { /*! \brief The schedule to the function */ te::Schedule schedule; /*! \brief The lowered functions to support the function. */ - IRModule funcs = IRModule(); + IRModule funcs = IRModule(Map({})); /*! \brief Parameter usage states in the shape function. */ tvm::Array shape_func_param_states; diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index 959a7306668f9..acc99c51b69b8 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -207,7 +207,7 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator({}))); } auto& mod = ret.lowered_funcs[kv.first]; mod->Update(kv.second); @@ -403,7 +403,7 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslatorstr())) { - lowered_funcs_[target->str()] = IRModule(); + lowered_funcs_[target->str()] = IRModule(Map({})); } lowered_funcs_[target->str()]->Update(lowered_func->funcs); return GraphAddCallNode(op, _GetUniqueName(lowered_func->func_name), lowered_func->func_name); diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 2afaa86a32ac7..e58c23b766703 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -719,7 +719,9 @@ TypedPackedFunc CreateInterpreter(IRModule mod, DLContext conte if (mod.defined()) { // eta expand to support constructors in argument position transform::Sequential seq({transform::EtaExpand( - /* expand_constructor */ true, /* expand_global_var */ false)}); + /* expand_constructor */ true, /* expand_global_var */ false), + transform::InferType()}); + transform::PassContext pass_ctx = transform::PassContext::Current(); tvm::With ctx(pass_ctx); mod = seq(mod); diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index aeb0c5aa55b2c..99ffea45b269b 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -1088,6 +1088,7 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe pass_seqs.push_back(transform::Inline()); pass_seqs.push_back(MemoryOpt(target_host, targets)); + pass_seqs.push_back(transform::InferType()); transform::Sequential seq(pass_seqs); transform::PassContext pass_ctx = PassContext::Current(); diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index b540dd47bcd9b..b5f4d152ee00a 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -113,14 +113,32 @@ FunctionPass::FunctionPass( // Perform Module -> Module optimizations at the Function level. IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const { + DiagnosticContext previous = DiagnosticContext::Default(mod); + + if (pass_ctx->diag_ctx) { + DiagnosticContext tmp = pass_ctx->diag_ctx.value(); + pass_ctx->diag_ctx = previous; + previous = tmp; + } else { + pass_ctx->diag_ctx = previous; + } + + ICHECK(pass_ctx->diag_ctx) + << "The diagnostic context was set at the top of this block this is a bug."; + const PassInfo& pass_info = Info(); + CHECK(mod.defined()); + DLOG(INFO) << "Executing function pass : " << pass_info->name << " with opt level: " << pass_info->opt_level; + pass_ctx.Trace(mod, pass_info, true); // Execute the pass function and return a new module. - IRModule updated_mod = IRModule(mod->functions, mod->type_definitions, mod->Imports()); + IRModule updated_mod = + IRModule(mod->functions, mod->type_definitions, mod->Imports(), mod->source_map); + std::vector > updates; for (const auto& it : updated_mod->functions) { // only picks up relay::Function @@ -134,8 +152,18 @@ IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) for (const auto& pair : updates) { updated_mod->Add(pair.first, pair.second, true); } + + ICHECK(pass_ctx->diag_ctx) + << "The diagnostic context was set at the top of this block this is a bug."; + + pass_ctx->diag_ctx.value().Render(); + pass_ctx->diag_ctx = previous; + pass_ctx.Trace(updated_mod, pass_info, false); - return updated_mod; + + // TODO(@jroesch): move away from eager type checking for performance reasons + // make issue. + return transform::InferType()(updated_mod); } bool FunctionPassNode::SkipFunction(const Function& func) const { diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h index 2311585deb601..cd334d7269ab1 100644 --- a/src/relay/op/nn/convolution.h +++ b/src/relay/op/nn/convolution.h @@ -24,6 +24,7 @@ #ifndef TVM_RELAY_OP_NN_CONVOLUTION_H_ #define TVM_RELAY_OP_NN_CONVOLUTION_H_ +#include #include #include @@ -130,7 +131,7 @@ bool Conv1DRel(const Array& types, int num_inputs, const Attrs& attrs, template bool Conv2DRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - CHECK_EQ(types.size(), 3); + ICHECK_EQ(types.size(), 3); const auto* data = types[0].as(); const auto* weight = types[1].as(); if (data == nullptr) return false; @@ -143,26 +144,43 @@ bool Conv2DRel(const Array& types, int num_inputs, const Attrs& attrs, const Layout kernel_layout(param->kernel_layout); const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); - CHECK(trans_in_layout.defined()) - << "Conv only support input layouts that are convertible from NCHW." - << " But got " << in_layout; + if (!trans_in_layout.defined()) { + reporter->GetDiagCtx().Emit( + Diagnostic::Error(reporter->GetSpan()) + << "conv2d only support input layouts that are convertible from NCHW." + << " The provided layout is: " << in_layout); + return false; + } const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW); - CHECK(trans_kernel_layout.defined()) - << "Conv only support kernel layouts that are convertible from OIHW." - << " But got " << kernel_layout; + if (!trans_kernel_layout.defined()) { + reporter->GetDiagCtx().Emit( + Diagnostic::Error(reporter->GetSpan()) + << "conv2d only support kernel layouts that are convertible from OIHW." + << " The provided layout is: " << kernel_layout); + return false; + } Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW); - CHECK(trans_out_layout.defined()) - << "Conv only support output layouts that are convertible from NCHW." - << " But got " << out_layout; + if (!trans_out_layout.defined()) { + reporter->GetDiagCtx().Emit( + Diagnostic::Error(reporter->GetSpan()) + << "conv2d only support output layouts that are convertible from NCHW." + << "The provided layout is: " << out_layout); + return false; + } Array dshape_nchw = trans_in_layout.ForwardShape(data->shape); bool is_depthwise = false; if (param->groups > 1) { - CHECK(weight && weight->shape.defined()) - << "Weight shape must be specified when groups is greater than 1."; + if (!(weight && weight->shape.defined())) { + reporter->GetDiagCtx().Emit( + Diagnostic::Error(reporter->GetSpan()) + << "Weight shape must be specified when groups is greater than 1."); + return false; + } + Array wshape_oihw = trans_kernel_layout.ForwardShape(weight->shape); if (tvm::tir::ExprDeepEqual()(param->groups, dshape_nchw[1]) && tvm::tir::ExprDeepEqual()(param->groups, wshape_oihw[0])) { @@ -201,20 +219,44 @@ bool Conv2DRel(const Array& types, int num_inputs, const Attrs& attrs, if (weight == nullptr) return false; auto wshape = trans_kernel_layout.ForwardShape(weight->shape); if (param->kernel_size.defined()) { - CHECK_EQ(param->kernel_size.size(), 2); - // check the size - CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) && - reporter->AssertEQ(param->kernel_size[1], wshape[3])) - << "Conv2D: shape of weight is inconsistent with kernel_size, " - << " kernel_size=" << param->kernel_size << " wshape=" << wshape; + ICHECK_EQ(param->kernel_size.size(), 2); + + if (!reporter->AssertEQ(param->kernel_size[0], wshape[2])) { + reporter->GetDiagCtx().Emit(Diagnostic::Error(reporter->GetSpan()) + << "Conv2D: shape of weight is inconsistent with kernel_size," + << " kernel_size=" << param->kernel_size + << " wshape=" << wshape); + } + + if (!reporter->AssertEQ(param->kernel_size[1], wshape[3])) { + reporter->GetDiagCtx().Emit(Diagnostic::Error(reporter->GetSpan()) + << "Conv2D: shape of weight is inconsistent with kernel_size," + << " kernel_size=" << param->kernel_size + << " wshape=" << wshape); + return false; + } } - if (param->channels.defined()) { - CHECK(reporter->AssertEQ(param->channels, wshape[0])) - << "Conv2D: shape of weight is inconsistent with channels, " - << " channels=" << param->channels << " wshape=" << wshape; + + if (param->channels.defined() && !reporter->AssertEQ(param->channels, wshape[0])) { + reporter->GetDiagCtx().Emit( + Diagnostic::Error(reporter->GetSpan()) + << "conv2D: the first dimensions of the weight tensor (" << wshape << ")" + << "does not match the number of channels (" << param->channels << ")."); + return false; } + if (!dshape_nchw[1].as() && !wshape[1].as()) { - CHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[1])); + if (!reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[1])) { + reporter->GetDiagCtx().Emit(Diagnostic::Error(reporter->GetSpan()) + << "conv2d: requires that `" + << indexdiv(dshape_nchw[1], param->groups) << "`," + << " the input channels (" << dshape_nchw[1] << ")" + << " divided by groups (" << param->groups << ")" + << ",\n must match the input channels" + << " of the weight `" << wshape[1] + << "`, where the weight shape is (" << wshape << ")."); + return false; + } } channels = wshape[0]; dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0]; @@ -321,11 +363,13 @@ bool Conv3DRel(const Array& types, int num_inputs, const Attrs& attrs, << "Conv3D: shape of weight is inconsistent with kernel_size, " << " kernel_size=" << param->kernel_size << " wshape=" << wshape; } + if (param->channels.defined()) { CHECK(reporter->AssertEQ(param->channels, wshape[0])) << "Conv3D: shape of weight is inconsistent with channels, " << " channels=" << param->channels << " wshape=" << wshape; } + if (!dshape_ncdhw[1].as() && !wshape[1].as()) { CHECK(reporter->AssertEQ(indexdiv(dshape_ncdhw[1], param->groups), wshape[1])); } diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index a82dc0a0697c2..e4560c093115c 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -207,6 +207,7 @@ class Partitioner : public MixedModeMutator { func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, func->attrs); module_->Update(pair.first, func); + module_ = transform::InferType()(module_); } } return module_; @@ -311,6 +312,7 @@ class Partitioner : public MixedModeMutator { auto pf = tvm::runtime::Registry::Get(ext_opt); if (pf != nullptr) { auto mod = IRModule::FromExpr(global_region_func); + mod = transform::InferType()(mod); mod = (*pf)(mod); global_region_func = Downcast(mod->Lookup("main")); } @@ -331,6 +333,7 @@ class Partitioner : public MixedModeMutator { // optimizing it. GlobalVar glob_func(fname); module_->Add(glob_func, global_region_func); + module_ = relay::transform::InferType()(module_); // Create a call node for the function. auto call = Call(glob_func, param_expr); @@ -415,6 +418,7 @@ IRModule RemoveDefaultAnnotations(IRModule module) { auto removed = PostOrderRewrite(func->body, &remover); func = Function(func->params, removed, func->ret_type, func->type_params, func->attrs); module->Update(pair.first, func); + module = relay::transform::InferType()(module); } } return module; @@ -470,6 +474,7 @@ IRModule FlattenTupleOutputs(IRModule module) { auto removed = PostOrderRewrite(func->body, &to_flattener); func = Function(func->params, removed, func->ret_type, func->type_params, func->attrs); module->Update(pair.first, func); + module = relay::transform::InferType()(module); } } return module; diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index e110737d62263..d34a662778a4a 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -37,7 +37,8 @@ * If we can not infer a type or there is a conflicting * constraint it will emit errors. */ -#include + +#include #include #include #include @@ -87,7 +88,7 @@ struct ResolvedTypeInfo { }; // -// The inference algorithm can roughly be devided into three stages: +// The inference algorithm can roughly be divided into three stages: // - Populate the constraints by visiting the expression (TypeInferencer.GetType) // - solver.AddConstraint and solver.Unify are called to populate the necessary constraints // - Solve the constraints (solver_.Solve) @@ -98,16 +99,13 @@ class TypeInferencer : private ExprFunctor, public: // constructors - explicit TypeInferencer(IRModule mod, GlobalVar current_func) - : mod_(mod), - current_func_(current_func), - err_reporter(), - solver_(current_func, mod, &this->err_reporter) { - CHECK(mod.defined()) << "internal error: Module must be set in the type inferencer"; + explicit TypeInferencer(IRModule mod, DiagnosticContext diag_ctx) + : mod_(mod), diag_ctx(diag_ctx), solver_(GlobalVar(), diag_ctx) { + ICHECK(mod.defined()) << "Module must not be null in the type inferencer."; } - // inference the type of expr. - Expr Infer(Expr expr); + // Infer the types inside of a function. + Expr Infer(GlobalVar var, Function expr); private: // type resolver that maps back to type @@ -118,8 +116,8 @@ class TypeInferencer : private ExprFunctor, // The current function being type checked. GlobalVar current_func_; - // The error reporter. - ErrorReporter err_reporter; + /*! \brief The diagnostic context. */ + DiagnosticContext diag_ctx; // map from expression to checked type // type inferencer will populate it up @@ -133,12 +131,12 @@ class TypeInferencer : private ExprFunctor, // Perform unification on two types and report the error at the expression // or the span of the expression. - Type Unify(const Type& t1, const Type& t2, const ObjectRef& expr) { + Type Unify(const Type& t1, const Type& t2, const Span& span) { try { - return solver_.Unify(t1, t2, expr); + return solver_.Unify(t1, t2, span); } catch (const dmlc::Error& e) { - this->ReportFatalError( - expr, ErrorBuilder() << "Error unifying `" << t1 << "` and `" << t2 << "`: " << e.what()); + this->EmitFatal(Diagnostic::Error(span) + << "Error unifying `" << t1 << "` and `" << t2 << "`: " << e.what()); return Type(); } } @@ -152,17 +150,13 @@ class TypeInferencer : private ExprFunctor, } Type ret = this->VisitExpr(expr); CHECK(ret.defined()); - KindCheck(ret, mod_); + KindCheck(ret, mod_, this->diag_ctx); ResolvedTypeInfo& rti = type_map_[expr]; rti.checked_type = ret; return ret; } - void ReportFatalError(const ObjectRef& expr, const Error& err) { - CHECK(this->current_func_.defined()); - this->err_reporter.ReportAt(this->current_func_, expr, err); - this->err_reporter.RenderErrors(this->mod_); - } + void EmitFatal(const Diagnostic& diag) { this->diag_ctx.EmitFatal(diag); } // Visitor Logic Type VisitExpr_(const VarNode* op) final { @@ -176,11 +170,10 @@ class TypeInferencer : private ExprFunctor, Type VisitExpr_(const GlobalVarNode* op) final { GlobalVar var = GetRef(op); if (!mod_.defined()) { - this->ReportFatalError(GetRef(op), - ErrorBuilder() << "Cannot do type inference on global variables " - "without a module"); + this->EmitFatal(Diagnostic::Error(op->span) << "Cannot do type inference on global variables " + << "without a module"); } - Expr e = mod_->Lookup(var); + relay::Function e = Downcast(mod_->Lookup(var)); return e->checked_type(); } @@ -204,7 +197,7 @@ class TypeInferencer : private ExprFunctor, auto attrs = make_object(); attrs->index = op->index; solver_.AddConstraint(TypeRelation(tuple_getitem_rel_, {tuple_type, rtype}, 1, Attrs(attrs)), - GetRef(op)); + op->span); return rtype; } @@ -219,33 +212,37 @@ class TypeInferencer : private ExprFunctor, for (size_t i = 0; i < td->type_vars.size(); i++) { unknown_args.push_back(IncompleteType(Kind::kType)); } + Type expected = TypeCall(con->constructor->belong_to, unknown_args); - Type unified = Unify(t, expected, GetRef(con)); + Type unified = Unify(t, expected, pc->span); auto* tc = unified.as(); if (!tc) { - this->ReportFatalError(pc, ErrorBuilder() << "Expected a type call, got " << unified); + this->EmitFatal(Diagnostic::Error(pc->span) << "Expected a type call, got " << unified); } + if (td->header != tc->func) { - this->ReportFatalError(pc, ErrorBuilder() << "ADT headers must match, but we have " - << td->header << " and " << tc->func); + this->EmitFatal(Diagnostic::Error(pc->span) << "ADT headers must match, but we have " + << td->header << " and " << tc->func); } + if (td->type_vars.size() != tc->args.size()) { - this->ReportFatalError( - pc, ErrorBuilder() << "The number of type args must match" - << "the number of type vars in the type data: " << td->type_vars.size() - << " != " << tc->args.size()); + this->EmitFatal(Diagnostic::Error(pc->span) + << "The number of type args must match" + << "the number of type vars in the type data: " << td->type_vars.size() + << " != " << tc->args.size()); } std::unordered_map type_var_map_; for (size_t i = 0; i < td->type_vars.size(); ++i) { type_var_map_[td->type_vars[i]] = tc->args[i]; } - CHECK(con->constructor->inputs.size() == con->patterns.size()) << "not enough pattern"; + if (con->constructor->inputs.size() != con->patterns.size()) { - this->ReportFatalError(pc, ErrorBuilder() << "Not enough inputs for the constructor; " - << "expected " << con->constructor->inputs.size() - << ", got " << con->patterns.size()); + this->EmitFatal(Diagnostic::Error(pc->span) << "Not enough inputs for the constructor; " + << "expected " << con->constructor->inputs.size() + << ", got " << con->patterns.size()); } + for (size_t i = 0; i < con->constructor->inputs.size(); ++i) { VisitPattern(con->patterns[i], Bind(con->constructor->inputs[i], type_var_map_)); } @@ -259,12 +256,13 @@ class TypeInferencer : private ExprFunctor, for (size_t i = 0; i < tup->patterns.size(); i++) { unknown_args.push_back(IncompleteType(Kind::kType)); } + Type expected = TupleType(unknown_args); - Type unified = Unify(t, expected, GetRef(tup)); + Type unified = Unify(t, expected, tup->span); auto* tt = unified.as(); if (!tt) { - this->ReportFatalError(pt, ErrorBuilder() << "Expected a tuple type, got " << unified); + this->EmitFatal(Diagnostic::Error(pt->span) << "Expected a tuple type, got " << unified); } CHECK(tup->patterns.size() == tt->fields.size()) << "not enough pattern"; for (size_t i = 0; i < tup->patterns.size(); ++i) { @@ -295,12 +293,13 @@ class TypeInferencer : private ExprFunctor, Array unmatched_cases = UnmatchedCases(match, this->mod_); if (unmatched_cases.size() != 0) { ErrorBuilder ss; - ss << "match expression does not handle the following cases: "; + auto err = Diagnostic::Error(match->span); + err << "match expression does not handle the following cases: "; int i = 0; for (auto cs : unmatched_cases) { - ss << "case " << i++ << ": \n" << PrettyPrint(cs); + err << "case " << i++ << ": \n" << PrettyPrint(cs); } - this->ReportFatalError(match, ss); + this->EmitFatal(err); } } @@ -320,11 +319,11 @@ class TypeInferencer : private ExprFunctor, } if (let->var->type_annotation.defined()) { - let_type = Unify(let_type, let->var->type_annotation, GetRef(let)); + let_type = Unify(let_type, let->var->type_annotation, let->span); } Type vtype = GetType(let->value); - let_type = Unify(let_type, vtype, GetRef(let)); + let_type = Unify(let_type, vtype, let->span); CHECK(is_functional_literal || !type_map_.count(let->var)); // NOTE: no scoping is necessary because var are unique in program @@ -336,10 +335,10 @@ class TypeInferencer : private ExprFunctor, // Ensure the type of the guard is of Tensor[Bool, ()], // that is a rank-0 boolean tensor. Type cond_type = this->GetType(ite->cond); - this->Unify(cond_type, TensorType::Scalar(tvm::DataType::Bool()), ite->cond); + this->Unify(cond_type, TensorType::Scalar(tvm::DataType::Bool()), ite->cond->span); Type checked_true = this->GetType(ite->true_branch); Type checked_false = this->GetType(ite->false_branch); - return this->Unify(checked_true, checked_false, GetRef(ite)); + return this->Unify(checked_true, checked_false, ite->span); } // This code is special-cased for primitive operators, @@ -347,7 +346,7 @@ class TypeInferencer : private ExprFunctor, // // The result will be the return type of the operator. Type PrimitiveCall(const FuncTypeNode* op, Array arg_types, const Attrs& attrs, - const ObjectRef& loc) { + const Span& span) { if (op->type_params.size() != arg_types.size() + 1) return Type(); if (op->type_constraints.size() != 1) return Type(); const TypeRelationNode* rel = op->type_constraints[0].as(); @@ -359,7 +358,7 @@ class TypeInferencer : private ExprFunctor, Type rtype = IncompleteType(Kind::kType); arg_types.push_back(rtype); // we can do simple replacement here - solver_.AddConstraint(TypeRelation(rel->func, arg_types, arg_types.size() - 1, attrs), loc); + solver_.AddConstraint(TypeRelation(rel->func, arg_types, arg_types.size() - 1, attrs), span); return rtype; } @@ -421,9 +420,8 @@ class TypeInferencer : private ExprFunctor, auto* inc_ty_node = ftype.as(); if (fn_ty_node == nullptr && inc_ty_node == nullptr) { - this->ReportFatalError( - GetRef(call), - ErrorBuilder() << "only expressions with function types can be called, found " << ftype); + this->EmitFatal(Diagnostic::Error(call->span) + << "only expressions with function types can be called, found " << ftype); } // incomplete type => it must be a function taking the arg types @@ -431,17 +429,16 @@ class TypeInferencer : private ExprFunctor, if (inc_ty_node != nullptr) { Type ret_type = IncompleteType(Kind::kType); Type func_type = FuncType(arg_types, ret_type, {}, {}); - Type unified = this->Unify(ftype, func_type, GetRef(call)); + Type unified = this->Unify(ftype, func_type, call->op->span); fn_ty_node = unified.as(); } Array type_args = call->type_args; if (type_args.size() > fn_ty_node->type_params.size()) { - this->ReportFatalError(GetRef(call), - ErrorBuilder() - << "Incorrect number of type args in " << call->span << ": " - << "Expected " << fn_ty_node->type_params.size() << "but got " - << type_args.size()); + this->EmitFatal(Diagnostic::Error(call->span) + << "Incorrect number of type args in " << call->span << ": " + << "Expected " << fn_ty_node->type_params.size() << "but got " + << type_args.size()); } for (size_t i = type_args.size(); i < fn_ty_node->type_params.size(); i++) { type_args.push_back(IncompleteType(TypeKind::kType)); @@ -456,28 +453,26 @@ class TypeInferencer : private ExprFunctor, if (type_arity != number_of_args) { if (type_arity < number_of_args) { - this->ReportFatalError(GetRef(call), - ErrorBuilder() - << "the function is provided too many arguments " - << "expected " << type_arity << ", found " << number_of_args); + this->EmitFatal(Diagnostic::Error(call->span) + << "the function is provided too many arguments " + << "expected " << type_arity << ", found " << number_of_args); } else { - this->ReportFatalError(GetRef(call), - ErrorBuilder() - << "the function is provided too few arguments " - << "expected " << type_arity << ", found " << number_of_args); + this->EmitFatal(Diagnostic::Error(call->span) + << "the function is provided too few arguments " + << "expected " << type_arity << ", found " << number_of_args); } } for (size_t i = 0; i < fn_ty->arg_types.size(); i++) { - this->Unify(fn_ty->arg_types[i], arg_types[i], GetRef(call)); + this->Unify(fn_ty->arg_types[i], arg_types[i], call->span); } for (auto cs : fn_ty->type_constraints) { if (const auto* tr = cs.as()) { solver_.AddConstraint(TypeRelation(tr->func, tr->args, tr->num_inputs, call->attrs), - GetRef(call)); + call->span); } else { - solver_.AddConstraint(cs, GetRef(call)); + solver_.AddConstraint(cs, call->span); } } @@ -491,8 +486,9 @@ class TypeInferencer : private ExprFunctor, } if (const OpNode* opnode = call->op.as()) { - Type rtype = PrimitiveCall(opnode->op_type.as(), arg_types, call->attrs, - GetRef(call)); + Type rtype = + PrimitiveCall(opnode->op_type.as(), arg_types, call->attrs, call->span); + if (rtype.defined()) { AddTypeArgs(GetRef(call), arg_types); return rtype; @@ -513,7 +509,7 @@ class TypeInferencer : private ExprFunctor, rtype = InstantiateFuncType(ft); } if (f->ret_type.defined()) { - rtype = this->Unify(f->ret_type, rtype, GetRef(f)); + rtype = this->Unify(f->ret_type, rtype, GetRef(f)->span); } CHECK(rtype.defined()); auto ret = FuncType(arg_types, rtype, f->type_params, {}); @@ -524,14 +520,14 @@ class TypeInferencer : private ExprFunctor, Type VisitExpr_(const RefReadNode* op) final { Type it = IncompleteType(Kind::kType); - this->Unify(GetType(op->ref), RelayRefType(it), GetRef(op)); + this->Unify(GetType(op->ref), RelayRefType(it), op->span); return it; } Type VisitExpr_(const RefWriteNode* op) final { Type it = IncompleteType(Kind::kType); - this->Unify(GetType(op->ref), RelayRefType(it), GetRef(op)); - this->Unify(GetType(op->value), it, GetRef(op)); + this->Unify(GetType(op->ref), RelayRefType(it), op->span); + this->Unify(GetType(op->value), it, op->span); return TupleType::Empty(); } @@ -547,10 +543,6 @@ class TypeInferencer : private ExprFunctor, void Solve() { solver_.Solve(); - - if (err_reporter.AnyErrors()) { - err_reporter.RenderErrors(mod_); - } } }; @@ -606,10 +598,13 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { CHECK(it != tmap_.end()); Type checked_type = solver_->Resolve(it->second.checked_type); - // TODO(@jroesch): it would be nice if we would report resolution - // errors directly on the program. - CHECK(checked_type.as() == nullptr) - << "Cannot resolve type of " << GetRef(op) << " at " << op->span; + if (checked_type.as() != nullptr) { + this->solver_->diag_ctx_.Emit( + Diagnostic::Error(op->span) + << "The type inference pass was unable to infer a type for this expression.\n" + << "This usually occurs when an operator call is under constrained in some way," + << " check other reported errors for hints of what may of happened."); + } Expr new_e = ExprMutator::VisitExpr_(op); // new_call and new_var's code is only going to be valid for VarNode/CallNode. @@ -686,16 +681,24 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { bool update_missing_type_annotation_{true}; }; -Expr TypeInferencer::Infer(Expr expr) { +Expr TypeInferencer::Infer(GlobalVar var, Function function) { + // Set the current function being type checked. + this->current_func_ = var; + // Step 1: Populate the constraints. - GetType(expr); + GetType(function); // Step 2: Solve the constraints. Solve(); // Step 3: Attach resolved types to checked_type field. - auto resolved_expr = Resolver(type_map_, &solver_).VisitExpr(expr); - CHECK(WellFormed(resolved_expr)); + auto resolved_expr = Resolver(type_map_, &solver_).VisitExpr(function); + + if (!WellFormed(resolved_expr, this->diag_ctx)) { + this->diag_ctx.Emit(Diagnostic::Bug(function->span) + << "the type checked function is malformed, please report this"); + } + return resolved_expr; } @@ -717,38 +720,88 @@ struct AllCheckTypePopulated : ExprVisitor { void EnsureCheckedType(const Expr& e) { AllCheckTypePopulated().VisitExpr(e); } -Expr InferType(const Expr& expr, const IRModule& mod) { - auto main = mod->GetGlobalVar("main"); - auto inferencer = TypeInferencer(mod, main); - auto e = inferencer.Infer(expr); - CHECK(WellFormed(e)); - auto free_tvars = FreeTypeVars(e, mod); - CHECK(free_tvars.size() == 0) << "Found unbound type variables in " << e << ": " << free_tvars; - EnsureCheckedType(e); - return e; -} +// TODO(@jroesch): Can we optimize this? +void AddGlobalTypes(IRModule mod) { + std::vector > updates; + for (const auto& it : mod->functions) { + // Currently we don't type check TIR. + // The inferencer will only check Relay functions + // the future plan is to have a unified type checker + // that works on TIR and Relay at the same time. + if (auto* func_node = it.second.as()) { + Function func = Function(make_object(*func_node)); + func->checked_type_ = func->func_type_annotation(); + updates.push_back({it.first, Downcast(func)}); + } + } -Function InferType(const Function& func, const IRModule& mod, const GlobalVar& var) { - CHECK(mod.defined()) << "internal error: module must be set for type inference"; - Function func_copy = Function(make_object(*func.operator->())); - func_copy->checked_type_ = func_copy->func_type_annotation(); - mod->AddUnchecked(var, func_copy); - Expr func_ret = TypeInferencer(mod, var).Infer(func_copy); - mod->Remove(var); - CHECK(WellFormed(func_ret)); - auto free_tvars = FreeTypeVars(func_ret, mod); - CHECK(free_tvars.size() == 0) << "Found unbound type variables in: " << std::endl - << AsText(func, true) << std::endl - << free_tvars; - return Downcast(func_ret); + for (const auto& pair : updates) { + mod->Add(pair.first, pair.second, true); + } } namespace transform { Pass InferType() { - runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { return Downcast(InferType(f, m)); }; - return CreateFunctionPass(pass_func, 0, "InferType", {}); + auto pass_info = PassInfo(0, "InferType", {}); + return tvm::transform::CreateModulePass( + [=](IRModule mod, const PassContext& pass_ctx) { + DLOG(INFO) << "tvm::relay::transform::InferType"; + // Execute the pass function and return a new module. + IRModule updated_mod = + IRModule(mod->functions, mod->type_definitions, mod->Imports(), mod->source_map); + + pass_ctx->diag_ctx = DiagnosticContext::Default(updated_mod); + + // Add all the type annotations to the functions in the model. + AddGlobalTypes(mod); + + std::vector > updates; + for (const auto& it : updated_mod->functions) { + // Currently we don't type check TIR. + // + // The inferencer will only check Relay functions. + + // In the future we plan a unified type checker + // that works on TIR and Relay at the same time. + if (auto* func_node = it.second.as()) { + auto func = GetRef(func_node); + + // // If a function already has type information we can skip checking it. + // if (func->checked_type_.defined()) { + // continue; + // } + + // TODO(@jroesch): we should be able to move the type inferencer outside + // of this function but it seems to be more stateful then I expect. + auto inferencer = TypeInferencer(mod, pass_ctx->diag_ctx.value()); + auto updated_func = inferencer.Infer(it.first, func); + + pass_ctx->diag_ctx.value().Render(); + + // After we are done checking write the global type back + // into the global var. + it.first->checked_type_ = updated_func->checked_type(); + + if (!WellFormed(updated_func, pass_ctx->diag_ctx)) { + LOG(FATAL) << "The type checked intermediate representation is malformed"; + } + + auto free_tvars = FreeTypeVars(updated_func, mod); + CHECK(free_tvars.size() == 0) + << "Found unbound type variables in " << updated_func << ": " << free_tvars; + EnsureCheckedType(updated_func); + updates.push_back({it.first, Downcast(updated_func)}); + } + } + + for (const auto& pair : updates) { + updated_mod->Add(pair.first, pair.second, true); + } + + return updated_mod; + }, + 0, "InferType", {}); } TVM_REGISTER_GLOBAL("relay._transform.InferType").set_body_typed([]() { return InferType(); }); diff --git a/src/runtime/contrib/random/mt_random_engine.cc b/src/runtime/contrib/random/mt_random_engine.cc index 8a4ee9af24bb4..8c20f0700ee74 100644 --- a/src/runtime/contrib/random/mt_random_engine.cc +++ b/src/runtime/contrib/random/mt_random_engine.cc @@ -42,7 +42,7 @@ class RandomEngine { /*! * \brief Creates a RandomEngine using a default seed. */ - RandomEngine() { this->Seed(time(0)); } + RandomEngine() { this->Seed(time(nullptr)); } /*! * \brief Creates a RandomEngine, suggesting the use of a provided seed. diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 5298added83be..c121285e23143 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -311,11 +311,12 @@ Pass SplitHostDevice() { auto pass_func = [](IRModule mod, PassContext ctx) { IRModuleNode* mod_ptr = mod.CopyOnWrite(); auto* func_dict = mod_ptr->functions.CopyOnWrite(); - IRModule device_mod = IRModule(); + IRModule device_mod = IRModule(Map({})); for (auto& kv : *func_dict) { if (kv.second->IsInstance()) { PrimFunc func = Downcast(std::move(kv.second)); + CHECK(device_mod.defined()) << "The device module must be defined."; kv.second = SplitHostDevice(std::move(func), &device_mod); } } diff --git a/tests/cpp/relay_build_module_test.cc b/tests/cpp/relay_build_module_test.cc index 2b5eb961e6b2e..fcab1b85edd92 100644 --- a/tests/cpp/relay_build_module_test.cc +++ b/tests/cpp/relay_build_module_test.cc @@ -116,12 +116,14 @@ TEST(Relay, BuildModule) { Target llvm_tgt = Target("llvm"); targets.Set(0, llvm_tgt); auto relay_mod = tvm::IRModule::FromExpr(func); + CHECK(relay_mod.defined()) << "Module must be defined"; build_f(relay_mod, targets, llvm_tgt); std::string json = json_f(); tvm::runtime::Module mod = mod_f(); // run auto ctx = A->ctx; auto pfr = tvm::runtime::Registry::Get("tvm.graph_runtime.create"); + CHECK(mod.defined()) << "Module must be defined"; tvm::runtime::Module run_mod = (*pfr)(json, mod, (int)ctx.device_type, (int)ctx.device_id); auto set_input_f = run_mod.GetFunction("set_input_zero_copy", false); auto run_f = run_mod.GetFunction("run", false); diff --git a/tests/python/contrib/test_ethosn/infrastructure.py b/tests/python/contrib/test_ethosn/infrastructure.py index 8503123760d67..e2c40550289a4 100644 --- a/tests/python/contrib/test_ethosn/infrastructure.py +++ b/tests/python/contrib/test_ethosn/infrastructure.py @@ -77,7 +77,8 @@ def make_module(func, params): func = relay.Function(relay.analysis.free_vars(func), func) if params: relay.build_module.bind_params_by_name(func, params) - return tvm.IRModule.from_expr(func) + mod = tvm.IRModule.from_expr(func) + return relay.transform.InferType()(mod) def make_ethosn_composite(ethosn_expr, name): @@ -92,20 +93,28 @@ def make_ethosn_partition(ethosn_expr): # Create an Ethos-N global function mod = tvm.IRModule({}) vars = relay.analysis.free_vars(ethosn_expr) - func = relay.Function(vars, ethosn_expr) + # NB: it is illegal to reuse variables inside and outside a scope in Relay + # if you want to duplicate types and names you must re-allocate them. + fresh_vars = [relay.Var(v.name_hint, v.type_annotation) for v in vars] + binds = {} + for var, fresh_var in zip(vars, fresh_vars): + binds[var] = fresh_var + ethosn_expr_fresh = relay.bind(ethosn_expr, binds) + func = relay.Function(fresh_vars, ethosn_expr_fresh) func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1)) func = func.with_attr("Compiler", "ethos-n") func = func.with_attr("global_symbol", "ethos-n_0") g1 = relay.GlobalVar("ethos-n_0") mod[g1] = func + mod = relay.transform.InferType()(mod) # These are the vars to call the Ethos-N partition with more_vars = relay.analysis.free_vars(ethosn_expr) # Call the Ethos-N partition in main call_fn1 = g1(*more_vars) mod["main"] = relay.Function(more_vars, call_fn1) - return mod + return relay.transform.InferType()(mod) def get_host_op_count(mod): @@ -150,9 +159,12 @@ def build(mod, params, npu=True, expected_host_ops=0, npu_partitions=1): mod = tvm.IRModule() mod["main"] = f pattern = get_pattern_table("ethos-n") + mod = relay.transform.InferType()(mod) mod = relay.transform.MergeComposite(pattern)(mod) mod = relay.transform.AnnotateTarget("ethos-n")(mod) + mod = relay.transform.InferType()(mod) mod = relay.transform.MergeCompilerRegions()(mod) + mod = relay.transform.InferType()(mod) mod = relay.transform.PartitionGraph()(mod) host_op_count = get_host_op_count(mod) assert ( @@ -245,6 +257,7 @@ def test_error(mod, params, err_msg): with tvm.transform.PassContext(opt_level=3): with tvm.target.Target("llvm"): try: + mod = relay.transform.InferType()(mod) relay.build(mod, params) except tvm.error.TVMError as e: caught = e.args[0] diff --git a/tests/python/contrib/test_ethosn/test_reshape.py b/tests/python/contrib/test_ethosn/test_reshape.py index e15ddd6fefaf8..4afec557e569e 100644 --- a/tests/python/contrib/test_ethosn/test_reshape.py +++ b/tests/python/contrib/test_ethosn/test_reshape.py @@ -76,6 +76,7 @@ def test_reshape_failure(): model, params = _get_model(input_shape, output_shape, dtype) mod = tei.make_module(model, params) pattern = get_pattern_table("ethos-n") + mod = tei.make_module(model, params) mod = relay.transform.MergeComposite(pattern)(mod) mod = tei.make_ethosn_partition(mod["main"].body) tei.test_error(mod, {}, err_msg) diff --git a/tests/python/frontend/pytorch/test_lstm.py b/tests/python/frontend/pytorch/test_lstm.py index 4d7a4063fdd64..39d78c70c0fb9 100644 --- a/tests/python/frontend/pytorch/test_lstm.py +++ b/tests/python/frontend/pytorch/test_lstm.py @@ -253,7 +253,8 @@ def wrap_nd_array(arr): mod = tvm.IRModule() prelude = Prelude(mod) - adt_lst = ADT(prelude.nil.tag, []) + list, cons, nil = mod.get_type("List") + adt_lst = ADT(nil.tag, []) for elem in reversed(py_lst): if isinstance(elem, np.ndarray): vmobj = wrap_nd_array(elem) @@ -261,7 +262,7 @@ def wrap_nd_array(arr): vmobj = tuple_object([wrap_nd_array(e) for e in elem]) elif isinstance(elem, list): vmobj = convert_list_to_vmobj(elem) - adt_lst = ADT(prelude.cons.tag, [vmobj, adt_lst]) + adt_lst = ADT(cons.tag, [vmobj, adt_lst]) return adt_lst diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index f55ef6aec36b6..2bd45b7a5a4ef 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -150,6 +150,8 @@ def run_tvm_graph( return vmobj_to_list(result) elif mode == "vm": with tvm.transform.PassContext(opt_level=opt_level, disabled_pass=disabled_pass): + print(mod["main"]) + mod = relay.transform.InferType()(mod) vm_exec = relay.vm.compile(mod, target="llvm", params=params) if serialize: code, lib = vm_exec.save() diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index c443fcdb33404..122fa67d65dfd 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -21,13 +21,12 @@ from tvm.relay.backend.interpreter import ConstructorValue from tvm.relay import create_executor from tvm.relay.prelude import Prelude, StaticTensorArrayOps -from tvm.relay.testing import add_nat_definitions, count as count_, make_nat_value, make_nat_expr +from tvm.relay.testing import count as count_, make_nat_value, make_nat_expr import numpy as np -mod = tvm.IRModule() -p = Prelude(mod) -add_nat_definitions(p) +prelude = p = Prelude(tvm.IRModule({})) +p.mod.import_from_std("nat.rly") def count(e): @@ -35,21 +34,16 @@ def count(e): ctx = tvm.context("llvm", 0) -intrp = create_executor(mod=mod, ctx=ctx, target="llvm") +intrp = create_executor(mod=prelude.mod, ctx=ctx, target="llvm") -z = p.z -s = p.s -nat = p.nat -double = p.double -add = p.add +nat, z, s = prelude.mod.get_type("nat") -optional = p.optional -some = p.some -none = p.none +double = p.mod.get_global_var("nat_double") +add = p.mod.get_global_var("nat_add") + +optional, some, none = prelude.mod.get_type("Option") +rlist, cons, nil = prelude.mod.get_type("List") -nil = p.nil -cons = p.cons -l = p.l hd = p.hd tl = p.tl nth = p.nth @@ -70,41 +64,25 @@ def count(e): map_accumr = p.map_accumr map_accuml = p.map_accuml -tree = p.tree -rose = p.rose +tree, rose = prelude.mod.get_type("Tree") + tmap = p.tmap size = p.size compose = p.compose iterate = p.iterate -# this is an example of creating the adt value in python side -def make_nat(n): - if n != 0: - return ConstructorValue(s, [make_nat(n - 1)]) - else: - return ConstructorValue(z, []) - - -def make_nat_expr(n): - assert n >= 0 - ret = z() - while n > 0: - ret = s(ret) - n = n - 1 - return ret - def to_list(l): assert isinstance(l, ConstructorValue) val = l ret = [] while True: - if val.tag == p.cons.tag: + if val.tag == cons.tag: ret.append(val.fields[0]) val = val.fields[1] else: - assert val.tag == p.nil.tag + assert val.tag == nil.tag break return ret @@ -112,7 +90,7 @@ def to_list(l): def tree_to_dict(t): assert isinstance(t, ConstructorValue) ret = {} - assert t.tag == p.rose.tag + assert t.tag == rose.tag ret["member"] = t.fields[0] ret["children"] = [] for subtree in to_list(t.fields[1]): @@ -158,7 +136,7 @@ def get_scalar(tv): return tv.asnumpy().item() -@tvm.testing.uses_gpu +# @tvm.testing.uses_gpu def test_nat_value(): assert count(make_nat_value(p, 10)) == 10 assert count(intrp.evaluate(s(s(z())))) == 2 @@ -168,24 +146,25 @@ def test_nat_value(): def test_nat_constructor(): func = relay.Function([], z()) test_z = relay.GlobalVar("test_z") - mod[test_z] = func - assert mod[test_z].body.checked_type == nat() test_sz = relay.GlobalVar("test_sz") + prelude.mod[test_z] = func func = relay.Function([], s(z())) - mod[test_sz] = func - assert mod[test_sz].body.checked_type == nat() + prelude.mod[test_sz] = func + ck_mod = relay.transform.InferType()(prelude.mod) + assert ck_mod[test_z].body.checked_type == nat() + assert ck_mod[test_sz].body.checked_type == nat() @tvm.testing.uses_gpu def test_double(): - assert mod[double].checked_type == relay.FuncType([nat()], nat()) + assert prelude.mod[double].checked_type == relay.FuncType([nat()], nat()) res = intrp.evaluate(double(s(z()))) assert count(res) == 2 @tvm.testing.uses_gpu def test_add(): - assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat()) + assert prelude.mod[add].checked_type == relay.FuncType([nat(), nat()], nat()) res = intrp.evaluate(add(s(z()), s(z()))) assert count(res) == 2 @@ -194,8 +173,9 @@ def test_add(): def test_list_constructor(): test_consz = relay.GlobalVar("test_consz") func = relay.Function([], cons(z(), nil())) - mod[test_consz] = func - assert mod[test_consz].body.checked_type == l(nat()) + prelude.mod[test_consz] = func + ck_mod = relay.transform.InferType()(prelude.mod) + assert ck_mod[test_consz].body.checked_type == rlist(nat()) @tvm.testing.uses_gpu @@ -203,7 +183,7 @@ def test_hd_tl(): expected = list(range(10)) l = nil() for i in reversed(expected): - l = cons(make_nat_expr(i), l) + l = cons(make_nat_expr(prelude, i), l) got = [] for i in range(len(expected)): @@ -221,6 +201,7 @@ def test_nth(): l = cons(relay.const(i), l) for i in range(len(expected)): + nth = prelude.mod.get_global_var("nth") item = intrp.evaluate(nth(l, relay.const(i))) assert get_scalar(item) == i @@ -231,11 +212,11 @@ def test_update(): l = nil() # create zero initialized list for i in range(len(expected)): - l = cons(make_nat_expr(0), l) + l = cons(make_nat_expr(prelude, 0), l) # set value for i, v in enumerate(expected): - l = update(l, relay.const(i), make_nat_expr(v)) + l = update(l, relay.const(i), make_nat_expr(prelude, v)) got = [] for i in range(len(expected)): @@ -247,7 +228,9 @@ def test_update(): @tvm.testing.uses_gpu def test_length(): a = relay.TypeVar("a") - assert mod[length].checked_type == relay.FuncType([l(a)], relay.scalar_type("int32"), [a]) + assert prelude.mod[length].checked_type == relay.FuncType( + [rlist(a)], relay.scalar_type("int32"), [a] + ) res = intrp.evaluate(length(cons(z(), cons(z(), cons(z(), nil()))))) assert get_scalar(res) == 3 @@ -256,8 +239,8 @@ def test_length(): def test_map(): a = relay.TypeVar("a") b = relay.TypeVar("b") - lhs = mod[map].checked_type - rhs = relay.FuncType([relay.FuncType([a], b), l(a)], l(b), [a, b]) + lhs = prelude.mod[map].checked_type + rhs = relay.FuncType([relay.FuncType([a], b), rlist(a)], rlist(b), [a, b]) assert lhs == rhs x = relay.Var("x") @@ -272,8 +255,9 @@ def test_map(): def test_foldl(): a = relay.TypeVar("a") b = relay.TypeVar("b") - lhs = mod[foldl].checked_type - rhs = relay.FuncType([relay.FuncType([a, b], a), a, l(b)], a, [a, b]) + + lhs = prelude.mod[foldl].checked_type + rhs = relay.FuncType([relay.FuncType([a, b], a), a, rlist(b)], a, [a, b]) assert lhs == rhs x = relay.Var("x") @@ -283,7 +267,10 @@ def test_foldl(): foldl( rev_dup, nil(), - cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil()))), + cons( + make_nat_expr(prelude, 1), + cons(make_nat_expr(prelude, 2), cons(make_nat_expr(prelude, 3), nil())), + ), ) ) reversed = to_list(res) @@ -297,8 +284,8 @@ def test_foldl(): def test_foldr(): a = relay.TypeVar("a") b = relay.TypeVar("b") - lhs = mod[foldr].checked_type - rhs = relay.FuncType([relay.FuncType([a, b], b), b, l(a)], b, [a, b]) + lhs = prelude.mod[foldr].checked_type + rhs = relay.FuncType([relay.FuncType([a, b], b), b, rlist(a)], b, [a, b]) assert lhs == rhs x = relay.Var("x") @@ -308,7 +295,10 @@ def test_foldr(): foldr( identity, nil(), - cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil()))), + cons( + make_nat_expr(prelude, 1), + cons(make_nat_expr(prelude, 2), cons(make_nat_expr(prelude, 3), nil())), + ), ) ) same = to_list(res) @@ -319,15 +309,21 @@ def test_foldr(): @tvm.testing.uses_gpu def test_foldr1(): a = relay.TypeVar("a") - lhs = mod[p.foldr1].checked_type - rhs = relay.FuncType([relay.FuncType([a, a], a), l(a)], a, [a]) + lhs = prelude.mod[foldr1].checked_type + rhs = relay.FuncType([relay.FuncType([a, a], a), rlist(a)], a, [a]) assert lhs == rhs x = relay.Var("x") y = relay.Var("y") f = relay.Function([x, y], add(x, y)) res = intrp.evaluate( - foldr1(f, cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil())))) + foldr1( + f, + cons( + make_nat_expr(prelude, 1), + cons(make_nat_expr(prelude, 2), cons(make_nat_expr(prelude, 3), nil())), + ), + ) ) assert count(res) == 6 @@ -335,8 +331,8 @@ def test_foldr1(): @tvm.testing.uses_gpu def test_sum(): - assert mod[sum].checked_type == relay.FuncType( - [l(relay.scalar_type("int32"))], relay.scalar_type("int32") + assert prelude.mod[sum].checked_type == relay.FuncType( + [rlist(relay.scalar_type("int32"))], relay.scalar_type("int32") ) res = intrp.evaluate(sum(cons(relay.const(1), cons(relay.const(2), nil())))) assert get_scalar(res) == 3 @@ -345,10 +341,10 @@ def test_sum(): @tvm.testing.uses_gpu def test_concat(): a = relay.TypeVar("a") - assert mod[concat].checked_type == relay.FuncType([l(a), l(a)], l(a), [a]) + assert prelude.mod[concat].checked_type == relay.FuncType([rlist(a), rlist(a)], rlist(a), [a]) - l1 = cons(make_nat_expr(1), cons(make_nat_expr(2), nil())) - l2 = cons(make_nat_expr(3), cons(make_nat_expr(4), nil())) + l1 = cons(make_nat_expr(prelude, 1), cons(make_nat_expr(prelude, 2), nil())) + l2 = cons(make_nat_expr(prelude, 3), cons(make_nat_expr(prelude, 4), nil())) res = intrp.evaluate(concat(l1, l2)) catted = to_list(res) @@ -363,9 +359,9 @@ def test_concat(): def test_filter(): a = relay.TypeVar("a") expected_type = relay.FuncType( - [relay.FuncType([a], relay.scalar_type("bool")), l(a)], l(a), [a] + [relay.FuncType([a], relay.scalar_type("bool")), rlist(a)], rlist(a), [a] ) - assert mod[filter].checked_type == expected_type + assert prelude.mod[filter].checked_type == expected_type x = relay.Var("x", nat()) greater_than_one = relay.Function( @@ -387,13 +383,14 @@ def test_filter(): filter( greater_than_one, cons( - make_nat_expr(1), + make_nat_expr(prelude, 1), cons( - make_nat_expr(1), + make_nat_expr(prelude, 1), cons( - make_nat_expr(3), + make_nat_expr(prelude, 3), cons( - make_nat_expr(1), cons(make_nat_expr(5), cons(make_nat_expr(1), nil())) + make_nat_expr(prelude, 1), + cons(make_nat_expr(prelude, 5), cons(make_nat_expr(prelude, 1), nil())), ), ), ), @@ -410,10 +407,13 @@ def test_filter(): def test_zip(): a = relay.TypeVar("a") b = relay.TypeVar("b") - expected_type = relay.FuncType([l(a), l(b)], l(relay.TupleType([a, b])), [a, b]) - assert mod[zip].checked_type == expected_type + expected_type = relay.FuncType([rlist(a), rlist(b)], rlist(relay.TupleType([a, b])), [a, b]) + assert prelude.mod[zip].checked_type == expected_type - l1 = cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil()))) + l1 = cons( + make_nat_expr(prelude, 1), + cons(make_nat_expr(prelude, 2), cons(make_nat_expr(prelude, 3), nil())), + ) l2 = cons(nil(), cons(cons(nil(), nil()), cons(cons(nil(), cons(nil(), nil())), nil()))) res = intrp.evaluate(zip(l1, l2)) @@ -427,7 +427,7 @@ def test_zip(): assert len(to_list(zipped[2][1])) == 2 # test truncation - l3 = cons(make_nat_expr(4), cons(make_nat_expr(5), nil())) + l3 = cons(make_nat_expr(prelude, 4), cons(make_nat_expr(prelude, 5), nil())) shorter_res = intrp.evaluate(zip(l3, l2)) truncated = to_list(shorter_res) assert len(truncated) == 2 @@ -447,10 +447,15 @@ def test_zip(): @tvm.testing.uses_gpu def test_rev(): a = relay.TypeVar("a") - assert mod[rev].checked_type == relay.FuncType([l(a)], l(a), [a]) + assert prelude.mod[rev].checked_type == relay.FuncType([rlist(a)], rlist(a), [a]) res = intrp.evaluate( - rev(cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil())))) + rev( + cons( + make_nat_expr(prelude, 1), + cons(make_nat_expr(prelude, 2), cons(make_nat_expr(prelude, 3), nil())), + ) + ) ) reversed = to_list(res) @@ -465,7 +470,7 @@ def test_unfoldr(): a = relay.TypeVar("a") b = relay.TypeVar("b") expected_type = relay.FuncType( - [relay.FuncType([a], optional(relay.TupleType([a, b]))), a], l(b), [a, b] + [relay.FuncType([a], optional(relay.TupleType([a, b]))), a], rlist(b), [a, b] ) x = relay.Var("x", nat()) @@ -483,7 +488,7 @@ def test_unfoldr(): ), ) - res = intrp.evaluate(unfoldr(count_down, make_nat_expr(3))) + res = intrp.evaluate(unfoldr(count_down, make_nat_expr(prelude, 3))) unfolded = to_list(res) assert len(unfolded) == 3 @@ -497,7 +502,7 @@ def test_unfoldl(): a = relay.TypeVar("a") b = relay.TypeVar("b") expected_type = relay.FuncType( - [relay.FuncType([a], optional(relay.TupleType([a, b]))), a], l(b), [a, b] + [relay.FuncType([a], optional(relay.TupleType([a, b]))), a], rlist(b), [a, b] ) x = relay.Var("x", nat()) @@ -515,7 +520,7 @@ def test_unfoldl(): ), ) - res = intrp.evaluate(unfoldl(count_down, make_nat_expr(3))) + res = intrp.evaluate(unfoldl(count_down, make_nat_expr(prelude, 3))) unfolded = to_list(res) assert len(unfolded) == 3 @@ -530,17 +535,20 @@ def test_map_accumr(): b = relay.TypeVar("b") c = relay.TypeVar("c") expected_type = relay.FuncType( - [relay.FuncType([a, b], relay.TupleType([a, c])), a, l(b)], - relay.TupleType([a, l(c)]), + [relay.FuncType([a, b], relay.TupleType([a, c])), a, rlist(b)], + relay.TupleType([a, rlist(c)]), [a, b, c], ) - assert mod[map_accumr].checked_type == expected_type + assert prelude.mod[map_accumr].checked_type == expected_type acc = relay.Var("acc", nat()) x = relay.Var("x", nat()) add_acc_to_each = relay.Function([acc, x], relay.Tuple([add(x, acc), add(x, acc)])) - vals = cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil()))) + vals = cons( + make_nat_expr(prelude, 1), + cons(make_nat_expr(prelude, 2), cons(make_nat_expr(prelude, 3), nil())), + ) res = intrp.evaluate(map_accumr(add_acc_to_each, z(), vals)) sum = count(res[0]) @@ -559,17 +567,20 @@ def test_map_accuml(): b = relay.TypeVar("b") c = relay.TypeVar("c") expected_type = relay.FuncType( - [relay.FuncType([a, b], relay.TupleType([a, c])), a, l(b)], - relay.TupleType([a, l(c)]), + [relay.FuncType([a, b], relay.TupleType([a, c])), a, rlist(b)], + relay.TupleType([a, rlist(c)]), [a, b, c], ) - assert mod[map_accuml].checked_type == expected_type + assert prelude.mod[map_accuml].checked_type == expected_type acc = relay.Var("acc", nat()) x = relay.Var("x", nat()) add_to_acc = relay.Function([acc, x], relay.Tuple([add(x, acc), x])) - vals = cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil()))) + vals = cons( + make_nat_expr(prelude, 1), + cons(make_nat_expr(prelude, 2), cons(make_nat_expr(prelude, 3), nil())), + ) res = intrp.evaluate(map_accuml(add_to_acc, z(), vals)) sum = count(res[0]) @@ -602,7 +613,10 @@ def test_optional_matching(): foldr( condense, nil(), - cons(some(make_nat_expr(3)), cons(none(), cons(some(make_nat_expr(1)), nil()))), + cons( + some(make_nat_expr(prelude, 3)), + cons(none(), cons(some(make_nat_expr(prelude, 1)), nil())), + ), ) ) @@ -616,7 +630,7 @@ def test_optional_matching(): def test_tmap(): a = relay.TypeVar("a") b = relay.TypeVar("b") - lhs = mod[tmap].checked_type + lhs = prelude.mod[tmap].checked_type rhs = relay.FuncType([relay.FuncType([a], b), tree(a)], tree(b), [a, b]) assert lhs == rhs @@ -637,7 +651,7 @@ def test_tmap(): @tvm.testing.uses_gpu def test_size(): a = relay.TypeVar("a") - lhs = mod[size].checked_type + lhs = prelude.mod[size].checked_type rhs = relay.FuncType([tree(a)], relay.scalar_type("int32"), [a]) assert lhs == rhs @@ -658,7 +672,7 @@ def test_wildcard_match_solo(): @tvm.testing.uses_gpu def test_wildcard_match_order(): - x = relay.Var("x", l(nat())) + x = relay.Var("x", rlist(nat())) y = relay.Var("y") a = relay.Var("a") return_zero = relay.Function( @@ -684,7 +698,8 @@ def test_wildcard_match_order(): @tvm.testing.uses_gpu def test_nested_matches(): a = relay.TypeVar("a") - x = relay.Var("x") + # TODO(@jroesch): inference should be able to handle this one + x = relay.Var("x", type_annotation=rlist(rlist(a))) y = relay.Var("y") w = relay.Var("w") h = relay.Var("h") @@ -703,7 +718,7 @@ def test_nested_matches(): ], ) - mod[flatten] = relay.Function( + prelude.mod[flatten] = relay.Function( [x], relay.Match( x, @@ -715,12 +730,18 @@ def test_nested_matches(): ), ], ), - l(a), + rlist(a), [a], ) - first_list = cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil()))) - second_list = cons(make_nat_expr(4), cons(make_nat_expr(5), cons(make_nat_expr(6), nil()))) + first_list = cons( + make_nat_expr(prelude, 1), + cons(make_nat_expr(prelude, 2), cons(make_nat_expr(prelude, 3), nil())), + ) + second_list = cons( + make_nat_expr(prelude, 4), + cons(make_nat_expr(prelude, 5), cons(make_nat_expr(prelude, 6), nil())), + ) final_list = cons(first_list, cons(second_list, nil())) res = intrp.evaluate(flatten(final_list)) @@ -751,7 +772,7 @@ def test_match_full_var(): @tvm.testing.uses_gpu def test_nested_pattern_match(): - x = relay.Var("x", l(nat())) + x = relay.Var("x", rlist(nat())) h1 = relay.Var("h1") h2 = relay.Var("h2") t = relay.Var("t") @@ -789,725 +810,10 @@ def test_compose(): @tvm.testing.uses_gpu def test_iterate(): - expr = relay.Call(iterate(double, relay.const(2)), [make_nat_expr(3)]) + expr = relay.Call(iterate(double, relay.const(2)), [make_nat_expr(prelude, 3)]) res = intrp.evaluate(relay.Function([], expr)()) assert count(res) == 12 -def check_tensor_array(ta_mod, ref_res, *args, dtype="float32", rtol=1e-5): - for kind in ["debug", "vm"]: - for target, ctx in testing.enabled_targets(): - if kind == "debug" and ctx.device_type != tvm.cpu().device_type: - continue - ex = relay.create_executor(kind, mod=ta_mod, ctx=ctx, target=target) - result = ex.evaluate()(*args) - got = vmobj_to_list(result, dtype) - tvm.testing.assert_allclose(ref_res, got, rtol=rtol, atol=rtol) - - -@tvm.testing.uses_gpu -def test_tensor_expand_dims(): - def run(dtype): - x = relay.var("x") - mod = tvm.IRModule() - p = Prelude(mod) - expand_dims_func = p.get_var("tensor_expand_dims", dtype) - tensor1 = p.get_var("tensor1", dtype) - mod["main"] = relay.Function([x], expand_dims_func(tensor1(x))) - x_np = np.random.uniform(low=0.0, high=8.0, size=(1,)).astype(dtype) - expected = [np.expand_dims(x_np, axis=0)] - check_tensor_array(mod, expected, x_np) - - run("float32") - run("int32") - - -@tvm.testing.uses_gpu -def test_tensor_array_constructor(): - def run(dtype): - x = relay.var("x") - mod = tvm.IRModule() - p = Prelude(mod) - tensor_array = p.get_var("tensor_array", dtype) - mod["main"] = relay.Function([x], tensor_array(x)) - expected = np.array([0, 0, 0, 0, 0]) - check_tensor_array(mod, expected, 5, dtype=dtype) - - run("float32") - run("int32") - - -@tvm.testing.uses_gpu -def test_tensor_array_read(): - def run(dtype): - mod = tvm.IRModule() - p = Prelude(mod) - l = relay.var("l") - i = relay.var("i") - read_func = p.get_var("tensor_array_read", dtype) - tensor_array = p.get_var("tensor_array", dtype) - mod["main"] = relay.Function([l, i], read_func(tensor_array(l), i)) - expected = [0] - check_tensor_array(mod, expected, *(1, 0), dtype=dtype) - check_tensor_array(mod, expected, *(5, 1), dtype=dtype) - - run("float32") - run("int32") - - -@tvm.testing.uses_gpu -def test_tensor_array_write(): - def run(dtype): - mod = tvm.IRModule() - p = Prelude(mod) - v1 = relay.var("v1") - v2 = relay.var("v2") - tensor_array = p.get_var("tensor_array", dtype) - init_tensor_array = tensor_array(relay.const(2)) - write_func = p.get_var("tensor_array_write", dtype) - tensor1 = p.get_var("tensor1", dtype) - tensor_array1 = write_func(init_tensor_array, relay.const(0), tensor1(v1)) - tensor_array2 = write_func(tensor_array1, relay.const(1), tensor1(v2)) - mod["main"] = relay.Function([v1, v2], tensor_array2) - expected = [3, 7] - check_tensor_array(mod, expected, *(3, 7), dtype=dtype) - - run("float32") - run("int32") - - -@tvm.testing.uses_gpu -def test_tensor_array_stack(): - def run(dtype): - mod = tvm.IRModule() - p = Prelude(mod) - tensor_array = p.get_var("tensor_array", dtype) - tensor1 = p.get_var("tensor1", dtype) - write = p.get_var("tensor_array_write", dtype) - stack = p.get_var("tensor_array_stack", dtype) - v = relay.var("v") - init_tensor_array = tensor_array(relay.const(3)) - tensor_array1 = write(init_tensor_array, relay.const(0), tensor1(v)) - tensor_array2 = write(tensor_array1, relay.const(1), tensor1(v)) - tensor_array3 = write(tensor_array2, relay.const(2), tensor1(v)) - tensor_array4 = stack(tensor_array3) - mod["main"] = relay.Function([v], tensor_array4) - t = np.random.uniform(low=0.0, high=8.0, size=(1,)).astype(dtype) - expected = [np.stack([t, t, t])] - check_tensor_array(mod, expected, t, dtype=dtype) - - run("float32") - run("int32") - - -@tvm.testing.uses_gpu -def test_tensor_array_unstack(): - def run(dtype): - mod = tvm.IRModule() - p = Prelude(mod) - unstack_tensor1 = p.get_var("tensor_array_unstack_tensor1", dtype) - v = relay.var("v") - mod["main"] = relay.Function([v], unstack_tensor1(v)) - t = np.random.uniform(low=0.0, high=8.0, size=(1,)).astype(dtype) - check_tensor_array(mod, t, t, dtype=dtype) - - run("float32") - run("int32") - - -@tvm.testing.uses_gpu -def test_tensor_take(): - def run(dtype): - mod = tvm.IRModule() - p = Prelude(mod) - take = p.get_var("tensor_take", dtype) - tensor2 = p.get_var("tensor2", dtype) - v = relay.var("v") - lower = relay.var("lower") - upper = relay.var("upper") - mod["main"] = relay.Function([v, lower, upper], take(tensor2(v), lower, upper)) - v_data = np.random.uniform(low=0.0, high=8.0, size=(10, 10)).astype(dtype) - expected = [np.take(v_data, range(2, 5), axis=0)] - check_tensor_array(mod, expected, *(v_data, 2, 5), dtype=dtype) - expected = [np.take(v_data, range(0, 9), axis=0)] - check_tensor_array(mod, expected, *(v_data, 0, 9), dtype=dtype) - - run("float32") - run("int32") - - -@tvm.testing.uses_gpu -def test_tensor_concatenate(): - def run(dtype): - mod = tvm.IRModule() - p = Prelude(mod) - concat = p.get_var("tensor_concatenate", dtype) - tensor1 = p.get_var("tensor1", dtype) - v1 = relay.var("v1") - v2 = relay.var("v2") - mod["main"] = relay.Function([v1, v2], concat(tensor1(v1), tensor1(v2))) - v1_data = np.random.uniform(low=0.0, high=8.0, size=(5,)).astype(dtype) - v2_data = np.random.uniform(low=0.0, high=8.0, size=(5,)).astype(dtype) - expected = [np.concatenate((v1_data, v2_data))] - check_tensor_array(mod, expected, *(v1_data, v2_data), dtype=dtype) - - run("float32") - run("int32") - - -@tvm.testing.uses_gpu -def test_tensor_array_concat(): - def run(dtype): - mod = tvm.IRModule() - p = Prelude(mod) - v1 = relay.var("v1") - v2 = relay.var("v2") - tensor_array = p.get_var("tensor_array", dtype) - tensor_array1 = tensor_array(relay.const(2)) - write_func = p.get_var("tensor_array_write", dtype) - concat_func = p.get_var("tensor_array_concat", dtype) - tensor1 = p.get_var("tensor2", dtype) - tensor_array1 = write_func(tensor_array1, relay.const(0), tensor1(v1)) - tensor_array1 = write_func(tensor_array1, relay.const(1), tensor1(v2)) - tensor_array_concat = concat_func(tensor_array1) - mod["main"] = relay.Function([v1, v2], tensor_array_concat) - v1_data = np.random.uniform(low=0.0, high=8.0, size=(2, 3)).astype(dtype) - v2_data = np.random.uniform(low=0.0, high=8.0, size=(1, 3)).astype(dtype) - expected = [np.concatenate((v1_data, v2_data), axis=0)] - check_tensor_array(mod, expected, *(v1_data, v2_data), dtype=dtype) - - run("float32") - run("int32") - - -@tvm.testing.uses_gpu -def test_tensor_array_scatter(): - def run(dtype): - mod = tvm.IRModule() - p = Prelude(mod) - - # tensor array - v1 = relay.var("v1") - v2 = relay.var("v2") - v3 = relay.var("v2") - tensor_array = p.get_var("tensor_array", dtype) - tensor_array1 = tensor_array(relay.const(3)) - write_func = p.get_var("tensor_array_write", dtype) - scatter_func = p.get_var("tensor_array_scatter", dtype) - tensor2 = p.get_var("tensor2", dtype) - tensor_array1 = write_func(tensor_array1, relay.const(0), tensor2(v1)) - tensor_array1 = write_func(tensor_array1, relay.const(1), tensor2(v2)) - tensor_array1 = write_func(tensor_array1, relay.const(2), tensor2(v3)) - - # indices array - index = relay.var("index") - - # values array - value_0 = relay.var("value_0") - value_1 = relay.var("value_1") - values_array = tensor_array(relay.const(2)) - values_array = write_func(values_array, relay.const(0), tensor2(value_0)) - values_array = write_func(values_array, relay.const(1), tensor2(value_1)) - - # create the scatter function - tensor_array_scatter = scatter_func(tensor_array1, index, values_array) - mod["main"] = relay.Function([v1, v2, v3, index, value_0, value_1], tensor_array_scatter) - - # initialize and check - v1_data = np.random.uniform(low=0.0, high=8.0, size=(2, 3)).astype(dtype) - v2_data = np.random.uniform(low=0.0, high=8.0, size=(2, 3)).astype(dtype) - v3_data = np.random.uniform(low=0.0, high=8.0, size=(2, 3)).astype(dtype) - index_data = np.array([0, 1], dtype="int32") - val1_data = np.random.uniform(low=0.0, high=8.0, size=(2, 3)).astype(dtype) - val2_data = np.random.uniform(low=0.0, high=8.0, size=(2, 3)).astype(dtype) - expected = [val1_data, val2_data, v3_data] - check_tensor_array( - mod, - expected, - *(v1_data, v2_data, v3_data, index_data, val1_data, val2_data), - dtype=dtype, - ) - - run("float32") - run("int32") - - -@tvm.testing.uses_gpu -def test_tensor_array_split(): - def run(dtype): - mod = tvm.IRModule() - p = Prelude(mod) - - # tensor array - v1 = relay.var("v1") - v2 = relay.var("v2") - v3 = relay.var("v2") - tensor_array = p.get_var("tensor_array", dtype) - tensor_array1 = tensor_array(relay.const(3)) - write_func = p.get_var("tensor_array_write", dtype) - split_func = p.get_var("tensor_array_split", dtype) - tensor2 = p.get_var("tensor2", dtype) - tensor_array1 = write_func(tensor_array1, relay.const(0), tensor2(v1)) - tensor_array1 = write_func(tensor_array1, relay.const(1), tensor2(v2)) - tensor_array1 = write_func(tensor_array1, relay.const(2), tensor2(v3)) - - # value tensor - value = relay.var("value") - - # lengths tensor - ta_len = relay.var("length") - - # create the scatter function - tensor_array_split = split_func(tensor_array1, tensor2(value), ta_len) - mod["main"] = relay.Function([v1, v2, v3, value, ta_len], tensor_array_split) - - # initialize and check - v1_data = np.random.uniform(low=0.0, high=8.0, size=(2, 3)).astype(dtype) - v2_data = np.random.uniform(low=0.0, high=8.0, size=(2, 3)).astype(dtype) - v3_data = np.random.uniform(low=0.0, high=8.0, size=(2, 3)).astype(dtype) - value_data = np.random.uniform(low=0.0, high=8.0, size=(4, 3)).astype(dtype) - length_data = np.array([2, 2], dtype="int32") - expected = np.concatenate([value_data, v3_data]) - expected = np.split(expected, indices_or_sections=[2, 4]) - check_tensor_array( - mod, expected, *(v1_data, v2_data, v3_data, value_data, length_data), dtype=dtype - ) - - run("float32") - run("int32") - - -@tvm.testing.uses_gpu -def test_static_tensor_take(): - def run(dtype, shape): - mod = tvm.IRModule() - p = Prelude(mod) - static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) - static_tensor_array_ops.register() - - take = p.get_var_static("tensor_take", dtype, shape) - tensor_constructor = p.get_var_static("tensor_constructor", dtype, shape) - v = relay.var("v") - lower = relay.var("lower") - upper = relay.var("upper") - mod["main"] = relay.Function([v, lower, upper], take(tensor_constructor(v), lower, upper)) - v_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype) - expected = [np.take(v_data, range(2, 5), axis=0)] - check_tensor_array(mod, expected, *(v_data, 2, 5), dtype=dtype) - expected = [np.take(v_data, range(0, 9), axis=0)] - check_tensor_array(mod, expected, *(v_data, 0, 9), dtype=dtype) - - run("float32", [10, 10]) - run("int32", [15, 11]) - - -@tvm.testing.uses_gpu -def test_static_tensor_concatenate(): - def run(dtype, shape): - mod = tvm.IRModule() - p = Prelude(mod) - static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) - static_tensor_array_ops.register() - - concat = p.get_var_static("tensor_concatenate", dtype, shape) - tensor = p.get_var_static("tensor_constructor", dtype, shape) - v1 = relay.var("v1") - v2 = relay.var("v2") - mod["main"] = relay.Function([v1, v2], concat(tensor(v1), tensor(v2))) - v1_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype) - v2_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype) - expected = [np.concatenate((v1_data, v2_data))] - check_tensor_array(mod, expected, *(v1_data, v2_data), dtype=dtype) - - run( - "float32", - [ - 5, - ], - ) - run("int32", [2, 3]) - - -@tvm.testing.uses_gpu -def test_static_tensor_expand_dims(): - def run(dtype, shape): - x = relay.var("x") - mod = tvm.IRModule() - p = Prelude(mod) - static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) - static_tensor_array_ops.register() - - expand_dims_func = p.get_var_static("tensor_expand_dims", dtype, shape) - tensor = p.get_var_static("tensor_constructor", dtype, shape) - mod["main"] = relay.Function([x], expand_dims_func(tensor(x))) - x_np = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype) - expected = [np.expand_dims(x_np, axis=0)] - check_tensor_array(mod, expected, x_np) - - run("float32", []) - run( - "int32", - [ - 2, - ], - ) - - -@tvm.testing.uses_gpu -def test_static_tensor_array_constructor(): - def run(dtype, shape): - mod = tvm.IRModule() - p = Prelude(mod) - static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) - static_tensor_array_ops.register() - tensor_constructor = p.get_name_static("tensor_constructor", dtype, shape) - assert tensor_constructor != None - - run("float32", [1, 1]) - - -@tvm.testing.uses_gpu -def test_static_tensor_array_read(): - def run(dtype, shape): - mod = tvm.IRModule() - p = Prelude(mod) - static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) - static_tensor_array_ops.register() - - np_data_list = [] - ta_length = 3 - for _ in range(ta_length): - np_data_list.append(np.random.uniform(0, 10, size=shape).astype(dtype)) - - v0 = relay.var("v0") - v1 = relay.var("v1") - v2 = relay.var("v2") - n = relay.var("n") - tensor = p.get_var_static("tensor_constructor", dtype, shape) - tensor_array = p.get_var_static("tensor_array", dtype, shape) - init_tensor_array = tensor_array(relay.const(ta_length)) - read_func = p.get_var_static("tensor_array_read", dtype, shape) - write_func = p.get_var_static("tensor_array_write", dtype, shape) - tensor_array0 = write_func(init_tensor_array, relay.const(0), tensor(v0)) - tensor_array1 = write_func(tensor_array0, relay.const(1), tensor(v1)) - tensor_array2 = write_func(tensor_array1, relay.const(2), tensor(v2)) - - mod["main"] = relay.Function([v0, v1, v2, n], read_func(tensor_array2, n)) - expected = [np_data_list[0]] - check_tensor_array(mod, expected, *list(np_data_list + [0]), dtype=dtype) - expected = [np_data_list[1]] - check_tensor_array(mod, expected, *list(np_data_list + [1]), dtype=dtype) - expected = [np_data_list[2]] - check_tensor_array(mod, expected, *list(np_data_list + [2]), dtype=dtype) - - run("float32", []) - run("int32", [2, 3]) - - -@tvm.testing.uses_gpu -def test_static_tensor_array_write(): - def run(dtype, shape): - mod = tvm.IRModule() - p = Prelude(mod) - static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) - static_tensor_array_ops.register() - - ta_length = 2 - np_data_list = [ - np.random.uniform(0, 10, size=shape).astype(dtype) for _ in range(ta_length) - ] - - v0 = relay.var("v0") - v1 = relay.var("v1") - tensor_array = p.get_var_static("tensor_array", dtype, shape) - init_tensor_array = tensor_array(relay.const(ta_length)) - write_func = p.get_var_static("tensor_array_write", dtype, shape) - tensor = p.get_var_static("tensor_constructor", dtype, shape) - tensor_array0 = write_func(init_tensor_array, relay.const(0), tensor(v0)) - tensor_array1 = write_func(tensor_array0, relay.const(1), tensor(v1)) - mod["main"] = relay.Function([v0, v1], tensor_array1) - expected = np_data_list - check_tensor_array(mod, expected, *np_data_list, dtype=dtype) - - run("float32", []) - run("int32", [2, 3]) - - -@tvm.testing.uses_gpu -def test_static_tensor_array_unstack(): - def run(dtype, shape): - mod = tvm.IRModule() - p = Prelude(mod) - static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) - static_tensor_array_ops.register() - - unstack_tensor = p.get_var_static("tensor_array_unstack", dtype, shape) - v = relay.var("v") - mod["main"] = relay.Function([v], unstack_tensor(v)) - t = np.random.uniform(low=0, high=10, size=shape).astype(dtype) - (*expected,) = t - check_tensor_array(mod, expected, t, dtype=dtype) - - run("float32", [4]) - run("int32", [2, 3]) - - -@tvm.testing.uses_gpu -def test_static_tensor_array_scatter(): - def run(dtype, shape, indices_shape=None): - mod = tvm.IRModule() - p = Prelude(mod) - static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) - static_tensor_array_ops.register() - if indices_shape is not None: - static_tensor_array_ops.define_tensor_array_scatter(indices_shape, True) - - # tensor array - v1 = relay.var("v1") - v2 = relay.var("v2") - v3 = relay.var("v2") - tensor_array = p.get_var_static("tensor_array", dtype, shape) - tensor_array0 = tensor_array(relay.const(3)) - write_func = p.get_var_static("tensor_array_write", dtype, shape) - scatter_func = p.get_var_static("tensor_array_scatter", dtype, shape) - tensor = p.get_var_static("tensor_constructor", dtype, shape) - tensor_array1 = write_func(tensor_array0, relay.const(0), tensor(v1)) - tensor_array1 = write_func(tensor_array1, relay.const(1), tensor(v2)) - tensor_array1 = write_func(tensor_array1, relay.const(2), tensor(v3)) - - # indices array - index = relay.var("index") - - # values array - value_0 = relay.var("value_0") - value_1 = relay.var("value_1") - values_array = tensor_array(relay.const(2)) - values_array = write_func(values_array, relay.const(0), tensor(value_0)) - values_array = write_func(values_array, relay.const(1), tensor(value_1)) - - # create the scatter function - tensor_array_scatter = scatter_func(tensor_array1, index, values_array) - mod["main"] = relay.Function([v1, v2, v3, index, value_0, value_1], tensor_array_scatter) - - # initialize and check - v1_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype) - v2_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype) - v3_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype) - index_data = np.array([0, 1], dtype="int32") - val1_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype) - val2_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype) - expected = [val1_data, val2_data, v3_data] - check_tensor_array( - mod, - expected, - *(v1_data, v2_data, v3_data, index_data, val1_data, val2_data), - dtype=dtype, - ) - - run("float32", [2, 3]) - run("int32", [2, 3]) - run( - "float32", - [2, 3], - [ - 2, - ], - ) - - -@tvm.testing.uses_gpu -def test_static_tensor_array_split(): - def run(dtype, shape, value_shape=None, lengths_shape=None): - mod = tvm.IRModule() - p = Prelude(mod) - static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) - static_tensor_array_ops.register() - if value_shape is not None or lengths_shape is not None: - static_tensor_array_ops.define_tensor_array_split(value_shape, lengths_shape, True) - - # tensor array - v1 = relay.var("v1") - v2 = relay.var("v2") - v3 = relay.var("v2") - - adt_shape = [ - relay.Any(), - ] + shape[1:] - origin_shape = static_tensor_array_ops.shape - static_tensor_array_ops.shape = adt_shape - static_tensor_array_ops.define_tensor_array() - tensor_array = p.get_var_static("tensor_array", dtype, adt_shape) - static_tensor_array_ops.shape = origin_shape - tensor_array1 = tensor_array(relay.const(3)) - write_func = p.get_var_static("tensor_array_write", dtype, adt_shape) - split_func = p.get_var_static("tensor_array_split", dtype, shape) - tensor = p.get_var_static("tensor_constructor", dtype, adt_shape) - tensor_array1 = write_func(tensor_array1, relay.const(0), tensor(v1)) - tensor_array1 = write_func(tensor_array1, relay.const(1), tensor(v2)) - tensor_array1 = write_func(tensor_array1, relay.const(2), tensor(v3)) - - # value tensor - value = relay.var("value") - - # lengths tensor - ta_len = relay.var("length") - - # create the split function - if value_shape is None: - tensor1 = p.get_var_static("tensor_constructor", dtype, shape) - else: - static_tensor_array_ops = StaticTensorArrayOps(p, dtype, value_shape) - static_tensor_array_ops.register() - tensor1 = p.get_var_static("tensor_constructor", dtype, value_shape) - tensor_array_split = split_func(tensor_array1, tensor1(value), ta_len) - mod["main"] = relay.Function([v1, v2, v3, value, ta_len], tensor_array_split) - - # initialize and check - v1_data = np.random.uniform(low=0.0, high=8.0, size=[2, 3]).astype(dtype) - v2_data = np.random.uniform(low=0.0, high=8.0, size=[2, 3]).astype(dtype) - v3_data = np.random.uniform(low=0.0, high=8.0, size=[2, 3]).astype(dtype) - value_data = np.random.uniform(low=0.0, high=8.0, size=value_shape or shape).astype(dtype) - length_data = np.array([2, 2], dtype="int32") - expected = np.concatenate([value_data, v3_data]) - expected = np.split(expected, indices_or_sections=[2, 4]) - check_tensor_array( - mod, expected, *(v1_data, v2_data, v3_data, value_data, length_data), dtype=dtype - ) - - run("float32", [4, 3]) - run("int32", [4, 3]) - run( - "int32", - [relay.Any(), 3], - [4, 3], - [ - 2, - ], - ) - - -@tvm.testing.uses_gpu -def test_static_tensor_array_concat(): - def run(dtype, shape): - mod = tvm.IRModule() - p = Prelude(mod) - static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) - static_tensor_array_ops.register() - - v1 = relay.var("v1") - v2 = relay.var("v2") - tensor_array = p.get_var_static("tensor_array", dtype, shape) - tensor_array1 = tensor_array(relay.const(2)) - write_func = p.get_var_static("tensor_array_write", dtype, shape) - concat_func = p.get_var_static("tensor_array_concat", dtype, shape) - tensor = p.get_var_static("tensor_constructor", dtype, shape) - tensor_array1 = write_func(tensor_array1, relay.const(0), tensor(v1)) - tensor_array1 = write_func(tensor_array1, relay.const(1), tensor(v2)) - tensor_array_concat = concat_func(tensor_array1) - mod["main"] = relay.Function([v1, v2], tensor_array_concat) - v1_data = np.random.uniform(low=0.0, high=8.0, size=(2, 3)).astype(dtype) - v2_data = np.random.uniform(low=0.0, high=8.0, size=(1, 3)).astype(dtype) - expected = [np.concatenate((v1_data, v2_data), axis=0)] - check_tensor_array(mod, expected, *(v1_data, v2_data), dtype=dtype) - - run("float32", [relay.Any(), 3]) - run("int32", [relay.Any(), 3]) - - -@tvm.testing.uses_gpu -def test_static_tensor_array_gather(): - def run(dtype, shape): - mod = tvm.IRModule() - p = Prelude(mod) - static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) - static_tensor_array_ops.register() - - tensor_array = p.get_var_static("tensor_array", dtype, shape) - tensor = p.get_var_static("tensor_constructor", dtype, shape) - write = p.get_var_static("tensor_array_write", dtype, shape) - gather = p.get_var_static("tensor_array_gather", dtype, shape) - v = relay.var("v") - indice = relay.var("indice") - init_tensor_array = tensor_array(relay.const(3)) - tensor_array1 = write(init_tensor_array, relay.const(0), tensor(v)) - tensor_array2 = write(tensor_array1, relay.const(1), tensor(v)) - tensor_array3 = write(tensor_array2, relay.const(2), tensor(v)) - out = gather(tensor_array3, indice) - mod["main"] = relay.Function([v, indice], out) - t = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype) - indice_data = np.array([0, 2], dtype="int32") - expected = [np.stack([t, t])] - check_tensor_array(mod, expected, *(t, indice_data), dtype=dtype) - - run("float32", []) - run("int32", [2, 3]) - - -@tvm.testing.uses_gpu -def test_static_tensor_array_stack(): - def run(dtype, shape): - mod = tvm.IRModule() - p = Prelude(mod) - static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) - static_tensor_array_ops.register() - - tensor_array = p.get_var_static("tensor_array", dtype, shape) - tensor = p.get_var_static("tensor_constructor", dtype, shape) - write = p.get_var_static("tensor_array_write", dtype, shape) - stack = p.get_var_static("tensor_array_stack", dtype, shape) - v = relay.var("v") - init_tensor_array = tensor_array(relay.const(3)) - tensor_array1 = write(init_tensor_array, relay.const(0), tensor(v)) - tensor_array2 = write(tensor_array1, relay.const(1), tensor(v)) - tensor_array3 = write(tensor_array2, relay.const(2), tensor(v)) - tensor_array4 = stack(tensor_array3) - mod["main"] = relay.Function([v], tensor_array4) - t = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype) - expected = [np.stack([t, t, t])] - check_tensor_array(mod, expected, t, dtype=dtype) - - run("float32", []) - run("int32", [2, 3]) - - -@tvm.testing.uses_gpu -def test_static_tensor_get_data(): - def run(dtype, shape): - mod = tvm.IRModule() - p = Prelude(mod) - static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) - static_tensor_array_ops.register() - - np_data_list = [] - ta_length = 3 - for _ in range(ta_length): - np_data_list.append(np.random.uniform(0, 10, size=shape).astype(dtype)) - - v0 = relay.var("v0") - v1 = relay.var("v1") - v2 = relay.var("v2") - n = relay.var("n") - tensor = p.get_var_static("tensor_constructor", dtype, shape) - tensor_array = p.get_var_static("tensor_array", dtype, shape) - init_tensor_array = tensor_array(relay.const(ta_length)) - read_func = p.get_var_static("tensor_array_read", dtype, shape) - write_func = p.get_var_static("tensor_array_write", dtype, shape) - get_data_func = p.get_var_static("tensor_get_data", dtype, shape) - tensor_array0 = write_func(init_tensor_array, relay.const(0), tensor(v0)) - tensor_array1 = write_func(tensor_array0, relay.const(1), tensor(v1)) - tensor_array2 = write_func(tensor_array1, relay.const(2), tensor(v2)) - - mod["main"] = relay.Function([v0, v1, v2, n], get_data_func(read_func(tensor_array2, n))) - expected = [np_data_list[0]] - check_tensor_array(mod, expected, *list(np_data_list + [0]), dtype=dtype) - expected = [np_data_list[1]] - check_tensor_array(mod, expected, *list(np_data_list + [1]), dtype=dtype) - expected = [np_data_list[2]] - check_tensor_array(mod, expected, *list(np_data_list + [2]), dtype=dtype) - - run("float32", []) - run("int32", [2, 3]) - - if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/relay/test_analysis_get_calibration_data.py b/tests/python/relay/test_analysis_get_calibration_data.py index 72c1c8181af6d..66500f84db2c5 100644 --- a/tests/python/relay/test_analysis_get_calibration_data.py +++ b/tests/python/relay/test_analysis_get_calibration_data.py @@ -48,6 +48,7 @@ def test_simple_graph(): f0 = f0.with_attr("Compiler", "test_graph") g0 = relay.GlobalVar("g0") mod[g0] = f0 + mod = relay.transform.InferType()(mod) x1 = relay.var("x1", shape=(8, 8)) y1 = relay.var("y1", shape=(8, 8)) @@ -56,6 +57,7 @@ def test_simple_graph(): f1 = f1.with_attr("Compiler", "test_graph") g1 = relay.GlobalVar("g1") mod[g1] = f1 + mod = relay.transform.InferType()(mod) x = relay.var("x", shape=(8, 8)) y = relay.var("y", shape=(8, 8)) @@ -64,6 +66,7 @@ def test_simple_graph(): c1 = relay.Call(g1, [relay.TupleGetItem(c0, 0), z]) fm = relay.Function([x, y, z], c1) mod["main"] = fm + mod = relay.transform.InferType()(mod) x_data = np.random.rand(8, 8).astype("float32") y_data = np.random.rand(8, 8).astype("float32") diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index b9b58a5a491a2..c445cd1944009 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -22,6 +22,7 @@ from tvm import relay from tvm.relay.loops import while_loop from tvm.relay.testing import run_infer_type as infer_type +from util.assert_diagnostic import DiagnosticTesting import tvm.topi.testing @@ -985,11 +986,9 @@ def _body(i, st): start = relay.var("start", shape=(), dtype="int32") body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1))) func = relay.Function([start], relay.TupleGetItem(body, 1)) - try: + with DiagnosticTesting() as diagnostics: + diagnostics.assert_message("in particular dimension 0 conflicts 2 does not match 1") func = infer_type(func) - assert False - except Exception as e: - assert "in particular dimension 0 conflicts 2 does not match 1" in str(e) @tvm.testing.uses_gpu diff --git a/tests/python/relay/test_backend_graph_runtime.py b/tests/python/relay/test_backend_graph_runtime.py index 5550caa93d3cf..1bd551004ad7e 100644 --- a/tests/python/relay/test_backend_graph_runtime.py +++ b/tests/python/relay/test_backend_graph_runtime.py @@ -125,8 +125,10 @@ def test_plan_memory(): z = relay.exp(z) func = relay.Function([x, y], z) mod = tvm.IRModule.from_expr(func) + mod = relay.transform.InferType()(mod) mod = relay.transform.FuseOps(0)(mod) func = mod["main"] + mod = relay.transform.InferType()(mod) smap = relay.backend._backend.GraphPlanMemory(func) storage_ids = set() device_types = set() diff --git a/tests/python/relay/test_backend_interpreter.py b/tests/python/relay/test_backend_interpreter.py index 08a840148276b..0beb93deaef5e 100644 --- a/tests/python/relay/test_backend_interpreter.py +++ b/tests/python/relay/test_backend_interpreter.py @@ -170,12 +170,13 @@ def test_function_taking_adt_ref_tuple(): mod = tvm.IRModule() prelude = relay.prelude.Prelude(mod) intrp = create_executor("debug", mod) + _, cons, nil = prelude.mod.get_type("List") - nil_value = ConstructorValue(prelude.nil.tag, [], prelude.nil) + nil_value = ConstructorValue(nil.tag, [], nil) cons_value = ConstructorValue( - prelude.cons.tag, + cons.tag, [nd.array(np.random.rand(1, 10).astype("float32")), nil_value], - prelude.cons, + cons, ) ref_value = RefValue(nd.array(np.random.rand(1, 10).astype("float32"))) @@ -194,7 +195,7 @@ def test_function_taking_adt_ref_tuple(): assert len(res_cons.fields) == len(cons_value.fields) tvm.testing.assert_allclose(res_cons.fields[0].asnumpy(), cons_value.fields[0].asnumpy()) assert isinstance(res_cons.fields[1], ConstructorValue) - assert res_cons.fields[1].tag == prelude.nil.tag + assert res_cons.fields[1].tag == nil.tag assert len(res_cons.fields[1].fields) == 0 res_ref = id_func(ref_value) diff --git a/tests/python/relay/test_error_reporting.py b/tests/python/relay/test_error_reporting.py deleted file mode 100644 index fc5c743cf290d..0000000000000 --- a/tests/python/relay/test_error_reporting.py +++ /dev/null @@ -1,68 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import tvm -from tvm import te -from tvm import relay - - -def check_type_err(expr, msg): - try: - mod = tvm.IRModule.from_expr(expr) - mod = relay.transform.InferType()(mod) - entry = mod["main"] - expr = entry if isinstance(expr, relay.Function) else entry.body - assert False - except tvm.error.TVMError as err: - assert msg in str(err) - - -def test_wellformed(): - x = relay.var("x", shape=(10, 10)) - f = relay.Function([x], x) - check_type_err(f(x), "Check failed: WellFormed") - - -def test_too_many_args(): - x = relay.var("x", shape=(10, 10)) - f = relay.Function([x], x) - y = relay.var("y", shape=(10, 10)) - check_type_err(f(y, y), "the function is provided too many arguments expected 1, found 2;") - - -def test_too_few_args(): - x = relay.var("x", shape=(10, 10)) - y = relay.var("y", shape=(10, 10)) - z = relay.var("z", shape=(10, 10)) - f = relay.Function([x, y], x) - check_type_err(f(z), "the function is provided too few arguments expected 2, found 1;") - - -def test_rel_fail(): - x = relay.var("x", shape=(10, 10)) - y = relay.var("y", shape=(11, 10)) - f = relay.Function([x, y], x + y) - check_type_err( - f, - "Incompatible broadcast type TensorType([10, 10], float32) and TensorType([11, 10], float32);", - ) - - -if __name__ == "__main__": - test_wellformed() - test_too_many_args() - test_too_few_args() - test_rel_fail() diff --git a/tests/python/relay/test_expr_functor.py b/tests/python/relay/test_expr_functor.py index f8ae7a91fc025..45317836faf07 100644 --- a/tests/python/relay/test_expr_functor.py +++ b/tests/python/relay/test_expr_functor.py @@ -124,8 +124,9 @@ def test_match(): def test_match_completeness(): p = relay.prelude.Prelude() + _, _, nil = p.mod.get_type("List") for completeness in [True, False]: - match_expr = relay.adt.Match(p.nil, [], complete=completeness) + match_expr = relay.adt.Match(nil, [], complete=completeness) result_expr = ExprMutator().visit(match_expr) # ensure the mutator doesn't mangle the completeness flag assert result_expr.complete == completeness diff --git a/tests/python/relay/test_ir_module.py b/tests/python/relay/test_ir_module.py index d3f8f2ccf84fa..c87ca19f117f3 100644 --- a/tests/python/relay/test_ir_module.py +++ b/tests/python/relay/test_ir_module.py @@ -19,24 +19,32 @@ from tvm import te from tvm import relay from tvm.relay.prelude import Prelude -from tvm.relay.testing import add_nat_definitions def constructor_list(p): - return [p.nil, p.cons, p.rose, p.some, p.none, p.z, p.s] + list_ctors = p.mod.get_type("List") + optional_ctors = p.mod.get_type("Option") + nat_ctors = p.mod.get_type("nat") + rose_ctors = p.mod.get_type("Tree") + return list_ctors[1:] + optional_ctors[1:] + nat_ctors[1:] + rose_ctors[1:] def adt_list(p): - return [p.nat, p.l, p.optional, p.tree] + list_ctors = p.mod.get_type("List") + optional_ctors = p.mod.get_type("Option") + nat_ctors = p.mod.get_type("nat") + rose_ctors = p.mod.get_type("Tree") + return list_ctors[:1] + optional_ctors[:1] + nat_ctors[:1] + rose_ctors[:1] def test_constructor_tag_round_trip(): mod1 = tvm.IRModule() p1 = Prelude(mod1) - add_nat_definitions(p1) + p1.mod.import_from_std("nat.rly") + mod2 = tvm.IRModule() p2 = Prelude(mod2) - add_nat_definitions(p2) + p2.mod.import_from_std("nat.rly") # ensure hashes match across modules ctors1 = constructor_list(p1) @@ -55,7 +63,7 @@ def test_constructor_tag_differences(): # each other mod = tvm.IRModule() p = Prelude(mod) - add_nat_definitions(p) + p.mod.import_from_std("nat.rly") adts = adt_list(p) for adt in adts: diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 2ac98d2066617..c5217ba41bfdb 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -111,6 +111,7 @@ def assert_parses_as(code, expr): def assert_parse_module_as(code, mod): + mod = tvm.relay.transform.InferType()(mod) parsed = parse_module(code) assert_graph_equal(parsed, mod) @@ -295,6 +296,14 @@ def test_tuple(): assert_parses_as("(0, 1, 2)", relay.Tuple([relay.const(0), relay.const(1), relay.const(2)])) +def test_tuple_proj(): + x = relay.var("x", shape=()) + assert_parses_as( + "free_var %x: float32; %x((%x,).0, %x)", + relay.Call(x, [relay.TupleGetItem(relay.Tuple([x]), 0), x]), + ) + + def test_func(): # 0 args assert_parses_as("fn () { 0 }", relay.Function([], relay.const(0), None, [])) @@ -367,6 +376,18 @@ def test_ifelse_scope(): ) +def test_ref(): + program = """ + #[version = "0.0.5"] + def @main(%x: float32) { + %0 = ref(%x); + ref_write(%0, 1f); + ref_read(%0) + } + """ + tvm.parser.parse(program) + + def test_call(): # select right function to call: simple ident case id_func = relay.Var("id") @@ -825,18 +846,35 @@ def inline_params(mod, params): body = relay.bind(main_fn.body, bind_map) main_fn = relay.Function(relay.analysis.free_vars(body), body) - mod["main_fn"] = main_fn + mod._add("main", main_fn, True) return mod def test_resnet_inlined_params(): mod, params = relay.testing.resnet.get_workload() mod = inline_params(mod, params) + mod = relay.transform.InferType()(mod) text = mod.astext() parsed_mod = tvm.parser.parse(text) tvm.ir.assert_structural_equal(mod, parsed_mod) +def test_tuple_return_value(): + program = """ + type Box[T] { + constructor(T) + } + + def @example() { + %0 = (); + %1 = constructor(%0); + %2 = constructor(0f); + (%1, %2,) + } + """ + parse_module(program) + + def test_op_string_attr(): call = parse_text( """ @@ -845,12 +883,19 @@ def test_op_string_attr(): nn.conv2d(%x, %y, data_layout="NHWC", kernel_layout="HWIO") """ ) + assert isinstance(call.op, tvm.ir.Op) assert call.op.name == "nn.conv2d" assert call.attrs.data_layout == "NHWC" assert call.attrs.kernel_layout == "HWIO" +def test_load_prelude(): + mod = tvm.IRModule() + mod.import_from_std("prelude.rly") + tvm.parser.parse(mod.astext()) + + if __name__ == "__main__": import sys diff --git a/tests/python/relay/test_ir_structural_equal_hash.py b/tests/python/relay/test_ir_structural_equal_hash.py index 65919e051831c..eec5a7b5f126c 100644 --- a/tests/python/relay/test_ir_structural_equal_hash.py +++ b/tests/python/relay/test_ir_structural_equal_hash.py @@ -589,63 +589,62 @@ def test_constructor_sequal(): # smoke test: it should be pointer equality mod = tvm.IRModule() p = relay.prelude.Prelude(mod) + _, cons, nil = p.mod.get_type("List") - assert consistent_equal(p.nil, p.nil) - assert consistent_equal(p.cons, p.cons) - assert not consistent_equal(p.nil, p.cons) + assert consistent_equal(nil, nil) + assert consistent_equal(cons, cons) + assert not consistent_equal(nil, cons) def test_match_sequal(): mod = tvm.IRModule() p = relay.prelude.Prelude(mod) + _, cons, nil = p.mod.get_type("List") + _, none, some = p.mod.get_type("Option") x = relay.Var("x") y = relay.Var("y") - nil_case = relay.Clause(relay.PatternConstructor(p.nil), p.nil()) + nil_case = relay.Clause(relay.PatternConstructor(nil), nil()) cons_case = relay.Clause( - relay.PatternConstructor(p.cons, [relay.PatternVar(x), relay.PatternVar(y)]), p.cons(x, y) + relay.PatternConstructor(cons, [relay.PatternVar(x), relay.PatternVar(y)]), cons(x, y) ) z = relay.Var("z") a = relay.Var("a") equivalent_cons = relay.Clause( - relay.PatternConstructor(p.cons, [relay.PatternVar(z), relay.PatternVar(a)]), p.cons(z, a) + relay.PatternConstructor(cons, [relay.PatternVar(z), relay.PatternVar(a)]), cons(z, a) ) - data = p.cons(relay.const(1), p.cons(relay.const(2), p.nil())) + data = cons(relay.const(1), cons(relay.const(2), nil())) match = relay.Match(data, [nil_case, cons_case]) equivalent = relay.Match(data, [nil_case, equivalent_cons]) empty = relay.Match(data, []) no_cons = relay.Match(data, [nil_case]) no_nil = relay.Match(data, [cons_case]) - different_data = relay.Match(p.nil(), [nil_case, cons_case]) + different_data = relay.Match(nil(), [nil_case, cons_case]) different_order = relay.Match(data, [cons_case, nil_case]) different_nil = relay.Match( - data, [relay.Clause(relay.PatternConstructor(p.nil), p.cons(p.nil(), p.nil())), cons_case] + data, [relay.Clause(relay.PatternConstructor(nil), cons(nil(), nil())), cons_case] ) different_cons = relay.Match( data, [ nil_case, relay.Clause( - relay.PatternConstructor( - p.cons, [relay.PatternWildcard(), relay.PatternWildcard()] - ), - p.nil(), + relay.PatternConstructor(cons, [relay.PatternWildcard(), relay.PatternWildcard()]), + nil(), ), ], ) another_case = relay.Match( - data, [nil_case, cons_case, relay.Clause(relay.PatternWildcard(), p.nil())] + data, [nil_case, cons_case, relay.Clause(relay.PatternWildcard(), nil())] ) wrong_constructors = relay.Match( data, [ - relay.Clause(relay.PatternConstructor(p.none), p.nil()), - relay.Clause( - relay.PatternConstructor(p.some, [relay.PatternVar(x)]), p.cons(x, p.nil()) - ), + relay.Clause(relay.PatternConstructor(none), nil()), + relay.Clause(relay.PatternConstructor(some, [relay.PatternVar(x)]), cons(x, nil())), ], ) diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index fd0853ac6d09d..6c2f7166f4463 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -57,20 +57,21 @@ def test_func(): show(astext(f)) -def test_env(): +def test_mod(): x = relay.var("x", "float32") y = relay.var("y", "float32") z = relay.add(x, y) z = relay.add(z, z) f = relay.Function([x, y], z) - env = tvm.IRModule() - env["myf"] = f - text = astext(env) + mod = tvm.IRModule() + mod["myf"] = f + mod = relay.transform.InferType()(mod) + text = astext(mod) assert "def @myf" in text - assert "def @myf" in str(env) + assert "def @myf" in str(mod) assert "add(%0, %0) /* ty=float32 */" in text - assert "add(%0, %0) /* ty=float32 */" in str(env) - show(env.astext(annotate=lambda x: str(x.checked_type.dtype) if type(x) == relay.Call else "")) + assert "add(%0, %0) /* ty=float32 */" in str(mod) + show(mod.astext(annotate=lambda x: str(x.checked_type.dtype) if type(x) == relay.Call else "")) show(text) diff --git a/tests/python/relay/test_ir_well_formed.py b/tests/python/relay/test_ir_well_formed.py index 53333d1131c67..44750ad0643e0 100644 --- a/tests/python/relay/test_ir_well_formed.py +++ b/tests/python/relay/test_ir_well_formed.py @@ -52,11 +52,12 @@ def test_tuple_get_item(): def test_adt(): mod = tvm.IRModule() p = Prelude(mod) + _, none, some = p.mod.get_type("Option") x = relay.Var("x") - some_case = relay.Clause(relay.PatternConstructor(p.some, [relay.PatternVar(x)]), x) + some_case = relay.Clause(relay.PatternConstructor(some, [relay.PatternVar(x)]), x) default_case = relay.Clause(relay.PatternVar(x), x) - m0 = relay.Match(p.none(), [default_case]) - m1 = relay.Match(p.none(), [some_case, default_case]) + m0 = relay.Match(none(), [default_case]) + m1 = relay.Match(none(), [some_case, default_case]) assert well_formed(m0) assert not well_formed(m1) diff --git a/tests/python/relay/test_json_runtime.py b/tests/python/relay/test_json_runtime.py index ef567161063ba..c09dab34be1e9 100644 --- a/tests/python/relay/test_json_runtime.py +++ b/tests/python/relay/test_json_runtime.py @@ -108,11 +108,13 @@ def conv2d_direct(): glb_var = relay.GlobalVar("dnnl_0") mod = tvm.IRModule() mod[glb_var] = func + mod = transform.InferType()(mod) data = relay.var("data", shape=(ishape), dtype=dtype) weight = relay.var("weight", shape=(w1shape), dtype=dtype) main_f = relay.Function([data, weight], glb_var(data, weight)) mod["main"] = main_f + mod = transform.InferType()(mod) data0 = relay.var("data", shape=ishape, dtype=dtype) weight0 = relay.var("weight", shape=w1shape, dtype=dtype) @@ -120,6 +122,7 @@ def conv2d_direct(): main_f = relay.Function([data0, weight0], out) ref_mod = tvm.IRModule() ref_mod["main"] = main_f + ref_mod = transform.InferType()(ref_mod) i_data = np.random.uniform(0, 1, ishape).astype(dtype) w1_data = np.random.uniform(0, 1, w1shape).astype(dtype) @@ -140,11 +143,13 @@ def group_conv2d(): glb_var = relay.GlobalVar("dnnl_0") mod = tvm.IRModule() mod[glb_var] = func + mod = transform.InferType()(mod) data = relay.var("data", shape=(ishape), dtype=dtype) weight = relay.var("weight", shape=(w2shape), dtype=dtype) main_f = relay.Function([data, weight], glb_var(data, weight)) mod["main"] = main_f + mod = transform.InferType()(mod) data0 = relay.var("data", shape=(ishape), dtype=dtype) weight0 = relay.var("weight", shape=(w2shape), dtype=dtype) @@ -152,6 +157,7 @@ def group_conv2d(): main_f = relay.Function([data0, weight0], out) ref_mod = tvm.IRModule() ref_mod["main"] = main_f + ref_mod = transform.InferType()(ref_mod) i_data = np.random.uniform(0, 1, ishape).astype(dtype) w_data = np.random.uniform(0, 1, w2shape).astype(dtype) @@ -181,11 +187,13 @@ def gen_add(): glb_var = relay.GlobalVar("dnnl_0") mod = tvm.IRModule() mod[glb_var] = func + mod = transform.InferType()(mod) data0 = relay.var("data0", shape=shape, dtype=dtype) data1 = relay.var("data1", shape=shape, dtype=dtype) main_f = relay.Function([data0, data1], glb_var(data0, data1)) mod["main"] = main_f + mod = transform.InferType()(mod) data0 = relay.var("data0", shape=shape, dtype=dtype) data1 = relay.var("data1", shape=shape, dtype=dtype) @@ -193,6 +201,7 @@ def gen_add(): main_f = relay.Function([data0, data1], out) ref_mod = tvm.IRModule() ref_mod["main"] = main_f + ref_mod = transform.InferType()(ref_mod) return mod, ref_mod @@ -221,16 +230,19 @@ def gen_relu(): glb_var = relay.GlobalVar("dnnl_0") mod = tvm.IRModule() mod[glb_var] = func + mod = transform.InferType()(mod) data0 = relay.var("data0", shape=shape, dtype=dtype) main_f = relay.Function([data0], glb_var(data0)) mod["main"] = main_f + mod = transform.InferType()(mod) data0 = relay.var("data0", shape=shape, dtype=dtype) out = relay.nn.relu(data0) main_f = relay.Function([data0], out) ref_mod = tvm.IRModule() ref_mod["main"] = main_f + ref_mod = transform.InferType()(ref_mod) return mod, ref_mod @@ -268,11 +280,13 @@ def gen_dense(): glb_var = relay.GlobalVar("dnnl_0") mod = tvm.IRModule() mod[glb_var] = func + mod = transform.InferType()(mod) a = relay.var("A", shape=a_shape, dtype=dtype) b = relay.var("B", shape=b_shape, dtype=dtype) main_f = relay.Function([a, b], glb_var(a, b)) mod["main"] = main_f + mod = transform.InferType()(mod) a = relay.var("A", shape=a_shape, dtype=dtype) b = relay.var("B", shape=b_shape, dtype=dtype) @@ -280,6 +294,7 @@ def gen_dense(): main_f = relay.Function([a, b], out) ref_mod = tvm.IRModule() ref_mod["main"] = main_f + ref_mod = transform.InferType()(ref_mod) return mod, ref_mod @@ -314,6 +329,7 @@ def gen_bn(): glb_var = relay.GlobalVar("dnnl_0") mod = tvm.IRModule() mod[glb_var] = func + mod = transform.InferType()(mod) data = relay.var("data", shape=d_shape) gamma = relay.var("gamma", shape=c_shape) @@ -325,6 +341,7 @@ def gen_bn(): glb_var(data, gamma, beta, moving_mean, moving_var), ) mod["main"] = main_f + mod = transform.InferType()(mod) data = relay.var("data", shape=d_shape) gamma = relay.var("gamma", shape=c_shape) @@ -336,6 +353,7 @@ def gen_bn(): main_f = relay.Function([data, gamma, beta, moving_mean, moving_var], out) ref_mod = tvm.IRModule() ref_mod["main"] = main_f + ref_mod = transform.InferType()(ref_mod) return mod, ref_mod @@ -457,12 +475,14 @@ def conv2d_relu(): glb_var = relay.GlobalVar("dnnl_0") mod = tvm.IRModule() mod[glb_var] = p_func + mod = transform.InferType()(mod) # Main function data = relay.var("data", shape=ishape, dtype=dtype) weight = relay.var("weight", shape=w1shape, dtype=dtype) main_func = relay.Function([data, weight], glb_var(data, weight)) mod["main"] = main_func + mod = transform.InferType()(mod) # Reference module data = relay.var("data", shape=ishape, dtype=dtype) @@ -472,6 +492,7 @@ def conv2d_relu(): main_func = relay.Function([data, weight], relu) ref_mod = tvm.IRModule() ref_mod["main"] = main_func + ref_mod = transform.InferType()(ref_mod) i_data = np.random.uniform(0, 1, ishape).astype(dtype) w1_data = np.random.uniform(0, 1, w1shape).astype(dtype) @@ -504,6 +525,7 @@ def conv2d_bias_relu(): glb_var = relay.GlobalVar("dnnl_0") mod = tvm.IRModule() mod[glb_var] = p_func + mod = transform.InferType()(mod) # Main function data = relay.var("data", shape=ishape, dtype=dtype) @@ -511,6 +533,7 @@ def conv2d_bias_relu(): bias = relay.var("bias", shape=bshape, dtype=dtype) main_func = relay.Function([data, weight, bias], glb_var(data, weight, bias)) mod["main"] = main_func + mod = transform.InferType()(mod) # Reference module data = relay.var("data", shape=ishape, dtype=dtype) @@ -522,6 +545,7 @@ def conv2d_bias_relu(): main_func = relay.Function([data, weight, bias], relu) ref_mod = tvm.IRModule() ref_mod["main"] = main_func + ref_mod = transform.InferType()(ref_mod) i_data = np.random.uniform(0, 1, ishape).astype(dtype) w1_data = np.random.uniform(0, 1, w1shape).astype(dtype) diff --git a/tests/python/relay/test_op_qnn_add.py b/tests/python/relay/test_op_qnn_add.py index 41ebfea78bfc3..6f33a7bb0b51e 100644 --- a/tests/python/relay/test_op_qnn_add.py +++ b/tests/python/relay/test_op_qnn_add.py @@ -38,6 +38,7 @@ def test_tflite_same_io_qnn_params(): func = relay.Function([x, y], z) mod = tvm.IRModule.from_expr(func) + mod = relay.transform.InferType()(mod) mod = relay.qnn.transform.CanonicalizeOps()(mod) func = mod["main"] @@ -85,6 +86,7 @@ def test_tflite_different_io_qnn_params(): func = relay.Function([x, y], z) mod = tvm.IRModule.from_expr(func) + mod = relay.transform.InferType()(mod) mod = relay.qnn.transform.CanonicalizeOps()(mod) func = mod["main"] @@ -132,8 +134,10 @@ def test_saturation(): func = relay.Function([x, y], z) mod = tvm.IRModule.from_expr(func) + mod = relay.transform.InferType()(mod) mod = relay.qnn.transform.CanonicalizeOps()(mod) func = mod["main"] + mod = relay.transform.InferType()(mod) x_data = np.array((255, 1, 1, 0)).reshape((1, 4)) y_data = np.array((255, 255, 128, 0)).reshape((1, 4)) @@ -157,6 +161,7 @@ def test_saturation(): func = relay.Function([x, y], z) mod = tvm.IRModule.from_expr(func) + mod = relay.transform.InferType()(mod) mod = relay.qnn.transform.CanonicalizeOps()(mod) func = mod["main"] @@ -182,6 +187,7 @@ def test_saturation(): func = relay.Function([x, y], z) mod = tvm.IRModule.from_expr(func) + mod = relay.transform.InferType()(mod) mod = relay.qnn.transform.CanonicalizeOps()(mod) func = mod["main"] @@ -207,6 +213,7 @@ def test_saturation(): func = relay.Function([x, y], z) mod = tvm.IRModule.from_expr(func) + mod = relay.transform.InferType()(mod) mod = relay.qnn.transform.CanonicalizeOps()(mod) func = mod["main"] diff --git a/tests/python/relay/test_op_qnn_concatenate.py b/tests/python/relay/test_op_qnn_concatenate.py index 230e8a871e2a7..55836dc1ee52d 100644 --- a/tests/python/relay/test_op_qnn_concatenate.py +++ b/tests/python/relay/test_op_qnn_concatenate.py @@ -45,6 +45,7 @@ def test_same_io_qnn_params(): func = relay.Function([x, y], z) mod = tvm.IRModule.from_expr(func) + mod = relay.transform.InferType()(mod) mod = relay.qnn.transform.CanonicalizeOps()(mod) func = mod["main"] @@ -79,6 +80,7 @@ def test_different_io_qnn_params(): func = relay.Function([x, y], z) mod = tvm.IRModule.from_expr(func) + mod = relay.transform.InferType()(mod) mod = relay.qnn.transform.CanonicalizeOps()(mod) func = mod["main"] @@ -113,6 +115,7 @@ def test_few_same_io_qnn_params(): func = relay.Function([x, y], z) mod = tvm.IRModule.from_expr(func) + mod = relay.transform.InferType()(mod) mod = relay.qnn.transform.CanonicalizeOps()(mod) func = mod["main"] @@ -147,6 +150,7 @@ def test_same_i_qnn_params(): func = relay.Function([x, y], z) mod = tvm.IRModule.from_expr(func) + mod = relay.transform.InferType()(mod) mod = relay.qnn.transform.CanonicalizeOps()(mod) func = mod["main"] diff --git a/tests/python/relay/test_op_qnn_dense.py b/tests/python/relay/test_op_qnn_dense.py index a05ee3f7a0a20..923940b5382df 100644 --- a/tests/python/relay/test_op_qnn_dense.py +++ b/tests/python/relay/test_op_qnn_dense.py @@ -207,6 +207,7 @@ def qnn_dense_driver(test_configuration): mod = relay.Function(relay.analysis.free_vars(mod), mod) mod = tvm.IRModule.from_expr(mod) + mod = relay.transform.InferType()(mod) mod = relay.qnn.transform.CanonicalizeOps()(mod) with tvm.transform.PassContext(opt_level=2): graph, lib, params = relay.build(mod, "llvm", params=None) diff --git a/tests/python/relay/test_op_qnn_mul.py b/tests/python/relay/test_op_qnn_mul.py index 17ec137b2a638..7a846cbf47175 100644 --- a/tests/python/relay/test_op_qnn_mul.py +++ b/tests/python/relay/test_op_qnn_mul.py @@ -57,6 +57,7 @@ def test_tflite_same_io_qnn_params(): func = relay.Function([x, y], z) mod = tvm.IRModule.from_expr(func) + mod = relay.transform.InferType()(mod) mod = relay.qnn.transform.CanonicalizeOps()(mod) func = mod["main"] @@ -110,6 +111,7 @@ def test_tflite_different_io_qnn_params(): func = relay.Function([x, y], z) mod = tvm.IRModule.from_expr(func) + mod = relay.transform.InferType()(mod) mod = relay.qnn.transform.CanonicalizeOps()(mod) func = mod["main"] @@ -158,6 +160,7 @@ def test_saturation(): func = relay.Function([x, y], z) mod = tvm.IRModule.from_expr(func) + mod = relay.transform.InferType()(mod) mod = relay.qnn.transform.CanonicalizeOps()(mod) func = mod["main"] @@ -191,6 +194,7 @@ def test_saturation(): func = relay.Function([x, y], z) mod = tvm.IRModule.from_expr(func) + mod = relay.transform.InferType()(mod) mod = relay.qnn.transform.CanonicalizeOps()(mod) func = mod["main"] @@ -225,6 +229,7 @@ def test_saturation(): func = relay.Function([x, y], z) mod = tvm.IRModule.from_expr(func) + mod = relay.transform.InferType()(mod) mod = relay.qnn.transform.CanonicalizeOps()(mod) func = mod["main"] diff --git a/tests/python/relay/test_op_qnn_subtract.py b/tests/python/relay/test_op_qnn_subtract.py index 6a1501cf1b31e..a76b05c31564c 100644 --- a/tests/python/relay/test_op_qnn_subtract.py +++ b/tests/python/relay/test_op_qnn_subtract.py @@ -45,6 +45,7 @@ def qnn_subtract_driver(x_datas, y_datas, golden_outputs, scale_and_zp, data_dty ) func = relay.Function([x, y], z) mod = tvm.IRModule.from_expr(func) + mod = relay.transform.InferType()(mod) mod = relay.qnn.transform.CanonicalizeOps()(mod) func = mod["main"] for i in range(0, len(x_datas)): diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 9df838bfd7175..a7ae9f77fcb71 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -715,6 +715,8 @@ def expected(): mod_new = tvm.IRModule() mod_before["main"] = a mod_new["main"] = b + mod_before = transform.InferType()(mod_before) + mod_new = transform.InferType()(mod_new) with relay.build_config(opt_level=3): for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug", "vm"]: @@ -1171,7 +1173,9 @@ def before(): mod = tvm.IRModule() foo = relay.GlobalVar("foo") mod[foo] = relay.Function([x, weight], y) + mod = transform.InferType()(mod) mod["main"] = relay.Function([x, weight], foo(x, weight)) + mod = transform.InferType()(mod) return mod def alter_conv2d(attrs, inputs, tinfos, out_type): @@ -1193,6 +1197,7 @@ def expected(): mod = tvm.IRModule() foo = relay.GlobalVar("foo") mod[foo] = relay.Function([x, weight], y) + mod = transform.InferType()(mod) mod["main"] = relay.Function([x, weight], foo(x, weight)) return mod diff --git a/tests/python/relay/test_pass_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py index b7c43498a69a0..4f355f60c901e 100644 --- a/tests/python/relay/test_pass_annotate_target.py +++ b/tests/python/relay/test_pass_annotate_target.py @@ -127,7 +127,9 @@ def expected(dtype, ishape, w1shape): def test_annotate(): mod = annotated(dtype, ishape, w1shape) mod = transform.AnnotateTarget("dnnl")(mod) + mod = relay.transform.InferType()(mod) ref_mod = expected(dtype, ishape, w1shape) + ref_mod = relay.transform.InferType()(ref_mod) tvm.ir.assert_structural_equal(mod, ref_mod) def test_run(): diff --git a/tests/python/relay/test_pass_combine_parallel_batch_matmul.py b/tests/python/relay/test_pass_combine_parallel_batch_matmul.py index 84fa40a5cfcea..1c09e15e92a53 100644 --- a/tests/python/relay/test_pass_combine_parallel_batch_matmul.py +++ b/tests/python/relay/test_pass_combine_parallel_batch_matmul.py @@ -25,6 +25,7 @@ def run_opt_pass(expr, opt_pass): "runs the opt_pass on the expr of a function the function" assert isinstance(opt_pass, tvm.transform.Pass) mod = tvm.IRModule.from_expr(expr) + mod = tvm.relay.transform.InferType()(mod) mod = opt_pass(mod) return mod["main"] diff --git a/tests/python/relay/test_pass_combine_parallel_conv2d.py b/tests/python/relay/test_pass_combine_parallel_conv2d.py index fbc836c4bad83..b9a5cca85cd2e 100644 --- a/tests/python/relay/test_pass_combine_parallel_conv2d.py +++ b/tests/python/relay/test_pass_combine_parallel_conv2d.py @@ -28,6 +28,7 @@ def run_combine_parallel(expr, min_num_branches=3): def run_opt_pass(expr, opt_pass): assert isinstance(opt_pass, tvm.transform.Pass) mod = tvm.IRModule.from_expr(expr) + mod = tvm.relay.transform.InferType()(mod) mod = opt_pass(mod) return mod["main"] diff --git a/tests/python/relay/test_pass_combine_parallel_dense.py b/tests/python/relay/test_pass_combine_parallel_dense.py index 7cf8867b4ed1a..a8c9782953bb1 100644 --- a/tests/python/relay/test_pass_combine_parallel_dense.py +++ b/tests/python/relay/test_pass_combine_parallel_dense.py @@ -29,6 +29,7 @@ def run_combine_parallel(expr, min_num_branches=3, to_batch=True): def run_opt_pass(expr, opt_pass): assert isinstance(opt_pass, tvm.transform.Pass) mod = tvm.IRModule.from_expr(expr) + mod = tvm.relay.transform.InferType()(mod) mod = opt_pass(mod) return mod["main"] diff --git a/tests/python/relay/test_pass_eta_expand.py b/tests/python/relay/test_pass_eta_expand.py index bbb2c2be4224d..62cc27d9c94b9 100644 --- a/tests/python/relay/test_pass_eta_expand.py +++ b/tests/python/relay/test_pass_eta_expand.py @@ -68,7 +68,9 @@ def @main[A]() -> fn(A, List[A]) -> List[A] { } """ ) - seq = tvm.transform.Sequential([_transform.EtaExpand(expand_constructor=True)]) + seq = tvm.transform.Sequential( + [_transform.EtaExpand(expand_constructor=True), _transform.InferType()] + ) with tvm.transform.PassContext(opt_level=3): mod = seq(mod) expected = tvm.parser.fromtext( diff --git a/tests/python/relay/test_pass_fold_constant.py b/tests/python/relay/test_pass_fold_constant.py index b3ea4225426e5..549596d616936 100644 --- a/tests/python/relay/test_pass_fold_constant.py +++ b/tests/python/relay/test_pass_fold_constant.py @@ -27,6 +27,7 @@ def run_opt_pass(expr, opt_pass): assert isinstance(opt_pass, tvm.transform.Pass) mod = tvm.IRModule.from_expr(expr) + mod = relay.transform.InferType()(mod) mod = opt_pass(mod) entry = mod["main"] return entry if isinstance(expr, relay.Function) else entry.body @@ -182,7 +183,8 @@ def expected(dtype): y = relay.var("y", shape=c_shape, dtype="float32") z = relay.const(np.size(np.zeros(c_shape)), dtype=dtype) func = relay.Function([x, y], z) - return func + mod = tvm.IRModule.from_expr(func) + return mod["main"] for dtype in ["int32", "float32"]: zz = run_opt_pass(before(dtype), transform.FoldConstant()) diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index 1d9cfb258531c..ff282df7c832d 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -331,8 +331,14 @@ def expected(dim): assert tvm.ir.structural_equal(zz, after) -fuse0 = relay.transform.FuseOps(fuse_opt_level=0) -fuse2 = relay.transform.FuseOps(fuse_opt_level=2) +def fuse0(mod): + mod = relay.transform.InferType()(mod) + return relay.transform.FuseOps(fuse_opt_level=0)(mod) + + +def fuse2(mod): + mod = relay.transform.InferType()(mod) + return relay.transform.FuseOps(fuse_opt_level=2)(mod) def test_tuple_intermediate(): @@ -550,10 +556,10 @@ def expected(): mod["main"] = relay.Function([x], y) return mod - mod = before() + mod = transform.InferType()(before()) new_mod = transform.FuseOps(fuse_opt_level=2)(mod) - assert tvm.ir.structural_equal(mod, before()) - assert tvm.ir.structural_equal(new_mod, expected()) + assert tvm.ir.structural_equal(mod, transform.InferType()(before())) + assert tvm.ir.structural_equal(new_mod, transform.InferType()(expected())) def test_split(): @@ -565,6 +571,7 @@ def test_split(): c = relay.TupleGetItem(y, 2) mod = tvm.IRModule() mod["main"] = relay.Function([x], a + relay.RefRead(relay.RefCreate(b)) + c) + mod = transform.InferType()(mod) mod = transform.FuseOps()(mod) diff --git a/tests/python/relay/test_pass_gradient.py b/tests/python/relay/test_pass_gradient.py index 5d792054da296..93bad3a19c53c 100644 --- a/tests/python/relay/test_pass_gradient.py +++ b/tests/python/relay/test_pass_gradient.py @@ -27,7 +27,6 @@ from tvm.relay.transform import gradient from tvm.relay.prelude import Prelude from tvm.relay.testing import ( - add_nat_definitions, make_nat_expr, run_infer_type, check_grad, @@ -267,15 +266,17 @@ def test_tuple_first_order(): def test_pow(): mod = tvm.IRModule() p = Prelude(mod) - add_nat_definitions(p) + p.mod.import_from_std("nat.rly") + nat_iterate = mod.get_global_var("nat_iterate") shape = (10, 10) dtype = "float32" t = relay.TensorType(shape, dtype) x = relay.var("x", t) double = relay.Function([x], x + x) i = relay.var("i", t) - func = relay.Function([i], p.nat_iterate(double, make_nat_expr(p, 3))(i)) + func = relay.Function([i], nat_iterate(double, make_nat_expr(p, 3))(i)) mod["main"] = func + mod = transform.InferType()(mod) mod["main"] = gradient(mod["main"], mod=mod) m = transform.InferType()(mod) back_func = m["main"] @@ -407,7 +408,9 @@ def test_global_function(): q = GlobalVar("q") m[q] = relay.Function([y], d(d(y))) g = GlobalVar("grad") + m = tvm.relay.transform.InferType()(m) m[g] = tvm.relay.transform.gradient(q, m) + m = tvm.relay.transform.InferType()(m) back_func = m[g] assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) ex = create_executor(mod=m) diff --git a/tests/python/relay/test_pass_inline.py b/tests/python/relay/test_pass_inline.py index aea3a38529b05..fb58b9032e5a1 100644 --- a/tests/python/relay/test_pass_inline.py +++ b/tests/python/relay/test_pass_inline.py @@ -507,10 +507,11 @@ def get_mod(): g1 = relay.GlobalVar("g1") g2 = relay.GlobalVar("g2") mod[g1] = fn1 + mod = relay.transform.InferType()(mod) mod[g2] = fn2 p = relay.var("p", "bool") mod["main"] = relay.Function([p], relay.Call(relay.If(p, g1, g2), [])) - return mod + return relay.transform.InferType()(mod) def expected(): mod = tvm.IRModule({}) @@ -520,7 +521,7 @@ def expected(): fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1)) p = relay.var("p", "bool") mod["main"] = relay.Function([p], relay.Call(relay.If(p, fn1, fn2), [])) - return mod + return relay.transform.InferType()(mod) mod = get_mod() mod = relay.transform.Inline()(mod) diff --git a/tests/python/relay/test_pass_lazy_gradient_init.py b/tests/python/relay/test_pass_lazy_gradient_init.py index 403f88d8ab0a3..9c85ac0a2242b 100644 --- a/tests/python/relay/test_pass_lazy_gradient_init.py +++ b/tests/python/relay/test_pass_lazy_gradient_init.py @@ -39,6 +39,7 @@ def test_tc(): y = relay.Function([x1, x2], (x1 - x2) * x2) mod["main"] = y + mod = transform.InferType()(mod) mod = transform.LazyGradientInit()(mod) # function input/output types should remain the same @@ -58,6 +59,7 @@ def test_add(): y = relay.Function([x], x + x) mod["main"] = y + mod = transform.InferType()(mod) mod = transform.LazyGradientInit()(mod) y = mod["main"] @@ -83,6 +85,7 @@ def test_add_tuple(): y = relay.Function([x], relay.TupleGetItem(x, 0) + relay.TupleGetItem(x, 1)) mod["main"] = y + mod = transform.InferType()(mod) mod = transform.LazyGradientInit()(mod) mod = tvm.transform.PrintIR(show_meta_data=True)(mod) y = mod["main"] @@ -108,6 +111,7 @@ def test_mult(): y = relay.Function([x], x * x) mod["main"] = y + mod = transform.InferType()(mod) mod = transform.LazyGradientInit()(mod) y = mod["main"] @@ -133,6 +137,7 @@ def test_ret_tuple(): func = run_infer_type(func) mod["main"] = func + mod = transform.InferType()(mod) mod = transform.LazyGradientInit()(mod) func = mod["main"] @@ -161,6 +166,7 @@ def test_add_broadcast(): func = run_infer_type(func) mod["main"] = func + mod = transform.InferType()(mod) mod = transform.LazyGradientInit()(mod) func = mod["main"] @@ -194,6 +200,7 @@ def test_reverse_ad_identity(): back_func = run_infer_type(back_func) mod["main"] = back_func + mod = transform.InferType()(mod) mod = transform.LazyGradientInit()(mod) back_func = mod["main"] @@ -225,6 +232,7 @@ def test_multivar_reverse_ad(): back_func = run_infer_type(back_func) mod["main"] = back_func + mod = transform.InferType()(mod) mod = transform.LazyGradientInit()(mod) back_func = mod["main"] @@ -257,6 +265,7 @@ def test_partial_eval(): back_func = run_infer_type(back_func) mod["main"] = back_func + mod = transform.InferType()(mod) back_func = mod["main"] transform.PartialEvaluate()(mod) @@ -282,7 +291,12 @@ def test_after_partial_eval(): back_func = mod["main"] seq = tvm.transform.Sequential( - [transform.PartialEvaluate(), transform.LazyGradientInit(), transform.DeadCodeElimination()] + [ + transform.PartialEvaluate(), + transform.InferType(), + transform.LazyGradientInit(), + transform.DeadCodeElimination(), + ] ) mod = seq(mod) @@ -352,6 +366,7 @@ def test_zeros(): y = relay.Function([x], x + relay.zeros(shape, dtype)) mod["main"] = y + mod = transform.InferType()(mod) mod = transform.LazyGradientInit()(mod) y = mod["main"] @@ -375,6 +390,7 @@ def test_ones(): y = relay.Function([x], x + relay.ones(shape, dtype)) mod["main"] = y + mod = transform.InferType()(mod) mod = transform.LazyGradientInit()(mod) y = mod["main"] @@ -398,6 +414,7 @@ def test_zeros_like(): y = relay.Function([x], x + relay.zeros_like(x)) mod["main"] = y + mod = transform.InferType()(mod) mod = transform.LazyGradientInit()(mod) y = mod["main"] @@ -421,6 +438,7 @@ def test_ones_like(): y = relay.Function([x], x + relay.ones_like(x)) mod["main"] = y + mod = transform.InferType()(mod) mod = transform.LazyGradientInit()(mod) y = mod["main"] diff --git a/tests/python/relay/test_pass_mac_count.py b/tests/python/relay/test_pass_mac_count.py index b3a062ffb455e..68f8851526b60 100644 --- a/tests/python/relay/test_pass_mac_count.py +++ b/tests/python/relay/test_pass_mac_count.py @@ -25,6 +25,7 @@ def run_opt_pass(expr, opt_pass): assert isinstance(opt_pass, tvm.transform.Pass) mod = tvm.IRModule.from_expr(expr) + mod = tvm.relay.transform.InferType()(mod) mod = opt_pass(mod) entry = mod["main"] return entry if isinstance(expr, relay.Function) else entry.body diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index c34ca07f62837..7e2282809f765 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -212,6 +212,7 @@ def transform_function(self, func, mod, ctx): mod = fpass(mod) # wrap in expr mod2 = tvm.IRModule.from_expr(f1) + mod2 = tvm.relay.transform.InferType()(mod2) assert tvm.ir.structural_equal(mod["main"], mod2["main"]) @@ -564,7 +565,9 @@ def test_print_debug_callback(): with tvm.transform.PassContext(opt_level=3, trace=_tracer): mod = seq(mod) - assert __TRACE_COUNTER__ == 3 + # TODO(@jroesch): when we remove new fn pass behavior we need to remove + # change this back to 3 + assert __TRACE_COUNTER__ == 5 if __name__ == "__main__": diff --git a/tests/python/relay/test_pass_merge_compiler_regions.py b/tests/python/relay/test_pass_merge_compiler_regions.py index 5dee0526e452b..8447eeffa6bac 100644 --- a/tests/python/relay/test_pass_merge_compiler_regions.py +++ b/tests/python/relay/test_pass_merge_compiler_regions.py @@ -208,7 +208,9 @@ def expected(): mod = annotated() mod = relay.transform.MergeCompilerRegions()(mod) + mod = relay.transform.InferType()(mod) ref_mod = expected() + ref_mod = relay.transform.InferType()(ref_mod) assert tvm.ir.structural_equal(mod, ref_mod) diff --git a/tests/python/relay/test_pass_partial_eval.py b/tests/python/relay/test_pass_partial_eval.py index f2368fcec1cca..45749c31f38fa 100644 --- a/tests/python/relay/test_pass_partial_eval.py +++ b/tests/python/relay/test_pass_partial_eval.py @@ -26,7 +26,7 @@ from tvm.relay import TensorType, Tuple, If, Clause, PatternConstructor, PatternVar, Match from tvm.relay import GlobalVar, Call from tvm.relay.transform import gradient -from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type +from tvm.relay.testing import make_nat_expr, run_infer_type def check_eval(expr, expected_result, mod=None, rtol=1e-07): @@ -52,7 +52,12 @@ def tipe(expr): def dcpe(expr, mod=None, grad=False): - passes = [transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)] + passes = [ + transform.PartialEvaluate(), + transform.InferType(), + transform.DeadCodeElimination(inline_once=True), + transform.InferType(), + ] if grad: expr = gradient(run_infer_type(expr)) if mod: @@ -175,23 +180,29 @@ def test_head_cons(): p = Prelude(mod) t = TypeVar("t") x = Var("x", t) - body = p.hd(p.cons(x, p.nil())) + rlist, cons, nil = p.mod.get_type("List") + hd = p.mod.get_global_var("hd") + body = hd(cons(x, nil())) f = Function([x], body, None, [t]) res = dcpe(f, mod) - assert tvm.ir.structural_equal(res, Function([x], x, t, [t])) + expected_mod = tvm.IRModule.from_expr(Function([x], x, t, [t])) + assert tvm.ir.structural_equal(res, expected_mod["main"]) def test_map(): mod = tvm.IRModule() p = Prelude(mod) + rlist, cons, nil = p.mod.get_type("List") + rmap = p.mod.get_global_var("map") f = GlobalVar("f") t = TypeVar("t") a = Var("a", t) mod[f] = Function([a], a, t, [t]) - orig = p.map(f, p.cons(const(1), p.cons(const(2), p.cons(const(3), p.nil())))) - expected = p.cons((const(1)), p.cons((const(2)), p.cons((const(3)), p.nil()))) + orig = rmap(f, cons(const(1), cons(const(2), cons(const(3), nil())))) + expected = cons((const(1)), cons((const(2)), cons((const(3)), nil()))) expected = Function([], expected) mod["main"] = expected + mod = transform.InferType()(mod) expected = mod["main"] orig = Function([], orig) res = dcpe(orig, mod=mod) @@ -206,6 +217,7 @@ def test_loop(): mod[loop] = Function([x], loop(x), t, [t]) expected = Call(loop, [const(1)]) mod["main"] = Function([], expected) + mod = transform.InferType()(mod) expected = mod["main"].body call = Function([], loop(const(1))) res = dcpe(call, mod=mod) @@ -215,12 +227,12 @@ def test_loop(): def test_swap_loop(): mod = tvm.IRModule() p = Prelude(mod) - add_nat_definitions(p) - nat = p.nat() - x = Var("x", nat) - y = Var("y", nat) + p.mod.import_from_std("nat.rly") + nat, _, _ = p.mod.get_type("nat") + x = Var("x", nat()) + y = Var("y", nat()) loop = GlobalVar("loop") - mod[loop] = Function([x, y], loop(y, x), nat) + mod[loop] = Function([x, y], loop(y, x), nat()) prog = loop(make_nat_expr(p, 1), make_nat_expr(p, 2)) res = Function([], prog) res = dcpe(res, mod=mod) @@ -231,17 +243,17 @@ def test_abs_diff(): # TODO(@M.K.): refactor using tuple pattern (not yet implemented) mod = tvm.IRModule() p = Prelude(mod) - add_nat_definitions(p) - nat = p.nat() - x = Var("x", nat) - y = Var("y", nat) - xp = Var("x'", nat) - yp = Var("y'", nat) + p.mod.import_from_std("nat.rly") + nat, z, s = p.mod.get_type("nat") + x = Var("x", nat()) + y = Var("y", nat()) + xp = Var("x'", nat()) + yp = Var("y'", nat()) diff = GlobalVar("diff") - y_z_case = Clause(PatternConstructor(p.z, []), x) - y_s_case = Clause(PatternConstructor(p.s, [PatternVar(yp)]), diff(yp, xp)) - x_z_case = Clause(PatternConstructor(p.z, []), y) - x_s_case = Clause(PatternConstructor(p.s, [PatternVar(xp)]), Match(y, [y_z_case, y_s_case])) + y_z_case = Clause(PatternConstructor(z, []), x) + y_s_case = Clause(PatternConstructor(s, [PatternVar(yp)]), diff(yp, xp)) + x_z_case = Clause(PatternConstructor(z, []), y) + x_s_case = Clause(PatternConstructor(s, [PatternVar(xp)]), Match(y, [y_z_case, y_s_case])) mod[diff] = Function([x, y], Match(x, [x_z_case, x_s_case])) orig = diff(make_nat_expr(p, 7), make_nat_expr(p, 3)) orig = Function([], orig) @@ -252,13 +264,13 @@ def test_abs_diff(): def test_match_nat_id(): mod = tvm.IRModule() p = Prelude(mod) - add_nat_definitions(p) - nat = p.nat() - x = Var("x", nat) - y = Var("y", nat) + p.mod.import_from_std("nat.rly") + nat, z, s = p.mod.get_type("nat") + x = Var("x", nat()) + y = Var("y", nat()) nat_id = GlobalVar("nat_id") - z_case = Clause(PatternConstructor(p.z, []), p.z()) - s_case = Clause(PatternConstructor(p.s, [PatternVar(y)]), p.s(y)) + z_case = Clause(PatternConstructor(z, []), z()) + s_case = Clause(PatternConstructor(s, [PatternVar(y)]), s(y)) mod[nat_id] = Function([x], Match(x, [z_case, s_case])) orig = nat_id(make_nat_expr(p, 3)) orig = Function([], orig) @@ -269,10 +281,10 @@ def test_match_nat_id(): def test_nat_id(): mod = tvm.IRModule() p = Prelude(mod) - add_nat_definitions(p) - nat = p.nat() - x = Var("x", nat) - y = Var("y", nat) + p.mod.import_from_std("nat.rly") + nat, _, _ = p.mod.get_type("nat") + x = Var("x", nat()) + y = Var("y", nat()) nat_id = GlobalVar("nat_id") mod[nat_id] = Function([x], x) orig = nat_id(make_nat_expr(p, 3)) @@ -284,11 +296,11 @@ def test_nat_id(): def test_global_match_nat_id(): mod = tvm.IRModule() p = Prelude(mod) - add_nat_definitions(p) - nat = p.nat() - x = Var("x", nat) - z_case = Clause(PatternConstructor(p.z, []), p.z()) - s_case = Clause(PatternConstructor(p.s, [PatternVar(x)]), p.s(x)) + p.mod.import_from_std("nat.rly") + nat, z, s = p.mod.get_type("nat") + x = Var("x", nat()) + z_case = Clause(PatternConstructor(z, []), z()) + s_case = Clause(PatternConstructor(s, [PatternVar(x)]), s(x)) orig = Match(make_nat_expr(p, 3), [z_case, s_case]) orig = Function([], orig) res = dcpe(orig, mod=mod) @@ -298,8 +310,9 @@ def test_global_match_nat_id(): def test_double(): mod = tvm.IRModule() p = Prelude(mod) - add_nat_definitions(p) - orig = p.double(make_nat_expr(p, 3)) + p.mod.import_from_std("nat.rly") + double = p.mod.get_global_var("nat_double") + orig = double(make_nat_expr(p, 3)) orig = Function([], orig) res = dcpe(orig, mod=mod) assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 6)) @@ -325,7 +338,7 @@ def test_triangle_number(): def test_nat_update(): m = tvm.IRModule() p = Prelude(m) - add_nat_definitions(p) + p.mod.import_from_std("nat.rly") m = transform.ToANormalForm()(m) transform.PartialEvaluate()(m) diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 75218c5a9855d..2fd440e1c2c9b 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -349,6 +349,7 @@ def expected(): fused_call = relay.Call(fused_func, [add_call]) main = relay.Function([x, y], fused_call) mod["main"] = main + mod = transform.InferType()(mod) return mod x = relay.var("x", shape=(8, 8)) @@ -415,11 +416,13 @@ def expected(): glb_var = relay.GlobalVar("dnnl_0") mod = tvm.IRModule() mod[glb_var] = func + mod = transform.InferType()(mod) data = relay.var("data", shape=(ishape), dtype=dtype) weight = relay.var("input", shape=(w1shape), dtype=dtype) main_f = relay.Function([data, weight], glb_var(data, weight)) mod["main"] = main_f + mod = transform.InferType()(mod) return mod @@ -439,6 +442,7 @@ def get_func(): mod = tvm.IRModule() mod["main"] = WholeGraphAnnotator("dnnl").visit(get_func()) mod = transform.PartitionGraph()(mod) + mod = transform.InferType()(mod) assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True) @@ -494,6 +498,7 @@ def partition(): ) mod = tvm.IRModule() mod["main"] = func + mod = relay.transform.InferType()(mod) op_list = ["nn.batch_norm", "nn.conv2d"] mod = WhiteListAnnotator(op_list, "test_compiler")(mod) @@ -526,6 +531,7 @@ def expected(): func0 = set_func_attr(func0, "test_compiler", "test_compiler_2") gv0 = relay.GlobalVar("test_compiler_2") mod[gv0] = func0 + mod = transform.InferType()(mod) # function for conv2d data1 = relay.var("data1", relay.TensorType((1, 3, 224, 224), "float32")) @@ -537,6 +543,7 @@ def expected(): func1 = set_func_attr(func1, "test_compiler", "test_compiler_0") gv1 = relay.GlobalVar("test_compiler_0") mod[gv1] = func1 + mod = transform.InferType()(mod) # main function data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32")) @@ -635,10 +642,12 @@ def expected(): func = set_func_attr(func, "ccompiler", "ccompiler_0") glb_0 = relay.GlobalVar("ccompiler_0") mod[glb_0] = func + mod = relay.transform.InferType()(mod) add_call = relay.Call(glb_0, [y]) log = relay.log(add_call) main = relay.Function([y], log) mod["main"] = main + mod = relay.transform.InferType()(mod) return mod x = relay.var("x", shape=(8, 8)) @@ -651,8 +660,10 @@ def expected(): mod["main"] = f mod = WhiteListAnnotator(["add"], "ccompiler")(mod) mod = transform.PartitionGraph()(mod) + mod = relay.transform.InferType()(mod) expected_mod = expected() + expected_mod = relay.transform.InferType()(expected_mod) assert tvm.ir.structural_equal(mod, expected_mod, map_free_vars=True) y_data = np.random.rand(8, 8).astype("float32") @@ -721,6 +732,7 @@ def expected(): func0 = set_func_attr(func0, "test_target", "test_target_0") gv0 = relay.GlobalVar("test_target_0") mod[gv0] = func0 + mod = relay.transform.InferType()(mod) # body data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32")) @@ -741,6 +753,7 @@ def expected(): func = relay.Function([data, weight, bn_gamma, bn_beta, bn_mean, bn_var], main_tuple) mod["main"] = func + mod = relay.transform.InferType()(mod) return mod mod = tvm.IRModule() @@ -782,6 +795,7 @@ def expected(): func1 = set_func_attr(func1, "test_target", "test_target_0") gv1 = relay.GlobalVar("test_target_0") mod[gv1] = func1 + mod = relay.transform.InferType()(mod) # function 0 f2_cb3 = relay.var("test_target_1_i0", shape=(10, 10)) @@ -791,6 +805,7 @@ def expected(): func0 = set_func_attr(func0, "test_target", "test_target_1") gv0 = relay.GlobalVar("test_target_1") mod[gv0] = func0 + mod = relay.transform.InferType()(mod) # body data = relay.var("data", shape=(10, 10)) @@ -802,13 +817,15 @@ def expected(): ce_4 = gv0(ce_3, X) func = relay.Function([data], ce_4) mod["main"] = func - + mod = relay.transform.InferType()(mod) return mod mod = tvm.IRModule() mod["main"] = create_graph() + mod = transform.InferType()(mod) ref_mod = expected() + partitioned = transform.PartitionGraph()(mod) assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True) @@ -937,10 +954,13 @@ def expected_same_output_region(): func = set_func_attr(func, "ccompiler", "ccompiler_0") glb_0 = relay.GlobalVar("ccompiler_0") mod[glb_0] = func + mod = transform.InferType()(mod) + add = x + y call = relay.Call(glb_0, [add, z]) main = relay.Function([x, y, z], call) mod["main"] = main + mod = transform.InferType()(mod) return mod def expected_different_output_region(): @@ -956,6 +976,7 @@ def expected_different_output_region(): func = set_func_attr(func, "ccompiler", "ccompiler_0") glb_0 = relay.GlobalVar("ccompiler_0") mod[glb_0] = func + mod = transform.InferType()(mod) # The partitioned graph contains subtract x0 = relay.var("x0", shape=(8, 8)) @@ -965,12 +986,14 @@ def expected_different_output_region(): func = set_func_attr(func, "ccompiler", "ccompiler_1") glb_1 = relay.GlobalVar("ccompiler_1") mod[glb_1] = func + mod = transform.InferType()(mod) add = x + y call_log = relay.Call(glb_0, [add]) call_sub = relay.Call(glb_1, [add, z]) main = relay.Function([x, y, z], call_log * call_sub) mod["main"] = main + mod = transform.InferType()(mod) return mod def get_mod(): @@ -1039,6 +1062,7 @@ def expected(): func0 = func0.with_attr("global_symbol", target + "_0") gv0 = relay.GlobalVar(target + "_0") mod[gv0] = func0 + mod = transform.InferType()(mod) # body data = relay.var("data", shape=(10, 10)) @@ -1049,6 +1073,7 @@ def expected(): out = relay.Tuple([out_1, out_2, out_3]) func = relay.Function([data], out) mod["main"] = func + mod = transform.InferType()(mod) return mod mod = tvm.IRModule() @@ -1114,6 +1139,7 @@ def expected(): func0 = func0.with_attr("global_symbol", target + "_0") gv0 = relay.GlobalVar(target + "_0") mod[gv0] = func0 + mod = transform.InferType()(mod) # body data = relay.var("data", shape=(10, 10)) @@ -1129,10 +1155,12 @@ def expected(): out = relay.Tuple([get_out0, out_2, out_3]) func = relay.Function([data, bn_gamma, bn_beta, bn_mmean, bn_mvar], out) mod["main"] = func + mod = transform.InferType()(mod) return mod mod = tvm.IRModule() mod["main"] = create_graph() + mod = transform.InferType()(mod) seq = tvm.transform.Sequential( [ @@ -1170,17 +1198,20 @@ def create_graph(): f = relay.Function([a, b], con) mod = tvm.IRModule.from_expr(f) + mod = transform.InferType()(mod) return mod seq = tvm.transform.Sequential( [ transform.AnnotateTarget("const_tuples"), + transform.InferType(), transform.MergeCompilerRegions(), transform.PartitionGraph(), ] ) partitioned = seq(create_graph()) + concat = partitioned["const_tuples_0"].body assert type(concat.args[1]) == relay.Tuple assert type(concat.args[2]) == relay.Tuple @@ -1212,6 +1243,7 @@ def create_graph(): out = relay.Tuple((a_con, a_split_0_relu)) f = relay.Function([a], out) mod = tvm.IRModule.from_expr(f) + mod = transform.InferType()(mod) return mod def expected(): @@ -1233,6 +1265,7 @@ def expected(): func0 = func0.with_attr("global_symbol", target + "_0") gv0 = relay.GlobalVar(target + "_0") mod[gv0] = func0 + mod = transform.InferType()(mod) # body data = relay.var("a", shape=(10, 10), dtype="uint8") @@ -1245,6 +1278,7 @@ def expected(): relu = relay.nn.relu(f_out_2) ret_tuple = relay.Tuple((concat, relu)) mod["main"] = relay.Function([data], ret_tuple) + mod = transform.InferType()(mod) return mod seq = tvm.transform.Sequential( @@ -1256,7 +1290,9 @@ def expected(): ) partitioned = seq(create_graph()) - assert tvm.ir.structural_equal(partitioned, expected(), map_free_vars=True) + partitioned = transform.InferType()(partitioned) + expected_mod = transform.InferType()(expected()) + assert tvm.ir.structural_equal(partitioned, expected_mod, map_free_vars=True) def test_tuple_output_exec(): @@ -1270,8 +1306,10 @@ def test_tuple_output_exec(): out = relay.Tuple((add, sub)) eout = relay.annotation.compiler_end(out, "ccompiler") func = relay.Function([a, b], eout) + mod = tvm.IRModule() mod["main"] = func + mod = transform.InferType()(mod) mod = transform.PartitionGraph()(mod) a_data = np.random.rand(10, 10).astype("float32") @@ -1303,6 +1341,7 @@ def Optimize(mod): f = bind_params_by_name(f, {"y0": tvm.nd.array(c), "y1": tvm.nd.array(c)}) mod = tvm.IRModule() mod["main"] = f + mod = transform.InferType()(mod) mod = transform.PartitionGraph()(mod) try: diff --git a/tests/python/relay/test_pass_qnn_legalize.py b/tests/python/relay/test_pass_qnn_legalize.py index 37da3ab1f2e22..6a5c8f7cd6476 100644 --- a/tests/python/relay/test_pass_qnn_legalize.py +++ b/tests/python/relay/test_pass_qnn_legalize.py @@ -138,6 +138,7 @@ def _get_mod(data_dtype, kernel_dtype): ############################################################# # Check that Intel VNNI gets picked up. with tvm.target.Target("llvm -mcpu=skylake-avx512"): + mod = relay.transform.InferType()(mod) legalized_mod = relay.qnn.transform.Legalize()(mod) assert "cast" in legalized_mod.astext() and "qnn.conv2d" in legalized_mod.astext() @@ -168,6 +169,7 @@ def _get_mod(data_dtype, kernel_dtype): ############################################################# # Check no transformation for Intel VNNI. with tvm.target.Target("llvm -mcpu=skylake-avx512"): + mod = relay.transform.InferType()(mod) legalized_mod = relay.qnn.transform.Legalize()(mod) assert tvm.ir.structural_equal(mod, legalized_mod) @@ -229,6 +231,7 @@ def _get_mod(data_dtype, kernel_dtype): ############################################################# # Check that Intel VNNI gets picked up. with tvm.target.Target("llvm -mcpu=skylake-avx512"): + mod = relay.transform.InferType()(mod) legalized_mod = relay.qnn.transform.Legalize()(mod) assert "cast" in legalized_mod.astext() and "qnn.dense" in legalized_mod.astext() @@ -259,6 +262,7 @@ def _get_mod(data_dtype, kernel_dtype): ############################################################# # Check no transformation for Intel VNNI. with tvm.target.Target("llvm -mcpu=skylake-avx512"): + mod = relay.transform.InferType()(mod) legalized_mod = relay.qnn.transform.Legalize()(mod) assert tvm.ir.structural_equal(mod, legalized_mod) diff --git a/tests/python/relay/test_pass_remove_unused_functions.py b/tests/python/relay/test_pass_remove_unused_functions.py index 271dc8e0f2982..0764a88b31591 100644 --- a/tests/python/relay/test_pass_remove_unused_functions.py +++ b/tests/python/relay/test_pass_remove_unused_functions.py @@ -49,10 +49,13 @@ def test_remove_all_prelude_functions_but_referenced_functions(): def test_keep_only_referenced_prelude_functions(): mod = tvm.IRModule() p = Prelude(mod) - l = p.nil() + _, cons, nil = p.mod.get_type("List") + hd = p.mod.get_global_var("hd") + tl = p.mod.get_global_var("tl") + l = nil() for i in [4, 3, 2, 1, 0]: - l = p.cons(relay.const(i), l) - body = p.hd(p.tl(p.tl(l))) + l = cons(relay.const(i), l) + body = hd(tl(tl(l))) mod["main"] = relay.Function([], body) mod = relay.transform.RemoveUnusedFunctions()(mod) l = set([x[0].name_hint for x in mod.functions.items()]) @@ -62,10 +65,13 @@ def test_keep_only_referenced_prelude_functions(): def test_multiple_entry_functions(): mod = tvm.IRModule() p = Prelude(mod) - l = p.nil() + _, cons, nil = p.mod.get_type("List") + hd = p.mod.get_global_var("hd") + tl = p.mod.get_global_var("tl") + l = nil() for i in [4, 3, 2, 1, 0]: - l = p.cons(relay.const(i), l) - body = p.hd(p.tl(p.tl(l))) + l = cons(relay.const(i), l) + body = hd(tl(tl(l))) mod["main1"] = relay.Function([], body) x = relay.var("x", shape=(1, 16)) @@ -81,10 +87,10 @@ def test_multiple_entry_functions(): def test_globalvar_as_call_arg(): mod = tvm.IRModule() p = Prelude(mod) - tensor_array = p.get_var("tensor_array", "int32") - tensor1 = p.get_var("tensor1", "int32") - write = p.get_var("tensor_array_write", "int32") - stack = p.get_var("tensor_array_stack", "int32") + tensor_array = p.get_global_var("tensor_array", "int32") + tensor1 = p.get_ctor(p.get_name("tensor_t", "int32"), "tensor1", "int32") + write = p.get_global_var("tensor_array_write", "int32") + stack = p.get_global_var("tensor_array_stack", "int32") v = relay.var("v") init_tensor_array = tensor_array(relay.const(3)) tensor_array1 = write(init_tensor_array, relay.const(0), tensor1(v)) diff --git a/tests/python/relay/test_pass_simplify_inference.py b/tests/python/relay/test_pass_simplify_inference.py index f557e2c5ac4de..24a63e97b30ed 100644 --- a/tests/python/relay/test_pass_simplify_inference.py +++ b/tests/python/relay/test_pass_simplify_inference.py @@ -16,7 +16,7 @@ # under the License. from tvm.ir import IRModule, structural_equal from tvm import relay as rly -from tvm.relay.transform import SimplifyInference +from tvm.relay.transform import SimplifyInference, InferType def test_simplify_batchnorm(dtype="float32"): @@ -66,7 +66,9 @@ def check(dim, axis, nstep): ) mod = IRModule.from_expr(y1) + simplify = SimplifyInference() + mod = InferType()(mod) mod = simplify(mod) y1 = mod["main"].body diff --git a/tests/python/relay/test_pass_to_a_normal_form.py b/tests/python/relay/test_pass_to_a_normal_form.py index 7f5d6a4c4ec59..72325e537c0ed 100644 --- a/tests/python/relay/test_pass_to_a_normal_form.py +++ b/tests/python/relay/test_pass_to_a_normal_form.py @@ -21,7 +21,7 @@ from tvm.relay.analysis import detect_feature from tvm.relay import op, create_executor, transform from tvm.relay.prelude import Prelude -from tvm.relay.testing import add_nat_definitions, count +from tvm.relay.testing import count from tvm.relay.analysis import Feature @@ -146,11 +146,9 @@ def test_ref(): def test_nat_add(): mod = tvm.IRModule() p = Prelude(mod) - add_nat_definitions(p) - nat = p.nat - add = p.add - s = p.s - z = p.z + p.mod.import_from_std("nat.rly") + nat, z, s = p.mod.get_type("nat") + add = p.mod.get_global_var("nat_add") ctx = tvm.context("llvm", 0) intrp = create_executor(mod=mod, ctx=ctx, target="llvm") assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat()) @@ -195,6 +193,7 @@ def test_gradient_if(): net = relay.Function([cond, x, y], net) mod = tvm.IRModule.from_expr(net) mod = relay.transform.ToANormalForm()(mod) + mod = relay.transform.InferType()(mod) mod["main"] = relay.transform.gradient(mod["main"], mode="higher_order") mod = relay.transform.ToANormalForm()(mod) diff --git a/tests/python/relay/test_pass_to_basic_block_normal_form.py b/tests/python/relay/test_pass_to_basic_block_normal_form.py index dafd1d153d39d..a52d51ad49606 100644 --- a/tests/python/relay/test_pass_to_basic_block_normal_form.py +++ b/tests/python/relay/test_pass_to_basic_block_normal_form.py @@ -22,7 +22,7 @@ from tvm.relay.analysis import detect_feature from tvm.relay import op, create_executor, transform from tvm.relay.prelude import Prelude -from tvm.relay.testing import add_nat_definitions, count +from tvm.relay.testing import count from tvm.relay.analysis import Feature from tvm.relay.analysis import check_basic_block_normal_form @@ -116,7 +116,7 @@ def expected(): body = relay.Let(z2, relay.add(z, z), body) return body - bblock = run_opt_pass(body, [transform.ToBasicBlockNormalForm()]) + bblock = run_opt_pass(body, [transform.ToBasicBlockNormalForm(), transform.InferType()]) """ free_var %z: float32 let %x: float32 = add(%z, %z) /* ty=float32 */; @@ -187,7 +187,7 @@ def expected(): body = relay.If(x, true_branch, false_branch) return body - bblock = run_opt_pass(body, [transform.ToBasicBlockNormalForm()]) + bblock = run_opt_pass(body, [transform.ToBasicBlockNormalForm(), transform.InferType()]) """ free_var %x: bool if (%x) { @@ -263,11 +263,9 @@ def test_ref(): def test_nat_add(): mod = tvm.IRModule() p = Prelude(mod) - add_nat_definitions(p) - nat = p.nat - add = p.add - s = p.s - z = p.z + p.mod.import_from_std("nat.rly") + nat, z, s = p.mod.get_type("nat") + add = p.mod.get_global_var("nat_add") ctx = tvm.context("llvm", 0) intrp = create_executor(mod=mod, ctx=ctx, target="llvm") assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat()) @@ -275,6 +273,7 @@ def test_nat_add(): expr = add(s(z()), s(z())) f = relay.GlobalVar("f") mod[f] = relay.Function([], expr) + mod = transform.InferType()(mod) mod = transform.ToBasicBlockNormalForm()(mod) opt_expr = mod["f"] assert count(p, intrp.evaluate(opt_expr.body)) == 2 @@ -368,6 +367,7 @@ def test_gradient_if(): net = relay.Function([cond, x, y], net) mod = tvm.IRModule.from_expr(net) mod = relay.transform.ToBasicBlockNormalForm()(mod) + mod = relay.transform.InferType()(mod) net_grad = relay.transform.gradient(mod["main"], mode="higher_order") mod["main"] = net_grad mod_grad = relay.transform.ToBasicBlockNormalForm()(mod) @@ -420,14 +420,14 @@ def expected_if_expr(x): x = relay.var("x", shape=(), dtype="float32") body = if_expr(x) expected_body = expected_if_expr(x) - bblock = run_opt_pass(body, transform.ToBasicBlockNormalForm()) + bblock = run_opt_pass(body, [transform.ToBasicBlockNormalForm(), transform.InferType()]) expected_bblock = run_opt_pass(expected_body, transform.InferType()) assert tvm.ir.structural_equal(bblock, expected_bblock, map_free_vars=True) check_basic_block_normal_form(bblock) func = relay.Function([x], body) expected_func = relay.Function([x], expected_body) - bblock = run_opt_pass(func, transform.ToBasicBlockNormalForm()) + bblock = run_opt_pass(func, [transform.ToBasicBlockNormalForm(), transform.InferType()]) expected_bblock = run_opt_pass(expected_func, transform.InferType()) assert tvm.ir.structural_equal(bblock, expected_bblock) check_basic_block_normal_form(bblock) diff --git a/tests/python/relay/test_pass_to_cps.py b/tests/python/relay/test_pass_to_cps.py index 85e3a622bd9f5..023bcb224d2bb 100644 --- a/tests/python/relay/test_pass_to_cps.py +++ b/tests/python/relay/test_pass_to_cps.py @@ -21,7 +21,7 @@ from tvm.relay.transform import to_cps, un_cps from tvm.relay.analysis import Feature from tvm.relay.prelude import Prelude -from tvm.relay.testing import add_nat_definitions, make_nat_expr, rand, run_infer_type, run_opt_pass +from tvm.relay.testing import make_nat_expr, rand, run_infer_type, run_opt_pass from tvm.relay import create_executor from tvm.relay import transform @@ -44,16 +44,19 @@ def test_double(): def test_recursion(): mod = tvm.IRModule() p = Prelude(mod) - add_nat_definitions(p) + p.mod.import_from_std("nat.rly") + nat_iterate = p.mod.get_global_var("nat_iterate") shape = (10, 10) dtype = "float32" t = relay.TensorType(shape, dtype) x = relay.var("x", t) double = relay.Function([x], x + x) i = relay.var("i", t) - func = relay.Function([i], p.nat_iterate(double, make_nat_expr(p, 3))(i)) + func = relay.Function([i], nat_iterate(double, make_nat_expr(p, 3))(i)) mod["main"] = func + mod = relay.transform.InferType()(mod) mod["main"] = to_cps(mod["main"], mod=mod) + mod = relay.transform.InferType()(mod) mod["main"] = un_cps(mod["main"]) ex = create_executor(mod=mod) i_nd = rand(dtype, *shape) diff --git a/tests/python/relay/test_pass_unmatched_cases.py b/tests/python/relay/test_pass_unmatched_cases.py index a410347adc95a..c6b4deb0b2c2b 100644 --- a/tests/python/relay/test_pass_unmatched_cases.py +++ b/tests/python/relay/test_pass_unmatched_cases.py @@ -85,6 +85,7 @@ def test_single_constructor_adt(): def test_too_specific_match(): mod = tvm.IRModule() p = Prelude(mod) + _, cons, nil = mod.get_type("List") v = relay.Var("v") match = relay.Match( @@ -92,11 +93,11 @@ def test_too_specific_match(): [ relay.Clause( relay.PatternConstructor( - p.cons, + cons, [ relay.PatternWildcard(), relay.PatternConstructor( - p.cons, [relay.PatternWildcard(), relay.PatternWildcard()] + cons, [relay.PatternWildcard(), relay.PatternWildcard()] ), ], ), @@ -113,11 +114,11 @@ def test_too_specific_match(): assert len(unmatched) == 2 for case in unmatched: assert isinstance(case, relay.PatternConstructor) - if case.constructor == p.nil: + if case.constructor == nil: nil_found = True - if case.constructor == p.cons: + if case.constructor == cons: assert isinstance(case.patterns[1], relay.PatternConstructor) - assert case.patterns[1].constructor == p.nil + assert case.patterns[1].constructor == nil single_length_found = True assert nil_found and single_length_found @@ -127,11 +128,11 @@ def test_too_specific_match(): [ relay.Clause( relay.PatternConstructor( - p.cons, + cons, [ relay.PatternWildcard(), relay.PatternConstructor( - p.cons, [relay.PatternWildcard(), relay.PatternWildcard()] + cons, [relay.PatternWildcard(), relay.PatternWildcard()] ), ], ), @@ -146,6 +147,7 @@ def test_too_specific_match(): def test_multiple_constructor_clauses(): mod = tvm.IRModule() p = Prelude(mod) + _, cons, nil = mod.get_type("List") v = relay.Var("v") match = relay.Match( @@ -154,33 +156,33 @@ def test_multiple_constructor_clauses(): # list of length exactly 1 relay.Clause( relay.PatternConstructor( - p.cons, [relay.PatternWildcard(), relay.PatternConstructor(p.nil, [])] + cons, [relay.PatternWildcard(), relay.PatternConstructor(nil, [])] ), v, ), # list of length exactly 2 relay.Clause( relay.PatternConstructor( - p.cons, + cons, [ relay.PatternWildcard(), relay.PatternConstructor( - p.cons, [relay.PatternWildcard(), relay.PatternConstructor(p.nil, [])] + cons, [relay.PatternWildcard(), relay.PatternConstructor(nil, [])] ), ], ), v, ), # empty list - relay.Clause(relay.PatternConstructor(p.nil, []), v), + relay.Clause(relay.PatternConstructor(nil, []), v), # list of length 2 or more relay.Clause( relay.PatternConstructor( - p.cons, + cons, [ relay.PatternWildcard(), relay.PatternConstructor( - p.cons, [relay.PatternWildcard(), relay.PatternWildcard()] + cons, [relay.PatternWildcard(), relay.PatternWildcard()] ), ], ), @@ -194,6 +196,7 @@ def test_multiple_constructor_clauses(): def test_missing_in_the_middle(): mod = tvm.IRModule() p = Prelude(mod) + _, cons, nil = mod.get_type("List") v = relay.Var("v") match = relay.Match( @@ -202,24 +205,24 @@ def test_missing_in_the_middle(): # list of length exactly 1 relay.Clause( relay.PatternConstructor( - p.cons, [relay.PatternWildcard(), relay.PatternConstructor(p.nil, [])] + cons, [relay.PatternWildcard(), relay.PatternConstructor(nil, [])] ), v, ), # empty list - relay.Clause(relay.PatternConstructor(p.nil, []), v), + relay.Clause(relay.PatternConstructor(nil, []), v), # list of length 3 or more relay.Clause( relay.PatternConstructor( - p.cons, + cons, [ relay.PatternWildcard(), relay.PatternConstructor( - p.cons, + cons, [ relay.PatternWildcard(), relay.PatternConstructor( - p.cons, [relay.PatternWildcard(), relay.PatternWildcard()] + cons, [relay.PatternWildcard(), relay.PatternWildcard()] ), ], ), @@ -234,11 +237,11 @@ def test_missing_in_the_middle(): unmatched = unmatched_cases(match, mod) assert len(unmatched) == 1 assert isinstance(unmatched[0], relay.PatternConstructor) - assert unmatched[0].constructor == p.cons + assert unmatched[0].constructor == cons assert isinstance(unmatched[0].patterns[1], relay.PatternConstructor) - assert unmatched[0].patterns[1].constructor == p.cons + assert unmatched[0].patterns[1].constructor == cons assert isinstance(unmatched[0].patterns[1].patterns[1], relay.PatternConstructor) - assert unmatched[0].patterns[1].patterns[1].constructor == p.nil + assert unmatched[0].patterns[1].patterns[1].constructor == nil def test_mixed_adt_constructors(): @@ -250,6 +253,7 @@ def test_mixed_adt_constructors(): mod[box] = box_data p = Prelude(mod) + _, cons, nil = p.mod.get_type("List") v = relay.Var("v") box_of_lists_inc = relay.Match( @@ -260,7 +264,7 @@ def test_mixed_adt_constructors(): box_ctor, [ relay.PatternConstructor( - p.cons, [relay.PatternWildcard(), relay.PatternWildcard()] + cons, [relay.PatternWildcard(), relay.PatternWildcard()] ) ], ), @@ -274,20 +278,20 @@ def test_mixed_adt_constructors(): assert len(unmatched) == 1 assert isinstance(unmatched[0], relay.PatternConstructor) assert unmatched[0].constructor == box_ctor - assert len(unmatched[0].patterns) == 1 and unmatched[0].patterns[0].constructor == p.nil + assert len(unmatched[0].patterns) == 1 and unmatched[0].patterns[0].constructor == nil box_of_lists_comp = relay.Match( v, [ relay.Clause( - relay.PatternConstructor(box_ctor, [relay.PatternConstructor(p.nil, [])]), v + relay.PatternConstructor(box_ctor, [relay.PatternConstructor(nil, [])]), v ), relay.Clause( relay.PatternConstructor( box_ctor, [ relay.PatternConstructor( - p.cons, [relay.PatternWildcard(), relay.PatternWildcard()] + cons, [relay.PatternWildcard(), relay.PatternWildcard()] ) ], ), @@ -302,7 +306,7 @@ def test_mixed_adt_constructors(): [ relay.Clause( relay.PatternConstructor( - p.cons, + cons, [ relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]), relay.PatternWildcard(), @@ -317,7 +321,7 @@ def test_mixed_adt_constructors(): unmatched = unmatched_cases(list_of_boxes_inc, mod) assert len(unmatched) == 1 assert isinstance(unmatched[0], relay.PatternConstructor) - assert unmatched[0].constructor == p.nil + assert unmatched[0].constructor == nil list_of_boxes_comp = relay.Match( v, @@ -325,10 +329,10 @@ def test_mixed_adt_constructors(): # exactly one box relay.Clause( relay.PatternConstructor( - p.cons, + cons, [ relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]), - relay.PatternConstructor(p.nil, []), + relay.PatternConstructor(nil, []), ], ), v, @@ -336,14 +340,14 @@ def test_mixed_adt_constructors(): # exactly two boxes relay.Clause( relay.PatternConstructor( - p.cons, + cons, [ relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]), relay.PatternConstructor( - p.cons, + cons, [ relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]), - relay.PatternConstructor(p.nil, []), + relay.PatternConstructor(nil, []), ], ), ], @@ -353,20 +357,20 @@ def test_mixed_adt_constructors(): # exactly three boxes relay.Clause( relay.PatternConstructor( - p.cons, + cons, [ relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]), relay.PatternConstructor( - p.cons, + cons, [ relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]), relay.PatternConstructor( - p.cons, + cons, [ relay.PatternConstructor( box_ctor, [relay.PatternWildcard()] ), - relay.PatternConstructor(p.nil, []), + relay.PatternConstructor(nil, []), ], ), ], @@ -377,13 +381,11 @@ def test_mixed_adt_constructors(): ), # one or more boxes relay.Clause( - relay.PatternConstructor( - p.cons, [relay.PatternWildcard(), relay.PatternWildcard()] - ), + relay.PatternConstructor(cons, [relay.PatternWildcard(), relay.PatternWildcard()]), v, ), # no boxes - relay.Clause(relay.PatternConstructor(p.nil, []), v), + relay.Clause(relay.PatternConstructor(nil, []), v), ], ) assert len(unmatched_cases(list_of_boxes_comp, mod)) == 0 diff --git a/tests/python/relay/test_pass_vars.py b/tests/python/relay/test_pass_vars.py index a5a0e506d0675..d823f6ea4bff5 100644 --- a/tests/python/relay/test_pass_vars.py +++ b/tests/python/relay/test_pass_vars.py @@ -91,27 +91,28 @@ def test_bound_vars(): def test_match_vars(): mod = tvm.IRModule() p = relay.prelude.Prelude(mod) + rlist, cons, nil = p.mod.get_type("List") x = relay.Var("x") y = relay.Var("y") z = relay.Var("z") match1 = relay.Match( - p.nil(), + nil(), [ - relay.Clause(relay.PatternConstructor(p.nil), z), + relay.Clause(relay.PatternConstructor(nil), z), relay.Clause( - relay.PatternConstructor(p.cons, [relay.PatternVar(x), relay.PatternVar(y)]), - p.cons(x, y), + relay.PatternConstructor(cons, [relay.PatternVar(x), relay.PatternVar(y)]), + cons(x, y), ), ], ) match2 = relay.Match( - p.nil(), + nil(), [ relay.Clause( - relay.PatternConstructor(p.cons, [relay.PatternWildcard(), relay.PatternVar(x)]), y + relay.PatternConstructor(cons, [relay.PatternWildcard(), relay.PatternVar(x)]), y ), relay.Clause(relay.PatternWildcard(), z), ], diff --git a/tests/python/relay/test_py_converter.py b/tests/python/relay/test_py_converter.py index 4eefa7116515e..dda7471bcd52d 100644 --- a/tests/python/relay/test_py_converter.py +++ b/tests/python/relay/test_py_converter.py @@ -335,6 +335,7 @@ def test_match_order(): def test_local_recursion(): mod = tvm.IRModule() p = Prelude(mod) + _, cons, nil = p.mod.get_type("List") v = relay.Var("v") h = relay.Var("h") @@ -350,35 +351,35 @@ def test_local_recursion(): v, [ relay.Clause( - relay.PatternConstructor( - p.cons, [relay.PatternVar(h), relay.PatternVar(t)] - ), - p.cons(h, f(t)), + relay.PatternConstructor(cons, [relay.PatternVar(h), relay.PatternVar(t)]), + cons(h, f(t)), ), - relay.Clause(relay.PatternConstructor(p.nil, []), p.nil()), + relay.Clause(relay.PatternConstructor(nil, []), nil()), ], ), ), - f(p.cons(relay.const(1), p.cons(relay.const(2), p.cons(relay.const(3), p.nil())))), + f(cons(relay.const(1), cons(relay.const(2), cons(relay.const(3), nil())))), ) val = run_as_python(let, mod) - assert_constructor_value(val, p.cons, 2) + assert_constructor_value(val, cons, 2) assert_tensor_value(val.fields[0], 1) - assert_constructor_value(val.fields[1], p.cons, 2) + assert_constructor_value(val.fields[1], cons, 2) assert_tensor_value(val.fields[1].fields[0], 2) - assert_constructor_value(val.fields[1].fields[1], p.cons, 2) + assert_constructor_value(val.fields[1].fields[1], cons, 2) assert_tensor_value(val.fields[1].fields[1].fields[0], 3) - assert_constructor_value(val.fields[1].fields[1].fields[1], p.nil, 0) + assert_constructor_value(val.fields[1].fields[1].fields[1], nil, 0) def test_global_recursion(): mod = tvm.IRModule() p = Prelude(mod) + rlist, cons, nil = p.mod.get_type("List") + copy = relay.GlobalVar("copy") # same as above: it copies the given list a = relay.TypeVar("a") - v = relay.Var("v", p.l(a)) + v = relay.Var("v", rlist(a)) h = relay.Var("h") t = relay.Var("t") copy_def = relay.Function( @@ -387,30 +388,30 @@ def test_global_recursion(): v, [ relay.Clause( - relay.PatternConstructor(p.cons, [relay.PatternVar(h), relay.PatternVar(t)]), - p.cons(h, copy(t)), + relay.PatternConstructor(cons, [relay.PatternVar(h), relay.PatternVar(t)]), + cons(h, copy(t)), ), - relay.Clause(relay.PatternConstructor(p.nil, []), p.nil()), + relay.Clause(relay.PatternConstructor(nil, []), nil()), ], ), - p.l(a), + rlist(a), [a], ) mod[copy] = copy_def - call1 = copy_def(p.cons(relay.const(1), p.cons(relay.const(2), p.nil()))) + call1 = copy_def(cons(relay.const(1), cons(relay.const(2), nil()))) val1 = run_as_python(call1, mod) - assert_constructor_value(val1, p.cons, 2) + assert_constructor_value(val1, cons, 2) assert_tensor_value(val1.fields[0], 1) - assert_constructor_value(val1.fields[1], p.cons, 2) + assert_constructor_value(val1.fields[1], cons, 2) assert_tensor_value(val1.fields[1].fields[0], 2) - assert_constructor_value(val1.fields[1].fields[1], p.nil, 0) + assert_constructor_value(val1.fields[1].fields[1], nil, 0) - call2 = copy_def(p.cons(relay.Tuple([]), p.nil())) + call2 = copy_def(cons(relay.Tuple([]), nil())) val2 = run_as_python(call2, mod) - assert_constructor_value(val2, p.cons, 2) + assert_constructor_value(val2, cons, 2) assert_adt_len(val2.fields[0], 0) - assert_constructor_value(val2.fields[1], p.nil, 0) + assert_constructor_value(val2.fields[1], nil, 0) def test_higher_order_call(): @@ -439,21 +440,22 @@ def test_higher_order_call(): def test_match_effect_exactly_once(): mod = tvm.IRModule() p = Prelude(mod) + _, cons, nil = p.mod.get_type("List") # the list should be of length 1! # Unless we mistakenly execute the data clause more than once r = relay.Var("r") - data = seq(relay.RefWrite(r, p.cons(relay.Tuple([]), relay.RefRead(r))), relay.RefRead(r)) + data = seq(relay.RefWrite(r, cons(relay.Tuple([]), relay.RefRead(r))), relay.RefRead(r)) match = relay.Let( r, - relay.RefCreate(p.nil()), + relay.RefCreate(nil()), relay.Match( data, [ - relay.Clause(relay.PatternConstructor(p.nil, []), relay.const(0)), + relay.Clause(relay.PatternConstructor(nil, []), relay.const(0)), relay.Clause( relay.PatternConstructor( - p.cons, [relay.PatternWildcard(), relay.PatternConstructor(p.nil, [])] + cons, [relay.PatternWildcard(), relay.PatternConstructor(nil, [])] ), relay.const(1), ), diff --git a/tests/python/relay/test_tensor_array.py b/tests/python/relay/test_tensor_array.py new file mode 100644 index 0000000000000..76e9d4a6d8a04 --- /dev/null +++ b/tests/python/relay/test_tensor_array.py @@ -0,0 +1,785 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +from tvm import relay +from tvm.relay import testing +from tvm.relay.backend.interpreter import ConstructorValue +from tvm.relay import create_executor +from tvm.relay.prelude import Prelude, StaticTensorArrayOps +from tvm.relay.testing import count as count_, make_nat_value, make_nat_expr + +import numpy as np + + +def vmobj_to_list(mod, o, dtype="float32"): + _, tensor_nil, _, _, _, _, _, _, _ = mod.get_type(f"tensor_{dtype}_t") + if isinstance(o, tvm.nd.NDArray): + return [o.asnumpy().tolist()] + elif isinstance(o, tvm.runtime.container.ADT): + if len(o) == 0: + if tensor_nil.tag == o.tag: + return [0] + return [] + + result = [] + for f in o: + result.extend(vmobj_to_list(mod, f, dtype)) + return result + elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue): + if o.constructor.name_hint == "Cons": + tl = vmobj_to_list(mod, o.fields[1], dtype) + hd = vmobj_to_list(mod, o.fields[0], dtype) + hd.extend(tl) + return hd + elif o.constructor.name_hint == "Nil": + return [] + elif "tensor_nil" in o.constructor.name_hint: + return [0] + elif "tensor" in o.constructor.name_hint: + return [o.fields[0].asnumpy()] + else: + raise RuntimeError("Unknown object type: %s" % o.constructor.name_hint) + else: + raise RuntimeError("Unknown object type: %s" % type(o)) + + +def check_tensor_array(ta_mod, ref_res, *args, dtype="float32", rtol=1e-5): + for kind in ["debug", "vm"]: + for target, ctx in [("llvm", tvm.cpu(0))]: # testing.enabled_targets(): + # for target, ctx in testing.enabled_targets(): + if kind == "debug" and ctx.device_type != tvm.cpu().device_type: + continue + ex = relay.create_executor(kind, mod=ta_mod, ctx=ctx, target=target) + result = ex.evaluate()(*args) + got = vmobj_to_list(ta_mod, result, dtype) + tvm.testing.assert_allclose(ref_res, got, rtol=rtol, atol=rtol) + + +@tvm.testing.uses_gpu +def test_tensor_expand_dims(): + def run(dtype): + x = relay.var("x") + mod = tvm.IRModule() + p = Prelude(mod) + expand_dims_func = p.get_global_var("tensor_expand_dims", dtype) + tensor1 = p.get_tensor_ctor("tensor1", dtype) + mod["main"] = relay.Function([x], expand_dims_func(tensor1(x))) + x_np = np.random.uniform(low=0.0, high=8.0, size=(1,)).astype(dtype) + expected = [np.expand_dims(x_np, axis=0)] + check_tensor_array(mod, expected, x_np) + + run("float32") + run("int32") + + +@tvm.testing.uses_gpu +def test_tensor_array_constructor(): + def run(dtype): + x = relay.var("x") + mod = tvm.IRModule() + p = Prelude(mod) + tensor_array = p.get_global_var("tensor_array", dtype) + mod["main"] = relay.Function([x], tensor_array(x)) + expected = np.array([0, 0, 0, 0, 0]) + check_tensor_array(mod, expected, 5, dtype=dtype) + + run("float32") + run("int32") + + +@tvm.testing.uses_gpu +def test_tensor_array_read(): + def run(dtype): + mod = tvm.IRModule() + p = Prelude(mod) + l = relay.var("l") + i = relay.var("i") + read_func = p.get_global_var("tensor_array_read", dtype) + tensor_array = p.get_global_var("tensor_array", dtype) + mod["main"] = relay.Function([l, i], read_func(tensor_array(l), i)) + expected = [0] + check_tensor_array(mod, expected, *(1, 0), dtype=dtype) + check_tensor_array(mod, expected, *(5, 1), dtype=dtype) + + run("float32") + run("int32") + + +@tvm.testing.uses_gpu +def test_tensor_array_write(): + def run(dtype): + mod = tvm.IRModule() + p = Prelude(mod) + tensor_t = p.get_type("tensor_t", dtype) + v1 = relay.var("v1") + v2 = relay.var("v2") + tensor_array = p.get_global_var("tensor_array", dtype) + init_tensor_array = tensor_array(relay.const(2)) + write_func = p.get_global_var("tensor_array_write", dtype) + tensor1 = p.get_tensor_ctor("tensor1", dtype) + tensor_array1 = write_func(init_tensor_array, relay.const(0), tensor1(v1)) + tensor_array2 = write_func(tensor_array1, relay.const(1), tensor1(v2)) + mod["main"] = relay.Function([v1, v2], tensor_array2) + expected = [3, 7] + check_tensor_array(mod, expected, *(3, 7), dtype=dtype) + + run("float32") + run("int32") + + +@tvm.testing.uses_gpu +def test_tensor_array_stack(): + def run(dtype): + mod = tvm.IRModule() + p = Prelude(mod) + tensor_t = p.get_type("tensor_t", dtype) + rlist = p.mod.get_global_type_var(f"List") + tensor_array = p.get_global_var("tensor_array", dtype) + tensor1 = p.get_tensor_ctor("tensor1", dtype) + write = p.get_global_var("tensor_array_write", dtype) + stack = p.get_global_var("tensor_array_stack", dtype) + # TODO extract test case from inference failures + # setting this wrong causes crashes + v = relay.var("v", shape=(1,), dtype=dtype) + init_tensor_array = tensor_array(relay.const(3)) + tensor_array1 = write(init_tensor_array, relay.const(0), tensor1(v)) + tensor_array2 = write(tensor_array1, relay.const(1), tensor1(v)) + tensor_array3 = write(tensor_array2, relay.const(2), tensor1(v)) + tensor_array4 = stack(tensor_array3) + mod["main"] = relay.Function([v], tensor_array4, tensor_t()) + t = np.random.uniform(low=0.0, high=8.0, size=(1,)).astype(dtype) + expected = [np.stack([t, t, t])] + check_tensor_array(mod, expected, t, dtype=dtype) + + run("float32") + run("int32") + + +@tvm.testing.uses_gpu +def test_tensor_array_unstack(): + def run(dtype): + mod = tvm.IRModule() + p = Prelude(mod) + unstack_tensor1 = p.get_global_var("tensor_array_unstack_tensor1", dtype) + v = relay.var("v") + mod["main"] = relay.Function([v], unstack_tensor1(v)) + t = np.random.uniform(low=0.0, high=8.0, size=(1,)).astype(dtype) + check_tensor_array(mod, t, t, dtype=dtype) + + run("float32") + run("int32") + + +@tvm.testing.uses_gpu +def test_tensor_take(): + def run(dtype): + mod = tvm.IRModule() + p = Prelude(mod) + take = p.get_global_var("tensor_take", dtype) + tensor2 = p.get_tensor_ctor("tensor2", dtype) + v = relay.var("v") + lower = relay.var("lower") + upper = relay.var("upper") + mod["main"] = relay.Function([v, lower, upper], take(tensor2(v), lower, upper)) + v_data = np.random.uniform(low=0.0, high=8.0, size=(10, 10)).astype(dtype) + expected = [np.take(v_data, range(2, 5), axis=0)] + check_tensor_array(mod, expected, *(v_data, 2, 5), dtype=dtype) + expected = [np.take(v_data, range(0, 9), axis=0)] + check_tensor_array(mod, expected, *(v_data, 0, 9), dtype=dtype) + + run("float32") + run("int32") + + +@tvm.testing.uses_gpu +def test_tensor_concatenate(): + def run(dtype): + mod = tvm.IRModule() + p = Prelude(mod) + concat = p.get_global_var("tensor_concatenate", dtype) + tensor1 = p.get_tensor_ctor("tensor1", dtype) + v1 = relay.var("v1", shape=(tvm.tir.Any(),), dtype=dtype) + v2 = relay.var("v2", shape=(tvm.tir.Any(),), dtype=dtype) + mod["main"] = relay.Function([v1, v2], concat(tensor1(v1), tensor1(v2))) + v1_data = np.random.uniform(low=0.0, high=8.0, size=(5,)).astype(dtype) + v2_data = np.random.uniform(low=0.0, high=8.0, size=(5,)).astype(dtype) + expected = [np.concatenate((v1_data, v2_data))] + check_tensor_array(mod, expected, *(v1_data, v2_data), dtype=dtype) + + run("float32") + run("int32") + + +@tvm.testing.uses_gpu +def test_tensor_array_concat(): + def run(dtype): + mod = tvm.IRModule() + p = Prelude(mod) + v1 = relay.var("v1") + v2 = relay.var("v2") + tensor_array = p.get_global_var("tensor_array", dtype) + tensor_array1 = tensor_array(relay.const(2)) + write_func = p.get_global_var("tensor_array_write", dtype) + concat_func = p.get_global_var("tensor_array_concat", dtype) + tensor1 = p.get_tensor_ctor("tensor2", dtype) + tensor_array1 = write_func(tensor_array1, relay.const(0), tensor1(v1)) + tensor_array1 = write_func(tensor_array1, relay.const(1), tensor1(v2)) + tensor_array_concat = concat_func(tensor_array1) + mod["main"] = relay.Function([v1, v2], tensor_array_concat) + v1_data = np.random.uniform(low=0.0, high=8.0, size=(2, 3)).astype(dtype) + v2_data = np.random.uniform(low=0.0, high=8.0, size=(1, 3)).astype(dtype) + expected = [np.concatenate((v1_data, v2_data), axis=0)] + check_tensor_array(mod, expected, *(v1_data, v2_data), dtype=dtype) + + run("float32") + run("int32") + + +@tvm.testing.uses_gpu +def test_tensor_array_scatter(): + def run(dtype): + mod = tvm.IRModule() + p = Prelude(mod) + + # tensor array + v1 = relay.var("v1") + v2 = relay.var("v2") + v3 = relay.var("v2") + tensor_array = p.get_global_var("tensor_array", dtype) + tensor_array1 = tensor_array(relay.const(3)) + write_func = p.get_global_var("tensor_array_write", dtype) + scatter_func = p.get_global_var("tensor_array_scatter", dtype) + tensor2 = p.get_tensor_ctor("tensor2", dtype) + tensor_array1 = write_func(tensor_array1, relay.const(0), tensor2(v1)) + tensor_array1 = write_func(tensor_array1, relay.const(1), tensor2(v2)) + tensor_array1 = write_func(tensor_array1, relay.const(2), tensor2(v3)) + + # indices array + index = relay.var("index") + + # values array + value_0 = relay.var("value_0") + value_1 = relay.var("value_1") + values_array = tensor_array(relay.const(2)) + values_array = write_func(values_array, relay.const(0), tensor2(value_0)) + values_array = write_func(values_array, relay.const(1), tensor2(value_1)) + + # create the scatter function + tensor_array_scatter = scatter_func(tensor_array1, index, values_array) + mod["main"] = relay.Function([v1, v2, v3, index, value_0, value_1], tensor_array_scatter) + + # initialize and check + v1_data = np.random.uniform(low=0.0, high=8.0, size=(2, 3)).astype(dtype) + v2_data = np.random.uniform(low=0.0, high=8.0, size=(2, 3)).astype(dtype) + v3_data = np.random.uniform(low=0.0, high=8.0, size=(2, 3)).astype(dtype) + index_data = np.array([0, 1], dtype="int32") + val1_data = np.random.uniform(low=0.0, high=8.0, size=(2, 3)).astype(dtype) + val2_data = np.random.uniform(low=0.0, high=8.0, size=(2, 3)).astype(dtype) + expected = [val1_data, val2_data, v3_data] + check_tensor_array( + mod, + expected, + *(v1_data, v2_data, v3_data, index_data, val1_data, val2_data), + dtype=dtype, + ) + + run("float32") + run("int32") + + +@tvm.testing.uses_gpu +def test_tensor_array_split(): + def run(dtype): + mod = tvm.IRModule() + p = Prelude(mod) + + # tensor array + v1 = relay.var("v1") + v2 = relay.var("v2") + v3 = relay.var("v2") + tensor_array = p.get_global_var("tensor_array", dtype) + tensor_array1 = tensor_array(relay.const(3)) + write_func = p.get_global_var("tensor_array_write", dtype) + split_func = p.get_global_var("tensor_array_split", dtype) + tensor2 = p.get_tensor_ctor("tensor2", dtype) + tensor_array1 = write_func(tensor_array1, relay.const(0), tensor2(v1)) + tensor_array1 = write_func(tensor_array1, relay.const(1), tensor2(v2)) + tensor_array1 = write_func(tensor_array1, relay.const(2), tensor2(v3)) + + # value tensor + value = relay.var("value") + + # lengths tensor + ta_len = relay.var("length") + + # create the scatter function + tensor_array_split = split_func(tensor_array1, tensor2(value), ta_len) + mod["main"] = relay.Function([v1, v2, v3, value, ta_len], tensor_array_split) + + # initialize and check + v1_data = np.random.uniform(low=0.0, high=8.0, size=(2, 3)).astype(dtype) + v2_data = np.random.uniform(low=0.0, high=8.0, size=(2, 3)).astype(dtype) + v3_data = np.random.uniform(low=0.0, high=8.0, size=(2, 3)).astype(dtype) + value_data = np.random.uniform(low=0.0, high=8.0, size=(4, 3)).astype(dtype) + length_data = np.array([2, 2], dtype="int32") + expected = np.concatenate([value_data, v3_data]) + expected = np.split(expected, indices_or_sections=[2, 4]) + check_tensor_array( + mod, expected, *(v1_data, v2_data, v3_data, value_data, length_data), dtype=dtype + ) + + run("float32") + run("int32") + + +@tvm.testing.uses_gpu +def test_static_tensor_take(): + def run(dtype, shape): + mod = tvm.IRModule() + p = Prelude(mod) + static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) + static_tensor_array_ops.register() + + take = p.get_global_var_static("tensor_take", dtype, shape) + tensor_constructor = p.get_tensor_ctor_static("tensor_constructor", dtype, shape) + v = relay.var("v") + lower = relay.var("lower") + upper = relay.var("upper") + mod["main"] = relay.Function([v, lower, upper], take(tensor_constructor(v), lower, upper)) + v_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype) + expected = [np.take(v_data, range(2, 5), axis=0)] + check_tensor_array(mod, expected, *(v_data, 2, 5), dtype=dtype) + expected = [np.take(v_data, range(0, 9), axis=0)] + check_tensor_array(mod, expected, *(v_data, 0, 9), dtype=dtype) + + run("float32", [10, 10]) + run("int32", [15, 11]) + + +@tvm.testing.uses_gpu +def test_static_tensor_concatenate(): + def run(dtype, shape): + mod = tvm.IRModule() + p = Prelude(mod) + static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) + static_tensor_array_ops.register() + + concat = p.get_global_var_static("tensor_concatenate", dtype, shape) + tensor = p.get_tensor_ctor_static("tensor_constructor", dtype, shape) + v1 = relay.var("v1") + v2 = relay.var("v2") + mod["main"] = relay.Function([v1, v2], concat(tensor(v1), tensor(v2))) + v1_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype) + v2_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype) + expected = [np.concatenate((v1_data, v2_data))] + check_tensor_array(mod, expected, *(v1_data, v2_data), dtype=dtype) + + run( + "float32", + [ + 5, + ], + ) + run("int32", [2, 3]) + + +@tvm.testing.uses_gpu +def test_static_tensor_expand_dims(): + def run(dtype, shape): + x = relay.var("x") + mod = tvm.IRModule() + p = Prelude(mod) + static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) + static_tensor_array_ops.register() + + expand_dims_func = p.get_global_var_static("tensor_expand_dims", dtype, shape) + tensor = p.get_tensor_ctor_static("tensor_constructor", dtype, shape) + mod["main"] = relay.Function([x], expand_dims_func(tensor(x))) + x_np = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype) + expected = [np.expand_dims(x_np, axis=0)] + check_tensor_array(mod, expected, x_np) + + run("float32", []) + run( + "int32", + [ + 2, + ], + ) + + +@tvm.testing.uses_gpu +def test_static_tensor_array_constructor(): + def run(dtype, shape): + mod = tvm.IRModule() + p = Prelude(mod) + static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) + static_tensor_array_ops.register() + tensor_constructor = p.get_name_static("tensor_constructor", dtype, shape) + assert tensor_constructor != None + + run("float32", [1, 1]) + + +@tvm.testing.uses_gpu +def test_static_tensor_array_read(): + def run(dtype, shape): + mod = tvm.IRModule() + p = Prelude(mod) + static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) + static_tensor_array_ops.register() + + np_data_list = [] + ta_length = 3 + for _ in range(ta_length): + np_data_list.append(np.random.uniform(0, 10, size=shape).astype(dtype)) + + v0 = relay.var("v0") + v1 = relay.var("v1") + v2 = relay.var("v2") + n = relay.var("n") + tensor = p.get_tensor_ctor_static("tensor_constructor", dtype, shape) + tensor_array = p.get_global_var_static("tensor_array", dtype, shape) + init_tensor_array = tensor_array(relay.const(ta_length)) + read_func = p.get_global_var_static("tensor_array_read", dtype, shape) + write_func = p.get_global_var_static("tensor_array_write", dtype, shape) + tensor_array0 = write_func(init_tensor_array, relay.const(0), tensor(v0)) + tensor_array1 = write_func(tensor_array0, relay.const(1), tensor(v1)) + tensor_array2 = write_func(tensor_array1, relay.const(2), tensor(v2)) + + mod["main"] = relay.Function([v0, v1, v2, n], read_func(tensor_array2, n)) + expected = [np_data_list[0]] + check_tensor_array(mod, expected, *list(np_data_list + [0]), dtype=dtype) + expected = [np_data_list[1]] + check_tensor_array(mod, expected, *list(np_data_list + [1]), dtype=dtype) + expected = [np_data_list[2]] + check_tensor_array(mod, expected, *list(np_data_list + [2]), dtype=dtype) + + run("float32", []) + run("int32", [2, 3]) + + +@tvm.testing.uses_gpu +def test_static_tensor_array_write(): + def run(dtype, shape): + mod = tvm.IRModule() + p = Prelude(mod) + static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) + static_tensor_array_ops.register() + + ta_length = 2 + np_data_list = [ + np.random.uniform(0, 10, size=shape).astype(dtype) for _ in range(ta_length) + ] + + v0 = relay.var("v0") + v1 = relay.var("v1") + tensor_array = p.get_global_var_static("tensor_array", dtype, shape) + init_tensor_array = tensor_array(relay.const(ta_length)) + write_func = p.get_global_var_static("tensor_array_write", dtype, shape) + tensor = p.get_tensor_ctor_static("tensor_constructor", dtype, shape) + tensor_array0 = write_func(init_tensor_array, relay.const(0), tensor(v0)) + tensor_array1 = write_func(tensor_array0, relay.const(1), tensor(v1)) + mod["main"] = relay.Function([v0, v1], tensor_array1) + expected = np_data_list + check_tensor_array(mod, expected, *np_data_list, dtype=dtype) + + run("float32", []) + run("int32", [2, 3]) + + +@tvm.testing.uses_gpu +def test_static_tensor_array_unstack(): + def run(dtype, shape): + mod = tvm.IRModule() + p = Prelude(mod) + static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) + static_tensor_array_ops.register() + + unstack_tensor = p.get_global_var_static("tensor_array_unstack", dtype, shape) + v = relay.var("v") + mod["main"] = relay.Function([v], unstack_tensor(v)) + t = np.random.uniform(low=0, high=10, size=shape).astype(dtype) + (*expected,) = t + check_tensor_array(mod, expected, t, dtype=dtype) + + run("float32", [4]) + run("int32", [2, 3]) + + +@tvm.testing.uses_gpu +def test_static_tensor_array_scatter(): + def run(dtype, shape, indices_shape=None): + mod = tvm.IRModule() + p = Prelude(mod) + static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) + static_tensor_array_ops.register() + if indices_shape is not None: + static_tensor_array_ops.define_tensor_array_scatter(indices_shape, True) + + # tensor array + v1 = relay.var("v1") + v2 = relay.var("v2") + v3 = relay.var("v2") + tensor_array = p.get_global_var_static("tensor_array", dtype, shape) + tensor_array0 = tensor_array(relay.const(3)) + write_func = p.get_global_var_static("tensor_array_write", dtype, shape) + scatter_func = p.get_global_var_static("tensor_array_scatter", dtype, shape) + tensor = p.get_tensor_ctor_static("tensor_constructor", dtype, shape) + tensor_array1 = write_func(tensor_array0, relay.const(0), tensor(v1)) + tensor_array1 = write_func(tensor_array1, relay.const(1), tensor(v2)) + tensor_array1 = write_func(tensor_array1, relay.const(2), tensor(v3)) + + # indices array + index = relay.var("index") + + # values array + value_0 = relay.var("value_0") + value_1 = relay.var("value_1") + values_array = tensor_array(relay.const(2)) + values_array = write_func(values_array, relay.const(0), tensor(value_0)) + values_array = write_func(values_array, relay.const(1), tensor(value_1)) + + # create the scatter function + tensor_array_scatter = scatter_func(tensor_array1, index, values_array) + mod["main"] = relay.Function([v1, v2, v3, index, value_0, value_1], tensor_array_scatter) + + # initialize and check + v1_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype) + v2_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype) + v3_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype) + index_data = np.array([0, 1], dtype="int32") + val1_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype) + val2_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype) + expected = [val1_data, val2_data, v3_data] + check_tensor_array( + mod, + expected, + *(v1_data, v2_data, v3_data, index_data, val1_data, val2_data), + dtype=dtype, + ) + + run("float32", [2, 3]) + run("int32", [2, 3]) + run( + "float32", + [2, 3], + [ + 2, + ], + ) + + +@tvm.testing.uses_gpu +def test_static_tensor_array_split(): + def run(dtype, shape, value_shape=None, lengths_shape=None): + mod = tvm.IRModule() + p = Prelude(mod) + static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) + static_tensor_array_ops.register() + if value_shape is not None or lengths_shape is not None: + static_tensor_array_ops.define_tensor_array_split(value_shape, lengths_shape, False) + + # tensor array + v1 = relay.var("v1") + v2 = relay.var("v2") + v3 = relay.var("v2") + + adt_shape = [ + relay.Any(), + ] + shape[1:] + test_ops = StaticTensorArrayOps(p, dtype, adt_shape) + test_ops.register() + tensor_array = test_ops.get_global_var("tensor_array") + + tensor_array1 = tensor_array(relay.const(3)) + write_func = test_ops.get_global_var("tensor_array_write") + split_ops = StaticTensorArrayOps(p, dtype, shape) + split_ops.register() + split_func = split_ops.get_global_var("tensor_array_split") + tensor = p.get_tensor_ctor_static("tensor_constructor", dtype, test_ops.shape) + tensor_array1 = write_func(tensor_array1, relay.const(0), tensor(v1)) + tensor_array1 = write_func(tensor_array1, relay.const(1), tensor(v2)) + tensor_array1 = write_func(tensor_array1, relay.const(2), tensor(v3)) + + # value tensor + value = relay.var("value") + + # lengths tensor + ta_len = relay.var("length") + + # create the split function + if value_shape is None: + tensor1 = p.get_tensor_ctor_static("tensor_constructor", dtype, shape) + else: + static_tensor_array_ops = StaticTensorArrayOps(p, dtype, value_shape) + static_tensor_array_ops.register() + tensor1 = p.get_tensor_ctor_static("tensor_constructor", dtype, test_ops.shape) + + tensor_array_split = split_func(tensor_array1, tensor1(value), ta_len) + mod["main"] = relay.Function([v1, v2, v3, value, ta_len], tensor_array_split) + + # initialize and check + v1_data = np.random.uniform(low=0.0, high=8.0, size=[2, 3]).astype(dtype) + v2_data = np.random.uniform(low=0.0, high=8.0, size=[2, 3]).astype(dtype) + v3_data = np.random.uniform(low=0.0, high=8.0, size=[2, 3]).astype(dtype) + value_data = np.random.uniform(low=0.0, high=8.0, size=value_shape or shape).astype(dtype) + length_data = np.array([2, 2], dtype="int32") + expected = np.concatenate([value_data, v3_data]) + expected = np.split(expected, indices_or_sections=[2, 4]) + check_tensor_array( + mod, expected, *(v1_data, v2_data, v3_data, value_data, length_data), dtype=dtype + ) + + run("float32", [4, 3]) + run("int32", [4, 3]) + run( + "int32", + [relay.Any(), 3], + [4, 3], + [ + 2, + ], + ) + + +@tvm.testing.uses_gpu +def test_static_tensor_array_concat(): + def run(dtype, shape): + mod = tvm.IRModule() + p = Prelude(mod) + static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) + static_tensor_array_ops.register() + + v1 = relay.var("v1") + v2 = relay.var("v2") + tensor_array = p.get_global_var_static("tensor_array", dtype, shape) + tensor_array1 = tensor_array(relay.const(2)) + write_func = p.get_global_var_static("tensor_array_write", dtype, shape) + concat_func = p.get_global_var_static("tensor_array_concat", dtype, shape) + tensor = p.get_tensor_ctor_static("tensor_constructor", dtype, shape) + tensor_array1 = write_func(tensor_array1, relay.const(0), tensor(v1)) + tensor_array1 = write_func(tensor_array1, relay.const(1), tensor(v2)) + tensor_array_concat = concat_func(tensor_array1) + mod["main"] = relay.Function([v1, v2], tensor_array_concat) + v1_data = np.random.uniform(low=0.0, high=8.0, size=(2, 3)).astype(dtype) + v2_data = np.random.uniform(low=0.0, high=8.0, size=(1, 3)).astype(dtype) + expected = [np.concatenate((v1_data, v2_data), axis=0)] + check_tensor_array(mod, expected, *(v1_data, v2_data), dtype=dtype) + + run("float32", [relay.Any(), 3]) + run("int32", [relay.Any(), 3]) + + +@tvm.testing.uses_gpu +def test_static_tensor_array_gather(): + def run(dtype, shape): + mod = tvm.IRModule() + p = Prelude(mod) + static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) + static_tensor_array_ops.register() + + tensor_array = p.get_global_var_static("tensor_array", dtype, shape) + tensor = p.get_tensor_ctor_static("tensor_constructor", dtype, shape) + write = p.get_global_var_static("tensor_array_write", dtype, shape) + gather = p.get_global_var_static("tensor_array_gather", dtype, shape) + v = relay.var("v") + indice = relay.var("indice") + init_tensor_array = tensor_array(relay.const(3)) + tensor_array1 = write(init_tensor_array, relay.const(0), tensor(v)) + tensor_array2 = write(tensor_array1, relay.const(1), tensor(v)) + tensor_array3 = write(tensor_array2, relay.const(2), tensor(v)) + out = gather(tensor_array3, indice) + mod["main"] = relay.Function([v, indice], out) + t = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype) + indice_data = np.array([0, 2], dtype="int32") + expected = [np.stack([t, t])] + check_tensor_array(mod, expected, *(t, indice_data), dtype=dtype) + + run("float32", []) + run("int32", [2, 3]) + + +@tvm.testing.uses_gpu +def test_static_tensor_array_stack(): + def run(dtype, shape): + mod = tvm.IRModule() + p = Prelude(mod) + static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) + static_tensor_array_ops.register() + + tensor_array = p.get_global_var_static("tensor_array", dtype, shape) + tensor = p.get_tensor_ctor_static("tensor_constructor", dtype, shape) + write = p.get_global_var_static("tensor_array_write", dtype, shape) + stack = p.get_global_var_static("tensor_array_stack", dtype, shape) + v = relay.var("v") + init_tensor_array = tensor_array(relay.const(3)) + tensor_array1 = write(init_tensor_array, relay.const(0), tensor(v)) + tensor_array2 = write(tensor_array1, relay.const(1), tensor(v)) + tensor_array3 = write(tensor_array2, relay.const(2), tensor(v)) + tensor_array4 = stack(tensor_array3) + mod["main"] = relay.Function([v], tensor_array4) + t = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype) + expected = [np.stack([t, t, t])] + check_tensor_array(mod, expected, t, dtype=dtype) + + run("float32", []) + run("int32", [2, 3]) + + +@tvm.testing.uses_gpu +def test_static_tensor_get_data(): + def run(dtype, shape): + mod = tvm.IRModule() + p = Prelude(mod) + static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) + static_tensor_array_ops.register() + + np_data_list = [] + ta_length = 3 + for _ in range(ta_length): + np_data_list.append(np.random.uniform(0, 10, size=shape).astype(dtype)) + + v0 = relay.var("v0") + v1 = relay.var("v1") + v2 = relay.var("v2") + n = relay.var("n") + tensor = p.get_tensor_ctor_static("tensor_constructor", dtype, shape) + tensor_array = p.get_global_var_static("tensor_array", dtype, shape) + init_tensor_array = tensor_array(relay.const(ta_length)) + read_func = p.get_global_var_static("tensor_array_read", dtype, shape) + write_func = p.get_global_var_static("tensor_array_write", dtype, shape) + get_data_func = p.get_global_var_static("tensor_get_data", dtype, shape) + tensor_array0 = write_func(init_tensor_array, relay.const(0), tensor(v0)) + tensor_array1 = write_func(tensor_array0, relay.const(1), tensor(v1)) + tensor_array2 = write_func(tensor_array1, relay.const(2), tensor(v2)) + + mod["main"] = relay.Function([v0, v1, v2, n], get_data_func(read_func(tensor_array2, n))) + expected = [np_data_list[0]] + check_tensor_array(mod, expected, *list(np_data_list + [0]), dtype=dtype) + expected = [np_data_list[1]] + check_tensor_array(mod, expected, *list(np_data_list + [1]), dtype=dtype) + expected = [np_data_list[2]] + check_tensor_array(mod, expected, *list(np_data_list + [2]), dtype=dtype) + + run("float32", []) + run("int32", [2, 3]) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index 455c8ce7e7bb4..6758d96773a26 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -19,57 +19,58 @@ """ import pytest import tvm -from tvm import te -from tvm import relay + +from tvm import IRModule, te, relay, parser from tvm.relay import op, transform, analysis from tvm.relay import Any -def run_infer_type(expr, mod=None): +def infer_mod(mod, annotate_spans=True): + if annotate_spans: + mod = relay.transform.AnnotateSpans()(mod) + + mod = transform.InferType()(mod) + return mod + + +def infer_expr(expr, annotate_spans=True): + mod = IRModule.from_expr(expr) + mod = infer_mod(mod, annotate_spans) + mod = transform.InferType()(mod) + entry = mod["main"] + return entry if isinstance(expr, relay.Function) else entry.body + + +def assert_has_type(expr, typ, mod=None): if not mod: - mod = tvm.IRModule.from_expr(expr) - mod = transform.InferType()(mod) - entry = mod["main"] - return entry if isinstance(expr, relay.Function) else entry.body - else: - if isinstance(expr, relay.GlobalVar): - gv = expr.name_hint - else: - func = expr - if not isinstance(expr, relay.Function): - func = relay.Function(analysis.free_vars(expr), expr) - mod["main"] = func - gv = "main" - mod = transform.InferType()(mod) - - if isinstance(expr, (relay.GlobalVar, relay.Function)): - return mod[gv] - return mod[gv].body - - -def assert_has_type(expr, typ, mod=tvm.IRModule({})): - checked_expr = run_infer_type(expr, mod) + mod = tvm.IRModule({}) + + mod["main"] = expr + mod = infer_mod(mod) + checked_expr = mod["main"] checked_type = checked_expr.checked_type if checked_type != typ: raise RuntimeError("Type mismatch %s vs %s" % (checked_type, typ)) -# initializes simple ADT for tests def initialize_box_adt(mod): + # initializes simple ADT for tests box = relay.GlobalTypeVar("box") tv = relay.TypeVar("tv") constructor = relay.Constructor("constructor", [tv], box) data = relay.TypeData(box, [tv], [constructor]) mod[box] = data - return (box, constructor) + return box, constructor def test_monomorphic_let(): "Program: let %x = 1; %x" + # TODO(@jroesch): this seems whack. sb = relay.ScopeBuilder() + x = relay.var("x", dtype="float64", shape=()) x = sb.let("x", relay.const(1.0, "float64")) sb.ret(x) - xchecked = run_infer_type(sb.get()) + xchecked = infer_expr(sb.get()) assert xchecked.checked_type == relay.scalar_type("float64") @@ -115,7 +116,7 @@ def test_dual_op(): t2 = sb.let("t2", relay.add(t1, x)) sb.ret(t2) f = relay.Function([x], sb.get()) - fchecked = run_infer_type(f) + fchecked = infer_expr(f) assert fchecked.checked_type == relay.FuncType([tp], tp) @@ -128,7 +129,7 @@ def @f(%x : Tensor[(10, 10), float32]) { tp = relay.TensorType((10, 10)) x = relay.var("x", tp) f = relay.Function([x], relay.log(x)) - fchecked = run_infer_type(f) + fchecked = infer_expr(f) assert fchecked.checked_type == relay.FuncType([tp], tp) @@ -156,8 +157,9 @@ def @f(%n: int32, %data: float32) -> float32 { sb.ret(f(relay.subtract(n, relay.const(1, ti32)), relay.log(data))) mod = tvm.IRModule() mod[f] = relay.Function([n, data], sb.get()) - assert "@f(%1, %2) /* ty=float32 */" in mod.astext() - assert mod[f].checked_type == relay.FuncType([ti32, tf32], tf32) + mod = infer_mod(mod) + assert "@f(%1, %2)" in mod.astext() + assert mod["f"].checked_type == relay.FuncType([ti32, tf32], tf32) def test_incomplete_call(): @@ -166,7 +168,7 @@ def test_incomplete_call(): f = relay.var("f") func = relay.Function([x, f], relay.Call(f, [x]), tt) - ft = run_infer_type(func) + ft = infer_expr(func) f_type = relay.FuncType([tt], tt) assert ft.checked_type == relay.FuncType([tt, f_type], tt) @@ -185,7 +187,7 @@ def test_higher_order_argument(): # function even though id_func takes a type parameter ho_call = ho_func(id_func, relay.const(0, "int32")) - hc = run_infer_type(ho_call) + hc = infer_expr(ho_call) expected = relay.scalar_type("int32") assert hc.checked_type == expected @@ -198,7 +200,7 @@ def test_higher_order_return(): b = relay.TypeVar("b") nested_id = relay.Function([], id_func, relay.FuncType([b], b), [b]) - ft = run_infer_type(nested_id) + ft = infer_expr(nested_id) assert ft.checked_type == relay.FuncType([], relay.FuncType([b], b), [b]) @@ -217,7 +219,7 @@ def test_higher_order_nested(): ) expected = relay.FuncType([choice_t], relay.FuncType([b], b), [b]) - ft = run_infer_type(top) + ft = infer_expr(top) assert ft.checked_type == expected @@ -225,7 +227,7 @@ def test_tuple(): tp = relay.TensorType((10,)) x = relay.var("x", tp) res = relay.Tuple([x, x]) - assert run_infer_type(res).checked_type == relay.TupleType([tp, tp]) + assert infer_expr(res).checked_type == relay.TupleType([tp, tp]) def test_ref(): @@ -233,18 +235,18 @@ def test_ref(): y = relay.var("y", "float32") r = relay.RefCreate(x) st = relay.scalar_type("float32") - assert run_infer_type(r).checked_type == relay.RefType(st) + assert infer_expr(r).checked_type == relay.RefType(st) g = relay.RefRead(r) - assert run_infer_type(g).checked_type == st + assert infer_expr(g).checked_type == st w = relay.RefWrite(r, y) - assert run_infer_type(w).checked_type == relay.TupleType([]) + assert infer_expr(w).checked_type == relay.TupleType([]) def test_free_expr(): - return x = relay.var("x", "float32") y = relay.add(x, x) - yy = run_infer_type(y) + yy = infer_expr(y, annotate_spans=False) + assert tvm.ir.structural_equal(yy.args[0], x, map_free_vars=True) assert yy.checked_type == relay.scalar_type("float32") assert x.vid.same_as(yy.args[0].vid) @@ -253,7 +255,7 @@ def test_type_args(): x = relay.var("x", shape=(10, 10)) y = relay.var("y", shape=(1, 10)) z = relay.add(x, y) - ty_z = run_infer_type(z) + ty_z = infer_expr(z) ty_args = ty_z.type_args assert len(ty_args) == 2 assert ty_args[0].dtype == "float32" @@ -274,16 +276,19 @@ def test_global_var_recursion(): func = relay.Function([x], relay.Call(gv, [x]), tt) mod[gv] = func + mod = infer_mod(mod) + func_ty = mod["main"].checked_type - ft = run_infer_type(gv, mod) - assert ft.checked_type == relay.FuncType([tt], tt) + assert func_ty == relay.FuncType([tt], tt) def test_equal(): i = relay.var("i", shape=[], dtype="int32") eq = op.equal(i, relay.const(0, dtype="int32")) func = relay.Function([i], eq) - ft = run_infer_type(func) + ft = infer_expr(func) + expected = relay.FuncType([relay.scalar_type("int32")], relay.scalar_type("bool")) + assert ft.checked_type == expected assert ft.checked_type == relay.FuncType( [relay.scalar_type("int32")], relay.scalar_type("bool") @@ -296,9 +301,13 @@ def test_constructor_type(): a = relay.TypeVar("a") x = relay.Var("x", a) - ct = run_infer_type(relay.Function([x], constructor(x), box(a), [a]), mod) + func = relay.Function([x], constructor(x), box(a), [a]) + mod["main"] = func + mod = infer_mod(mod) + func_ty = mod["main"].checked_type + box = mod.get_global_type_var("box") expected = relay.FuncType([a], box(a), [a]) - assert ct.checked_type == expected + assert func_ty == expected def test_constructor_call(): @@ -308,10 +317,17 @@ def test_constructor_call(): box_unit = constructor(relay.Tuple([])) box_constant = constructor(relay.const(0, "float32")) - ut = run_infer_type(box_unit, mod) - ct = run_infer_type(box_constant, mod) - assert ut.checked_type == box(relay.TupleType([])) - assert ct.checked_type == box(relay.TensorType((), "float32")) + func = relay.Function([], relay.Tuple([box_unit, box_constant])) + mod["main"] = func + mod = infer_mod(mod) + ret_type = mod["main"].checked_type.ret_type.fields + # NB(@jroesch): when we annotate spans the ast fragments before + # annotation the previous fragments will no longer be directly equal. + box = mod.get_global_type_var("box") + expected1 = box(relay.TupleType([])) + expected2 = box(relay.TensorType((), "float32")) + assert ret_type[0] == expected1 + assert ret_type[1] == expected2 def test_adt_match(): @@ -330,8 +346,11 @@ def test_adt_match(): ], ) - mt = run_infer_type(match, mod) - assert mt.checked_type == relay.TupleType([]) + func = relay.Function([], match) + mod["main"] = func + mod = infer_mod(mod) + actual = mod["main"].checked_type.ret_type + assert actual == relay.TupleType([]) def test_adt_match_type_annotations(): @@ -352,9 +371,10 @@ def test_adt_match_type_annotations(): ], ) - func = relay.Function([x], match) - ft = run_infer_type(func, mod) - assert ft.checked_type == relay.FuncType([tt], relay.TupleType([])) + mod["main"] = relay.Function([x], match) + mod = infer_mod(mod) + ft = mod["main"].checked_type + assert ft == relay.FuncType([tt], relay.TupleType([])) def test_let_polymorphism(): @@ -363,7 +383,7 @@ def test_let_polymorphism(): x = relay.Var("x", xt) body = relay.Tuple([id(relay.const(1)), id(relay.Tuple([]))]) body = relay.Let(id, relay.Function([x], x, xt, [xt]), body) - body = run_infer_type(body) + body = infer_expr(body) int32 = relay.TensorType((), "int32") tvm.ir.assert_structural_equal(body.checked_type, relay.TupleType([int32, relay.TupleType([])])) @@ -374,7 +394,7 @@ def test_if(): true_branch = relay.Var("True", relay.TensorType([Any(), 1], dtype="float32")) false_branch = relay.Var("False", relay.TensorType([Any(), Any()], dtype="float32")) top = relay.Function([f, true_branch, false_branch], relay.If(f(), true_branch, false_branch)) - ft = run_infer_type(top) + ft = infer_expr(top) tvm.ir.assert_structural_equal(ft.ret_type, relay.TensorType([Any(), 1], dtype="float32")) @@ -394,4 +414,6 @@ def @main(%f: float32) -> float32 { if __name__ == "__main__": - pytest.main([__file__]) + import sys + + pytest.main(sys.argv) diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 0ee6accdc1f4d..038b5c5ed9e15 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -235,6 +235,7 @@ def test_sum_loop(): sb.ret(relay.Call(sum_up, [one_less, new_accum])) func = relay.Function([i, accum], sb.get()) mod[sum_up] = func + mod = relay.transform.InferType()(mod) loop_bound = 0 i_data = np.array(loop_bound, dtype="int32") accum_data = np.array(0, dtype="int32") @@ -273,9 +274,7 @@ def test_list_constructor(): mod = tvm.IRModule() p = Prelude(mod) - nil = p.nil - cons = p.cons - l = p.l + l, cons, nil = mod.get_type("List") one2 = cons(relay.const(1), nil()) one3 = cons(relay.const(2), one2) @@ -372,10 +371,8 @@ def test_list_hd(): mod = tvm.IRModule() p = Prelude(mod) - nil = p.nil - cons = p.cons - l = p.l - hd = p.hd + l, cons, nil = mod.get_type("List") + hd = mod.get_global_var("hd") one2 = cons(relay.const(1), nil()) one3 = cons(relay.const(2), one2) @@ -395,9 +392,8 @@ def test_list_tl_empty_list(): mod = tvm.IRModule() p = Prelude(mod) - nil = p.nil - l = p.l - tl = p.tl + l, cons, nil = mod.get_type("List") + tl = mod.get_global_var("tl") f = relay.Function([], tl(nil())) @@ -412,10 +408,8 @@ def test_list_tl(): mod = tvm.IRModule() p = Prelude(mod) - nil = p.nil - cons = p.cons - l = p.l - tl = p.tl + l, cons, nil = mod.get_type("List") + tl = mod.get_global_var("tl") one2 = cons(relay.const(1), nil()) one3 = cons(relay.const(2), one2) @@ -438,9 +432,9 @@ def test_list_nth(): mod = tvm.IRModule() p = Prelude(mod) - nil = p.nil - cons = p.cons - nth = p.nth + _, cons, nil = mod.get_type("List") + nth = mod.get_global_var("nth") + l = nil() for i in reversed(expected): l = cons(relay.const(i), l) @@ -459,9 +453,8 @@ def test_list_update(): mod = tvm.IRModule() p = Prelude(mod) - nil = p.nil - cons = p.cons - update = p.update + _, cons, nil = mod.get_type("List") + update = mod.get_global_var("update") l = nil() # create zero initialized list @@ -486,13 +479,12 @@ def test_list_length(): mod = tvm.IRModule() p = Prelude(mod) - nil = p.nil - cons = p.cons - length = p.length + _, cons, nil = mod.get_type("List") + length = mod.get_global_var("length") l = nil() # create zero initialized list - for i in range(len(expected)): + for _ in range(len(expected)): l = cons(relay.const(0), l) l = length(l) @@ -512,9 +504,8 @@ def test_list_map(): x = relay.var("x", "int32") add_one_func = relay.Function([x], relay.const(1) + x) - nil = p.nil - cons = p.cons - map = p.map + _, cons, nil = mod.get_type("List") + map = mod.get_global_var("map") l = cons(relay.const(2), cons(relay.const(1), nil())) @@ -530,9 +521,8 @@ def test_list_foldl(): mod = tvm.IRModule() p = Prelude(mod) - nil = p.nil - cons = p.cons - foldl = p.foldl + _, cons, nil = mod.get_type("List") + foldl = mod.get_global_var("foldl") x = relay.var("x") y = relay.var("y") @@ -551,9 +541,8 @@ def test_list_foldr(): mod = tvm.IRModule() p = Prelude(mod) - nil = p.nil - cons = p.cons - foldr = p.foldr + _, cons, nil = mod.get_type("List") + foldr = mod.get_global_var("foldr") x = relay.var("x") y = relay.var("y") @@ -572,9 +561,8 @@ def test_list_sum(): mod = tvm.IRModule() p = Prelude(mod) - nil = p.nil - cons = p.cons - sum = p.sum + _, cons, nil = mod.get_type("List") + sum = mod.get_global_var("sum") l = cons(relay.const(1), cons(relay.const(2), cons(relay.const(3), nil()))) f = relay.Function([], sum(l)) @@ -589,9 +577,8 @@ def test_list_filter(): mod = tvm.IRModule() p = Prelude(mod) - nil = p.nil - cons = p.cons - filter = p.filter + _, cons, nil = mod.get_type("List") + filter = mod.get_global_var("filter") x = relay.var("x", "int32") greater_than_one = relay.Function([x], x > relay.const(1)) diff --git a/tests/python/relay/test_vm_serialization.py b/tests/python/relay/test_vm_serialization.py index 4be2fe9037c4f..14f003e3500b8 100644 --- a/tests/python/relay/test_vm_serialization.py +++ b/tests/python/relay/test_vm_serialization.py @@ -25,6 +25,7 @@ from tvm import relay from tvm.relay.scope_builder import ScopeBuilder +from tvm.relay import transform from tvm.relay.prelude import Prelude from tvm.contrib import util from tvm.relay import testing @@ -57,7 +58,6 @@ def get_vm_output(mod, data, params, target, ctx, dtype="float32"): result = ex.evaluate()(data, **params) return result.asnumpy().astype(dtype) - print(mod["main"]) data_shape = [int(x) for x in mod["main"].checked_type.arg_types[0].shape] data = np.random.uniform(size=data_shape).astype(dtype) target = "llvm" @@ -78,12 +78,18 @@ def test_serializer(): glb_f1 = relay.GlobalVar("f1") mod[glb_f1] = f1 + # TODO(@jroesch): look into optimizing away the need to do this + mod = transform.InferType()(mod) + b = relay.const(2.0, "float32") y = relay.var("y", shape=(10, 10), dtype="float32") f2 = relay.Function([y], y - b) glb_f2 = relay.GlobalVar("f2") mod[glb_f2] = f2 + # TODO(@jroesch): look into optimizing away the need to do this + mod = transform.InferType()(mod) + x1 = relay.var("x1", shape=(10, 10), dtype="float32") y1 = relay.var("y1", shape=(10, 10), dtype="float32") main = relay.Function([x1, y1], glb_f1(x1) * glb_f2(y1)) @@ -181,6 +187,7 @@ def test_loop(): sb.ret(relay.Call(sum_up, [one_less, new_accum])) func = relay.Function([i, accum], sb.get()) mod[sum_up] = func + mod = transform.InferType()(mod) loop_bound = 0 i_data = np.array(loop_bound, dtype="int32") accum_data = np.array(0, dtype="int32") @@ -206,10 +213,10 @@ def test_tuple(): def test_adt_list(): mod = tvm.IRModule() p = Prelude(mod) - - l1 = p.cons(relay.const(1), p.nil()) - l21 = p.cons(relay.const(2), l1) - l321 = p.cons(relay.const(3), l21) + _, cons, nil = mod.get_type("List") + l1 = cons(relay.const(1), nil()) + l21 = cons(relay.const(2), l1) + l321 = cons(relay.const(3), l21) f = relay.Function([], l321) mod["main"] = f @@ -229,7 +236,7 @@ def test_adt_compose(): mod = tvm.IRModule() p = Prelude(mod) - compose = p.compose + compose = mod.get_global_var("compose") # add_one = fun x -> x + 1 sb = relay.ScopeBuilder() diff --git a/tests/python/relay/util/assert_diagnostic.py b/tests/python/relay/util/assert_diagnostic.py new file mode 100644 index 0000000000000..ba73d8755e0cc --- /dev/null +++ b/tests/python/relay/util/assert_diagnostic.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm + +from tvm import register_func, get_global_func, IRModule +from tvm import relay +from tvm.parser import SpanCheck +from tvm.relay.transform import AnnotateSpans +from tvm.runtime import Object +from tvm.ir.diagnostics import get_renderer, override_renderer +from tvm.error import DiagnosticError + +DEFAULT_RENDERER = get_renderer() + +__TESTING__ = None + + +def testing_renderer(diag_ctx): + global __TESTING__ + if __TESTING__ and __TESTING__.mirror: + DEFAULT_RENDERER.render(diag_ctx) + + if __TESTING__: + __TESTING__._render(diag_ctx) + + +class DiagnosticTesting: + def __init__(self, mirror=False): + self.mirror = mirror + self.messages = [] + + def __enter__(self): + global __TESTING__ + __TESTING__ = self + override_renderer(testing_renderer) + return self + + def __exit__(self, type, value, traceback): + global __TESTING__ + __TESTING__ = None + override_renderer(None) + if type is DiagnosticError and self.matches: + return True + + def assert_message(self, in_message): + self.messages.append(in_message) + + def _render(self, diag_ctx): + self.matches = False + for diagnostic in diag_ctx.diagnostics: + message = diagnostic.message + for partial_msg in self.messages: + if partial_msg in message: + self.matches = True diff --git a/tests/python/unittest/test_custom_datatypes.py b/tests/python/unittest/test_custom_datatypes.py index 239a67acb0d2d..6aad93abd5102 100644 --- a/tests/python/unittest/test_custom_datatypes.py +++ b/tests/python/unittest/test_custom_datatypes.py @@ -88,6 +88,7 @@ def change_dtype(src, dst, module, params): def compare(module, input, src_dtype, dst_dtype, rtol, atol, params={}, target="llvm"): + module = relay.transform.InferType()(module) module = relay.transform.SimplifyInference()(module) ex = relay.create_executor("graph", mod=module) diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py b/tests/python/unittest/test_tir_transform_narrow_datatype.py index e1e5adb58cec1..b1a9eae7893a3 100644 --- a/tests/python/unittest/test_tir_transform_narrow_datatype.py +++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py @@ -195,6 +195,7 @@ def check(shapex, shapey, target_bits, target_dtype): z = relay.add(x, y) func = relay.Function([x, y], z) mod = tvm.IRModule.from_expr(func) + mod = relay.transform.InferType()(mod) func = mod["main"] z = engine.lower(func, "llvm") stmt = lower_sch(z.schedule, tuple(z.inputs) + tuple(z.outputs), 32) @@ -235,6 +236,7 @@ def check(shape, index, target_bits, target_dtype): y = relay.op.take(x, indices=index) func = relay.Function([x], y) mod = tvm.IRModule.from_expr(func) + mod = relay.transform.InferType()(mod) func = mod["main"] z = engine.lower(func, "llvm") stmt = lower_sch(z.schedule, tuple(z.inputs) + tuple(z.outputs), 32) diff --git a/tutorials/dev/bring_your_own_datatypes.py b/tutorials/dev/bring_your_own_datatypes.py index cbb1b99413cf1..07592e7d6d4cd 100644 --- a/tutorials/dev/bring_your_own_datatypes.py +++ b/tutorials/dev/bring_your_own_datatypes.py @@ -125,6 +125,7 @@ z = relay.cast(z_myfloat, dtype="float32") program = relay.Function([x, y], z) module = tvm.IRModule.from_expr(program) +module = relay.transform.InferType()(module) ###################################################################### # Now we have a Relay program that uses myfloat! @@ -287,6 +288,8 @@ def convert_ndarray(dst_dtype, array): src_dtype = "float32" dst_dtype = "custom[myfloat]32" +module = relay.transform.InferType()(module) + # Currently, custom datatypes only work if you run simplify_inference beforehand module = tvm.relay.transform.SimplifyInference()(module)