Skip to content

Commit

Permalink
support optional and invoke (#57084)
Browse files Browse the repository at this point in the history
  • Loading branch information
Charles-hit authored Sep 8, 2023
1 parent f9a86e3 commit 430a657
Show file tree
Hide file tree
Showing 10 changed files with 142 additions and 74 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/ir/dialect/op_generator/op_interface_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
res[i].resize(tensor_res[i].size());
for (size_t j = 0; j < tensor_res[i].size(); ++j) {
if(tensor_res[i][j].defined()){
res[i][j] = std::static_pointer_cast<primitive::LazyTensor>(tensor_res[i][j].impl())->getValue().dyn_cast<ir::OpResult>();
res[i][j] = std::static_pointer_cast<primitive::LazyTensor>(tensor_res[i][j].impl())->value().dyn_cast<ir::OpResult>();
}
}
}"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ std::vector<std::vector<ir::OpResult>> SumOp::Vjp(
if (tensor_res[0][0].defined()) {
res[0][0] =
std::static_pointer_cast<primitive::LazyTensor>(tensor_res[0][0].impl())
->getValue()
->value()
.dyn_cast<ir::OpResult>();
}
return res;
Expand Down
50 changes: 33 additions & 17 deletions paddle/fluid/primitive/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import hashlib
import pathlib
import sys
from typing import Dict, List

import jinja2
import yaml
Expand All @@ -28,6 +27,7 @@
)
import filters as op_gen_filters
import tests_utils as op_gen_tests
from parse_utils import to_named_dict

# import from paddle/fluid/ir/dialect/op_generator/api_gen.py
sys.path.append(
Expand Down Expand Up @@ -62,6 +62,9 @@
'slice_grad',
'transpose_grad',
'dropout_grad',
'cast_grad',
'slice_double_grad',
'layer_norm_grad',
]
VJP_COMPS = ['divide_grad', 'sum_grad', 'gelu_grad']
BACKENDS = [
Expand Down Expand Up @@ -130,6 +133,8 @@
'scatter',
'scatter_nd_add',
'dropout_grad',
'slice',
'layer_norm_grad',
]


Expand Down Expand Up @@ -219,21 +224,6 @@ def save(content: str, path: pathlib.Path):
print(f"Generate source file {path}")


def to_compat_dict(items: List[Dict]) -> Dict[str, Dict]:
compat_dict = {}
for item in items:
name = item["op"]
compat_dict[name] = item
return compat_dict


def to_apis_dict(apis):
apis_dict = {}
for api in apis:
apis_dict[api['name']] = api
return apis_dict


def get_inplace_api(apis):
inplace_apis = []
for api in apis:
Expand Down Expand Up @@ -271,7 +261,7 @@ def extend_compat_info(apis, compats):
attr['typename']
) or op_gen_tests.is_intarray(attr['typename']):
attr["support_tensor"] = False
apis_dict = to_apis_dict(apis)
apis_dict = to_named_dict(apis)
for compat_item in compats:
fwd_op_name = compat_item["op"]
if fwd_op_name not in apis_dict:
Expand Down Expand Up @@ -322,6 +312,31 @@ def extend_compat_info(apis, compats):
return apis


def process_backward_invoke_info(apis):
apis_dict = to_named_dict(apis)
for api in apis:
if api['is_fwd']:
continue
if 'invoke' in api and api['invoke']['func'] in apis_dict:
args = api['invoke']['args'].split(',')
args = [arg.strip() for arg in args]
attrs_dict = to_named_dict(api['attrs'])
inputs_dict = to_named_dict(api['inputs'])
arg_inputs = []
arg_attrs = []
for arg in args:
if arg in inputs_dict:
arg_inputs.append(arg)
elif arg in attrs_dict and attrs_dict[arg].get(
"support_tensor", False
):
arg_inputs.append(arg + '_')
else:
arg_attrs.append(arg)
args = arg_inputs + arg_attrs
api['invoke']['args'] = ', '.join(args)


def gen(
prim_path: pathlib.Path,
fwd_path: pathlib.Path,
Expand Down Expand Up @@ -369,6 +384,7 @@ def gen(
]
apis = extend_compat_info(apis, compats)
apis = apis + get_inplace_api(apis)
process_backward_invoke_info(apis)
render(
templates_dir,
destination_dir,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <vector>

#include "paddle/phi/api/include/tensor.h"
#include "paddle/utils/optional.h"


namespace paddle {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,69 @@ template <>
{{common.ret(outputs)}} {{name}}<LazyTensor>({{common.params(inputs, attrs, mutable_attribute_as_inputs, False)}})
{%- endmacro -%}

{% macro body(name, inputs, outputs, attrs, mutable_attribute_as_inputs=False) %}
{%- set output_names = [] -%}
{%- for o in outputs -%} {%- do output_names.append(o.name) -%} {%-endfor-%}
{%- macro prepare_ir_api_inputs(inputs)-%}
{%- for input in inputs -%}
{% if input.typename=='Tensor[]' %}
std::vector<ir::OpResult> {{input.name}}_res({{input.name}}.size());
std::transform({{input.name}}.begin(), {{input.name}}.end(), {{input.name}}_res.begin(), [](const Tensor& t) {
return std::static_pointer_cast<LazyTensor>(t.impl())->getValue().dyn_cast<ir::OpResult>();
{% if input.typename=='Tensor[]' and not input.optional %}
std::vector<ir::OpResult> {{input.name}}_res({{input.name}}.size());
std::transform({{input.name}}.begin(), {{input.name}}.end(), {{input.name}}_res.begin(), [](const Tensor& t) {
return std::static_pointer_cast<LazyTensor>(t.impl())->value().dyn_cast<ir::OpResult>();
});
{% elif input.typename=='Tensor[]' and input.optional %}
std::vector<ir::OpResult> {{input.name}}_res({{input.name}}.size());
if({{input.name}}) {
std::transform({{input.name}}.get().begin(), {{input.name}}.get().end(), {{input.name}}_res.begin(), [](const Tensor& t) {
return std::static_pointer_cast<LazyTensor>(t.impl())->value().dyn_cast<ir::OpResult>();
});
}
{% elif input.typename=='Tensor' and not input.optional %}
ir::OpResult {{input.name}}_res = std::static_pointer_cast<LazyTensor>({{input.name}}.impl())->value().dyn_cast<ir::OpResult>();
{% else %}
ir::OpResult {{input.name}}_res = std::static_pointer_cast<LazyTensor>({{input.name}}.impl())->getValue().dyn_cast<ir::OpResult>();
ir::OpResult {{input.name}}_res;
if({{input.name}}) {
{{input.name}}_res = std::static_pointer_cast<LazyTensor>({{input.name}}.get().impl())->value().dyn_cast<ir::OpResult>();
}
{% endif %}
{% endfor %}
{%- for attr in attrs -%}
{%- endmacro -%}

{%- macro get_static_backend_outputs(outputs)-%}
{%- if outputs|length == 1 -%}
{%- if outputs[0].typename == 'Tensor' -%}
Tensor {{outputs[0].name}}(std::make_shared<LazyTensor>(op_res));
return {{outputs[0].name}};
{%- elif outputs[0].typename == 'Tensor[]' -%}
std::vector<Tensor> {{outputs[0].name}}(op_res.size());
std::transform(op_res.begin(), op_res.end(), {{outputs[0].name}}.begin(), [](const ir::OpResult& res) {
return Tensor(std::make_shared<LazyTensor>(res));
});
return {{outputs[0].name}};
{%- else -%} {#- render nothing -#}
{%- endif -%}
{%- elif outputs|length > 1 -%}
{%- for i in range(outputs|length) %}
auto op_res_{{i}} = std::get<{{i}}>(op_res);
{% if outputs[i].typename == 'Tensor' %}
Tensor {{outputs[i].name}}(std::make_shared<LazyTensor>(op_res_{{i}}));
{% elif outputs[i].typename == 'Tensor[]' %}
std::vector<Tensor> {{outputs[i].name}}(op_res_{{i}}.size());
std::transform(op_res_{{i}}.begin(), op_res_{{i}}.end(), {{outputs[i].name}}.begin(), [](const ir::OpResult& res) {
return Tensor(std::make_shared<LazyTensor>(res));
});
{% else %} {#- render nothing -#}
{% endif %}
{% endfor -%}
return std::make_tuple({%- for i in range(outputs|length) -%}{{outputs[i].name}}{%- if i!=outputs|length - 1 -%}, {% endif -%}{%- endfor -%});
{%- else -%} {#- render nothing -#}
{%- endif -%}
{%- endmacro -%}

{% macro body(name, inputs, outputs, attrs, mutable_attribute_as_inputs=False) %}
{%- set output_names = [] -%}
{%- for o in outputs -%} {%- do output_names.append(o.name) -%} {%-endfor-%}
{{prepare_ir_api_inputs(inputs)}}
{%- for attr in attrs %}
{% if mutable_attribute_as_inputs and attr is mutable_attribute %}
ir::OpResult {{attr.name}}_res = std::static_pointer_cast<LazyTensor>({{attr.name~'_'}}.impl())->getValue().dyn_cast<ir::OpResult>();
ir::OpResult {{attr.name}}_res = std::static_pointer_cast<LazyTensor>({{attr.name~'_'}}.impl())->value().dyn_cast<ir::OpResult>();
{% endif %}
{% endfor %}
{%- set input_names = [] -%}
Expand All @@ -52,48 +99,25 @@ template <>
{%- do attr_names.append(common.phi2ir_attr(i)) -%}
{%- endif -%}
{% endfor %}
auto op_res = paddle::dialect::{{name}}({{common.args(input_names, attr_names)}});
{% if outputs|length == 1 %}
{% if outputs[0].typename == 'Tensor' %}
Tensor {{outputs[0].name}}(std::make_shared<LazyTensor>(op_res));
return {{outputs[0].name}};
{% elif outputs[0].typename == 'Tensor[]' %}
std::vector<Tensor> {{outputs[0].name}}(op_res.size());
std::transform(op_res.begin(), op_res.end(), {{outputs[0].name}}.begin(), [](const ir::OpResult& res) {
return Tensor(std::make_shared<LazyTensor>(res));
});
return {{outputs[0].name}};
{% else %} {#- render nothing -#}
{% endif %}
{% elif outputs|length > 1 %}
{% for i in range(outputs|length) %}
auto op_res_{{i}} = std::get<{{i}}>(op_res);
{% if outputs[i].typename == 'Tensor' %}
Tensor {{outputs[i].name}}(std::make_shared<LazyTensor>(op_res_{{i}}));
{% elif outputs[i].typename == 'Tensor[]' %}
std::vector<Tensor> {{outputs[i].name}}(op_res_{{i}}.size());
std::transform(op_res_{{i}}.begin(), op_res_{{i}}.end(), {{outputs[i].name}}.begin(), [](const ir::OpResult& res) {
return Tensor(std::make_shared<LazyTensor>(res));
});
{% else %} {#- render nothing -#}
{% endif %}
{% endfor %}
return std::make_tuple({% for i in range(outputs|length) %}{{outputs[i].name}}{%- if i!=outputs|length - 1 -%}, {% endif %}{% endfor %});
{% else %} {#- render nothing -#}
{% endif %}
{% endmacro %}
auto op_res = paddle::dialect::{{name}}({{common.args(input_names, attr_names)}});
{{get_static_backend_outputs(outputs)}}
{%- endmacro %}


{% for api in apis %}
{% if api.name in backend_white_list %}
{% set api_outputs = api.outputs | trip_intermediate %}
{{sig(api.name, api.inputs, api_outputs, api.attrs)}} {
{% filter indent(2, True) %}
{{body(api.name, api.inputs, api_outputs, api.attrs)}}
{% endfilter %}
}

{% if api.attrs is exist_mutable_attribute %}
{{sig(api.name, api.inputs, api_outputs, api.attrs, True)}} {
{% filter indent(2, True) %}
{{body(api.name, api.inputs, api_outputs, api.attrs, True)}}
{% endfilter %}
}

{% endif %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "paddle/fluid/primitive/utils/utils.h"
#include "paddle/ir/core/operation.h"
#include "paddle/phi/core/flags.h"
#include "paddle/utils/optional.h"

PHI_DECLARE_string(tensor_operants_mode);

Expand All @@ -33,14 +34,14 @@ if (paddle::prim::StaticCompositeContext::Instance().IsBwdPrimEnabled()) {
}
{% else %}
{{body_unprim(api)}}
{% endif %}
{%- endif %}
return vjp_res;
{% endmacro %}
{%- endmacro -%}

{% macro get_mutable_attribute(attrs, api_name) %}
{% for i in attrs %}
{%- if i is mutable_attribute -%}
auto* {{i.name}}_define_op = std::static_pointer_cast<primitive::LazyTensor>({{i.name~'_'}}.impl())->getValue().dyn_cast<ir::OpResult>().GetDefiningOp();
auto* {{i.name}}_define_op = std::static_pointer_cast<primitive::LazyTensor>({{i.name~'_'}}.impl())->value().dyn_cast<ir::OpResult>().GetDefiningOp();
{% if i.typename is scalar %}
if({{i.name}}_define_op->name() != "pd.full") {
PADDLE_THROW(platform::errors::Unimplemented(
Expand All @@ -62,6 +63,7 @@ auto {{i.name}} = {{i.name}}_define_op->attribute("value").dyn_cast<paddle::dial

{% macro body_unprim(api) %}
{%- set input_names=[] -%}
{%- for api in apis -%} {%- do api_map.update({api.name: api}) -%} {%- endfor -%}
{%- for i in api.inputs -%} {%- do input_names.append(i.name) -%} {%- endfor -%}
{%- set attr_names=[] -%}
{%- for i in api.attrs -%}
Expand All @@ -71,29 +73,29 @@ auto {{i.name}} = {{i.name}}_define_op->attribute("value").dyn_cast<paddle::dial
{%- do attr_names.append(i.name) -%}
{%- endif -%}
{%- endfor %}
{% if 'invoke' in api and api.invoke.func in api_map %}
auto op_res = backend::{{api.invoke.func}}<LazyTensor>({{api.invoke.args}});
{% else %}
auto op_res = backend::{{api.name}}<LazyTensor>({{common.args(input_names, attr_names)}});
{% endif %}
{% set outputs = api.outputs|trip_intermediate %} {#- ignore intermediate output -#}
{% if outputs|length > 1 %}
{% for i in range(outputs|length) %}
auto out{{i}} = std::get<{{i}}>(op_res);
{% if outputs[i].typename=='Tensor' %}
vjp_res[{{i}}][0] = !stop_gradients[{{i}}][0] ? out{{i}} : vjp_res[{{i}}][0];
vjp_res[{{i}}][0] = std::get<{{i}}>(op_res);
{% else %}
for (size_t i=0; i< stop_gradients[{{i}}].size(); i++ ) {
vjp_res[{{i}}][i] = !stop_gradients[{{i}}][i] ? out{{i}}[i] : vjp_res[{{i}}][i];
}
vjp_res[{{i}}] = std::get<{{i}}>(op_res);
{% endif %}
{% endfor %}
{% elif outputs|length == 1 %}
{% if outputs[0].typename=='Tensor' %}
vjp_res[0][0] = !stop_gradients[0][0] ? op_res : vjp_res[0][0];
vjp_res[0][0] = op_res;
{% else %}
for (size_t i=0; i< stop_gradients[0].size(); i++ ) {
vjp_res[0][i] = !stop_gradients[0][i] ? op_res[i] : vjp_res[0][i];
}
vjp_res[0] = op_res;
{% endif %}
{% else %} {#- render nothing -#}
{% endif %}
vjp_res = ConstructVjpResultByStopGradients(vjp_res, stop_gradients);
{% endmacro %}

{% macro body_prim(api) %}
Expand All @@ -120,7 +122,7 @@ details::{{api.composite.func_name}}<LazyTensor>({{api.composite.func_args}});
{{sig(api.name, backward_api.name, backward_api.inputs, backward_api.attrs, backward_api.outputs)}} {
{% filter indent(2, True) %}
{{body(backward_api)}}
{% endfilter %}
{% endfilter -%}
}

{% endif %}
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/primitive/type/lazy_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ class LazyTensor : public phi::ExtendedTensor,
value_.type().dyn_cast<paddle::dialect::DenseTensorType>().dtype());
}

ir::Value getValue() const { return value_; }
ir::Value value() const { return value_; }

const phi::Place& place() const override { return place_; }

bool initialized() const override { return value_.impl() != nullptr; }

void set_empty_type() { value_.set_type(ir::Type()); }

private:
ir::Value value_;
mutable phi::DDim dims_;
Expand Down
18 changes: 18 additions & 0 deletions paddle/fluid/primitive/utils/static_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,23 @@ void set_output<LazyTensor>(const paddle::Tensor& x_tmp, paddle::Tensor* x) {
x->set_impl(x_tmp.impl());
}

std::vector<std::vector<Tensor>> ConstructVjpResultByStopGradients(
const std::vector<std::vector<Tensor>>& outputs,
const std::vector<std::vector<bool>>& stop_gradients) {
std::vector<std::vector<Tensor>> vjp_results(outputs.size());
for (size_t i = 0; i < outputs.size(); ++i) {
vjp_results[i].reserve(outputs[i].size());
for (size_t j = 0; j < outputs[i].size(); ++j) {
if (stop_gradients[i][j]) {
// Use Tensor's impl is nullptr to indicate it has no gradient
vjp_results[i].emplace_back(Tensor());
} else {
vjp_results[i].emplace_back(outputs[i][j]);
}
}
}
return vjp_results;
}

} // namespace primitive
} // namespace paddle
Loading

0 comments on commit 430a657

Please sign in to comment.