Skip to content

Commit

Permalink
fix bug of group norm and layer norm for npu (#10609)
Browse files Browse the repository at this point in the history
Co-authored-by: oneflow-ci-bot <[email protected]>
  • Loading branch information
crazy-JiangDongHua and oneflow-ci-bot authored Dec 26, 2024
1 parent 4bdf038 commit f7fa76f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 0 deletions.
3 changes: 3 additions & 0 deletions oneflow/user/ops/group_norm_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@ limitations under the License.

namespace oneflow {

DEFINE_ENV_BOOL(ONEFLOW_GROUP_NORM_USE_FP16_DIRECTLY, false);

namespace {

oneflow::DataType InferGnParamDataType(const DataType x_data_type) {
if (EnvBool<ONEFLOW_GROUP_NORM_USE_FP16_DIRECTLY>()) { return x_data_type; }
return (x_data_type == DataType::kFloat16 || x_data_type == DataType::kBFloat16)
? DataType::kFloat
: x_data_type;
Expand Down
7 changes: 7 additions & 0 deletions oneflow/user/ops/layer_norm_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ limitations under the License.

namespace oneflow {

DEFINE_ENV_BOOL(ONEFLOW_LAYER_NORM_PARAM_KEEP_DIM, false);

namespace {

int64_t ShiftNegativeAxisIfNeed(const Shape& shape, int64_t axis) {
Expand All @@ -31,6 +33,11 @@ Shape InferBnParamShape(const Shape& x_shape, const int64_t begin_norm_axis) {
DimVector bn_param_shape_dim_vec;
bn_param_shape_dim_vec.insert(bn_param_shape_dim_vec.end(), x_shape.dim_vec().cbegin(),
x_shape.dim_vec().cbegin() + begin_norm_axis);
if (EnvBool<ONEFLOW_LAYER_NORM_PARAM_KEEP_DIM>()) {
while (bn_param_shape_dim_vec.size() < x_shape.dim_vec().size()) {
bn_param_shape_dim_vec.push_back(1);
}
}
const Shape bn_param_shape(bn_param_shape_dim_vec);
return bn_param_shape;
}
Expand Down

0 comments on commit f7fa76f

Please sign in to comment.