Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule] Add Script for TorchBench Model Tuning & Benchmarking #12914

Merged
merged 1 commit into from
Sep 29, 2022

Conversation

yelite
Copy link
Contributor

@yelite yelite commented Sep 27, 2022

This PR adds a script to tune and benchmark TorchBench models, using torchdynamo and the pytorch importer in TVM.

cc @junrushao @zxybazh

Comment on lines 31 to 35
pip3 install --pre \
--extra-index-url https://download.pytorch.org/whl/nightly/cu116 \
torch==1.13.0.dev20220926 \
torchvision==0.14.0.dev20220926 \
torchtext==0.14.0.dev20220926
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick quesiton: how long does it guarantee pytorch nightly to persist on the URL? I mean, is this instruction going to expire some time in the future because of server cleanup?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The earliest wheel on that index was built on 07/30. So I guess the retention is 60 days. Those wheels are pretty large in size (1 GB). I added a message in the comment string to suggest trying latest nightly if this one can not be found

Comment on lines 321 to 324
mod.run()
result = [torch.from_dlpack(mod.get_output(i)) for i in range(mod.get_num_outputs())]
if IS_CUDA:
torch.cuda.synchronize()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's think twice about the synchronization here. two questions:

  • torch.cuda.synchronize or tvm.cuda(0).sync()
  • do we want to move the sync one line up before mod.get_output()?

Copy link
Contributor Author

@yelite yelite Sep 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thoughts are:

  1. We should call torch's synchronize before the computation here then call tvm's synchronize afterward, letting each library's sync to wait the computation from their side.
  2. We probably want to move sync before get_output. I believe that it doesn't matter as long as torch.from_dlpack(mod.get_output(i)) is zero-copy. But it doesn't hurt to move sync before that line.

def forward(*args):
if IS_CUDA:
torch.cuda.synchronize()
args = [arg.contiguous() for arg in args]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: is it going to incur an unnecessary extra copy if arg has been contiguous already, or it's going to be a no-op?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's no-op if arg is contiguous.

Copy link
Member

@junrushao junrushao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high-level comments: per-method documentation is needed.

Copy link
Member

@zxybazh zxybazh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally looking good. Would you please address my comments and add some type annotations? Thanks.

python/tvm/meta_schedule/testing/torchbench/run.py Outdated Show resolved Hide resolved
python/tvm/meta_schedule/testing/torchbench/run.py Outdated Show resolved Hide resolved
"--benchmark-repeat",
type=int,
default=10,
help="The number of times to repeat the benchmark measurement.",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, can we customize other benchmarking details like warm up rounds, time between measurements, etc.?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In torchdynamo there is warm up rounds. Adding this could be an option here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark runner from TorchDynamo doesn't have these parameters exposed (Their args can be found at https://github.com/pytorch/torchdynamo/blob/main/benchmarks/common.py#L1363). But we can still implement these customization if needed.

python/tvm/meta_schedule/testing/torchbench/run.py Outdated Show resolved Hide resolved
python/tvm/meta_schedule/testing/torchbench/utils.py Outdated Show resolved Hide resolved
python/tvm/meta_schedule/testing/torchbench/utils.py Outdated Show resolved Hide resolved
python/tvm/meta_schedule/testing/torchbench/utils.py Outdated Show resolved Hide resolved
python/tvm/meta_schedule/testing/torchbench/run.py Outdated Show resolved Hide resolved
machine than the one executes tuning.
```bash
python python/tvm/meta_schedule/testing/torchbench/run.py \
--mode eval \

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: As no perf evaluation will be done with --tuning, I feel like we should combine tuning and all, if perf evaluation doesn't really take much time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This option is created to support running this script on a machine without the target GPU. For example, the tuning can be done with the help of RPC runners on a machine with 3070 while targeting A100.

We still require the host machine to have GPU because the model provided from TorchBench could potentially be different on CPU versus on CUDA. If we implement the remote task extraction (we probably will), we can even run this script on a machine without GPU to tune the model.

python/tvm/meta_schedule/testing/torchbench/run.py Outdated Show resolved Hide resolved
"--benchmark-repeat",
type=int,
default=10,
help="The number of times to repeat the benchmark measurement.",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In torchdynamo there is warm up rounds. Adding this could be an option here.

for idx, arg in enumerate(args, 0):
mod.set_input(
f"inp_{idx}",
tvm.nd.from_dlpack(arg),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this could potentially be a problem and that is reason why in torchdynamo's TVM backend torch.Tensor is converted to numpy and then to TVM.NDarray. And if the arg is typed torch.Tensor you may need torch.utils.dlpack.to_dlpack(arg) for this approach as well.

Copy link
Contributor Author

@yelite yelite Sep 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That TVM backend actually doesn't work on CUDA. Converting tensor to numpy array will fail if it's on CUDA

>>> torch.zeros((5, 5), device="cuda").numpy()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

I guess the only problem here is from the boolean tensor (an maybe also tensor with unaligned memory). I will create a followup PR which uses Yaoda's work on TVM PyTorch integration to replace these code. That integration can handle these edge cases with minimal numbers of data copies.

to_dlpack is considered as a legacy interface (https://pytorch.org/docs/stable/dlpack.html#torch.utils.dlpack.to_dlpack). The new approach is to have a __dlpack__ method on the object (like torch.Tensor) and the importing function (like tvm.nd.from_dlpack) can call it to get the capsule.

Copy link
Member

@zxybazh zxybazh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good to me, would you please fix the CI?

@yelite yelite force-pushed the torchdynamo-tuning-script branch from fb909b8 to c737c1c Compare September 28, 2022 20:16
@zxybazh zxybazh changed the title Add a script to tune and benchmark models from TorchBench [MetaSchedule] Add Script for TorchBench Model Tuning & Benchmarking Sep 28, 2022
from enum import Enum
from typing import Callable, List, Tuple

import numpy as np # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if we need type: ignore here for imported libraries, any particular reason?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are for suppressing errors like Cannot find implementation or library stub for module named "torch"(https://ci.tlcpack.ai/blue/organizations/jenkins/tvm/detail/PR-12914/7/pipeline#step-97-log-74)

@yelite yelite force-pushed the torchdynamo-tuning-script branch 4 times, most recently from 8a50453 to 5affc12 Compare September 29, 2022 04:51
@yelite yelite force-pushed the torchdynamo-tuning-script branch from 5affc12 to 6b7fbbc Compare September 29, 2022 11:48
@junrushao junrushao merged commit 2379917 into apache:main Sep 29, 2022
xinetzone pushed a commit to daobook/tvm that referenced this pull request Nov 25, 2022
…pache#12914)

This PR adds a script to tune and benchmark TorchBench models, using torchdynamo and the pytorch importer in TVM.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants