Skip to content

Commit

Permalink
Merge branch 'tickets/DM-43357'
Browse files Browse the repository at this point in the history
  • Loading branch information
taranu committed May 15, 2024
2 parents 8f95826 + 7c25e83 commit 2631d95
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 18 deletions.
9 changes: 9 additions & 0 deletions python/lsst/multiprofit/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ def column_name(cls) -> str:
"""Return the standard column name for this error."""


class NoDataError(CatalogError):
"""RuntimeError for when there is no data to fit."""

@classmethod
@abstractmethod
def column_name(cls) -> str:
return "no_data_flag"


class PsfRebuildFitFlagError(RuntimeError):
"""RuntimeError for when a PSF can't be rebuilt because the fit failed."""

Expand Down
19 changes: 13 additions & 6 deletions python/lsst/multiprofit/fit_psf.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,14 @@ def fit(
if len(flux_total) != 1:
raise RuntimeError(f"len({flux_total=}) != 1; PSF model is badly-formed")
flux_total = flux_total[0]
gaussians_linear = None
if config.fit_linear_init:
# The total flux must be freed first or else LinearGaussians.make
# will fail to find the required number of free linear params
flux_total.fixed = False
gaussians_linear = LinearGaussians.make(model_source, is_psf=True)
flux_total.fixed = True

# TODO: Remove isinstance when channel filtering is fixed
fluxfracs = tuple(
param
Expand Down Expand Up @@ -506,6 +514,7 @@ def fit(
# dummy size for first iteration
size, size_new = 0, 0
fitInputs = FitInputsDummy()
catexp_get_psf_image = catexp.get_psf_image

for idx in range_idx:
time_init = time.process_time()
Expand All @@ -516,22 +525,19 @@ def fit(

try:
self.check_source(source, config=config)
img_psf = catexp.get_psf_image(source)
img_psf = catexp_get_psf_image(source)
data = self._get_data(img_psf)
model = g2f.Model(data=data, psfmodels=[model_psf], sources=[model_source], priors=priors)
self.initialize_model(model=model, config_data=config_data)

# Caches the jacobian residual if the kernel size is unchanged
if img_psf.size != size:
fitInputs = None
size = np.copy(img_psf.size)
size = int(img_psf.size)
else:
fitInputs = fitInputs if not fitInputs.validate_for_model(model) else None

if config.fit_linear_init:
flux_total.fixed = False
gaussians_linear = LinearGaussians.make(model_source, is_psf=True)
flux_total.fixed = True
result = self.modeller.fit_gaussians_linear(gaussians_linear, data[0])
result = list(result.values())[0]
# Re-normalize fluxes (hopefully close already)
Expand All @@ -544,7 +550,7 @@ def fit(
result /= np.sum(result)
for idx_param, param in enumerate(fluxfracs):
param.value = result[idx_param]
result /= np.sum(result[idx_param + 1 :])
result /= np.sum(result[idx_param + 1:])

result_full = self.modeller.fit_model(model, fitinputs=fitInputs, **kwargs)
fitInputs = result_full.inputs
Expand Down Expand Up @@ -611,6 +617,7 @@ def initialize_model(
centroid.x_param.limits = limits_x
centroid.y_param.value = cen_y
centroid.y_param.limits = limits_y
centroids.add(centroid)
ellipse = component.ellipse
ellipse.size_x_param.limits = limits_x
ellipse.size_x = config_comp.size_x.value_initial
Expand Down
18 changes: 12 additions & 6 deletions python/lsst/multiprofit/fit_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from pydantic.dataclasses import dataclass

from .componentconfig import Fluxes
from .errors import NoDataError
from .fit_catalog import CatalogExposureABC, CatalogFitterConfig, ColumnInfo
from .modelconfig import ModelConfig
from .modeller import FitInputsDummy, Modeller
Expand All @@ -61,7 +62,7 @@ def channel(self) -> g2f.Channel:
"""Return the exposure's associated channel object."""

@abstractmethod
def get_psf_model(self, params: Mapping[str, Any]) -> g2f.PsfModel:
def get_psf_model(self, params: Mapping[str, Any]) -> g2f.PsfModel | None:
"""Get the PSF model for a given source row.
Parameters
Expand All @@ -73,7 +74,9 @@ def get_psf_model(self, params: Mapping[str, Any]) -> g2f.PsfModel:
Returns
-------
psf_model : `gauss2d.fit.PsfModel`
A PsfModel object initialized with the best-fit parameters.
A PsfModel object initialized with the best-fit parameters, or None
if PSF rebuilding failed for an expected reason (i.e. the input PSF
fit table has a flag set).
"""

@abstractmethod
Expand Down Expand Up @@ -142,11 +145,12 @@ def make_model_data(
source = catexp.get_catalog()[idx_row]
observation = catexp.get_source_observation(source)
if observation is not None:
observations.append(observation)
psf_model = catexp.get_psf_model(source)
for param in get_params_uniq(psf_model):
param.fixed = True
psf_models.append(psf_model)
if psf_model is not None:
observations.append(observation)
for param in get_params_uniq(psf_model):
param.fixed = True
psf_models.append(psf_model)

data = g2f.Data(observations)
return data, psf_models
Expand Down Expand Up @@ -553,6 +557,8 @@ def fit(

try:
data, psf_models = config.make_model_data(idx_row=idx, catexps=catexps)
if data.size == 0:
raise NoDataError("make_model_data returned empty data")
model = g2f.Model(data=data, psfmodels=psf_models, sources=model_sources, priors=priors)
self.initialize_model(
model, source_multi, catexps, values_init,
Expand Down
4 changes: 3 additions & 1 deletion python/lsst/multiprofit/modeller.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,17 +582,19 @@ def residual_func(params_new, model_jac, model_ll, params, result, jac):
time_init = time.process_time()
if config_fit.eval_residual:
model_ll.evaluate()
result.n_eval_resid += 1
else:
model_jac.evaluate()
result.n_eval_jac += 1
result.time_eval += time.process_time() - time_init
result.n_eval_resid += 1
return -result.inputs.residual

def jacobian_func(params_new, model_jac, model_ll, params, result, jac):
if result.config.eval_residual:
time_init = time.process_time()
model_jac.evaluate()
result.time_eval += time.process_time() - time_init
result.n_eval_jac += 1
return jac

if config.eval_residual:
Expand Down
44 changes: 41 additions & 3 deletions python/lsst/multiprofit/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from collections import defaultdict
from dataclasses import dataclass, field
from itertools import cycle
import math
from typing import Any, Iterable, Type, TypeAlias

import astropy.table
Expand Down Expand Up @@ -387,6 +388,7 @@ def plot_model_rgb(
weights: dict[str, float] | None = None,
high_sn_threshold: float | None = None,
plot_singleband: bool = True,
plot_chi_hist: bool = True,
chi_max: float = 5.,
rgb_min_auto: bool = False,
rgb_stretch_auto: bool = False,
Expand All @@ -405,6 +407,8 @@ def plot_model_rgb(
pixels having a model S/N above this threshold in every band.
plot_singleband
Whether to make grayscale plots for each band.
plot_chi_hist
Whether to plot histograms of the chi (scaled residual) values.
chi_max
The maximum absolute value of chi in residual plots. Values of 3-5 are
suitable for good models while inadequate ones may need larger values.
Expand Down Expand Up @@ -462,6 +466,8 @@ def plot_model_rgb(
if has_model:
observations = {}
else:
if plot_chi_hist:
raise ValueError("Cannot plot chi histograms without a model")
obs_kwarg = kwargs.pop("observations")
observations = {band: obs_kwarg[band] for band in bands}

Expand Down Expand Up @@ -533,11 +539,18 @@ def add_if_not_none(array, index, arg):
else:
array[index] = arg

chis_unweighted = {}

for idx_band, (band, weight) in enumerate(weights.items()):
observation = observations[band]
if has_model:
model_band = models[band]
variance_band = observation.sigma_inv.data ** -2
sigma_inv = observation.sigma_inv.data
variance_band = sigma_inv ** -2
if plot_chi_hist:
chi_good = (sigma_inv > 0) & np.isfinite(sigma_inv)
chi_unweighted = (observation.image.data[chi_good] - model_band[chi_good])*sigma_inv[chi_good]
chis_unweighted[band] = chi_unweighted
weight_channel_new = weights_channel[idx_band]
idx_channel_new = int(weight_channel_new // 1)
if idx_channel_new == idx_channel:
Expand Down Expand Up @@ -593,9 +606,13 @@ def add_if_not_none(array, index, arg):
img_model_rgb = apVis.make_lupton_rgb(*images_model, **kwargs)
aspect = np.clip((y_max - y_min) / (x_max - x_min), 0.25, 4)

fig_rgb, ax_rgb = plt.subplots(1 + has_model, 1 + has_model, figsize=(16, 16 * aspect))
n_rows = 1 + has_model
n_cols = 1 + has_model * (1 + plot_chi_hist)
figsize = (8 * n_cols, 8 * n_rows * aspect)

fig_rgb, ax_rgb = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=figsize)
fig_gs, ax_gs = (None, None) if not plot_singleband else plt.subplots(
nrows=n_bands, ncols=1 + has_model, figsize=(8 * (1 + has_model), 8 * aspect * n_bands)
nrows=n_bands, ncols=n_cols, figsize=(figsize[0], 8*aspect*n_bands),
)
(ax_rgb[0][0] if has_model else ax_rgb).imshow(img_rgb, extent=extent, origin="lower")
(ax_rgb[0][0] if has_model else ax_rgb).set_title("Data")
Expand Down Expand Up @@ -658,6 +675,27 @@ def add_if_not_none(array, index, arg):
ax_rgb[0][1].set_title(f"Residual (abs., += {resid_max:.3e})")
ax_rgb[0][1].tick_params(labelleft=False)

if plot_chi_hist:
cmap = mpl.colormaps["coolwarm"]
residuals_rgb = np.concatenate(tuple(chis_unweighted.values()))
residuals_abs = np.abs(residuals_rgb)
n_resid = len(residuals_abs)
chi_max = 5 + 2.5*((np.sum(residuals_abs > 5)/n_resid > 0.1)
+ (np.sum(residuals_abs > 7.5)/n_resid > 0.1))
n_bins = int(math.ceil(np.clip(n_resid/50, 2, 20))*chi_max)
# ax_rgb[0][2].set_adjustable('box')
ax_rgb[0][2].hist(
np.clip(residuals_rgb, -chi_max, chi_max),
bins=n_bins, histtype="step", label="all",
)
band_colors = cmap(np.linspace(0, 1, n_bands))
for band, band_color in zip(bands, band_colors):
ax_rgb[0][2].hist(
np.clip(residuals_rgb, -chi_max, chi_max),
bins=n_bins, histtype="step", label=band,
)
ax_rgb[0][2].legend()

residual_rgb = np.stack(
[
(np.clip(residuals[idx] * images_sigma_inv[idx], -chi_max, chi_max) + chi_max) / (2*chi_max)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def model(channels, data, psf_models):

def test_plot_model_rgb(model):
fig, ax, fig_gs, ax_gs, *_ = plot_model_rgb(
model, minimum=0, stretch=0.15, Q=4, weights=bands_weights_lsst,
model, minimum=0, stretch=0.15, Q=4, weights=bands_weights_lsst, plot_chi_hist=True,
)
assert fig is not None
assert ax is not None
Expand All @@ -121,7 +121,7 @@ def test_plot_model_rgb(model):
def test_plot_model_rgb_auto(model):
fig, ax, *_ = plot_model_rgb(
model, Q=6, weights=bands_weights_lsst, rgb_min_auto=True, rgb_stretch_auto=True,
plot_singleband=False,
plot_singleband=False, plot_chi_hist=False,
)
assert fig is not None
assert ax is not None

0 comments on commit 2631d95

Please sign in to comment.