Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoTVM][Autoscheduler] Default build funcs inherit PassContext #11632

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,22 @@
We implement these in python to utilize python's multiprocessing and error handling.
"""

import logging
import multiprocessing
import os
import time
import shutil
import tempfile
import multiprocessing
import logging
import time

import tvm._ffi
from tvm.runtime import Object, module, ndarray
from tvm.autotvm.env import AutotvmGlobalScope, reset_global_scope
from tvm.contrib import ndk, tar
from tvm.contrib.popen_pool import PopenPoolExecutor, PopenWorker, StatusKind
from tvm.driver import build_module
from tvm.ir import transform
from tvm.autotvm.env import AutotvmGlobalScope, reset_global_scope
from tvm.contrib import tar, ndk
from tvm.contrib.popen_pool import PopenWorker, PopenPoolExecutor, StatusKind
from tvm.runtime import Object, module, ndarray
from tvm.target import Target


from . import _ffi_api
from .loop_state import StateObject
from .utils import (
Expand All @@ -59,8 +58,8 @@
request_remote,
)
from .workload_registry import (
serialize_workload_registry_entry,
deserialize_workload_registry_entry,
serialize_workload_registry_entry,
)

# pylint: disable=invalid-name
Expand Down Expand Up @@ -555,8 +554,8 @@ def __init__(
device=0,
):
# pylint: disable=import-outside-toplevel
from tvm.rpc.tracker import Tracker
from tvm.rpc.server import Server
from tvm.rpc.tracker import Tracker

self.tracker = Tracker(port=9000, port_end=10000, silent=True)
device_key = "$local$device$%d" % self.tracker.port
Expand Down Expand Up @@ -630,7 +629,7 @@ def _local_build_worker(inp_serialized, build_func, verbose):
filename = os.path.join(dirname, "tmp_func." + build_func.output_format)

try:
with transform.PassContext():
with transform.PassContext().current():
func = build_module.build(sch, args, target=task.target)
func.export_library(filename, build_func)
# pylint: disable=broad-except
Expand Down
29 changes: 23 additions & 6 deletions python/tvm/autotvm/measure/measure_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@
import time
import traceback
import typing
import warnings
from collections import namedtuple
from random import getrandbits
import warnings

import tvm._ffi
import tvm.ir.transform
Expand Down Expand Up @@ -505,10 +505,6 @@ def _build_func_common(measure_input, runtime=None, check_gpu=None, build_option
if not config.valid():
raise InstantiationError(config.errors)

opts = build_option or {}
if check_gpu: # Add verify pass to filter out invalid configs in advance.
opts["tir.add_lower_pass"] = [(2, gpu_verify_pass(**check_gpu))]

# if target is vta, we need to use vta build
if (
hasattr(measure_input.target, "device_name")
Expand All @@ -519,7 +515,28 @@ def _build_func_common(measure_input, runtime=None, check_gpu=None, build_option

func = vta.build(s, args, target_host=task.target_host)
else:
with tvm.ir.transform.PassContext(config=opts):
current_pass_context: tvm.ir.transform.PassContext = (
tvm.ir.transform.PassContext.current()
)
current_config = dict(current_pass_context.config)
if build_option is not None:
current_config.update(build_option)

if "tir.add_lower_pass" in current_config:
current_add_lower_pass = list(current_config["tir.add_lower_pass"])
else:
current_add_lower_pass = []
if check_gpu:
current_add_lower_pass.append((2, gpu_verify_pass(**check_gpu)))
current_config["tir.add_lower_pass"] = current_add_lower_pass

with tvm.ir.transform.PassContext(
opt_level=current_pass_context.opt_level,
required_pass=current_pass_context.required_pass,
disabled_pass=current_pass_context.disabled_pass,
instruments=current_pass_context.instruments,
config=current_config,
):
func = build(s, args, target_host=task.target_host, runtime=runtime)
return func, tuple((get_const_tuple(x.shape), x.dtype) for x in args)

Expand Down
112 changes: 112 additions & 0 deletions tests/python/integration/test_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@
import tvm.relay
import tvm.testing
from tvm import autotvm, te
from tvm.autotvm.measure import measure_methods
from tvm.autotvm.tuner import RandomTuner
from tvm.contrib import tar
from tvm.ir.instrument import pass_instrument
from tvm.ir.transform import PassContext
from tvm.target import Target


Expand Down Expand Up @@ -180,6 +184,114 @@ def runner(target, dev):
run_test_with_all_multiprocessing(runner, target, dev)


@tvm.testing.parametrize_targets("cuda", "opencl")
def test_tuning_gpu_inherits_pass_context(target, dev):
"""Autotvm tuner inherits PassContexts but also adds a gpu verification pass by default.

Test that using PassContext inherits passes properly but also runs gpu verification pass.
"""
from tvm.tir.analysis import _ffi_api as _analysis_ffi_api

@pass_instrument
class PassInstrumentChecker:
"""Pass Instrument that simply sees if it's been run."""

def __init__(self):
self.has_been_run = False

def run_after_pass(self, mod, info):
self.has_been_run = True

class GPUVerifyPassMocked:
"""Context manager that mocks tir.analysis.verify_gpu_code meant
to verify the pass has been run. This is done by patching the ffi func handles."""

FFI_FUNC_HANDLE = "tir.analysis.verify_gpu_code"
FUNC_NAME = "verify_gpu_code"

def __init__(self) -> None:
self.old_impl = tvm._ffi.get_global_func(self.FFI_FUNC_HANDLE)
self.has_been_run = False

def gpu_verify_pass_mocked(self):
"""Get the replacement for the gpu verification pass."""

def _gpu_verify_pass_mocked(*args, **kwargs):
self.has_been_run = True
return self.old_impl(*args, **kwargs)

return _gpu_verify_pass_mocked

def __enter__(self):
tvm._ffi.register_func(
self.FFI_FUNC_HANDLE, self.gpu_verify_pass_mocked(), override=True
)

# Also overwrite the python bindings
setattr(
_analysis_ffi_api, self.FUNC_NAME, tvm._ffi.get_global_func(self.FFI_FUNC_HANDLE)
)

def __exit__(self, *args, **kwargs):
# Restore FFI status back to normal
tvm._ffi.register_func(self.FFI_FUNC_HANDLE, self.old_impl, override=True)
setattr(_analysis_ffi_api, self.FUNC_NAME, self.old_impl)

class OverwrittenBuildFunc(measure_methods._WrappedBuildFunc):
"""BuildFunc that mocks and patches as necessary to test proper passes are run."""

def __call__(self, measure_input, tmp_dir, **kwargs):
instrument = PassInstrumentChecker()
mocked_pass_checker = GPUVerifyPassMocked()
with mocked_pass_checker:
with PassContext(instruments=[instrument]):
regular_result = super().__call__(measure_input, tmp_dir, **kwargs)

# Check instrument has been run, meaning context was inherited by builder
assert instrument.has_been_run

# But also check the gpu verification pass has been run
# (which was not in the inherited ctx)
assert mocked_pass_checker.has_been_run

return regular_result

class MockedLocalBuilder(measure_methods.LocalBuilder):
"""As measure_methods.LocalBuilder but overwrites the PassContext for testing."""

def __init__(
self,
timeout=10,
n_parallel=None,
build_kwargs=None,
build_func="default",
do_fork=False,
runtime=None,
):
super().__init__(timeout, n_parallel, build_kwargs, build_func, do_fork, runtime)
self.build_func = OverwrittenBuildFunc(tar.tar, runtime)

def runner(target, dev):
task, target = get_sample_task(target, None)
logging.info("task config space: %s", task.config_space)

# Note: we use the MockedLocalBuilder here instead of autotvm.LocalBuilder()
measure_option = autotvm.measure_option(MockedLocalBuilder(), autotvm.LocalRunner())

results = []

tuner = RandomTuner(task)
tuner.tune(
n_trial=1,
measure_option=measure_option,
callbacks=(lambda _tuner, _inputs, rs: results.extend(rs),),
)

assert len(results) == 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

want to check the pass also succeeded? i think if one of those asserts fail we just get measure error here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assertions will fail the test so I think it's safe -- we just want to make sure proper passes are run

We don't want to check for success in tuning since the tuning process on GPU is actually flaky (see test_tuning_gpu()) above.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok--maybe the runner is where exceptions don't fail the test. good enough, if you've proven locally an exception causes the test to fail.


run_test_with_all_multiprocessing(runner, target, dev)


def test_tuning_cpu():
def runner():
ir_mod = tvm.parser.fromtext(
Expand Down