-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy path07-basic-use.R
84 lines (73 loc) · 2.58 KB
/
07-basic-use.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
### Basic Use example ###
# globalpath <- "/storage/phd/projects/aghq-softwarepaper/paper/"
# plotpath <- paste0(globalpath,"/figures/basic-use/")
globalpath <- tempdir()
plotpath <- file.path(globalpath,"basic-use")
if (!dir.exists(plotpath)) dir.create(plotpath)
library(aghq)
set.seed(84343124)
y <- rpois(10,5) # True lambda = 5, n = 10
# Define the log posterior, log(pi(theta,y)) here
logpithetay <- function(theta,y) {
sum(y) * theta - (length(y) + 1) * exp(theta) - sum(lgamma(y+1)) + theta
}
objfunc <- function(x) logpithetay(x,y)
objfuncgrad <- function(x) numDeriv::grad(objfunc,x)
objfunchess <- function(x) numDeriv::hessian(objfunc,x)
# Now create the list to pass to aghq()
funlist <- list(
fn = objfunc,
gr = objfuncgrad,
he = objfunchess
)
# AGHQ with k = 3
# Use theta = 0 as a starting value
thequadrature <- aghq::aghq(ff = funlist,k = 3,startingvalue = 0)
summary(thequadrature)
plot(thequadrature)
# The posterior
thequadrature$normalized_posterior$nodesandweights
# The log normalization constant:
thequadrature$normalized_posterior$lognormconst
# Compare to the truth:
lgamma(1 + sum(y)) - (1 + sum(y)) * log(length(y) + 1) - sum(lgamma(y+1))
# Quite accurate with only n = 10 and k = 3; this example is very simple.
# The mode found by the optimization:
thequadrature$optresults$mode
# The true mode:
log((sum(y) + 1)/(length(y) + 1))
# Compute the pdf for theta
transformation <- list(totheta = log,fromtheta = exp)
pdfwithlambda <- compute_pdf_and_cdf(
thequadrature,
transformation = transformation
)[[1]]
head(pdfwithlambda,n = 2)
lambdapostsamps <- sample_marginal(thequadrature,1e04,transformation = transformation)[[1]]
# Plot along with the true posterior
pdf(file = file.path(plotpath,'lambda-post-plot.pdf'))
with(pdfwithlambda,{
hist(lambdapostsamps,breaks = 50,freq = FALSE,main = "",xlab = expression(lambda))
lines(transparam,pdf_transparam)
lines(transparam,dgamma(transparam,1+sum(y),1+length(y)),lty='dashed')
})
dev.off()
# Check if the posterior integrates to 1, by computing the "moment" of "1":
compute_moment(thequadrature$normalized_posterior,
ff = function(x) 1)
# Posterior mean for theta:
compute_moment(thequadrature$normalized_posterior,
ff = function(x) x)
# Posterior mean for lambda = exp(theta)
compute_moment(thequadrature$normalized_posterior,
ff = function(x) exp(x))
# Compare to the truth:
(sum(y) + 1)/(length(y) + 1)
# Quantiles
compute_quantiles(
thequadrature,
q = c(.01,.25,.50,.75,.99),
transformation = transformation
)[[1]]
# The truth:
qgamma(c(.01,.25,.50,.75,.99),1+sum(y),1+length(y))