Skip to content

Commit

Permalink
Gpu memory representation functions (#557)
Browse files Browse the repository at this point in the history
  • Loading branch information
dorgun authored Nov 28, 2023
1 parent d736450 commit 13bb632
Show file tree
Hide file tree
Showing 11 changed files with 487 additions and 3 deletions.
19 changes: 18 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,15 @@ build-docs:
-f docker/$(DOCKER_FILE) \
-t savant-docs:$(SAVANT_VERSION) .

build-tests:
docker buildx build \
--target tests \
--build-arg DEEPSTREAM_VERSION=$(DEEPSTREAM_VERSION) \
--build-arg USER_UID=`id -u` \
--build-arg USER_GID=`id -g` \
-f docker/$(DOCKER_FILE) \
-t savant-tests:$(SAVANT_VERSION) .

build-opencv: opencv-build-amd64 opencv-build-arm64 opencv-cp-amd64 opencv-cp-arm64

opencv-build-amd64:
Expand Down Expand Up @@ -96,6 +105,14 @@ run-docs:
--name savant-docs \
savant-docs:$(SAVANT_VERSION)

run-tests:
docker run -it --rm \
-v `pwd`/savant:$(PROJECT_PATH)/savant \
-v `pwd`/tests:$(PROJECT_PATH)/tests \
--gpus=all \
--name savant-tests \
savant-tests:$(SAVANT_VERSION)

run-dev:
xhost +local:docker
docker run -it --rm --gpus=all \
Expand Down Expand Up @@ -131,7 +148,7 @@ check-unify:
check: check-black check-unify check-isort

run-unify:
unify --in-place --recursive savant adapters gst_plugins samples scripts
unify --in-place --recursive savant adapters gst_plugins samples scripts tests

run-black:
black .
Expand Down
11 changes: 11 additions & 0 deletions docker/Dockerfile.deepstream
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,14 @@ WORKDIR $PROJECT_PATH/docs

ENTRYPOINT ["make"]
CMD ["clean", "html"]


# Savant test image, x86 only
FROM base AS tests
COPY requirements/dev.txt requirements/dev.txt
COPY tests /opt/savant/tests

RUN python -m pip install --no-cache-dir -r requirements/dev.txt
RUN python -m pip install torch torchvision torchaudio

ENTRYPOINT ["pytest", "-s", "/opt/savant/tests"]
144 changes: 144 additions & 0 deletions docs/source/advanced_topics/11_memory_representation_function.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
Conversions Between GPU Memory Formats
---------------------------------------------

When working with images, there are many ways to represent them as arrays of pixels. Working with different models you may encounter representation of an image using OpenCV GpuMat class, PyTorch tensor or CuPy array.

The Savant framework aims to use GPU efficiently without excessive data transfers. To achieve this, Savant provides functions for converting between different image representations. Data exchange is performed with zero-copying between different views, except for some cases of conversion to GpuMat OpenCV.

Conversion to OpenCV
^^^^^^^^^^^^^^^^^^^^

**From PyTorch tensor**

.. py:currentmodule:: savant.utils.memory_repr_pytorch
:py:func:`pytorch_tensor_as_opencv_gpu_mat <pytorch_tensor_as_opencv_gpu_mat>` allows you to convert a PyTorch tensor into an OpenCV GpuMat. The input tensor must be on GPU, must have shape in HWC format and be in C-contiguous layout.

.. code-block:: python
import torch
from savant.utils.memory_repr_pytorch import pytorch_tensor_as_opencv_gpu_mat
# original in HWC
pytorch_tensor = torch.randint(0, 255, size=(10, 20, 3), device='cuda').to(torch.uint8)
# map to opencv gpu mat (zero-copy)
opencv_gpu_mat = pytorch_tensor_as_opencv_gpu_mat(torch_tensor)
If the shape format of the tensor is different, you can transform it into the required format using e.g. `Tensor.permute() <https://pytorch.org/docs/stable/generated/torch.Tensor.permute.html>`__. You should keep in mind that such transformations usually lead to data copying and additionally require the tensor to be converted to contiguous in memory layout. You can do this with `Tensor.contiguous() <https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html>`__.

.. code-block:: python
import torch
from savant.utils.memory_repr_pytorch import pytorch_tensor_as_opencv_gpu_mat
# original in CHW
tensor0 = torch.randint(0, 255, size=(3, 10, 20), device='cuda').to(torch.uint8)
# transform to HWC
tensor1 = tensor0.permute(1, 2, 0)
# to contiguous (copy)
tensor2 = tensor1.contiguous()
# map to opencv gpu mat (zero-copy)
gpu_mat = pytorch_tensor_as_opencv_gpu_mat(tensor2)
**From CuPy array**

.. py:currentmodule:: savant.utils.memory_repr
:py:func:`cupy_array_as_opencv_gpu_mat <cupy_array_as_opencv_gpu_mat>` allows you to convert a CuPy array into an OpenCV GpuMat. The input array must have shape in HWC format, 2 or 3 dimensions and be in C-contiguous layout.

.. code-block:: python
import cupy as cp
from savant.utils.memory_repr import cupy_array_as_opencv_gpu_mat
# original in HWC
cupy_array = cp.random.randint(0, 255, (10, 20, 3)).astype(cp.uint8)
# map to opencv gpu mat (zero-copy)
opencv_gpu_mat = cupy_array_as_opencv_gpu_mat(cupy_array)
If the shape format of the array is different, you can transform it into the required format using e.g. `cupy.transpose() <https://docs.cupy.dev/en/stable/reference/generated/cupy.transpose.html>`__. You should keep in mind that such transformations usually lead to data copying and additionally require the array to be converted to contiguous in memory layout. You can do this with `cupy.ascontiguousarray() <https://docs.cupy.dev/en/stable/reference/generated/cupy.ascontiguousarray.html>`__.

.. code-block:: python
import cupy as cp
from savant.utils.memory_repr import cupy_array_as_opencv_gpu_mat
# original in CHW
arr0 = cp.random.randint(0, 255, (3, 10, 20)).astype(cp.uint8)
# transform to HWC
arr1 = arr0.transpose((1, 2, 0))
# to contiguous (copy)
arr2 = cp.ascontiguousarray(arr1)
# map to opencv gpu mat (zero-copy)
gpu_mat = cupy_array_as_opencv_gpu_mat(arr2)
Conversion to PyTorch Tensor
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

**From OpenCV GpuMat**

.. py:currentmodule:: savant.utils.memory_repr_pytorch
:py:func:`opencv_gpu_mat_as_pytorch_tensor <opencv_gpu_mat_as_pytorch_tensor>` allows you to convert an OpenCV GpuMat into a PyTorch tensor on GPU.


.. code-block:: python
import cv2
from savant.utils.memory_repr_pytorch import opencv_gpu_mat_as_pytorch_tensor
opencv_gpu_mat = cv2.cuda_GpuMat()
opencv_gpu_mat.upload(np.random.randint(0, 255, (10, 20, 3)).astype(np.uint8))
# zero-copy, HWC format
torch_tensor = opencv_gpu_mat_as_pytorch_tensor(opencv_gpu_mat)
**From CuPy Array**

Conversion from CuPy array to PyTorch tensor is performed by using standard PyTorch function `torch.as_tensor <https://pytorch.org/docs/stable/generated/torch.as_tensor.html>`__.

.. code-block:: python
import cupy as cp
import torch
cupy_array = cp.random.randint(0, 255, (10, 20, 3)).astype(cp.uint8)
# zero-copy, original array format
torch_tensor = torch.as_tensor(cupy_array)
Conversion to CuPy Array
^^^^^^^^^^^^^^^^^^^^^^^^

**From OpenCV GpuMat**

.. py:currentmodule:: savant.utils.memory_repr
:py:func:`opencv_gpu_mat_as_cupy_array <opencv_gpu_mat_as_cupy_array>` allows you to convert an OpenCV GpuMat into a CuPy array.

.. code-block:: python
import cv2
import cupy as cp
import numpy as np
from savant.utils.memory_repr import opencv_gpu_mat_as_cupy_array
opencv_gpu_mat = cv2.cuda_GpuMat()
opencv_gpu_mat.upload(np.random.randint(0, 255, (10, 20, 3)).astype(np.uint8))
# zero-copy, HWC format
cupy_array = opencv_gpu_mat_as_cupy_array(opencv_gpu_mat)
**From PyTorch tensor**

Conversion from PyTorch tensor to CuPy is performed by using standard CuPy function `cupy.asarray <https://docs.cupy.dev/en/stable/reference/generated/cupy.asarray.html>`__ .

.. code-block:: python
import torch
import cupy as cp
torch_tensor = torch.randint(0, 255, size=(3, 10, 20), device='cuda').to(torch.uint8)
# zero-copy, original tensor format
cupy_array = cp.asarray(torch_tensor)
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@

# List of modules that will be excluded from import to prevent import errors to stop
# the build process when some external dependencies cannot be imported during the build.
autodoc_mock_imports = ['cv2', 'pyds', 'pysavantboost']
autodoc_mock_imports = ['cv2', 'pyds', 'pysavantboost', 'cupy', 'torch']

# -- Options for HTML output -------------------------------------------------

Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ Savant supports processing parallelization; it helps to utilize the available re
advanced_topics/9_open_telemetry
advanced_topics/9_input_json_metadata
advanced_topics/10_client_sdk
advanced_topics/11_memory_representation_function.rst

.. toctree::
:maxdepth: 0
Expand Down
15 changes: 15 additions & 0 deletions docs/source/reference/api/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,21 @@ General utilities
artist.Artist
logging.LoggerMixin

GPU Memory Formats
------------------

.. currentmodule:: savant.utils

.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/function.rst

memory_repr.opencv_gpu_mat_as_cupy_array
memory_repr_pytorch.opencv_gpu_mat_as_pytorch_tensor
memory_repr.cupy_array_as_opencv_gpu_mat
memory_repr_pytorch.pytorch_tensor_as_opencv_gpu_mat

DeepStream utilities
--------------------

Expand Down
4 changes: 4 additions & 0 deletions requirements/base.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
numpy~=1.22.4
# cupy
cupy-cuda12x; platform_machine=='x86_64'
cupy-cuda11x; platform_machine=='aarch64'

numba~=0.57
scipy~=1.10

Expand Down
2 changes: 1 addition & 1 deletion requirements/dev.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
black~=22.3.0
unify~=0.5
pytest~=7.1.2
pytest~=7.4.3
isort~=5.12.0
81 changes: 81 additions & 0 deletions savant/utils/memory_repr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import cupy as cp
import cv2

OPENCV_TO_NUMPY_TYPE_MAP = {
cv2.CV_8U: '|u1',
cv2.CV_8S: '|i1',
cv2.CV_16U: '<u2',
cv2.CV_16S: '<i2',
cv2.CV_32S: '<i4',
cv2.CV_32F: '<f4',
cv2.CV_64F: '<f8',
}

NUMPY_TO_OPENCV_TYPE_MAP = {v: k for k, v in OPENCV_TO_NUMPY_TYPE_MAP.items()}


class OpenCVGpuMatCudaArrayInterface:
"""OpenCV GpuMat __cuda_array_interface__.
https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html
"""

def __init__(self, gpu_mat: cv2.cuda.GpuMat):
width, height = gpu_mat.size()
channels = gpu_mat.channels()
type_str = OPENCV_TO_NUMPY_TYPE_MAP.get(gpu_mat.depth())
assert type_str is not None, 'Unsupported OpenCV GpuMat type.'
self.__cuda_array_interface__ = {
'version': 3,
'shape': (height, width, channels) if channels > 1 else (height, width),
'data': (gpu_mat.cudaPtr(), False),
'typestr': type_str,
'descr': [('', type_str)],
# 'stream': 1, # TODO: Investigate
'strides': (gpu_mat.step, gpu_mat.elemSize(), gpu_mat.elemSize1())
if channels > 1
else (gpu_mat.step, gpu_mat.elemSize()),
}


def cuda_array_as_opencv_gpu_mat(arr) -> cv2.cuda.GpuMat:
"""Returns OpenCV GpuMat for the given cuda array.
The array must support __cuda_array_interface__ (CuPy, PyTorch),
have 2 or 3 dims and be in C-contiguous layout.
"""
shape = arr.__cuda_array_interface__['shape']
assert len(shape) in (2, 3), 'Array must have 2 or 3 dimensions.'

dtype = arr.__cuda_array_interface__['typestr']
depth = NUMPY_TO_OPENCV_TYPE_MAP.get(dtype)
assert (
depth is not None
), f'Array must be of one of the following types {list(NUMPY_TO_OPENCV_TYPE_MAP)}.'

strides = arr.__cuda_array_interface__['strides']
assert strides is None, 'Array must be in C-contiguous layout.'

channels = 1 if len(shape) == 2 else shape[2]
# equivalent to unexposed opencv C++ macro CV_MAKETYPE(depth,channels):
mat_type = depth + ((channels - 1) << 3)

return cv2.cuda.createGpuMatFromCudaMemory(
shape[1::-1], mat_type, arr.__cuda_array_interface__['data'][0]
)


def opencv_gpu_mat_as_cupy_array(gpu_mat: cv2.cuda.GpuMat) -> cp.ndarray:
"""Returns CuPy ndarray in HWC format for the given OpenCV GpuMat (zero-copy)."""
return cp.asarray(OpenCVGpuMatCudaArrayInterface(gpu_mat))


def cupy_array_as_opencv_gpu_mat(arr: cp.ndarray) -> cv2.cuda.GpuMat:
"""Returns OpenCV GpuMat for the given CuPy ndarray (zero-copy).
The array must have 2 or 3 dims in HWC format and C-contiguous layout.
Use `cupy.shape` and `cupy.strides` to check if an array
has supported shape format and is contiguous in memory.
Use `cupy.transpose()` and `cupy.ascontiguousarray()` to transform an array
if necessary (creates a copy of the array).
"""
return cuda_array_as_opencv_gpu_mat(arr)
25 changes: 25 additions & 0 deletions savant/utils/memory_repr_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import cv2
import torch

from savant.utils.memory_repr import (
OpenCVGpuMatCudaArrayInterface,
cuda_array_as_opencv_gpu_mat,
)


def opencv_gpu_mat_as_pytorch_tensor(gpu_mat: cv2.cuda.GpuMat) -> torch.Tensor:
"""Returns PyTorch tensor in HWC format for the given OpenCV GpuMat (zero-copy)."""
return torch.as_tensor(OpenCVGpuMatCudaArrayInterface(gpu_mat), device='cuda')


def pytorch_tensor_as_opencv_gpu_mat(tensor: torch.Tensor) -> cv2.cuda_GpuMat:
"""Returns OpenCV GpuMat for the given PyTorch tensor (zero-copy).
The tensor must have 2 or 3 dims in HWC format and C-contiguous layout.
Use `Tensor.size()` and `Tensor.is_contiguous()` to check if a tensor
has supported shape format and is contiguous in memory.
Use `Tensor.transpose()` and `Tensor.contiguous()` to transform a tensor
if necessary (creates a copy of the tensor).
"""
return cuda_array_as_opencv_gpu_mat(tensor)
Loading

0 comments on commit 13bb632

Please sign in to comment.