Skip to content

Commit

Permalink
[MetaSchedule] Rewrite Cooperative-Fetching / Unbound-Block / Reducti…
Browse files Browse the repository at this point in the history
…on-Block (apache#509)
  • Loading branch information
junrushao authored Nov 11, 2021
1 parent ad7adb3 commit dcf7310
Show file tree
Hide file tree
Showing 31 changed files with 1,253 additions and 161 deletions.
45 changes: 29 additions & 16 deletions include/tvm/meta_schedule/postproc.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,7 @@ namespace meta_schedule {
class TuneContext;

/*!
* \brief Rules to apply a post processing to a schedule.
* \note Post processing is designed to deal with the problem of undertermined schedule validity
* after applying some schedule primitves at runtime. E.g., Fuse the first X loops to reach the
* maximum number below 1024, X is only decided at runtime.
* \brief Rules to apply a postprocessor to a schedule.
*/
class PostprocNode : public runtime::Object {
public:
Expand All @@ -47,17 +44,17 @@ class PostprocNode : public runtime::Object {
virtual void InitializeWithTuneContext(const TuneContext& context) = 0;

/*!
* \brief Apply a post processing to the given schedule.
* \brief Apply a postprocessor to the given schedule.
* \param sch The schedule to be post processed.
* \return Whether the post processing was successfully applied.
* \return Whether the postprocessor was successfully applied.
*/
virtual bool Apply(const tir::Schedule& sch) = 0;

static constexpr const char* _type_key = "meta_schedule.Postproc";
TVM_DECLARE_BASE_OBJECT_INFO(PostprocNode, Object);
};

/*! \brief The post processing with customized methods on the python-side. */
/*! \brief The postprocessor with customized methods on the python-side. */
class PyPostprocNode : public PostprocNode {
public:
/*!
Expand All @@ -66,22 +63,22 @@ class PyPostprocNode : public PostprocNode {
*/
using FInitializeWithTuneContext = runtime::TypedPackedFunc<void(const TuneContext&)>;
/*!
* \brief Apply a post processing to the given schedule.
* \brief Apply a postprocessor to the given schedule.
* \param sch The schedule to be post processed.
* \return Whether the post processing was successfully applied.
* \return Whether the postprocessor was successfully applied.
*/
using FApply = runtime::TypedPackedFunc<bool(const tir::Schedule&)>;
/*!
* \brief Get the post processing function as string with name.
* \return The string of the post processing function.
* \brief Get the postprocessor function as string with name.
* \return The string of the postprocessor function.
*/
using FAsString = runtime::TypedPackedFunc<String()>;

/*! \brief The packed function to the `InitializeWithTuneContext` funcion. */
/*! \brief The packed function to the `InitializeWithTuneContext` function. */
FInitializeWithTuneContext f_initialize_with_tune_context;
/*! \brief The packed function to the `Apply` funcion. */
/*! \brief The packed function to the `Apply` function. */
FApply f_apply;
/*! \brief The packed function to the `AsString` funcion. */
/*! \brief The packed function to the `AsString` function. */
FAsString f_as_string;

void VisitAttrs(tvm::AttrVisitor* v) {
Expand Down Expand Up @@ -112,15 +109,31 @@ class PyPostprocNode : public PostprocNode {
class Postproc : public runtime::ObjectRef {
public:
/*!
* \brief Create a post processing with customized methods on the python-side.
* \brief Create a postprocessor with customized methods on the python-side.
* \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
* \param f_apply The packed function of `Apply`.
* \return The post processing created.
* \return The postprocessor created.
*/
TVM_DLL static Postproc PyPostproc(
PyPostprocNode::FInitializeWithTuneContext f_initialize_with_tune_context, //
PyPostprocNode::FApply f_apply, //
PyPostprocNode::FAsString f_as_string);
/*!
* \brief Create a postprocessor that rewrites the cooperative fetch annotation to
* actual vectorized cooperative fetching in loop bindings.
* \return The postprocessor created.
*/
TVM_DLL static Postproc RewriteCooperativeFetch();
/*!
* \brief Create a postprocessor that rewrites reduction block by moving the init block out.
* \return The postprocessor created.
*/
TVM_DLL static Postproc RewriteReductionBlock();
/*!
* \brief Create a postprocessor that adds thread binding to unbound blocks
* \return The postprocessor created.
*/
TVM_DLL static Postproc RewriteUnboundBlock();
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Postproc, ObjectRef, PostprocNode);
};

Expand Down
4 changes: 2 additions & 2 deletions include/tvm/meta_schedule/tune_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class TuneContextNode : public runtime::Object {
Optional<SearchStrategy> search_strategy;
/*! \brief The schedule rules. */
Optional<Array<ScheduleRule>> sch_rules;
/*! \brief The post processings. */
/*! \brief The postprocessors. */
Optional<Array<Postproc>> postprocs;
/*! \brief The mutators. */
Optional<Array<Mutator>> mutators;
Expand Down Expand Up @@ -95,7 +95,7 @@ class TuneContext : public runtime::ObjectRef {
* \param space_generator The design space generator.
* \param search_strategy The search strategy.
* \param sch_rules The schedule rules.
* \param postprocs The post processings.
* \param postprocs The postprocessors.
* \param mutators The mutators.
* \param task_name The name of the tuning task.
* \param rand_state The random state.
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1359,7 +1359,7 @@ constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint";
* \brief Mark that the loop should be further skip and bound to environment threads to enable
* cooperative fetching.
*/
constexpr const char* meta_schedule_lazy_cooperative_fetch = "meta_schedule.lazy_cooperative_fetch";
constexpr const char* meta_schedule_cooperative_fetch = "meta_schedule.cooperative_fetch";

/*!
* \brief Mark a block as generated by cache_read or cache_write block.
Expand Down
10 changes: 4 additions & 6 deletions python/tvm/meta_schedule/postproc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
The tvm.meta_schedule.postproc package.
Meta Schedule post processings that deal with the problem of
undertermined schedule validity after applying some schedule
primitves at runtime.
"""
"""The tvm.meta_schedule.postproc package."""
from .postproc import Postproc, PyPostproc
from .rewrite_cooperative_fetch import RewriteCooperativeFetch
from .rewrite_reduction_block import RewriteReductionBlock
from .rewrite_unbound_block import RewriteUnboundBlock
19 changes: 6 additions & 13 deletions python/tvm/meta_schedule/postproc/postproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,29 +31,22 @@

@register_object("meta_schedule.Postproc")
class Postproc(Object):
"""Rules to apply a post processing to a schedule.
Note
----
Post processing is designed to deal with the problem of undertermined schedule validity after
applying some schedule primitves at runtime. E.g., Fuse the first X loops to reach the maximum
number below 1024, X is only decided at runtime.
"""
"""Rules to apply a postprocessor to a schedule."""

def initialize_with_tune_context(self, tune_context: "TuneContext") -> None:
"""Initialize the post processing with a tune context.
"""Initialize the postprocessor with a tune context.
Parameters
----------
tune_context : TuneContext
The tuning context for initializing the post processing.
The tuning context for initializing the postprocessor.
"""
_ffi_api.PostprocInitializeWithTuneContext( # type: ignore # pylint: disable=no-member
self, tune_context
)

def apply(self, sch: Schedule) -> bool:
"""Apply a post processing to the given schedule.
"""Apply a postprocessor to the given schedule.
Parameters
----------
Expand All @@ -63,9 +56,9 @@ def apply(self, sch: Schedule) -> bool:
Returns
-------
result : bool
Whether the post processing was successfully applied.
Whether the postprocessor was successfully applied.
"""
return _ffi_api.PostprocApply(self, sch)
return _ffi_api.PostprocApply(self, sch) # type: ignore # pylint: disable=no-member


@register_object("meta_schedule.PyPostproc")
Expand Down
34 changes: 34 additions & 0 deletions python/tvm/meta_schedule/postproc/rewrite_cooperative_fetch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""A postprocessor that rewrites the cooperative fetch annotation to actual
vectorized cooperative fetching in loop bindings."""

from tvm._ffi.registry import register_object
from .. import _ffi_api
from .postproc import Postproc


@register_object("meta_schedule.RewriteCooperativeFetch")
class RewriteCooperativeFetch(Postproc):
"""A postprocessor that rewrites the cooperative fetch annotation to actual vectorized
cooperative fetching in loop bindings.
"""

def __init__(self) -> None:
self.__init_handle_by_constructor__(
_ffi_api.PostprocRewriteCooperativeFetch, # type: ignore # pylint: disable=no-member
)
31 changes: 31 additions & 0 deletions python/tvm/meta_schedule/postproc/rewrite_reduction_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""A postprocessor that rewrites reduction block by moving the init block out."""

from tvm._ffi.registry import register_object
from .. import _ffi_api
from .postproc import Postproc


@register_object("meta_schedule.RewriteReductionBlock")
class RewriteReductionBlock(Postproc):
"""A postprocessor that rewrites reduction block by moving the init block out."""

def __init__(self) -> None:
self.__init_handle_by_constructor__(
_ffi_api.PostprocRewriteReductionBlock, # type: ignore # pylint: disable=no-member
)
31 changes: 31 additions & 0 deletions python/tvm/meta_schedule/postproc/rewrite_unbound_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""A postprocessor that adds thread binding to unbound blocks"""

from tvm._ffi.registry import register_object
from .. import _ffi_api
from .postproc import Postproc


@register_object("meta_schedule.RewriteUnboundBlock")
class RewriteUnboundBlock(Postproc):
"""A postprocessor that adds thread binding to unbound blocks"""

def __init__(self) -> None:
self.__init_handle_by_constructor__(
_ffi_api.PostprocRewriteUnboundBlock, # type: ignore # pylint: disable=no-member
)
4 changes: 2 additions & 2 deletions python/tvm/meta_schedule/tune_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class TuneContext(Object):
sch_rules: Optional[List[ScheduleRule]] = None,
The schedule rules.
postproc: Optional[List[Postproc"]] = None,
The post processings.
The postprocessors.
mutator: Optional[List[Mutator]] = None,
The mutators.
task_name : Optional[str] = None
Expand Down Expand Up @@ -115,7 +115,7 @@ def __init__(
sch_rules : List[ScheduleRule] = []
The schedule rules.
postproc : List[Postproc] = []
The post-processors.
The postprocessors.
mutator : List[Mutator] = []
The mutators.
task_name : Optional[str] = None
Expand Down
Loading

0 comments on commit dcf7310

Please sign in to comment.