Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Prim][PIR] Support composite rules of Llama ops #58018

Merged
merged 11 commits into from
Oct 13, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
'maximum',
'argsort',
'min',
'max',
'batch_norm',
'max_pool2d_with_index',
'pool2d',
Expand Down Expand Up @@ -183,6 +184,7 @@
'maximum',
'argsort',
'min',
'max',
'batch_norm',
'max_pool2d_with_index',
'pool2d',
Expand Down
12 changes: 12 additions & 0 deletions paddle/fluid/primitive/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,23 @@
'tanh_grad',
'transpose_grad',
'concat_grad',
'erf_grad',
'exp_grad',
'expand_grad',
'log_grad',
'gather_nd_grad',
'pad_grad',
'max_grad',
'slice_grad',
'tile_grad',
] # vjp list of primitive op
CUSTOM_VJP = [
'gelu_grad',
'layer_norm_grad',
'dropout_grad',
'silu_grad',
'softmax_grad',
'sqrt_grad',
] # custom vjp list of composite op
VJP_COMPS = PRIM_VJP + CUSTOM_VJP

Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/primitive/primitive.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,4 @@
- full
- cast
- sign
- slice
300 changes: 300 additions & 0 deletions paddle/fluid/primitive/rule/vjp/details.h
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,306 @@ void dropout_grad(const Tensor& mask,
}
}

template <typename T>
void erf_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
auto m_2_sqrt_pi = full<T>(phi::vectorize(x.dims()), M_2_SQRTPI, x.dtype());
auto neg_one = full<T>(phi::vectorize(x.dims()), -1.0, x.dtype());
auto neg_tmp = neg_one * x * x;
auto mul_tmp = m_2_sqrt_pi * exp<T>(neg_tmp);
set_output<T>(out_grad * mul_tmp, x_grad);
}
}

template <typename T>
void expand_grad(const Tensor& x,
const Tensor& out_grad,
const IntArray& shape,
Tensor* x_grad) {
if (x_grad) {
auto out_dims = phi::make_ddim(shape.GetData());
if (out_dims != x.dims()) {
auto axes = get_reduce_dims(x.dims(), out_dims);
if (!axes.size()) {
by_pass<T>(out_grad, x_grad);
} else {
auto reduced = out_grad.sum(phi::vectorize(axes), x.dtype(), false);
if (reduced.dims().size() != x.dims().size()) {
reduced = reshape<T>(reduced, x.shape());
}
set_output<T>(reduced, x_grad);
}
} else {
by_pass<T>(out_grad, x_grad);
}
}
}

template <typename T>
void log_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
// dx = dout / x
set_output<T>(out_grad / x, x_grad);
}
}

template <typename T>
void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
if (out.dtype() == phi::DataType::FLOAT16 ||
out.dtype() == phi::DataType::BFLOAT16) {
Tensor out_promote = cast<T>(out, phi::DataType::FLOAT32);
Tensor out_grad_promote = cast<T>(out_grad, phi::DataType::FLOAT32);
set_output<T>(cast<T>(out_promote * out_grad_promote, out.dtype()),
x_grad);
} else {
set_output<T>(out_grad * out, x_grad);
}
}
}

template <typename T>
void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
// This calculation is important for resnet.
auto x_grad_tmp = (0.5 / out) * out_grad;
set_output<T>(x_grad_tmp, x_grad);
}
}

template <typename T>
void silu_grad(const Tensor& x,
const Tensor& out,
const Tensor& out_grad,
Tensor* x_grad) {
if (x_grad) {
auto org_dtype = x.dtype();
bool need_cast = org_dtype == phi::DataType::FLOAT16 ||
org_dtype == phi::DataType::BFLOAT16;
if (need_cast) {
auto x_cast = cast<T>(x, phi::DataType::FLOAT32);
auto out_cast = cast<T>(out, phi::DataType::FLOAT32);
auto out_grad_cast = cast<T>(out_grad, phi::DataType::FLOAT32);
auto sigmoid = 1.0 / (1.0 + exp<T>(-x_cast));
auto res = out_grad_cast * sigmoid * (1.0 + x_cast - out_cast);
set_output<T>(cast<T>(res, org_dtype), x_grad);
} else {
auto sigmoid = 1.0 / (1.0 + exp<T>(-x));
auto res = out_grad * sigmoid * (1.0 + x - out);
set_output<T>(res, x_grad);
}
}
}

template <typename T>
void softmax_grad(const Tensor& out,
const Tensor& out_grad,
int axis,
Tensor* x_grad) {
if (x_grad) {
if (out_grad.dims().size() > 0) {
if (axis >= 0) {
auto new_out_grad = out_grad * out;
auto tmp_x_grad = new_out_grad -
out * sum<T>(new_out_grad, {axis}, out.dtype(), true);
set_output<T>(tmp_x_grad, x_grad);
} else {
auto new_out_grad = out_grad * out;
auto tmp_x_grad =
new_out_grad - out * sum<T>(new_out_grad,
{out.dims().size() + axis},
out.dtype(),
true);
set_output<T>(tmp_x_grad, x_grad);
}
} else {
set_output<T>(
full<T>(phi::vectorize(out_grad.dims()), 0.0, out_grad.dtype()),
x_grad);
}
}
}

template <typename T>
void gather_nd_grad(const Tensor& x,
const Tensor& index,
const Tensor& out_grad,
Tensor* x_grad) {
if (x_grad) {
auto zero_tensor = full<T>(phi::vectorize(x.dims()), 0.0, x.dtype());
auto x_grad_tmp = scatter_nd_add<T>(zero_tensor, index, out_grad);
set_output<T>(x_grad_tmp, x_grad);
}
}

template <typename T>
void pad_grad(const Tensor& input,
const Tensor& out_grad,
const std::vector<int>& paddings,
const Scalar& pad_value,
Tensor* input_grad) {
if (input_grad) {
size_t rank = input.dims().size();
auto out_dims = out_grad.dims();

std::vector<int64_t> starts(rank, 0);
std::vector<int64_t> ends(rank, 0);
std::vector<int64_t> axes(rank, 0);
std::vector<int64_t> infer_flags(rank, 1);
std::vector<int64_t> decrease_axis({});
for (size_t i = 0; i < rank; ++i) {
starts[i] = static_cast<int64_t>(paddings[2 * i]);
ends[i] = static_cast<int64_t>(out_dims[i] - paddings[2 * i + 1]);
axes[i] = i;
}
auto out_tmp =
slice<T>(out_grad, axes, starts, ends, infer_flags, decrease_axis);
set_output<T>(out_tmp, input_grad);
}
}

template <typename T>
void max_grad(const Tensor& x,
const Tensor& out,
const Tensor& out_grad,
const IntArray& axis,
bool keepdim,
bool reduce_all,
Tensor* x_grad) {
if (!x_grad) {
return;
}
auto zero_tensor = full<T>(phi::vectorize(x.dims()), 0.0, x.dtype());
std::vector<int64_t> x_dim = phi::vectorize<int64_t>(x.dims());
int64_t axis_size = axis.size();
int64_t x_dim_size = x_dim.size();
reduce_all = false;
if (reduce_all || axis_size == 0 || axis_size == x_dim_size) {
reduce_all = true;
} else {
reduce_all = false;
}
auto x_grad_tmp = Tensor();
if (x_dim_size == 0 || x_dim_size == 1 || keepdim) {
auto out_grad_tmp = out_grad.expand(IntArray(x_dim));
auto out_tmp = out.expand(IntArray(x_dim));
auto mask = equal<T>(x, out_tmp);
x_grad_tmp = where<T>(mask, out_grad_tmp, zero_tensor);
} else {
auto axis_ = std::vector<int64_t>();
if (reduce_all) {
for (int64_t i = 0; i < x_dim_size; i++) {
axis_.push_back(i);
}
} else {
axis_ = axis.GetData();
for (int64_t i = 0; i < axis_size; i++) {
if (axis[i] < 0) {
axis_[i] = axis[i] + x_dim_size;
}
}
}
auto out_grad_shape = get_unsqueeze_dims(out_grad, axis_);
auto out_grad_ = reshape<T>(out_grad, out_grad_shape);
auto out_ = reshape<T>(out, out_grad_shape);
auto out_grad_tmp = out_grad_.expand(IntArray(x_dim));
auto out_tmp = out_.expand(IntArray(x_dim));
auto mask = equal<T>(x, out_tmp);
x_grad_tmp = where<T>(mask, out_grad_tmp, zero_tensor);
}
set_output<T>(x_grad_tmp, x_grad);
}

template <typename T>
void slice_grad(const Tensor& input,
const Tensor& out_grad,
const std::vector<int64_t>& axes,
const IntArray& starts,
const IntArray& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis,
Tensor* input_grad) {
if (input_grad) {
size_t rank = input.dims().size();
auto out_dims = out_grad.dims();
std::vector<int64_t> origin_out_shape;
auto in_dims = input.dims();

auto decrease_size = decrease_axis.size();
if (decrease_size > 0) {
if (decrease_size == static_cast<size_t>(in_dims.size())) {
// all dims decrease
out_dims = phi::make_ddim(std::vector<int>(decrease_size, 1));
} else {
origin_out_shape.resize(out_dims.size() + decrease_size, -1);
for (size_t i = 0; i < decrease_size; ++i) {
origin_out_shape[decrease_axis[i]] = 1;
}

int index = 0;
for (size_t i = 0; i < origin_out_shape.size(); ++i) {
if (origin_out_shape[i] == -1) {
origin_out_shape[i] = out_dims[index];
++index;
}
}
out_dims = phi::make_ddim(origin_out_shape);
}
}

std::vector<int> offsets(rank, 0);
std::vector<int> extents(rank, 0);
for (size_t i = 0; i < rank; ++i) {
offsets[i] = 0;
extents[i] = out_dims[i];
}
for (size_t i = 0; i < axes.size(); ++i) {
int axis = axes[i];
int64_t start = starts[i] < 0 ? (starts[i] + in_dims[axis]) : starts[i];
start = std::max(start, static_cast<int64_t>(0));
offsets[axis] = start;
}

std::vector<int> paddings;
for (size_t i = 0; i < rank; ++i) {
paddings.push_back(offsets[i]);
paddings.push_back((in_dims[i] - out_dims[i]) - offsets[i]);
}
if (decrease_size > 0 &&
(decrease_size != static_cast<size_t>(in_dims.size()))) {
auto out_tmp =
pad<T>(reshape<T>(out_grad, origin_out_shape), paddings, 0.0);
set_output<T>(out_tmp, input_grad);
} else {
auto out_tmp = pad<T>(out_grad, paddings, 0.0);
set_output<T>(out_tmp, input_grad);
}
}
}

template <typename T>
void tile_grad(const Tensor& x,
const Tensor& out_grad,
const IntArray& repeat_times,
Tensor* x_grad) {
if (x_grad) {
auto repeat_times_data = repeat_times.GetData();
auto out_grad_shape = phi::vectorize<int>(out_grad.dims());
auto result = out_grad;
for (int i = 0; i < static_cast<int>(repeat_times_data.size()); i++) {
int size = out_grad_shape[i] / repeat_times_data[i];
std::vector<int> sections(repeat_times_data[i], size);
auto split_arr = split<T>(result, IntArray(sections), i);
result = full<T>(phi::vectorize(split_arr[0].dims()), 0.0, x.dtype());
for (int j = 0; j < static_cast<int>(split_arr.size()); j++) {
result = split_arr[j] + result;
}
}
result = reshape<T>(result, x.shape());
set_output<T>(result, x_grad);
}
}

} // namespace details
} // namespace primitive
} // namespace paddle
8 changes: 8 additions & 0 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,10 @@ void BindValue(py::module *m) {
.def("first_use", &Value::first_use, return_value_policy::reference)
.def("has_one_use", &Value::HasOneUse)
.def("use_empty", &Value::use_empty)
.def("replace_all_uses_with",
[](Value &self, Value &op_value) {
self.ReplaceAllUsesWith(op_value);
})
.def("__eq__", &Value::operator==)
.def("__eq__",
[](Value &self, OpResult &other) {
Expand Down Expand Up @@ -610,6 +614,10 @@ void BindOpResult(py::module *m) {
return false;
}
})
.def("replace_all_uses_with",
[](OpResult &self, OpResult &op_result) {
self.ReplaceAllUsesWith(op_result);
})
.def_property(
"stop_gradient",
[](OpResult &self) {
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@
kernel :
func : tile_grad
no_need_buffer : x
composite : tile_grad(x, outgrad, repeat_times, x_grad)
composite : tile_grad(x, out_grad, repeat_times, x_grad)
backward : tile_double_grad

- backward_op : trans_layout_grad
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/autograd/ir_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import collections
import logging
from collections.abc import Sequence

import paddle.pir
Expand Down Expand Up @@ -556,7 +557,7 @@ def create_backward_prune_set(inputs, outputs, no_grad_set, state):
if state.value_to_valuegrad[item] != []:
outputs_set.add(state.value_to_valuegrad[item][0][0])
else:
raise ValueError("input privided by inputs has no use")
logging.warning("input privided by inputs has no use")

inputs_set = set()
for output in outputs:
Expand Down
Loading