Skip to content

Commit

Permalink
[PIR] add control flow op backward components. (PaddlePaddle#58729)
Browse files Browse the repository at this point in the history
  • Loading branch information
winter-wang authored Nov 8, 2023
1 parent 86f501b commit 4cb2a0c
Show file tree
Hide file tree
Showing 18 changed files with 493 additions and 30 deletions.
10 changes: 5 additions & 5 deletions paddle/fluid/pir/dialect/kernel/ir/kernel_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,39 +29,39 @@ const phi::DDim& AllocatedDenseTensorType::dims() const {
return storage()->dense_tensor_type_.dims();
}

const phi::DataLayout& AllocatedDenseTensorType::data_layout() const {
phi::DataLayout AllocatedDenseTensorType::data_layout() const {
return storage()->dense_tensor_type_.data_layout();
}

const phi::LoD& AllocatedDenseTensorType::lod() const {
return storage()->dense_tensor_type_.lod();
}

const size_t& AllocatedDenseTensorType::offset() const {
size_t AllocatedDenseTensorType::offset() const {
return storage()->dense_tensor_type_.offset();
}

const phi::Place& AllocatedSelectedRowsType::place() const {
return storage()->place_;
}

const pir::Type& AllocatedSelectedRowsType::dtype() const {
pir::Type AllocatedSelectedRowsType::dtype() const {
return storage()->selected_rows_type_.dtype();
}

const phi::DDim& AllocatedSelectedRowsType::dims() const {
return storage()->selected_rows_type_.dims();
}

const phi::DataLayout& AllocatedSelectedRowsType::data_layout() const {
phi::DataLayout AllocatedSelectedRowsType::data_layout() const {
return storage()->selected_rows_type_.data_layout();
}

const phi::LoD& AllocatedSelectedRowsType::lod() const {
return storage()->selected_rows_type_.lod();
}

const size_t& AllocatedSelectedRowsType::offset() const {
size_t AllocatedSelectedRowsType::offset() const {
return storage()->selected_rows_type_.offset();
}

Expand Down
10 changes: 5 additions & 5 deletions paddle/fluid/pir/dialect/kernel/ir/kernel_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ class AllocatedDenseTensorType

const phi::DDim &dims() const;

const phi::DataLayout &data_layout() const;
phi::DataLayout data_layout() const;

const phi::LoD &lod() const;

const size_t &offset() const;
size_t offset() const;
};

class AllocatedSelectedRowsType
Expand Down Expand Up @@ -92,15 +92,15 @@ class AllocatedSelectedRowsType

const phi::Place &place() const;

const pir::Type &dtype() const;
pir::Type dtype() const;

const phi::DDim &dims() const;

const phi::DataLayout &data_layout() const;
phi::DataLayout data_layout() const;

const phi::LoD &lod() const;

const size_t &offset() const;
size_t offset() const;
};

} // namespace dialect
Expand Down
8 changes: 4 additions & 4 deletions paddle/pir/core/builtin_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,19 @@ std::vector<Type> VectorType::data() const { return storage()->GetAsKey(); }

pir::Type DenseTensorType::dtype() const { return storage()->dtype_; }

const DenseTensorTypeStorage::Dim& DenseTensorType::dims() const {
const DenseTensorType::Dim& DenseTensorType::dims() const {
return storage()->dims_;
}

const DenseTensorTypeStorage::DataLayout& DenseTensorType::data_layout() const {
DenseTensorType::DataLayout DenseTensorType::data_layout() const {
return storage()->layout_;
}

const DenseTensorTypeStorage::LoD& DenseTensorType::lod() const {
const DenseTensorType::LoD& DenseTensorType::lod() const {
return storage()->lod_;
}

const size_t& DenseTensorType::offset() const { return storage()->offset_; }
size_t DenseTensorType::offset() const { return storage()->offset_; }
} // namespace pir

IR_DEFINE_EXPLICIT_TYPE_ID(pir::UInt8Type)
Expand Down
23 changes: 15 additions & 8 deletions paddle/pir/core/builtin_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,23 @@ class DenseTensorType : public Type::TypeBase<DenseTensorType,
ShapedTypeInterface> {
public:
using Base::Base;
using Dim = DenseTensorTypeStorage::Dim;
using DataLayout = DenseTensorTypeStorage::DataLayout;
using LoD = DenseTensorTypeStorage::LoD;

Type dtype() const;

const DenseTensorTypeStorage::Dim &dims() const;

const DenseTensorTypeStorage::DataLayout &data_layout() const;

const DenseTensorTypeStorage::LoD &lod() const;

const size_t &offset() const;
const Dim &dims() const;
DataLayout data_layout() const;
const LoD &lod() const;
size_t offset() const;
static DenseTensorType get(IrContext *ctx,
Type dtype,
const Dim &dims,
DataLayout layout = DataLayout::kNCHW,
const LoD &lod = {},
size_t offset = 0u) {
return Base::get(ctx, dtype, dims, layout, lod, offset);
}
};

#define DECLARE_BUILTIN_TYPE(__name) \
Expand Down
6 changes: 3 additions & 3 deletions paddle/pir/core/builtin_type_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ struct DenseTensorTypeStorage : public pir::TypeStorage {
using DataLayout = phi::DataLayout;
using Dim = phi::DDim;
using LoD = std::vector<std::vector<size_t>>;
using ParamKey = std::tuple<pir::Type, Dim, DataLayout, LoD, size_t>;
using ParamKey = std::tuple<Type, Dim, DataLayout, LoD, size_t>;

DenseTensorTypeStorage(const pir::Type& dtype,
DenseTensorTypeStorage(Type dtype,
const Dim& dims,
const DataLayout& layout,
DataLayout layout,
const LoD& lod,
size_t offset)
: dtype_(dtype),
Expand Down
7 changes: 6 additions & 1 deletion paddle/pir/core/ir_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ void IrPrinter::PrintValue(Value v) {
os << "<<NULL VALUE>>";
return;
}
const void* key = static_cast<const void*>(v.impl());
const void* key = v.impl();
auto ret = aliases_.find(key);
if (ret != aliases_.end()) {
os << ret->second;
Expand Down Expand Up @@ -310,6 +310,11 @@ void IrPrinter::PrintOpReturnType(Operation* op) {
[this]() { this->os << ", "; });
}

void IrPrinter::AddValueAlias(Value v, const std::string& alias) {
const void* key = v.impl();
IR_ENFORCE(aliases_.find(key) == aliases_.end(), "Value already has alias");
aliases_[key] = alias;
}
void Dialect::PrintOperation(Operation* op, IrPrinter& printer) const {
printer.PrintGeneralOperation(op);
}
Expand Down
2 changes: 2 additions & 0 deletions paddle/pir/core/ir_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ class IR_API IrPrinter : public BasicIrPrinter {

void PrintOpReturnType(Operation* op);

void AddValueAlias(Value value, const std::string& alias);

private:
size_t cur_result_number_{0};
size_t cur_block_argument_number_{0};
Expand Down
3 changes: 1 addition & 2 deletions paddle/pir/core/op_result.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ bool OpResult::classof(Value value) {
}

Operation *OpResult::owner() const {
CHECK_OPRESULT_NULL_IMPL(owner);
return IMPL_->owner();
return impl_ ? static_cast<detail::OpResultImpl *>(impl_)->owner() : nullptr;
}

uint32_t OpResult::index() const {
Expand Down
3 changes: 3 additions & 0 deletions paddle/pir/core/operation_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ struct OperationArgument {
template <class InputIt>
void AddOutputs(InputIt first, InputIt last);

void AddOutputs(std::initializer_list<Type> type_list) {
AddOutputs(std::begin(type_list), std::end(type_list));
}
template <class TypeContainer>
void AddOutputs(const TypeContainer& type_container) {
AddOutputs(std::begin(type_container), std::end(type_container));
Expand Down
2 changes: 2 additions & 0 deletions paddle/pir/core/value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ Value::operator bool() const { return impl_; }

pir::Type Value::type() const { return impl_ ? impl_->type() : nullptr; }

Operation *Value::defining_op() const { return dyn_cast<OpResult>().owner(); }

void Value::set_type(pir::Type type) {
CHECK_VALUE_NULL_IMPL(set_type);
impl_->set_type(type);
Expand Down
10 changes: 10 additions & 0 deletions paddle/pir/core/value.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,16 @@ class IR_API Value {

Type type() const;

/// If this value is the result of an operation, return the operation that
/// defines it, else return nullptr;
Operation *defining_op() const;

template <typename OpTy>
OpTy defining_op() const {
/// It is safety even if defining_op() return nullptr.
return OpTy::dyn_cast(defining_op());
}

void set_type(Type type);

std::string PrintUdChain();
Expand Down
30 changes: 29 additions & 1 deletion paddle/pir/dialect/control_flow/ir/cf_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,37 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pir/dialect/control_flow/ir/cf_dialect.h"
#include "paddle/pir/core/ir_printer.h"
#include "paddle/pir/dialect/control_flow/ir/cf_op.h"
#include "paddle/pir/dialect/control_flow/ir/cf_type.h"

namespace pir {
void ControlFlowDialect::initialize() { RegisterOps<YieldOp>(); }
void ControlFlowDialect::initialize() {
RegisterTypes<StackType, InletType, OutletType>();
RegisterOps<YieldOp, CreateStackOp, PushBackOp, PopBackOp, HasElementsOp>();
}

void ControlFlowDialect::PrintType(pir::Type type, std::ostream &os) const {
os << name();
os << '.';
if (type.isa<StackType>()) {
os << "stack";
} else if (type.isa<InletType>()) {
os << "inlet";
} else if (type.isa<OutletType>()) {
os << "outlet";
} else {
os << "unknown type";
}
}

void ControlFlowDialect::PrintOperation(pir::Operation *op,
pir::IrPrinter &printer) const {
if (auto create_op = op->dyn_cast<CreateStackOp>()) {
create_op.Print(printer);
} else {
printer.PrintGeneralOperation(op);
}
}
} // namespace pir
IR_DEFINE_EXPLICIT_TYPE_ID(pir::ControlFlowDialect)
4 changes: 3 additions & 1 deletion paddle/pir/dialect/control_flow/ir/cf_dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ class ControlFlowDialect : public Dialect {
initialize();
}
static const char *name() { return "cf"; }

void PrintType(pir::Type type, std::ostream &os) const override;
void PrintOperation(pir::Operation *op,
pir::IrPrinter &printer) const override; // NOLINT
private:
void initialize();
};
Expand Down
Loading

0 comments on commit 4cb2a0c

Please sign in to comment.