Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN] Infer the shape of cf.tuple_pop by adding a grad cache interface for cf.tuple_push #70723

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion paddle/pir/include/dialect/control_flow/ir/cf_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ class IR_API YieldOp : public Op<YieldOp, SideEffectTrait> {
///
/// \brief Push a value tuple to a container.
///
class IR_API TuplePushOp : public Op<TuplePushOp, SideEffectTrait> {
class IR_API TuplePushOp : public Op<TuplePushOp,
SideEffectTrait,
CacheGradOpSymbolicShapeInterface> {
public:
using Op::Op;
static const char *name() { return "cf.tuple_push"; }
Expand Down Expand Up @@ -70,6 +72,8 @@ class IR_API TuplePushOp : public Op<TuplePushOp, SideEffectTrait> {
return inlet().defining_op<ContainerOpInterface>();
}
TuplePopOp tuple_pop_op();

CacheGradOpSymbolicShape(pir::InferSymbolicShapeContext *infer_context);
};

class IR_API TuplePopOp : public Op<TuplePopOp, SideEffectTrait> {
Expand Down
24 changes: 19 additions & 5 deletions paddle/pir/src/dialect/control_flow/ir/cf_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,21 @@ TuplePopOp TuplePushOp::tuple_pop_op() {
return container_interface().tuple_pop_op();
}

void TuplePushOp::CacheGradOpSymbolicShape(
pir::InferSymbolicShapeContext *infer_context) {
const auto &x_shape = GetInputShape(infer_context, this->operation(), 0);
pir::InferSymbolicShapeCacheKey op_shape_info("cf.tuple_pop", {x_shape}, );

std::vector<symbol::ShapeOrDataDimExprs> pop_value_shape_list;
for (size_t index = 1; index < num_operands(); ++index) {
const auto &pop_value_shape =
GetGradVarShapeFromInput(infer_context, this->operation(), index);
pop_value_shape_list.emplace_back(pop_value_shape);
}
infer_context->SetOpInferSymbolicShapeCache(op_shape_info,
pop_value_shape_list);
}

void TuplePopOp::Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
Value outlet) {
Expand Down Expand Up @@ -202,11 +217,10 @@ void StackCreateOp::VerifySig() {

bool StackCreateOp::InferSymbolicShape(
pir::InferSymbolicShapeContext *infer_context) {
const auto &null_shape_or_data =
symbol::ShapeOrDataDimExprs(symbol::NullShapeOrDataDimExpr());
infer_context->SetShapeOrDataForValue(result(0), null_shape_or_data);
infer_context->SetShapeOrDataForValue(result(1), null_shape_or_data);
infer_context->SetShapeOrDataForValue(result(2), null_shape_or_data);
symbol::DimExpr mark_symbol = infer_context->GetNextSymName();
infer_context->SetShapeOrDataForValue(result(0), mark_symbol);
infer_context->SetShapeOrDataForValue(result(1), mark_symbol);
infer_context->SetShapeOrDataForValue(result(2), mark_symbol);
return true;
}

Expand Down
Loading