From 57a60285765b0d4888ac5f4aef0f96ae6ff10286 Mon Sep 17 00:00:00 2001 From: xiguadong <55774832+xiguadong@users.noreply.github.com> Date: Tue, 10 Dec 2024 16:19:53 +0800 Subject: [PATCH] add check for inputs (#14) * session add check inputs * fix check coniguous * fix assert inputs shape --- axengine/session.py | 79 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 62 insertions(+), 17 deletions(-) diff --git a/axengine/session.py b/axengine/session.py index 6981dd7..ab143de 100644 --- a/axengine/session.py +++ b/axengine/session.py @@ -88,7 +88,10 @@ def __init__( raise ValueError( f"Model type '{self._model_type}' is not allowed when vnpu is inited as {self._vnpu_type}." ) - if self._vnpu_type is VNPUType.BIG_LITTLE or self._vnpu_type is VNPUType.LITTLE_BIG: + if ( + self._vnpu_type is VNPUType.BIG_LITTLE + or self._vnpu_type is VNPUType.LITTLE_BIG + ): if self._model_type is ModelType.TRIPLE: raise ValueError( f"Model type '{self._model_type}' is not allowed when vnpu is inited as {self._vnpu_type}." @@ -123,8 +126,12 @@ def __init__( self._cmm_token = self._engine_ffi.new("AX_S8[]", b"PyEngine") self._io[0].nInputSize = len(self.get_inputs()) self._io[0].nOutputSize = len(self.get_outputs()) - self._io[0].pInputs = self._engine_ffi.new("AX_ENGINE_IO_BUFFER_T[{}]".format(self._io[0].nInputSize)) - self._io[0].pOutputs = self._engine_ffi.new("AX_ENGINE_IO_BUFFER_T[{}]".format(self._io[0].nOutputSize)) + self._io[0].pInputs = self._engine_ffi.new( + "AX_ENGINE_IO_BUFFER_T[{}]".format(self._io[0].nInputSize) + ) + self._io[0].pOutputs = self._engine_ffi.new( + "AX_ENGINE_IO_BUFFER_T[{}]".format(self._io[0].nOutputSize) + ) for i in range(len(self.get_inputs())): max_buf = 0 for j in range(self._shape_count): @@ -173,7 +180,9 @@ def _init(self, vnpu=VNPUType.DISABLED): # vnpu type, the default is disabled ret = self._engine_lib.AX_ENGINE_GetVNPUAttr(vnpu_type) if 0 != ret: # this means the NPU was not initialized - vnpu_type.eHardMode = self._engine_ffi.cast("AX_ENGINE_NPU_MODE_T", vnpu.value) + vnpu_type.eHardMode = self._engine_ffi.cast( + "AX_ENGINE_NPU_MODE_T", vnpu.value + ) return self._engine_lib.AX_ENGINE_Init(vnpu_type) @@ -196,13 +205,17 @@ def _get_vnpu_type(self) -> VNPUType: def _get_model_type(self) -> ModelType: model_type = self._engine_ffi.new("AX_ENGINE_MODEL_TYPE_T *") - ret = self._engine_lib.AX_ENGINE_GetModelType(self._model_buffer, self._model_buffer_size, model_type) + ret = self._engine_lib.AX_ENGINE_GetModelType( + self._model_buffer, self._model_buffer_size, model_type + ) if 0 != ret: raise RuntimeError("Failed to get model type.") return ModelType(model_type[0]) def _get_model_tool_version(self): - model_tool_version = self._engine_lib.AX_ENGINE_GetModelToolsVersion(self._handle[0]) + model_tool_version = self._engine_lib.AX_ENGINE_GetModelToolsVersion( + self._handle[0] + ) return self._engine_ffi.string(model_tool_version).decode("utf-8") def _load(self): @@ -216,7 +229,9 @@ def _load(self): self._handle, self._model_buffer, self._model_buffer_size, extra ) if 0 == ret: - ret = self._engine_lib.AX_ENGINE_CreateContextV2(self._handle[0], self._context) + ret = self._engine_lib.AX_ENGINE_CreateContextV2( + self._handle[0], self._context + ) return ret def _get_info(self): @@ -230,7 +245,9 @@ def _get_info(self): else: for i in range(self._shape_count): info = self._engine_ffi.new("AX_ENGINE_IO_INFO_T **") - ret = self._engine_lib.AX_ENGINE_GetGroupIOInfo(self._handle[0], i, info) + ret = self._engine_lib.AX_ENGINE_GetGroupIOInfo( + self._handle[0], i, info + ) if 0 != ret: raise RuntimeError(f"Failed to get model the {i}th shape.") total_info.append(info) @@ -256,7 +273,9 @@ def _get_inputs(self): shape = [] for i in range(current_input.nShapeSize): shape.append(current_input.pShape[i]) - dtype = _transform_dtype(self._engine_ffi, self._engine_lib, current_input.eDataType) + dtype = _transform_dtype( + self._engine_ffi, self._engine_lib, current_input.eDataType + ) meta = NodeArg(name, dtype, shape) one_group_input.append(meta) inputs.append(one_group_input) @@ -272,7 +291,9 @@ def _get_outputs(self): shape = [] for i in range(current_output.nShapeSize): shape.append(current_output.pShape[i]) - dtype = _transform_dtype(self._engine_ffi, self._engine_lib, current_output.eDataType) + dtype = _transform_dtype( + self._engine_ffi, self._engine_lib, current_output.eDataType + ) meta = NodeArg(name, dtype, shape) one_group_output.append(meta) outputs.append(one_group_output) @@ -280,13 +301,17 @@ def _get_outputs(self): def get_inputs(self, shape_group=0) -> list[NodeArg]: if shape_group > self._shape_count: - raise ValueError(f"Shape group '{shape_group}' is out of range, total {self._shape_count}.") + raise ValueError( + f"Shape group '{shape_group}' is out of range, total {self._shape_count}." + ) selected_info = self._inputs[shape_group] return selected_info def get_outputs(self, shape_group=0) -> list[NodeArg]: if shape_group > self._shape_count: - raise ValueError(f"Shape group '{shape_group}' is out of range, total {self._shape_count}.") + raise ValueError( + f"Shape group '{shape_group}' is out of range, total {self._shape_count}." + ) selected_info = self._outputs[shape_group] return selected_info @@ -318,25 +343,45 @@ def run(self, output_names, input_feed, run_options=None): for key, npy in input_feed.items(): for i, one in enumerate(self.get_inputs()): if one.name == key: + assert ( + list(one.shape) == list(npy.shape) and one.dtype == npy.dtype + ), f"model inputs({key}) expect shape {one.shape} and dtype {one.dtype}, howerver gets input with shape {npy.shape} and dtype {npy.dtype}" + + if not ( + not npy.flags.c_contiguous + and npy.flags.f_contiguous + and npy.flags.contiguous + ): + npy = np.ascontiguousarray(npy) npy_ptr = self._engine_ffi.cast("void *", npy.ctypes.data) - self._engine_ffi.memmove(self._io[0].pInputs[i].pVirAddr, npy_ptr, npy.nbytes) + self._engine_ffi.memmove( + self._io[0].pInputs[i].pVirAddr, npy_ptr, npy.nbytes + ) self._sys_lib.AX_SYS_MflushCache( - self._io[0].pInputs[i].phyAddr, self._io[0].pInputs[i].pVirAddr, self._io[0].pInputs[i].nSize + self._io[0].pInputs[i].phyAddr, + self._io[0].pInputs[i].pVirAddr, + self._io[0].pInputs[i].nSize, ) break # execute model - ret = self._engine_lib.AX_ENGINE_RunSyncV2(self._handle[0], self._context[0], self._io) + ret = self._engine_lib.AX_ENGINE_RunSyncV2( + self._handle[0], self._context[0], self._io + ) # flush output outputs = [] if 0 == ret: for i in range(len(self.get_outputs())): self._sys_lib.AX_SYS_MinvalidateCache( - self._io[0].pOutputs[i].phyAddr, self._io[0].pOutputs[i].pVirAddr, self._io[0].pOutputs[i].nSize + self._io[0].pOutputs[i].phyAddr, + self._io[0].pOutputs[i].pVirAddr, + self._io[0].pOutputs[i].nSize, ) npy = np.frombuffer( - self._engine_ffi.buffer(self._io[0].pOutputs[i].pVirAddr, self._io[0].pOutputs[i].nSize), + self._engine_ffi.buffer( + self._io[0].pOutputs[i].pVirAddr, self._io[0].pOutputs[i].nSize + ), dtype=self.get_outputs()[i].dtype, ).reshape(self.get_outputs()[i].shape) name = self.get_outputs()[i].name