Skip to content

Commit

Permalink
[Meta Schedule][M3a] Traced Schedule
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Aug 2, 2021
1 parent 7653972 commit 568064d
Show file tree
Hide file tree
Showing 6 changed files with 285 additions and 4 deletions.
17 changes: 17 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#define TVM_TIR_SCHEDULE_SCHEDULE_H_

#include <tvm/tir/schedule/state.h>
#include <tvm/tir/schedule/trace.h>

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -102,6 +103,8 @@ class ScheduleNode : public runtime::Object {
virtual IRModule mod() const { return state()->mod; }
/*! \return The internal state of scheduling */
virtual ScheduleState state() const = 0;
/*! \return The internally maintained trace of scheduling program execution */
virtual Optional<Trace> trace() const = 0;
/*!
* \brief Returns a copy of the schedule, including both its state and its symbol table,
* guaranteeing that
Expand Down Expand Up @@ -299,6 +302,20 @@ class Schedule : public runtime::ObjectRef {
*/
TVM_DLL static Schedule Concrete(IRModule mod, int debug_mode,
ScheduleErrorRenderLevel error_render_level);
/*!
* \brief Construct a traced concrete TensorIR schedule from an IRModule
* \param mod The IRModule to be scheduled
* \param debug_mode Do extra correctness checking after the class creation
* and each time after calling the Replace method.
* \param error_render_level The level of error rendering
* \return The concrete schedule created
* \sa ScheduleDebugMask
* \note The checks performed includes:
* 1) VerifySRefTree
* 2) VerifyCachedFlags
*/
TVM_DLL static Schedule Traced(IRModule mod, int debug_mode,
ScheduleErrorRenderLevel error_render_level);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode);
};

Expand Down
24 changes: 22 additions & 2 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from . import _ffi_api
from .state import ScheduleState, StmtSRef
from .trace import Trace


@register_error
Expand Down Expand Up @@ -93,6 +94,7 @@ def __init__(
*,
debug_mode: Union[bool, int] = False,
error_render_level: ERROR_RENDER_LEVEL_CANDIDATES = "detail",
traced: bool = False,
) -> None:
"""Construct a concrete TensorIR schedule from an IRModule or a PrimFunc
Expand All @@ -108,29 +110,42 @@ def __init__(
"detail": Render a detailed error message, with the TIR and error locations printed
"fast: Show a simple error message without rendering or string manipulation
"none": Do not show any error message.
traced : bool = False
A flag indicating if the scheduling process is being traced.
If set to true, users are able to print the inspect the instructions executed so far
by printing `Schedule.trace`
Note
----
The checks performed includes:
1) VerifySRefTree
2) VerifyCachedFlags
"""
# preprocess `mod`
if isinstance(mod, PrimFunc):
mod = IRModule({"main": mod})
# preprocess `debug_mode`
if isinstance(debug_mode, bool):
if debug_mode:
debug_mode = -1
else:
debug_mode = 0
if not isinstance(debug_mode, int):
raise TypeError(f"`debug_mode` should be integer or boolean, but gets: {debug_mode}")
# preprocess `error_render_level`
if error_render_level not in Schedule.ERROR_RENDER_LEVEL:
raise ValueError(
'error_render_level can be "detail", "fast", or "none", but got: '
+ f"{error_render_level}"
)
# preprocess `traced`
if traced:
f_constructor = _ffi_api.TracedSchedule # type: ignore # pylint: disable=no-member
else:
f_constructor = _ffi_api.ConcreteSchedule # type: ignore # pylint: disable=no-member
# call the constructor
self.__init_handle_by_constructor__(
_ffi_api.ConcreteSchedule, # type: ignore # pylint: disable=no-member
f_constructor,
mod,
debug_mode,
Schedule.ERROR_RENDER_LEVEL.get(error_render_level),
Expand All @@ -141,13 +156,18 @@ def __init__(
@property
def mod(self) -> IRModule:
"""Returns the AST of the module being scheduled"""
return _ffi_api.ScheduleModule(self) # type: ignore # pylint: disable=no-member
return _ffi_api.ScheduleGetMod(self) # type: ignore # pylint: disable=no-member

@property
def state(self) -> ScheduleState:
"""Returns the ScheduleState in the current schedule class"""
return _ffi_api.ScheduleGetState(self) # type: ignore # pylint: disable=no-member

@property
def trace(self) -> Optional[Trace]:
"""Returns the internally maintained trace of scheduling program execution"""
return _ffi_api.ScheduleGetTrace(self) # type: ignore # pylint: disable=no-member

def copy(self) -> "Schedule":
"""Returns a copy of the schedule, including both the state and the symbol table,
* guaranteeing that
Expand Down
3 changes: 2 additions & 1 deletion src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ class ConcreteScheduleNode : public ScheduleNode {

public:
void VisitAttrs(tvm::AttrVisitor* v) {
// `error_render_level_` is not visited
// `state_` is not visited
// `error_render_level_` is not visited
// `symbol_table_` is not visited
// `analyzer_` is not visitied
}
Expand All @@ -59,6 +59,7 @@ class ConcreteScheduleNode : public ScheduleNode {

public:
ScheduleState state() const final { return state_; }
Optional<Trace> trace() const override { return NullOpt; }
Schedule Copy() const override;

public:
Expand Down
4 changes: 3 additions & 1 deletion src/tir/schedule/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,12 @@ TVM_REGISTER_NODE_TYPE(BlockRVNode);
TVM_REGISTER_NODE_TYPE(LoopRVNode);
TVM_REGISTER_OBJECT_TYPE(ScheduleNode);

TVM_REGISTER_GLOBAL("tir.schedule.ScheduleModule") //
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetMod") //
.set_body_method<Schedule>(&ScheduleNode::mod);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetState") //
.set_body_method<Schedule>(&ScheduleNode::state);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetTrace") //
.set_body_method<Schedule>(&ScheduleNode::trace);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSeed") //
.set_body_method<Schedule>(&ScheduleNode::Seed);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCopy") //
Expand Down
165 changes: 165 additions & 0 deletions src/tir/schedule/traced_schedule.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "./traced_schedule.h"

namespace tvm {
namespace tir {

Schedule Schedule::Traced(IRModule mod, int debug_mode,
ScheduleErrorRenderLevel error_render_level) {
ObjectPtr<TracedScheduleNode> n = make_object<TracedScheduleNode>();
n->state_ = ScheduleState(mod, debug_mode);
n->error_render_level_ = error_render_level;
n->symbol_table_ = {};
n->analyzer_ = std::make_unique<arith::Analyzer>();
n->trace_ = Trace();
return Schedule(std::move(n));
}

Schedule TracedScheduleNode::Copy() const {
ObjectPtr<TracedScheduleNode> n = make_object<TracedScheduleNode>();
ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_);
n->error_render_level_ = this->error_render_level_;
n->analyzer_ = std::make_unique<arith::Analyzer>();
n->trace_ = Trace(this->trace_->insts, this->trace_->decisions);
return Schedule(std::move(n));
}

/******** Schedule: Sampling ********/

/******** Schedule: Get blocks & loops ********/

BlockRV TracedScheduleNode::GetBlock(const String& name, const String& func_name) {
BlockRV result = ConcreteScheduleNode::GetBlock(name, func_name);

static const InstructionKind& kind = InstructionKind::Get("GetBlock");
trace_->Append(/*inst=*/Instruction(/*kind=*/kind, //
/*inputs=*/{},
/*attrs=*/{name, func_name},
/*outputs=*/{result}));
return result;
}

Array<LoopRV> TracedScheduleNode::GetLoops(const BlockRV& block_rv) {
Array<LoopRV> results = ConcreteScheduleNode::GetLoops(block_rv);

static const InstructionKind& kind = InstructionKind::Get("GetLoops");
trace_->Append(/*inst=*/Instruction(/*kind=*/kind, //
/*inputs=*/{block_rv},
/*attrs=*/{},
/*outputs=*/{results.begin(), results.end()}));
return results;
}

/******** Schedule: Transform loops ********/

LoopRV TracedScheduleNode::Fuse(const Array<LoopRV>& loop_rvs) {
LoopRV result = ConcreteScheduleNode::Fuse(loop_rvs);

static const InstructionKind& kind = InstructionKind::Get("Fuse");
trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
/*inputs=*/{loop_rvs.begin(), loop_rvs.end()},
/*attrs=*/{},
/*outputs=*/{result}));
return result;
}

Array<LoopRV> TracedScheduleNode::Split(const LoopRV& loop_rv,
const Array<Optional<ExprRV>>& factor_rvs) {
Array<LoopRV> results = ConcreteScheduleNode::Split(loop_rv, factor_rvs);

std::vector<ObjectRef> inputs;
inputs.reserve(1 + factor_rvs.size());
inputs.push_back(loop_rv);
for (const ObjectRef& obj : factor_rvs) {
inputs.push_back(obj);
}

static const InstructionKind& kind = InstructionKind::Get("Split");
trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
/*inputs=*/inputs,
/*attrs=*/{},
/*outputs=*/{results.begin(), results.end()}));
return results;
}

/******** Schedule: Manipulate ForKind ********/

/******** Schedule: Insert cache stages ********/

/******** Schedule: Compute location ********/

void TracedScheduleNode::ComputeInline(const BlockRV& block_rv) {
ConcreteScheduleNode::ComputeInline(block_rv);

static const InstructionKind& kind = InstructionKind::Get("ComputeInline");
trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
/*inputs=*/{block_rv},
/*attrs=*/{},
/*outputs=*/{}));
}

void TracedScheduleNode::ReverseComputeInline(const BlockRV& block_rv) {
ConcreteScheduleNode::ReverseComputeInline(block_rv);

static const InstructionKind& kind = InstructionKind::Get("ReverseComputeInline");
trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
/*inputs=*/{block_rv},
/*attrs=*/{},
/*outputs=*/{}));
}

/******** Schedule: Reduction ********/

BlockRV TracedScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) {
BlockRV result = ConcreteScheduleNode::RFactor(loop_rv, factor_axis);
static const InstructionKind& kind = InstructionKind::Get("RFactor");
trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
/*inputs=*/{loop_rv},
/*attrs=*/{Integer(factor_axis)},
/*outputs=*/{result}));
return result;
}

/******** Schedule: Blockize & Tensorize ********/

/******** Schedule: Annotation ********/

/******** Schedule: Misc ********/

void TracedScheduleNode::EnterPostproc() {
ConcreteScheduleNode::EnterPostproc();
static const InstructionKind& kind = InstructionKind::Get("EnterPostproc");
trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
/*inputs=*/{},
/*attrs=*/{},
/*outputs=*/{}));
}

/******** FFI ********/

TVM_REGISTER_NODE_TYPE(TracedScheduleNode);
TVM_REGISTER_GLOBAL("tir.schedule.TracedSchedule")
.set_body_typed([](IRModule mod, int debug_mode, int error_render_level) -> Schedule {
return Schedule::Traced(mod, debug_mode,
static_cast<ScheduleErrorRenderLevel>(error_render_level));
});

} // namespace tir
} // namespace tvm
76 changes: 76 additions & 0 deletions src/tir/schedule/traced_schedule.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_TIR_SCHEDULE_TRACED_SCHEDULE_H_
#define TVM_TIR_SCHEDULE_TRACED_SCHEDULE_H_

#include "./concrete_schedule.h"

namespace tvm {
namespace tir {

class TracedScheduleNode : public ConcreteScheduleNode {
friend class Schedule;

protected:
Trace trace_;

public:
void VisitAttrs(tvm::AttrVisitor* v) {
// `state_` is not visited
// `error_render_level_` is not visited
// `symbol_table_` is not visited
// `analyzer_` is not visitied
// `trace_` is not visited
}

~TracedScheduleNode() = default;

static constexpr const char* _type_key = "tir.TracedSchedule";
TVM_DECLARE_FINAL_OBJECT_INFO(TracedScheduleNode, ScheduleNode);

public:
Optional<Trace> trace() const final { return trace_; }
Schedule Copy() const final;

public:
/******** Schedule: Sampling ********/

/******** Schedule: Get blocks & loops ********/
BlockRV GetBlock(const String& name, const String& func_name = "main") final;
Array<LoopRV> GetLoops(const BlockRV& block_rv) final;
/******** Schedule: Transform loops ********/
LoopRV Fuse(const Array<LoopRV>& loop_rvs) final;
Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factor_rvs) final;
/******** Schedule: Manipulate ForKind ********/
/******** Schedule: Insert cache stages ********/
/******** Schedule: Compute location ********/
void ComputeInline(const BlockRV& block_rv) final;
void ReverseComputeInline(const BlockRV& block_rv) final;
/******** Schedule: Reduction ********/
BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) final;
/******** Schedule: Blockize & Tensorize ********/
/******** Schedule: Annotation ********/
/******** Schedule: Misc ********/
void EnterPostproc() final;
};

} // namespace tir
} // namespace tvm

#endif // TVM_TIR_SCHEDULE_TRACED_SCHEDULE_H_

0 comments on commit 568064d

Please sign in to comment.