diff --git a/;w b/;w deleted file mode 100644 index 3f1b4334f2974..0000000000000 --- a/;w +++ /dev/null @@ -1,104 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=missing-docstring - -import argparse -import glob -import json -import os - -import tvm -from tqdm import tqdm # type: ignore -from tvm import meta_schedule as ms -from tvm.ir import save_json -from tvm.meta_schedule.testing.relay_workload import _load_cache -from tvm.runtime import load_param_dict - - -def _parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--model_cache_dir", type=str, help="Please provide the full path to the model cache dir." - ) - parser.add_argument( - "--task_cache_dir", type=str, help="Please provide the full path to save extracted tasks." - ) - parser.add_argument( - "--target", type=str, default="cuda", help="Please specify the target hardware for tuning." - ) - return parser.parse_args() - - -# pylint: disable=too-many-locals -def extract_and_save_tasks(cache_file): - """Extract tuning tasks and cache the nonspatial ones in the given directory. - - Parameters - ---------- - cache_file : str - The filename of the cached model. - - Returns - ------- - None - """ - - mod, params_bytearray, _ = _load_cache(args.model_cache_dir, cache_file) - params = load_param_dict(params_bytearray) - try: - extracted_tasks = ms.extract_task_from_relay(mod, target=args.target, params=params) - except tvm.error.TVMError as e: # pylint: disable=protected-access - print(str(e)) - return - task_cache_path = os.path.join( - args.task_cache_dir, cache_file.split(".")[0] + "_extracted_tasks.json" - ) - is_spatial = tvm.get_global_func("tir.schedule.IsSpatialPrimFunc") - with open(task_cache_path, "w", encoding="utf8") as file: - for i, task in enumerate(extracted_tasks): - subgraph = task.dispatched[0] - prim_func = subgraph[subgraph.get_global_vars()[0]] - if not is_spatial(prim_func): - subgraph_str = save_json(subgraph) - json_obj = [task.task_name, json.loads(subgraph_str)] - json_str = json.dumps(json_obj) - assert "\n" not in json_str, "Failed to generate single line string." - if i == len(extracted_tasks) - 1: - file.write(json_str) - else: - file.write(json_str + "\n") - - -args = _parse_args() # pylint: disable=invalid-name - - -def main(): - if not os.path.isdir(args.model_cache_dir): - raise Exception("Please provide a correct model cache dir.") - try: - os.makedirs(args.task_cache_dir, exist_ok=True) - except OSError: - print(f"Directory {args.task_cache_dir} cannot be created successfully.") - - paths = glob.glob(os.path.join(args.model_cache_dir, "*.json")) # pylint: disable=invalid-name - for path in tqdm(paths): - filename = path.split("/")[-1] - extract_and_save_tasks(filename) - - -if __name__ == "__main__": - main() diff --git a/python/tvm/meta_schedule/testing/dataset_sample_candidates.py b/python/tvm/meta_schedule/testing/dataset_sample_candidates.py index edffad90e2e50..3b599baafd870 100644 --- a/python/tvm/meta_schedule/testing/dataset_sample_candidates.py +++ b/python/tvm/meta_schedule/testing/dataset_sample_candidates.py @@ -109,21 +109,6 @@ def sample_candidates(task, task_name, model_name): evolve_with_cost_model = tvm.get_global_func( "meta_schedule.SearchStrategyEvolutionarySearchEvolveWithCostModel" ) - tuning_record_path = os.path.join( - args.candidate_cache_dir, - model_name, - task_name + "_tuning_record.json", - ) - workload_path = os.path.join( - args.candidate_cache_dir, - model_name, - task_name + "_workload.json", - ) - database = ms.database.JSONDatabase( - path_workload=workload_path, - path_tuning_record=tuning_record_path, - ) - strategy = ms.search_strategy.EvolutionarySearch( num_trials_per_iter=args.num_trials_per_iter, max_trials_per_task=args.max_trials_per_task, @@ -143,7 +128,7 @@ def sample_candidates(task, task_name, model_name): context.initialize() context.pre_tuning( context.generate_design_space(), - database=database, + database=ms.database.MemoryDatabase(), cost_model=ms.cost_model.RandomModel(), # type: ignore ) @@ -162,9 +147,17 @@ def sample_candidates(task, task_name, model_name): itr += 1 all_states = all_states[: args.num_samples_per_task] - workload = database.commit_workload(mod=task) - for state in all_states: - database.commit_tuning_record(ms.database.TuningRecord(state.trace, workload)) + workload = ms.database.Workload(context.mod) + file_path = os.path.join(args.candidate_cache_dir, model_name, task_name + ".json") + with open(file_path, "w", encoding="utf8") as file: + for i, state in enumerate(all_states): + tuning_record = ms.database.TuningRecord(state.trace, workload) + json_str = json.dumps(tuning_record.as_json()) + assert "\n" not in json_str, "Failed to generate single line string." + if i == len(all_states) - 1: + file.write(json_str) + else: + file.write(json_str + "\n") args = _parse_args() # pylint: disable=invalid-name