diff --git a/advanced_source/cpp_cuda_graphs.rst b/advanced_source/cpp_cuda_graphs.rst new file mode 100644 index 0000000000..494d6426d4 --- /dev/null +++ b/advanced_source/cpp_cuda_graphs.rst @@ -0,0 +1,193 @@ +Using CUDA Graphs in PyTorch C++ API +==================================== + +.. note:: + |edit| View and edit this tutorial in `GitHub `__. The full source code is available on `GitHub `__. + +Prerequisites: + +- `Using the PyTorch C++ Frontend <../advanced_source/cpp_frontend.html>`__ +- `CUDA semantics `__ +- Pytorch 2.0 or later +- CUDA 11 or later + +NVIDIA’s CUDA Graphs have been a part of CUDA Toolkit library since the +release of `version 10 `_. +They are capable of greatly reducing the CPU overhead increasing the +performance of applications. + +In this tutorial, we will be focusing on using CUDA Graphs for `C++ +frontend of PyTorch `_. +The C++ frontend is mostly utilized in production and deployment applications which +are important parts of PyTorch use cases. Since `the first appearance +`_ +the CUDA Graphs won users’ and developer’s hearts for being a very performant +and at the same time simple-to-use tool. In fact, CUDA Graphs are used by default +in ``torch.compile`` of PyTorch 2.0 to boost the productivity of training and inference. + +We would like to demonstrate CUDA Graphs usage on PyTorch’s `MNIST +example `_. +The usage of CUDA Graphs in LibTorch (C++ Frontend) is very similar to its +`Python counterpart `_ +but with some differences in syntax and functionality. + +Getting Started +--------------- + +The main training loop consists of the several steps and depicted in the +following code chunk: + +.. code-block:: cpp + + for (auto& batch : data_loader) { + auto data = batch.data.to(device); + auto targets = batch.target.to(device); + optimizer.zero_grad(); + auto output = model.forward(data); + auto loss = torch::nll_loss(output, targets); + loss.backward(); + optimizer.step(); + } + +The example above includes a forward pass, a backward pass, and weight updates. + +In this tutorial, we will be applying CUDA Graph on all the compute steps through the whole-network +graph capture. But before doing so, we need to slightly modify the source code. What we need +to do is preallocate tensors for reusing them in the main training loop. Here is an example +implementation: + +.. code-block:: cpp + + torch::TensorOptions FloatCUDA = + torch::TensorOptions(device).dtype(torch::kFloat); + torch::TensorOptions LongCUDA = + torch::TensorOptions(device).dtype(torch::kLong); + + torch::Tensor data = torch::zeros({kTrainBatchSize, 1, 28, 28}, FloatCUDA); + torch::Tensor targets = torch::zeros({kTrainBatchSize}, LongCUDA); + torch::Tensor output = torch::zeros({1}, FloatCUDA); + torch::Tensor loss = torch::zeros({1}, FloatCUDA); + + for (auto& batch : data_loader) { + data.copy_(batch.data); + targets.copy_(batch.target); + training_step(model, optimizer, data, targets, output, loss); + } + +Where ``training_step`` simply consists of forward and backward passes with corresponding optimizer calls: + +.. code-block:: cpp + + void training_step( + Net& model, + torch::optim::Optimizer& optimizer, + torch::Tensor& data, + torch::Tensor& targets, + torch::Tensor& output, + torch::Tensor& loss) { + optimizer.zero_grad(); + output = model.forward(data); + loss = torch::nll_loss(output, targets); + loss.backward(); + optimizer.step(); + } + +PyTorch’s CUDA Graphs API is relying on Stream Capture which in our case would be used like this: + +.. code-block:: cpp + + at::cuda::CUDAGraph graph; + at::cuda::CUDAStream captureStream = at::cuda::getStreamFromPool(); + at::cuda::setCurrentCUDAStream(captureStream); + + graph.capture_begin(); + training_step(model, optimizer, data, targets, output, loss); + graph.capture_end(); + +Before the actual graph capture, it is important to run several warm-up iterations on side stream to +prepare CUDA cache as well as CUDA libraries (like CUBLAS and CUDNN) that will be used during +the training: + +.. code-block:: cpp + + at::cuda::CUDAStream warmupStream = at::cuda::getStreamFromPool(); + at::cuda::setCurrentCUDAStream(warmupStream); + for (int iter = 0; iter < num_warmup_iters; iter++) { + training_step(model, optimizer, data, targets, output, loss); + } + +After the successful graph capture, we can replace ``training_step(model, optimizer, data, targets, output, loss);`` +call via ``graph.replay();`` to do the training step. + +Training Results +---------------- + +Taking the code for a spin we can see the following output from ordinary non-graphed training: + +.. code-block:: shell + + $ time ./mnist + Train Epoch: 1 [59584/60000] Loss: 0.3921 + Test set: Average loss: 0.2051 | Accuracy: 0.938 + Train Epoch: 2 [59584/60000] Loss: 0.1826 + Test set: Average loss: 0.1273 | Accuracy: 0.960 + Train Epoch: 3 [59584/60000] Loss: 0.1796 + Test set: Average loss: 0.1012 | Accuracy: 0.968 + Train Epoch: 4 [59584/60000] Loss: 0.1603 + Test set: Average loss: 0.0869 | Accuracy: 0.973 + Train Epoch: 5 [59584/60000] Loss: 0.2315 + Test set: Average loss: 0.0736 | Accuracy: 0.978 + Train Epoch: 6 [59584/60000] Loss: 0.0511 + Test set: Average loss: 0.0704 | Accuracy: 0.977 + Train Epoch: 7 [59584/60000] Loss: 0.0802 + Test set: Average loss: 0.0654 | Accuracy: 0.979 + Train Epoch: 8 [59584/60000] Loss: 0.0774 + Test set: Average loss: 0.0604 | Accuracy: 0.980 + Train Epoch: 9 [59584/60000] Loss: 0.0669 + Test set: Average loss: 0.0544 | Accuracy: 0.984 + Train Epoch: 10 [59584/60000] Loss: 0.0219 + Test set: Average loss: 0.0517 | Accuracy: 0.983 + + real 0m44.287s + user 0m44.018s + sys 0m1.116s + +While the training with the CUDA Graph produces the following output: + +.. code-block:: shell + + $ time ./mnist --use-train-graph + Train Epoch: 1 [59584/60000] Loss: 0.4092 + Test set: Average loss: 0.2037 | Accuracy: 0.938 + Train Epoch: 2 [59584/60000] Loss: 0.2039 + Test set: Average loss: 0.1274 | Accuracy: 0.961 + Train Epoch: 3 [59584/60000] Loss: 0.1779 + Test set: Average loss: 0.1017 | Accuracy: 0.968 + Train Epoch: 4 [59584/60000] Loss: 0.1559 + Test set: Average loss: 0.0871 | Accuracy: 0.972 + Train Epoch: 5 [59584/60000] Loss: 0.2240 + Test set: Average loss: 0.0735 | Accuracy: 0.977 + Train Epoch: 6 [59584/60000] Loss: 0.0520 + Test set: Average loss: 0.0710 | Accuracy: 0.978 + Train Epoch: 7 [59584/60000] Loss: 0.0935 + Test set: Average loss: 0.0666 | Accuracy: 0.979 + Train Epoch: 8 [59584/60000] Loss: 0.0744 + Test set: Average loss: 0.0603 | Accuracy: 0.981 + Train Epoch: 9 [59584/60000] Loss: 0.0762 + Test set: Average loss: 0.0547 | Accuracy: 0.983 + Train Epoch: 10 [59584/60000] Loss: 0.0207 + Test set: Average loss: 0.0525 | Accuracy: 0.983 + + real 0m6.952s + user 0m7.048s + sys 0m0.619s + +Conclusion +---------- + +As we can see, just by applying a CUDA Graph on the `MNIST example +`_ we were able to gain the performance +by more than six times for training. This kind of large performance improvement was achievable due to +the small model size. In case of larger models with heavy GPU usage, the CPU overhead is less impactful +so the improvement will be smaller. Nevertheless, it is always advantageous to use CUDA Graphs to +gain the performance of GPUs. diff --git a/advanced_source/cpp_cuda_graphs/CMakeLists.txt b/advanced_source/cpp_cuda_graphs/CMakeLists.txt new file mode 100644 index 0000000000..76fc5bc676 --- /dev/null +++ b/advanced_source/cpp_cuda_graphs/CMakeLists.txt @@ -0,0 +1,31 @@ +cmake_minimum_required(VERSION 3.18 FATAL_ERROR) +project(mnist) +set(CMAKE_CXX_STANDARD 17) + +find_package(Torch REQUIRED) +find_package(Threads REQUIRED) + +option(DOWNLOAD_MNIST "Download the MNIST dataset from the internet" ON) +if (DOWNLOAD_MNIST) + message(STATUS "Downloading MNIST dataset") + execute_process( + COMMAND python ${CMAKE_CURRENT_LIST_DIR}/../tools/download_mnist.py + -d ${CMAKE_BINARY_DIR}/data + ERROR_VARIABLE DOWNLOAD_ERROR) + if (DOWNLOAD_ERROR) + message(FATAL_ERROR "Error downloading MNIST dataset: ${DOWNLOAD_ERROR}") + endif() +endif() + +add_executable(mnist mnist.cpp) +target_compile_features(mnist PUBLIC cxx_range_for) +target_link_libraries(mnist ${TORCH_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT}) + +if (MSVC) + file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll") + add_custom_command(TARGET mnist + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different + ${TORCH_DLLS} + $) +endif (MSVC) diff --git a/advanced_source/cpp_cuda_graphs/README.md b/advanced_source/cpp_cuda_graphs/README.md new file mode 100644 index 0000000000..cbe368d1e9 --- /dev/null +++ b/advanced_source/cpp_cuda_graphs/README.md @@ -0,0 +1,38 @@ +# MNIST Example with the PyTorch C++ Frontend + +This folder contains an example of training a computer vision model to recognize +digits in images from the MNIST dataset, using the PyTorch C++ frontend. + +The entire training code is contained in `mnist.cpp`. + +To build the code, run the following commands from your terminal: + +```shell +$ cd mnist +$ mkdir build +$ cd build +$ cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch .. +$ make +``` + +where `/path/to/libtorch` should be the path to the unzipped _LibTorch_ +distribution, which you can get from the [PyTorch +homepage](https://pytorch.org/get-started/locally/). + +Execute the compiled binary to train the model: + +```shell +$ ./mnist +Train Epoch: 1 [59584/60000] Loss: 0.4232 +Test set: Average loss: 0.1989 | Accuracy: 0.940 +Train Epoch: 2 [59584/60000] Loss: 0.1926 +Test set: Average loss: 0.1338 | Accuracy: 0.959 +Train Epoch: 3 [59584/60000] Loss: 0.1390 +Test set: Average loss: 0.0997 | Accuracy: 0.969 +Train Epoch: 4 [59584/60000] Loss: 0.1239 +Test set: Average loss: 0.0875 | Accuracy: 0.972 +... +``` + +For running with CUDA Graphs add `--use-train-graph` and/or `--use-test-graph` +for training and testing passes respectively. diff --git a/advanced_source/cpp_cuda_graphs/mnist.cpp b/advanced_source/cpp_cuda_graphs/mnist.cpp new file mode 100644 index 0000000000..97c5fb80ca --- /dev/null +++ b/advanced_source/cpp_cuda_graphs/mnist.cpp @@ -0,0 +1,372 @@ +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +// Where to find the MNIST dataset. +const char* kDataRoot = "./data"; + +// The batch size for training. +const int64_t kTrainBatchSize = 64; + +// The batch size for testing. +const int64_t kTestBatchSize = 1000; + +// The number of epochs to train. +const int64_t kNumberOfEpochs = 10; + +// After how many batches to log a new update with the loss value. +const int64_t kLogInterval = 10; + +// Model that we will be training +struct Net : torch::nn::Module { + Net() + : conv1(torch::nn::Conv2dOptions(1, 10, /*kernel_size=*/5)), + conv2(torch::nn::Conv2dOptions(10, 20, /*kernel_size=*/5)), + fc1(320, 50), + fc2(50, 10) { + register_module("conv1", conv1); + register_module("conv2", conv2); + register_module("conv2_drop", conv2_drop); + register_module("fc1", fc1); + register_module("fc2", fc2); + } + + torch::Tensor forward(torch::Tensor x) { + x = torch::relu(torch::max_pool2d(conv1->forward(x), 2)); + x = torch::relu( + torch::max_pool2d(conv2_drop->forward(conv2->forward(x)), 2)); + x = x.view({-1, 320}); + x = torch::relu(fc1->forward(x)); + x = torch::dropout(x, /*p=*/0.5, /*training=*/is_training()); + x = fc2->forward(x); + return torch::log_softmax(x, /*dim=*/1); + } + + torch::nn::Conv2d conv1; + torch::nn::Conv2d conv2; + torch::nn::Dropout2d conv2_drop; + torch::nn::Linear fc1; + torch::nn::Linear fc2; +}; + +void stream_sync( + at::cuda::CUDAStream& dependency, + at::cuda::CUDAStream& dependent) { + at::cuda::CUDAEvent cuda_ev; + cuda_ev.record(dependency); + cuda_ev.block(dependent); +} + +void training_step( + Net& model, + torch::optim::Optimizer& optimizer, + torch::Tensor& data, + torch::Tensor& targets, + torch::Tensor& output, + torch::Tensor& loss) { + optimizer.zero_grad(); + output = model.forward(data); + loss = torch::nll_loss(output, targets); + loss.backward(); + optimizer.step(); +} + +void capture_train_graph( + Net& model, + torch::optim::Optimizer& optimizer, + torch::Tensor& data, + torch::Tensor& targets, + torch::Tensor& output, + torch::Tensor& loss, + at::cuda::CUDAGraph& graph, + const short num_warmup_iters = 7) { + model.train(); + + auto warmupStream = at::cuda::getStreamFromPool(); + auto captureStream = at::cuda::getStreamFromPool(); + auto legacyStream = at::cuda::getCurrentCUDAStream(); + + at::cuda::setCurrentCUDAStream(warmupStream); + + stream_sync(legacyStream, warmupStream); + + for (C10_UNUSED const auto iter : c10::irange(num_warmup_iters)) { + training_step(model, optimizer, data, targets, output, loss); + } + + stream_sync(warmupStream, captureStream); + at::cuda::setCurrentCUDAStream(captureStream); + + graph.capture_begin(); + training_step(model, optimizer, data, targets, output, loss); + graph.capture_end(); + + stream_sync(captureStream, legacyStream); + at::cuda::setCurrentCUDAStream(legacyStream); +} + +template +void train( + size_t epoch, + Net& model, + torch::Device device, + DataLoader& data_loader, + torch::optim::Optimizer& optimizer, + size_t dataset_size, + torch::Tensor& data, + torch::Tensor& targets, + torch::Tensor& output, + torch::Tensor& loss, + at::cuda::CUDAGraph& graph, + bool use_graph) { + model.train(); + + size_t batch_idx = 0; + + for (const auto& batch : data_loader) { + if (batch.data.size(0) != kTrainBatchSize || + batch.target.size(0) != kTrainBatchSize) { + continue; + } + + data.copy_(batch.data); + targets.copy_(batch.target); + + if (use_graph) { + graph.replay(); + } else { + training_step(model, optimizer, data, targets, output, loss); + } + + if (batch_idx++ % kLogInterval == 0) { + float train_loss = loss.item(); + std::cout << "\rTrain Epoch:" << epoch << " [" + << batch_idx * batch.data.size(0) << "/" << dataset_size + << "] Loss: " << train_loss; + } + } +} + +void test_step( + Net& model, + torch::Tensor& data, + torch::Tensor& targets, + torch::Tensor& output, + torch::Tensor& loss) { + output = model.forward(data); + loss = torch::nll_loss(output, targets, {}, torch::Reduction::Sum); +} + +void capture_test_graph( + Net& model, + torch::Tensor& data, + torch::Tensor& targets, + torch::Tensor& output, + torch::Tensor& loss, + torch::Tensor& total_loss, + torch::Tensor& total_correct, + at::cuda::CUDAGraph& graph, + const int num_warmup_iters = 7) { + torch::NoGradGuard no_grad; + model.eval(); + + auto warmupStream = at::cuda::getStreamFromPool(); + auto captureStream = at::cuda::getStreamFromPool(); + auto legacyStream = at::cuda::getCurrentCUDAStream(); + + at::cuda::setCurrentCUDAStream(warmupStream); + stream_sync(captureStream, legacyStream); + + for (C10_UNUSED const auto iter : c10::irange(num_warmup_iters)) { + test_step(model, data, targets, output, loss); + total_loss += loss; + total_correct += output.argmax(1).eq(targets).sum(); + } + + stream_sync(warmupStream, captureStream); + at::cuda::setCurrentCUDAStream(captureStream); + + graph.capture_begin(); + test_step(model, data, targets, output, loss); + graph.capture_end(); + + stream_sync(captureStream, legacyStream); + at::cuda::setCurrentCUDAStream(legacyStream); +} + +template +void test( + Net& model, + torch::Device device, + DataLoader& data_loader, + size_t dataset_size, + torch::Tensor& data, + torch::Tensor& targets, + torch::Tensor& output, + torch::Tensor& loss, + torch::Tensor& total_loss, + torch::Tensor& total_correct, + at::cuda::CUDAGraph& graph, + bool use_graph) { + torch::NoGradGuard no_grad; + + model.eval(); + loss.zero_(); + total_loss.zero_(); + total_correct.zero_(); + + for (const auto& batch : data_loader) { + if (batch.data.size(0) != kTestBatchSize || + batch.target.size(0) != kTestBatchSize) { + continue; + } + data.copy_(batch.data); + targets.copy_(batch.target); + + if (use_graph) { + graph.replay(); + } else { + test_step(model, data, targets, output, loss); + } + total_loss += loss; + total_correct += output.argmax(1).eq(targets).sum(); + } + + float test_loss = total_loss.item() / dataset_size; + float test_accuracy = + static_cast(total_correct.item()) / dataset_size; + + std::cout << std::endl + << "Test set: Average loss: " << test_loss + << " | Accuracy: " << test_accuracy << std::endl; +} + +int main(int argc, char* argv[]) { + if (!torch::cuda::is_available()) { + std::cout << "CUDA is not available!" << std::endl; + return -1; + } + + bool use_train_graph = false; + bool use_test_graph = false; + + std::vector arguments(argv + 1, argv + argc); + for (std::string& arg : arguments) { + if (arg == "--use-train-graph") { + std::cout << "Using CUDA Graph for training." << std::endl; + use_train_graph = true; + } + if (arg == "--use-test-graph") { + std::cout << "Using CUDA Graph for testing." << std::endl; + use_test_graph = true; + } + } + + torch::manual_seed(1); + torch::cuda::manual_seed(1); + torch::Device device(torch::kCUDA); + + Net model; + model.to(device); + + auto train_dataset = + torch::data::datasets::MNIST(kDataRoot) + .map(torch::data::transforms::Normalize<>(0.1307, 0.3081)) + .map(torch::data::transforms::Stack<>()); + const size_t train_dataset_size = train_dataset.size().value(); + auto train_loader = + torch::data::make_data_loader( + std::move(train_dataset), kTrainBatchSize); + + auto test_dataset = + torch::data::datasets::MNIST( + kDataRoot, torch::data::datasets::MNIST::Mode::kTest) + .map(torch::data::transforms::Normalize<>(0.1307, 0.3081)) + .map(torch::data::transforms::Stack<>()); + const size_t test_dataset_size = test_dataset.size().value(); + auto test_loader = + torch::data::make_data_loader(std::move(test_dataset), kTestBatchSize); + + torch::optim::SGD optimizer( + model.parameters(), torch::optim::SGDOptions(0.01).momentum(0.5)); + + torch::TensorOptions FloatCUDA = + torch::TensorOptions(device).dtype(torch::kFloat); + torch::TensorOptions LongCUDA = + torch::TensorOptions(device).dtype(torch::kLong); + + torch::Tensor train_data = + torch::zeros({kTrainBatchSize, 1, 28, 28}, FloatCUDA); + torch::Tensor train_targets = torch::zeros({kTrainBatchSize}, LongCUDA); + torch::Tensor train_output = torch::zeros({1}, FloatCUDA); + torch::Tensor train_loss = torch::zeros({1}, FloatCUDA); + + torch::Tensor test_data = + torch::zeros({kTestBatchSize, 1, 28, 28}, FloatCUDA); + torch::Tensor test_targets = torch::zeros({kTestBatchSize}, LongCUDA); + torch::Tensor test_output = torch::zeros({1}, FloatCUDA); + torch::Tensor test_loss = torch::zeros({1}, FloatCUDA); + torch::Tensor test_total_loss = torch::zeros({1}, FloatCUDA); + torch::Tensor test_total_correct = torch::zeros({1}, LongCUDA); + + at::cuda::CUDAGraph train_graph; + at::cuda::CUDAGraph test_graph; + + capture_train_graph( + model, + optimizer, + train_data, + train_targets, + train_output, + train_loss, + train_graph); + + capture_test_graph( + model, + test_data, + test_targets, + test_output, + test_loss, + test_total_loss, + test_total_correct, + test_graph); + + for (size_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) { + train( + epoch, + model, + device, + *train_loader, + optimizer, + train_dataset_size, + train_data, + train_targets, + train_output, + train_loss, + train_graph, + use_train_graph); + test( + model, + device, + *test_loader, + test_dataset_size, + test_data, + test_targets, + test_output, + test_loss, + test_total_loss, + test_total_correct, + test_graph, + use_test_graph); + } + + std::cout << " Training/testing complete" << std::endl; + return 0; +}