Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMP] Add multi_precision for sgd #38231

Merged
merged 7 commits into from
Dec 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions paddle/fluid/operators/optimizers/sgd_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,24 @@ class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Param", "(Tensor or SelectedRows) Input parameter");
AddInput("LearningRate", "(Tensor) Learning rate of SGD");
AddInput("Grad", "(Tensor or SelectedRows) Input gradient");
AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable();
AddOutput("ParamOut",
"(Tensor or SelectedRows, same with Param) "
"Output parameter, should share the same memory with Param");
AddOutput("MasterParamOut",
"The updated FP32 master weight for AMP. "
"It shared memory with Input(MasterParam).")
.AsDispensable();

AddAttr<bool>(
"use_mkldnn",
"(bool, default false) Indicates if MKL-DNN kernel will be used")
.SetDefault(false);
AddAttr<bool>("multi_precision",
"(bool, default false) "
"Whether to use multi-precision during weight updating.")
.SetDefault(false);

AddComment(R"DOC(

SGD operator
Expand Down
60 changes: 42 additions & 18 deletions paddle/fluid/operators/optimizers/sgd_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include <algorithm>
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/optimizers/sgd_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"

Expand All @@ -21,14 +22,19 @@ namespace operators {

namespace {

template <typename T>
__global__ void SGDKernel(const T* g, const T* p, const T* learning_rate,
const int num, T* p_out) {
T lr = learning_rate[0];
template <typename T, typename MT>
__global__ void SGDKernelMT(const T* param, const T* grad,
const T* learning_rate, const int num, T* param_out,
const MT* master_param, MT* master_param_out) {
MT lr = static_cast<MT>(learning_rate[0]);
CUDA_KERNEL_LOOP(i, num) {
T g_data = g[i];
T p_data = p[i];
p_out[i] = p_data - lr * g_data;
MT p_data = master_param ? master_param[i] : static_cast<MT>(param[i]);
MT g_data = static_cast<MT>(grad[i]);
p_data = p_data - lr * g_data;
param_out[i] = static_cast<T>(p_data);
if (master_param_out) {
master_param_out[i] = p_data;
}
}
}

Expand Down Expand Up @@ -63,30 +69,48 @@ class SGDOpKernel<platform::CUDADeviceContext, T>
"but the received is %s",
ctx.InputNames("Param").front(),
paddle::framework::ToTypeName(param_var->Type())));
using paddle::framework::Tensor;
using MPDType = typename details::MPTypeTrait<T>::Type;

auto* param = ctx.Input<framework::Tensor>("Param");
auto* param_out = ctx.Output<framework::Tensor>("ParamOut");
auto* learning_rate = ctx.Input<framework::Tensor>("LearningRate");

auto* grad_var = ctx.InputVar("Grad");

const bool multi_precision = ctx.Attr<bool>("multi_precision");
const Tensor* master_param = nullptr;
Tensor* master_param_out = nullptr;
if (multi_precision) {
bool has_master =
ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut");
PADDLE_ENFORCE_EQ(has_master, true,
platform::errors::InvalidArgument(
"The Input(MasterParam) and Output(MasterParamOut) "
"should not be null when "
"the attr `multi_precision` is true"));
master_param = ctx.Input<framework::Tensor>("MasterParam");
master_param_out = ctx.Output<framework::Tensor>("MasterParamOut");
}
const MPDType* master_in_data =
multi_precision ? master_param->data<MPDType>() : nullptr;
MPDType* master_out_data =
multi_precision
? master_param_out->mutable_data<MPDType>(ctx.GetPlace())
: nullptr;

// Actually, all tensors are LoDTensor except SelectedRows.
if (grad_var->IsType<framework::LoDTensor>()) {
param_out->mutable_data<T>(ctx.GetPlace());
auto* grad = ctx.Input<framework::Tensor>("Grad");
// LOG(ERROR) << "grad";
// LOG(ERROR) << ctx.op().Input("Grad");
auto* grad_data = grad->data<T>();
// LOG(ERROR) << "param";
auto* param_data = param->data<T>();
// LOG(ERROR) << "fin";
auto* param_out_data = param_out->data<T>();

int block = 512;
int grid = (param->numel() + block - 1) / block;

SGDKernel<T><<<grid, block, 0, ctx.cuda_device_context().stream()>>>(
grad_data, param_data, learning_rate->data<T>(), param->numel(),
param_out_data);
SGDKernelMT<
T, MPDType><<<grid, block, 0, ctx.cuda_device_context().stream()>>>(
param->data<T>(), grad->data<T>(), learning_rate->data<T>(),
param->numel(), param_out->mutable_data<T>(ctx.GetPlace()),
master_in_data, master_out_data);

} else if (grad_var->IsType<framework::SelectedRows>()) {
// TODO(qijun): In Sparse SGD operator, in-place update is enforced.
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/pybind/op_function_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
"Beta2Pow", "MasterParam"}},
{"sparse_attention",
{"Q", "K", "V", "Offset", "Columns", "KeyPaddingMask", "AttnMask"}},
{"sgd", {"Param", "LearningRate", "Grad", "MasterParam"}},
};

// NOTE(zhiqiu): Like op_ins_map.
Expand Down Expand Up @@ -125,6 +126,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"adamw",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
"MasterParamOut"}},
{"sgd", {"ParamOut", "MasterParamOut"}},
{"lamb",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
"MasterParamOut"}},
Expand All @@ -142,7 +144,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
// especially in declarative mode.
// For those OPs, we need to manually specify the outs need to pass in this map.
std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"sgd", {"ParamOut"}},
{"sgd", {"ParamOut", "MasterParamOut"}},
{"adam",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
"MasterParamOut"}},
Expand Down
79 changes: 70 additions & 9 deletions python/paddle/fluid/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,7 @@ def __init__(self,
parameter_list=None,
regularization=None,
grad_clip=None,
multi_precision=False,
name=None):
assert learning_rate is not None
super(SGDOptimizer, self).__init__(
Expand All @@ -1306,26 +1307,86 @@ def __init__(self,
name=name)
self.type = "sgd"
self._use_mkldnn = False
self._multi_precision = multi_precision
self._master_weights = {}

def _create_master_weight(self, param):
if param.name in self._master_weights:
var = self._master_weights[param.name]
else:
assert isinstance(self.helper, LayerHelper)

var_name = param.name + "_fp32_master"
var_name = unique_name.generate(var_name)
var = layers.create_global_var(
name=var_name,
shape=param.shape,
value=0,
dtype='float32',
persistable=True)
block = self.helper.startup_program.global_block()
block.append_op(
type="cast",
inputs={"X": [param]},
outputs={"Out": [var]},
attrs={
"in_dtype": param.dtype,
"out_dtype": core.VarDesc.VarType.FP32
})
self._master_weights[param.name] = var
return var

def _create_accumulators(self, block, parameters):
assert isinstance(block, framework.Block)
if isinstance(parameters, dict):
parameters = self._update_param_group(parameters)

# Create accumulator tensors for first and second moments
for p in parameters:
if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16:
master_p = self._create_master_weight(p)
continue
if p.dtype == core.VarDesc.VarType.FP16 and not self._multi_precision:
warnings.warn(
"Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence."
"Consider using multi_precision=True option of the Adam optimizer."
)

@no_grad
def _append_optimize_op(self, block, param_and_grad):

find_master = self._multi_precision and param_and_grad[
0].dtype == core.VarDesc.VarType.FP16
master_weight = (self._master_weights[param_and_grad[0].name]
if find_master else None)

lr = self._create_param_lr(param_and_grad)
if framework.in_dygraph_mode():
_C_ops.sgd(param_and_grad[0], lr, param_and_grad[1],
param_and_grad[0])
_C_ops.sgd(param_and_grad[0], lr, param_and_grad[1], master_weight,
param_and_grad[0], master_weight)
return None

assert isinstance(block, framework.Block)
# create the optimize op
inputs = {
"Param": param_and_grad[0],
"Grad": param_and_grad[1],
"LearningRate": lr
}

outputs = {"ParamOut": param_and_grad[0]}

attrs = {"multi_precision": find_master}

if find_master:
inputs["MasterParam"] = master_weight
outputs["MasterParamOut"] = master_weight

sgd_op = block.append_op(
type=self.type,
inputs={
"Param": param_and_grad[0],
"Grad": param_and_grad[1],
"LearningRate": lr
},
attrs={"use_mkldnn": self._use_mkldnn},
outputs={"ParamOut": param_and_grad[0]},
inputs=inputs,
outputs=outputs,
attrs=attrs,
stop_gradient=True)

return sgd_op
Expand Down
Loading