Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-38616: Update Sersic index/profile usage #25

Merged
merged 5 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions examples/plot_sersic_mix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import gauss2d.fit as g2f
from lsst.multiprofit.plots import plot_sersicmix_interp
import matplotlib.pyplot as plt
import numpy as np
from scipy.interpolate import CubicSpline

interps = {
"lin": (g2f.LinearSersicMixInterpolator(), "-"),
"gsl-csp": (g2f.GSLSersicMixInterpolator(interp_type=g2f.InterpType.cspline), (0, (8, 8))),
"scipy-csp": ((CubicSpline, {}), (0, (4, 4))),
}

for n_low, n_hi in ((0.5, 0.7), (2.2, 4.4)):
n_ser = 10 ** np.linspace(np.log10(n_low), np.log10(n_hi), 100)
plot_sersicmix_interp(interps=interps, n_ser=n_ser, figsize=(10, 8))
plt.tight_layout()
plt.show()
14 changes: 11 additions & 3 deletions python/lsst/multiprofit/fit_bootstrap_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class CatalogExposurePsfBootstrap(CatalogExposurePsfABC, SourceCatalogBootstrap)
sigma_x: float
sigma_y: float
rho: float
nser: float

@cached_property
def centroid(self) -> g2.Centroid:
Expand Down Expand Up @@ -142,9 +143,10 @@ class CatalogSourceFitterBootstrap(CatalogSourceFitterABC):

background: float = 1e2
flux: float = 1e4
sigma_x: float
sigma_y: float
reff_x: float
reff_y: float
rho: float
nser: float

def get_model_radec(self, source: Mapping[str, Any], cen_x: float, cen_y: float) -> tuple[float, float]:
return float(cen_x), float(cen_y)
Expand All @@ -162,7 +164,13 @@ def initialize_model(
limits_y.max = float(observation.image.n_rows)
init_component(comp1, cen_x=cenx, cen_y=ceny)
init_component(
comp2, cen_x=cenx, cen_y=ceny, sigma_x=self.sigma_x, sigma_y=self.sigma_y, rho=self.rho
comp2,
cen_x=cenx,
cen_y=ceny,
reff_x=self.reff_x,
reff_y=self.reff_y,
rho=self.rho,
nser=self.nser,
)
params_free = get_params_uniq(model, fixed=False)
for param in params_free:
Expand Down
11 changes: 7 additions & 4 deletions python/lsst/multiprofit/fit_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,11 @@ def schema(
for band in bands:
columns.append(ColumnInfo(key=f"{prefix_comp}{band}_flux{suffix}", dtype="f8"))

for idx, name_comp in enumerate(self.sersics.keys()):
for idx, (name_comp, comp) in enumerate(self.sersics.items()):
prefix_comp = f"{name_comp}_"
columns_comp = [
ColumnInfo(key=f"{prefix_comp}sigma_x{suffix}", dtype="f8", unit=u.pix),
ColumnInfo(key=f"{prefix_comp}sigma_y{suffix}", dtype="f8", unit=u.pix),
ColumnInfo(key=f"{prefix_comp}reff_x{suffix}", dtype="f8", unit=u.pix),
ColumnInfo(key=f"{prefix_comp}reff_y{suffix}", dtype="f8", unit=u.pix),
ColumnInfo(key=f"{prefix_comp}rho{suffix}", dtype="f8", unit=u.pix),
]
for band in bands:
Expand All @@ -222,6 +222,8 @@ def schema(
unit=u.Unit(self.unit_flux) if self.unit_flux else None,
)
)
if not comp.sersicindex.fixed:
columns_comp.append(ColumnInfo(key=f"{prefix_comp}nser{suffix}", dtype="f8"))
columns.extend(columns_comp)
if self.convert_cen_xy_to_radec:
columns.append(ColumnInfo(key=f"cen_ra{suffix}", dtype="f8", unit=u.deg))
Expand Down Expand Up @@ -539,7 +541,8 @@ def fit(
logger.info(f"{id_source=} ({idx=}/{n_rows}) fit failed with known exception={e}")
else:
row[f"{prefix}unknown_flag"] = True
logger.info(f"{id_source=} ({idx=}/{n_rows}) fit failed with unexpected exception={e}")
logger.info(f"{id_source=} ({idx=}/{n_rows}) fit failed with unexpected exception={e}",
exc_info=1)
return results

def get_channels(
Expand Down
138 changes: 136 additions & 2 deletions python/lsst/multiprofit/plots.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import astropy
import astropy.table
from collections import defaultdict
from dataclasses import dataclass, field
import gauss2d.fit as g2f
from itertools import cycle
import numpy as np
import matplotlib as mpl
import matplotlib.figure
import matplotlib.pyplot as plt
import numpy as np
import typing

from .utils import get_params_uniq

Expand Down Expand Up @@ -313,3 +316,134 @@ def plot_loglike(
param.fixed = False

return fig, ax


Interpolator: typing.TypeAlias = g2f.SersicMixInterpolator | tuple[typing.Type, dict[str, typing.Any]]


def plot_sersicmix_interp(
interps: dict[str, tuple[Interpolator, str | tuple]], n_ser: np.ndarray, **kwargs
) -> matplotlib.figure.Figure:
"""Plot Gaussian mixture Sersic profile interpolated values.

Parameters
----------
interps
Dict of interpolators by name.
n_ser
Array of Sersic index values to plot interpolated quantities for.
kwargs
Keyword arguments to pass to matplotlib.pyplot.subplots.

Returns
-------
The resulting figure.
"""
orders = {
name: interp.order
for name, (interp, _) in interps.items()
if isinstance(interp, g2f.SersicMixInterpolator)
}
order = set(orders.values())
if not len(order) == 1:
raise ValueError(f"len(set({orders})) != 1; all interpolators must have the same order")
order = tuple(order)[0]

cmap = mpl.cm.get_cmap("tab20b")
colors_ord = [None] * order
for i_ord in range(order):
colors_ord[i_ord] = cmap(i_ord / (order - 1.0))

n_ser_min = np.min(n_ser)
n_ser_max = np.max(n_ser)
knots = g2f.sersic_mix_knots(order=order)
n_knots = len(knots)
integrals_knots = np.empty((n_knots, order))
sigmas_knots = np.empty((n_knots, order))
n_ser_knots = np.empty(n_knots)

i_knot_first = None
i_knot_last = n_knots
for i_knot, knot in enumerate(knots):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like you assume that knots is ordered by increasing sersic index. Is that something that you should check?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ones from gauss2dfit are, and I didn't provide a way to override them here so it would take some effort for users to thwart this.

if i_knot_first is None:
if knot.sersicindex > n_ser_min:
i_knot_first = i_knot
else:
continue
if knot.sersicindex > n_ser_max:
i_knot_last = i_knot
break
n_ser_knots[i_knot] = knot.sersicindex
for i_ord in range(order):
values = knot.values[i_ord]
integrals_knots[i_knot, i_ord] = values.integral
sigmas_knots[i_knot, i_ord] = values.sigma
range_knots = range(i_knot_first, i_knot_last)
integrals_knots = integrals_knots[range_knots, :]
sigmas_knots = sigmas_knots[range_knots, :]
n_ser_knots = n_ser_knots[range_knots]

n_values = len(n_ser)
integrals, dintegrals, sigmas, dsigmas = (
{name: np.empty((n_values, order)) for name in interps} for _ in range(4)
)

for name, (interp, _) in interps.items():
if not isinstance(interp, g2f.SersicMixInterpolator):
kwargs = interp[1] if interp[1] is not None else {}
interp = interp[0]
x = [knot.sersicindex for knot in knots]
for i_ord in range(order):
integrals_i = np.empty(n_knots, dtype=float)
sigmas_i = np.empty(n_knots, dtype=float)
for i_knot, knot in enumerate(knots):
integrals_i[i_knot] = knot.values[i_ord].integral
sigmas_i[i_knot] = knot.values[i_ord].sigma
interp_int = interp(x, integrals_i, **kwargs)
dinterp_int = interp_int.derivative()
interp_sigma = interp(x, sigmas_i, **kwargs)
dinterp_sigma = interp_sigma.derivative()
for i_val, value in enumerate(n_ser):
integrals[name][i_val, i_ord] = interp_int(value)
sigmas[name][i_val, i_ord] = interp_sigma(value)
dintegrals[name][i_val, i_ord] = dinterp_int(value)
dsigmas[name][i_val, i_ord] = dinterp_sigma(value)

for i_val, value in enumerate(n_ser):
for name, (interp, _) in interps.items():
if isinstance(interp, g2f.SersicMixInterpolator):
values = interp.integralsizes(value)
derivs = interp.integralsizes_derivs(value)
for i_ord in range(order):
integrals[name][i_val, i_ord] = values[i_ord].integral
sigmas[name][i_val, i_ord] = values[i_ord].sigma
dintegrals[name][i_val, i_ord] = derivs[i_ord].integral
dsigmas[name][i_val, i_ord] = derivs[i_ord].sigma

fig, axes = plt.subplots(2, 2, **kwargs)
for idx_row, (yv, yd, yk, y_label) in (
(0, (integrals, dintegrals, integrals_knots, "integral")),
(1, (sigmas, dsigmas, sigmas_knots, "sigma")),
):
is_label_row = idx_row == 1
for idx_col, y_i, y_prefix in ((0, yv, ""), (1, yd, "d")):
is_label_col = idx_col == 0
make_label = is_label_col and is_label_row
axis = axes[idx_row, idx_col]
if is_label_col:
for i_ord in range(order):
axis.plot(
n_ser_knots,
yk[:, i_ord],
"kx",
label="knots" if make_label and (i_ord == 0) else None,
)
for name, (_, lstyle) in interps.items():
for i_ord in range(order):
label = f"{name}" if make_label and (i_ord == 0) else None
axis.plot(n_ser, y_i[name][:, i_ord], c=colors_ord[i_ord], label=label, linestyle=lstyle)
axis.set_xlim((n_ser_min, n_ser_max))
axis.set_ylabel(f"{y_prefix}{y_label}")
if make_label:
axis.legend(loc="upper left")
return fig
Loading