Skip to content

Commit

Permalink
Merge branch 'develop' into remove_no_used_indexing_code
Browse files Browse the repository at this point in the history
  • Loading branch information
zoooo0820 committed Nov 23, 2023
2 parents 2a72fd8 + 6ca1b8e commit 6f5789b
Show file tree
Hide file tree
Showing 28 changed files with 703 additions and 190 deletions.
3 changes: 3 additions & 0 deletions paddle/fluid/imperative/gradient_accumulator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,9 @@ void TensorAdd(const VarType& src, VarType* dst) {
XPUTensorAddFunctor<platform::float16>(place, src_tensor, dst_tensor);
} else if (data_type == framework::DataTypeTrait<double>::DataType()) {
XPUTensorAddFunctor<double>(place, src_tensor, dst_tensor);
} else if (data_type ==
framework::DataTypeTrait<platform::bfloat16>::DataType()) {
XPUTensorAddFunctor<platform::bfloat16>(place, src_tensor, dst_tensor);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Gradient accumulation of data type (%s) on place (%s) is not "
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/generator/dist_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@
"""
MULTI_SINGLE_OUT_CREATION_TEMPLATE = """
auto dist_out_{idx} = SetKernelDistOutput(&{out}, spmd_info.second[{idx}]);
auto dense_out_{idx} = dist_out_{idx}->unsafe_mutable_value();
auto dense_out_{idx} = dist_out_{idx} ? dist_out_{idx}->unsafe_mutable_value() : nullptr;
if (!rank_is_in_current_mesh) {{
*dense_out_{idx} = phi::DenseTensor(
std::make_shared<phi::Allocation>(nullptr, 0, phi::distributed::GetDefaultPlace()),
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
output : Tensor(param_out), Tensor(moment1_out), Tensor(moment2_out), Tensor(beta1_pow_out), Tensor(beta2_pow_out), Tensor(master_param_out)
infer_meta :
func : AdamwInferMeta
spmd_rule : AdamwInferSpmdDynamic
kernel :
func : adamw
data_type : param
Expand Down
33 changes: 26 additions & 7 deletions paddle/phi/backends/xpu/xpu3_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ XPUOpMap& get_kl3_ops() {
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"accuracy", XPUKernelSet({phi::DataType::FLOAT32})},
{"adadelta", XPUKernelSet({phi::DataType::FLOAT32})},
{"adamw", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"adamw",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"adam", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"adam_dense_param_sparse_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
Expand Down Expand Up @@ -176,10 +179,13 @@ XPUOpMap& get_kl3_ops() {
{"coalesce_tensor",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"concat_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"concat",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::FLOAT64,
phi::DataType::BOOL,
phi::DataType::INT8,
Expand Down Expand Up @@ -243,10 +249,13 @@ XPUOpMap& get_kl3_ops() {
{"einsum", XPUKernelSet({phi::DataType::FLOAT32})},
{"einsum_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_add_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"elementwise_add",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::INT64,
phi::DataType::INT32})},
{"elementwise_div_grad",
Expand All @@ -271,6 +280,7 @@ XPUOpMap& get_kl3_ops() {
{"elementwise_mul",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::INT32,
phi::DataType::INT64})},
{"elementwise_pow",
Expand Down Expand Up @@ -348,6 +358,7 @@ XPUOpMap& get_kl3_ops() {
phi::DataType::INT32,
phi::DataType::BOOL,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::FLOAT32})},
{"fill_diagonal_tensor",
XPUKernelSet({phi::DataType::INT64,
Expand All @@ -363,7 +374,8 @@ XPUOpMap& get_kl3_ops() {
phi::DataType::UINT8,
phi::DataType::BOOL,
phi::DataType::FLOAT32,
phi::DataType::FLOAT16})},
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"flatten2_grad",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
Expand Down Expand Up @@ -424,7 +436,8 @@ XPUOpMap& get_kl3_ops() {
XPUKernelSet({phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::FLOAT32,
phi::DataType::FLOAT16})},
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"gather",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
Expand Down Expand Up @@ -531,7 +544,9 @@ XPUOpMap& get_kl3_ops() {
{"logical_xor", XPUKernelSet({phi::DataType::BOOL})},
{"lookup_table_v2_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"lookup_table_v2",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"masked_select",
XPUKernelSet({phi::DataType::INT32,
phi::DataType::INT64,
Expand Down Expand Up @@ -670,13 +685,15 @@ XPUOpMap& get_kl3_ops() {
{"reshape2_grad",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL,
phi::DataType::FLOAT32})},
{"reshape2",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL,
Expand All @@ -702,6 +719,7 @@ XPUOpMap& get_kl3_ops() {
{"scale",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::INT64,
phi::DataType::INT32})},
{"scatter",
Expand Down Expand Up @@ -917,7 +935,8 @@ XPUOpMap& get_kl3_ops() {
{"triu",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::FLOAT16})},
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"tril_triu_grad",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/backends/xpu/xpu_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ int get_xpu_max_ptr_size(int dev_id) {
break;
default:
PADDLE_THROW(phi::errors::InvalidArgument(
"Only support get max ptr size of XPU1 or XPU2."));
"Only support get max ptr size of XPU1, XPU2 or XPU3."));
break;
}
return max_ptr_size;
Expand Down
Loading

0 comments on commit 6f5789b

Please sign in to comment.