Skip to content

Commit

Permalink
Squashed commit
Browse files Browse the repository at this point in the history
[Meta Schedule][M3c] Schedule Rules, Mutator & Postprocs (apache#485)

[Meta Schedule][M3c] PostOrderApply (apache#486)

Fix Post Order Apply (apache#490)

[MetaSchedule] Relay Integration (apache#489)

[M3c][Meta Schedule] Add Trace Correctness Test for PostOrderApply (apache#492)

Fix replay trace. (apache#493)

[M3c][Meta Schedule] Implement the Replay Func class. (apache#495)

[PR] Test script for meta-schedule task extraction. Interface to load… (apache#494)

[Meta Schedule Refactor] Get child blocks (apache#500)

Read-at && Write-at (apache#497)

[M3c][Meta Schedule] Measure Callbacks (apache#498)

[Bug] Fix Infinite Loop Caused When Calling Methods Not Overrided In PyClass (apache#496)

[MetaSchedule] Sample-Perfect-Tile (apache#501)

[MetaSchedule] TE Workloads (apache#502)

Co-authored-by: Siyuan Feng <[email protected]>
Co-authored-by: Bohan Hou <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Wuwei Lin <[email protected]>
Co-authored-by: Sunghyun Park <[email protected]>
  • Loading branch information
7 people committed Nov 5, 2021
1 parent 048994b commit b5eb32d
Show file tree
Hide file tree
Showing 64 changed files with 4,500 additions and 61 deletions.
126 changes: 126 additions & 0 deletions include/tvm/meta_schedule/measure_callback.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#ifndef TVM_META_SCHEDULE_MEASURE_CALLBACK_H_
#define TVM_META_SCHEDULE_MEASURE_CALLBACK_H_

#include <tvm/meta_schedule/builder.h>
#include <tvm/meta_schedule/runner.h>
#include <tvm/meta_schedule/search_strategy.h>
#include <tvm/meta_schedule/tune_context.h>

namespace tvm {
namespace meta_schedule {

class TaskScheduler;

/*! \brief Rules to apply after measure results is available. */
class MeasureCallbackNode : public runtime::Object {
public:
/*! \brief Virtual destructor. */
virtual ~MeasureCallbackNode() = default;

void VisitAttrs(tvm::AttrVisitor* v) {}

/*!
* \brief Apply a measure callback rule with given arguments.
* \param task_scheduler The task scheduler.
* \param tasks The list of tune context to process.
* \param measure_candidates The measure candidates.
* \param builds The builder results by building the measure candidates.
* \param results The runner results by running the built measure candidates.
* \return Whether the measure callback was successfully applied.
*/
virtual bool Apply(const TaskScheduler& task_scheduler, //
const Array<TuneContext> tasks, //
const Array<MeasureCandidate>& measure_candidates, //
const Array<BuilderResult>& builds, //
const Array<RunnerResult>& results) = 0;

static constexpr const char* _type_key = "meta_schedule.MeasureCallback";
TVM_DECLARE_BASE_OBJECT_INFO(MeasureCallbackNode, Object);
};

/*! \brief The measure callback with customized methods on the python-side. */
class PyMeasureCallbackNode : public MeasureCallbackNode {
public:
/*!
* \brief Apply a measure callback to the given schedule.
* \param task_scheduler The task scheduler.
* \param tasks The list of tune context to process.
* \param measure_candidates The measure candidates.
* \param builds The builder results by building the measure candidates.
* \param results The runner results by running the built measure candidates.
* \return Whether the measure callback was successfully applied.
*/
using FApply =
runtime::TypedPackedFunc<bool(const TaskScheduler& task_scheduler, //
const Array<TuneContext> tasks, //
const Array<MeasureCandidate>& measure_candidates, //
const Array<BuilderResult>& builds, //
const Array<RunnerResult>& results)>;
/*!
* \brief Get the measure callback function as string with name.
* \return The string of the measure callback function.
*/
using FAsString = runtime::TypedPackedFunc<String()>;

/*! \brief The packed function to the `Apply` funcion. */
FApply f_apply;
/*! \brief The packed function to the `AsString` funcion. */
FAsString f_as_string;

void VisitAttrs(tvm::AttrVisitor* v) {
// `f_apply` is not visited
// `f_as_string` is not visited
}

bool Apply(const TaskScheduler& task_scheduler, //
const Array<TuneContext> tasks, //
const Array<MeasureCandidate>& measure_candidates, //
const Array<BuilderResult>& builds, //
const Array<RunnerResult>& results) final {
ICHECK(f_apply != nullptr) << "PyMeasureCallback's Apply method not implemented!";
return this->f_apply(task_scheduler, tasks, measure_candidates, builds, results);
}

static constexpr const char* _type_key = "meta_schedule.PyMeasureCallback";
TVM_DECLARE_FINAL_OBJECT_INFO(PyMeasureCallbackNode, MeasureCallbackNode);
};

/*!
* \brief Managed reference to MeasureCallbackNode
* \sa MeasureCallbackNode
*/
class MeasureCallback : public runtime::ObjectRef {
public:
/*!
* \brief Create a measure callback with customized methods on the python-side.
* \param f_apply The packed function of `Apply`.
* \return The measure callback created.
*/
TVM_DLL static MeasureCallback PyMeasureCallback(PyMeasureCallbackNode::FApply f_apply, //
PyMeasureCallbackNode::FAsString f_as_string);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MeasureCallback, ObjectRef, MeasureCallbackNode);
};

} // namespace meta_schedule
} // namespace tvm

#endif // TVM_META_SCHEDULE_MEASURE_CALLBACK_H_
125 changes: 125 additions & 0 deletions include/tvm/meta_schedule/mutator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#ifndef TVM_META_SCHEDULE_MUTATOR_H_
#define TVM_META_SCHEDULE_MUTATOR_H_

#include <tvm/tir/schedule/schedule.h>

namespace tvm {
namespace meta_schedule {

class TuneContext;

/*! \brief Mutator is designed to mutate the trace to explore the design space. */
class MutatorNode : public runtime::Object {
public:
/*! \brief Virtual destructor. */
virtual ~MutatorNode() = default;

void VisitAttrs(tvm::AttrVisitor* v) {}

/*!
* \brief The function type of `InitializeWithTuneContext` method.
* \param tune_context The tuning context for initialization.
*/
virtual void InitializeWithTuneContext(const TuneContext& context) = 0;

/*!
* \brief Apply the mutator function to the given trace.
* \param trace The given trace for mutation.
* \return None if mutator failed, otherwise return the mutated trace.
*/
virtual Optional<tir::Trace> Apply(const tir::Trace& trace) = 0;

static constexpr const char* _type_key = "meta_schedule.Mutator";
TVM_DECLARE_BASE_OBJECT_INFO(MutatorNode, Object);
};

/*! \brief The mutator with customized methods on the python-side. */
class PyMutatorNode : public MutatorNode {
public:
/*!
* \brief The function type of `InitializeWithTuneContext` method.
* \param tune_context The tuning context for initialization.
*/
using FInitializeWithTuneContext = runtime::TypedPackedFunc<void(const TuneContext&)>;
/*!
* \brief Apply the mutator function to the given trace.
* \param trace The given trace for mutation.
* \return None if mutator failed, otherwise return the mutated trace.
*/
using FApply = runtime::TypedPackedFunc<Optional<tir::Trace>(const tir::Trace&)>;
/*!
* \brief Get the mutator as string with name.
* \return The string of the mutator.
*/
using FAsString = runtime::TypedPackedFunc<String()>;

/*! \brief The packed function to the `InitializeWithTuneContext` funcion. */
FInitializeWithTuneContext f_initialize_with_tune_context;
/*! \brief The packed function to the `Apply` funcion. */
FApply f_apply;
/*! \brief The packed function to the `AsString` funcion. */
FAsString f_as_string;

void VisitAttrs(tvm::AttrVisitor* v) {
// `f_initialize_with_tune_context` is not visited
// `f_apply` is not visited
// `f_as_string` is not visited
}

void InitializeWithTuneContext(const TuneContext& context) final {
ICHECK(f_initialize_with_tune_context != nullptr)
<< "PyMutator's InitializeWithTuneContext method not implemented!";
this->f_initialize_with_tune_context(context);
}

Optional<tir::Trace> Apply(const tir::Trace& trace) final {
ICHECK(f_apply != nullptr) << "PyMutator's Apply method not implemented!";
return this->f_apply(trace);
}

static constexpr const char* _type_key = "meta_schedule.PyMutator";
TVM_DECLARE_FINAL_OBJECT_INFO(PyMutatorNode, MutatorNode);
};

/*!
* \brief Managed reference to MutatorNode
* \sa MutatorNode
*/
class Mutator : public runtime::ObjectRef {
public:
/*!
* \brief Create a mutator with customized methods on the python-side.
* \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
* \param f_apply The packed function of `Apply`.
* \return The mutator created.
*/
TVM_DLL static Mutator PyMutator(
PyMutatorNode::FInitializeWithTuneContext f_initialize_with_tune_context, //
PyMutatorNode::FApply f_apply, //
PyMutatorNode::FAsString f_as_string);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Mutator, ObjectRef, MutatorNode);
};

} // namespace meta_schedule
} // namespace tvm

#endif // TVM_META_SCHEDULE_MUTATOR_H_
130 changes: 130 additions & 0 deletions include/tvm/meta_schedule/postproc.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#ifndef TVM_META_SCHEDULE_POSTPROC_H_
#define TVM_META_SCHEDULE_POSTPROC_H_

#include <tvm/tir/schedule/schedule.h>

namespace tvm {
namespace meta_schedule {

class TuneContext;

/*!
* \brief Rules to apply a post processing to a schedule.
* \note Post processing is designed to deal with the problem of undertermined schedule validity
* after applying some schedule primitves at runtime. E.g., Fuse the first X loops to reach the
* maximum number below 1024, X is only decided at runtime.
*/
class PostprocNode : public runtime::Object {
public:
/*! \brief Virtual destructor. */
virtual ~PostprocNode() = default;

void VisitAttrs(tvm::AttrVisitor* v) {}

/*!
* \brief The function type of `InitializeWithTuneContext` method.
* \param tune_context The tuning context for initialization.
*/
virtual void InitializeWithTuneContext(const TuneContext& context) = 0;

/*!
* \brief Apply a post processing to the given schedule.
* \param sch The schedule to be post processed.
* \return Whether the post processing was successfully applied.
*/
virtual bool Apply(const tir::Schedule& schedule) = 0;

static constexpr const char* _type_key = "meta_schedule.Postproc";
TVM_DECLARE_BASE_OBJECT_INFO(PostprocNode, Object);
};

/*! \brief The post processing with customized methods on the python-side. */
class PyPostprocNode : public PostprocNode {
public:
/*!
* \brief The function type of `InitializeWithTuneContext` method.
* \param tune_context The tuning context for initialization.
*/
using FInitializeWithTuneContext = runtime::TypedPackedFunc<void(const TuneContext&)>;
/*!
* \brief Apply a post processing to the given schedule.
* \param sch The schedule to be post processed.
* \return Whether the post processing was successfully applied.
*/
using FApply = runtime::TypedPackedFunc<bool(const tir::Schedule&)>;
/*!
* \brief Get the post processing function as string with name.
* \return The string of the post processing function.
*/
using FAsString = runtime::TypedPackedFunc<String()>;

/*! \brief The packed function to the `InitializeWithTuneContext` funcion. */
FInitializeWithTuneContext f_initialize_with_tune_context;
/*! \brief The packed function to the `Apply` funcion. */
FApply f_apply;
/*! \brief The packed function to the `AsString` funcion. */
FAsString f_as_string;

void VisitAttrs(tvm::AttrVisitor* v) {
// `f_initialize_with_tune_context` is not visited
// `f_apply` is not visited
// `f_as_string` is not visited
}

void InitializeWithTuneContext(const TuneContext& context) final {
ICHECK(f_initialize_with_tune_context != nullptr)
<< "PyPostproc's InitializeWithTuneContext method not implemented!";
this->f_initialize_with_tune_context(context);
}

bool Apply(const tir::Schedule& sch) final {
ICHECK(f_apply != nullptr) << "PyPostproc's Apply method not implemented!";
return this->f_apply(sch);
}

static constexpr const char* _type_key = "meta_schedule.PyPostproc";
TVM_DECLARE_FINAL_OBJECT_INFO(PyPostprocNode, PostprocNode);
};

/*!
* \brief Managed reference to PostprocNode
* \sa PostprocNode
*/
class Postproc : public runtime::ObjectRef {
public:
/*!
* \brief Create a post processing with customized methods on the python-side.
* \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
* \param f_apply The packed function of `Apply`.
* \return The post processing created.
*/
TVM_DLL static Postproc PyPostproc(
PyPostprocNode::FInitializeWithTuneContext f_initialize_with_tune_context, //
PyPostprocNode::FApply f_apply, //
PyPostprocNode::FAsString f_as_string);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Postproc, ObjectRef, PostprocNode);
};

} // namespace meta_schedule
} // namespace tvm

#endif // TVM_META_SCHEDULE_POSTPROC_H_
Loading

0 comments on commit b5eb32d

Please sign in to comment.