Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoTVM] Add support for text buffers to ApplyHistoryBest #12521

Merged
merged 7 commits into from
Aug 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions python/tvm/autotvm/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@

import argparse
import base64
from io import TextIOBase
import logging
import pickle
import json
import time
from typing import Union
import os
import itertools
from collections import OrderedDict
Expand Down Expand Up @@ -194,20 +196,41 @@ def clean_json_to_python(x):
raise RuntimeError("Invalid log protocol: " + protocol)


def load_from_file(filename):
"""Generator: load records from file.
def load_from_buffer(file: TextIOBase):
"""Generator: load records from buffer.
This is a generator that yields the records.

Parameters
----------
filename: str
file: io.TextIOBase

Yields
------
input: autotvm.measure.MeasureInput
result: autotvm.measure.MeasureResult
"""
with open(filename) as f:
for row in file:
if row and not row.startswith("#"):
ret = decode(row)
if ret is None:
continue
yield ret


def load_from_file(filepath: Union[str, bytes, os.PathLike]):
"""Generator: load records from path.
This is a generator that yields the records.

Parameters
----------
filepath: str, bytes, or os.PathLike

Yields
------
input: autotvm.measure.MeasureInput
result: autotvm.measure.MeasureResult
"""
with open(filepath) as f:
for row in f:
if row and not row.startswith("#"):
ret = decode(row)
Expand Down
87 changes: 51 additions & 36 deletions python/tvm/autotvm/task/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,26 @@

from __future__ import absolute_import as _abs

from io import TextIOBase
import logging
import typing
from typing import Union
from collections.abc import Iterable
from os import PathLike
from pathlib import Path
from typing import List, Iterable, Tuple, Union

import numpy as np

from .space import FallbackConfigEntity
from .. import env as _env
from ..measure import MeasureInput, MeasureResult

logger = logging.getLogger("autotvm")

Records = Union[
Union[str, bytes, Path], # Path-like objects
TextIOBase, # File-like objects
Iterable[Tuple[MeasureInput, MeasureResult]],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this use case supported currently? i'm not sure whether we should support bypassing the load path here (this is technically "defensive programming" which i recognize is not great but i also think it might be prudent here).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, ApplyHistoryBest currently supports taking an input of type Iterable[Tuple[MeasureInput, MeasureResult]]. This functionality is basically never used though, as autotuning (AFAIK) does not support exporting logs with this type. Do you think this functionality can be removed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok, if we remove it then let's do it not in this PR.

]


class DispatchContext(object):
"""
Expand Down Expand Up @@ -194,7 +202,7 @@ class ApplyFixedConfig(DispatchContext):
Name of schedules to use.
"""

def __init__(self, tasks, schedule_names: Union[str, typing.List[str]]):
def __init__(self, tasks, schedule_names: Union[str, List[str]]):
super(ApplyFixedConfig, self).__init__()
if isinstance(schedule_names, str):
self._schedule_names = list(schedule_names)
Expand Down Expand Up @@ -238,15 +246,15 @@ class ApplyHistoryBest(DispatchContext):

Parameters
----------
records : str, list of str, or iterator of (autotvm.measure.MeasureInput,\
autotvm.measure.MeasureResult)
Collection of tuning records.
If is str, then it should be the filename of a records log file.
Each row of this file is an encoded record pair. If it is a list, it can either be
a list of paths to log files that will be loaded jointly or an iterator or records.
records : None, Records, or iterator of Records objects, where a
Records object is a path-like object, a file-like object,
or an iterator of (MeasureInput, MeasureResult).

Collection of tuning records. If multiple Records objects are passed, their
contents will be merged.
"""

def __init__(self, records):
def __init__(self, records: Union[None, Records, Iterable[Records]]):
super(ApplyHistoryBest, self).__init__()

self.best_by_targetkey = {}
Expand All @@ -256,46 +264,48 @@ def __init__(self, records):
if records:
self.load(records)

def load(self, records):
def load(self, records: Union[Records, Iterable[Records]]):
"""Load records to this dispatch context

Parameters
----------
records : str, list of str, or iterator of (autotvm.measure.MeasureInput,\
autotvm.measure.MeasureResult)
Collection of tuning records.
If is str, then it should be the filename of a records log file.
Each row of this file is an encoded record pair. If it is a list
it can either be a list of paths to logs that will be loaded jointly or
an iterator of measurement results.

Collection of tuning records. If multiple Records objects are passed, their
contents will be merged.
"""
# pylint: disable=import-outside-toplevel
from pathlib import Path
from ..record import load_from_file
from ..record import load_from_file, load_from_buffer

joint_records = []
if not isinstance(records, Iterable) or isinstance(records, str):
records = [records]
def _unpack_records(
records: Union[Records, Iterable[Records]]
) -> List[Tuple[MeasureInput, MeasureResult]]:

for rec in records:
if isinstance(rec, Path):
rec = str(rec)
if isinstance(records, (str, bytes, PathLike)):
return load_from_file(records)

if isinstance(rec, str):
rec = load_from_file(rec)
joint_records += rec
else:
if rec is not None:
joint_records.append(rec)
if isinstance(records, TextIOBase):
return load_from_buffer(records)

if not joint_records:
joint_records = []
for record in records:
if isinstance(record, Tuple) and isinstance(record[0], MeasureInput):
joint_records.append(record)
else:
joint_records += _unpack_records(record)

return joint_records

flattened_records = _unpack_records(records)
if not flattened_records:
return

best_by_targetkey = self.best_by_targetkey
best_by_model = self.best_by_model

counter = 0
for inp, res in joint_records:
for inp, res in flattened_records:
counter += 1
if res.error_no != 0:
continue
Expand Down Expand Up @@ -447,7 +457,7 @@ class ApplyGraphBest(DispatchContext):
node index.
"""

def __init__(self, records):
def __init__(self, records: Records):
"""
Parameters
----------
Expand All @@ -458,11 +468,16 @@ def __init__(self, records):
Otherwise, it is an iterator.
"""
# pylint: disable=import-outside-toplevel
from ..record import load_from_file
from ..record import load_from_file, load_from_buffer

super(ApplyGraphBest, self).__init__()
if isinstance(records, str):
if isinstance(records, (str, bytes, PathLike)):
records = load_from_file(records)
elif isinstance(records, TextIOBase):
records = load_from_buffer(records)
else:
records = list(records)

self._records = list(records)
self._counter = 0
self._global_cfg_dict = {}
Expand Down
1 change: 1 addition & 0 deletions tests/micro/common/test_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def test_kws_autotune_workflow(platform, board, tmp_path):
assert logs[0]["config"]["entity"] != logs[1]["config"]["entity"]

# Compile the best model with AOT and connect to it
str_io_logs.seek(0)
with tvm.micro.testing.create_aot_session(
platform,
board,
Expand Down
92 changes: 76 additions & 16 deletions tests/python/unittest/test_autotvm_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
# specific language governing permissions and limitations
# under the License.
"""test the correctness of dump and load of data log"""
from io import StringIO
from os import PathLike
import time

import tvm
from tvm import te
from tvm.contrib import utils

from tvm import autotvm
Expand Down Expand Up @@ -78,23 +78,83 @@ def test_file_io():
assert str(x) == str(inputs[0][2])


def test_apply_history_best():
def test_apply_history_best(tmpdir):
tsk, target = get_sample_task()
best = str(tsk.config_space.get(2))

records = [
(MeasureInput(target, tsk, tsk.config_space.get(0)), MeasureResult((0.1,), 0, 2.3, 0)),
(MeasureInput(target, tsk, tsk.config_space.get(1)), MeasureResult((0.3,), 0, 2.3, 0)),
(MeasureInput(target, tsk, tsk.config_space.get(2)), MeasureResult((0.01,), 0, 2.3, 0)),
(MeasureInput(target, tsk, tsk.config_space.get(4)), MeasureResult((0.4,), 0, 2.3, 0)),
]
hist_best = ApplyHistoryBest(records)
x = hist_best.query(target, tsk.workload)
assert str(x) == str(tsk.config_space.get(2))
inputs_batch_1 = [MeasureInput(target, tsk, tsk.config_space.get(i)) for i in range(3)]
results_batch_1 = [MeasureResult((i,), 0, 0, 0) for i in range(1, 3)]
results_batch_1.append(MeasureResult((0.5,), 0, 2.3, 0))

# Confirm same functionality for iterators.
hist_best = ApplyHistoryBest(iter(records))
x = hist_best.query(target, tsk.workload)
assert str(x) == str(tsk.config_space.get(2))
# Write data out to file
filepath_batch_1 = tmpdir / "batch_1.log"
with open(filepath_batch_1, "w") as file:
autotvm.callback.log_to_file(file)(None, inputs_batch_1, results_batch_1)

# Load best results from Path
assert isinstance(filepath_batch_1, PathLike)
hist_best = ApplyHistoryBest(filepath_batch_1)
assert str(hist_best.query(target, tsk.workload)) == best

# Load best results from str(Path)
hist_best = ApplyHistoryBest(str(filepath_batch_1))
assert str(hist_best.query(target, tsk.workload)) == best

# Write data into StringIO buffer
stringio_batch_1 = StringIO()
assert isinstance(filepath_batch_1, PathLike)
callback = autotvm.callback.log_to_file(stringio_batch_1)
callback(None, inputs_batch_1, results_batch_1)
stringio_batch_1.seek(0)

# Load best results from strIO
hist_best = ApplyHistoryBest(stringio_batch_1)
assert str(hist_best.query(target, tsk.workload)) == best

# Load best result from list of tuples (MeasureInput, MeasureResult)
hist_best = ApplyHistoryBest(list(zip(inputs_batch_1, results_batch_1)))
assert str(hist_best.query(target, tsk.workload)) == best

# Same thing, but iterable instead of list (i.e. no subscripting)
hist_best = ApplyHistoryBest(zip(inputs_batch_1, results_batch_1))
assert str(hist_best.query(target, tsk.workload)) == best


def test_apply_history_best_multiple_batches(tmpdir):
tsk, target = get_sample_task()
best = str(tsk.config_space.get(2))

inputs_batch_1 = [MeasureInput(target, tsk, tsk.config_space.get(i)) for i in range(2)]
results_batch_1 = [MeasureResult((i,), 0, 0, 0) for i in range(1, 3)]
filepath_batch_1 = tmpdir / "batch_1.log"
with open(filepath_batch_1, "w") as file:
autotvm.callback.log_to_file(file)(None, inputs_batch_1, results_batch_1)

inputs_batch_2 = [MeasureInput(target, tsk, tsk.config_space.get(i)) for i in range(2, 4)]
results_batch_2 = [MeasureResult((0.5,), 0, 0, 0), MeasureResult((3,), 0, 0, 0)]
filepath_batch_2 = tmpdir / "batch_2.log"
with open(filepath_batch_2, "w") as file:
autotvm.callback.log_to_file(file)(None, inputs_batch_2, results_batch_2)

# Check two Path filepaths works
hist_best = ApplyHistoryBest([filepath_batch_1, filepath_batch_2])
assert str(hist_best.query(target, tsk.workload)) == best

# Check that an arbitrary Iterable of Paths works
# Calling zip() on a single list gives a non-subscriptable Iterable
hist_best = ApplyHistoryBest(zip([filepath_batch_1, filepath_batch_2]))
assert str(hist_best.query(target, tsk.workload)) == best

# Check that Iterable of Iterable of tuples is correctly merged
hist_best = ApplyHistoryBest(
zip(
[
zip(inputs_batch_1, results_batch_1),
zip(inputs_batch_2, results_batch_2),
]
)
)
assert str(hist_best.query(target, tsk.workload)) == best


if __name__ == "__main__":
Expand Down