Skip to content

Commit

Permalink
Add elementwise ops (PaddlePaddle#59)
Browse files Browse the repository at this point in the history
* add elementwise_op_handler
* fix pow_handler
  • Loading branch information
gglin001 authored Aug 13, 2021
1 parent 34e760f commit 82d4973
Show file tree
Hide file tree
Showing 9 changed files with 247 additions and 42 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/framework/ipu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ set(POPART_CANONICALIZATION_SRC
"popart_canonicalization/canonicalization_utils.cc"
"popart_canonicalization/op_builder.cc"
"popart_canonicalization/activation_ops.cc"
"popart_canonicalization/elementwise_ops.cc"
"popart_canonicalization/logic_ops.cc"
"popart_canonicalization/math_ops.cc"
"popart_canonicalization/nn_ops.cc"
Expand Down
17 changes: 14 additions & 3 deletions paddle/fluid/framework/ipu/ipu_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,18 @@ std::unique_ptr<popart::Optimizer> IpuBackend::GetPopartOptimizer() {
}

void IpuBackend::Prepare() {
VLOG(1) << "Save Model to file paddle_model.onnx ...\n";
builder_->saveModelProto("paddle_model.onnx");

VLOG(1) << "Get ModelProto ...\n";
auto proto = builder_->getModelProto();

// for onnx graph debug
// std::ofstream onnxfile("paddle_model_no_check.onnx",
// std::ios_base::binary);
// onnxfile.write(proto.data(), proto.size());
// onnxfile.close();

VLOG(1) << "Save Model to file paddle_model.onnx ...\n";
builder_->saveModelProto("paddle_model.onnx");

VLOG(1) << "Constructing DataFlow\n";
std::vector<popart::TensorId> anchor_ids;
for (popart::TensorId item : outputs_) {
Expand Down Expand Up @@ -343,6 +349,11 @@ void IpuBackend::LowerBody(const ir::Graph* graph) {
popart::TensorId result =
builder_->aiOnnxOpset11().reducemean(inputs, axes, keepdims);
tensors_.emplace(outputs[0], result);
} else if (op_type == "Reshape") {
auto inputs = GetOpInputs(op);
auto outputs = op->Output("__outputs__");
popart::TensorId result = builder_->aiOnnxOpset11().reshape(inputs);
tensors_.emplace(outputs[0], result);
} else if (op_type == "Softmax") {
auto inputs = GetOpInputs(op);
auto outputs = op->Output("__outputs__");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ void ReplaceNodeInputs(ir::Node *node, ir::Node *new_node) {
if (node->inputs.empty()) {
return;
}
new_node->inputs = node->inputs;
for (auto *node_in : node->inputs) {
for (size_t i = 0; i < node_in->outputs.size(); ++i) {
if (node_in->outputs[i] == node) {
Expand Down Expand Up @@ -152,6 +151,16 @@ const int ConvertDataType(const int &type) {
}
}

Node *GetInputNode(const std::string &name, const Node *node) {
auto node_name = node->Op()->Input(name).front();
for (auto *n : node->inputs) {
if (n->Name() == node_name) {
return n;
}
}
return nullptr;
}

} // namespace ipu
} // namespace framework
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ namespace paddle {
namespace framework {
namespace ipu {

// TODO(alleng) remove ir::
using ir::Graph;
using ir::Node;

#define REGISTER_HANDLER(name, func) \
static bool __UNUSED_##name = \
paddle::framework::ipu::RegisterHandler(#name, func)
Expand All @@ -47,6 +51,8 @@ void CopyOpAttr(const std::string &attr_name, OpDesc *op, OpDesc *new_op,

const int ConvertDataType(const int &type);

Node *GetInputNode(const std::string &name, const Node *node);

} // namespace ipu
} // namespace framework
} // namespace paddle
109 changes: 109 additions & 0 deletions paddle/fluid/framework/ipu/popart_canonicalization/elementwise_ops.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ipu/popart_canonicalization/canonicalization_utils.h"
#include "paddle/fluid/framework/ipu/popart_canonicalization/op_builder.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace framework {
namespace ipu {
namespace {

ir::Node *elementwise_op_handler(ir::Graph *graph, ir::Node *node,
const std::string &type) {
auto *op = node->Op();
auto x_shape = op->Block()->FindVar(op->Input("X").front())->GetShape();
int64_t x_rank = x_shape.size();
auto y_shape = op->Block()->FindVar(op->Input("Y").front())->GetShape();
int64_t y_rank = y_shape.size();

auto axis = BOOST_GET_CONST(int, op->GetAttr("axis"));
if (axis == -1 || axis == x_rank - 1 || x_rank == y_rank) {
auto new_node = CreateBaseOp(
graph, type, {GetInputNode("X", node), GetInputNode("Y", node)},
node->outputs);
return new_node;
} else {
auto y_new_shape = std::vector<int64_t>(x_rank, 1);
for (int i = axis; i < axis + y_rank; ++i) {
y_new_shape[i] = y_shape[i - axis];
}
auto attrs = AttributeMap{
{"value", y_new_shape},
{"dims", std::vector<int64_t>{x_rank}},
{"dtype", ONNXDataType::INT64},
};
// constant
auto new_node_const = CreateConst(graph, {}, {}, attrs);
// reshape
auto new_node_reshape =
CreateBaseOp(graph, "Reshape",
{GetInputNode("Y", node), new_node_const->outputs[0]}, {});
ReplaceNodeInputs(node, new_node_reshape);
// elementwise_op
auto new_node = CreateBaseOp(
graph, type, {GetInputNode("X", node), new_node_reshape->outputs[0]},
{node->outputs[0]});
ReplaceNodeInputs(node, new_node);
ReplaceNodeOutputs(node, new_node);
return new_node;
}
}

ir::Node *elementwise_add_handler(ir::Graph *graph, ir::Node *node) {
return elementwise_op_handler(graph, node, "Add");
}

ir::Node *elementwise_sub_handler(ir::Graph *graph, ir::Node *node) {
return elementwise_op_handler(graph, node, "Sub");
}

ir::Node *elementwise_div_handler(ir::Graph *graph, ir::Node *node) {
return elementwise_op_handler(graph, node, "Div");
}

ir::Node *elementwise_mul_handler(ir::Graph *graph, ir::Node *node) {
return elementwise_op_handler(graph, node, "Mul");
}

ir::Node *elementwise_min_handler(ir::Graph *graph, ir::Node *node) {
return elementwise_op_handler(graph, node, "Min");
}

ir::Node *elementwise_max_handler(ir::Graph *graph, ir::Node *node) {
return elementwise_op_handler(graph, node, "Max");
}

ir::Node *elementwise_pow_handler(ir::Graph *graph, ir::Node *node) {
return elementwise_op_handler(graph, node, "Pow");
}

ir::Node *elementwise_mod_handler(ir::Graph *graph, ir::Node *node) {
return elementwise_op_handler(graph, node, "Mod");
}

REGISTER_HANDLER(elementwise_add, elementwise_add_handler);
REGISTER_HANDLER(elementwise_sub, elementwise_sub_handler);
REGISTER_HANDLER(elementwise_div, elementwise_div_handler);
REGISTER_HANDLER(elementwise_mul, elementwise_mul_handler);
REGISTER_HANDLER(elementwise_min, elementwise_min_handler);
REGISTER_HANDLER(elementwise_max, elementwise_max_handler);
REGISTER_HANDLER(elementwise_pow, elementwise_pow_handler);
REGISTER_HANDLER(elementwise_mod, elementwise_mod_handler);

} // namespace
} // namespace ipu
} // namespace framework
} // namespace paddle
28 changes: 4 additions & 24 deletions paddle/fluid/framework/ipu/popart_canonicalization/math_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,6 @@ namespace framework {
namespace ipu {
namespace {

ir::Node *elementwise_add_handler(ir::Graph *graph, ir::Node *node) {
auto *op = node->Op();
auto op_desc = std::make_unique<framework::OpDesc>();
op_desc->SetType("Add");

std::vector<std::string> inputs;
inputs.push_back(op->Input("X").front());
inputs.push_back(op->Input("Y").front());
op_desc->SetInput("__inputs__", inputs);
std::vector<std::string> outputs;
outputs.push_back(op->Output("Out").front());
op_desc->SetOutput("__outputs__", outputs);

op_desc->Flush();
return graph->CreateOpNode(op_desc.get());
}

ir::Node *reduce_mean_handler(ir::Graph *graph, ir::Node *node) {
auto *op = node->Op();
auto op_desc = std::make_unique<framework::OpDesc>();
Expand Down Expand Up @@ -68,12 +51,10 @@ ir::Node *pow_handler(ir::Graph *graph, ir::Node *node) {
auto *op = node->Op();
auto value_ = BOOST_GET_CONST(float, op->GetAttr("factor"));
auto attrs = MakeConstAttributeMap(value_, {1}, ONNXDataType::FLOAT);
auto new_node_const = CreateConst(graph, {node}, {}, attrs);
ReplaceNodeOutputs(node, new_node_const);

auto new_node_pow =
CreatePow(graph, {node->inputs[0], new_node_const->outputs[0]},
{node->outputs[0]}, {});
auto new_node_const = CreateConst(graph, {}, {}, attrs);
auto new_node_pow = CreateBaseOp(
graph, "Pow", {GetInputNode("X", node), new_node_const->outputs[0]},
{node->outputs[0]});
ReplaceNodeInputs(node, new_node_pow);
return new_node_pow;
}
Expand Down Expand Up @@ -129,7 +110,6 @@ ir::Node *softmax_handler(ir::Graph *graph, ir::Node *node) {
return graph->CreateOpNode(op_desc.get());
}

REGISTER_HANDLER(elementwise_add, elementwise_add_handler);
REGISTER_HANDLER(reduce_mean, reduce_mean_handler);
REGISTER_HANDLER(pow, pow_handler);
REGISTER_HANDLER(mul, mul_handler);
Expand Down
22 changes: 12 additions & 10 deletions paddle/fluid/framework/ipu/popart_canonicalization/op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,17 @@ ir::Node *MakeOpNode(ir::Graph *graph, const std::string &type,
return op;
}

Node *CreateBaseOp(ir::Graph *graph, const std::string &type,
const std::vector<ir::Node *> &inputs,
const std::vector<ir::Node *> &outputs,
const AttributeMap &attrs) {
auto node = MakeOpNode(graph, type, inputs, outputs);
if (!attrs.empty()) {
node->Op()->SetAttrMap(attrs);
}
return node;
}

AttributeMap MakeConstAttributeMap(float v, std::vector<int64_t> dims,
int dtype) {
size_t size = 1;
Expand All @@ -85,16 +96,7 @@ AttributeMap MakeConstAttributeMap(float v, std::vector<int64_t> dims,
ir::Node *CreateConst(ir::Graph *graph, const std::vector<ir::Node *> &inputs,
const std::vector<ir::Node *> &outputs,
const AttributeMap &attrs) {
auto node = MakeOpNode(graph, "Constant", {}, {});
node->Op()->SetAttrMap(attrs);
return node;
}

ir::Node *CreatePow(ir::Graph *graph, const std::vector<ir::Node *> &inputs,
const std::vector<ir::Node *> &outputs,
const AttributeMap &attrs) {
auto node = MakeOpNode(graph, "Pow", inputs, outputs);
return node;
return CreateBaseOp(graph, "Constant", inputs, outputs, attrs);
}

} // namespace ipu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,18 @@ ir::Node *MakeOpNode(ir::Graph *graph, const std::string &type,
const std::vector<ir::Node *> &inputs,
const std::vector<ir::Node *> &outputs);

Node *CreateBaseOp(ir::Graph *graph, const std::string &type,
const std::vector<ir::Node *> &inputs,
const std::vector<ir::Node *> &outputs,
const AttributeMap &attrs = {});

// TODO(alleng) make template
AttributeMap MakeConstAttributeMap(float v, std::vector<int64_t> dims,
int dtype);
ir::Node *CreateConst(ir::Graph *graph, const std::vector<ir::Node *> &inputs,
const std::vector<ir::Node *> &outputs,
const AttributeMap &attrs);

ir::Node *CreatePow(ir::Graph *graph, const std::vector<ir::Node *> &inputs,
const std::vector<ir::Node *> &outputs,
const AttributeMap &attrs);

} // namespace ipu
} // namespace framework
} // namespace paddle
86 changes: 86 additions & 0 deletions python/paddle/fluid/tests/unittests/ipu/test_elemetwise_add_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import numpy as np
import unittest
import paddle
import paddle.fluid
import paddle.static
import paddle.fluid.compiler as compiler

paddle.enable_static()
SEED = 2021


@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestAdd(unittest.TestCase):
def _test_add(self, run_ipu=True):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = SEED
startup_prog.random_seed = SEED
np.random.seed(SEED)

np_a = np.random.rand(3, 3, 3).astype(np.float32)
np_b = np.arange(1, 4).reshape([3]).astype(np.float32)
with paddle.static.program_guard(main_prog, startup_prog):
a = paddle.static.data(
name="a",
shape=[3, 3, 3],
dtype='float32', )
b = paddle.static.data(
name="b",
shape=[3],
dtype='float32', )
# out = paddle.fluid.layers.elementwise_add(a, b, axis=-1)
# out = paddle.fluid.layers.elementwise_add(a, b, axis=0)
# out = paddle.fluid.layers.elementwise_add(a, b, axis=1)
out = paddle.fluid.layers.elementwise_add(a, b, axis=2)

if run_ipu:
place = paddle.IPUPlace()
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)

if run_ipu:
feed_list = [a.name, b.name]
fetch_list = [out.name]
ipu_strategy = compiler.get_ipu_strategy()
ipu_strategy.is_training = False
program = compiler.IpuCompiler(
main_prog, ipu_strategy=ipu_strategy).compile(feed_list,
fetch_list)
else:
program = main_prog

result = exe.run(
program,
feed={'a': np_a,
'b': np_b},
fetch_list=[out], )
return result[0]

def test_add(self):
ipu_res = self._test_add(True)
cpu_res = self._test_add(False)
self.assertTrue(np.allclose(ipu_res, cpu_res, atol=1e-4))


if __name__ == "__main__":
unittest.main()

0 comments on commit 82d4973

Please sign in to comment.