-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add pattern for MultivariateNormal(affine) #245
Conversation
OK, I've fixed most of the shape errors and we're down to numerical errors only 😄 |
Yay all tests pass 😄 Thanks for all the help @fehiepsi ! |
@@ -13,8 +14,49 @@ | |||
from funsor.torch import Tensor | |||
|
|||
|
|||
# This version constructs factors using funsor.distributions. | |||
@pytest.mark.parametrize('state_dim,obs_dim', [(3, 2), (2, 3)]) | |||
def test_distributions(state_dim, obs_dim): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jpchen you can follow the idioms of this test in your experiment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks beautiful for me! There is just a small point which I don't understand yet.
|
||
# Compute log_prob using funsors. | ||
scale_diag = Tensor(scale_tril.data.diagonal(dim1=-1, dim2=-2), scale_tril.inputs) | ||
log_prob = (-0.5 * scale_diag.shape[0] * math.log(2 * math.pi) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious: why we use shape[0]
instead of shape[-1]
, scale_diag.log().sum()
instead of scale_diag.log().sum(-1)
, and 0.5 * (const ** 2).sum()
instead of 0.5 * (const ** 2).sum(-1)
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are equivalent: scale_diag.shape[0] == scale_diag.shape[-1]
. Here scale_diag
is a funsor, and it separates "batch" .inputs
from "event" .shape
. In fact scale_diag.shape == (dim,)
regardless of batching. This also allows us to call .sum()
below rather than .sum(-1)
, since there is only one tensor dimension.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's nice! Thanks for explaining.
Addresses #72
pair coded with @eb8680
Tested
MultivariateNormal(x + y, ...)
etc.dist.MultivariateNormal
directly