From 43c281053e9b49561f86cdc4390efb39fda239d1 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 12 Feb 2023 06:40:50 -0800 Subject: [PATCH] [MetaSchedule] Fix a typo in MemoryDatabase (#13928) This typo was introduced a while ago, but was not uncovered until I was rebasing Relax when a unittest crashes. --- src/meta_schedule/database/memory_database.cc | 2 +- .../python/unittest/test_meta_schedule_database.py | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/meta_schedule/database/memory_database.cc b/src/meta_schedule/database/memory_database.cc index 8cbde46f83b7..533a86acacfd 100644 --- a/src/meta_schedule/database/memory_database.cc +++ b/src/meta_schedule/database/memory_database.cc @@ -85,7 +85,7 @@ class MemoryDatabaseNode : public DatabaseNode { } std::stable_sort(results.begin(), results.end(), SortTuningRecordByMeanRunSecs()); if (results.size() > static_cast(top_k)) { - return {results.begin(), results.end() + top_k}; + return {results.begin(), results.begin() + top_k}; } else { if (results.size() < static_cast(top_k)) { LOG(WARNING) << "The size of the GetTopK result is smaller than requested. There are not " diff --git a/tests/python/unittest/test_meta_schedule_database.py b/tests/python/unittest/test_meta_schedule_database.py index 806ea2d1827b..11fbeb811ea7 100644 --- a/tests/python/unittest/test_meta_schedule_database.py +++ b/tests/python/unittest/test_meta_schedule_database.py @@ -18,19 +18,19 @@ """Test Meta Schedule Database""" import os.path as osp import tempfile -import pytest -from typing import Callable, Optional, List +from typing import Callable, List, Optional +import pytest import tvm import tvm.testing -from tvm.target import Target from tvm import meta_schedule as ms -from tvm.meta_schedule.database import TuningRecord, Workload -from tvm import tir +from tvm import relay, tir from tvm.ir.module import IRModule +from tvm.meta_schedule.database import TuningRecord, Workload from tvm.script import tir as T +from tvm.target import Target from tvm.tir import Schedule -from tvm import relay + # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument # fmt: off @@ -556,6 +556,7 @@ def call_get_top_k(run_secs_list, database, k): "k,expected", [ (0, []), + (1, [[0.0, 2.0]]), (4, [[0.0, 2.0], [2.0], [1.5, 4.5], [3.0, 1e10]]), (5, [[0.0, 2.0], [2.0], [1.5, 4.5], [3.0, 1e10]]), ],