Skip to content

Commit

Permalink
[Relay] CaptureIndexInSpans debugging pass
Browse files Browse the repository at this point in the history
This pass will update (most) expression nodes to capture their post-dfs
indexes. That makes it easy to connect pretty-printed fragments back to
the overall model, and is very handy for Collage which uses post-dfs indexes
extensively.
  • Loading branch information
mbs-octoml committed Jun 28, 2022
1 parent 1115fd9 commit 0009ecc
Show file tree
Hide file tree
Showing 3 changed files with 243 additions and 0 deletions.
19 changes: 19 additions & 0 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,3 +1420,22 @@ def MarkCompilerFunctionsAsExtern(compiler_filter=""):
The pass.
"""
return _ffi_api.MarkCompilerFunctionsAsExtern(compiler_filter)


def CaptureIndexInSpans():
"""Captures the post-dfs index and dominator post-dfs index of (most) expression nodes in
their span, in the form "index:<post-dfs index>:<dominator post-dfs index>".
This is useful for debugging since a) it helps identify pretty-printed sub-expressions within
the overall model and b) the indexes are heavily used by Collage for its compact representation
of sub-graphs.
Note that Op and Constructor nodes are not changed even though they are assigned an
post-dfs index.
Returns
-------
ret : tvm.transform.Pass
The pass.
"""
return _ffi_api.CaptureIndexInSpans()
133 changes: 133 additions & 0 deletions src/relay/transforms/capture_index_in_spans.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relay/transform/capture_index_in_spans.cc
* \brief A pass to set spans to capture the post-dfs index of every node.
*/

#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>

#include "../ir/indexed_graph.h"

namespace tvm {
namespace relay {
namespace transform {

namespace {

/*! \brief Update all the spans to capture their post-dfs index. */
class CaptureIndexInSpansRewriter : public ExprRewriter {
public:
explicit CaptureIndexInSpansRewriter(const IndexedGraph<Expr>* indexed_graph)
: source_name_(SourceName::Get("index")), indexed_graph_(indexed_graph) {}

private:
Expr Rewrite_(const VarNode* var_node, const Expr& post) final {
return WithFields(Downcast<Var>(post), {}, {}, {}, MakeSpan(GetRef<Var>(var_node)));
}

Expr Rewrite_(const GlobalVarNode* global_var_node, const Expr& post) final {
return WithFields(Downcast<GlobalVar>(post), {}, {}, {},
MakeSpan(GetRef<GlobalVar>(global_var_node)));
}

Expr Rewrite_(const ConstantNode* constant_node, const Expr& post) final {
return WithFields(Downcast<Constant>(post), {}, {}, MakeSpan(GetRef<Constant>(constant_node)));
}

Expr Rewrite_(const TupleNode* tuple_node, const Expr& post) final {
return WithFields(Downcast<Tuple>(post), {}, {}, MakeSpan(GetRef<Tuple>(tuple_node)));
}

Expr Rewrite_(const FunctionNode* function_node, const Expr& post) final {
return WithFields(Downcast<Function>(post), {}, {}, {}, {}, {}, {},
MakeSpan(GetRef<Function>(function_node)));
}

Expr Rewrite_(const CallNode* call_node, const Expr& post) final {
return WithFields(Downcast<Call>(post), {}, {}, {}, {}, {}, MakeSpan(GetRef<Call>(call_node)));
}

Expr Rewrite_(const LetNode* let_node, const Expr& post) final {
return WithFields(Downcast<Let>(post), {}, {}, {}, {}, MakeSpan(GetRef<Let>(let_node)));
}

Expr Rewrite_(const IfNode* if_node, const Expr& post) final {
return WithFields(Downcast<If>(post), {}, {}, {}, {}, MakeSpan(GetRef<If>(if_node)));
}

// OpNodes are not rewritten.

Expr Rewrite_(const TupleGetItemNode* tuple_get_item_node, const Expr& post) final {
return WithFields(Downcast<TupleGetItem>(post), {}, {}, {},
MakeSpan(GetRef<TupleGetItem>(tuple_get_item_node)));
}

Expr Rewrite_(const RefCreateNode* ref_create_node, const Expr& post) final {
return WithFields(Downcast<RefCreate>(post), {}, {},
MakeSpan(GetRef<RefCreate>(ref_create_node)));
}

Expr Rewrite_(const RefReadNode* ref_read_node, const Expr& post) final {
return WithFields(Downcast<RefRead>(post), {}, {}, MakeSpan(GetRef<RefRead>(ref_read_node)));
}

Expr Rewrite_(const RefWriteNode* ref_write_node, const Expr& post) final {
return WithFields(Downcast<RefWrite>(post), {}, {}, {},
MakeSpan(GetRef<RefWrite>(ref_write_node)));
}

// ConstructorNodes are not rewritten.

Expr Rewrite_(const MatchNode* match_node, const Expr& post) final {
return WithFields(Downcast<Match>(post), {}, {}, {}, MakeSpan(GetRef<Match>(match_node)));
}

Span MakeSpan(const Expr& expr) {
auto node = indexed_graph_->item_to_node(expr);
int node_index = static_cast<int>(node->index_);
int dominator_index =
node->dominator_parent_ ? static_cast<int>(node->dominator_parent_->index_) : -1;
Span span(source_name_, /*line=*/node_index, /*end_line=*/node_index,
/*column=*/dominator_index, /*end_column=*/dominator_index);
return span;
}

SourceName source_name_;
const IndexedGraph<Expr>* indexed_graph_;
};

} // namespace

tvm::transform::Pass CaptureIndexInSpans() {
auto pass_func = [](Function f, IRModule m, transform::PassContext ctxt) {
std::unique_ptr<IndexedGraph<Expr>> indexed_graph = CreateIndexedGraph(f);
CaptureIndexInSpansRewriter rewriter(indexed_graph.get());
return Downcast<Function>(PostOrderRewrite(f, &rewriter));
};
return CreateFunctionPass(pass_func, 0, "CaptureIndexInSpans", {});
}

TVM_REGISTER_GLOBAL("relay._transform.CaptureIndexInSpans").set_body_typed(CaptureIndexInSpans);

} // namespace transform
} // namespace relay
} // namespace tvm
91 changes: 91 additions & 0 deletions tests/python/relay/transform/test_capture_index_in_spans.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License
"""Unit tests for the CaptureIndexInSpans debugging pass."""

import tvm
import tvm.testing
import numpy as np


def make_const(dtype, shape):
return tvm.relay.const(np.random.rand(*shape).astype(dtype))


def make_consts(dtype, shapes):
return [make_const(dtype, shape) for shape in shapes]


metatable = {
"relay.Constant": make_consts(
"float16",
[
(2304, 768), # 0
(2304,), # 1
(600, 32, 64), # 2
],
)
}


def input_mod():
return tvm.parser.parse(
"""
#[version = "0.0.5"]
def @main(%x0 : Tensor[(1600, 768), float16], %x3 : Tensor[(600, 32, 64), float16]) -> (Tensor[(1600, 2304), float16], Tensor[(600, 32, 32), float16]) {
%0 = nn.dense(%x0, meta[relay.Constant][0], units=2304);
%1 = add(%0, meta[relay.Constant][1]);
%2 = fn(%y_3_i0: Tensor[(600, 32, 64), float16], %y_3_i1: Tensor[(600, 32, 64), float16],
Inline=1, Compiler="cublas", global_symbol="tvmgen_default_cublas_main_3", Primitive=1) -> Tensor[(600, 32, 32), float16] {
%6 = fn (%FunctionVar_0_01: Tensor[(600, 32, 64), float16], %FunctionVar_0_11: Tensor[(600, 32, 64), float16],
PartitionedFromPattern="nn.batch_matmul_", Composite="cublas.batch_matmul") -> Tensor[(600, 32, 32), float16] {
nn.batch_matmul(%FunctionVar_0_01, %FunctionVar_0_11, out_dtype="float16", transpose_b=True)
};
%6(%y_3_i0, %y_3_i1)
};
%3 = %2(%x3, meta[relay.Constant][2]);
(%1, %3)
}
""",
"from_string",
None,
metatable,
)


expected_pretty_printed_output_mod = r"""def @main(%x0: Tensor[(1600, 768), float16] /* ty=Tensor[(1600, 768), float16] span=index:0:5 */, %x3: Tensor[(600, 32, 64), float16] /* ty=Tensor[(600, 32, 64), float16] span=index:1:18 */) -> (Tensor[(1600, 2304), float16], Tensor[(600, 32, 32), float16]) {
%0 = nn.dense(%x0, meta[relay.Constant][0] /* ty=Tensor[(2304, 768), float16] span=index:4:5 */, units=2304) /* ty=Tensor[(1600, 2304), float16] span=index:5:7 */;
%2 = fn (%y_3_i0: Tensor[(600, 32, 64), float16] /* ty=Tensor[(600, 32, 64), float16] span=index:8:15 */, %y_3_i1: Tensor[(600, 32, 64), float16] /* ty=Tensor[(600, 32, 64), float16] span=index:9:15 */, Inline=1, Compiler="cublas", global_symbol="tvmgen_default_cublas_main_3", Primitive=1) -> Tensor[(600, 32, 32), float16] {
%1 = fn (%FunctionVar_0_01: Tensor[(600, 32, 64), float16] /* ty=Tensor[(600, 32, 64), float16] span=index:10:13 */, %FunctionVar_0_11: Tensor[(600, 32, 64), float16] /* ty=Tensor[(600, 32, 64), float16] span=index:11:13 */, PartitionedFromPattern="nn.batch_matmul_", Composite="cublas.batch_matmul") -> Tensor[(600, 32, 32), float16] {
nn.batch_matmul(%FunctionVar_0_01, %FunctionVar_0_11, out_dtype="float16", transpose_b=True) /* ty=Tensor[(600, 32, 32), float16] span=index:13:14 */
} /* ty=fn (Tensor[(600, 32, 64), float16], Tensor[(600, 32, 64), float16]) -> Tensor[(600, 32, 32), float16] span=index:14:15 */;
%1(%y_3_i0, %y_3_i1) /* ty=Tensor[(600, 32, 32), float16] span=index:15:16 */
} /* ty=fn (Tensor[(600, 32, 64), float16], Tensor[(600, 32, 64), float16]) -> Tensor[(600, 32, 32), float16] span=index:16:18 */;
%3 = add(%0, meta[relay.Constant][1] /* ty=Tensor[(2304), float16] span=index:6:7 */) /* ty=Tensor[(1600, 2304), float16] span=index:7:19 */;
%4 = %2(%x3, meta[relay.Constant][2] /* ty=Tensor[(600, 32, 64), float16] span=index:17:18 */) /* ty=Tensor[(600, 32, 32), float16] span=index:18:19 */;
(%3, %4) /* ty=(Tensor[(1600, 2304), float16], Tensor[(600, 32, 32), float16]) span=index:19:20 */
}
"""


def test_capture_index_in_spans():
output_mod = str(tvm.relay.transform.CaptureIndexInSpans()(input_mod()))
assert output_mod == expected_pretty_printed_output_mod


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 0009ecc

Please sign in to comment.