Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix error message of multinomial op #27946

Merged
merged 13 commits into from
Oct 19, 2020
15 changes: 15 additions & 0 deletions paddle/fluid/operators/multinomial_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,27 @@ class MultinomialOp : public framework::OperatorWithKernel {

auto x_dim = ctx->GetInputDim("X");
int64_t x_rank = x_dim.size();
PADDLE_ENFORCE_GT(x_rank, 0,
platform::errors::InvalidArgument(
"The number of dimensions of the input probability "
"distribution should be > 0, but got %d.",
x_rank));
PADDLE_ENFORCE_LE(x_rank, 2,
platform::errors::InvalidArgument(
"The number of dimensions of the input probability "
"distribution should be <= 2, but got %d.",
x_rank));

std::vector<int64_t> out_dims(x_rank);
for (int64_t i = 0; i < x_rank - 1; i++) {
out_dims[i] = x_dim[i];
}

int64_t num_samples = ctx->Attrs().Get<int>("num_samples");
PADDLE_ENFORCE_GT(
num_samples, 0,
platform::errors::InvalidArgument(
"The number of samples should be > 0, but got %d", num_samples));
out_dims[x_rank - 1] = num_samples;

ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/operators/multinomial_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/multinomial_op.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/transform.h"

namespace paddle {
Expand All @@ -31,6 +32,14 @@ __global__ void NormalizeProbability(T* norm_probs, const T* in_data,
T* sum_rows) {
int id = threadIdx.x + blockIdx.x * blockDim.x +
blockIdx.y * gridDim.x * blockDim.x;
PADDLE_ENFORCE(
in_data[id] >= 0.0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议报错信息统一加句点,PR里有的加了,有的没加

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"The input of multinomial distribution should be >= 0, but got %f",
in_data[id]);
PADDLE_ENFORCE(sum_rows[blockIdx.y] > 0.0,
"The sum of one multinomial distribution probability should "
"be > 0, but got %f",
sum_rows[blockIdx.y]);
norm_probs[id] = in_data[id] / sum_rows[blockIdx.y];
}

Expand Down
29 changes: 15 additions & 14 deletions paddle/fluid/operators/multinomial_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,28 +44,29 @@ void MultinomialFunctor(int64_t* out_data, const T* in_data,
int64_t num_zeros = 0;
for (int64_t j = 0; j < num_categories; j++) {
prob_value = in_data[i * num_categories + j];
PADDLE_ENFORCE_GE(
prob_value, 0.0,
platform::errors::OutOfRange(
"The input of multinomial distribution should be >= 0"));
PADDLE_ENFORCE_EQ((std::isinf(static_cast<double>(prob_value)) ||
std::isnan(static_cast<double>(prob_value))),
false, platform::errors::OutOfRange(
"The input of multinomial distribution "
"shoud not be infinity or NaN"));
PADDLE_ENFORCE_GE(prob_value, 0.0,
platform::errors::InvalidArgument(
"The input of multinomial distribution "
"should be >= 0, but got %f",
prob_value));

probs_sum += prob_value;
if (prob_value == 0) {
num_zeros += 1;
}
cumulative_probs[j] = probs_sum;
}
PADDLE_ENFORCE_GT(probs_sum, 0.0, platform::errors::OutOfRange(
"The sum of input should not be 0"));
PADDLE_ENFORCE_GT(probs_sum, 0.0,
platform::errors::InvalidArgument(
"The sum of one multinomial distribution "
"probability should be > 0, but got %f",
probs_sum));
PADDLE_ENFORCE_EQ(
(replacement || (num_categories - num_zeros >= num_samples)), true,
platform::errors::OutOfRange("When replacement is False, number of "
"samples should be less than non-zero "
"categories"));
platform::errors::InvalidArgument(
"When replacement is False, number of "
"samples should be less than non-zero "
"categories"));

for (int64_t j = 0; j < num_categories; j++) {
cumulative_probs[j] /= probs_sum;
Expand Down
195 changes: 103 additions & 92 deletions python/paddle/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,48 +662,53 @@ class Categorical(Distribution):

Args:
logits(list|numpy.ndarray|Tensor): The logits input of categorical distribution. The data type is float32 or float64.
name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

Examples:
.. code-block:: python

import paddle
from paddle.distribution import Categorical
import paddle
from paddle.distribution import Categorical

x = paddle.rand([6])
print(x.numpy())
# [0.32564053, 0.99334985, 0.99034804,
# 0.09053693, 0.30820143, 0.19095989]
y = paddle.rand([6])
print(y.numpy())
# [0.6365463 , 0.7278677 , 0.90260243,
# 0.5226815 , 0.35837543, 0.13981032]
paddle.manual_seed(100)
x = paddle.rand([6])
print(x.numpy())
# [0.00224779 0.50324494 0.13526054
# 0.1611277 0.7955702 0.96897715]

cat = Categorical(x)
cat2 = Categorical(y)
paddle.manual_seed(200)
y = paddle.rand([6])
print(y.numpy())
# [0.00449559 0.00648983 0.27052107
# 0.3222554 0.5911404 0.93795437]

cat.sample([2,3])
# [[5, 1, 1],
# [0, 1, 2]]
cat = Categorical(x)
cat2 = Categorical(y)

cat.entropy()
# [1.71887]
cat.sample([2,3])
# [[4, 5, 5],
# [4, 2, 3]]

cat.kl_divergence(cat2)
# [0.0278455]
cat.entropy()
# [1.72595]

value = paddle.to_tensor([2,1,3])
cat.probs(value)
# [0.341613 0.342648 0.03123]
cat.kl_divergence(cat2)
# [0.0218145]

cat.log_prob(value)
# [-1.07408 -1.07105 -3.46638]
value = paddle.to_tensor([2,1,3])
cat.probs(value)
# [0.0527038 0.196088 0.0627829]

cat.log_prob(value)
# [-2.94307 -1.62919 -2.76807]

"""

def __init__(self, logits, name=None):
"""
Args:
logits(list|numpy.ndarray|Variable): The logits input of categorical distribution. The data type is float32 or float64.
logits(list|numpy.ndarray|Tensor): The logits input of categorical distribution. The data type is float32 or float64.
name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
"""
if not in_dygraph_mode():
check_type(logits, 'logits', (np.ndarray, tensor.Variable, list),
Expand All @@ -727,27 +732,28 @@ def sample(self, shape):
"""Generate samples of the specified shape.

Args:
shape (list): Shape of the generated samples.
shape (list): Shape of the generated samples.

Returns:
Tensor: A tensor with prepended dimensions shape.
Tensor: A tensor with prepended dimensions shape.

Examples:
.. code-block:: python
.. code-block:: python

import paddle
from paddle.distribution import Categorical
import paddle
from paddle.distribution import Categorical

x = paddle.rand([6])
print(x.numpy())
# [0.32564053, 0.99334985, 0.99034804,
# 0.09053693, 0.30820143, 0.19095989]
paddle.manual_seed(100)
x = paddle.rand([6])
print(x.numpy())
# [0.00224779 0.50324494 0.13526054
# 0.1611277 0.7955702 0.96897715]

cat = Categorical(x)
cat = Categorical(x)

cat.sample([2,3])
# [[5, 1, 1],
# [0, 1, 2]]
cat.sample([2,3])
# [[4, 5, 5],
# [4, 2, 3]]

"""
name = self.name + '_sample'
Expand Down Expand Up @@ -775,28 +781,31 @@ def kl_divergence(self, other):
other (Categorical): instance of Categorical. The data type is float32.

Returns:
Variable: kl-divergence between two Categorical distributions.
Tensor: kl-divergence between two Categorical distributions.

Examples:
.. code-block:: python
.. code-block:: python

import paddle
from paddle.distribution import Categorical
import paddle
from paddle.distribution import Categorical

paddle.manual_seed(100)
x = paddle.rand([6])
print(x.numpy())
# [0.00224779 0.50324494 0.13526054
# 0.1611277 0.7955702 0.96897715]

x = paddle.rand([6])
print(x.numpy())
# [0.32564053, 0.99334985, 0.99034804,
# 0.09053693, 0.30820143, 0.19095989]
y = paddle.rand([6])
print(y.numpy())
# [0.6365463 , 0.7278677 , 0.90260243,
# 0.5226815 , 0.35837543, 0.13981032]
paddle.manual_seed(200)
y = paddle.rand([6])
print(y.numpy())
# [0.00449559 0.00648983 0.27052107
# 0.3222554 0.5911404 0.93795437]

cat = Categorical(x)
cat2 = Categorical(y)
cat = Categorical(x)
cat2 = Categorical(y)

cat.kl_divergence(cat2)
# [0.0278455]
cat.kl_divergence(cat2)
# [0.0218145]

"""
name = self.name + '_kl_divergence'
Expand All @@ -823,23 +832,24 @@ def entropy(self):
"""Shannon entropy in nats.

Returns:
Variable: Shannon entropy of Categorical distribution. The data type is float32.
Tensor: Shannon entropy of Categorical distribution. The data type is float32.

Examples:
.. code-block:: python
.. code-block:: python

import paddle
from paddle.distribution import Categorical
import paddle
from paddle.distribution import Categorical

x = paddle.rand([6])
print(x.numpy())
# [0.32564053, 0.99334985, 0.99034804,
# 0.09053693, 0.30820143, 0.19095989]
paddle.manual_seed(100)
x = paddle.rand([6])
print(x.numpy())
# [0.00224779 0.50324494 0.13526054
# 0.1611277 0.7955702 0.96897715]

cat = Categorical(x)
cat = Categorical(x)

cat.entropy()
# [1.71887]
cat.entropy()
# [1.72595]

"""
name = self.name + '_entropy'
Expand All @@ -864,27 +874,28 @@ def probs(self, value):
with ``logits. That is, ``value[:-1] = logits[:-1]``.

Args:
value (Tensor): The input tensor represents the selected category index.
value (Tensor): The input tensor represents the selected category index.

Returns:
Tensor: probability according to the category index.
Tensor: probability according to the category index.

Examples:
.. code-block:: python
.. code-block:: python

import paddle
from paddle.distribution import Categorical
import paddle
from paddle.distribution import Categorical

x = paddle.rand([6])
print(x.numpy())
# [0.32564053, 0.99334985, 0.99034804,
# 0.09053693, 0.30820143, 0.19095989]
paddle.manual_seed(100)
x = paddle.rand([6])
print(x.numpy())
# [0.00224779 0.50324494 0.13526054
# 0.1611277 0.7955702 0.96897715]

cat = Categorical(x)
cat = Categorical(x)

value = paddle.to_tensor([2,1,3])
cat.probs(value)
# [0.341613 0.342648 0.03123]
value = paddle.to_tensor([2,1,3])
cat.probs(value)
# [0.0527038 0.196088 0.0627829]

"""
name = self.name + '_probs'
Expand Down Expand Up @@ -929,28 +940,28 @@ def log_prob(self, value):
"""Log probabilities of the given category. Refer to ``probs`` method.

Args:
value (Tensor): The input tensor represents the selected category index.
value (Tensor): The input tensor represents the selected category index.

Returns:
Tensor: Log probability.
Tensor: Log probability.

Examples:
.. code-block:: python

import paddle
from paddle.distribution import Categorical
.. code-block:: python

x = paddle.rand([6])
print(x.numpy())
# [0.32564053, 0.99334985, 0.99034804,
# 0.09053693, 0.30820143, 0.19095989]
import paddle
from paddle.distribution import Categorical

cat = Categorical(x)
paddle.manual_seed(100)
x = paddle.rand([6])
print(x.numpy())
# [0.00224779 0.50324494 0.13526054
# 0.1611277 0.7955702 0.96897715]

value = paddle.to_tensor([2,1,3])
cat = Categorical(x)

cat.log_prob(value)
# [-1.07408 -1.07105 -3.46638]
value = paddle.to_tensor([2,1,3])
cat.log_prob(value)
# [-2.94307 -1.62919 -2.76807]

"""
name = self.name + '_log_prob'
Expand Down
Loading