Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
formatted
Browse files Browse the repository at this point in the history
Kathryn-cat committed Jun 10, 2022
1 parent f1e6ff4 commit 7db45ee
Showing 4 changed files with 114 additions and 11 deletions.
104 changes: 104 additions & 0 deletions ;w
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-docstring

import argparse
import glob
import json
import os

import tvm
from tqdm import tqdm # type: ignore
from tvm import meta_schedule as ms
from tvm.ir import save_json
from tvm.meta_schedule.testing.relay_workload import _load_cache
from tvm.runtime import load_param_dict


def _parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_cache_dir", type=str, help="Please provide the full path to the model cache dir."
)
parser.add_argument(
"--task_cache_dir", type=str, help="Please provide the full path to save extracted tasks."
)
parser.add_argument(
"--target", type=str, default="cuda", help="Please specify the target hardware for tuning."
)
return parser.parse_args()


# pylint: disable=too-many-locals
def extract_and_save_tasks(cache_file):
"""Extract tuning tasks and cache the nonspatial ones in the given directory.

Parameters
----------
cache_file : str
The filename of the cached model.

Returns
-------
None
"""

mod, params_bytearray, _ = _load_cache(args.model_cache_dir, cache_file)
params = load_param_dict(params_bytearray)
try:
extracted_tasks = ms.extract_task_from_relay(mod, target=args.target, params=params)
except tvm.error.TVMError as e: # pylint: disable=protected-access
print(str(e))
return
task_cache_path = os.path.join(
args.task_cache_dir, cache_file.split(".")[0] + "_extracted_tasks.json"
)
is_spatial = tvm.get_global_func("tir.schedule.IsSpatialPrimFunc")
with open(task_cache_path, "w", encoding="utf8") as file:
for i, task in enumerate(extracted_tasks):
subgraph = task.dispatched[0]
prim_func = subgraph[subgraph.get_global_vars()[0]]
if not is_spatial(prim_func):
subgraph_str = save_json(subgraph)
json_obj = [task.task_name, json.loads(subgraph_str)]
json_str = json.dumps(json_obj)
assert "\n" not in json_str, "Failed to generate single line string."
if i == len(extracted_tasks) - 1:
file.write(json_str)
else:
file.write(json_str + "\n")


args = _parse_args() # pylint: disable=invalid-name


def main():
if not os.path.isdir(args.model_cache_dir):
raise Exception("Please provide a correct model cache dir.")
try:
os.makedirs(args.task_cache_dir, exist_ok=True)
except OSError:
print(f"Directory {args.task_cache_dir} cannot be created successfully.")

paths = glob.glob(os.path.join(args.model_cache_dir, "*.json")) # pylint: disable=invalid-name
for path in tqdm(paths):
filename = path.split("/")[-1]
extract_and_save_tasks(filename)


if __name__ == "__main__":
main()
1 change: 0 additions & 1 deletion python/tvm/meta_schedule/testing/dataset_collect_models.py
Original file line number Diff line number Diff line change
@@ -65,7 +65,6 @@ def _build_dataset() -> List[Tuple[str, List[int]]]:

def main():
model_cache_dir = args.model_cache_dir

try:
os.makedirs(model_cache_dir, exist_ok=True)
except OSError:
8 changes: 4 additions & 4 deletions python/tvm/meta_schedule/testing/dataset_extract_tasks.py
Original file line number Diff line number Diff line change
@@ -21,8 +21,8 @@
import json
import os

import tvm
from tqdm import tqdm # type: ignore
import tvm
from tvm import meta_schedule as ms
from tvm.ir import save_json
from tvm.meta_schedule.testing.relay_workload import _load_cache
@@ -61,8 +61,8 @@ def extract_and_save_tasks(cache_file):
params = load_param_dict(params_bytearray)
try:
extracted_tasks = ms.extract_task_from_relay(mod, target=args.target, params=params)
except tvm.error.TVMError as e: # pylint: disable=protected-access
print(str(e))
except tvm.error.TVMError as error:
print(str(error))
return
task_cache_path = os.path.join(
args.task_cache_dir, cache_file.split(".")[0] + "_extracted_tasks.json"
@@ -83,7 +83,7 @@ def extract_and_save_tasks(cache_file):
file.write(json_str + "\n")


args = _parse_args()
args = _parse_args() # pylint: disable=invalid-name


def main():
12 changes: 6 additions & 6 deletions python/tvm/meta_schedule/testing/dataset_sample_candidates.py
Original file line number Diff line number Diff line change
@@ -22,8 +22,8 @@
import os
from typing import List

import tvm
from tqdm import tqdm # type: ignore
import tvm
from tvm import meta_schedule as ms
from tvm.ir import load_json
from tvm.target import Target
@@ -184,14 +184,14 @@ def main():
print(f"Selected models: {task_paths}")
for num, task_path in enumerate(task_paths):
print(f"Processing model {num} ...")
with open(task_path, "rb") as f:
tasks = f.readlines()
model_n = task_path.split("/")[-1][len("relay-") :][: -len("_extracted_tasks.json")]
os.makedirs(os.path.join(args.candidate_cache_dir, model_n), exist_ok=True)
with open(task_path, "rb") as file:
tasks = file.readlines()
model_name = task_path.split("/")[-1][len("relay-") :][: -len("_extracted_tasks.json")]
os.makedirs(os.path.join(args.candidate_cache_dir, model_name), exist_ok=True)
for task_str in tqdm(tasks):
task_name, task_mod = json.loads(task_str)
task_mod = load_json(json.dumps(task_mod))
sample_candidates(task_mod, task_name, model_n)
sample_candidates(task_mod, task_name, model_name)


if __name__ == "__main__":

0 comments on commit 7db45ee

Please sign in to comment.