diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 1fef02557e09..042ad1ef02da 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -569,6 +569,17 @@ TVM_DLL Pass FlattenAtrousConv(); */ TVM_DLL Pass AnnotateUsedMemory(); +/*! + * \brief Captures the post-dfs index and dominator post-dfs index of (most) expression nodes in + * their span, in the form "index::". This is useful for + * debugging since a) it helps identify pretty-printed sub-expressions within the overall model + * and b) the indexes are heavily used by Collage for its compact representation of sub-graphs. + * + * Note that Op and Constructor nodes are not changed even though they are assigned an + * post-dfs index. + */ +TVM_DLL Pass CapturePostDfsIndexInSpans(); + } // namespace transform /*! diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 05ea67b56a7b..c931289d40c6 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1422,7 +1422,7 @@ def MarkCompilerFunctionsAsExtern(compiler_filter=""): return _ffi_api.MarkCompilerFunctionsAsExtern(compiler_filter) -def CaptureIndexInSpans(): +def CapturePostDfsIndexInSpans(): """Captures the post-dfs index and dominator post-dfs index of (most) expression nodes in their span, in the form "index::". @@ -1438,7 +1438,7 @@ def CaptureIndexInSpans(): ret : tvm.transform.Pass The pass. """ - return _ffi_api.CaptureIndexInSpans() + return _ffi_api.CapturePostDfsIndexInSpans() def InlineCompilerFunctionsBoundTo(global_vars): diff --git a/src/relay/transforms/capture_index_in_spans.cc b/src/relay/transforms/capture_postdfsindex_in_spans.cc similarity index 91% rename from src/relay/transforms/capture_index_in_spans.cc rename to src/relay/transforms/capture_postdfsindex_in_spans.cc index c7cc10478012..17c7e59c7f60 100644 --- a/src/relay/transforms/capture_index_in_spans.cc +++ b/src/relay/transforms/capture_postdfsindex_in_spans.cc @@ -34,9 +34,9 @@ namespace transform { namespace { /*! \brief Update all the spans to capture their post-dfs index. */ -class CaptureIndexInSpansRewriter : public ExprRewriter { +class SpansRewriter : public ExprRewriter { public: - explicit CaptureIndexInSpansRewriter(const IndexedGraph* indexed_graph) + explicit SpansRewriter(const IndexedGraph* indexed_graph) : source_name_(SourceName::Get("index")), indexed_graph_(indexed_graph) {} private: @@ -117,16 +117,17 @@ class CaptureIndexInSpansRewriter : public ExprRewriter { } // namespace -tvm::transform::Pass CaptureIndexInSpans() { +tvm::transform::Pass CapturePostDfsIndexInSpans() { auto pass_func = [](Function f, IRModule m, transform::PassContext ctxt) { std::unique_ptr> indexed_graph = CreateIndexedGraph(f); - CaptureIndexInSpansRewriter rewriter(indexed_graph.get()); + SpansRewriter rewriter(indexed_graph.get()); return Downcast(PostOrderRewrite(f, &rewriter)); }; - return CreateFunctionPass(pass_func, 0, "CaptureIndexInSpans", {}); + return CreateFunctionPass(pass_func, 0, "CapturePostDfsIndexInSpans", {}); } -TVM_REGISTER_GLOBAL("relay._transform.CaptureIndexInSpans").set_body_typed(CaptureIndexInSpans); +TVM_REGISTER_GLOBAL("relay._transform.CapturePostDfsIndexInSpans") + .set_body_typed(CapturePostDfsIndexInSpans); } // namespace transform } // namespace relay diff --git a/tests/python/relay/transform/test_capture_index_in_spans.py b/tests/python/relay/transform/test_capture_postdfsindex_in_spans.py similarity index 96% rename from tests/python/relay/transform/test_capture_index_in_spans.py rename to tests/python/relay/transform/test_capture_postdfsindex_in_spans.py index c60ffccb7aed..16a7bd447992 100644 --- a/tests/python/relay/transform/test_capture_index_in_spans.py +++ b/tests/python/relay/transform/test_capture_postdfsindex_in_spans.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License -"""Unit tests for the CaptureIndexInSpans debugging pass.""" +"""Unit tests for the CapturePostDfsIndexInSpans debugging pass.""" import tvm import tvm.testing @@ -83,7 +83,7 @@ def @main(%x0 : Tensor[(1600, 768), float16], %x3 : Tensor[(600, 32, 64), float1 def test_capture_index_in_spans(): - output_mod = str(tvm.relay.transform.CaptureIndexInSpans()(input_mod())) + output_mod = str(tvm.relay.transform.CapturePostDfsIndexInSpans()(input_mod())) assert output_mod == expected_pretty_printed_output_mod