Skip to content

Commit

Permalink
[mlir][dataflow] Propagate errors from visitOperation (#105448)
Browse files Browse the repository at this point in the history
Base `DataFlowAnalysis::visit` returns `LogicalResult`, but wrappers's
Sparse/Dense/Forward/Backward `visitOperation` doesn't.

Sometimes it's needed to abort solver early if some unrecoverable
condition detected inside analysis.

Update `visitOperation` to return `LogicalResult` and propagate it to
`solver.initializeAndRun()`. Only `visitOperation` is updated for now,
it's possible to update other hooks like `visitNonControlFlowArguments`,
bit it's not needed immediately and let's keep this PR small.

Hijacked `UnderlyingValueAnalysis` test analysis to test it.
  • Loading branch information
Hardcode84 authored Aug 22, 2024
1 parent 14c7e4a commit 15e915a
Show file tree
Hide file tree
Showing 16 changed files with 220 additions and 149 deletions.
29 changes: 16 additions & 13 deletions flang/lib/Optimizer/Transforms/StackArrays.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,9 @@ class AllocationAnalysis
public:
using DenseForwardDataFlowAnalysis::DenseForwardDataFlowAnalysis;

void visitOperation(mlir::Operation *op, const LatticePoint &before,
LatticePoint *after) override;
mlir::LogicalResult visitOperation(mlir::Operation *op,
const LatticePoint &before,
LatticePoint *after) override;

/// At an entry point, the last modifications of all memory resources are
/// yet to be determined
Expand All @@ -159,7 +160,7 @@ class AllocationAnalysis
protected:
/// Visit control flow operations and decide whether to call visitOperation
/// to apply the transfer function
void processOperation(mlir::Operation *op) override;
mlir::LogicalResult processOperation(mlir::Operation *op) override;
};

/// Drives analysis to find candidate fir.allocmem operations which could be
Expand Down Expand Up @@ -329,9 +330,8 @@ std::optional<AllocationState> LatticePoint::get(mlir::Value val) const {
return it->second;
}

void AllocationAnalysis::visitOperation(mlir::Operation *op,
const LatticePoint &before,
LatticePoint *after) {
mlir::LogicalResult AllocationAnalysis::visitOperation(
mlir::Operation *op, const LatticePoint &before, LatticePoint *after) {
LLVM_DEBUG(llvm::dbgs() << "StackArrays: Visiting operation: " << *op
<< "\n");
LLVM_DEBUG(llvm::dbgs() << "--Lattice in: " << before << "\n");
Expand All @@ -346,14 +346,14 @@ void AllocationAnalysis::visitOperation(mlir::Operation *op,
if (attr && attr.getValue()) {
LLVM_DEBUG(llvm::dbgs() << "--Found fir.must_be_heap: skipping\n");
// skip allocation marked not to be moved
return;
return mlir::success();
}

auto retTy = allocmem.getAllocatedType();
if (!mlir::isa<fir::SequenceType>(retTy)) {
LLVM_DEBUG(llvm::dbgs()
<< "--Allocation is not for an array: skipping\n");
return;
return mlir::success();
}

mlir::Value result = op->getResult(0);
Expand Down Expand Up @@ -387,6 +387,7 @@ void AllocationAnalysis::visitOperation(mlir::Operation *op,

LLVM_DEBUG(llvm::dbgs() << "--Lattice out: " << *after << "\n");
propagateIfChanged(after, changed);
return mlir::success();
}

void AllocationAnalysis::setToEntryState(LatticePoint *lattice) {
Expand All @@ -395,18 +396,20 @@ void AllocationAnalysis::setToEntryState(LatticePoint *lattice) {

/// Mostly a copy of AbstractDenseLattice::processOperation - the difference
/// being that call operations are passed through to the transfer function
void AllocationAnalysis::processOperation(mlir::Operation *op) {
mlir::LogicalResult AllocationAnalysis::processOperation(mlir::Operation *op) {
// If the containing block is not executable, bail out.
if (!getOrCreateFor<mlir::dataflow::Executable>(op, op->getBlock())->isLive())
return;
return mlir::success();

// Get the dense lattice to update
mlir::dataflow::AbstractDenseLattice *after = getLattice(op);

// If this op implements region control-flow, then control-flow dictates its
// transfer function.
if (auto branch = mlir::dyn_cast<mlir::RegionBranchOpInterface>(op))
return visitRegionBranchOperation(op, branch, after);
if (auto branch = mlir::dyn_cast<mlir::RegionBranchOpInterface>(op)) {
visitRegionBranchOperation(op, branch, after);
return mlir::success();
}

// pass call operations through to the transfer function

Expand All @@ -418,7 +421,7 @@ void AllocationAnalysis::processOperation(mlir::Operation *op) {
before = getLatticeFor(op, op->getBlock());

/// Invoke the operation transfer function
visitOperationImpl(op, *before, after);
return visitOperationImpl(op, *before, after);
}

llvm::LogicalResult
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,10 @@ class SparseConstantPropagation
public:
using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;

void visitOperation(Operation *op,
ArrayRef<const Lattice<ConstantValue> *> operands,
ArrayRef<Lattice<ConstantValue> *> results) override;
LogicalResult
visitOperation(Operation *op,
ArrayRef<const Lattice<ConstantValue> *> operands,
ArrayRef<Lattice<ConstantValue> *> results) override;

void setToEntryState(Lattice<ConstantValue> *lattice) override;
};
Expand Down
42 changes: 22 additions & 20 deletions mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis {
protected:
/// Propagate the dense lattice before the execution of an operation to the
/// lattice after its execution.
virtual void visitOperationImpl(Operation *op,
const AbstractDenseLattice &before,
AbstractDenseLattice *after) = 0;
virtual LogicalResult visitOperationImpl(Operation *op,
const AbstractDenseLattice &before,
AbstractDenseLattice *after) = 0;

/// Get the dense lattice after the execution of the given program point.
virtual AbstractDenseLattice *getLattice(ProgramPoint point) = 0;
Expand All @@ -114,7 +114,7 @@ class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis {
/// operation, then the state after the execution of the operation is set by
/// control-flow or the callgraph. Otherwise, this function invokes the
/// operation transfer function.
virtual void processOperation(Operation *op);
virtual LogicalResult processOperation(Operation *op);

/// Propagate the dense lattice forward along the control flow edge from
/// `regionFrom` to `regionTo` regions of the `branch` operation. `nullopt`
Expand Down Expand Up @@ -191,8 +191,8 @@ class DenseForwardDataFlowAnalysis
/// Visit an operation with the dense lattice before its execution. This
/// function is expected to set the dense lattice after its execution and
/// trigger change propagation in case of change.
virtual void visitOperation(Operation *op, const LatticeT &before,
LatticeT *after) = 0;
virtual LogicalResult visitOperation(Operation *op, const LatticeT &before,
LatticeT *after) = 0;

/// Hook for customizing the behavior of lattice propagation along the call
/// control flow edges. Two types of (forward) propagation are possible here:
Expand Down Expand Up @@ -263,10 +263,11 @@ class DenseForwardDataFlowAnalysis

/// Type-erased wrappers that convert the abstract dense lattice to a derived
/// lattice and invoke the virtual hooks operating on the derived lattice.
void visitOperationImpl(Operation *op, const AbstractDenseLattice &before,
AbstractDenseLattice *after) final {
visitOperation(op, static_cast<const LatticeT &>(before),
static_cast<LatticeT *>(after));
LogicalResult visitOperationImpl(Operation *op,
const AbstractDenseLattice &before,
AbstractDenseLattice *after) final {
return visitOperation(op, static_cast<const LatticeT &>(before),
static_cast<LatticeT *>(after));
}
void visitCallControlFlowTransfer(CallOpInterface call,
CallControlFlowAction action,
Expand Down Expand Up @@ -326,9 +327,9 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
protected:
/// Propagate the dense lattice after the execution of an operation to the
/// lattice before its execution.
virtual void visitOperationImpl(Operation *op,
const AbstractDenseLattice &after,
AbstractDenseLattice *before) = 0;
virtual LogicalResult visitOperationImpl(Operation *op,
const AbstractDenseLattice &after,
AbstractDenseLattice *before) = 0;

/// Get the dense lattice before the execution of the program point. That is,
/// before the execution of the given operation or after the execution of the
Expand All @@ -353,7 +354,7 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
/// Visit an operation. Dispatches to specialized methods for call or region
/// control-flow operations. Otherwise, this function invokes the operation
/// transfer function.
virtual void processOperation(Operation *op);
virtual LogicalResult processOperation(Operation *op);

/// Propagate the dense lattice backwards along the control flow edge from
/// `regionFrom` to `regionTo` regions of the `branch` operation. `nullopt`
Expand Down Expand Up @@ -442,8 +443,8 @@ class DenseBackwardDataFlowAnalysis
/// Transfer function. Visits an operation with the dense lattice after its
/// execution. This function is expected to set the dense lattice before its
/// execution and trigger propagation in case of change.
virtual void visitOperation(Operation *op, const LatticeT &after,
LatticeT *before) = 0;
virtual LogicalResult visitOperation(Operation *op, const LatticeT &after,
LatticeT *before) = 0;

/// Hook for customizing the behavior of lattice propagation along the call
/// control flow edges. Two types of (back) propagation are possible here:
Expand Down Expand Up @@ -513,10 +514,11 @@ class DenseBackwardDataFlowAnalysis

/// Type-erased wrappers that convert the abstract dense lattice to a derived
/// lattice and invoke the virtual hooks operating on the derived lattice.
void visitOperationImpl(Operation *op, const AbstractDenseLattice &after,
AbstractDenseLattice *before) final {
visitOperation(op, static_cast<const LatticeT &>(after),
static_cast<LatticeT *>(before));
LogicalResult visitOperationImpl(Operation *op,
const AbstractDenseLattice &after,
AbstractDenseLattice *before) final {
return visitOperation(op, static_cast<const LatticeT &>(after),
static_cast<LatticeT *>(before));
}
void visitCallControlFlowTransfer(CallOpInterface call,
CallControlFlowAction action,
Expand Down
7 changes: 4 additions & 3 deletions mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,10 @@ class IntegerRangeAnalysis

/// Visit an operation. Invoke the transfer function on each operation that
/// implements `InferIntRangeInterface`.
void visitOperation(Operation *op,
ArrayRef<const IntegerValueRangeLattice *> operands,
ArrayRef<IntegerValueRangeLattice *> results) override;
LogicalResult
visitOperation(Operation *op,
ArrayRef<const IntegerValueRangeLattice *> operands,
ArrayRef<IntegerValueRangeLattice *> results) override;

/// Visit block arguments or operation results of an operation with region
/// control-flow for which values are not defined by region control-flow. This
Expand Down
4 changes: 2 additions & 2 deletions mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ class LivenessAnalysis : public SparseBackwardDataFlowAnalysis<Liveness> {
public:
using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;

void visitOperation(Operation *op, ArrayRef<Liveness *> operands,
ArrayRef<const Liveness *> results) override;
LogicalResult visitOperation(Operation *op, ArrayRef<Liveness *> operands,
ArrayRef<const Liveness *> results) override;

void visitBranchOperand(OpOperand &operand) override;

Expand Down
26 changes: 14 additions & 12 deletions mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {

/// The operation transfer function. Given the operand lattices, this
/// function is expected to set the result lattices.
virtual void
virtual LogicalResult
visitOperationImpl(Operation *op,
ArrayRef<const AbstractSparseLattice *> operandLattices,
ArrayRef<AbstractSparseLattice *> resultLattices) = 0;
Expand Down Expand Up @@ -238,7 +238,7 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
/// Visit an operation. If this is a call operation or an operation with
/// region control-flow, then its result lattices are set accordingly.
/// Otherwise, the operation transfer function is invoked.
void visitOperation(Operation *op);
LogicalResult visitOperation(Operation *op);

/// Visit a block to compute the lattice values of its arguments. If this is
/// an entry block, then the argument values are determined from the block's
Expand Down Expand Up @@ -277,8 +277,9 @@ class SparseForwardDataFlowAnalysis

/// Visit an operation with the lattices of its operands. This function is
/// expected to set the lattices of the operation's results.
virtual void visitOperation(Operation *op, ArrayRef<const StateT *> operands,
ArrayRef<StateT *> results) = 0;
virtual LogicalResult visitOperation(Operation *op,
ArrayRef<const StateT *> operands,
ArrayRef<StateT *> results) = 0;

/// Visit a call operation to an externally defined function given the
/// lattices of its arguments.
Expand Down Expand Up @@ -328,10 +329,10 @@ class SparseForwardDataFlowAnalysis
private:
/// Type-erased wrappers that convert the abstract lattice operands to derived
/// lattices and invoke the virtual hooks operating on the derived lattices.
void visitOperationImpl(
LogicalResult visitOperationImpl(
Operation *op, ArrayRef<const AbstractSparseLattice *> operandLattices,
ArrayRef<AbstractSparseLattice *> resultLattices) override {
visitOperation(
return visitOperation(
op,
{reinterpret_cast<const StateT *const *>(operandLattices.begin()),
operandLattices.size()},
Expand Down Expand Up @@ -387,7 +388,7 @@ class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {

/// The operation transfer function. Given the result lattices, this
/// function is expected to set the operand lattices.
virtual void visitOperationImpl(
virtual LogicalResult visitOperationImpl(
Operation *op, ArrayRef<AbstractSparseLattice *> operandLattices,
ArrayRef<const AbstractSparseLattice *> resultLattices) = 0;

Expand Down Expand Up @@ -424,7 +425,7 @@ class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {
/// Visit an operation. If this is a call operation or an operation with
/// region control-flow, then its operand lattices are set accordingly.
/// Otherwise, the operation transfer function is invoked.
void visitOperation(Operation *op);
LogicalResult visitOperation(Operation *op);

/// Visit a block.
void visitBlock(Block *block);
Expand Down Expand Up @@ -474,8 +475,9 @@ class SparseBackwardDataFlowAnalysis

/// Visit an operation with the lattices of its results. This function is
/// expected to set the lattices of the operation's operands.
virtual void visitOperation(Operation *op, ArrayRef<StateT *> operands,
ArrayRef<const StateT *> results) = 0;
virtual LogicalResult visitOperation(Operation *op,
ArrayRef<StateT *> operands,
ArrayRef<const StateT *> results) = 0;

/// Visit a call to an external function. This function is expected to set
/// lattice values of the call operands. By default, calls `visitCallOperand`
Expand Down Expand Up @@ -510,10 +512,10 @@ class SparseBackwardDataFlowAnalysis
private:
/// Type-erased wrappers that convert the abstract lattice operands to derived
/// lattices and invoke the virtual hooks operating on the derived lattices.
void visitOperationImpl(
LogicalResult visitOperationImpl(
Operation *op, ArrayRef<AbstractSparseLattice *> operandLattices,
ArrayRef<const AbstractSparseLattice *> resultLattices) override {
visitOperation(
return visitOperation(
op,
{reinterpret_cast<StateT *const *>(operandLattices.begin()),
operandLattices.size()},
Expand Down
11 changes: 6 additions & 5 deletions mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ void ConstantValue::print(raw_ostream &os) const {
// SparseConstantPropagation
//===----------------------------------------------------------------------===//

void SparseConstantPropagation::visitOperation(
LogicalResult SparseConstantPropagation::visitOperation(
Operation *op, ArrayRef<const Lattice<ConstantValue> *> operands,
ArrayRef<Lattice<ConstantValue> *> results) {
LLVM_DEBUG(llvm::dbgs() << "SCP: Visiting operation: " << *op << "\n");
Expand All @@ -54,14 +54,14 @@ void SparseConstantPropagation::visitOperation(
// folding.
if (op->getNumRegions()) {
setAllToEntryStates(results);
return;
return success();
}

SmallVector<Attribute, 8> constantOperands;
constantOperands.reserve(op->getNumOperands());
for (auto *operandLattice : operands) {
if (operandLattice->getValue().isUninitialized())
return;
return success();
constantOperands.push_back(operandLattice->getValue().getConstantValue());
}

Expand All @@ -77,7 +77,7 @@ void SparseConstantPropagation::visitOperation(
foldResults.reserve(op->getNumResults());
if (failed(op->fold(constantOperands, foldResults))) {
setAllToEntryStates(results);
return;
return success();
}

// If the folding was in-place, mark the results as overdefined and reset
Expand All @@ -87,7 +87,7 @@ void SparseConstantPropagation::visitOperation(
op->setOperands(originalOperands);
op->setAttrs(originalAttrs);
setAllToEntryStates(results);
return;
return success();
}

// Merge the fold results into the lattice for this operation.
Expand All @@ -108,6 +108,7 @@ void SparseConstantPropagation::visitOperation(
lattice, *getLatticeElement(foldResult.get<Value>()));
}
}
return success();
}

void SparseConstantPropagation::setToEntryState(
Expand Down
Loading

0 comments on commit 15e915a

Please sign in to comment.