Skip to content

Commit

Permalink
[TIR][UX] allow override when register TensorIntrin (#12439)
Browse files Browse the repository at this point in the history
* allow override when register TensorIntrin

* lint
  • Loading branch information
Hzfengsy authored Aug 15, 2022
1 parent 25c4a73 commit 55f1d7e
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 6 deletions.
3 changes: 2 additions & 1 deletion include/tvm/tir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,11 @@ class TensorIntrin : public ObjectRef {
* up with its name.
* \param name The name of the TensorIntrin to register
* \param intrin The TensorIntrin to register.
* \param override Whether override existing intrinsic.
* \throws This method throws an exception if the TensorIntrin with the specified name already
* exists.
*/
TVM_DLL static void Register(String name, TensorIntrin intrin);
TVM_DLL static void Register(String name, TensorIntrin intrin, bool override = false);

/*!
* \brief Look up TensorIntrin by name. Raises an exception if not found.
Expand Down
8 changes: 6 additions & 2 deletions python/tvm/tir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def __init__(self, desc, impl):
self.__init_handle_by_constructor__(_ffi_api.TensorIntrin, desc, impl)

@staticmethod
def register(name: str, desc: PrimFunc, impl: PrimFunc):
def register(name: str, desc: PrimFunc, impl: PrimFunc, override: bool = False):
"""Register a tensor intrinsic with its name.
Parameters
Expand All @@ -237,8 +237,12 @@ def register(name: str, desc: PrimFunc, impl: PrimFunc):
The function to describe the computation.
impl : PrimFunc
The function of the implementation for the execution.
override: bool
Whether override existing intrinsic.
"""
return _ffi_api.TensorIntrinRegister(name, TensorIntrin(desc, impl)) # type: ignore
return _ffi_api.TensorIntrinRegister(
name, TensorIntrin(desc, impl), override
) # type: ignore

@staticmethod
def get(name: str):
Expand Down
8 changes: 5 additions & 3 deletions src/tir/ir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,12 @@ TensorIntrin::TensorIntrin(PrimFunc desc, PrimFunc impl) {
data_ = std::move(n);
}

void TensorIntrin::Register(String name, TensorIntrin intrin) {
void TensorIntrin::Register(String name, TensorIntrin intrin, bool override) {
TensorIntrinManager* manager = TensorIntrinManager::Global();
CHECK_EQ(manager->reg.count(name), 0)
<< "ValueError: TensorIntrin '" << name << "' has already been registered";
if (!override) {
CHECK_EQ(manager->reg.count(name), 0)
<< "ValueError: TensorIntrin '" << name << "' has already been registered";
}
manager->reg.Set(name, intrin);
}

Expand Down

0 comments on commit 55f1d7e

Please sign in to comment.