Skip to content

Commit

Permalink
Experiment 4.0 realistic noise gaussian galaxies (#81)
Browse files Browse the repository at this point in the history
* script to submit batch jobs (temporary)

* rename

* what distribution off luxes to use for tests?

* nothing

* setup experiment 40

* forgot true flux might be useful

* simple script to collate samples from individual files into one

* ordering

* simplify script

* figures with higher noise

* higher jackknife default and more explicit variables

* changes to ensure correct ordering

* run in two modes

* get jackknife estimate of mean shear posterior and reduce error

* draft figure file

* slurm file update to run in two modes

* useful for mode

* update import

* preliminary figures look great

* check

* modes

* update shell script

* update with flags

* vmap version of jackknife

* more plots

* all figures in this experiment

* draft make file for exp40

* add multiplicative bias contours

* no jack

* trying to understand

* simplify makefile works really well

* slight improvements

* redo figures

* move figures around

* add tag for generality

* script to make figures for exp 40

* make file absorbts this

* instructions for make files

* tag option

* fix makefile

* example plots for another experiment

* return both samples to be summarized later

* use new vectorized version of JK

* consistent tag

* add tag to figures

* add jackknife as an option

* more options to propagate for more general usage

* add summary of bias results

* make file for low noise experimen

* readme

* bug in jackknife

* looks nicer

* more samples just in case

* fix makefile

* no updates

* fix quote

* doesn't matter

* more careful so it run sfaster

* less samples is probably OK

* debug

* all new results

* don't need as many resources

* a little more

* another seed

* seed
  • Loading branch information
ismael-mendoza authored Jan 22, 2025
1 parent 3d48d44 commit 50bbbf3
Show file tree
Hide file tree
Showing 53 changed files with 1,919 additions and 25 deletions.
54 changes: 48 additions & 6 deletions bpd/jackknife.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@
from typing import Callable

import jax.numpy as jnp
from jax import Array, random
from jax import Array, jit, random, vmap
from tqdm import tqdm


def run_jackknife_shear_pipeline(
rng_key,
*,
init_g: Array,
post_params_pos: dict,
post_params_neg: dict,
shear_pipeline: Callable,
n_jacks: int = 10,
n_gals: int,
n_jacks: int = 50,
disable_bar: bool = True,
):
"""Use jackknife+shape noise cancellation to estimate the mean and std of the shear posterior.
Expand All @@ -28,11 +30,9 @@ def run_jackknife_shear_pipeline(
n_jacks: Number of jackknife batches.
Returns:
Jackknife
Jackknife samples of shear posterior mean combined with shape noise cancellation trick.
"""
N, _ = post_params_pos["e1"].shape # N = n_gals, K = n_samples_per_gal
batch_size = ceil(N / n_jacks)
batch_size = ceil(n_gals / n_jacks)

g_best_list = []
keys = random.split(rng_key, n_jacks)
Expand All @@ -57,3 +57,45 @@ def run_jackknife_shear_pipeline(

g_best_means = jnp.array(g_best_list)
return g_best_means


def run_jackknife_vectorized(
rng_key,
*,
init_g: Array,
post_params_pos: dict,
post_params_neg: dict,
shear_pipeline: Callable,
n_gals: int,
n_jacks: int = 50,
):
batch_size = ceil(n_gals / n_jacks)
keys = random.split(rng_key, n_jacks)

# prepare dictionaries of jackknife samples
params_jack_pos = {}
params_jack_neg = {}
for k in post_params_pos:
v1 = post_params_pos[k]
v2 = post_params_neg[k]
all_jack_params_pos = []
all_jack_params_neg = []
for ii in range(n_jacks):
start, end = ii * batch_size, (ii + 1) * batch_size
all_jack_params_pos.append(jnp.concatenate([v1[:start], v1[end:]]))
all_jack_params_neg.append(jnp.concatenate([v2[:start], v2[end:]]))

params_jack_pos[k] = jnp.stack(all_jack_params_pos, axis=0)
params_jack_neg[k] = jnp.stack(all_jack_params_neg, axis=0)

# run on a single example for compilation purposes
vec_shear_pipeline = jit(vmap(shear_pipeline, in_axes=(0, 0, None)))
_ = vec_shear_pipeline(
keys[0, None], {k: v[0, None] for k, v in params_jack_pos.items()}, init_g
)

# run on full dataset
g_pos_samples = vec_shear_pipeline(keys, params_jack_pos, init_g)
g_neg_samples = vec_shear_pipeline(keys, params_jack_neg, -init_g)

return g_pos_samples, g_neg_samples
3 changes: 1 addition & 2 deletions scripts/slurm/slurm_job.py → bpd/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ def setup_sbatch_job_gpu(

# prepare files and directories
jobfile_name = f"{jobname}_{JOB_SEED}.sbatch"
job_dir = Path(JOB_DIR)
jobfile = job_dir.joinpath(jobfile_name)
jobfile = JOB_DIR.joinpath(jobfile_name)

with open(jobfile, "w", encoding="utf-8") as f:
f.writelines(
Expand Down
2 changes: 1 addition & 1 deletion experiments/exp30/get_posteriors.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ export JAX_ENABLE_X64="True"
SEED="43"

./get_image_interim_samples_fixed.py $SEED
../../scripts/get_shear_from_shapes.py $SEED exp30_$SEED "e_post_${SEED}.npz" --overwrite
../../scripts/get_shear_from_shapes.py $SEED --old-seed $SEED --interim-samples-fname "e_post_${SEED}.npz" --tag exp30_$SEED --overwrite
2 changes: 1 addition & 1 deletion experiments/exp31/get_posteriors.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ export JAX_ENABLE_X64="True"
SEED="43"

./get_interim_samples.py $SEED
../../scripts/get_shear_from_shapes.py $SEED exp31_$SEED "e_post_${SEED}.npz" --overwrite
../../scripts/get_shear_from_shapes.py $SEED --old-seed $SEED --interim-samples-fname "e_post_${SEED}.npz" --tag exp31_ --overwrite
31 changes: 31 additions & 0 deletions experiments/exp40/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
SHELL=/bin/bash
SEED := 42
DIR := /pscratch/sd/i/imendoza/data/cache_chains/exp40_${SEED}
TAG := exp40_${SEED}
export JAX_ENABLE_X64=True

figures:
./get_figures.py ${SEED} --tag ${TAG}

jackknife:
export CUDA_VISIBLE_DEVICES=0
./get_shear_jackknife.py ${SEED} --old-seed ${SEED} \
--samples-plus-fname interim_samples_${SEED}_plus.npz \
--samples-minus-fname interim_samples_${SEED}_minus.npz \
--tag ${TAG} --overwrite

shear:
export CUDA_VISIBLE_DEVICES=0
../../scripts/get_shear_from_shapes.py ${SEED} --old-seed ${SEED} --interim-samples-fname "interim_samples_${SEED}_plus.npz" --tag ${TAG} --overwrite --extra-tag "plus"
../../scripts/get_shear_from_shapes.py ${SEED} --old-seed ${SEED} --interim-samples-fname "interim_samples_${SEED}_minus.npz" --tag ${TAG} --overwrite --extra-tag "minus"

collate:
./collate_samples.py ${SEED} --tag ${TAG} --mode "plus" --n-files 4
./collate_samples.py ${SEED} --tag ${TAG} --mode "minus" --n-files 4

samples:
./slurm_get_interim_samples.py ${SEED} --tag ${TAG} --g1 0.02 --g2 0.0 --mode "plus"
./slurm_get_interim_samples.py ${SEED} --tag ${TAG} --g1 -0.02 --g2 0.0 --mode "minus"

clean:
rm -f ${DIR}/g_samples_*.npy ${DIR}/interim_samples_*.npz
17 changes: 17 additions & 0 deletions experiments/exp40/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Experiment 4.0

Shear inference on images with more or less realistic noise conditions, only ellipticities are free.

- SNR is centered around ~15
- 10^4 galaxies
- Use jackknife + shape noise cancellation to estimate mean and error on mean shear.
- HLR is fixed to `0.8` where images of size `63` are sufficient.

## Reproducing results

```bash
make samples # wait for slurm job to finish
make collate
make shear
make figures
```
41 changes: 41 additions & 0 deletions experiments/exp40/collate_samples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/usr/bin/env python3
import os

import jax.numpy as jnp
import typer

from bpd import DATA_DIR
from bpd.io import load_dataset, save_dataset


def main(
seed: int, tag: str = typer.Option(), mode: str = typer.Option(), n_files: int = 4
):
assert mode in ("plus", "minus", "")
mode_txt = f"_{mode}" if mode else ""
dirpath = DATA_DIR / "cache_chains" / tag
newpath = dirpath / f"interim_samples_{seed}{mode_txt}.npz"

if newpath.exists():
os.remove(newpath)

full_ds = {}

for ii in range(n_files):
fp = dirpath / f"interim_samples_{seed}{ii}{mode_txt}.npz"
ds = load_dataset(fp)

for k, v in ds.items():
if k in ("e_post", "e1_true", "e2_true", "f"):
if k in full_ds:
full_ds[k] = jnp.concatenate([full_ds[k], v])
else:
full_ds[k] = ds[k]
else:
full_ds[k] = ds[k]

save_dataset(full_ds, newpath, overwrite=True)


if __name__ == "__main__":
typer.run(main)
Binary file added experiments/exp40/figs/42/contours_bias.pdf
Binary file not shown.
Binary file added experiments/exp40/figs/42/contours_minus.pdf
Binary file not shown.
Binary file added experiments/exp40/figs/42/contours_plus.pdf
Binary file not shown.
Binary file added experiments/exp40/figs/42/hists_minus.pdf
Binary file not shown.
Binary file added experiments/exp40/figs/42/hists_plus.pdf
Binary file not shown.
Binary file added experiments/exp40/figs/42/scatter_shapes.pdf
Binary file not shown.
18 changes: 18 additions & 0 deletions experiments/exp40/figs/42/summary.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#### Full results ####
Units: 1e-3

m_mean: -10.47
m_std: 29.65

c_mean: -0.03942
c_std: 1.104

#### Jackknife results ####
Units: 1e-3

m_mean: -10.77
m_std: 4.564

c_mean: -0.05475
c_std: 1.003

Binary file added experiments/exp40/figs/42/traces_minus.pdf
Binary file not shown.
Binary file added experiments/exp40/figs/42/traces_plus.pdf
Binary file not shown.
Binary file added experiments/exp40/figs/43/contours_bias.pdf
Binary file not shown.
Binary file added experiments/exp40/figs/43/contours_minus.pdf
Binary file not shown.
Binary file added experiments/exp40/figs/43/contours_plus.pdf
Binary file not shown.
Binary file added experiments/exp40/figs/43/hists_minus.pdf
Binary file not shown.
Binary file added experiments/exp40/figs/43/hists_plus.pdf
Binary file not shown.
Binary file added experiments/exp40/figs/43/scatter_shapes.pdf
Binary file not shown.
18 changes: 18 additions & 0 deletions experiments/exp40/figs/43/summary.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#### Full results ####
Units: 1e-3

m_mean: -18.61
m_std: 3.958

c_mean: -1.391
c_std: 1.225

#### Jackknife results ####
Units: 1e-3

m_mean: -18.92
m_std: 3.841

c_mean: -1.399
c_std: 1.298

Binary file added experiments/exp40/figs/43/traces_minus.pdf
Binary file not shown.
Binary file added experiments/exp40/figs/43/traces_plus.pdf
Binary file not shown.
Loading

0 comments on commit 50bbbf3

Please sign in to comment.