Skip to content

Commit

Permalink
[Autotvm] Support override
Browse files Browse the repository at this point in the history
  • Loading branch information
hlu1 committed Jun 9, 2019
1 parent 98a91af commit 2849fea
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
8 changes: 4 additions & 4 deletions python/tvm/autotvm/task/topi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def get(allow_duplicate=False):
return TaskExtractEnv.current


def register_topi_compute(topi_compute, target_keys, template_keys, func=None):
def register_topi_compute(topi_compute, target_keys, template_keys, func=None, override=False):
"""Register a tunable template for a topi compute function.
After the registration, this topi compute will become a configuration dispatcher. It uses
Expand Down Expand Up @@ -333,7 +333,7 @@ def config_dispatcher(*args, **kwargs):

config_dispatcher = _REGISTERED_DISPATCHER[target_key][topi_compute]

@config_dispatcher.register(template_keys)
@config_dispatcher.register(template_keys, override=override)
def template_call(cfg, *args, **kwargs):
"""call the topi func and attach workload to compute node"""
assert not kwargs, "Do not support kwargs in template function call"
Expand Down Expand Up @@ -372,7 +372,7 @@ def template_call(cfg, *args, **kwargs):
return _decorator


def register_topi_schedule(topi_schedule, target_keys, template_keys, func=None):
def register_topi_schedule(topi_schedule, target_keys, template_keys, func=None, override=False):
"""Register a tunable template for a topi schedule function.
After the registration. This topi schedule will become a configuration dispatcher. It dispatches
Expand Down Expand Up @@ -438,7 +438,7 @@ def traverse(tensors):

config_dispatcher = _REGISTERED_DISPATCHER[target_key][topi_schedule]

@config_dispatcher.register(template_keys)
@config_dispatcher.register(template_keys, override=override)
def template_call(cfg, outs, *args, **kwargs):
"""call the schedule func"""
if f == topi_schedule.fdefault:
Expand Down
11 changes: 11 additions & 0 deletions tests/python/unittest/test_autotvm_dispatch_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@ def simple_template(a, b):
simple_template(2, 3)


def test_override():
from topi import generic, nn
@autotvm.register_topi_compute(nn.dense, 'cpu', ['direct'], override=True)
def dense(cfg, data, weight, bias=None, out_dtype=None):
pass

@autotvm.register_topi_schedule(generic.schedule_dense, 'cpu', ['direct'], override=True)
def schedule_dense(cfg, outs):
pass

if __name__ == "__main__":
test_dispatch()
test_fallback()
test_override()

0 comments on commit 2849fea

Please sign in to comment.