Skip to content

Commit

Permalink
适配新代码,利用axclrtSetDevice完成context的get/set,从而支持跨线程推理
Browse files Browse the repository at this point in the history
  • Loading branch information
zylo117 authored and kalcohol committed Jan 21, 2025
1 parent 54bab3b commit 95d3b87
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions axengine/_axclrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ def __init__(
self.soc_name = axclrt_cffi.string(axclrt_lib.axclrtGetSocName()).decode()
print(f"[INFO] SOC Name: {self.soc_name}")

self._thread_context = axclrt_cffi.new("axclrtContext *")
ret = axclrt_lib.axclrtGetCurrentContext(self._thread_context)
if ret != 0:
raise RuntimeError("axclrtGetCurrentContext failed")

# model handle, context, info, io
self._model_id = axclrt_cffi.new("uint64_t *")
self._context_id = axclrt_cffi.new("uint64_t *")
Expand Down Expand Up @@ -322,6 +327,10 @@ def run(
self._validate_input(input_feed)
self._validate_output(output_names)

ret = axclrt_lib.axclrtSetCurrentContext(self._thread_context[0])
if ret != 0:
raise RuntimeError("axclrtSetCurrentContext failed")

if None is output_names:
output_names = [o.name for o in self.get_outputs()]

Expand Down

0 comments on commit 95d3b87

Please sign in to comment.