From 00251a002735768cbd2f28358eb584105da99203 Mon Sep 17 00:00:00 2001 From: Aidyn-A <31858918+Aidyn-A@users.noreply.github.com> Date: Mon, 5 Jun 2023 18:43:46 -0700 Subject: [PATCH 01/14] add libtorch_cuda_graphs tutorial --- advanced_source/libtorch_cuda_graphs.rst | 156 +++++++ .../libtorch_cuda_graphs/CMakeLists.txt | 30 ++ .../libtorch_cuda_graphs/README.md | 38 ++ .../libtorch_cuda_graphs/mnist.cpp | 379 ++++++++++++++++++ 4 files changed, 603 insertions(+) create mode 100644 advanced_source/libtorch_cuda_graphs.rst create mode 100644 advanced_source/libtorch_cuda_graphs/CMakeLists.txt create mode 100644 advanced_source/libtorch_cuda_graphs/README.md create mode 100644 advanced_source/libtorch_cuda_graphs/mnist.cpp diff --git a/advanced_source/libtorch_cuda_graphs.rst b/advanced_source/libtorch_cuda_graphs.rst new file mode 100644 index 0000000000..8ce93a6f0c --- /dev/null +++ b/advanced_source/libtorch_cuda_graphs.rst @@ -0,0 +1,156 @@ +Using CUDA Graphs in PyTorch C++ API +==================================== + +NVIDIA’s CUDA Graph is a powerful tool. CUDA Graphs are capable of greatly reducing +the CPU overhead increasing the performance. + +In this example we would like to demonstrate CUDA Graphs usage on PyTorch’s `MNIST +example `_. +The usage of CUDA Graphs in LibTorch is very similar to its `Python counterpart +`_ with only major +difference in launguage syntax. + + +The main training training loop consists of the several steps and depicted in the +following code chunk: + +.. code-block:: cpp + + for (auto& batch : data_loader) { + auto data = batch.data.to(device); + auto targets = batch.target.to(device); + optimizer.zero_grad(); + auto output = model.forward(data); + auto loss = torch::nll_loss(output, targets); + loss.backward(); + optimizer.step(); + } + +In this tutorial we will be applying CUDA Graph on all the compute steps via whole-network +graph capture. Before doing so, we need to slightly modify the source code by preallocating +tensors and reusing them in all the work that is being done on GPU: + +.. code-block:: cpp + + torch::TensorOptions FloatCUDA = + torch::TensorOptions(device).dtype(torch::kFloat); + torch::TensorOptions LongCUDA = + torch::TensorOptions(device).dtype(torch::kLong); + + torch::Tensor data = torch::zeros({kTrainBatchSize, 1, 28, 28}, FloatCUDA); + torch::Tensor targets = torch::zeros({kTrainBatchSize}, LongCUDA); + torch::Tensor output = torch::zeros({1}, FloatCUDA); + torch::Tensor loss = torch::zeros({1}, FloatCUDA); + + for (auto& batch : data_loader) { + data.copy_(batch.data); + targets.copy_(batch.target); + training_step(model, optimizer, data, targets, output, loss); + } + +Where training step simply consists of forward and backward passes with corresponding optimizer calls: + +.. code-block:: cpp + + void training_step( + Net& model, + torch::optim::Optimizer& optimizer, + torch::Tensor& data, + torch::Tensor& targets, + torch::Tensor& output, + torch::Tensor& loss) { + optimizer.zero_grad(); + output = model.forward(data); + loss = torch::nll_loss(output, targets); + loss.backward(); + optimizer.step(); + } + +PyTorch’s CUDA Graphs API is relying on Stream Capture which in general looks like the following: + +.. code-block:: cpp + + at::cuda::CUDAGraph test_graph; + at::cuda::CUDAStream captureStream = at::cuda::getStreamFromPool(); + at::cuda::setCurrentCUDAStream(captureStream); + + graph.capture_begin(); + training_step(model, optimizer, data, targets, output, loss); + graph.capture_end(); + +Before actual graph capture it is important to run several warm-up iterations on side stream to +prepare CUDA cache as well as CUDA libraries (like CUBLAS and CUDNN) that will be used during +the training: + +.. code-block:: cpp + + at::cuda::CUDAStream warmupStream = at::cuda::getStreamFromPool(); + at::cuda::setCurrentCUDAStream(warmupStream); + for (int iter = 0; iter < num_warmup_iters; iter++) { + training_step(model, optimizer, data, targets, output, loss); + } + +After successful graph capturing we can replace `training_step(model, optimizer, data, targets, output, loss);` +call via `graph.replay();`. + +The full source code is available in GitHub. + +.. code-block:: shell + +$ time ./mnist +Train Epoch: 1 [59584/60000] Loss: 0.3921 +Test set: Average loss: 0.2051 | Accuracy: 0.938 +Train Epoch: 2 [59584/60000] Loss: 0.1826 +Test set: Average loss: 0.1273 | Accuracy: 0.960 +Train Epoch: 3 [59584/60000] Loss: 0.1796 +Test set: Average loss: 0.1012 | Accuracy: 0.968 +Train Epoch: 4 [59584/60000] Loss: 0.1603 +Test set: Average loss: 0.0869 | Accuracy: 0.973 +Train Epoch: 5 [59584/60000] Loss: 0.2315 +Test set: Average loss: 0.0736 | Accuracy: 0.978 +Train Epoch: 6 [59584/60000] Loss: 0.0511 +Test set: Average loss: 0.0704 | Accuracy: 0.977 +Train Epoch: 7 [59584/60000] Loss: 0.0802 +Test set: Average loss: 0.0654 | Accuracy: 0.979 +Train Epoch: 8 [59584/60000] Loss: 0.0774 +Test set: Average loss: 0.0604 | Accuracy: 0.980 +Train Epoch: 9 [59584/60000] Loss: 0.0669 +Test set: Average loss: 0.0544 | Accuracy: 0.984 +Train Epoch: 10 [59584/60000] Loss: 0.0219 +Test set: Average loss: 0.0517 | Accuracy: 0.983 + +real 0m44.287s +user 0m44.018s +sys 0m1.116s + +.. code-block:: shell + +$ time ./mnist --use-train-graph +CUDA is available! Training on GPU. +Using CUDA Graph for training. +Train Epoch: 1 [59584/60000] Loss: 0.4092 +Test set: Average loss: 0.2037 | Accuracy: 0.938 +Train Epoch: 2 [59584/60000] Loss: 0.2039 +Test set: Average loss: 0.1274 | Accuracy: 0.961 +Train Epoch: 3 [59584/60000] Loss: 0.1779 +Test set: Average loss: 0.1017 | Accuracy: 0.968 +Train Epoch: 4 [59584/60000] Loss: 0.1559 +Test set: Average loss: 0.0871 | Accuracy: 0.972 +Train Epoch: 5 [59584/60000] Loss: 0.2240 +Test set: Average loss: 0.0735 | Accuracy: 0.977 +Train Epoch: 6 [59584/60000] Loss: 0.0520 +Test set: Average loss: 0.0710 | Accuracy: 0.978 +Train Epoch: 7 [59584/60000] Loss: 0.0935 +Test set: Average loss: 0.0666 | Accuracy: 0.979 +Train Epoch: 8 [59584/60000] Loss: 0.0744 +Test set: Average loss: 0.0603 | Accuracy: 0.981 +Train Epoch: 9 [59584/60000] Loss: 0.0762 +Test set: Average loss: 0.0547 | Accuracy: 0.983 +Train Epoch: 10 [59584/60000] Loss: 0.0207 +Test set: Average loss: 0.0525 | Accuracy: 0.983 + +real 0m6.952s +user 0m7.048s +sys 0m0.619s + +As we can see, just applying a CUDA Graph for the training step we were able to gain the performance by more than 6 times. diff --git a/advanced_source/libtorch_cuda_graphs/CMakeLists.txt b/advanced_source/libtorch_cuda_graphs/CMakeLists.txt new file mode 100644 index 0000000000..e525b71bc4 --- /dev/null +++ b/advanced_source/libtorch_cuda_graphs/CMakeLists.txt @@ -0,0 +1,30 @@ +cmake_minimum_required(VERSION 3.1 FATAL_ERROR) +project(mnist) +set(CMAKE_CXX_STANDARD 17) + +find_package(Torch REQUIRED) + +option(DOWNLOAD_MNIST "Download the MNIST dataset from the internet" ON) +if (DOWNLOAD_MNIST) + message(STATUS "Downloading MNIST dataset") + execute_process( + COMMAND python ${CMAKE_CURRENT_LIST_DIR}/../tools/download_mnist.py + -d ${CMAKE_BINARY_DIR}/data + ERROR_VARIABLE DOWNLOAD_ERROR) + if (DOWNLOAD_ERROR) + message(FATAL_ERROR "Error downloading MNIST dataset: ${DOWNLOAD_ERROR}") + endif() +endif() + +add_executable(mnist mnist.cpp) +target_compile_features(mnist PUBLIC cxx_range_for) +target_link_libraries(mnist ${TORCH_LIBRARIES}) + +if (MSVC) + file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll") + add_custom_command(TARGET mnist + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different + ${TORCH_DLLS} + $) +endif (MSVC) diff --git a/advanced_source/libtorch_cuda_graphs/README.md b/advanced_source/libtorch_cuda_graphs/README.md new file mode 100644 index 0000000000..cbe368d1e9 --- /dev/null +++ b/advanced_source/libtorch_cuda_graphs/README.md @@ -0,0 +1,38 @@ +# MNIST Example with the PyTorch C++ Frontend + +This folder contains an example of training a computer vision model to recognize +digits in images from the MNIST dataset, using the PyTorch C++ frontend. + +The entire training code is contained in `mnist.cpp`. + +To build the code, run the following commands from your terminal: + +```shell +$ cd mnist +$ mkdir build +$ cd build +$ cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch .. +$ make +``` + +where `/path/to/libtorch` should be the path to the unzipped _LibTorch_ +distribution, which you can get from the [PyTorch +homepage](https://pytorch.org/get-started/locally/). + +Execute the compiled binary to train the model: + +```shell +$ ./mnist +Train Epoch: 1 [59584/60000] Loss: 0.4232 +Test set: Average loss: 0.1989 | Accuracy: 0.940 +Train Epoch: 2 [59584/60000] Loss: 0.1926 +Test set: Average loss: 0.1338 | Accuracy: 0.959 +Train Epoch: 3 [59584/60000] Loss: 0.1390 +Test set: Average loss: 0.0997 | Accuracy: 0.969 +Train Epoch: 4 [59584/60000] Loss: 0.1239 +Test set: Average loss: 0.0875 | Accuracy: 0.972 +... +``` + +For running with CUDA Graphs add `--use-train-graph` and/or `--use-test-graph` +for training and testing passes respectively. diff --git a/advanced_source/libtorch_cuda_graphs/mnist.cpp b/advanced_source/libtorch_cuda_graphs/mnist.cpp new file mode 100644 index 0000000000..0121de80fd --- /dev/null +++ b/advanced_source/libtorch_cuda_graphs/mnist.cpp @@ -0,0 +1,379 @@ +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +// Where to find the MNIST dataset. +const char* kDataRoot = "./data"; + +// The batch size for training. +const int64_t kTrainBatchSize = 64; + +// The batch size for testing. +const int64_t kTestBatchSize = 1000; + +// The number of epochs to train. +const int64_t kNumberOfEpochs = 10; + +// After how many batches to log a new update with the loss value. +const int64_t kLogInterval = 10; + +// Model that we will be training +struct Net : torch::nn::Module { + Net() + : conv1(torch::nn::Conv2dOptions(1, 10, /*kernel_size=*/5)), + conv2(torch::nn::Conv2dOptions(10, 20, /*kernel_size=*/5)), + fc1(320, 50), + fc2(50, 10) { + register_module("conv1", conv1); + register_module("conv2", conv2); + register_module("conv2_drop", conv2_drop); + register_module("fc1", fc1); + register_module("fc2", fc2); + } + + torch::Tensor forward(torch::Tensor x) { + x = torch::relu(torch::max_pool2d(conv1->forward(x), 2)); + x = torch::relu( + torch::max_pool2d(conv2_drop->forward(conv2->forward(x)), 2)); + x = x.view({-1, 320}); + x = torch::relu(fc1->forward(x)); + x = torch::dropout(x, /*p=*/0.5, /*training=*/is_training()); + x = fc2->forward(x); + return torch::log_softmax(x, /*dim=*/1); + } + + torch::nn::Conv2d conv1; + torch::nn::Conv2d conv2; + torch::nn::Dropout2d conv2_drop; + torch::nn::Linear fc1; + torch::nn::Linear fc2; +}; + +void training_step( + Net& model, + torch::optim::Optimizer& optimizer, + torch::Tensor& data, + torch::Tensor& targets, + torch::Tensor& output, + torch::Tensor& loss) { + optimizer.zero_grad(); + output = model.forward(data); + loss = torch::nll_loss(output, targets); + loss.backward(); + optimizer.step(); +} + +void capture_train_graph( + Net& model, + torch::optim::Optimizer& optimizer, + torch::Tensor& data, + torch::Tensor& targets, + torch::Tensor& output, + torch::Tensor& loss, + at::cuda::CUDAGraph& graph, + const short num_warmup_iters = 7) { + model.train(); + + at::cuda::CUDAStream warmupStream = at::cuda::getStreamFromPool(); + at::cuda::CUDAStream captureStream = at::cuda::getStreamFromPool(); + at::cuda::CUDAStream legacyStream = at::cuda::getCurrentCUDAStream(); + + at::cuda::setCurrentCUDAStream(warmupStream); + + at::cuda::CUDAEvent cuda_ev; + cuda_ev.record(legacyStream); + cuda_ev.block(warmupStream); + + for (int iter = 0; iter < num_warmup_iters; iter++) { + training_step(model, optimizer, data, targets, output, loss); + } + + cuda_ev.record(warmupStream); + cuda_ev.block(captureStream); + at::cuda::setCurrentCUDAStream(captureStream); + + graph.capture_begin(); + training_step(model, optimizer, data, targets, output, loss); + graph.capture_end(); + + cuda_ev.record(captureStream); + cuda_ev.block(legacyStream); + at::cuda::setCurrentCUDAStream(legacyStream); +} + +template +void train( + size_t epoch, + Net& model, + torch::Device device, + DataLoader& data_loader, + torch::optim::Optimizer& optimizer, + size_t dataset_size, + torch::Tensor& data, + torch::Tensor& targets, + torch::Tensor& output, + torch::Tensor& loss, + at::cuda::CUDAGraph& graph, + bool use_graph) { + model.train(); + + size_t batch_idx = 0; + + for (auto& batch : data_loader) { + if (batch.data.size(0) != kTrainBatchSize || + batch.target.size(0) != kTrainBatchSize) { + continue; + } + + data.copy_(batch.data); + targets.copy_(batch.target); + + if (use_graph) { + graph.replay(); + } else { + training_step(model, optimizer, data, targets, output, loss); + } + + if (batch_idx++ % kLogInterval == 0) { + float train_loss = loss.item(); + std::printf( + "\rTrain Epoch: %ld [%5ld/%5ld] Loss: %.4f", + epoch, + batch_idx * batch.data.size(0), + dataset_size, + train_loss); + } + } +} + +void test_step( + Net& model, + torch::Tensor& data, + torch::Tensor& targets, + torch::Tensor& output, + torch::Tensor& loss) { + output = model.forward(data); + loss = torch::nll_loss(output, targets, {}, torch::Reduction::Sum); +} + +void capture_test_graph( + Net& model, + torch::Tensor& data, + torch::Tensor& targets, + torch::Tensor& output, + torch::Tensor& loss, + torch::Tensor& total_loss, + torch::Tensor& total_correct, + at::cuda::CUDAGraph& graph) { + torch::NoGradGuard no_grad; + model.eval(); + + const int num_warmup_iters = 7; + + at::cuda::CUDAStream warmupStream = at::cuda::getStreamFromPool(); + at::cuda::CUDAStream captureStream = at::cuda::getStreamFromPool(); + at::cuda::CUDAStream legacyStream = at::cuda::getCurrentCUDAStream(); + + at::cuda::setCurrentCUDAStream(warmupStream); + + at::cuda::CUDAEvent cuda_ev; + cuda_ev.record(legacyStream); + cuda_ev.block(warmupStream); + + for (int iter = 0; iter < num_warmup_iters; iter++) { + test_step(model, data, targets, output, loss); + total_loss = total_loss + loss; + total_correct = total_correct + output.argmax(1).eq(targets).sum(); + } + + cuda_ev.record(warmupStream); + cuda_ev.block(captureStream); + at::cuda::setCurrentCUDAStream(captureStream); + + graph.capture_begin(); + test_step(model, data, targets, output, loss); + graph.capture_end(); + + cuda_ev.record(captureStream); + cuda_ev.block(legacyStream); + at::cuda::setCurrentCUDAStream(legacyStream); +} + +template +void test( + Net& model, + torch::Device device, + DataLoader& data_loader, + size_t dataset_size, + torch::Tensor& data, + torch::Tensor& targets, + torch::Tensor& output, + torch::Tensor& loss, + torch::Tensor& total_loss, + torch::Tensor& total_correct, + at::cuda::CUDAGraph& graph, + bool use_graph) { + torch::NoGradGuard no_grad; + + model.eval(); + loss.zero_(); + total_loss.zero_(); + total_correct.zero_(); + + for (const auto& batch : data_loader) { + if (batch.data.size(0) != kTestBatchSize || + batch.target.size(0) != kTestBatchSize) { + continue; + } + data.copy_(batch.data); + targets.copy_(batch.target); + + if (use_graph) { + graph.replay(); + } else { + test_step(model, data, targets, output, loss); + } + total_loss = total_loss + loss; + total_correct = total_correct + output.argmax(1).eq(targets).sum(); + } + + float test_loss = total_loss.item() / dataset_size; + float test_accuracy = + static_cast(total_correct.item()) / dataset_size; + + std::printf( + "\nTest set: Average loss: %.4f | Accuracy: %.3f\n", + test_loss, + test_accuracy); +} + +int main(int argc, char* argv[]) { + bool use_train_graph = false; + bool use_test_graph = false; + + torch::manual_seed(1); + + torch::DeviceType device_type; + if (torch::cuda::is_available()) { + std::cout << "CUDA is available! Training on GPU." << std::endl; + + std::vector arguments(argv + 1, argv + argc); + for (std::string& arg : arguments) { + if (arg == "--use-train-graph") { + std::cout << "Using CUDA Graph for training." << std::endl; + use_train_graph = true; + } + if (arg == "--use-test-graph") { + std::cout << "Using CUDA Graph for testing." << std::endl; + use_test_graph = true; + } + } + + device_type = torch::kCUDA; + } else { + std::cout << "CUDA is not available!" << std::endl; + return 1; + } + torch::Device device(device_type); + + Net model; + model.to(device); + + auto train_dataset = + torch::data::datasets::MNIST(kDataRoot) + .map(torch::data::transforms::Normalize<>(0.1307, 0.3081)) + .map(torch::data::transforms::Stack<>()); + const size_t train_dataset_size = train_dataset.size().value(); + auto train_loader = + torch::data::make_data_loader( + std::move(train_dataset), kTrainBatchSize); + + auto test_dataset = + torch::data::datasets::MNIST( + kDataRoot, torch::data::datasets::MNIST::Mode::kTest) + .map(torch::data::transforms::Normalize<>(0.1307, 0.3081)) + .map(torch::data::transforms::Stack<>()); + const size_t test_dataset_size = test_dataset.size().value(); + auto test_loader = + torch::data::make_data_loader(std::move(test_dataset), kTestBatchSize); + + torch::optim::SGD optimizer( + model.parameters(), torch::optim::SGDOptions(0.01).momentum(0.5)); + + torch::TensorOptions FloatCUDA = + torch::TensorOptions(device).dtype(torch::kFloat); + torch::TensorOptions LongCUDA = + torch::TensorOptions(device).dtype(torch::kLong); + + torch::Tensor train_data = + torch::zeros({kTrainBatchSize, 1, 28, 28}, FloatCUDA); + torch::Tensor train_targets = torch::zeros({kTrainBatchSize}, LongCUDA); + torch::Tensor train_output = torch::zeros({1}, FloatCUDA); + torch::Tensor train_loss = torch::zeros({1}, FloatCUDA); + + torch::Tensor test_data = + torch::zeros({kTestBatchSize, 1, 28, 28}, FloatCUDA); + torch::Tensor test_targets = torch::zeros({kTestBatchSize}, LongCUDA); + torch::Tensor test_output = torch::zeros({1}, FloatCUDA); + torch::Tensor test_loss = torch::zeros({1}, FloatCUDA); + torch::Tensor test_total_loss = torch::zeros({1}, FloatCUDA); + torch::Tensor test_total_correct = torch::zeros({1}, LongCUDA); + + at::cuda::CUDAGraph train_graph; + at::cuda::CUDAGraph test_graph; + + capture_train_graph( + model, + optimizer, + train_data, + train_targets, + train_output, + train_loss, + train_graph); + + capture_test_graph( + model, + test_data, + test_targets, + test_output, + test_loss, + test_total_loss, + test_total_correct, + test_graph); + + for (size_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) { + train( + epoch, + model, + device, + *train_loader, + optimizer, + train_dataset_size, + train_data, + train_targets, + train_output, + train_loss, + train_graph, + use_train_graph); + test( + model, + device, + *test_loader, + test_dataset_size, + test_data, + test_targets, + test_output, + test_loss, + test_total_loss, + test_total_correct, + test_graph, + use_test_graph); + } +} From c5dcb0b8c396154bedea039da8e6d5d84615b36d Mon Sep 17 00:00:00 2001 From: Aidyn-A <31858918+Aidyn-A@users.noreply.github.com> Date: Mon, 5 Jun 2023 18:48:45 -0700 Subject: [PATCH 02/14] minor fixes --- advanced_source/libtorch_cuda_graphs.rst | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/advanced_source/libtorch_cuda_graphs.rst b/advanced_source/libtorch_cuda_graphs.rst index 8ce93a6f0c..6cc87c82bd 100644 --- a/advanced_source/libtorch_cuda_graphs.rst +++ b/advanced_source/libtorch_cuda_graphs.rst @@ -8,7 +8,7 @@ In this example we would like to demonstrate CUDA Graphs usage on PyTorch’s `M example `_. The usage of CUDA Graphs in LibTorch is very similar to its `Python counterpart `_ with only major -difference in launguage syntax. +difference in language syntax. The main training training loop consists of the several steps and depicted in the @@ -95,6 +95,8 @@ call via `graph.replay();`. The full source code is available in GitHub. +The ordinary eager-mode produces the following output: + .. code-block:: shell $ time ./mnist @@ -123,6 +125,8 @@ real 0m44.287s user 0m44.018s sys 0m1.116s +While the CUDA Graph output is the following: + .. code-block:: shell $ time ./mnist --use-train-graph From 48fd4154a311d4903376a37d0e3dd3e0b6e6cc55 Mon Sep 17 00:00:00 2001 From: Aidyn-A <31858918+Aidyn-A@users.noreply.github.com> Date: Mon, 5 Jun 2023 18:49:45 -0700 Subject: [PATCH 03/14] minor fixes --- advanced_source/libtorch_cuda_graphs.rst | 104 +++++++++++------------ 1 file changed, 52 insertions(+), 52 deletions(-) diff --git a/advanced_source/libtorch_cuda_graphs.rst b/advanced_source/libtorch_cuda_graphs.rst index 6cc87c82bd..d786be2414 100644 --- a/advanced_source/libtorch_cuda_graphs.rst +++ b/advanced_source/libtorch_cuda_graphs.rst @@ -99,62 +99,62 @@ The ordinary eager-mode produces the following output: .. code-block:: shell -$ time ./mnist -Train Epoch: 1 [59584/60000] Loss: 0.3921 -Test set: Average loss: 0.2051 | Accuracy: 0.938 -Train Epoch: 2 [59584/60000] Loss: 0.1826 -Test set: Average loss: 0.1273 | Accuracy: 0.960 -Train Epoch: 3 [59584/60000] Loss: 0.1796 -Test set: Average loss: 0.1012 | Accuracy: 0.968 -Train Epoch: 4 [59584/60000] Loss: 0.1603 -Test set: Average loss: 0.0869 | Accuracy: 0.973 -Train Epoch: 5 [59584/60000] Loss: 0.2315 -Test set: Average loss: 0.0736 | Accuracy: 0.978 -Train Epoch: 6 [59584/60000] Loss: 0.0511 -Test set: Average loss: 0.0704 | Accuracy: 0.977 -Train Epoch: 7 [59584/60000] Loss: 0.0802 -Test set: Average loss: 0.0654 | Accuracy: 0.979 -Train Epoch: 8 [59584/60000] Loss: 0.0774 -Test set: Average loss: 0.0604 | Accuracy: 0.980 -Train Epoch: 9 [59584/60000] Loss: 0.0669 -Test set: Average loss: 0.0544 | Accuracy: 0.984 -Train Epoch: 10 [59584/60000] Loss: 0.0219 -Test set: Average loss: 0.0517 | Accuracy: 0.983 - -real 0m44.287s -user 0m44.018s -sys 0m1.116s + $ time ./mnist + Train Epoch: 1 [59584/60000] Loss: 0.3921 + Test set: Average loss: 0.2051 | Accuracy: 0.938 + Train Epoch: 2 [59584/60000] Loss: 0.1826 + Test set: Average loss: 0.1273 | Accuracy: 0.960 + Train Epoch: 3 [59584/60000] Loss: 0.1796 + Test set: Average loss: 0.1012 | Accuracy: 0.968 + Train Epoch: 4 [59584/60000] Loss: 0.1603 + Test set: Average loss: 0.0869 | Accuracy: 0.973 + Train Epoch: 5 [59584/60000] Loss: 0.2315 + Test set: Average loss: 0.0736 | Accuracy: 0.978 + Train Epoch: 6 [59584/60000] Loss: 0.0511 + Test set: Average loss: 0.0704 | Accuracy: 0.977 + Train Epoch: 7 [59584/60000] Loss: 0.0802 + Test set: Average loss: 0.0654 | Accuracy: 0.979 + Train Epoch: 8 [59584/60000] Loss: 0.0774 + Test set: Average loss: 0.0604 | Accuracy: 0.980 + Train Epoch: 9 [59584/60000] Loss: 0.0669 + Test set: Average loss: 0.0544 | Accuracy: 0.984 + Train Epoch: 10 [59584/60000] Loss: 0.0219 + Test set: Average loss: 0.0517 | Accuracy: 0.983 + + real 0m44.287s + user 0m44.018s + sys 0m1.116s While the CUDA Graph output is the following: .. code-block:: shell -$ time ./mnist --use-train-graph -CUDA is available! Training on GPU. -Using CUDA Graph for training. -Train Epoch: 1 [59584/60000] Loss: 0.4092 -Test set: Average loss: 0.2037 | Accuracy: 0.938 -Train Epoch: 2 [59584/60000] Loss: 0.2039 -Test set: Average loss: 0.1274 | Accuracy: 0.961 -Train Epoch: 3 [59584/60000] Loss: 0.1779 -Test set: Average loss: 0.1017 | Accuracy: 0.968 -Train Epoch: 4 [59584/60000] Loss: 0.1559 -Test set: Average loss: 0.0871 | Accuracy: 0.972 -Train Epoch: 5 [59584/60000] Loss: 0.2240 -Test set: Average loss: 0.0735 | Accuracy: 0.977 -Train Epoch: 6 [59584/60000] Loss: 0.0520 -Test set: Average loss: 0.0710 | Accuracy: 0.978 -Train Epoch: 7 [59584/60000] Loss: 0.0935 -Test set: Average loss: 0.0666 | Accuracy: 0.979 -Train Epoch: 8 [59584/60000] Loss: 0.0744 -Test set: Average loss: 0.0603 | Accuracy: 0.981 -Train Epoch: 9 [59584/60000] Loss: 0.0762 -Test set: Average loss: 0.0547 | Accuracy: 0.983 -Train Epoch: 10 [59584/60000] Loss: 0.0207 -Test set: Average loss: 0.0525 | Accuracy: 0.983 - -real 0m6.952s -user 0m7.048s -sys 0m0.619s + $ time ./mnist --use-train-graph + CUDA is available! Training on GPU. + Using CUDA Graph for training. + Train Epoch: 1 [59584/60000] Loss: 0.4092 + Test set: Average loss: 0.2037 | Accuracy: 0.938 + Train Epoch: 2 [59584/60000] Loss: 0.2039 + Test set: Average loss: 0.1274 | Accuracy: 0.961 + Train Epoch: 3 [59584/60000] Loss: 0.1779 + Test set: Average loss: 0.1017 | Accuracy: 0.968 + Train Epoch: 4 [59584/60000] Loss: 0.1559 + Test set: Average loss: 0.0871 | Accuracy: 0.972 + Train Epoch: 5 [59584/60000] Loss: 0.2240 + Test set: Average loss: 0.0735 | Accuracy: 0.977 + Train Epoch: 6 [59584/60000] Loss: 0.0520 + Test set: Average loss: 0.0710 | Accuracy: 0.978 + Train Epoch: 7 [59584/60000] Loss: 0.0935 + Test set: Average loss: 0.0666 | Accuracy: 0.979 + Train Epoch: 8 [59584/60000] Loss: 0.0744 + Test set: Average loss: 0.0603 | Accuracy: 0.981 + Train Epoch: 9 [59584/60000] Loss: 0.0762 + Test set: Average loss: 0.0547 | Accuracy: 0.983 + Train Epoch: 10 [59584/60000] Loss: 0.0207 + Test set: Average loss: 0.0525 | Accuracy: 0.983 + + real 0m6.952s + user 0m7.048s + sys 0m0.619s As we can see, just applying a CUDA Graph for the training step we were able to gain the performance by more than 6 times. From f1763b20b6213b7cb98a8f70ac96dbddda901c7c Mon Sep 17 00:00:00 2001 From: Aidyn-A <31858918+Aidyn-A@users.noreply.github.com> Date: Mon, 5 Jun 2023 18:50:33 -0700 Subject: [PATCH 04/14] remove unnecessary line --- advanced_source/libtorch_cuda_graphs.rst | 2 -- 1 file changed, 2 deletions(-) diff --git a/advanced_source/libtorch_cuda_graphs.rst b/advanced_source/libtorch_cuda_graphs.rst index d786be2414..c7f1ff6622 100644 --- a/advanced_source/libtorch_cuda_graphs.rst +++ b/advanced_source/libtorch_cuda_graphs.rst @@ -130,8 +130,6 @@ While the CUDA Graph output is the following: .. code-block:: shell $ time ./mnist --use-train-graph - CUDA is available! Training on GPU. - Using CUDA Graph for training. Train Epoch: 1 [59584/60000] Loss: 0.4092 Test set: Average loss: 0.2037 | Accuracy: 0.938 Train Epoch: 2 [59584/60000] Loss: 0.2039 From 2546e0a44e64ba73d86b9ddc493665be19a6b79e Mon Sep 17 00:00:00 2001 From: Aidyn-A <31858918+Aidyn-A@users.noreply.github.com> Date: Mon, 5 Jun 2023 19:11:13 -0700 Subject: [PATCH 05/14] minor fixes --- advanced_source/libtorch_cuda_graphs.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/advanced_source/libtorch_cuda_graphs.rst b/advanced_source/libtorch_cuda_graphs.rst index c7f1ff6622..3db1c80c28 100644 --- a/advanced_source/libtorch_cuda_graphs.rst +++ b/advanced_source/libtorch_cuda_graphs.rst @@ -90,8 +90,8 @@ the training: training_step(model, optimizer, data, targets, output, loss); } -After successful graph capturing we can replace `training_step(model, optimizer, data, targets, output, loss);` -call via `graph.replay();`. +After successful graph capturing we can replace ``training_step(model, optimizer, data, targets, output, loss);`` +call via ``graph.replay();``. The full source code is available in GitHub. From e07989e7aac29b162291ebb0749c813e33ad4b32 Mon Sep 17 00:00:00 2001 From: Aidyn-A <31858918+Aidyn-A@users.noreply.github.com> Date: Mon, 5 Jun 2023 19:24:50 -0700 Subject: [PATCH 06/14] add stream_sync --- .../libtorch_cuda_graphs/mnist.cpp | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/advanced_source/libtorch_cuda_graphs/mnist.cpp b/advanced_source/libtorch_cuda_graphs/mnist.cpp index 0121de80fd..9ffb1b1ecf 100644 --- a/advanced_source/libtorch_cuda_graphs/mnist.cpp +++ b/advanced_source/libtorch_cuda_graphs/mnist.cpp @@ -56,6 +56,14 @@ struct Net : torch::nn::Module { torch::nn::Linear fc2; }; +void stream_sync( + at::cuda::CUDAStream& dependency, + at::cuda::CUDAStream& dependent) { + at::cuda::CUDAEvent cuda_ev; + cuda_ev.record(dependency); + cuda_ev.block(dependent); +} + void training_step( Net& model, torch::optim::Optimizer& optimizer, @@ -87,24 +95,20 @@ void capture_train_graph( at::cuda::setCurrentCUDAStream(warmupStream); - at::cuda::CUDAEvent cuda_ev; - cuda_ev.record(legacyStream); - cuda_ev.block(warmupStream); + stream_sync(legacyStream, warmupStream); for (int iter = 0; iter < num_warmup_iters; iter++) { training_step(model, optimizer, data, targets, output, loss); } - cuda_ev.record(warmupStream); - cuda_ev.block(captureStream); + stream_sync(warmupStream, captureStream); at::cuda::setCurrentCUDAStream(captureStream); graph.capture_begin(); training_step(model, optimizer, data, targets, output, loss); graph.capture_end(); - cuda_ev.record(captureStream); - cuda_ev.block(legacyStream); + stream_sync(captureStream, legacyStream); at::cuda::setCurrentCUDAStream(legacyStream); } @@ -182,10 +186,7 @@ void capture_test_graph( at::cuda::CUDAStream legacyStream = at::cuda::getCurrentCUDAStream(); at::cuda::setCurrentCUDAStream(warmupStream); - - at::cuda::CUDAEvent cuda_ev; - cuda_ev.record(legacyStream); - cuda_ev.block(warmupStream); + stream_sync(captureStream, legacyStream); for (int iter = 0; iter < num_warmup_iters; iter++) { test_step(model, data, targets, output, loss); @@ -193,16 +194,14 @@ void capture_test_graph( total_correct = total_correct + output.argmax(1).eq(targets).sum(); } - cuda_ev.record(warmupStream); - cuda_ev.block(captureStream); + stream_sync(warmupStream, captureStream); at::cuda::setCurrentCUDAStream(captureStream); graph.capture_begin(); test_step(model, data, targets, output, loss); graph.capture_end(); - cuda_ev.record(captureStream); - cuda_ev.block(legacyStream); + stream_sync(captureStream, legacyStream); at::cuda::setCurrentCUDAStream(legacyStream); } From 0c9a6bae598f08ffef13fb8f9e53800281470e45 Mon Sep 17 00:00:00 2001 From: Aidyn-A Date: Wed, 7 Jun 2023 10:45:19 -0700 Subject: [PATCH 07/14] rename files and folders for consistency --- advanced_source/{libtorch_cuda_graphs.rst => cpp_cuda_graphs.rst} | 0 .../{libtorch_cuda_graphs => cpp_cuda_graphs}/CMakeLists.txt | 0 .../{libtorch_cuda_graphs => cpp_cuda_graphs}/README.md | 0 .../{libtorch_cuda_graphs => cpp_cuda_graphs}/mnist.cpp | 0 4 files changed, 0 insertions(+), 0 deletions(-) rename advanced_source/{libtorch_cuda_graphs.rst => cpp_cuda_graphs.rst} (100%) rename advanced_source/{libtorch_cuda_graphs => cpp_cuda_graphs}/CMakeLists.txt (100%) rename advanced_source/{libtorch_cuda_graphs => cpp_cuda_graphs}/README.md (100%) rename advanced_source/{libtorch_cuda_graphs => cpp_cuda_graphs}/mnist.cpp (100%) diff --git a/advanced_source/libtorch_cuda_graphs.rst b/advanced_source/cpp_cuda_graphs.rst similarity index 100% rename from advanced_source/libtorch_cuda_graphs.rst rename to advanced_source/cpp_cuda_graphs.rst diff --git a/advanced_source/libtorch_cuda_graphs/CMakeLists.txt b/advanced_source/cpp_cuda_graphs/CMakeLists.txt similarity index 100% rename from advanced_source/libtorch_cuda_graphs/CMakeLists.txt rename to advanced_source/cpp_cuda_graphs/CMakeLists.txt diff --git a/advanced_source/libtorch_cuda_graphs/README.md b/advanced_source/cpp_cuda_graphs/README.md similarity index 100% rename from advanced_source/libtorch_cuda_graphs/README.md rename to advanced_source/cpp_cuda_graphs/README.md diff --git a/advanced_source/libtorch_cuda_graphs/mnist.cpp b/advanced_source/cpp_cuda_graphs/mnist.cpp similarity index 100% rename from advanced_source/libtorch_cuda_graphs/mnist.cpp rename to advanced_source/cpp_cuda_graphs/mnist.cpp From af5807b18b612bb865fa9678f68eb43137097b88 Mon Sep 17 00:00:00 2001 From: Aidyn-A Date: Wed, 7 Jun 2023 11:47:54 -0700 Subject: [PATCH 08/14] add more text --- advanced_source/cpp_cuda_graphs.rst | 53 ++++++++++++++++++++--------- 1 file changed, 36 insertions(+), 17 deletions(-) diff --git a/advanced_source/cpp_cuda_graphs.rst b/advanced_source/cpp_cuda_graphs.rst index 3db1c80c28..0647d8fdd8 100644 --- a/advanced_source/cpp_cuda_graphs.rst +++ b/advanced_source/cpp_cuda_graphs.rst @@ -1,17 +1,29 @@ Using CUDA Graphs in PyTorch C++ API ==================================== -NVIDIA’s CUDA Graph is a powerful tool. CUDA Graphs are capable of greatly reducing -the CPU overhead increasing the performance. - -In this example we would like to demonstrate CUDA Graphs usage on PyTorch’s `MNIST +NVIDIA’s CUDA Graphs have been a part of CUDA Toolkit library since the +release of `version 10 `_. +They are capable of greatly reducing the CPU overhead increasing the +performance of applications. + +In this tutorial we will be focusing on using CUDA Graphs for `C++ +frontend of PyTorch `_. +The C++ frontend is mostly utilized in production and deployment applications which +are important parts of PyTorch use cases. Since first `the appearance +`_ +the CUDA Graphs won users’ and developer’s hearts for being a very performant +and at the same time simple-to-use tool. In fact, CUDA Graphs are used by default +in torch.compile of PyTorch 2.0 to boost the productivity of training and inference. + +We would like to demonstrate CUDA Graphs usage on PyTorch’s `MNIST example `_. -The usage of CUDA Graphs in LibTorch is very similar to its `Python counterpart -`_ with only major -difference in language syntax. +The usage of CUDA Graphs in LibTorch (C++ Frontend) is very similar to its +`Python counterpart `_ +but with some differences in syntax and functionality. +Without further ado, let us get started! -The main training training loop consists of the several steps and depicted in the +The main training loop consists of the several steps and depicted in the following code chunk: .. code-block:: cpp @@ -26,9 +38,11 @@ following code chunk: optimizer.step(); } +Which are basically a forward pass, backward pass and weight updates. + In this tutorial we will be applying CUDA Graph on all the compute steps via whole-network -graph capture. Before doing so, we need to slightly modify the source code by preallocating -tensors and reusing them in all the work that is being done on GPU: +graph capture. But before doing so, we need to slightly modify the source code. What we need +to do is preallocate tensors for reusing them in the main training loop like so: .. code-block:: cpp @@ -48,7 +62,7 @@ tensors and reusing them in all the work that is being done on GPU: training_step(model, optimizer, data, targets, output, loss); } -Where training step simply consists of forward and backward passes with corresponding optimizer calls: +Where ``training_step`` simply consists of forward and backward passes with corresponding optimizer calls: .. code-block:: cpp @@ -70,7 +84,7 @@ PyTorch’s CUDA Graphs API is relying on Stream Capture which in general looks .. code-block:: cpp - at::cuda::CUDAGraph test_graph; + at::cuda::CUDAGraph graph; at::cuda::CUDAStream captureStream = at::cuda::getStreamFromPool(); at::cuda::setCurrentCUDAStream(captureStream); @@ -90,12 +104,12 @@ the training: training_step(model, optimizer, data, targets, output, loss); } -After successful graph capturing we can replace ``training_step(model, optimizer, data, targets, output, loss);`` -call via ``graph.replay();``. +After successful graph capture we can replace ``training_step(model, optimizer, data, targets, output, loss);`` +call via ``graph.replay();`` do the training step. The full source code is available in GitHub. -The ordinary eager-mode produces the following output: +Taking the code for a spin we can see the following output from ordinary non-graphed training: .. code-block:: shell @@ -125,7 +139,7 @@ The ordinary eager-mode produces the following output: user 0m44.018s sys 0m1.116s -While the CUDA Graph output is the following: +While the training with the CUDA Graph produces the following output: .. code-block:: shell @@ -155,4 +169,9 @@ While the CUDA Graph output is the following: user 0m7.048s sys 0m0.619s -As we can see, just applying a CUDA Graph for the training step we were able to gain the performance by more than 6 times. +As we can see, just by applying a CUDA Graph on the `MNIST example +`_ we were able to gain the performance +by more than 6 times for training. This kind of large perf improvement was achievable due to +small model size. In case of larger models with heavy GPU usage the CPU overhead is less impactful +so the improvement will be smaller. Nevertheless, it is always advantageous to use CUDA Graphs to +gain the performance of GPUs. From ebe92d09cffaa4fb4119f813b06478eb75e7f35b Mon Sep 17 00:00:00 2001 From: Aidyn-A Date: Wed, 7 Jun 2023 11:58:44 -0700 Subject: [PATCH 09/14] fix typos and better phrasing --- advanced_source/cpp_cuda_graphs.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/advanced_source/cpp_cuda_graphs.rst b/advanced_source/cpp_cuda_graphs.rst index 0647d8fdd8..c47c7fd4a4 100644 --- a/advanced_source/cpp_cuda_graphs.rst +++ b/advanced_source/cpp_cuda_graphs.rst @@ -9,11 +9,11 @@ performance of applications. In this tutorial we will be focusing on using CUDA Graphs for `C++ frontend of PyTorch `_. The C++ frontend is mostly utilized in production and deployment applications which -are important parts of PyTorch use cases. Since first `the appearance +are important parts of PyTorch use cases. Since `the first appearance `_ the CUDA Graphs won users’ and developer’s hearts for being a very performant and at the same time simple-to-use tool. In fact, CUDA Graphs are used by default -in torch.compile of PyTorch 2.0 to boost the productivity of training and inference. +in ``torch.compile`` of PyTorch 2.0 to boost the productivity of training and inference. We would like to demonstrate CUDA Graphs usage on PyTorch’s `MNIST example `_. @@ -80,7 +80,7 @@ Where ``training_step`` simply consists of forward and backward passes with corr optimizer.step(); } -PyTorch’s CUDA Graphs API is relying on Stream Capture which in general looks like the following: +PyTorch’s CUDA Graphs API is relying on Stream Capture which in our case would used as following: .. code-block:: cpp From 073021c0fe5d87edd7bd63eadbc2ee820d4df979 Mon Sep 17 00:00:00 2001 From: Aidyn-A <31858918+Aidyn-A@users.noreply.github.com> Date: Fri, 9 Jun 2023 10:12:41 -0700 Subject: [PATCH 10/14] apply text comments --- advanced_source/cpp_cuda_graphs.rst | 43 +++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 12 deletions(-) diff --git a/advanced_source/cpp_cuda_graphs.rst b/advanced_source/cpp_cuda_graphs.rst index c47c7fd4a4..8b23698380 100644 --- a/advanced_source/cpp_cuda_graphs.rst +++ b/advanced_source/cpp_cuda_graphs.rst @@ -1,12 +1,25 @@ Using CUDA Graphs in PyTorch C++ API ==================================== +.. note:: + |edit| View and edit this tutorial in `GitHub `__. + +.. note:: + The full source code is available on `GitHub `__. + +Prerequisites: + +- `Using the PyTorch C++ Frontend <../advanced_source/cpp_frontend.html>`__ +- `CUDA semantics `__ +- `Pytorch 2.0 or later` +- `CUDA 11 or later` + NVIDIA’s CUDA Graphs have been a part of CUDA Toolkit library since the release of `version 10 `_. They are capable of greatly reducing the CPU overhead increasing the performance of applications. -In this tutorial we will be focusing on using CUDA Graphs for `C++ +In this tutorial, we will be focusing on using CUDA Graphs for `C++ frontend of PyTorch `_. The C++ frontend is mostly utilized in production and deployment applications which are important parts of PyTorch use cases. Since `the first appearance @@ -21,7 +34,8 @@ The usage of CUDA Graphs in LibTorch (C++ Frontend) is very similar to its `Python counterpart `_ but with some differences in syntax and functionality. -Without further ado, let us get started! +Getting Started +--------------- The main training loop consists of the several steps and depicted in the following code chunk: @@ -38,11 +52,12 @@ following code chunk: optimizer.step(); } -Which are basically a forward pass, backward pass and weight updates. +The example above includes a forward pass, a backward pass, and weight updates. -In this tutorial we will be applying CUDA Graph on all the compute steps via whole-network +In this tutorial, we will be applying CUDA Graph on all the compute steps through the whole-network graph capture. But before doing so, we need to slightly modify the source code. What we need -to do is preallocate tensors for reusing them in the main training loop like so: +to do is preallocate tensors for reusing them in the main training loop. Here is an example +implementation: .. code-block:: cpp @@ -80,7 +95,7 @@ Where ``training_step`` simply consists of forward and backward passes with corr optimizer.step(); } -PyTorch’s CUDA Graphs API is relying on Stream Capture which in our case would used as following: +PyTorch’s CUDA Graphs API is relying on Stream Capture which in our case would be used like this: .. code-block:: cpp @@ -92,7 +107,7 @@ PyTorch’s CUDA Graphs API is relying on Stream Capture which in our case would training_step(model, optimizer, data, targets, output, loss); graph.capture_end(); -Before actual graph capture it is important to run several warm-up iterations on side stream to +Before the actual graph capture, it is important to run several warm-up iterations on side stream to prepare CUDA cache as well as CUDA libraries (like CUBLAS and CUDNN) that will be used during the training: @@ -104,10 +119,11 @@ the training: training_step(model, optimizer, data, targets, output, loss); } -After successful graph capture we can replace ``training_step(model, optimizer, data, targets, output, loss);`` -call via ``graph.replay();`` do the training step. +After the successful graph capture, we can replace ``training_step(model, optimizer, data, targets, output, loss);`` +call via ``graph.replay();`` to do the training step. -The full source code is available in GitHub. +Training Results +---------------- Taking the code for a spin we can see the following output from ordinary non-graphed training: @@ -169,9 +185,12 @@ While the training with the CUDA Graph produces the following output: user 0m7.048s sys 0m0.619s +Conclusion +---------- + As we can see, just by applying a CUDA Graph on the `MNIST example `_ we were able to gain the performance -by more than 6 times for training. This kind of large perf improvement was achievable due to -small model size. In case of larger models with heavy GPU usage the CPU overhead is less impactful +by more than six times for training. This kind of large performance improvement was achievable due to +the small model size. In case of larger models with heavy GPU usage, the CPU overhead is less impactful so the improvement will be smaller. Nevertheless, it is always advantageous to use CUDA Graphs to gain the performance of GPUs. From 42d27d14e95011b821c703f0d0b6db3589447909 Mon Sep 17 00:00:00 2001 From: Aidyn-A <31858918+Aidyn-A@users.noreply.github.com> Date: Fri, 9 Jun 2023 10:12:57 -0700 Subject: [PATCH 11/14] apply source comments --- .../cpp_cuda_graphs/CMakeLists.txt | 3 +- advanced_source/cpp_cuda_graphs/mnist.cpp | 74 +++++++++---------- 2 files changed, 38 insertions(+), 39 deletions(-) diff --git a/advanced_source/cpp_cuda_graphs/CMakeLists.txt b/advanced_source/cpp_cuda_graphs/CMakeLists.txt index e525b71bc4..35ff1ea3e4 100644 --- a/advanced_source/cpp_cuda_graphs/CMakeLists.txt +++ b/advanced_source/cpp_cuda_graphs/CMakeLists.txt @@ -3,6 +3,7 @@ project(mnist) set(CMAKE_CXX_STANDARD 17) find_package(Torch REQUIRED) +find_package(Threads REQUIRED) option(DOWNLOAD_MNIST "Download the MNIST dataset from the internet" ON) if (DOWNLOAD_MNIST) @@ -18,7 +19,7 @@ endif() add_executable(mnist mnist.cpp) target_compile_features(mnist PUBLIC cxx_range_for) -target_link_libraries(mnist ${TORCH_LIBRARIES}) +target_link_libraries(mnist ${TORCH_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT}) if (MSVC) file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll") diff --git a/advanced_source/cpp_cuda_graphs/mnist.cpp b/advanced_source/cpp_cuda_graphs/mnist.cpp index 9ffb1b1ecf..ce686cc8ea 100644 --- a/advanced_source/cpp_cuda_graphs/mnist.cpp +++ b/advanced_source/cpp_cuda_graphs/mnist.cpp @@ -89,15 +89,15 @@ void capture_train_graph( const short num_warmup_iters = 7) { model.train(); - at::cuda::CUDAStream warmupStream = at::cuda::getStreamFromPool(); - at::cuda::CUDAStream captureStream = at::cuda::getStreamFromPool(); - at::cuda::CUDAStream legacyStream = at::cuda::getCurrentCUDAStream(); + auto warmupStream = at::cuda::getStreamFromPool(); + auto captureStream = at::cuda::getStreamFromPool(); + auto legacyStream = at::cuda::getCurrentCUDAStream(); at::cuda::setCurrentCUDAStream(warmupStream); stream_sync(legacyStream, warmupStream); - for (int iter = 0; iter < num_warmup_iters; iter++) { + for (C10_UNUSED const auto iter : c10::irange(num_warmup_iters)) { training_step(model, optimizer, data, targets, output, loss); } @@ -130,7 +130,7 @@ void train( size_t batch_idx = 0; - for (auto& batch : data_loader) { + for (const auto& batch : data_loader) { if (batch.data.size(0) != kTrainBatchSize || batch.target.size(0) != kTrainBatchSize) { continue; @@ -175,23 +175,22 @@ void capture_test_graph( torch::Tensor& loss, torch::Tensor& total_loss, torch::Tensor& total_correct, - at::cuda::CUDAGraph& graph) { + at::cuda::CUDAGraph& graph, + const int num_warmup_iters = 7) { torch::NoGradGuard no_grad; model.eval(); - const int num_warmup_iters = 7; - - at::cuda::CUDAStream warmupStream = at::cuda::getStreamFromPool(); - at::cuda::CUDAStream captureStream = at::cuda::getStreamFromPool(); - at::cuda::CUDAStream legacyStream = at::cuda::getCurrentCUDAStream(); + auto warmupStream = at::cuda::getStreamFromPool(); + auto captureStream = at::cuda::getStreamFromPool(); + auto legacyStream = at::cuda::getCurrentCUDAStream(); at::cuda::setCurrentCUDAStream(warmupStream); stream_sync(captureStream, legacyStream); - for (int iter = 0; iter < num_warmup_iters; iter++) { + for (C10_UNUSED const auto iter : c10::irange(num_warmup_iters)) { test_step(model, data, targets, output, loss); - total_loss = total_loss + loss; - total_correct = total_correct + output.argmax(1).eq(targets).sum(); + total_loss += loss; + total_correct += output.argmax(1).eq(targets).sum(); } stream_sync(warmupStream, captureStream); @@ -239,8 +238,8 @@ void test( } else { test_step(model, data, targets, output, loss); } - total_loss = total_loss + loss; - total_correct = total_correct + output.argmax(1).eq(targets).sum(); + total_loss += loss; + total_correct += output.argmax(1).eq(targets).sum(); } float test_loss = total_loss.item() / dataset_size; @@ -254,33 +253,29 @@ void test( } int main(int argc, char* argv[]) { + if (!torch::cuda::is_available()) { + std::cout << "CUDA is not available!" << std::endl; + return -1; + } + bool use_train_graph = false; bool use_test_graph = false; - torch::manual_seed(1); - - torch::DeviceType device_type; - if (torch::cuda::is_available()) { - std::cout << "CUDA is available! Training on GPU." << std::endl; - - std::vector arguments(argv + 1, argv + argc); - for (std::string& arg : arguments) { - if (arg == "--use-train-graph") { - std::cout << "Using CUDA Graph for training." << std::endl; - use_train_graph = true; - } - if (arg == "--use-test-graph") { - std::cout << "Using CUDA Graph for testing." << std::endl; - use_test_graph = true; - } + std::vector arguments(argv + 1, argv + argc); + for (std::string& arg : arguments) { + if (arg == "--use-train-graph") { + std::cout << "Using CUDA Graph for training." << std::endl; + use_train_graph = true; + } + if (arg == "--use-test-graph") { + std::cout << "Using CUDA Graph for testing." << std::endl; + use_test_graph = true; } - - device_type = torch::kCUDA; - } else { - std::cout << "CUDA is not available!" << std::endl; - return 1; } - torch::Device device(device_type); + + torch::manual_seed(1); + torch::cuda::manual_seed(1); + torch::Device device(torch::kCUDA); Net model; model.to(device); @@ -375,4 +370,7 @@ int main(int argc, char* argv[]) { test_graph, use_test_graph); } + + std::cout << " Training/testing complete" << std::endl; + return 0; } From a58626506e74e9131d077ce318d8f39317e70052 Mon Sep 17 00:00:00 2001 From: Aidyn-A <31858918+Aidyn-A@users.noreply.github.com> Date: Fri, 9 Jun 2023 10:37:23 -0700 Subject: [PATCH 12/14] use cout and apply clang-format --- advanced_source/cpp_cuda_graphs/mnist.cpp | 24 ++++++++++------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/advanced_source/cpp_cuda_graphs/mnist.cpp b/advanced_source/cpp_cuda_graphs/mnist.cpp index ce686cc8ea..97c5fb80ca 100644 --- a/advanced_source/cpp_cuda_graphs/mnist.cpp +++ b/advanced_source/cpp_cuda_graphs/mnist.cpp @@ -97,7 +97,7 @@ void capture_train_graph( stream_sync(legacyStream, warmupStream); - for (C10_UNUSED const auto iter : c10::irange(num_warmup_iters)) { + for (C10_UNUSED const auto iter : c10::irange(num_warmup_iters)) { training_step(model, optimizer, data, targets, output, loss); } @@ -147,12 +147,9 @@ void train( if (batch_idx++ % kLogInterval == 0) { float train_loss = loss.item(); - std::printf( - "\rTrain Epoch: %ld [%5ld/%5ld] Loss: %.4f", - epoch, - batch_idx * batch.data.size(0), - dataset_size, - train_loss); + std::cout << "\rTrain Epoch:" << epoch << " [" + << batch_idx * batch.data.size(0) << "/" << dataset_size + << "] Loss: " << train_loss; } } } @@ -187,7 +184,7 @@ void capture_test_graph( at::cuda::setCurrentCUDAStream(warmupStream); stream_sync(captureStream, legacyStream); - for (C10_UNUSED const auto iter : c10::irange(num_warmup_iters)) { + for (C10_UNUSED const auto iter : c10::irange(num_warmup_iters)) { test_step(model, data, targets, output, loss); total_loss += loss; total_correct += output.argmax(1).eq(targets).sum(); @@ -246,16 +243,15 @@ void test( float test_accuracy = static_cast(total_correct.item()) / dataset_size; - std::printf( - "\nTest set: Average loss: %.4f | Accuracy: %.3f\n", - test_loss, - test_accuracy); + std::cout << std::endl + << "Test set: Average loss: " << test_loss + << " | Accuracy: " << test_accuracy << std::endl; } int main(int argc, char* argv[]) { if (!torch::cuda::is_available()) { - std::cout << "CUDA is not available!" << std::endl; - return -1; + std::cout << "CUDA is not available!" << std::endl; + return -1; } bool use_train_graph = false; From 46dd9cacefd87b38fa036a8a540c5847053c7532 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Tue, 13 Jun 2023 10:23:41 -0700 Subject: [PATCH 13/14] Apply suggestions from code review Editorial --- advanced_source/cpp_cuda_graphs.rst | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/advanced_source/cpp_cuda_graphs.rst b/advanced_source/cpp_cuda_graphs.rst index 8b23698380..494d6426d4 100644 --- a/advanced_source/cpp_cuda_graphs.rst +++ b/advanced_source/cpp_cuda_graphs.rst @@ -2,17 +2,14 @@ Using CUDA Graphs in PyTorch C++ API ==================================== .. note:: - |edit| View and edit this tutorial in `GitHub `__. - -.. note:: - The full source code is available on `GitHub `__. + |edit| View and edit this tutorial in `GitHub `__. The full source code is available on `GitHub `__. Prerequisites: - `Using the PyTorch C++ Frontend <../advanced_source/cpp_frontend.html>`__ - `CUDA semantics `__ -- `Pytorch 2.0 or later` -- `CUDA 11 or later` +- Pytorch 2.0 or later +- CUDA 11 or later NVIDIA’s CUDA Graphs have been a part of CUDA Toolkit library since the release of `version 10 `_. From 6df8287e3b5ecf918345e832af21b287cafe0b22 Mon Sep 17 00:00:00 2001 From: Aidyn-A <31858918+Aidyn-A@users.noreply.github.com> Date: Tue, 13 Jun 2023 12:34:09 -0700 Subject: [PATCH 14/14] Require CMake >= 3.18 --- advanced_source/cpp_cuda_graphs/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/advanced_source/cpp_cuda_graphs/CMakeLists.txt b/advanced_source/cpp_cuda_graphs/CMakeLists.txt index 35ff1ea3e4..76fc5bc676 100644 --- a/advanced_source/cpp_cuda_graphs/CMakeLists.txt +++ b/advanced_source/cpp_cuda_graphs/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.1 FATAL_ERROR) +cmake_minimum_required(VERSION 3.18 FATAL_ERROR) project(mnist) set(CMAKE_CXX_STANDARD 17)