Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Embedding layer support more int ids type #39381

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions paddle/fluid/framework/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ limitations under the License. */
namespace paddle {
namespace framework {

extern std::string DataTypeToString(const proto::VarType::Type type);
extern size_t SizeOfType(proto::VarType::Type type);

template <typename T>
struct IsComplex : public std::false_type {};

Expand Down Expand Up @@ -63,6 +66,13 @@ struct DataTypeTrait<void> {
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex<double>, \
COMPLEX128);

#define _ForEachIntDataType_(callback) \
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
_ForEachDataTypeHelper_(callback, uint8_t, UINT8); \
_ForEachDataTypeHelper_(callback, int16_t, INT16); \
_ForEachDataTypeHelper_(callback, int8_t, INT8);

#define _ForEachDataTypeSmall_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \
_ForEachDataTypeHelper_(callback, double, FP64); \
Expand Down Expand Up @@ -138,6 +148,24 @@ inline void VisitDataTypeSmall(proto::VarType::Type type, Visitor visitor) {
#undef VisitDataTypeCallbackSmall
}

template <typename Visitor>
inline void VisitIntDataType(proto::VarType::Type type, Visitor visitor) {
#define VisitIntDataTypeCallback(cpp_type, proto_type) \
do { \
if (type == proto_type) { \
visitor.template apply<cpp_type>(); \
return; \
} \
} while (0)

_ForEachIntDataType_(VisitIntDataTypeCallback);

PADDLE_THROW(platform::errors::Unimplemented(
"Expected integral data type, but got %s", DataTypeToString(type)));

#undef VisitIntDataTypeCallback
}

template <typename Visitor>
inline void VisitDataTypeTiny(proto::VarType::Type type, Visitor visitor) {
#define VisitDataTypeCallbackTiny(cpp_type, proto_type) \
Expand Down Expand Up @@ -166,8 +194,6 @@ inline void VisitDataTypeForHIP(proto::VarType::Type type, Visitor visitor) {
#undef VisitDataTypeCallbackHIP
}

extern std::string DataTypeToString(const proto::VarType::Type type);
extern size_t SizeOfType(proto::VarType::Type type);
inline std::ostream& operator<<(std::ostream& out,
const proto::VarType::Type& type) {
out << DataTypeToString(type);
Expand Down
202 changes: 94 additions & 108 deletions paddle/fluid/operators/lookup_table_v2_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@ limitations under the License. */
namespace paddle {
namespace operators {

template <typename T, int BlockDimX, int BlockDimY, int GridDimX,
template <typename T, typename IdT, int BlockDimX, int BlockDimY, int GridDimX,
bool PaddingFlag>
__global__ void LookupTableV2(T *output, const T *table, const int64_t *ids,
__global__ void LookupTableV2(T *output, const T *table, const IdT *ids,
const int64_t N, const int64_t K, const int64_t D,
const int64_t padding_idx) {
int idx = threadIdx.x;
int idy = blockIdx.x + threadIdx.y * GridDimX;

while (idy < K) {
int64_t id = ids[idy];
auto id = static_cast<int64_t>(ids[idy]);
T *out = output + idy * D;
const T *tab = table + id * D;
for (int i = idx; i < D; i += BlockDimX) {
Expand All @@ -47,15 +47,15 @@ __global__ void LookupTableV2(T *output, const T *table, const int64_t *ids,
}
}

template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
__global__ void LookupTableV2Grad(T *table, const T *output, const int64_t *ids,
template <typename T, typename IdT, int BlockDimX, int BlockDimY, int GridDimX>
__global__ void LookupTableV2Grad(T *table, const T *output, const IdT *ids,
const int64_t N, const int64_t K,
const int64_t D) {
int idx = threadIdx.x;
int idy = blockIdx.x + threadIdx.y * GridDimX;

while (idy < K) {
int64_t id = ids[idy];
auto id = static_cast<int64_t>(ids[idy]);
const T *out = output + idy * D;
T *tab = table + id * D;
for (int i = idx; i < D; i += BlockDimX) {
Expand All @@ -66,123 +66,107 @@ __global__ void LookupTableV2Grad(T *table, const T *output, const int64_t *ids,
}

template <typename T>
__global__ void InputTypeCovert(const T *in_ids, const int64_t K,
int64_t *out_ids) {
for (int i = 0; i < K; i++) {
out_ids[i] = (int64_t)(in_ids[i]);
}
}

template <typename T>
class LookupTableV2CUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *table_t = context.Input<LoDTensor>("W");
auto *ids_t = context.Input<LoDTensor>("Ids");
auto *output_t = context.Output<LoDTensor>("Out");
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
struct LookupTableV2CUDAFunctor {
LookupTableV2CUDAFunctor(const framework::ExecutionContext &context,
const framework::Tensor *ids_t)
: context_(context), ids_t_(ids_t) {}

auto id_name = context.InputNames("Ids").front();
auto out_name = context.OutputNames("Out").front();
template <typename IdT>
void apply() {
auto *table_t = context_.Input<framework::Tensor>("W");
auto *output_t = context_.Output<framework::Tensor>("Out");
int64_t padding_idx = context_.Attr<int64_t>("padding_idx");

size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1];
size_t K = ids_t->numel();
size_t K = ids_t_->numel();

dim3 threads(256, 4);
dim3 grids(80, 1);

// copy GPU memory to CPU pinned memory
framework::Vector<int64_t> ids;
ids.resize(K);
const auto *table = table_t->template data<T>();
const auto *ids = ids_t_->template data<IdT>();
auto *output = output_t->template mutable_data<T>(context_.GetPlace());
auto stream = context_.cuda_device_context().stream();

const int64_t *ids_p = nullptr;

if (ids_t->type() == framework::proto::VarType::INT32) {
InputTypeCovert<
int><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
ids_t->data<int>(), K, ids.MutableData(context.GetPlace()));
ids_p = ids.MutableData(context.GetPlace());
if (padding_idx == -1) {
LookupTableV2<T, IdT, 256, 4, 80, false><<<grids, threads, 0, stream>>>(
output, table, ids, N, K, D, padding_idx);
} else {
ids_p = ids_t->data<int64_t>();
LookupTableV2<T, IdT, 256, 4, 80, true><<<grids, threads, 0, stream>>>(
output, table, ids, N, K, D, padding_idx);
}

for (int64_t i = 0; i < K; ++i) {
PADDLE_ENFORCE_GE(
ids[i], 0,
platform::errors::InvalidArgument(
"Variable value (input) of OP(paddle.nn.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input value.",
N, ids[i]));
PADDLE_ENFORCE_LT(
ids[i], N,
platform::errors::InvalidArgument(
"Variable value (input) of OP(paddle.nn.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input value.",
N, ids[i]));
}

auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());

if (padding_idx == -1)
LookupTableV2<
T, 256, 4, 80,
false><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids_p, N, K, D, padding_idx);
else
LookupTableV2<
T, 256, 4, 80,
true><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids_p, N, K, D, padding_idx);
}

private:
const framework::ExecutionContext &context_;
const framework::Tensor *ids_t_;
};

template <typename T>
class LookupTableV2GradCUDAKernel : public framework::OpKernel<T> {
class LookupTableV2CUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const auto *ids_t = context.Input<framework::Tensor>("Ids");
LookupTableV2CUDAFunctor<T> functor(context, ids_t);
framework::VisitIntDataType(ids_t->type(), functor);
}
};

template <typename InT, typename OutT>
__global__ void InputTypeConvert(const InT *in_ids, const int64_t K,
OutT *out_ids) {
for (int i = 0; i < K; i++) {
out_ids[i] = static_cast<OutT>(in_ids[i]);
}
}

template <typename T>
struct LookupTableV2GradCUDAFunctor {
LookupTableV2GradCUDAFunctor(const framework::ExecutionContext &context,
const framework::Tensor *ids_t)
: context_(context), ids_t_(ids_t) {}

template <typename IdT>
void apply() {
auto &dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
bool is_sparse = context.Attr<bool>("is_sparse");
context_.template device_context<platform::CUDADeviceContext>();
bool is_sparse = context_.Attr<bool>("is_sparse");

// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
if (is_sparse) {
auto *ids = context.Input<LoDTensor>("Ids");
auto *table = context.Input<LoDTensor>("W");
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto *table = context_.Input<framework::Tensor>("W");
auto *d_output =
context_.Input<framework::Tensor>(framework::GradVarName("Out"));
auto *d_table =
context.Output<pten::SelectedRows>(framework::GradVarName("W"));
context_.Output<pten::SelectedRows>(framework::GradVarName("W"));

auto *ids_data = ids->data<int64_t>();
int64_t ids_num = ids->numel();
const auto *ids_data = ids_t_->template data<IdT>();
int64_t ids_num = ids_t_->numel();
dim3 threads(128, 8);
dim3 grids(8, 1);
auto stream = dev_ctx.stream();
// copy GPU memory to CPU pinned memory
framework::Vector<int64_t> new_rows;
new_rows.resize(ids_num);
auto gpu_place = context.GetPlace();
auto gpu_place = context_.GetPlace();

if (ids->type() == framework::proto::VarType::INT32) {
InputTypeCovert<
int><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
ids->data<int>(), ids_num,
new_rows.MutableData(context.GetPlace()));
if (!std::is_same<IdT, int64_t>::value) {
InputTypeConvert<<<grids, threads, 0, stream>>>(
ids_data, ids_num, new_rows.MutableData(gpu_place));
} else {
memory::Copy(gpu_place, new_rows.CUDAMutableData(context.GetPlace()),
gpu_place, ids_data, ids_num * sizeof(int64_t), stream);
memory::Copy(gpu_place, new_rows.CUDAMutableData(gpu_place), gpu_place,
ids_data, ids_num * sizeof(int64_t), stream);
}

d_table->set_rows(new_rows);

auto *d_table_value = d_table->mutable_value();
d_table_value->Resize({ids_num, table->dims()[1]});
d_table_value->mutable_data<T>(context.GetPlace());
d_table_value->template mutable_data<T>(gpu_place);

auto *d_table_data = d_table_value->data<T>();
auto *d_output_data = d_output->data<T>();
auto *d_table_data = d_table_value->template data<T>();
auto *d_output_data = d_output->template data<T>();
auto d_output_dims = d_output->dims();
auto d_output_dims_2d =
framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1);
Expand All @@ -197,41 +181,43 @@ class LookupTableV2GradCUDAKernel : public framework::OpKernel<T> {
d_output->numel() * sizeof(T), stream);

} else {
auto ids_t = context.Input<LoDTensor>("Ids");
auto d_output_t = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto d_table_t = context.Output<LoDTensor>(framework::GradVarName("W"));
auto d_output_t =
context_.Input<framework::Tensor>(framework::GradVarName("Out"));
auto d_table_t =
context_.Output<framework::Tensor>(framework::GradVarName("W"));

int N = d_table_t->dims()[0];
int D = d_table_t->dims()[1];
int K = ids_t->numel();
int K = ids_t_->numel();

dim3 threads(128, 8);
dim3 grids(8, 1);
// copy GPU memory to CPU pinned memory
framework::Vector<int64_t> ids;
ids.resize(K);

const int64_t *ids_p = nullptr;

if (ids_t->type() == framework::proto::VarType::INT32) {
InputTypeCovert<
int><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
ids_t->data<int>(), K, ids.MutableData(context.GetPlace()));
ids_p = ids.MutableData(context.GetPlace());
} else {
ids_p = ids_t->data<int64_t>();
}

const T *d_output = d_output_t->data<T>();
T *d_table = d_table_t->mutable_data<T>(context.GetPlace());
const T *d_output = d_output_t->template data<T>();
const auto *ids = ids_t_->template data<IdT>();
T *d_table = d_table_t->mutable_data<T>(context_.GetPlace());

auto t = framework::EigenVector<T>::Flatten(*d_table_t);
t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0));

LookupTableV2Grad<T, 128, 8, 8><<<grids, threads, 0, dev_ctx.stream()>>>(
d_table, d_output, ids_p, N, K, D);
LookupTableV2Grad<T, IdT, 128, 8,
8><<<grids, threads, 0, dev_ctx.stream()>>>(
d_table, d_output, ids, N, K, D);
}
}

private:
const framework::ExecutionContext &context_;
const framework::Tensor *ids_t_;
};

template <typename T>
class LookupTableV2GradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const auto *ids_t = context.Input<framework::Tensor>("Ids");
LookupTableV2GradCUDAFunctor<T> functor(context, ids_t);
framework::VisitIntDataType(ids_t->type(), functor);
}
};

} // namespace operators
Expand Down
Loading