From b77e2cf3d1162a5a7bab97dca487f435333a74ca Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Sun, 21 Aug 2022 02:07:18 -0700 Subject: [PATCH 1/7] Type hints for load_history_best --- python/tvm/autotvm/record.py | 27 +++++++++++-- python/tvm/autotvm/task/dispatcher.py | 56 ++++++++++++++++----------- 2 files changed, 58 insertions(+), 25 deletions(-) diff --git a/python/tvm/autotvm/record.py b/python/tvm/autotvm/record.py index b2faee243be0..8a3753326d26 100644 --- a/python/tvm/autotvm/record.py +++ b/python/tvm/autotvm/record.py @@ -194,20 +194,41 @@ def clean_json_to_python(x): raise RuntimeError("Invalid log protocol: " + protocol) -def load_from_file(filename): +def load_from_io(file): """Generator: load records from file. This is a generator that yields the records. Parameters ---------- - filename: str + file: os.TextIOBase Yields ------ input: autotvm.measure.MeasureInput result: autotvm.measure.MeasureResult """ - with open(filename) as f: + for row in file: + if row and not row.startswith("#"): + ret = decode(row) + if ret is None: + continue + yield ret + + +def load_from_file(filepath): + """Generator: load records from path. + This is a generator that yields the records. + + Parameters + ---------- + filepath: str, bytes, os.PathLike + + Yields + ------ + input: autotvm.measure.MeasureInput + result: autotvm.measure.MeasureResult + """ + with open(filepath) as f: for row in f: if row and not row.startswith("#"): ret = decode(row) diff --git a/python/tvm/autotvm/task/dispatcher.py b/python/tvm/autotvm/task/dispatcher.py index 11a608d4cbbf..076f3fb96422 100644 --- a/python/tvm/autotvm/task/dispatcher.py +++ b/python/tvm/autotvm/task/dispatcher.py @@ -30,18 +30,26 @@ from __future__ import absolute_import as _abs +from io import TextIOBase import logging -import typing -from typing import Union -from collections.abc import Iterable +from os import PathLike +from pathlib import Path +from typing import List, Iterable, Tuple, Union import numpy as np from .space import FallbackConfigEntity from .. import env as _env +from ..measure import MeasureInput, MeasureResult logger = logging.getLogger("autotvm") +Records = Union[ + Union[str, bytes, Path], # Path-like objects + TextIOBase, # File-like objects + Iterable[Tuple[MeasureInput, MeasureResult]], +] + class DispatchContext(object): """ @@ -194,7 +202,7 @@ class ApplyFixedConfig(DispatchContext): Name of schedules to use. """ - def __init__(self, tasks, schedule_names: Union[str, typing.List[str]]): + def __init__(self, tasks, schedule_names: Union[str, List[str]]): super(ApplyFixedConfig, self).__init__() if isinstance(schedule_names, str): self._schedule_names = list(schedule_names) @@ -256,7 +264,7 @@ def __init__(self, records): if records: self.load(records) - def load(self, records): + def load(self, records: Union[Records, Iterable[Records]]): """Load records to this dispatch context Parameters @@ -270,32 +278,36 @@ def load(self, records): an iterator of measurement results. """ # pylint: disable=import-outside-toplevel - from pathlib import Path - from ..record import load_from_file + from ..record import load_from_file, load_from_io - joint_records = [] - if not isinstance(records, Iterable) or isinstance(records, str): - records = [records] + def _unpack_records( + records: Union[Records, Iterable[Records]] + ) -> List[Tuple[MeasureInput, MeasureResult]]: - for rec in records: - if isinstance(rec, Path): - rec = str(rec) + if isinstance(records, (str, bytes, PathLike)): + return load_from_file(records) - if isinstance(rec, str): - rec = load_from_file(rec) - joint_records += rec - else: - if rec is not None: - joint_records.append(rec) + if isinstance(records, TextIOBase): + return load_from_io(records) - if not joint_records: + joint_records = [] + for record in records: + if isinstance(record, Tuple) and isinstance(record[0], MeasureInput): + joint_records.append(record) + else: + joint_records += _unpack_records(record) + + return joint_records + + flattened_records = _unpack_records(records) + if not flattened_records: return best_by_targetkey = self.best_by_targetkey best_by_model = self.best_by_model counter = 0 - for inp, res in joint_records: + for inp, res in flattened_records: counter += 1 if res.error_no != 0: continue @@ -447,7 +459,7 @@ class ApplyGraphBest(DispatchContext): node index. """ - def __init__(self, records): + def __init__(self, records: Records): """ Parameters ---------- From 6fabf84ee647aff68d48ef6b153df34c52b99b9a Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Sun, 21 Aug 2022 06:17:31 -0700 Subject: [PATCH 2/7] Update docstring --- python/tvm/autotvm/task/dispatcher.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/python/tvm/autotvm/task/dispatcher.py b/python/tvm/autotvm/task/dispatcher.py index 076f3fb96422..83ce90b39aa4 100644 --- a/python/tvm/autotvm/task/dispatcher.py +++ b/python/tvm/autotvm/task/dispatcher.py @@ -246,12 +246,12 @@ class ApplyHistoryBest(DispatchContext): Parameters ---------- - records : str, list of str, or iterator of (autotvm.measure.MeasureInput,\ - autotvm.measure.MeasureResult) - Collection of tuning records. - If is str, then it should be the filename of a records log file. - Each row of this file is an encoded record pair. If it is a list, it can either be - a list of paths to log files that will be loaded jointly or an iterator or records. + records : Records or iterator of Records objects, where a Records + object is a path-like object, a file-like object, or an + iterator of (MeasureInput, MeasureResult). + + Collection of tuning records. If multiple Records objects are passed, their + contents will be merged. """ def __init__(self, records): @@ -271,11 +271,9 @@ def load(self, records: Union[Records, Iterable[Records]]): ---------- records : str, list of str, or iterator of (autotvm.measure.MeasureInput,\ autotvm.measure.MeasureResult) - Collection of tuning records. - If is str, then it should be the filename of a records log file. - Each row of this file is an encoded record pair. If it is a list - it can either be a list of paths to logs that will be loaded jointly or - an iterator of measurement results. + + Collection of tuning records. If multiple Records objects are passed, their + contents will be merged. """ # pylint: disable=import-outside-toplevel from ..record import load_from_file, load_from_io From 49c57545b4b5cdd5c8c389870ffd19a65e5e9724 Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Sun, 21 Aug 2022 06:17:49 -0700 Subject: [PATCH 3/7] Add unit tests for new functionality --- tests/python/unittest/test_autotvm_record.py | 92 ++++++++++++++++---- 1 file changed, 76 insertions(+), 16 deletions(-) diff --git a/tests/python/unittest/test_autotvm_record.py b/tests/python/unittest/test_autotvm_record.py index 147122ff10d6..693810d3f979 100644 --- a/tests/python/unittest/test_autotvm_record.py +++ b/tests/python/unittest/test_autotvm_record.py @@ -15,10 +15,10 @@ # specific language governing permissions and limitations # under the License. """test the correctness of dump and load of data log""" +from io import StringIO +from os import PathLike import time -import tvm -from tvm import te from tvm.contrib import utils from tvm import autotvm @@ -78,23 +78,83 @@ def test_file_io(): assert str(x) == str(inputs[0][2]) -def test_apply_history_best(): +def test_apply_history_best(tmpdir): tsk, target = get_sample_task() + best = str(tsk.config_space.get(2)) - records = [ - (MeasureInput(target, tsk, tsk.config_space.get(0)), MeasureResult((0.1,), 0, 2.3, 0)), - (MeasureInput(target, tsk, tsk.config_space.get(1)), MeasureResult((0.3,), 0, 2.3, 0)), - (MeasureInput(target, tsk, tsk.config_space.get(2)), MeasureResult((0.01,), 0, 2.3, 0)), - (MeasureInput(target, tsk, tsk.config_space.get(4)), MeasureResult((0.4,), 0, 2.3, 0)), - ] - hist_best = ApplyHistoryBest(records) - x = hist_best.query(target, tsk.workload) - assert str(x) == str(tsk.config_space.get(2)) + inputs_batch_1 = [MeasureInput(target, tsk, tsk.config_space.get(i)) for i in range(3)] + results_batch_1 = [MeasureResult((i,), 0, 0, 0) for i in range(1, 3)] + results_batch_1.append(MeasureResult((0.5,), 0, 2.3, 0)) - # Confirm same functionality for iterators. - hist_best = ApplyHistoryBest(iter(records)) - x = hist_best.query(target, tsk.workload) - assert str(x) == str(tsk.config_space.get(2)) + # Write data out to file + filepath_batch_1 = tmpdir / "batch_1.log" + with open(filepath_batch_1, "w") as file: + autotvm.callback.log_to_file(file)(None, inputs_batch_1, results_batch_1) + + # Load best results from Path + assert isinstance(filepath_batch_1, PathLike) + hist_best = ApplyHistoryBest(filepath_batch_1) + assert str(hist_best.query(target, tsk.workload)) == best + + # Load best results from str(Path) + hist_best = ApplyHistoryBest(str(filepath_batch_1)) + assert str(hist_best.query(target, tsk.workload)) == best + + # Write data into StringIO buffer + stringio_batch_1 = StringIO() + assert isinstance(filepath_batch_1, PathLike) + callback = autotvm.callback.log_to_file(stringio_batch_1) + callback(None, inputs_batch_1, results_batch_1) + stringio_batch_1.seek(0) + + # Load best results from strIO + hist_best = ApplyHistoryBest(stringio_batch_1) + assert str(hist_best.query(target, tsk.workload)) == best + + # Load best result from list of tuples (MeasureInput, MeasureResult) + hist_best = ApplyHistoryBest(list(zip(inputs_batch_1, results_batch_1))) + assert str(hist_best.query(target, tsk.workload)) == best + + # Same thing, but iterable instead of list (i.e. no subscripting) + hist_best = ApplyHistoryBest(zip(inputs_batch_1, results_batch_1)) + assert str(hist_best.query(target, tsk.workload)) == best + + +def test_apply_history_best_multiple_batches(tmpdir): + tsk, target = get_sample_task() + best = str(tsk.config_space.get(2)) + + inputs_batch_1 = [MeasureInput(target, tsk, tsk.config_space.get(i)) for i in range(2)] + results_batch_1 = [MeasureResult((i,), 0, 0, 0) for i in range(1, 3)] + filepath_batch_1 = tmpdir / "batch_1.log" + with open(filepath_batch_1, "w") as file: + autotvm.callback.log_to_file(file)(None, inputs_batch_1, results_batch_1) + + inputs_batch_2 = [MeasureInput(target, tsk, tsk.config_space.get(i)) for i in range(2, 4)] + results_batch_2 = [MeasureResult((0.5,), 0, 0, 0), MeasureResult((3,), 0, 0, 0)] + filepath_batch_2 = tmpdir / "batch_2.log" + with open(filepath_batch_2, "w") as file: + autotvm.callback.log_to_file(file)(None, inputs_batch_2, results_batch_2) + + # Check two Path filepaths works + hist_best = ApplyHistoryBest([filepath_batch_1, filepath_batch_2]) + assert str(hist_best.query(target, tsk.workload)) == best + + # Check that an arbitrary Iterable of Paths works + # Calling zip() on a single list gives a non-subscriptable Iterable + hist_best = ApplyHistoryBest(zip([filepath_batch_1, filepath_batch_2])) + assert str(hist_best.query(target, tsk.workload)) == best + + # Check that Iterable of Iterable of tuples is correctly merged + hist_best = ApplyHistoryBest( + zip( + [ + zip(inputs_batch_1, results_batch_1), + zip(inputs_batch_2, results_batch_2), + ] + ) + ) + assert str(hist_best.query(target, tsk.workload)) == best if __name__ == "__main__": From c8c69833179dc74f470a912a068cb276c99d14d0 Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Sun, 21 Aug 2022 06:18:00 -0700 Subject: [PATCH 4/7] Fix relevant bug in test_autotune --- tests/micro/common/test_autotune.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/micro/common/test_autotune.py b/tests/micro/common/test_autotune.py index b79260dd46ed..46f6d8889a9a 100644 --- a/tests/micro/common/test_autotune.py +++ b/tests/micro/common/test_autotune.py @@ -61,6 +61,7 @@ def test_kws_autotune_workflow(platform, board, tmp_path): assert logs[0]["config"]["entity"] != logs[1]["config"]["entity"] # Compile the best model with AOT and connect to it + str_io_logs.seek(0) with tvm.micro.testing.create_aot_session( platform, board, From 04211e16cb4eb83ea6ee23cac542aadd5cd39770 Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Sun, 21 Aug 2022 06:28:37 -0700 Subject: [PATCH 5/7] Rename load_from_io to load_from_buffer --- python/tvm/autotvm/record.py | 12 +++++++----- python/tvm/autotvm/task/dispatcher.py | 4 ++-- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/python/tvm/autotvm/record.py b/python/tvm/autotvm/record.py index 8a3753326d26..8e54e011c0b7 100644 --- a/python/tvm/autotvm/record.py +++ b/python/tvm/autotvm/record.py @@ -20,10 +20,12 @@ import argparse import base64 +from io import TextIOBase import logging import pickle import json import time +from typing import Union import os import itertools from collections import OrderedDict @@ -194,13 +196,13 @@ def clean_json_to_python(x): raise RuntimeError("Invalid log protocol: " + protocol) -def load_from_io(file): - """Generator: load records from file. +def load_from_buffer(file: TextIOBase): + """Generator: load records from buffer. This is a generator that yields the records. Parameters ---------- - file: os.TextIOBase + file: io.TextIOBase Yields ------ @@ -215,13 +217,13 @@ def load_from_io(file): yield ret -def load_from_file(filepath): +def load_from_file(filepath: Union[str, bytes, os.PathLike]): """Generator: load records from path. This is a generator that yields the records. Parameters ---------- - filepath: str, bytes, os.PathLike + filepath: str, bytes, or os.PathLike Yields ------ diff --git a/python/tvm/autotvm/task/dispatcher.py b/python/tvm/autotvm/task/dispatcher.py index 83ce90b39aa4..009cf620ceb7 100644 --- a/python/tvm/autotvm/task/dispatcher.py +++ b/python/tvm/autotvm/task/dispatcher.py @@ -276,7 +276,7 @@ def load(self, records: Union[Records, Iterable[Records]]): contents will be merged. """ # pylint: disable=import-outside-toplevel - from ..record import load_from_file, load_from_io + from ..record import load_from_file, load_from_buffer def _unpack_records( records: Union[Records, Iterable[Records]] @@ -286,7 +286,7 @@ def _unpack_records( return load_from_file(records) if isinstance(records, TextIOBase): - return load_from_io(records) + return load_from_buffer(records) joint_records = [] for record in records: From 6960e5bf4f763cc15d078547a6966992b81549d8 Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Sun, 21 Aug 2022 06:47:35 -0700 Subject: [PATCH 6/7] Modify ApplyGraphBest to take a Records object as input --- python/tvm/autotvm/task/dispatcher.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/tvm/autotvm/task/dispatcher.py b/python/tvm/autotvm/task/dispatcher.py index 009cf620ceb7..f50a72822552 100644 --- a/python/tvm/autotvm/task/dispatcher.py +++ b/python/tvm/autotvm/task/dispatcher.py @@ -468,11 +468,16 @@ def __init__(self, records: Records): Otherwise, it is an iterator. """ # pylint: disable=import-outside-toplevel - from ..record import load_from_file + from ..record import load_from_file, load_from_buffer super(ApplyGraphBest, self).__init__() - if isinstance(records, str): + if isinstance(records, str, bytes, PathLike): records = load_from_file(records) + elif isinstance(records, TextIOBase): + records = load_from_buffer(records) + else: + records = list(records) + self._records = list(records) self._counter = 0 self._global_cfg_dict = {} From 131e56ab5ba4a5f4c73077693543aae707601f41 Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Tue, 23 Aug 2022 02:57:20 -0700 Subject: [PATCH 7/7] Address some comments --- python/tvm/autotvm/task/dispatcher.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/tvm/autotvm/task/dispatcher.py b/python/tvm/autotvm/task/dispatcher.py index f50a72822552..8b2e7eb01fe6 100644 --- a/python/tvm/autotvm/task/dispatcher.py +++ b/python/tvm/autotvm/task/dispatcher.py @@ -246,15 +246,15 @@ class ApplyHistoryBest(DispatchContext): Parameters ---------- - records : Records or iterator of Records objects, where a Records - object is a path-like object, a file-like object, or an - iterator of (MeasureInput, MeasureResult). + records : None, Records, or iterator of Records objects, where a + Records object is a path-like object, a file-like object, + or an iterator of (MeasureInput, MeasureResult). Collection of tuning records. If multiple Records objects are passed, their contents will be merged. """ - def __init__(self, records): + def __init__(self, records: Union[None, Records, Iterable[Records]]): super(ApplyHistoryBest, self).__init__() self.best_by_targetkey = {} @@ -471,7 +471,7 @@ def __init__(self, records: Records): from ..record import load_from_file, load_from_buffer super(ApplyGraphBest, self).__init__() - if isinstance(records, str, bytes, PathLike): + if isinstance(records, (str, bytes, PathLike)): records = load_from_file(records) elif isinstance(records, TextIOBase): records = load_from_buffer(records)