Skip to content

Commit

Permalink
[TVMC] Allow section of tasks to be tuned when using tvmc tune command
Browse files Browse the repository at this point in the history
[TVMC] Address PR comments
  • Loading branch information
PhilippvK committed Sep 7, 2022
1 parent 546a7da commit e0feebb
Showing 1 changed file with 81 additions and 8 deletions.
89 changes: 81 additions & 8 deletions python/tvm/driver/tvmc/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ def add_tune_parser(subparsers, _, json_params):
help="enable tuning the graph through the AutoScheduler tuner",
action="store_true",
)
parser.add_argument(
"--tasks",
default="all",
help="which tasks should be tuned, i.e. 0 0,2 3-5 all list",
)

auto_scheduler_group = parser.add_argument_group(
"AutoScheduler options",
Expand Down Expand Up @@ -293,9 +298,51 @@ def drive_tune(args):
include_simple_tasks=args.include_simple_tasks,
log_estimated_latency=args.log_estimated_latency,
additional_target_options=reconstruct_target_args(args),
tasks_filter=args.tasks,
)


def filter_tasks(
tasks: Optional[Union[auto_scheduler.SearchTask, autotvm.task.Task]],
expr: str,
):
assert isinstance(expr, str), "Expected filter expression of string type"
assert len(expr) > 0, "Got empty filter expression"

# groups of keywords are comma-separated
splitted = expr.split(",")

do_list = False
do_filter = False
selected = []
for item in splitted:
if item in ["list", "help"]:
do_list = True
elif item in ["all"]:
selected = list(range(len(tasks)))
else:
do_filter = True
if "-" in item:
lhs, rhs = item.split("-")[:2]
lhs = int(lhs) if lhs else 0
rhs = int(rhs) if rhs else len(tasks) - 1
assert 0 <= lhs < len(tasks), "Left-hand side expression out of range"
assert 0 <= rhs < len(tasks), "Right-hand side expression out of range"
selected.extend(list(range(lhs, rhs + 1)))
else:
assert isinstance(item, str)
idx = int(item)
assert idx < len(tasks) and idx >= 0
selected.append(idx)

if do_filter:
# remove duplicates
selected = list(set(selected))
tasks = [task for i, task in enumerate(tasks) if i in selected]

return tasks, do_list


def tune_model(
tvmc_model: TVMCModel,
target: str,
Expand All @@ -319,6 +366,7 @@ def tune_model(
include_simple_tasks: bool = False,
log_estimated_latency: bool = False,
additional_target_options: Optional[Dict[str, Dict[str, Any]]] = None,
tasks_filter: str = "all",
):
"""Use tuning to automatically optimize the functions in a model.
Expand Down Expand Up @@ -377,6 +425,9 @@ def tune_model(
If using the autoscheduler, write the estimated latency at each step of tuning to file.
additional_target_options: Optional[Dict[str, Dict[str, Any]]]
Additional target options in a dictionary to combine with initial Target arguments
tasks_filter : str, optional
Filter which tasks should be tuned or output a list of the extracted tasks.
Examples: 0 0,2 3-5 all list
Returns
-------
Expand Down Expand Up @@ -444,7 +495,6 @@ def tune_model(
runner = local_server

if enable_autoscheduler:

tasks, weights = autoscheduler_get_tuning_tasks(
mod=mod,
params=params,
Expand All @@ -453,7 +503,37 @@ def tune_model(
hardware_params=hardware_params,
include_simple_tasks=include_simple_tasks,
)
else:
tasks = autotvm_get_tuning_tasks(
mod=mod,
params=params,
target=target,
alter_layout=desired_layout,
)

# Filter extracted tasks by provided user expression
if tasks_filter:
tasks, do_list = filter_tasks(tasks, tasks_filter)
if do_list:
print("Available Tasks for tuning:")
print(
"\n".join(
[
" {}. {}".format(
i, task if len(str(task)) < 100 else str(task)[:97] + "..."
)
for i, task in enumerate(tasks)
]
)
)
return None
if len(tasks) == 0:
logger.info("No tasks have been selected for tuning.")
return None
else:
logger.info(f"Selected {len(tasks)} for tuning.")

if enable_autoscheduler:
# Create the autoscheduler tuning options
tuning_options = auto_scheduler.TuningOptions(
num_measure_trials=trials,
Expand All @@ -467,13 +547,6 @@ def tune_model(
# Schedule the tasks (i.e., produce a schedule for each task)
schedule_tasks(tasks, weights, tuning_options, prior_records, log_estimated_latency)
else:
tasks = autotvm_get_tuning_tasks(
mod=mod,
params=params,
target=target,
alter_layout=desired_layout,
)

# In autotvm, trials is specified per task. We can convert the per-model input
# provided to per-task trials by dividing by the number of tasks.
trials = int(trials / max(len(tasks), 1))
Expand Down

0 comments on commit e0feebb

Please sign in to comment.