From 854dea588288af7824ffc48fcf01dfdc2d1e235c Mon Sep 17 00:00:00 2001 From: phlrain Date: Wed, 6 Dec 2023 13:29:03 +0000 Subject: [PATCH] fix pir cinn transpose bug --- paddle/cinn/hlir/framework/pir/op_mapper.cc | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/paddle/cinn/hlir/framework/pir/op_mapper.cc b/paddle/cinn/hlir/framework/pir/op_mapper.cc index 477ceb52613fb..183e68a183cd5 100644 --- a/paddle/cinn/hlir/framework/pir/op_mapper.cc +++ b/paddle/cinn/hlir/framework/pir/op_mapper.cc @@ -44,13 +44,22 @@ void AppendAttrForReduceOp(const ::pir::Operation& op, void AppendAttrForTransposeOp(const ::pir::Operation& op, utils::AttributeMap& attrs) { // NOLINT + auto rank = op.operand_source(0) + .type() + .dyn_cast() + .dims() + .size(); auto attr = op.attributes().at("perm"); auto attr_vec = attr.dyn_cast<::pir::ArrayAttribute>().AsVector(); std::vector dim; for (auto vec_element : attr_vec) { - dim.push_back(vec_element.dyn_cast<::pir::Int32Attribute>().data()); + auto ele = vec_element.dyn_cast<::pir::Int32Attribute>().data(); + if (ele < 0) { + ele += rank; + } + dim.push_back(ele); } attrs["axis"] = dim;