Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ZIP and ZINB #17

Merged
merged 17 commits into from
Dec 7, 2024
Merged

Refactor ZIP and ZINB #17

merged 17 commits into from
Dec 7, 2024

Conversation

hrlai
Copy link
Contributor

@hrlai hrlai commented Nov 30, 2022

This pull request builds on #15 . Changes in 08c82fd and 093ae0b addresses question raised in #15 (comment) , where the log likelihood doesn't quite match the output of extraDistr::dzip and dzinb. In the refactor codes, the R equivalent for ZIP now looks identical to extraDistr:

pi_var <- 0.2
lambda <- 1

r_log_prob <- function(x) {
  log(
    (pi_var * (1 - sign(abs(x))) + 
       exp(
         log1p(-pi_var) - lambda + 
           x * log(lambda) - lgamma(x + 1))
    )
  )
}

x <- seq(-10, 170, by = 5)
z <- extraDistr::dzip(x, lambda = 1, pi= 0.2)
log(z)
#>  [1]         -Inf         -Inf   -0.7046055   -6.0106353  -16.3275561
#>  [6]  -29.1224149  -43.5587600  -59.2267488  -75.8813799  -93.3593192
#> [11] -111.5437833 -130.3470772 -149.7009105 -169.5505890 -189.8513170
#> [16] -210.5657303 -231.6621871 -253.1135458 -274.8962678 -296.9897449
#> [21] -319.3757832 -342.0382024 -364.9625191 -388.1356927 -411.5459201
#> [26] -435.1824675 -459.0355315 -483.0961228 -507.3559689 -531.8074318
#> [31] -556.4434377 -581.2574163 -606.2432494 -631.3952254 -656.7080003
#> [36] -682.1765631 -707.7962058
r_log_prob(x)
#>  [1]         -Inf         -Inf   -0.7046055   -6.0106353  -16.3275561
#>  [6]  -29.1224149  -43.5587600  -59.2267488  -75.8813799  -93.3593192
#> [11] -111.5437833 -130.3470772 -149.7009105 -169.5505890 -189.8513170
#> [16] -210.5657303 -231.6621871 -253.1135458 -274.8962678 -296.9897449
#> [21] -319.3757832 -342.0382024 -364.9625191 -388.1356927 -411.5459201
#> [26] -435.1824675 -459.0355315 -483.0961228 -507.3559689 -531.8074318
#> [31] -556.4434377 -581.2574163 -606.2432494 -631.3952254 -656.7080003
#> [36] -682.1765631 -707.7962058

Created on 2022-11-30 with reprex v2.0.2

And ZINB also looks identical to extraDistr:

pi_var <- 0.2
p <- 0.1
q <- 1 - p
size <- 10

r_log_prob <- function(x) {
  log(
    (pi_var * (1 - sign(abs(x))) + 
       exp(
         log1p(-pi_var) + 
           lchoose(x+size-1, x) + 
           size * log(p) + 
           x * log1p(-p)
       )
    )
  )
}

x <- seq(-10, 170, by = 5)
z <- extraDistr::dzinb(x, size = 10, prob= 0.1, pi=0.2)
log(z)
#>  [1]       -Inf       -Inf  -1.609438 -16.173895 -12.868956 -10.745772
#>  [7]  -9.236610  -8.107613  -7.238114  -6.557344  -6.020138  -5.595867
#> [13]  -5.262786  -5.004896  -4.810078  -4.668928  -4.574004  -4.519313
#> [19]  -4.499958  -4.511885  -4.551703  -4.616541  -4.703953  -4.811833
#> [25]  -4.938359  -5.081940  -5.241181  -5.414852  -5.601863  -5.801241
#> [31]  -6.012118  -6.233713  -6.465322  -6.706310  -6.956097  -7.214159
#> [37]  -7.480016
r_log_prob(x)
#>  [1]       -Inf       -Inf  -1.609438 -16.173895 -12.868956 -10.745772
#>  [7]  -9.236610  -8.107613  -7.238114  -6.557344  -6.020138  -5.595867
#> [13]  -5.262786  -5.004896  -4.810078  -4.668928  -4.574004  -4.519313
#> [19]  -4.499958  -4.511885  -4.551703  -4.616541  -4.703953  -4.811833
#> [25]  -4.938359  -5.081940  -5.241181  -5.414852  -5.601863  -5.801241
#> [31]  -6.012118  -6.233713  -6.465322  -6.706310  -6.956097  -7.214159
#> [37]  -7.480016

Created on 2022-11-30 with reprex v2.0.2

Two things that I am not sure if 100% efficient or correct:

  1. Replaced relu(x) with 1 - sign(abs(x)) because I don't feel comfortable with relu giving negative x increasing values. I feel that it should be more like a step function to turn the pi_var component "on" or "off" depending on whether x == 0. There might be a better way of doing ifelse-like condition in tensorflow but I don't know how...
  2. I hope I'm using tf_lchoose correctly...

Earlier comments here and here actually only showed that sample is working, not the log_prob. I think we need to use these ZIP and ZINB with distribution or model and then mcmc to properly test them, so here you go. For ZIP (sorry reprex keeps crashing for the codes below, don't know why):

library(greta.distributions)
y <- extraDistr::rzip(1e3, lambda = 1, pi= 0.2)
lambda <- normal(0, 1, truncation = c(0, Inf))
zi_pi <- beta(2, 2)
distribution(y) <- zero_inflated_poisson(lambda, zi_pi)
m <- model(lambda, zi_pi)
draws <- mcmc(m)
plot(draws)

image

And for ZINB:

library(greta.distributions)
y <- extraDistr::rzinb(1e3, size = 10, prob = 0.1, pi = 0.2)
size <- normal(0, 10, truncation = c(0, Inf))
prob <- beta(2, 2)
zi_pi <- beta(2, 2)
distribution(y) <- zero_inflated_negative_binomial(size, prob, zi_pi)
m <- model(size, prob, zi_pi)
draws <- mcmc(m)
plot(draws)

image

Both ZIP and ZINB models seem to retrieve the true values well. Though the ZINB model seems to struggle a bit inferring the proportion of zeros.

Hope this makes sense. Will use it on real data and report any peculiarities back. Not 100% sure everything's correct 😉

olivuntu and others added 11 commits July 22, 2022 06:57
…, need to explore how the distribution is defined and why this could be happening from the TF end
…latter is less foolproof because I am not sure how greta handles negative integers for ZINB that only support non-negative integers), then (2) leverage on tf_lchoose that's already being used in some greta distributions, and finally (3) calculate things in log space before converting them back to regular space
* add coda, cramer, distributional, extraDistr, likelihoodExplore, mvtnorm, and truncdist to Suggests
* Update to Roxygen 7.3.1
* use globalVariables to capture NOTE
* use new sentinel package structure - "greta-distributions-packate.R"
* import some greta internals directly into helpers.R
@njtierney
Copy link
Collaborator

Hi @hrlai ! thanks so much for this.

I've made some changes to get checks to pass appropriately. Currently running into issues verifying that the distributions are matching the ones that they reference. I'll check back in on this. But keen to get this merged soon :)

@njtierney njtierney merged commit dee60f0 into greta-dev:main Dec 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants