diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py b/python/tvm/relay/backend/contrib/ethosu/codegen.py index fc50e03f30080..8aa6064e1b522 100644 --- a/python/tvm/relay/backend/contrib/ethosu/codegen.py +++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py @@ -117,7 +117,7 @@ def alter_ethosu_op_layout(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Ca input_count += 1 if not isinstance(arg, tvm.relay.expr.Call): continue - if is_ethosu_op(arg): + if isinstance(arg.op, tvm.ir.op.Op) and arg.op.name in self.optimize_op: layout_string = "ifm_layout" if input_count <= 1 else f"ifm{input_count}_layout" new_attrs[layout_string] = "NHCWB16" parents.append(arg) @@ -126,7 +126,11 @@ def alter_ethosu_op_layout(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Ca if call in self.children: children = self.children[call] if all( - is_ethosu_op(child) and child.attrs["ifm_layout"] == "NHCWB16" for child in children + isinstance(child, tvm.relay.expr.Call) + and isinstance(child.op, tvm.ir.op.Op) + and child.op.name in self.optimize_op + and child.attrs["ifm_layout"] == "NHCWB16" + for child in children ): new_attrs["ofm_layout"] = "NHCWB16" @@ -144,6 +148,8 @@ def alter_ethosu_op_layout(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Ca else: self.children[input_arg] = [new_call] + print(new_call) + return super().visit_call(new_call) def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call: