Skip to content

Commit

Permalink
cherry-pick from (PaddlePaddle#55621): add check for cembedding
Browse files Browse the repository at this point in the history
  • Loading branch information
ForFishes authored and wentaoyu committed Nov 8, 2023
1 parent 4807847 commit f15405c
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 12 deletions.
4 changes: 4 additions & 0 deletions paddle/fluid/operators/collective/c_embedding_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ class CEmbeddingOpMaker : public framework::OpProtoAndCheckerMaker {
"(int64, default 0), The starting index is indeed, "
"and the out-of-bounds will be set to 0 ")
.SetDefault(0);
AddAttr<int64_t>("vocab_size",
"(int64, default -1), The total vocabulary size to check"
"the out-of-bounds ids. If it is -1, no check will be ")
.SetDefault(-1);
AddComment(R"DOC(
c_embedding Operator.
Expand Down
21 changes: 13 additions & 8 deletions paddle/phi/kernels/gpu/c_embedding_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,23 @@ __global__ void CEmbedding(T* out,
const int64_t N,
const int64_t start_idx,
const int64_t end_idx,
const int64_t limit) {
const int64_t limit,
const int64_t vocab_size) {
CUDA_KERNEL_LOOP(i, limit) {
size_t row = i / columns;
size_t col = i % columns;
auto id = ids[row];

PADDLE_ENFORCE(
id >= 0 && (vocab_size < 0 || id < vocab_size),
"The index is out of bounds, "
"please check whether the dimensions of index and "
"input meet the requirements. It should "
"be less than [%d] and greater than or equal to 0, but received [%d]",
vocab_size,
id);
if (id >= start_idx && id < end_idx) {
auto real_idx = id - start_idx;
PADDLE_ENFORCE(real_idx < N,
"The index is out of bounds, "
"please check whether the dimensions of index and "
"input meet the requirements. It should "
"be less than [%d], but received [%d]",
N,
real_idx);
out[i] = table[real_idx * columns + col];
} else {
out[i] = static_cast<T>(0);
Expand All @@ -63,6 +65,7 @@ void CEmbeddingKernel(const Context& ctx,
const DenseTensor& w,
const DenseTensor& ids,
int64_t start_index,
int64_t vocab_size,
DenseTensor* out) {
size_t N = w.dims()[0];
size_t D = w.dims()[1];
Expand All @@ -88,6 +91,7 @@ void CEmbeddingKernel(const Context& ctx,
N,
start_index,
end_idx,
vocab_size,
limit);

} else if (index_type == phi::DataType::INT64) {
Expand All @@ -100,6 +104,7 @@ void CEmbeddingKernel(const Context& ctx,
N,
start_index,
end_idx,
vocab_size,
limit);
} else {
PADDLE_THROW(phi::errors::Unavailable(
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/distributed/fleet/layers/mpu/mp_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def __init__(
self._size = [per_part_size, embedding_dim]
self._weight_attr = weight_attr
self._name = name
self.num_embeddings = num_embeddings

if self.is_mp and paddle.in_dynamic_mode():
with get_rng_state_tracker().rng_state():
Expand Down Expand Up @@ -161,6 +162,7 @@ def forward(self, x):
self.weight,
x,
start_index=self.vocab_start_index,
vocab_size=self.num_embeddings,
name=self._name,
)
output = mp_ops._mp_allreduce(
Expand Down
12 changes: 8 additions & 4 deletions python/paddle/distributed/fleet/layers/mpu/mp_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def _mp_allreduce(
return out


def _c_lookup_table(table, index, start_index=0, name=None):
def _c_lookup_table(table, index, start_index=0, vocab_size=-1, name=None):
"""
Lookup table according to index.
Expand All @@ -346,7 +346,7 @@ def _c_lookup_table(table, index, start_index=0, name=None):
"""
if in_dynamic_mode():
return _legacy_C_ops.c_embedding(
table, index, "start_index", start_index
table, index, "start_index", start_index, "vocab_size", vocab_size
)
else:
op_type = 'c_embedding'
Expand All @@ -358,7 +358,7 @@ def _c_lookup_table(table, index, start_index=0, name=None):
type='c_embedding',
inputs={'Ids': index, 'W': table},
outputs={'Out': tmp},
attrs={"start_index": start_index},
attrs={"start_index": start_index, "vocab_size": vocab_size},
)
return tmp

Expand Down Expand Up @@ -684,7 +684,11 @@ def _parallel_embedding(
main_block.vars[weight.name].is_distributed = True

output_parallel = _c_lookup_table(
weight, x, start_index=vocab_start_index, name=name
weight,
x,
start_index=vocab_start_index,
vocab_size=origin_size[0],
name=name,
)
out = _mp_allreduce(
output_parallel,
Expand Down

0 comments on commit f15405c

Please sign in to comment.