diff --git a/examples/plot_sersic_mix.py b/examples/plot_sersic_mix.py new file mode 100644 index 0000000..7b99b21 --- /dev/null +++ b/examples/plot_sersic_mix.py @@ -0,0 +1,17 @@ +import gauss2d.fit as g2f +from lsst.multiprofit.plots import plot_sersicmix_interp +import matplotlib.pyplot as plt +import numpy as np +from scipy.interpolate import CubicSpline + +interps = { + "lin": (g2f.LinearSersicMixInterpolator(), "-"), + "gsl-csp": (g2f.GSLSersicMixInterpolator(interp_type=g2f.InterpType.cspline), (0, (8, 8))), + "scipy-csp": ((CubicSpline, {}), (0, (4, 4))), +} + +for n_low, n_hi in ((0.5, 0.7), (2.2, 4.4)): + n_ser = 10 ** np.linspace(np.log10(n_low), np.log10(n_hi), 100) + plot_sersicmix_interp(interps=interps, n_ser=n_ser, figsize=(10, 8)) + plt.tight_layout() + plt.show() diff --git a/python/lsst/multiprofit/fit_bootstrap_model.py b/python/lsst/multiprofit/fit_bootstrap_model.py index c26671d..2be09ab 100644 --- a/python/lsst/multiprofit/fit_bootstrap_model.py +++ b/python/lsst/multiprofit/fit_bootstrap_model.py @@ -61,6 +61,7 @@ class CatalogExposurePsfBootstrap(CatalogExposurePsfABC, SourceCatalogBootstrap) sigma_x: float sigma_y: float rho: float + nser: float @cached_property def centroid(self) -> g2.Centroid: @@ -142,9 +143,10 @@ class CatalogSourceFitterBootstrap(CatalogSourceFitterABC): background: float = 1e2 flux: float = 1e4 - sigma_x: float - sigma_y: float + reff_x: float + reff_y: float rho: float + nser: float def get_model_radec(self, source: Mapping[str, Any], cen_x: float, cen_y: float) -> tuple[float, float]: return float(cen_x), float(cen_y) @@ -162,7 +164,13 @@ def initialize_model( limits_y.max = float(observation.image.n_rows) init_component(comp1, cen_x=cenx, cen_y=ceny) init_component( - comp2, cen_x=cenx, cen_y=ceny, sigma_x=self.sigma_x, sigma_y=self.sigma_y, rho=self.rho + comp2, + cen_x=cenx, + cen_y=ceny, + reff_x=self.reff_x, + reff_y=self.reff_y, + rho=self.rho, + nser=self.nser, ) params_free = get_params_uniq(model, fixed=False) for param in params_free: diff --git a/python/lsst/multiprofit/fit_source.py b/python/lsst/multiprofit/fit_source.py index b54d5f9..52e0cd4 100644 --- a/python/lsst/multiprofit/fit_source.py +++ b/python/lsst/multiprofit/fit_source.py @@ -208,11 +208,11 @@ def schema( for band in bands: columns.append(ColumnInfo(key=f"{prefix_comp}{band}_flux{suffix}", dtype="f8")) - for idx, name_comp in enumerate(self.sersics.keys()): + for idx, (name_comp, comp) in enumerate(self.sersics.items()): prefix_comp = f"{name_comp}_" columns_comp = [ - ColumnInfo(key=f"{prefix_comp}sigma_x{suffix}", dtype="f8", unit=u.pix), - ColumnInfo(key=f"{prefix_comp}sigma_y{suffix}", dtype="f8", unit=u.pix), + ColumnInfo(key=f"{prefix_comp}reff_x{suffix}", dtype="f8", unit=u.pix), + ColumnInfo(key=f"{prefix_comp}reff_y{suffix}", dtype="f8", unit=u.pix), ColumnInfo(key=f"{prefix_comp}rho{suffix}", dtype="f8", unit=u.pix), ] for band in bands: @@ -222,6 +222,8 @@ def schema( unit=u.Unit(self.unit_flux) if self.unit_flux else None, ) ) + if not comp.sersicindex.fixed: + columns_comp.append(ColumnInfo(key=f"{prefix_comp}nser{suffix}", dtype="f8")) columns.extend(columns_comp) if self.convert_cen_xy_to_radec: columns.append(ColumnInfo(key=f"cen_ra{suffix}", dtype="f8", unit=u.deg)) @@ -539,7 +541,8 @@ def fit( logger.info(f"{id_source=} ({idx=}/{n_rows}) fit failed with known exception={e}") else: row[f"{prefix}unknown_flag"] = True - logger.info(f"{id_source=} ({idx=}/{n_rows}) fit failed with unexpected exception={e}") + logger.info(f"{id_source=} ({idx=}/{n_rows}) fit failed with unexpected exception={e}", + exc_info=1) return results def get_channels( diff --git a/python/lsst/multiprofit/plots.py b/python/lsst/multiprofit/plots.py index 037d70f..5639bc9 100644 --- a/python/lsst/multiprofit/plots.py +++ b/python/lsst/multiprofit/plots.py @@ -1,10 +1,13 @@ -import astropy +import astropy.table from collections import defaultdict from dataclasses import dataclass, field import gauss2d.fit as g2f from itertools import cycle -import numpy as np +import matplotlib as mpl +import matplotlib.figure import matplotlib.pyplot as plt +import numpy as np +import typing from .utils import get_params_uniq @@ -313,3 +316,134 @@ def plot_loglike( param.fixed = False return fig, ax + + +Interpolator: typing.TypeAlias = g2f.SersicMixInterpolator | tuple[typing.Type, dict[str, typing.Any]] + + +def plot_sersicmix_interp( + interps: dict[str, tuple[Interpolator, str | tuple]], n_ser: np.ndarray, **kwargs +) -> matplotlib.figure.Figure: + """Plot Gaussian mixture Sersic profile interpolated values. + + Parameters + ---------- + interps + Dict of interpolators by name. + n_ser + Array of Sersic index values to plot interpolated quantities for. + kwargs + Keyword arguments to pass to matplotlib.pyplot.subplots. + + Returns + ------- + The resulting figure. + """ + orders = { + name: interp.order + for name, (interp, _) in interps.items() + if isinstance(interp, g2f.SersicMixInterpolator) + } + order = set(orders.values()) + if not len(order) == 1: + raise ValueError(f"len(set({orders})) != 1; all interpolators must have the same order") + order = tuple(order)[0] + + cmap = mpl.cm.get_cmap("tab20b") + colors_ord = [None] * order + for i_ord in range(order): + colors_ord[i_ord] = cmap(i_ord / (order - 1.0)) + + n_ser_min = np.min(n_ser) + n_ser_max = np.max(n_ser) + knots = g2f.sersic_mix_knots(order=order) + n_knots = len(knots) + integrals_knots = np.empty((n_knots, order)) + sigmas_knots = np.empty((n_knots, order)) + n_ser_knots = np.empty(n_knots) + + i_knot_first = None + i_knot_last = n_knots + for i_knot, knot in enumerate(knots): + if i_knot_first is None: + if knot.sersicindex > n_ser_min: + i_knot_first = i_knot + else: + continue + if knot.sersicindex > n_ser_max: + i_knot_last = i_knot + break + n_ser_knots[i_knot] = knot.sersicindex + for i_ord in range(order): + values = knot.values[i_ord] + integrals_knots[i_knot, i_ord] = values.integral + sigmas_knots[i_knot, i_ord] = values.sigma + range_knots = range(i_knot_first, i_knot_last) + integrals_knots = integrals_knots[range_knots, :] + sigmas_knots = sigmas_knots[range_knots, :] + n_ser_knots = n_ser_knots[range_knots] + + n_values = len(n_ser) + integrals, dintegrals, sigmas, dsigmas = ( + {name: np.empty((n_values, order)) for name in interps} for _ in range(4) + ) + + for name, (interp, _) in interps.items(): + if not isinstance(interp, g2f.SersicMixInterpolator): + kwargs = interp[1] if interp[1] is not None else {} + interp = interp[0] + x = [knot.sersicindex for knot in knots] + for i_ord in range(order): + integrals_i = np.empty(n_knots, dtype=float) + sigmas_i = np.empty(n_knots, dtype=float) + for i_knot, knot in enumerate(knots): + integrals_i[i_knot] = knot.values[i_ord].integral + sigmas_i[i_knot] = knot.values[i_ord].sigma + interp_int = interp(x, integrals_i, **kwargs) + dinterp_int = interp_int.derivative() + interp_sigma = interp(x, sigmas_i, **kwargs) + dinterp_sigma = interp_sigma.derivative() + for i_val, value in enumerate(n_ser): + integrals[name][i_val, i_ord] = interp_int(value) + sigmas[name][i_val, i_ord] = interp_sigma(value) + dintegrals[name][i_val, i_ord] = dinterp_int(value) + dsigmas[name][i_val, i_ord] = dinterp_sigma(value) + + for i_val, value in enumerate(n_ser): + for name, (interp, _) in interps.items(): + if isinstance(interp, g2f.SersicMixInterpolator): + values = interp.integralsizes(value) + derivs = interp.integralsizes_derivs(value) + for i_ord in range(order): + integrals[name][i_val, i_ord] = values[i_ord].integral + sigmas[name][i_val, i_ord] = values[i_ord].sigma + dintegrals[name][i_val, i_ord] = derivs[i_ord].integral + dsigmas[name][i_val, i_ord] = derivs[i_ord].sigma + + fig, axes = plt.subplots(2, 2, **kwargs) + for idx_row, (yv, yd, yk, y_label) in ( + (0, (integrals, dintegrals, integrals_knots, "integral")), + (1, (sigmas, dsigmas, sigmas_knots, "sigma")), + ): + is_label_row = idx_row == 1 + for idx_col, y_i, y_prefix in ((0, yv, ""), (1, yd, "d")): + is_label_col = idx_col == 0 + make_label = is_label_col and is_label_row + axis = axes[idx_row, idx_col] + if is_label_col: + for i_ord in range(order): + axis.plot( + n_ser_knots, + yk[:, i_ord], + "kx", + label="knots" if make_label and (i_ord == 0) else None, + ) + for name, (_, lstyle) in interps.items(): + for i_ord in range(order): + label = f"{name}" if make_label and (i_ord == 0) else None + axis.plot(n_ser, y_i[name][:, i_ord], c=colors_ord[i_ord], label=label, linestyle=lstyle) + axis.set_xlim((n_ser_min, n_ser_max)) + axis.set_ylabel(f"{y_prefix}{y_label}") + if make_label: + axis.legend(loc="upper left") + return fig diff --git a/tests/test_fit_bootstrap_model.py b/tests/test_fit_bootstrap_model.py index 2e3c832..d2a9ef4 100644 --- a/tests/test_fit_bootstrap_model.py +++ b/tests/test_fit_bootstrap_model.py @@ -21,10 +21,16 @@ import gauss2d.fit as g2f from lsst.multiprofit.componentconfig import ( - GaussianConfig, init_component, ParameterConfig, SersicConfig, SersicIndexConfig, + GaussianConfig, + init_component, + ParameterConfig, + SersicConfig, + SersicIndexConfig, ) from lsst.multiprofit.fit_bootstrap_model import ( - CatalogExposurePsfBootstrap, CatalogExposureSourcesBootstrap, CatalogSourceFitterBootstrap, + CatalogExposurePsfBootstrap, + CatalogExposureSourcesBootstrap, + CatalogSourceFitterBootstrap, ) from lsst.multiprofit.fit_psf import CatalogPsfFitter, CatalogPsfFitterConfig from lsst.multiprofit.fit_source import CatalogSourceFitterConfig @@ -37,7 +43,7 @@ channels = (g2f.Channel.get("g"), g2f.Channel.get("r"), g2f.Channel.get("i")) shape_img = (23, 27) sigma_psf = 2.1 -sigma_x_src, sigma_y_src, rho_src = 2.5, 3.6, -0.25 +reff_x_src, reff_y_src, rho_src, nser_src = 2.5, 3.6, -0.25, 2.0 # TODO: These can be parameterized; should they be? compute_errors_no_covar = True @@ -46,14 +52,14 @@ plot = False -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def config_psf(): return CatalogPsfFitterConfig( - gaussians={'comp1': GaussianConfig(size=ParameterConfig(value_initial=sigma_psf))}, + gaussians={"comp1": GaussianConfig(size=ParameterConfig(value_initial=sigma_psf))}, ) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def config_source_fit(): # TODO: Separately test n_pointsources=0 and sersics={} return CatalogSourceFitterConfig( @@ -61,11 +67,11 @@ def config_source_fit(): n_pointsources=1, sersics={ "comp1": SersicConfig( - prior_size_mean=sigma_y_src, + prior_size_mean=reff_y_src, prior_size_stddev=1.0, - prior_axrat_mean=sigma_x_src/sigma_y_src, + prior_axrat_mean=reff_x_src / reff_y_src, prior_axrat_stddev=0.2, - sersicindex=SersicIndexConfig(fixed=True), + sersicindex=SersicIndexConfig(fixed=False, value_initial=1.0), ) }, convert_cen_xy_to_radec=False, @@ -74,13 +80,17 @@ def config_source_fit(): ) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def table_psf_fits(config_psf): fitter = CatalogPsfFitter() fits = { channel.name: fitter.fit( CatalogExposurePsfBootstrap( - sigma_x=sigma_x_src, sigma_y=sigma_y_src, rho=rho_src, n_sources=n_sources, + sigma_x=reff_x_src, + sigma_y=reff_y_src, + rho=rho_src, + nser=nser_src, + n_sources=n_sources, ), config_psf, ) @@ -92,7 +102,7 @@ def table_psf_fits(config_psf): def test_fit_psf(config_psf, table_psf_fits): for results in table_psf_fits.values(): assert len(results) == n_sources - assert np.sum(results['mpf_psf_unknown_flag']) == 0 + assert np.sum(results["mpf_psf_unknown_flag"]) == 0 assert all(np.isfinite(list(results[0].values()))) psfmodel = config_psf.rebuild_psfmodel(results[0]) assert len(psfmodel.components) == len(config_psf.gaussians) @@ -105,37 +115,43 @@ def test_fit_source(config_source_fit, table_psf_fits): init_component(model_source.components[1], sigma_x=sigma_psf, sigma_y=sigma_psf, rho=0) catexps = tuple( CatalogExposureSourcesBootstrap( - channel=channel, config_fit=config_source_fit, - model_source=model_source, table_psf_fits=table_psf_fits[channel.name], + channel=channel, + config_fit=config_source_fit, + model_source=model_source, + table_psf_fits=table_psf_fits[channel.name], n_sources=n_sources, ) for channel in channels ) - fitter = CatalogSourceFitterBootstrap(sigma_x=sigma_x_src, sigma_y=sigma_y_src, rho=rho_src) + fitter = CatalogSourceFitterBootstrap(reff_x=reff_x_src, reff_y=reff_y_src, rho=rho_src, nser=nser_src) catalog_multi = catexps[0].get_catalog() results = fitter.fit(catalog_multi=catalog_multi, catexps=catexps, config=config_source_fit) assert len(results) == n_sources - model = fitter.get_model(0, catalog_multi=catalog_multi, catexps=catexps, config=config_source_fit, - results=results) + model = fitter.get_model( + 0, catalog_multi=catalog_multi, catexps=catexps, config=config_source_fit, results=results + ) model_true = g2f.Model(data=model.data, psfmodels=model.psfmodels, sources=[model_source]) fitter.initialize_model(model_true, catalog_multi[0]) params_true = tuple(param.value for param in get_params_uniq(model_true, fixed=False)) - plot_catalog_bootstrap(results, histtype='step', paramvals_ref=params_true, - plot_total_fluxes=True, plot_colors=True) + plot_catalog_bootstrap( + results, histtype="step", paramvals_ref=params_true, plot_total_fluxes=True, plot_colors=True + ) if plot: import matplotlib.pyplot as plt + plt.show() - assert np.sum(results['mpf_unknown_flag']) == 0 + assert np.sum(results["mpf_unknown_flag"]) == 0 assert all(np.isfinite(list(results[0].values()))) variances = [] for return_negative in (False, True): variances.append( - fitter.modeller.compute_variances(model, transformed=False, - options=g2f.HessianOptions(return_negative=return_negative)) + fitter.modeller.compute_variances( + model, transformed=False, options=g2f.HessianOptions(return_negative=return_negative) + ) ) if return_negative: variances = np.array(variances) @@ -153,28 +169,28 @@ def test_fit_source(config_source_fit, table_psf_fits): img = obs.image.data img.flat = output.data.flat options_hessian = g2f.HessianOptions(return_negative=return_negative) - variances_bootstrap = fitter.modeller.compute_variances( - model, transformed=False, options=options_hessian) + variances_bootstrap = fitter.modeller.compute_variances(model, transformed=False, options=options_hessian) variances_bootstrap_diag = fitter.modeller.compute_variances( - model, transformed=False, options=options_hessian, use_diag_only=True) + model, transformed=False, options=options_hessian, use_diag_only=True + ) for obs, img_datum_old in zip(model.data, img_data_old): obs.image.data.flat = img_datum_old.flat variances_jac = fitter.modeller.compute_variances(model, transformed=False) variances_jac_diag = fitter.modeller.compute_variances(model, transformed=False, use_diag_only=True) errors_plot = { - 'inv_hess': ErrorValues(values=np.sqrt(variances[0]), - kwargs_plot={'linestyle': '-', 'color': 'r'}), - '-inv_hess': ErrorValues(values=np.sqrt(variances[1]), - kwargs_plot={'linestyle': '--', 'color': 'r'}), - 'inv_jac': ErrorValues(values=np.sqrt(variances_jac), - kwargs_plot={'linestyle': '-.', 'color': 'r'}), - 'boot_hess': ErrorValues(values=np.sqrt(variances_bootstrap), - kwargs_plot={'linestyle': '-', 'color': 'b'}), - 'boot_diag': ErrorValues(values=np.sqrt(variances_bootstrap_diag), - kwargs_plot={'linestyle': '--', 'color': 'b'}), - 'boot_jac_diag': ErrorValues(values=np.sqrt(variances_jac_diag), - kwargs_plot={'linestyle': '-.', 'color': 'm'}), + "inv_hess": ErrorValues(values=np.sqrt(variances[0]), kwargs_plot={"linestyle": "-", "color": "r"}), + "-inv_hess": ErrorValues(values=np.sqrt(variances[1]), kwargs_plot={"linestyle": "--", "color": "r"}), + "inv_jac": ErrorValues(values=np.sqrt(variances_jac), kwargs_plot={"linestyle": "-.", "color": "r"}), + "boot_hess": ErrorValues( + values=np.sqrt(variances_bootstrap), kwargs_plot={"linestyle": "-", "color": "b"} + ), + "boot_diag": ErrorValues( + values=np.sqrt(variances_bootstrap_diag), kwargs_plot={"linestyle": "--", "color": "b"} + ), + "boot_jac_diag": ErrorValues( + values=np.sqrt(variances_jac_diag), kwargs_plot={"linestyle": "-.", "color": "m"} + ), } fig, ax = plot_loglike(model, errors=errors_plot, values_reference=fitter.params_values_init) if plot: diff --git a/tests/test_modeller.py b/tests/test_modeller.py index 6ed642e..a7581ac 100644 --- a/tests/test_modeller.py +++ b/tests/test_modeller.py @@ -239,11 +239,11 @@ def get_sources(channels, config, limits: Limits, transforms: Transforms): (channel, g2f.IntegralParameterD(flux, label=channel.name)) for channel in channels.values() ] size = compconf.size_base + c*compconf.size_increment - # Needs to not be exactly 1.0 because of linear interpolation - # (this breaks finite differencing at any delta +/- a knot) - # ... and 1.0 and 4.0 are both knots - # TODO: Test more after DM-38616 - sersicindex = g2f.SersicMixComponentIndexParameterD(1.01 + 3*c) + sersicindex = g2f.SersicMixComponentIndexParameterD(1.0 + 3*c) + # Add a small offset if using linear interpolation + # n=1.0 should always be a knot and finite differencing breaks + # for linear interpolators right at knots + sersicindex.value += 1e-3*(sersicindex.interptype == g2f.InterpType.linear) ellipse = g2f.SersicParametricEllipse( g2f.ReffXParameterD(size, transform=transforms.x), g2f.ReffYParameterD(size, transform=transforms.y),