Skip to content

Commit

Permalink
add check for inputs (AXERA-TECH#14)
Browse files Browse the repository at this point in the history
* session add check inputs

* fix check coniguous

* fix assert inputs shape
  • Loading branch information
xiguadong authored Dec 10, 2024
1 parent f7bbee7 commit 57a6028
Showing 1 changed file with 62 additions and 17 deletions.
79 changes: 62 additions & 17 deletions axengine/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ def __init__(
raise ValueError(
f"Model type '{self._model_type}' is not allowed when vnpu is inited as {self._vnpu_type}."
)
if self._vnpu_type is VNPUType.BIG_LITTLE or self._vnpu_type is VNPUType.LITTLE_BIG:
if (
self._vnpu_type is VNPUType.BIG_LITTLE
or self._vnpu_type is VNPUType.LITTLE_BIG
):
if self._model_type is ModelType.TRIPLE:
raise ValueError(
f"Model type '{self._model_type}' is not allowed when vnpu is inited as {self._vnpu_type}."
Expand Down Expand Up @@ -123,8 +126,12 @@ def __init__(
self._cmm_token = self._engine_ffi.new("AX_S8[]", b"PyEngine")
self._io[0].nInputSize = len(self.get_inputs())
self._io[0].nOutputSize = len(self.get_outputs())
self._io[0].pInputs = self._engine_ffi.new("AX_ENGINE_IO_BUFFER_T[{}]".format(self._io[0].nInputSize))
self._io[0].pOutputs = self._engine_ffi.new("AX_ENGINE_IO_BUFFER_T[{}]".format(self._io[0].nOutputSize))
self._io[0].pInputs = self._engine_ffi.new(
"AX_ENGINE_IO_BUFFER_T[{}]".format(self._io[0].nInputSize)
)
self._io[0].pOutputs = self._engine_ffi.new(
"AX_ENGINE_IO_BUFFER_T[{}]".format(self._io[0].nOutputSize)
)
for i in range(len(self.get_inputs())):
max_buf = 0
for j in range(self._shape_count):
Expand Down Expand Up @@ -173,7 +180,9 @@ def _init(self, vnpu=VNPUType.DISABLED): # vnpu type, the default is disabled
ret = self._engine_lib.AX_ENGINE_GetVNPUAttr(vnpu_type)
if 0 != ret:
# this means the NPU was not initialized
vnpu_type.eHardMode = self._engine_ffi.cast("AX_ENGINE_NPU_MODE_T", vnpu.value)
vnpu_type.eHardMode = self._engine_ffi.cast(
"AX_ENGINE_NPU_MODE_T", vnpu.value
)

return self._engine_lib.AX_ENGINE_Init(vnpu_type)

Expand All @@ -196,13 +205,17 @@ def _get_vnpu_type(self) -> VNPUType:

def _get_model_type(self) -> ModelType:
model_type = self._engine_ffi.new("AX_ENGINE_MODEL_TYPE_T *")
ret = self._engine_lib.AX_ENGINE_GetModelType(self._model_buffer, self._model_buffer_size, model_type)
ret = self._engine_lib.AX_ENGINE_GetModelType(
self._model_buffer, self._model_buffer_size, model_type
)
if 0 != ret:
raise RuntimeError("Failed to get model type.")
return ModelType(model_type[0])

def _get_model_tool_version(self):
model_tool_version = self._engine_lib.AX_ENGINE_GetModelToolsVersion(self._handle[0])
model_tool_version = self._engine_lib.AX_ENGINE_GetModelToolsVersion(
self._handle[0]
)
return self._engine_ffi.string(model_tool_version).decode("utf-8")

def _load(self):
Expand All @@ -216,7 +229,9 @@ def _load(self):
self._handle, self._model_buffer, self._model_buffer_size, extra
)
if 0 == ret:
ret = self._engine_lib.AX_ENGINE_CreateContextV2(self._handle[0], self._context)
ret = self._engine_lib.AX_ENGINE_CreateContextV2(
self._handle[0], self._context
)
return ret

def _get_info(self):
Expand All @@ -230,7 +245,9 @@ def _get_info(self):
else:
for i in range(self._shape_count):
info = self._engine_ffi.new("AX_ENGINE_IO_INFO_T **")
ret = self._engine_lib.AX_ENGINE_GetGroupIOInfo(self._handle[0], i, info)
ret = self._engine_lib.AX_ENGINE_GetGroupIOInfo(
self._handle[0], i, info
)
if 0 != ret:
raise RuntimeError(f"Failed to get model the {i}th shape.")
total_info.append(info)
Expand All @@ -256,7 +273,9 @@ def _get_inputs(self):
shape = []
for i in range(current_input.nShapeSize):
shape.append(current_input.pShape[i])
dtype = _transform_dtype(self._engine_ffi, self._engine_lib, current_input.eDataType)
dtype = _transform_dtype(
self._engine_ffi, self._engine_lib, current_input.eDataType
)
meta = NodeArg(name, dtype, shape)
one_group_input.append(meta)
inputs.append(one_group_input)
Expand All @@ -272,21 +291,27 @@ def _get_outputs(self):
shape = []
for i in range(current_output.nShapeSize):
shape.append(current_output.pShape[i])
dtype = _transform_dtype(self._engine_ffi, self._engine_lib, current_output.eDataType)
dtype = _transform_dtype(
self._engine_ffi, self._engine_lib, current_output.eDataType
)
meta = NodeArg(name, dtype, shape)
one_group_output.append(meta)
outputs.append(one_group_output)
return outputs

def get_inputs(self, shape_group=0) -> list[NodeArg]:
if shape_group > self._shape_count:
raise ValueError(f"Shape group '{shape_group}' is out of range, total {self._shape_count}.")
raise ValueError(
f"Shape group '{shape_group}' is out of range, total {self._shape_count}."
)
selected_info = self._inputs[shape_group]
return selected_info

def get_outputs(self, shape_group=0) -> list[NodeArg]:
if shape_group > self._shape_count:
raise ValueError(f"Shape group '{shape_group}' is out of range, total {self._shape_count}.")
raise ValueError(
f"Shape group '{shape_group}' is out of range, total {self._shape_count}."
)
selected_info = self._outputs[shape_group]
return selected_info

Expand Down Expand Up @@ -318,25 +343,45 @@ def run(self, output_names, input_feed, run_options=None):
for key, npy in input_feed.items():
for i, one in enumerate(self.get_inputs()):
if one.name == key:
assert (
list(one.shape) == list(npy.shape) and one.dtype == npy.dtype
), f"model inputs({key}) expect shape {one.shape} and dtype {one.dtype}, howerver gets input with shape {npy.shape} and dtype {npy.dtype}"

if not (
not npy.flags.c_contiguous
and npy.flags.f_contiguous
and npy.flags.contiguous
):
npy = np.ascontiguousarray(npy)
npy_ptr = self._engine_ffi.cast("void *", npy.ctypes.data)
self._engine_ffi.memmove(self._io[0].pInputs[i].pVirAddr, npy_ptr, npy.nbytes)
self._engine_ffi.memmove(
self._io[0].pInputs[i].pVirAddr, npy_ptr, npy.nbytes
)
self._sys_lib.AX_SYS_MflushCache(
self._io[0].pInputs[i].phyAddr, self._io[0].pInputs[i].pVirAddr, self._io[0].pInputs[i].nSize
self._io[0].pInputs[i].phyAddr,
self._io[0].pInputs[i].pVirAddr,
self._io[0].pInputs[i].nSize,
)
break

# execute model
ret = self._engine_lib.AX_ENGINE_RunSyncV2(self._handle[0], self._context[0], self._io)
ret = self._engine_lib.AX_ENGINE_RunSyncV2(
self._handle[0], self._context[0], self._io
)

# flush output
outputs = []
if 0 == ret:
for i in range(len(self.get_outputs())):
self._sys_lib.AX_SYS_MinvalidateCache(
self._io[0].pOutputs[i].phyAddr, self._io[0].pOutputs[i].pVirAddr, self._io[0].pOutputs[i].nSize
self._io[0].pOutputs[i].phyAddr,
self._io[0].pOutputs[i].pVirAddr,
self._io[0].pOutputs[i].nSize,
)
npy = np.frombuffer(
self._engine_ffi.buffer(self._io[0].pOutputs[i].pVirAddr, self._io[0].pOutputs[i].nSize),
self._engine_ffi.buffer(
self._io[0].pOutputs[i].pVirAddr, self._io[0].pOutputs[i].nSize
),
dtype=self.get_outputs()[i].dtype,
).reshape(self.get_outputs()[i].shape)
name = self.get_outputs()[i].name
Expand Down

0 comments on commit 57a6028

Please sign in to comment.