From 13bb632b87acad44379fe0c5e5c5d070d0aedaf3 Mon Sep 17 00:00:00 2001 From: Nikolay Bogoslovskiy Date: Tue, 28 Nov 2023 07:11:12 +0700 Subject: [PATCH] Gpu memory representation functions (#557) --- Makefile | 19 +- docker/Dockerfile.deepstream | 11 ++ .../11_memory_representation_function.rst | 144 ++++++++++++++ docs/source/conf.py | 2 +- docs/source/index.rst | 1 + docs/source/reference/api/utils.rst | 15 ++ requirements/base.txt | 4 + requirements/dev.txt | 2 +- savant/utils/memory_repr.py | 81 ++++++++ savant/utils/memory_repr_pytorch.py | 25 +++ tests/test_memory_repr.py | 186 ++++++++++++++++++ 11 files changed, 487 insertions(+), 3 deletions(-) create mode 100644 docs/source/advanced_topics/11_memory_representation_function.rst create mode 100644 savant/utils/memory_repr.py create mode 100644 savant/utils/memory_repr_pytorch.py create mode 100644 tests/test_memory_repr.py diff --git a/Makefile b/Makefile index 79f5c1b6..7e07b448 100644 --- a/Makefile +++ b/Makefile @@ -60,6 +60,15 @@ build-docs: -f docker/$(DOCKER_FILE) \ -t savant-docs:$(SAVANT_VERSION) . +build-tests: + docker buildx build \ + --target tests \ + --build-arg DEEPSTREAM_VERSION=$(DEEPSTREAM_VERSION) \ + --build-arg USER_UID=`id -u` \ + --build-arg USER_GID=`id -g` \ + -f docker/$(DOCKER_FILE) \ + -t savant-tests:$(SAVANT_VERSION) . + build-opencv: opencv-build-amd64 opencv-build-arm64 opencv-cp-amd64 opencv-cp-arm64 opencv-build-amd64: @@ -96,6 +105,14 @@ run-docs: --name savant-docs \ savant-docs:$(SAVANT_VERSION) +run-tests: + docker run -it --rm \ + -v `pwd`/savant:$(PROJECT_PATH)/savant \ + -v `pwd`/tests:$(PROJECT_PATH)/tests \ + --gpus=all \ + --name savant-tests \ + savant-tests:$(SAVANT_VERSION) + run-dev: xhost +local:docker docker run -it --rm --gpus=all \ @@ -131,7 +148,7 @@ check-unify: check: check-black check-unify check-isort run-unify: - unify --in-place --recursive savant adapters gst_plugins samples scripts + unify --in-place --recursive savant adapters gst_plugins samples scripts tests run-black: black . diff --git a/docker/Dockerfile.deepstream b/docker/Dockerfile.deepstream index 673fce6f..7d0ab0b6 100644 --- a/docker/Dockerfile.deepstream +++ b/docker/Dockerfile.deepstream @@ -246,3 +246,14 @@ WORKDIR $PROJECT_PATH/docs ENTRYPOINT ["make"] CMD ["clean", "html"] + + +# Savant test image, x86 only +FROM base AS tests +COPY requirements/dev.txt requirements/dev.txt +COPY tests /opt/savant/tests + +RUN python -m pip install --no-cache-dir -r requirements/dev.txt +RUN python -m pip install torch torchvision torchaudio + +ENTRYPOINT ["pytest", "-s", "/opt/savant/tests"] \ No newline at end of file diff --git a/docs/source/advanced_topics/11_memory_representation_function.rst b/docs/source/advanced_topics/11_memory_representation_function.rst new file mode 100644 index 00000000..4cda8dbf --- /dev/null +++ b/docs/source/advanced_topics/11_memory_representation_function.rst @@ -0,0 +1,144 @@ +Conversions Between GPU Memory Formats +--------------------------------------------- + +When working with images, there are many ways to represent them as arrays of pixels. Working with different models you may encounter representation of an image using OpenCV GpuMat class, PyTorch tensor or CuPy array. + +The Savant framework aims to use GPU efficiently without excessive data transfers. To achieve this, Savant provides functions for converting between different image representations. Data exchange is performed with zero-copying between different views, except for some cases of conversion to GpuMat OpenCV. + +Conversion to OpenCV +^^^^^^^^^^^^^^^^^^^^ + +**From PyTorch tensor** + +.. py:currentmodule:: savant.utils.memory_repr_pytorch + +:py:func:`pytorch_tensor_as_opencv_gpu_mat ` allows you to convert a PyTorch tensor into an OpenCV GpuMat. The input tensor must be on GPU, must have shape in HWC format and be in C-contiguous layout. + +.. code-block:: python + + import torch + from savant.utils.memory_repr_pytorch import pytorch_tensor_as_opencv_gpu_mat + + # original in HWC + pytorch_tensor = torch.randint(0, 255, size=(10, 20, 3), device='cuda').to(torch.uint8) + # map to opencv gpu mat (zero-copy) + opencv_gpu_mat = pytorch_tensor_as_opencv_gpu_mat(torch_tensor) + +If the shape format of the tensor is different, you can transform it into the required format using e.g. `Tensor.permute() `__. You should keep in mind that such transformations usually lead to data copying and additionally require the tensor to be converted to contiguous in memory layout. You can do this with `Tensor.contiguous() `__. + +.. code-block:: python + + import torch + from savant.utils.memory_repr_pytorch import pytorch_tensor_as_opencv_gpu_mat + + # original in CHW + tensor0 = torch.randint(0, 255, size=(3, 10, 20), device='cuda').to(torch.uint8) + # transform to HWC + tensor1 = tensor0.permute(1, 2, 0) + # to contiguous (copy) + tensor2 = tensor1.contiguous() + # map to opencv gpu mat (zero-copy) + gpu_mat = pytorch_tensor_as_opencv_gpu_mat(tensor2) + +**From CuPy array** + +.. py:currentmodule:: savant.utils.memory_repr + +:py:func:`cupy_array_as_opencv_gpu_mat ` allows you to convert a CuPy array into an OpenCV GpuMat. The input array must have shape in HWC format, 2 or 3 dimensions and be in C-contiguous layout. + +.. code-block:: python + + import cupy as cp + from savant.utils.memory_repr import cupy_array_as_opencv_gpu_mat + + # original in HWC + cupy_array = cp.random.randint(0, 255, (10, 20, 3)).astype(cp.uint8) + # map to opencv gpu mat (zero-copy) + opencv_gpu_mat = cupy_array_as_opencv_gpu_mat(cupy_array) + +If the shape format of the array is different, you can transform it into the required format using e.g. `cupy.transpose() `__. You should keep in mind that such transformations usually lead to data copying and additionally require the array to be converted to contiguous in memory layout. You can do this with `cupy.ascontiguousarray() `__. + +.. code-block:: python + + import cupy as cp + from savant.utils.memory_repr import cupy_array_as_opencv_gpu_mat + + # original in CHW + arr0 = cp.random.randint(0, 255, (3, 10, 20)).astype(cp.uint8) + # transform to HWC + arr1 = arr0.transpose((1, 2, 0)) + # to contiguous (copy) + arr2 = cp.ascontiguousarray(arr1) + # map to opencv gpu mat (zero-copy) + gpu_mat = cupy_array_as_opencv_gpu_mat(arr2) + + +Conversion to PyTorch Tensor +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +**From OpenCV GpuMat** + +.. py:currentmodule:: savant.utils.memory_repr_pytorch + +:py:func:`opencv_gpu_mat_as_pytorch_tensor ` allows you to convert an OpenCV GpuMat into a PyTorch tensor on GPU. + + +.. code-block:: python + + import cv2 + from savant.utils.memory_repr_pytorch import opencv_gpu_mat_as_pytorch_tensor + + opencv_gpu_mat = cv2.cuda_GpuMat() + opencv_gpu_mat.upload(np.random.randint(0, 255, (10, 20, 3)).astype(np.uint8)) + # zero-copy, HWC format + torch_tensor = opencv_gpu_mat_as_pytorch_tensor(opencv_gpu_mat) + + +**From CuPy Array** + +Conversion from CuPy array to PyTorch tensor is performed by using standard PyTorch function `torch.as_tensor `__. + +.. code-block:: python + + import cupy as cp + import torch + + cupy_array = cp.random.randint(0, 255, (10, 20, 3)).astype(cp.uint8) + # zero-copy, original array format + torch_tensor = torch.as_tensor(cupy_array) + + +Conversion to CuPy Array +^^^^^^^^^^^^^^^^^^^^^^^^ + +**From OpenCV GpuMat** + +.. py:currentmodule:: savant.utils.memory_repr + +:py:func:`opencv_gpu_mat_as_cupy_array ` allows you to convert an OpenCV GpuMat into a CuPy array. + +.. code-block:: python + + import cv2 + import cupy as cp + import numpy as np + from savant.utils.memory_repr import opencv_gpu_mat_as_cupy_array + + opencv_gpu_mat = cv2.cuda_GpuMat() + opencv_gpu_mat.upload(np.random.randint(0, 255, (10, 20, 3)).astype(np.uint8)) + # zero-copy, HWC format + cupy_array = opencv_gpu_mat_as_cupy_array(opencv_gpu_mat) + + +**From PyTorch tensor** + +Conversion from PyTorch tensor to CuPy is performed by using standard CuPy function `cupy.asarray `__ . + +.. code-block:: python + + import torch + import cupy as cp + + torch_tensor = torch.randint(0, 255, size=(3, 10, 20), device='cuda').to(torch.uint8) + # zero-copy, original tensor format + cupy_array = cp.asarray(torch_tensor) diff --git a/docs/source/conf.py b/docs/source/conf.py index 1f0b93b4..c1ffb9d5 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -82,7 +82,7 @@ # List of modules that will be excluded from import to prevent import errors to stop # the build process when some external dependencies cannot be imported during the build. -autodoc_mock_imports = ['cv2', 'pyds', 'pysavantboost'] +autodoc_mock_imports = ['cv2', 'pyds', 'pysavantboost', 'cupy', 'torch'] # -- Options for HTML output ------------------------------------------------- diff --git a/docs/source/index.rst b/docs/source/index.rst index e628b98e..6573884e 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -178,6 +178,7 @@ Savant supports processing parallelization; it helps to utilize the available re advanced_topics/9_open_telemetry advanced_topics/9_input_json_metadata advanced_topics/10_client_sdk + advanced_topics/11_memory_representation_function.rst .. toctree:: :maxdepth: 0 diff --git a/docs/source/reference/api/utils.rst b/docs/source/reference/api/utils.rst index 93fc24a0..a9d1a412 100644 --- a/docs/source/reference/api/utils.rst +++ b/docs/source/reference/api/utils.rst @@ -20,6 +20,21 @@ General utilities artist.Artist logging.LoggerMixin +GPU Memory Formats +------------------ + +.. currentmodule:: savant.utils + +.. autosummary:: + :toctree: generated + :nosignatures: + :template: autosummary/function.rst + + memory_repr.opencv_gpu_mat_as_cupy_array + memory_repr_pytorch.opencv_gpu_mat_as_pytorch_tensor + memory_repr.cupy_array_as_opencv_gpu_mat + memory_repr_pytorch.pytorch_tensor_as_opencv_gpu_mat + DeepStream utilities -------------------- diff --git a/requirements/base.txt b/requirements/base.txt index 08120599..197b0159 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -1,4 +1,8 @@ numpy~=1.22.4 +# cupy +cupy-cuda12x; platform_machine=='x86_64' +cupy-cuda11x; platform_machine=='aarch64' + numba~=0.57 scipy~=1.10 diff --git a/requirements/dev.txt b/requirements/dev.txt index 5e6ce377..c279d7f3 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -1,4 +1,4 @@ black~=22.3.0 unify~=0.5 -pytest~=7.1.2 +pytest~=7.4.3 isort~=5.12.0 diff --git a/savant/utils/memory_repr.py b/savant/utils/memory_repr.py new file mode 100644 index 00000000..1fe5f899 --- /dev/null +++ b/savant/utils/memory_repr.py @@ -0,0 +1,81 @@ +import cupy as cp +import cv2 + +OPENCV_TO_NUMPY_TYPE_MAP = { + cv2.CV_8U: '|u1', + cv2.CV_8S: '|i1', + cv2.CV_16U: ' 1 else (height, width), + 'data': (gpu_mat.cudaPtr(), False), + 'typestr': type_str, + 'descr': [('', type_str)], + # 'stream': 1, # TODO: Investigate + 'strides': (gpu_mat.step, gpu_mat.elemSize(), gpu_mat.elemSize1()) + if channels > 1 + else (gpu_mat.step, gpu_mat.elemSize()), + } + + +def cuda_array_as_opencv_gpu_mat(arr) -> cv2.cuda.GpuMat: + """Returns OpenCV GpuMat for the given cuda array. + The array must support __cuda_array_interface__ (CuPy, PyTorch), + have 2 or 3 dims and be in C-contiguous layout. + """ + shape = arr.__cuda_array_interface__['shape'] + assert len(shape) in (2, 3), 'Array must have 2 or 3 dimensions.' + + dtype = arr.__cuda_array_interface__['typestr'] + depth = NUMPY_TO_OPENCV_TYPE_MAP.get(dtype) + assert ( + depth is not None + ), f'Array must be of one of the following types {list(NUMPY_TO_OPENCV_TYPE_MAP)}.' + + strides = arr.__cuda_array_interface__['strides'] + assert strides is None, 'Array must be in C-contiguous layout.' + + channels = 1 if len(shape) == 2 else shape[2] + # equivalent to unexposed opencv C++ macro CV_MAKETYPE(depth,channels): + mat_type = depth + ((channels - 1) << 3) + + return cv2.cuda.createGpuMatFromCudaMemory( + shape[1::-1], mat_type, arr.__cuda_array_interface__['data'][0] + ) + + +def opencv_gpu_mat_as_cupy_array(gpu_mat: cv2.cuda.GpuMat) -> cp.ndarray: + """Returns CuPy ndarray in HWC format for the given OpenCV GpuMat (zero-copy).""" + return cp.asarray(OpenCVGpuMatCudaArrayInterface(gpu_mat)) + + +def cupy_array_as_opencv_gpu_mat(arr: cp.ndarray) -> cv2.cuda.GpuMat: + """Returns OpenCV GpuMat for the given CuPy ndarray (zero-copy). + The array must have 2 or 3 dims in HWC format and C-contiguous layout. + + Use `cupy.shape` and `cupy.strides` to check if an array + has supported shape format and is contiguous in memory. + + Use `cupy.transpose()` and `cupy.ascontiguousarray()` to transform an array + if necessary (creates a copy of the array). + """ + return cuda_array_as_opencv_gpu_mat(arr) diff --git a/savant/utils/memory_repr_pytorch.py b/savant/utils/memory_repr_pytorch.py new file mode 100644 index 00000000..443cfa17 --- /dev/null +++ b/savant/utils/memory_repr_pytorch.py @@ -0,0 +1,25 @@ +import cv2 +import torch + +from savant.utils.memory_repr import ( + OpenCVGpuMatCudaArrayInterface, + cuda_array_as_opencv_gpu_mat, +) + + +def opencv_gpu_mat_as_pytorch_tensor(gpu_mat: cv2.cuda.GpuMat) -> torch.Tensor: + """Returns PyTorch tensor in HWC format for the given OpenCV GpuMat (zero-copy).""" + return torch.as_tensor(OpenCVGpuMatCudaArrayInterface(gpu_mat), device='cuda') + + +def pytorch_tensor_as_opencv_gpu_mat(tensor: torch.Tensor) -> cv2.cuda_GpuMat: + """Returns OpenCV GpuMat for the given PyTorch tensor (zero-copy). + The tensor must have 2 or 3 dims in HWC format and C-contiguous layout. + + Use `Tensor.size()` and `Tensor.is_contiguous()` to check if a tensor + has supported shape format and is contiguous in memory. + + Use `Tensor.transpose()` and `Tensor.contiguous()` to transform a tensor + if necessary (creates a copy of the tensor). + """ + return cuda_array_as_opencv_gpu_mat(tensor) diff --git a/tests/test_memory_repr.py b/tests/test_memory_repr.py new file mode 100644 index 00000000..2249fe78 --- /dev/null +++ b/tests/test_memory_repr.py @@ -0,0 +1,186 @@ +import numpy as np +import pytest +import torch +import cupy as cp +from savant.utils.memory_repr import ( + opencv_gpu_mat_as_cupy_array, + cupy_array_as_opencv_gpu_mat, +) +from savant.utils.memory_repr_pytorch import ( + opencv_gpu_mat_as_pytorch_tensor, + pytorch_tensor_as_opencv_gpu_mat, +) +import cv2 + +TORCH_TYPE = [torch.int8, torch.uint8, torch.float32] +NUMPY_TYPE = [np.int8, np.uint8, np.float32] +CUPY_TYPE = [cp.int8, cp.uint8, cp.float32] + + +class TestAsOpenCV: + @pytest.mark.parametrize('input_type', TORCH_TYPE) + @pytest.mark.parametrize('channels', [1, 3, 4]) + @pytest.mark.parametrize('memory_format', ['channels_first', 'channels_last']) + def test_pytorch_3d(self, input_type, channels, memory_format): + """Test for pytorch 3d tensors emulate color image""" + if memory_format == 'channels_first': + # shape - [channels, height, width] + pytorch_tensor = ( + torch.randint(0, 255, size=(channels, 10, 20), device='cuda') + .to(input_type) + .permute(1, 2, 0) + ) + elif memory_format == 'channels_last': + # shape - [height, width, channels] + pytorch_tensor = torch.randint( + 0, 255, size=(10, 20, channels), device='cuda' + ).to(input_type) + else: + raise ValueError(f'Unsupported memory format {memory_format}') + + if memory_format == 'channels_first' and channels != 1: + with pytest.raises( + AssertionError, + match='Array must be in C-contiguous layout.', + ): + opencv_gpu_mat = pytorch_tensor_as_opencv_gpu_mat( + pytorch_tensor.permute(1, 2, 0) + ) + else: + opencv_gpu_mat = pytorch_tensor_as_opencv_gpu_mat(pytorch_tensor) + np.testing.assert_almost_equal( + opencv_gpu_mat.download(), + pytorch_tensor.squeeze(2).cpu().numpy() + if channels == 1 + else pytorch_tensor.cpu().numpy(), + ) + assert opencv_gpu_mat.cudaPtr() == pytorch_tensor.data_ptr() + + @pytest.mark.parametrize('input_type', TORCH_TYPE) + def test_pytorch_2d(self, input_type): + """Test for pytorch tensors with grayscale image""" + + # shape - [height, width] + pytorch_tensor = torch.randint(0, 255, (10, 20), device='cuda').to(input_type) + + opencv_gpu_mat = pytorch_tensor_as_opencv_gpu_mat(pytorch_tensor) + np.testing.assert_almost_equal( + opencv_gpu_mat.download(), + pytorch_tensor.cpu().numpy(), + ) + + assert opencv_gpu_mat.cudaPtr() == pytorch_tensor.data_ptr() + + @pytest.mark.parametrize('input_type', CUPY_TYPE) + @pytest.mark.parametrize('channels', [1, 3, 4]) + @pytest.mark.parametrize('memory_format', ['channels_first', 'channels_last']) + def test_cupy_3d(self, input_type, channels, memory_format): + """Test for cupy tensors""" + + if memory_format == 'channels_last': + cupy_array = cp.random.randint(0, 255, (10, 20, channels)).astype( + input_type + ) + elif memory_format == 'channels_first': + cupy_array = ( + cp.random.randint(0, 255, (channels, 10, 20)) + .astype(input_type) + .transpose(1, 2, 0) + ) + else: + raise ValueError(f'Unsupported memory format {memory_format}') + if memory_format == 'channels_first' and channels != 1: + with pytest.raises( + AssertionError, + match='Array must be in C-contiguous layout.', + ): + opencv_mat = cupy_array_as_opencv_gpu_mat(np.transpose(cupy_array, (1, 2, 0))) + else: + opencv_mat = cupy_array_as_opencv_gpu_mat(cupy_array) + np.testing.assert_almost_equal( + opencv_mat.download(), + cupy_array.squeeze(2).get() if channels == 1 else cupy_array.get(), + ) + assert opencv_mat.cudaPtr() == cupy_array.data.ptr + + @pytest.mark.parametrize('input_type', CUPY_TYPE) + def test_cupy_2d(self, input_type): + """Test for pytorch tensors with grayscale image""" + + cupy_array = cp.random.randint(0, 255, (10, 20)).astype(input_type) + + opencv_gpu_mat = cupy_array_as_opencv_gpu_mat(cupy_array) + np.testing.assert_almost_equal( + opencv_gpu_mat.download(), + cupy_array.get(), + ) + + assert opencv_gpu_mat.cudaPtr() == cupy_array.data.ptr + + +class TestToTorch: + @pytest.mark.parametrize('input_type', NUMPY_TYPE) + @pytest.mark.parametrize('channels', [1, 3, 4]) + def test_opencv(self, input_type, channels): + opencv_gpu_mat = cv2.cuda_GpuMat() + opencv_gpu_mat.upload( + np.random.randint(0, 255, (10, 20, channels)).astype(input_type) + ) + + torch_tensor = opencv_gpu_mat_as_pytorch_tensor(opencv_gpu_mat) + + np.testing.assert_almost_equal( + opencv_gpu_mat.download() + if channels == 1 + else opencv_gpu_mat.download(), + torch_tensor.cpu().numpy(), + ) + + assert opencv_gpu_mat.cudaPtr() == torch_tensor.data_ptr() + + @pytest.mark.parametrize('input_type', NUMPY_TYPE) + def test_opencv_grayscale(self, input_type): + opencv_gpu_mat = cv2.cuda_GpuMat() + opencv_gpu_mat.upload(np.random.randint(0, 255, (10, 20)).astype(input_type)) + + torch_tensor = opencv_gpu_mat_as_pytorch_tensor(opencv_gpu_mat) + + np.testing.assert_almost_equal( + opencv_gpu_mat.download(), + torch_tensor.cpu().numpy(), + ) + + assert opencv_gpu_mat.cudaPtr() == torch_tensor.data_ptr() + + +class TestToCUPY: + @pytest.mark.parametrize('input_type', NUMPY_TYPE) + @pytest.mark.parametrize('channels', [1, 3, 4]) + def test_opencv(self, input_type, channels): + opencv_gpu_mat = cv2.cuda_GpuMat() + opencv_gpu_mat.upload( + np.random.randint(0, 255, (10, 20, channels)).astype(input_type) + ) + + cupy_array = opencv_gpu_mat_as_cupy_array(opencv_gpu_mat) + + np.testing.assert_almost_equal( + opencv_gpu_mat.download(), + cupy_array.get(), + ) + + assert opencv_gpu_mat.cudaPtr() == cupy_array.data.ptr + + @pytest.mark.parametrize('input_type', NUMPY_TYPE) + def test_opencv_grayscale(self, input_type): + opencv_gpu_mat = cv2.cuda_GpuMat() + opencv_gpu_mat.upload(np.random.randint(0, 255, (10, 20)).astype(input_type)) + + cupy_array = opencv_gpu_mat_as_cupy_array(opencv_gpu_mat) + + np.testing.assert_almost_equal( + opencv_gpu_mat.download(), + cupy_array.get(), + ) + + assert opencv_gpu_mat.cudaPtr() == cupy_array.data.ptr