From cb9a98e90fb6be1e744c0ac552a6c095d0235ffd Mon Sep 17 00:00:00 2001 From: haixuantao Date: Fri, 17 Jan 2025 13:50:03 +0100 Subject: [PATCH] Minor python esthetics --- apis/python/node/dora/__init__.py | 15 ++++++++------- apis/python/node/dora/cuda.py | 22 ++++++++-------------- apis/python/node/pyproject.toml | 5 ++++- 3 files changed, 20 insertions(+), 22 deletions(-) diff --git a/apis/python/node/dora/__init__.py b/apis/python/node/dora/__init__.py index 354d9ad40..a574bef66 100644 --- a/apis/python/node/dora/__init__.py +++ b/apis/python/node/dora/__init__.py @@ -10,30 +10,31 @@ from enum import Enum from .dora import * - from .dora import ( Node, Ros2Context, + Ros2Durability, + Ros2Liveliness, Ros2Node, Ros2NodeOptions, - Ros2Topic, Ros2Publisher, + Ros2QosPolicies, Ros2Subscription, - start_runtime, - __version__, + Ros2Topic, __author__, - Ros2QosPolicies, - Ros2Durability, - Ros2Liveliness, + __version__, + start_runtime, ) class DoraStatus(Enum): """Dora status to indicate if operator `on_input` loop should be stopped. + Args: Enum (u8): Status signaling to dora operator to stop or continue the operator. + """ CONTINUE = 0 diff --git a/apis/python/node/dora/cuda.py b/apis/python/node/dora/cuda.py index 1858a7bbe..fdd6fba73 100644 --- a/apis/python/node/dora/cuda.py +++ b/apis/python/node/dora/cuda.py @@ -1,19 +1,18 @@ import pyarrow as pa -# To install pyarrow.cuda, run `conda install pyarrow "arrow-cpp-proc=*=cuda" -c conda-forge` -import pyarrow.cuda as cuda - # Make sure to install torch with cuda import torch +from numba.cuda import to_device # Make sure to install numba with cuda from numba.cuda.cudadrv.devicearray import DeviceNDArray -from numba.cuda import to_device + +# To install pyarrow.cuda, run `conda install pyarrow "arrow-cpp-proc=*=cuda" -c conda-forge` +from pyarrow import cuda def torch_to_ipc_buffer(tensor: torch.TensorType) -> tuple[pa.array, dict]: - """ - Converts a Pytorch tensor into a pyarrow buffer containing the IPC handle and its metadata. + """Converts a Pytorch tensor into a pyarrow buffer containing the IPC handle and its metadata. Example Use: ```python @@ -34,8 +33,7 @@ def torch_to_ipc_buffer(tensor: torch.TensorType) -> tuple[pa.array, dict]: def ipc_buffer_to_ipc_handle(handle_buffer: pa.array) -> cuda.IpcMemHandle: - """ - Converts a buffer containing a serialized handler into cuda IPC MemHandle. + """Converts a buffer containing a serialized handler into cuda IPC MemHandle. example use: ```python @@ -57,8 +55,7 @@ def ipc_buffer_to_ipc_handle(handle_buffer: pa.array) -> cuda.IpcMemHandle: def cudabuffer_to_numba(buffer: cuda.CudaBuffer, metadata: dict) -> DeviceNDArray: - """ - Converts a pyarrow CUDA buffer to numba. + """Converts a pyarrow CUDA buffer to numba. example use: ```python @@ -74,7 +71,6 @@ def cudabuffer_to_numba(buffer: cuda.CudaBuffer, metadata: dict) -> DeviceNDArra numba_tensor = cudabuffer_to_numbda(cudabuffer, event["metadata"]) ``` """ - shape = metadata["shape"] strides = metadata["strides"] dtype = metadata["dtype"] @@ -83,8 +79,7 @@ def cudabuffer_to_numba(buffer: cuda.CudaBuffer, metadata: dict) -> DeviceNDArra def cudabuffer_to_torch(buffer: cuda.CudaBuffer, metadata: dict) -> torch.Tensor: - """ - Converts a pyarrow CUDA buffer to a torch tensor. + """Converts a pyarrow CUDA buffer to a torch tensor. example use: ```python @@ -100,7 +95,6 @@ def cudabuffer_to_torch(buffer: cuda.CudaBuffer, metadata: dict) -> torch.Tensor torch_tensor = cudabuffer_to_torch(cudabuffer, event["metadata"]) # on cuda ``` """ - device_arr = cudabuffer_to_numba(buffer, metadata) torch_tensor = torch.as_tensor(device_arr, device="cuda") return torch_tensor diff --git a/apis/python/node/pyproject.toml b/apis/python/node/pyproject.toml index d1f79b9c3..35c0ace5f 100644 --- a/apis/python/node/pyproject.toml +++ b/apis/python/node/pyproject.toml @@ -4,9 +4,12 @@ build-backend = "maturin" [project] name = "dora-rs" -dynamic = ["version"] +dynamic = ["version", "readme"] # Install pyarrow at the same time of dora-rs dependencies = ['pyarrow'] +[dependency-groups] +dev = ["pytest >=8.1.1", "ruff >=0.9.1"] + [tool.maturin] features = ["pyo3/extension-module"]