Skip to content

Commit

Permalink
Avoid tuples (#52)
Browse files Browse the repository at this point in the history
* use array not tuple (outside of tests)

* refactor to use arrays and vmap

* need one more vmap

* arrays in tests

* too much as a default

* more stringent test

* better
  • Loading branch information
ismael-mendoza authored Nov 30, 2024
1 parent 238d4b9 commit 539b8b3
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 52 deletions.
7 changes: 4 additions & 3 deletions bpd/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@

_grad_fnc1 = vmap(vmap(grad(inv_shear_func1), in_axes=(0, None)), in_axes=(0, None))
_grad_fnc2 = vmap(vmap(grad(inv_shear_func2), in_axes=(0, None)), in_axes=(0, None))
_inv_shear_trans = vmap(inv_shear_transformation, in_axes=(0, None))


def shear_loglikelihood_unreduced(
g: tuple[float, float], e_post: Array, prior: Callable, interim_prior: Callable
g: Array, e_post: Array, prior: Callable, interim_prior: Callable
) -> ArrayLike:
# Given by the inference procedure in Schneider et al. 2014
# assume single shear g
Expand All @@ -34,7 +35,7 @@ def shear_loglikelihood_unreduced(
grad2 = _grad_fnc2(e_post, g)
absjacdet = jnp.abs(grad1[..., 0] * grad2[..., 1] - grad1[..., 1] * grad2[..., 0])

e_post_unsheared = inv_shear_transformation(e_post, g)
e_post_unsheared = _inv_shear_trans(e_post, g)
e_post_unsheared_mag = norm(e_post_unsheared, axis=-1)
num = prior(e_post_unsheared_mag) * absjacdet # (N, K)

Expand All @@ -43,7 +44,7 @@ def shear_loglikelihood_unreduced(


def shear_loglikelihood(
g: tuple[float, float], e_post, prior: Callable, interim_prior: Callable
g: Array, e_post: Array, prior: Callable, interim_prior: Callable
) -> float:
"""Reduce with sum"""
return shear_loglikelihood_unreduced(g, e_post, prior, interim_prior).sum()
4 changes: 3 additions & 1 deletion bpd/pipelines/image_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ def get_true_params_from_galaxy_params(galaxy_params: dict[str, Array]):
e1, e2 = true_params.pop("e1"), true_params.pop("e2")
g1, g2 = true_params.pop("g1"), true_params.pop("g2")

e1_prime, e2_prime = scalar_shear_transformation((e1, e2), (g1, g2))
e1_prime, e2_prime = scalar_shear_transformation(
jnp.array([e1, e2]), jnp.array([g1, g2])
)
true_params["e1"] = e1_prime
true_params["e2"] = e2_prime

Expand Down
52 changes: 16 additions & 36 deletions bpd/prior.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import jax.numpy as jnp
from jax import Array, random
from jax import Array, random, vmap
from jax._src.prng import PRNGKeyArray
from jax.numpy.linalg import norm
from jaxtyping import ArrayLike
Expand Down Expand Up @@ -39,71 +39,51 @@ def sample_ellip_prior(rng_key: PRNGKeyArray, sigma: float, n: int = 1):
return jnp.stack((e1, e2), axis=1)


def scalar_shear_transformation(e: tuple[float, float], g: tuple[float, float]):
def scalar_shear_transformation(e: Array, g: Array):
"""Transform elliptiticies by a fixed shear (scalar version).
The transformation we used is equation 3.4b in Seitz & Schneider (1997).
NOTE: This function is meant to be vmapped later.
"""
assert e.shape == (2,) and g.shape == (2,)

e1, e2 = e
g1, g2 = g

e_comp = e1 + e2 * 1j
g_comp = g1 + g2 * 1j

e_prime = (e_comp + g_comp) / (1 + g_comp.conjugate() * e_comp)
return e_prime.real, e_prime.imag
return jnp.array([e_prime.real, e_prime.imag])


def scalar_inv_shear_transformation(e: tuple[float, float], g: tuple[float, float]):
def scalar_inv_shear_transformation(e: Array, g: Array):
"""Same as above but the inverse."""
assert e.shape == (2,) and g.shape == (2,)
e1, e2 = e
g1, g2 = g

e_comp = e1 + e2 * 1j
g_comp = g1 + g2 * 1j

e_prime = (e_comp - g_comp) / (1 - g_comp.conjugate() * e_comp)
return e_prime.real, e_prime.imag
return jnp.array([e_prime.real, e_prime.imag])


# batched
shear_transformation = vmap(scalar_shear_transformation, in_axes=(0, None))
inv_shear_transformation = vmap(scalar_inv_shear_transformation, in_axes=(0, None))

# useful for jacobian later, only need 2 grads really
# useful for jacobian later
inv_shear_func1 = lambda e, g: scalar_inv_shear_transformation(e, g)[0]
inv_shear_func2 = lambda e, g: scalar_inv_shear_transformation(e, g)[1]


def shear_transformation(e: Array, g: tuple[float, float]):
"""Transform elliptiticies by a fixed shear.
The transformation we used is equation 3.4b in Seitz & Schneider (1997).
"""
e1, e2 = e[..., 0], e[..., 1]
g1, g2 = g

e_comp = e1 + e2 * 1j
g_comp = g1 + g2 * 1j

e_prime = (e_comp + g_comp) / (1 + g_comp.conjugate() * e_comp)
return jnp.stack([e_prime.real, e_prime.imag], axis=-1)


def inv_shear_transformation(e: Array, g: tuple[float, float]):
"""Same as above but the inverse."""
e1, e2 = e[..., 0], e[..., 1]
g1, g2 = g

e_comp = e1 + e2 * 1j
g_comp = g1 + g2 * 1j

e_prime = (e_comp - g_comp) / (1 - g_comp.conjugate() * e_comp)
return jnp.stack([e_prime.real, e_prime.imag], axis=-1)


# get synthetic measured sheared ellipticities
def sample_synthetic_sheared_ellips_unclipped(
rng_key: PRNGKeyArray,
g: tuple[float, float],
g: Array,
n: int,
sigma_m: float,
sigma_e: float,
Expand All @@ -119,7 +99,7 @@ def sample_synthetic_sheared_ellips_unclipped(

def sample_synthetic_sheared_ellips_clipped(
rng_key: PRNGKeyArray,
g: tuple[float, float],
g: Array,
sigma_m: float,
sigma_e: float,
n: int = 1,
Expand All @@ -140,7 +120,7 @@ def sample_synthetic_sheared_ellips_clipped(

# clip magnitude to < 1
# preserve angle after noise added when clipping
beta = jnp.arctan2(e_obs[:, :, 1], e_obs[:, :, 0]) / 2
beta = jnp.arctan2(e_obs[:, :, 1], e_obs[:, :, 0]) * 0.5
e_obs_mag = norm(e_obs, axis=-1)
e_obs_mag = jnp.clip(e_obs_mag, 0, e_tol) # otherwise likelihood explodes

Expand Down
2 changes: 1 addition & 1 deletion experiments/exp2/run_inference_galaxy_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def main(
_run_sampling = vmap(vmap(jjit(_run_sampling1), in_axes=(0, 0, 0, None)))

results = {}
for n_gals in (1, 1, 5, 10, 20, 25, 50, 100, 250, 500): # repeat 1 == compilation
for n_gals in (1, 1, 5, 10, 20, 25, 50, 100, 250): # repeat 1 == compilation
print("n_gals:", n_gals)

# generate data and parameters
Expand Down
21 changes: 19 additions & 2 deletions tests/test_shear_inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Test shear inference reaches desired accuracy for low-noise regime."""

import jax.numpy as jnp
import numpy as np
import pytest
from jax import random

Expand Down Expand Up @@ -44,5 +45,21 @@ def test_shear_inference_toy_ellipticities(seed):
assert shear_samples.shape == (1000, 2)
assert jnp.abs((jnp.mean(shear_samples[:, 0]) - g1) / g1) <= 3e-3
assert jnp.abs(jnp.mean(shear_samples[:, 1])) <= 3e-3
assert jnp.std(shear_samples[:, 0]) > 0
assert jnp.std(shear_samples[:, 1]) > 0
assert np.allclose(
jnp.std(shear_samples[:, 0]), sigma_e / jnp.sqrt(1000), rtol=0.1, atol=0
)
assert np.allclose(
jnp.std(shear_samples[:, 1]), sigma_e / jnp.sqrt(1000), rtol=0.1, atol=0
)
assert not np.allclose(
jnp.std(shear_samples[:, 0]),
sigma_e / jnp.sqrt(1000) / jnp.sqrt(2),
rtol=0.1,
atol=0,
)
assert not np.allclose(
jnp.std(shear_samples[:, 1]),
sigma_e / jnp.sqrt(1000) / jnp.sqrt(2),
rtol=0.1,
atol=0,
)
21 changes: 12 additions & 9 deletions tests/test_shear_trans.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from functools import partial

import jax.numpy as jnp
import numpy as np
import pytest
from jax import jit as jjit
Expand All @@ -23,12 +24,11 @@ def test_scalar_inverse():
for e2 in ellips:
for g1 in shears:
for g2 in shears:
e_trans = scalar_shear_transformation((e1, e2), (g1, g2))
e1_new, e2_new = scalar_inv_shear_transformation(e_trans, (g1, g2))

e_array = np.array([e1, e2])
e_new_array = np.array([e1_new, e2_new])
np.testing.assert_allclose(e_new_array, e_array, atol=1e-15)
e = jnp.array([e1, e2])
g = jnp.array([g1, g2])
e_trans = scalar_shear_transformation(e, g)
e_new = scalar_inv_shear_transformation(e_trans, g)
np.testing.assert_allclose(e_new, e, atol=1e-15)


@pytest.mark.parametrize("seed", [1234, 4567])
Expand All @@ -41,8 +41,9 @@ def test_transformation(seed):

for g1 in shears:
for g2 in shears:
e_trans_samples = shear_transformation(e_samples, (g1, g2))
e_new = inv_shear_transformation(e_trans_samples, (g1, g2))
g = jnp.array([g1, g2])
e_trans_samples = shear_transformation(e_samples, g)
e_new = inv_shear_transformation(e_trans_samples, g)
assert e_new.shape == (100, 2)
np.testing.assert_allclose(e_new, e_samples)

Expand All @@ -60,7 +61,9 @@ def test_image_shear_commute():
for e2 in ellips:
for g1 in shears:
for g2 in shears:
(e1_p, e2_p) = scalar_shear_transformation((e1, e2), (g1, g2))
e = jnp.array([e1, e2])
g = jnp.array([g1, g2])
(e1_p, e2_p) = scalar_shear_transformation(e, g)
im1 = draw_jitted(
f=f, hlr=hlr, e1=e1, e2=e2, g1=g1, g2=g2, x=x, y=y
)
Expand Down

0 comments on commit 539b8b3

Please sign in to comment.