-
Notifications
You must be signed in to change notification settings - Fork 53
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Gpu memory representation functions (#557)
- Loading branch information
Showing
11 changed files
with
487 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
144 changes: 144 additions & 0 deletions
144
docs/source/advanced_topics/11_memory_representation_function.rst
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
Conversions Between GPU Memory Formats | ||
--------------------------------------------- | ||
|
||
When working with images, there are many ways to represent them as arrays of pixels. Working with different models you may encounter representation of an image using OpenCV GpuMat class, PyTorch tensor or CuPy array. | ||
|
||
The Savant framework aims to use GPU efficiently without excessive data transfers. To achieve this, Savant provides functions for converting between different image representations. Data exchange is performed with zero-copying between different views, except for some cases of conversion to GpuMat OpenCV. | ||
|
||
Conversion to OpenCV | ||
^^^^^^^^^^^^^^^^^^^^ | ||
|
||
**From PyTorch tensor** | ||
|
||
.. py:currentmodule:: savant.utils.memory_repr_pytorch | ||
:py:func:`pytorch_tensor_as_opencv_gpu_mat <pytorch_tensor_as_opencv_gpu_mat>` allows you to convert a PyTorch tensor into an OpenCV GpuMat. The input tensor must be on GPU, must have shape in HWC format and be in C-contiguous layout. | ||
|
||
.. code-block:: python | ||
import torch | ||
from savant.utils.memory_repr_pytorch import pytorch_tensor_as_opencv_gpu_mat | ||
# original in HWC | ||
pytorch_tensor = torch.randint(0, 255, size=(10, 20, 3), device='cuda').to(torch.uint8) | ||
# map to opencv gpu mat (zero-copy) | ||
opencv_gpu_mat = pytorch_tensor_as_opencv_gpu_mat(torch_tensor) | ||
If the shape format of the tensor is different, you can transform it into the required format using e.g. `Tensor.permute() <https://pytorch.org/docs/stable/generated/torch.Tensor.permute.html>`__. You should keep in mind that such transformations usually lead to data copying and additionally require the tensor to be converted to contiguous in memory layout. You can do this with `Tensor.contiguous() <https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html>`__. | ||
|
||
.. code-block:: python | ||
import torch | ||
from savant.utils.memory_repr_pytorch import pytorch_tensor_as_opencv_gpu_mat | ||
# original in CHW | ||
tensor0 = torch.randint(0, 255, size=(3, 10, 20), device='cuda').to(torch.uint8) | ||
# transform to HWC | ||
tensor1 = tensor0.permute(1, 2, 0) | ||
# to contiguous (copy) | ||
tensor2 = tensor1.contiguous() | ||
# map to opencv gpu mat (zero-copy) | ||
gpu_mat = pytorch_tensor_as_opencv_gpu_mat(tensor2) | ||
**From CuPy array** | ||
|
||
.. py:currentmodule:: savant.utils.memory_repr | ||
:py:func:`cupy_array_as_opencv_gpu_mat <cupy_array_as_opencv_gpu_mat>` allows you to convert a CuPy array into an OpenCV GpuMat. The input array must have shape in HWC format, 2 or 3 dimensions and be in C-contiguous layout. | ||
|
||
.. code-block:: python | ||
import cupy as cp | ||
from savant.utils.memory_repr import cupy_array_as_opencv_gpu_mat | ||
# original in HWC | ||
cupy_array = cp.random.randint(0, 255, (10, 20, 3)).astype(cp.uint8) | ||
# map to opencv gpu mat (zero-copy) | ||
opencv_gpu_mat = cupy_array_as_opencv_gpu_mat(cupy_array) | ||
If the shape format of the array is different, you can transform it into the required format using e.g. `cupy.transpose() <https://docs.cupy.dev/en/stable/reference/generated/cupy.transpose.html>`__. You should keep in mind that such transformations usually lead to data copying and additionally require the array to be converted to contiguous in memory layout. You can do this with `cupy.ascontiguousarray() <https://docs.cupy.dev/en/stable/reference/generated/cupy.ascontiguousarray.html>`__. | ||
|
||
.. code-block:: python | ||
import cupy as cp | ||
from savant.utils.memory_repr import cupy_array_as_opencv_gpu_mat | ||
# original in CHW | ||
arr0 = cp.random.randint(0, 255, (3, 10, 20)).astype(cp.uint8) | ||
# transform to HWC | ||
arr1 = arr0.transpose((1, 2, 0)) | ||
# to contiguous (copy) | ||
arr2 = cp.ascontiguousarray(arr1) | ||
# map to opencv gpu mat (zero-copy) | ||
gpu_mat = cupy_array_as_opencv_gpu_mat(arr2) | ||
Conversion to PyTorch Tensor | ||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
**From OpenCV GpuMat** | ||
|
||
.. py:currentmodule:: savant.utils.memory_repr_pytorch | ||
:py:func:`opencv_gpu_mat_as_pytorch_tensor <opencv_gpu_mat_as_pytorch_tensor>` allows you to convert an OpenCV GpuMat into a PyTorch tensor on GPU. | ||
|
||
|
||
.. code-block:: python | ||
import cv2 | ||
from savant.utils.memory_repr_pytorch import opencv_gpu_mat_as_pytorch_tensor | ||
opencv_gpu_mat = cv2.cuda_GpuMat() | ||
opencv_gpu_mat.upload(np.random.randint(0, 255, (10, 20, 3)).astype(np.uint8)) | ||
# zero-copy, HWC format | ||
torch_tensor = opencv_gpu_mat_as_pytorch_tensor(opencv_gpu_mat) | ||
**From CuPy Array** | ||
|
||
Conversion from CuPy array to PyTorch tensor is performed by using standard PyTorch function `torch.as_tensor <https://pytorch.org/docs/stable/generated/torch.as_tensor.html>`__. | ||
|
||
.. code-block:: python | ||
import cupy as cp | ||
import torch | ||
cupy_array = cp.random.randint(0, 255, (10, 20, 3)).astype(cp.uint8) | ||
# zero-copy, original array format | ||
torch_tensor = torch.as_tensor(cupy_array) | ||
Conversion to CuPy Array | ||
^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
**From OpenCV GpuMat** | ||
|
||
.. py:currentmodule:: savant.utils.memory_repr | ||
:py:func:`opencv_gpu_mat_as_cupy_array <opencv_gpu_mat_as_cupy_array>` allows you to convert an OpenCV GpuMat into a CuPy array. | ||
|
||
.. code-block:: python | ||
import cv2 | ||
import cupy as cp | ||
import numpy as np | ||
from savant.utils.memory_repr import opencv_gpu_mat_as_cupy_array | ||
opencv_gpu_mat = cv2.cuda_GpuMat() | ||
opencv_gpu_mat.upload(np.random.randint(0, 255, (10, 20, 3)).astype(np.uint8)) | ||
# zero-copy, HWC format | ||
cupy_array = opencv_gpu_mat_as_cupy_array(opencv_gpu_mat) | ||
**From PyTorch tensor** | ||
|
||
Conversion from PyTorch tensor to CuPy is performed by using standard CuPy function `cupy.asarray <https://docs.cupy.dev/en/stable/reference/generated/cupy.asarray.html>`__ . | ||
|
||
.. code-block:: python | ||
import torch | ||
import cupy as cp | ||
torch_tensor = torch.randint(0, 255, size=(3, 10, 20), device='cuda').to(torch.uint8) | ||
# zero-copy, original tensor format | ||
cupy_array = cp.asarray(torch_tensor) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
black~=22.3.0 | ||
unify~=0.5 | ||
pytest~=7.1.2 | ||
pytest~=7.4.3 | ||
isort~=5.12.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import cupy as cp | ||
import cv2 | ||
|
||
OPENCV_TO_NUMPY_TYPE_MAP = { | ||
cv2.CV_8U: '|u1', | ||
cv2.CV_8S: '|i1', | ||
cv2.CV_16U: '<u2', | ||
cv2.CV_16S: '<i2', | ||
cv2.CV_32S: '<i4', | ||
cv2.CV_32F: '<f4', | ||
cv2.CV_64F: '<f8', | ||
} | ||
|
||
NUMPY_TO_OPENCV_TYPE_MAP = {v: k for k, v in OPENCV_TO_NUMPY_TYPE_MAP.items()} | ||
|
||
|
||
class OpenCVGpuMatCudaArrayInterface: | ||
"""OpenCV GpuMat __cuda_array_interface__. | ||
https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html | ||
""" | ||
|
||
def __init__(self, gpu_mat: cv2.cuda.GpuMat): | ||
width, height = gpu_mat.size() | ||
channels = gpu_mat.channels() | ||
type_str = OPENCV_TO_NUMPY_TYPE_MAP.get(gpu_mat.depth()) | ||
assert type_str is not None, 'Unsupported OpenCV GpuMat type.' | ||
self.__cuda_array_interface__ = { | ||
'version': 3, | ||
'shape': (height, width, channels) if channels > 1 else (height, width), | ||
'data': (gpu_mat.cudaPtr(), False), | ||
'typestr': type_str, | ||
'descr': [('', type_str)], | ||
# 'stream': 1, # TODO: Investigate | ||
'strides': (gpu_mat.step, gpu_mat.elemSize(), gpu_mat.elemSize1()) | ||
if channels > 1 | ||
else (gpu_mat.step, gpu_mat.elemSize()), | ||
} | ||
|
||
|
||
def cuda_array_as_opencv_gpu_mat(arr) -> cv2.cuda.GpuMat: | ||
"""Returns OpenCV GpuMat for the given cuda array. | ||
The array must support __cuda_array_interface__ (CuPy, PyTorch), | ||
have 2 or 3 dims and be in C-contiguous layout. | ||
""" | ||
shape = arr.__cuda_array_interface__['shape'] | ||
assert len(shape) in (2, 3), 'Array must have 2 or 3 dimensions.' | ||
|
||
dtype = arr.__cuda_array_interface__['typestr'] | ||
depth = NUMPY_TO_OPENCV_TYPE_MAP.get(dtype) | ||
assert ( | ||
depth is not None | ||
), f'Array must be of one of the following types {list(NUMPY_TO_OPENCV_TYPE_MAP)}.' | ||
|
||
strides = arr.__cuda_array_interface__['strides'] | ||
assert strides is None, 'Array must be in C-contiguous layout.' | ||
|
||
channels = 1 if len(shape) == 2 else shape[2] | ||
# equivalent to unexposed opencv C++ macro CV_MAKETYPE(depth,channels): | ||
mat_type = depth + ((channels - 1) << 3) | ||
|
||
return cv2.cuda.createGpuMatFromCudaMemory( | ||
shape[1::-1], mat_type, arr.__cuda_array_interface__['data'][0] | ||
) | ||
|
||
|
||
def opencv_gpu_mat_as_cupy_array(gpu_mat: cv2.cuda.GpuMat) -> cp.ndarray: | ||
"""Returns CuPy ndarray in HWC format for the given OpenCV GpuMat (zero-copy).""" | ||
return cp.asarray(OpenCVGpuMatCudaArrayInterface(gpu_mat)) | ||
|
||
|
||
def cupy_array_as_opencv_gpu_mat(arr: cp.ndarray) -> cv2.cuda.GpuMat: | ||
"""Returns OpenCV GpuMat for the given CuPy ndarray (zero-copy). | ||
The array must have 2 or 3 dims in HWC format and C-contiguous layout. | ||
Use `cupy.shape` and `cupy.strides` to check if an array | ||
has supported shape format and is contiguous in memory. | ||
Use `cupy.transpose()` and `cupy.ascontiguousarray()` to transform an array | ||
if necessary (creates a copy of the array). | ||
""" | ||
return cuda_array_as_opencv_gpu_mat(arr) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import cv2 | ||
import torch | ||
|
||
from savant.utils.memory_repr import ( | ||
OpenCVGpuMatCudaArrayInterface, | ||
cuda_array_as_opencv_gpu_mat, | ||
) | ||
|
||
|
||
def opencv_gpu_mat_as_pytorch_tensor(gpu_mat: cv2.cuda.GpuMat) -> torch.Tensor: | ||
"""Returns PyTorch tensor in HWC format for the given OpenCV GpuMat (zero-copy).""" | ||
return torch.as_tensor(OpenCVGpuMatCudaArrayInterface(gpu_mat), device='cuda') | ||
|
||
|
||
def pytorch_tensor_as_opencv_gpu_mat(tensor: torch.Tensor) -> cv2.cuda_GpuMat: | ||
"""Returns OpenCV GpuMat for the given PyTorch tensor (zero-copy). | ||
The tensor must have 2 or 3 dims in HWC format and C-contiguous layout. | ||
Use `Tensor.size()` and `Tensor.is_contiguous()` to check if a tensor | ||
has supported shape format and is contiguous in memory. | ||
Use `Tensor.transpose()` and `Tensor.contiguous()` to transform a tensor | ||
if necessary (creates a copy of the tensor). | ||
""" | ||
return cuda_array_as_opencv_gpu_mat(tensor) |
Oops, something went wrong.