Skip to content

Commit

Permalink
- rename
Browse files Browse the repository at this point in the history
- add header decl
  • Loading branch information
mbs-octoml committed Jun 28, 2022
1 parent ec45185 commit 1b3eb16
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 10 deletions.
11 changes: 11 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,17 @@ TVM_DLL Pass FlattenAtrousConv();
*/
TVM_DLL Pass AnnotateUsedMemory();

/*!
* \brief Captures the post-dfs index and dominator post-dfs index of (most) expression nodes in
* their span, in the form "index:<post-dfs index>:<dominator post-dfs index>". This is useful for
* debugging since a) it helps identify pretty-printed sub-expressions within the overall model
* and b) the indexes are heavily used by Collage for its compact representation of sub-graphs.
*
* Note that Op and Constructor nodes are not changed even though they are assigned an
* post-dfs index.
*/
TVM_DLL Pass CapturePostDfsIndexInSpans();

} // namespace transform

/*!
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1422,7 +1422,7 @@ def MarkCompilerFunctionsAsExtern(compiler_filter=""):
return _ffi_api.MarkCompilerFunctionsAsExtern(compiler_filter)


def CaptureIndexInSpans():
def CapturePostDfsIndexInSpans():
"""Captures the post-dfs index and dominator post-dfs index of (most) expression nodes in
their span, in the form "index:<post-dfs index>:<dominator post-dfs index>".
Expand All @@ -1438,7 +1438,7 @@ def CaptureIndexInSpans():
ret : tvm.transform.Pass
The pass.
"""
return _ffi_api.CaptureIndexInSpans()
return _ffi_api.CapturePostDfsIndexInSpans()


def InlineCompilerFunctionsBoundTo(global_vars):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ namespace transform {
namespace {

/*! \brief Update all the spans to capture their post-dfs index. */
class CaptureIndexInSpansRewriter : public ExprRewriter {
class SpansRewriter : public ExprRewriter {
public:
explicit CaptureIndexInSpansRewriter(const IndexedGraph<Expr>* indexed_graph)
explicit SpansRewriter(const IndexedGraph<Expr>* indexed_graph)
: source_name_(SourceName::Get("index")), indexed_graph_(indexed_graph) {}

private:
Expand Down Expand Up @@ -117,16 +117,17 @@ class CaptureIndexInSpansRewriter : public ExprRewriter {

} // namespace

tvm::transform::Pass CaptureIndexInSpans() {
tvm::transform::Pass CapturePostDfsIndexInSpans() {
auto pass_func = [](Function f, IRModule m, transform::PassContext ctxt) {
std::unique_ptr<IndexedGraph<Expr>> indexed_graph = CreateIndexedGraph(f);
CaptureIndexInSpansRewriter rewriter(indexed_graph.get());
SpansRewriter rewriter(indexed_graph.get());
return Downcast<Function>(PostOrderRewrite(f, &rewriter));
};
return CreateFunctionPass(pass_func, 0, "CaptureIndexInSpans", {});
return CreateFunctionPass(pass_func, 0, "CapturePostDfsIndexInSpans", {});
}

TVM_REGISTER_GLOBAL("relay._transform.CaptureIndexInSpans").set_body_typed(CaptureIndexInSpans);
TVM_REGISTER_GLOBAL("relay._transform.CapturePostDfsIndexInSpans")
.set_body_typed(CapturePostDfsIndexInSpans);

} // namespace transform
} // namespace relay
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License
"""Unit tests for the CaptureIndexInSpans debugging pass."""
"""Unit tests for the CapturePostDfsIndexInSpans debugging pass."""

import tvm
import tvm.testing
Expand Down Expand Up @@ -83,7 +83,7 @@ def @main(%x0 : Tensor[(1600, 768), float16], %x3 : Tensor[(600, 32, 64), float1


def test_capture_index_in_spans():
output_mod = str(tvm.relay.transform.CaptureIndexInSpans()(input_mod()))
output_mod = str(tvm.relay.transform.CapturePostDfsIndexInSpans()(input_mod()))
assert output_mod == expected_pretty_printed_output_mod


Expand Down

0 comments on commit 1b3eb16

Please sign in to comment.