-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MetaSchedule] Generate MetaSchedule Dataset
In order to build a dataset for improving the cost model for MetaSchedule, I added several files including importing models to TVM, extracting tuning tasks, and sampling measure candidates. Meanwhile, I exposed some methods in C++ to the Python side to assist the process.
- Loading branch information
1 parent
6fca5c6
commit eaf2c64
Showing
5 changed files
with
366 additions
and
0 deletions.
There are no files selected for viewing
73 changes: 73 additions & 0 deletions
73
python/tvm/meta_schedule/testing/dataset_collect_models.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
""" | ||
Import models to TVM. | ||
""" | ||
|
||
import argparse | ||
import os | ||
from typing import List, Tuple | ||
from tqdm import tqdm # type: ignore | ||
|
||
from tvm.meta_schedule.testing.relay_workload import get_network | ||
|
||
|
||
# pylint: disable=too-many-branches | ||
def _build_dataset() -> List[Tuple[str, List[int]]]: | ||
network_keys = [] | ||
for name in [ | ||
"resnet_18", | ||
"resnet_50", | ||
"mobilenet_v2", | ||
"mobilenet_v3", | ||
"wide_resnet_50", | ||
"resnext_50", | ||
"densenet_121", | ||
"vgg_16", | ||
]: | ||
for batch_size in [1, 4, 8]: | ||
for image_size in [224, 240, 256]: | ||
network_keys.append((name, [batch_size, 3, image_size, image_size])) | ||
# inception-v3 | ||
for name in ["inception_v3"]: | ||
for batch_size in [1, 2, 4]: | ||
for image_size in [299]: | ||
network_keys.append((name, [batch_size, 3, image_size, image_size])) | ||
# resnet3d | ||
for name in ["resnet3d_18"]: | ||
for batch_size in [1, 2, 4]: | ||
for image_size in [112, 128, 144]: | ||
network_keys.append((name, [batch_size, 3, image_size, image_size, 16])) | ||
# bert | ||
for name in ["bert_tiny", "bert_base", "bert_medium", "bert_large"]: | ||
for batch_size in [1, 2, 4]: | ||
for seq_length in [64, 128, 256]: | ||
network_keys.append((name, [batch_size, seq_length])) | ||
# dcgan | ||
for name in ["dcgan"]: | ||
for batch_size in [1, 4, 8]: | ||
for image_size in [64]: | ||
network_keys.append((name, [batch_size, 3, image_size, image_size])) | ||
|
||
return network_keys | ||
|
||
|
||
def cache_models(network_keys, cache_dir): | ||
"""Download the model and cache it in the given directory.""" | ||
|
||
for name, input_shape in tqdm(network_keys): | ||
get_network(name=name, input_shape=input_shape, cache_dir=cache_dir) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--model_cache_dir", type=str, help="Please provide the full path to the model cache dir." | ||
) | ||
args = parser.parse_args() | ||
model_cache_dir = args.model_cache_dir | ||
|
||
try: | ||
os.makedirs(model_cache_dir, exist_ok=True) | ||
except OSError as error: | ||
print(f"Directory {model_cache_dir} cannot be created successfully.") | ||
keys = _build_dataset() | ||
cache_models(keys, model_cache_dir) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
""" | ||
Extract tuning tasks using MetaSchedule, and filter out spatial tasks. | ||
""" | ||
|
||
import argparse | ||
import glob | ||
import json | ||
import os | ||
from tqdm import tqdm # type: ignore | ||
|
||
import tvm | ||
from tvm import meta_schedule as ms | ||
from tvm.ir import save_json | ||
from tvm.meta_schedule.testing.relay_workload import _load_cache | ||
from tvm.runtime import load_param_dict | ||
|
||
|
||
def _parse_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--model_cache_dir", type=str, help="Please provide the full path to the model cache dir." | ||
) | ||
parser.add_argument( | ||
"--task_cache_dir", type=str, help="Please provide the full path to save extracted tasks." | ||
) | ||
parser.add_argument( | ||
"--target", type=str, default="cuda", help="Please specify the target hardware for tuning." | ||
) | ||
return parser.parse_args() | ||
|
||
|
||
# pylint: disable=too-many-locals | ||
def extract_and_save_tasks(cache_file, is_spatial, args): | ||
"""Extract tuning tasks and cache the nonspatial ones in the given directory. | ||
Parameters | ||
---------- | ||
cache_file : str | ||
The filename of the cached model. | ||
is_spatial : PackedFunc | ||
The function for checking whether a task is spatial. | ||
args : argparse.Namespace | ||
The parsed arguments. | ||
Returns | ||
------- | ||
None | ||
""" | ||
|
||
mod, params_bytearray, _ = _load_cache(args.model_cache_dir, cache_file) | ||
params = load_param_dict(params_bytearray) | ||
try: | ||
extracted_tasks = ms.extract_task_from_relay(mod, target=args.target, params=params) | ||
except tvm._ffi.base.TVMError: # pylint: disable=protected-access | ||
return | ||
task_cache_path = os.path.join( | ||
args.task_cache_dir, cache_file.split(".")[0] + "_extracted_tasks.json" | ||
) | ||
with open(task_cache_path, "w", encoding="utf8") as file: | ||
for i, task in enumerate(extracted_tasks): | ||
subgraph = task.dispatched[0] | ||
prim_func = subgraph[subgraph.get_global_vars()[0]] | ||
if not is_spatial(prim_func): | ||
subgraph_str = save_json(subgraph) | ||
json_obj = [json.loads(subgraph_str), task.task_name] | ||
json_str = json.dumps(json_obj) | ||
assert "\n" not in json_str, "Failed to generate single line string." | ||
if i == len(extracted_tasks) - 1: | ||
file.write(json_str) | ||
else: | ||
file.write(json_str + "\n") | ||
|
||
|
||
if __name__ == "__main__": | ||
parsed_args = _parse_args() | ||
if not os.path.isdir(parsed_args.model_cache_dir): | ||
raise Exception("Please provide a correct model cache dir.") | ||
try: | ||
os.makedirs(parsed_args.task_cache_dir, exist_ok=True) | ||
except OSError as error: | ||
print(f"Directory {parsed_args.task_cache_dir} cannot be created successfully.") | ||
|
||
check_spatial_fn = tvm.get_global_func("tir.schedule.IsSpatialPrimFunc") | ||
cache_paths = glob.glob(os.path.join(parsed_args.model_cache_dir, "*.json")) | ||
for cache_path in tqdm(cache_paths): | ||
filename = cache_path.split("/")[-1] | ||
extract_and_save_tasks(filename, check_spatial_fn, parsed_args) |
177 changes: 177 additions & 0 deletions
177
python/tvm/meta_schedule/testing/dataset_sample_candidates.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
""" | ||
Sample measure candidates for each tuning task by evolutionary search. | ||
""" | ||
|
||
import argparse | ||
import glob | ||
import json | ||
import os | ||
from typing import List | ||
from tqdm import tqdm # type: ignore | ||
|
||
import tvm | ||
from tvm import meta_schedule as ms | ||
from tvm.ir import load_json | ||
from tvm.meta_schedule import TuneContext | ||
from tvm.meta_schedule.database import TuningRecord, Workload | ||
from tvm.meta_schedule.search_strategy import EvolutionarySearch | ||
from tvm.meta_schedule.space_generator import PostOrderApply | ||
from tvm.meta_schedule.testing.utils import DummyDatabase | ||
from tvm.meta_schedule.tune import DefaultCUDA, DefaultLLVM | ||
from tvm.target import Target | ||
|
||
|
||
def _parse_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--task_cache_dir", type=str, help="Please provide the full path to the extracted tasks." | ||
) | ||
parser.add_argument( | ||
"--candidate_cache_dir", | ||
type=str, | ||
help="Please provide the full path to save the sampled candidates.", | ||
) | ||
parser.add_argument( | ||
"--target", | ||
type=str, | ||
default="nvidia/geforce-rtx-3070", | ||
help="Please specify the target hardware for tuning.\ | ||
Note: for generating dataset, the hardware does not need to be present.", | ||
) | ||
parser.add_argument( | ||
"--init_population_size", | ||
type=int, | ||
default=256, | ||
help="The initial population size used in evolutionary search.", | ||
) | ||
parser.add_argument( | ||
"--num_samples_per_task", | ||
type=int, | ||
default=400, | ||
help="The number of samples to gather per tuning task.", | ||
) | ||
parser.add_argument( | ||
"--num_trials_per_iter", | ||
type=int, | ||
default=64, | ||
help="The number of trials per iteration in evolutionary search.", | ||
) | ||
parser.add_argument( | ||
"--max_trials_per_task", | ||
type=int, | ||
default=400, | ||
help="The maximum number of trials per task in evolutionary search.", | ||
) | ||
parser.add_argument( | ||
"--max_retry_per_task", | ||
type=int, | ||
default=10, | ||
help="The maximum number of retry attempts allowed.", | ||
) | ||
parser.add_argument( | ||
"--file_group", | ||
type=int, | ||
default=0, | ||
help="To enable running multiple scripts in parallel, files [idx * 10 : (idx + 1) * 10]\ | ||
in the sorted file list from the given directory will be run.", | ||
) | ||
return parser.parse_args() | ||
|
||
|
||
# pylint: disable=too-many-locals | ||
def sample_candidates(task, task_name, model_name): | ||
"""Randomly sample candidates for a task and save the candidates in the given directory. | ||
Parameters | ||
---------- | ||
task : IRModule | ||
The initial ir module used for generating the search space. | ||
task_name : str | ||
The name of the task. | ||
model_name : str | ||
The name of the model. | ||
Returns | ||
------- | ||
None | ||
""" | ||
|
||
strategy = EvolutionarySearch( | ||
num_trials_per_iter=args.num_trials_per_iter, | ||
max_trials_per_task=args.max_trials_per_task, | ||
) | ||
default_config = DefaultCUDA if args.target != "llvm" else DefaultLLVM | ||
# pylint: disable=protected-access | ||
context = TuneContext( | ||
mod=task, | ||
target=Target(args.target), | ||
space_generator=PostOrderApply(), | ||
search_strategy=strategy, | ||
sch_rules=default_config._sch_rules(), # type: ignore | ||
postprocs=default_config._postproc(), # type: ignore | ||
mutator_probs=default_config._mutator_probs(), # type: ignore | ||
task_name=task_name, | ||
) | ||
context.initialize() | ||
spaces = context.space_generator.generate_design_space(context.mod) | ||
# type: ignore | ||
strategy.pre_tuning(spaces, database=DummyDatabase(), cost_model=ms.cost_model.RandomModel()) | ||
|
||
all_states: List[tvm.tir.schedule.schedule.Schedule] = [] | ||
num_retry, itr = 0, 0 | ||
states = sample_init_population(strategy, args.init_population_size) | ||
while len(all_states) < args.num_samples_per_task and num_retry < args.max_retry_per_task: | ||
states = evolve_with_cost_model(strategy, states, len(states)) | ||
all_states += states | ||
if len(states) == 0: | ||
states = sample_init_population(strategy, args.init_population_size) | ||
num_retry += 1 | ||
else: | ||
num_retry = 0 | ||
print(f"iter: {itr}, number of states sampled: {len(all_states)}") | ||
itr += 1 | ||
all_states = all_states[: args.num_samples_per_task] | ||
|
||
workload = Workload(context.mod) | ||
file_path = os.path.join(args.candidate_cache_dir, model_name, task_name + ".json") | ||
with open(file_path, "w", encoding="utf8") as file: | ||
for i, state in enumerate(all_states): | ||
tuning_record = TuningRecord(state.trace, workload) | ||
json_str = json.dumps(tuning_record.as_json()) | ||
assert "\n" not in json_str, "Failed to generate single line string." | ||
if i == len(all_states) - 1: | ||
file.write(json_str) | ||
else: | ||
file.write(json_str + "\n") | ||
|
||
|
||
if __name__ == "__main__": | ||
args = _parse_args() | ||
if not os.path.isdir(args.task_cache_dir): | ||
raise Exception("Please provide a correct task cache dir.") | ||
try: | ||
os.makedirs(args.candidate_cache_dir, exist_ok=True) | ||
except OSError as error: | ||
print(f"Directory {args.candidate_cache_dir} cannot be created successfully.") | ||
|
||
sample_init_population = tvm.get_global_func( | ||
"meta_schedule.SearchStrategyEvolutionarySearchSampleInitPopulation" | ||
) | ||
evolve_with_cost_model = tvm.get_global_func( | ||
"meta_schedule.SearchStrategyEvolutionarySearchEvolveWithCostModel" | ||
) | ||
task_paths = sorted(glob.glob(os.path.join(args.task_cache_dir, "*.json")))[ | ||
args.file_group * 10 : (args.file_group + 1) * 10 | ||
] | ||
print(f"Selected models: {task_paths}") | ||
for num, task_path in enumerate(task_paths): | ||
print(f"Processing model {num} ...") | ||
with open(task_path, "rb") as f: | ||
tasks = f.readlines() | ||
model_n = task_path.split("/")[-1][len("relay-") :][: -len("_extracted_tasks.json")] | ||
os.makedirs(os.path.join(args.candidate_cache_dir, model_n), exist_ok=True) | ||
for task_str in tqdm(tasks): | ||
task_mod, task_n = json.loads(task_str) | ||
task_mod = load_json(json.dumps(task_mod)) | ||
sample_candidates(task_mod, task_n, model_n) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters