Skip to content

Commit

Permalink
fix_repeat_interleave_op (#58379)
Browse files Browse the repository at this point in the history
* fix_repeat_interleave_op

* fix repeat_interleave

* fix

* fix

* fix codestyle

* fix codestyle

* fix codestyle

* fix codestyle

* fix codestyle

* fix codestyle

* fix codestyle

* fix codestyle

* fix style

* fix style
  • Loading branch information
xingmingyyj authored Oct 31, 2023
1 parent 721f834 commit ffd3392
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 5 deletions.
82 changes: 79 additions & 3 deletions paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2492,6 +2492,79 @@ struct ShareBufferOpTranscriber : public OpTranscriber {
}
};

struct RepeatInterLeaveOpTranscriber : public OpTranscriber {
pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx,
const OpDesc& op_desc) override {
std::string target_op_name;
if (op_desc.HasInput("RepeatsTensor") &&
!op_desc.Input("RepeatsTensor").empty()) {
target_op_name = "pd_op.repeat_interleave_with_tensor_index";
} else {
target_op_name = "pd_op.repeat_interleave";
}
const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
return op_info;
}

std::vector<pir::Value> GenerateOperationInput(
pir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
const std::string& normalized_op_name,
const OpInputInfoList& input_infos,
pir::Block* block) override {
std::vector<pir::Value> op_inputs;
auto x_names = op_desc.Input("X", true);
auto input = param_map->at(x_names[0]).value;
op_inputs.push_back(input);
if (op_desc.HasInput("RepeatsTensor") &&
!op_desc.Input("RepeatsTensor").empty()) {
auto repeats_names = op_desc.Input("RepeatsTensor", true);
input = param_map->at(repeats_names[0]).value;
op_inputs.push_back(input);
}
return op_inputs;
}
};

struct RepeatInterLeaveGradOpTranscriber : public OpTranscriber {
pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx,
const OpDesc& op_desc) override {
std::string target_op_name;
if (op_desc.HasInput("RepeatsTensor") &&
!op_desc.Input("RepeatsTensor").empty()) {
target_op_name = "pd_op.repeat_interleave_with_tensor_index_grad";
} else {
target_op_name = "pd_op.repeat_interleave_grad";
}
const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
return op_info;
}

std::vector<pir::Value> GenerateOperationInput(
pir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
const std::string& normalized_op_name,
const OpInputInfoList& input_infos,
pir::Block* block) override {
std::vector<pir::Value> op_inputs;
auto x_names = op_desc.Input("X", true);
auto input = param_map->at(x_names[0]).value;
op_inputs.push_back(input);
if (op_desc.HasInput("RepeatsTensor") &&
!op_desc.Input("RepeatsTensor").empty()) {
auto repeats_names = op_desc.Input("RepeatsTensor", true);
input = param_map->at(repeats_names[0]).value;
op_inputs.push_back(input);
}
auto out_grad_names = op_desc.Input("Out@GRAD", true);
input = param_map->at(out_grad_names[0]).value;
op_inputs.push_back(input);

return op_inputs;
}
};
OpTranslator::OpTranslator() {
pir::IrContext* ctx = pir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
Expand All @@ -2501,8 +2574,8 @@ OpTranslator::OpTranslator() {
special_handlers["assign_value"] = AssignValueOpTranscriber();
special_handlers["range"] = ArangeOpTranscriber();
special_handlers["cast"] = CastOpTranscriber();
special_handlers["feed"] = FeedOpTranscriber();
special_handlers["data"] = DataOpTranscriber();
special_handlers["feed"] = FeedOpTranscriber();
special_handlers["fetch"] = FetchOpTranscriber();
special_handlers["fetch_v2"] = FetchOpTranscriber();
special_handlers["fill_constant"] = FillConstantTranscriber();
Expand All @@ -2514,11 +2587,14 @@ OpTranslator::OpTranslator() {
special_handlers["one_hot_v2"] = OneHotTranscriber();
special_handlers["reduce_all"] = ReduceOpTranscriber();
special_handlers["reduce_any"] = ReduceOpTranscriber();
special_handlers["repeat_interleave"] = RepeatInterLeaveOpTranscriber();
special_handlers["repeat_interleave_grad"] =
RepeatInterLeaveGradOpTranscriber();
special_handlers["rnn"] = RnnOpTranscriber();
special_handlers["shadow_output"] = ShadowOutputOpTranscriber();
special_handlers["share_buffer"] = ShareBufferOpTranscriber();
special_handlers["set_value"] = LegacySetValueDispatcher();
special_handlers["set_value_grad"] = SetValueGradOpTranscriber();
special_handlers["shadow_output"] = ShadowOutputOpTranscriber();
special_handlers["share_buffer"] = ShareBufferOpTranscriber();
special_handlers["split"] = SplitOpTranscriber();
special_handlers["sum"] = AddNOpTranscriber();
special_handlers["tril_triu"] = TrilAndTriuOpTranscriber();
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,7 @@
func : RepeatInterleaveInferMeta
kernel :
func : repeat_interleave
data_type : x
backward: repeat_interleave_grad

- op : repeat_interleave_with_tensor_index
Expand Down
18 changes: 18 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2473,6 +2473,24 @@
attrs :
repeats : Repeats

- op : repeat_interleave
backward : repeat_interleave_grad
inputs :
x : X
outputs :
out : Out
attrs :
{repeats : Repeats, axis : dim}

- op : repeat_interleave_with_tensor_index
backward : repeat_interleave_with_tensor_index_grad
inputs :
{x : X, repeats: RepeatTensor}
outputs:
out : Out
attrs:
axis : dim

- op : reshape (reshape2)
backward : reshape_grad (reshape2_grad)
inputs:
Expand Down
4 changes: 2 additions & 2 deletions test/legacy_test/test_repeat_interleave_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,14 @@ def test_check_grad_normal(self):

class TestIndexSelectAPI(unittest.TestCase):
def input_data(self):
self.data_zero_dim_x = np.array(0.5)
self.data_zero_dim_x = np.array(0.5).astype('float32')
self.data_x = np.array(
[
[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0],
]
)
).astype('float32')
self.data_zero_dim_index = np.array(2)
self.data_index = np.array([0, 1, 2, 1]).astype('int32')

Expand Down
1 change: 1 addition & 0 deletions test/white_list/new_ir_op_test_white_list
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ test_put_along_axis_op
test_range
test_reduce_op
test_reduce_op_static_build
test_repeat_interleave_op
test_reshape_op
test_reverse_op
test_roi_align_op
Expand Down

0 comments on commit ffd3392

Please sign in to comment.