Skip to content

Commit

Permalink
[CINN]remove concat op infer symbolic check (#68705)
Browse files Browse the repository at this point in the history
  • Loading branch information
phlrain authored Oct 17, 2024
1 parent 46180c3 commit 30d7c4f
Showing 1 changed file with 2 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1118,26 +1118,8 @@ bool ConcatOpInferSymbolicShape(pir::Operation *op,
return axis >= 0 ? axis : std::max(int64_t(0), int64_t(axis + rank));
}();

if (details::HasCompleteData(x_shape)) {
if (rank == 1) {
ExprVec data = details::GetExprVecFromData(x_shape);
const std::vector<symbol::DimExpr> shape{std::int64_t(data.size())};
symbol::ShapeOrDataDimExprs shape_data{
symbol::TensorShapeOrDataDimExprs(shape, data)};
pir::Value res = op->result(0);
infer_context->SetShapeOrDataForValue(res, shape_data);

return true;
} else {
PADDLE_THROW(common::errors::Unimplemented(
op->name() +
" 's InferSymbolicShape can NOT deal with rank > 1 now."));
}
std::vector<symbol::DimExpr> data;
data.reserve(shape_data_list.size());
for (auto &data_elem : shape_data_list) {
data.push_back(data_elem.data().value().at(0));
}
if (details::HasCompleteData(x_shape) && (rank == 1)) {
ExprVec data = details::GetExprVecFromData(x_shape);
const std::vector<symbol::DimExpr> shape{std::int64_t(data.size())};
symbol::ShapeOrDataDimExprs shape_data{
symbol::TensorShapeOrDataDimExprs(shape, data)};
Expand Down

0 comments on commit 30d7c4f

Please sign in to comment.