Skip to content

Commit

Permalink
[MetaSchedule] Postproc: Rewrite-Layout (apache#11884)
Browse files Browse the repository at this point in the history
  • Loading branch information
jinhongyii authored and zxybazh committed Jun 26, 2022
1 parent 2fa01c0 commit c582779
Show file tree
Hide file tree
Showing 7 changed files with 358 additions and 26 deletions.
7 changes: 6 additions & 1 deletion include/tvm/meta_schedule/postproc.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,17 @@ class Postproc : public runtime::ObjectRef {
* \return The postprocessor created.
*/
TVM_DLL static Postproc RewriteTensorize(bool vectorize_init_loop = false);

/*!
* \brief Creates a postprocessor that verifies if the GPU code is correct
* \return The postprocessor created
*/
TVM_DLL static Postproc VerifyGPUCode();
/*!
* \brief Creates a postprocessor that rewrites the layout of input tensor
* \note Weight layout rewrite is supported so far, activation layout rewrite will be added.
* \return The postprocessor created
*/
TVM_DLL static Postproc RewriteLayout();
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Postproc, ObjectRef, PostprocNode);
};

Expand Down
65 changes: 42 additions & 23 deletions python/tvm/auto_scheduler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,45 +17,64 @@
# pylint: disable=unused-import, redefined-builtin
""" Namespace for TVM Auto-scheduler. """

from . import compute_dag
from . import dispatcher
from . import feature
from . import loop_state
from . import measure
from . import measure_record
from . import relay_integration
from . import search_policy
from . import search_task
from . import task_scheduler
from . import utils
from . import workload_registry
from . import (
compute_dag,
dispatcher,
feature,
loop_state,
measure,
measure_record,
relay_integration,
search_policy,
search_task,
task_scheduler,
utils,
workload_registry,
)

# Shortcut
from .compute_dag import ComputeDAG, LayoutRewriteOption, get_shape_from_rewritten_layout
from .compute_dag import (
ComputeDAG,
LayoutRewriteOption,
get_shape_from_rewritten_layout,
)
from .cost_model import RandomModel, XGBModel
from .dispatcher import DispatchContext, ApplyHistoryBest, ApplyHistoryBestOrSample
from .dispatcher import ApplyHistoryBest, ApplyHistoryBestOrSample, DispatchContext
from .measure import (
MeasureInput,
MeasureResult,
LocalBuilder,
LocalRPCMeasureContext,
LocalRunner,
MeasureInput,
MeasureResult,
RPCRunner,
LocalRPCMeasureContext,
register_task_input_check_func,
)
from .measure_record import RecordToFile, RecordReader, load_best_record, load_records, save_records
from .measure_record import (
RecordReader,
RecordToFile,
load_best_record,
load_records,
save_records,
)
from .relay_integration import (
extract_tasks,
is_auto_scheduler_enabled,
remove_index_check,
rewrite_compute_body,
is_auto_scheduler_enabled,
rewrite_tensor_shape,
)
from .search_task import SearchTask, TuningOptions, HardwareParams, create_task, auto_schedule
from .search_policy import (
EmptyPolicy,
SketchPolicy,
PreloadMeasuredStates,
PreloadCustomSketchRule,
PreloadMeasuredStates,
SketchPolicy,
)
from .search_task import (
HardwareParams,
SearchTask,
TuningOptions,
auto_schedule,
create_task,
)
from .task_scheduler import TaskScheduler
from .workload_registry import register_workload, make_workload_key
from .workload_registry import make_workload_key, register_workload
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/default_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def postprocs() -> List[Postproc]:
M.DisallowDynamicLoop(),
M.RewriteParallelVectorizeUnroll(),
M.RewriteReductionBlock(),
M.RewriteLayout(),
]

@staticmethod
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/meta_schedule/postproc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
# specific language governing permissions and limitations
# under the License.
"""The tvm.meta_schedule.postproc package."""
from .postproc import Postproc, PyPostproc
from .disallow_dynamic_loop import DisallowDynamicLoop
from .postproc import Postproc, PyPostproc
from .rewrite_cooperative_fetch import RewriteCooperativeFetch
from .rewrite_layout import RewriteLayout
from .rewrite_parallel_vectorize_unroll import RewriteParallelVectorizeUnroll
from .rewrite_reduction_block import RewriteReductionBlock
from .rewrite_tensorize import RewriteTensorize
from .rewrite_unbound_block import RewriteUnboundBlock
from .verify_gpu_code import VerifyGPUCode
from .rewrite_tensorize import RewriteTensorize
32 changes: 32 additions & 0 deletions python/tvm/meta_schedule/postproc/rewrite_layout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""A postprocessor that rewrites the layout of input tensor"""

from tvm._ffi.registry import register_object

from .. import _ffi_api
from .postproc import Postproc


@register_object("meta_schedule.RewriteLayout")
class RewriteLayout(Postproc):
"""A postprocessor that rewrites the layout of input tensor"""

def __init__(self) -> None:
self.__init_handle_by_constructor__(
_ffi_api.PostprocRewriteLayout, # type: ignore # pylint: disable=no-member
)
183 changes: 183 additions & 0 deletions src/meta_schedule/postproc/rewrite_layout.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "../utils.h"

namespace tvm {
namespace tir {

/*!
* \brief Collect the block and index where the buffer is read.
* \note The buffers are expected to be read by only one BufferLoad
*/
class BufferReadPosCollector : public StmtExprVisitor {
public:
explicit BufferReadPosCollector(const Array<Buffer>& buffers) {
for (const Buffer& buf : buffers) {
buffers_.insert(buf.get());
}
}

const std::unordered_map<const BufferNode*, std::pair<Block, int>>& GetBufferLocations() const {
return buffer_locs_;
}

const std::unordered_map<const BufferNode*, Optional<IndexMap>>& GetBufferIndexMap() const {
return buffer_index_maps_;
}

private:
void VisitStmt_(const ForNode* op) final {
loop_stack_.push_back(GetRef<For>(op));
StmtVisitor::VisitStmt_(op);
loop_stack_.pop_back();
}

void VisitStmt_(const BlockRealizeNode* op) final {
BlockRealize outer_block_realize = GetRef<BlockRealize>(op);
std::swap(outer_block_realize, cur_realize_);
StmtVisitor::VisitStmt_(op);
std::swap(cur_realize_, outer_block_realize);
}

void VisitExpr_(const BufferLoadNode* op) final {
const Buffer& buffer = op->buffer;
if (buffers_.count(buffer.get())) {
Map<Var, PrimExpr> subst_map;
for (size_t i = 0; i < cur_realize_->iter_values.size(); i++) {
const Var& var = cur_realize_->block->iter_vars[i]->var;
const PrimExpr& value = cur_realize_->iter_values[i];
subst_map.Set(var, value);
}
Array<PrimExpr> subst_indices;
for (const PrimExpr& e : op->indices) {
subst_indices.push_back(Substitute(e, subst_map));
}
buffer_index_maps_[buffer.get()] = SuggestIndexMap(/*buffer=*/buffer, //
/*indices=*/subst_indices, //
/*loops=*/loop_stack_, //
/*predicate=*/cur_realize_->predicate, //
/*analyzer=*/&analyzer_);
int buffer_index = GetReadBufferIndex(cur_realize_->block, buffer);
ICHECK(buffer_index != -1);
buffer_locs_[buffer.get()] = std::make_pair(cur_realize_->block, buffer_index);
}
}

static int GetReadBufferIndex(const Block& block, const Buffer& buffer) {
for (size_t i = 0; i < block->reads.size(); i++) {
if (block->reads[i]->buffer.same_as(buffer)) {
return i;
}
}
return -1;
}

private:
/*! \brief All interested buffer. */
std::unordered_set<const BufferNode*> buffers_;
/*! \brief The result mapping from buffer to its inner-most block and read index. */
std::unordered_map<const BufferNode*, std::pair<Block, int>> buffer_locs_;
/*! \brief The result mapping from buffer to its IndexMap. */
std::unordered_map<const BufferNode*, Optional<IndexMap>> buffer_index_maps_;

/*! \brief Loop stack for calculating IndexMap. */
Array<For> loop_stack_;
/*! \brief Arithmetic analyzer. */
arith::Analyzer analyzer_;
/*! \brief Current BlockRealize scope, used in recursive visit */
BlockRealize cur_realize_;
};

bool RewriteLayout(const Schedule& sch) {
std::vector<std::pair<StmtSRef, String>> results;
for (const auto& kv : sch->mod()->functions) {
const GlobalVar& g_var = kv.first;
const String& func_name = g_var->name_hint;
const auto* prim_func = kv.second.as<PrimFuncNode>();
// Only consider PrimFunc
if (prim_func == nullptr) {
continue;
}
// Only rewrite PrimFuncs with attr "layout_free_buffers"
Array<Integer> layout_free_buffer_index =
prim_func->GetAttr(attr::layout_free_buffers, Array<Integer>()).value();

Array<Buffer> layout_free_buffers;
for (const Integer& index : layout_free_buffer_index) {
const Var& param = prim_func->params[index->value];
layout_free_buffers.push_back(prim_func->buffer_map.at(param));
}
// Collect Buffer read positions
BufferReadPosCollector collector(layout_free_buffers);
collector(prim_func->body);
const auto& locations = collector.GetBufferLocations();
const auto& index_maps = collector.GetBufferIndexMap();
// Check all buffers are collected
if (locations.size() != layout_free_buffers.size() ||
index_maps.size() != layout_free_buffer_index.size()) {
return false;
}

for (const auto& kv : locations) {
const Buffer& buffer = GetRef<Buffer>(kv.first);
const Block& block = kv.second.first;
int buffer_index = kv.second.second;

// Get IndexMap
const Optional<IndexMap> index_map = index_maps.at(buffer.get());
if (!index_map.defined()) {
continue;
}

// Apply schedule
BlockRV block_rv = sch->GetBlock(block->name_hint, func_name);
BlockRV cached_block_rv = sch->CacheRead(block_rv, buffer_index, "global");
sch->TransformLayout(block_rv, buffer_index, BufferIndexType::kRead, index_map.value());
sch->Annotate(cached_block_rv, attr::meta_schedule_layout_rewrite_preproc, const_true());
}
}
return true;
}

} // namespace tir

namespace meta_schedule {
/*! \brief Layout Rewrite. */
class RewriteLayoutNode : public PostprocNode {
public:
// Inherited from PostprocNode
void InitializeWithTuneContext(const TuneContext& context) final {}

// Inherited from PostprocNode
bool Apply(const tir::Schedule& sch) final { return tir::RewriteLayout(sch); }

static constexpr const char* _type_key = "meta_schedule.RewriteLayout";
TVM_DECLARE_FINAL_OBJECT_INFO(RewriteLayoutNode, PostprocNode);
};

Postproc Postproc::RewriteLayout() {
auto n = make_object<RewriteLayoutNode>();
return Postproc(n);
}

TVM_REGISTER_NODE_TYPE(RewriteLayoutNode);
TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteLayout").set_body_typed(Postproc::RewriteLayout);

} // namespace meta_schedule
} // namespace tvm
Loading

0 comments on commit c582779

Please sign in to comment.