diff --git a/python/tvm/meta_schedule/testing/relay_workload.py b/python/tvm/meta_schedule/testing/relay_workload.py index c1edc7309408..2ad448808bdb 100644 --- a/python/tvm/meta_schedule/testing/relay_workload.py +++ b/python/tvm/meta_schedule/testing/relay_workload.py @@ -28,6 +28,7 @@ class MODEL_TYPE(Enum): # pylint: disable=invalid-name VIDEO_CLASSIFICATION = (2,) SEGMENTATION = (3,) OBJECT_DETECTION = (4,) + TEXT_CLASSIFICATION = (5,) # Specify the type of each model @@ -95,6 +96,11 @@ class MODEL_TYPE(Enum): # pylint: disable=invalid-name "r3d_18": MODEL_TYPE.VIDEO_CLASSIFICATION, "mc3_18": MODEL_TYPE.VIDEO_CLASSIFICATION, "r2plus1d_18": MODEL_TYPE.VIDEO_CLASSIFICATION, + # Text classification + "bert_tiny": MODEL_TYPE.TEXT_CLASSIFICATION, + "bert_base": MODEL_TYPE.TEXT_CLASSIFICATION, + "bert_medium": MODEL_TYPE.TEXT_CLASSIFICATION, + "bert_large": MODEL_TYPE.TEXT_CLASSIFICATION, } @@ -121,6 +127,8 @@ def get_torch_model( import torch # type: ignore # pylint: disable=import-error,import-outside-toplevel from torchvision import models # type: ignore # pylint: disable=import-error,import-outside-toplevel + import transformers # type: ignore # pylint: disable=import-error,import-outside-toplevel + import os # type: ignore # pylint: disable=import-error,import-outside-toplevel def do_trace(model, inp): model_trace = torch.jit.trace(model, inp) @@ -136,6 +144,50 @@ def do_trace(model, inp): model = getattr(models.detection, model_name)() elif MODEL_TYPES[model_name] == MODEL_TYPE.VIDEO_CLASSIFICATION: model = getattr(models.video, model_name)() + elif MODEL_TYPES[model_name] == MODEL_TYPE.TEXT_CLASSIFICATION: + os.environ["TOKENIZERS_PARALLELISM"] = "false" + config_dict = { + "bert_tiny": transformers.BertConfig( + num_hidden_layers=6, + hidden_size=512, + intermediate_size=2048, + num_attention_heads=8, + return_dict=False, + ), + "bert_base": transformers.BertConfig( + num_hidden_layers=12, + hidden_size=768, + intermediate_size=3072, + num_attention_heads=12, + return_dict=False, + ), + "bert_medium": transformers.BertConfig( + num_hidden_layers=12, + hidden_size=1024, + intermediate_size=4096, + num_attention_heads=16, + return_dict=False, + ), + "bert_large": transformers.BertConfig( + num_hidden_layers=24, + hidden_size=1024, + intermediate_size=4096, + num_attention_heads=16, + return_dict=False, + ), + } + configuration = config_dict[model_name] + model = transformers.BertModel(configuration) + input_name = "input_ids" + A = torch.randint(10000, input_shape) + + model.eval() + scripted_model = torch.jit.trace(model, [A], strict=False) + + input_name = "input_ids" + shape_list = [(input_name, input_shape)] + mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) + return mod, params else: raise ValueError("Unsupported model in Torch model zoo.") diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc index 1f3943dc14dc..e3cc51a479ef 100644 --- a/src/meta_schedule/task_scheduler/task_scheduler.cc +++ b/src/meta_schedule/task_scheduler/task_scheduler.cc @@ -124,7 +124,7 @@ void TaskSchedulerNode::Tune() { int running_tasks = tasks.size(); for (int task_id; (task_id = NextTaskId()) != -1;) { - LOG(INFO) << "Scheduler picks Task #" << task_id << ": " << tasks[task_id]->task_name; + LOG(INFO) << "Scheduler picks Task #" << task_id + 1 << ": " << tasks[task_id]->task_name; TuneContext task = tasks[task_id]; ICHECK(!task->is_stopped); ICHECK(!task->runner_futures.defined()); diff --git a/tests/python/unittest/test_meta_schedule_tune_relay.py b/tests/python/unittest/test_meta_schedule_tune_relay.py index 799a77979c01..02264f797127 100644 --- a/tests/python/unittest/test_meta_schedule_tune_relay.py +++ b/tests/python/unittest/test_meta_schedule_tune_relay.py @@ -31,7 +31,7 @@ @pytest.mark.skip("Integration test") -@pytest.mark.parametrize("model_name", ["resnet18"]) +@pytest.mark.parametrize("model_name", ["resnet18", "mobilenet_v2", "bert_base"]) @pytest.mark.parametrize("batch_size", [1]) @pytest.mark.parametrize("target", ["llvm --num-cores=16", "nvidia/geforce-rtx-3070"]) def test_meta_schedule_tune_relay(model_name: str, batch_size: int, target: str): @@ -47,6 +47,9 @@ def test_meta_schedule_tune_relay(model_name: str, batch_size: int, target: str) input_shape = (1, 3, 300, 300) elif MODEL_TYPES[model_name] == MODEL_TYPE.VIDEO_CLASSIFICATION: input_shape = (batch_size, 3, 3, 299, 299) + elif MODEL_TYPES[model_name] == MODEL_TYPE.TEXT_CLASSIFICATION: + seq_length = 128 + input_shape = (batch_size, seq_length) else: raise ValueError("Unsupported model: " + model_name) output_shape: Tuple[int, int] = (batch_size, 1000) @@ -71,7 +74,7 @@ def test_meta_schedule_tune_relay(model_name: str, batch_size: int, target: str) work_dir=work_dir, ) for i, sch in enumerate(schs): - print("-" * 10 + f" Part {i}/{len(schs)} " + "-" * 10) + print("-" * 10 + f" Part {i+1}/{len(schs)} " + "-" * 10) if sch is None: print("No valid schedule found!") else: @@ -82,3 +85,7 @@ def test_meta_schedule_tune_relay(model_name: str, batch_size: int, target: str) if __name__ == """__main__""": test_meta_schedule_tune_relay("resnet18", 1, "llvm --num-cores=16") test_meta_schedule_tune_relay("resnet18", 1, "nvidia/geforce-rtx-3070") + test_meta_schedule_tune_relay("mobilenet_v2", 1, "llvm --num-cores=16") + test_meta_schedule_tune_relay("mobilenet_v2", 1, "nvidia/geforce-rtx-3070") + test_meta_schedule_tune_relay("bert_base", 1, "llvm --num-cores=16") + test_meta_schedule_tune_relay("bert_base", 1, "nvidia/geforce-rtx-3070")