Skip to content

Commit

Permalink
Bernoulli distribution returns an int instead of a boolean (#796)
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi authored Nov 3, 2020
1 parent 90ace43 commit b7af180
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions numpyro/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def __init__(self, probs, validate_args=None):
super(BernoulliProbs, self).__init__(batch_shape=jnp.shape(self.probs), validate_args=validate_args)

def sample(self, key, sample_shape=()):
return random.bernoulli(key, self.probs, shape=sample_shape + self.batch_shape)
samples = random.bernoulli(key, self.probs, shape=sample_shape + self.batch_shape)
return samples.astype(jnp.result_type(samples, int))

@validate_sample
def log_prob(self, value):
Expand Down Expand Up @@ -114,7 +115,8 @@ def __init__(self, logits=None, validate_args=None):
super(BernoulliLogits, self).__init__(batch_shape=jnp.shape(self.logits), validate_args=validate_args)

def sample(self, key, sample_shape=()):
return random.bernoulli(key, self.probs, shape=sample_shape + self.batch_shape)
samples = random.bernoulli(key, self.probs, shape=sample_shape + self.batch_shape)
return samples.astype(jnp.result_type(samples, int))

@validate_sample
def log_prob(self, value):
Expand Down

0 comments on commit b7af180

Please sign in to comment.