forked from spacetelescope/jwst
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
use scipy directly instead of astropy modeling for residual fringe fit (
- Loading branch information
1 parent
c96f2e7
commit 9615771
Showing
3 changed files
with
43 additions
and
95 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,92 +1,47 @@ | ||
import numpy as np | ||
|
||
from astropy.modeling.fitting import model_to_fit_params | ||
import scipy.interpolate | ||
|
||
|
||
class ChiSqOutlierRejectionFitter: | ||
"""Chi squared statistic for outlier rejection""" | ||
def _lsq_spline(x, y, weights, knots, degree): | ||
return scipy.interpolate.LSQUnivariateSpline(x, y, knots, w=weights, k=degree) | ||
|
||
def __init__(self, fitter, domain=None, tolerance=0.0001): | ||
self.fitter = fitter | ||
self.tolerance = tolerance | ||
|
||
if domain is None: | ||
self.domain = 10 | ||
else: | ||
self.domain = domain | ||
def spline_fitter(x, y, weights, knots, degree, reject_outliers=False, domain=10, tolerance=0.0001): | ||
if not reject_outliers: | ||
return _lsq_spline(x, y, weights, knots, degree) | ||
|
||
@staticmethod | ||
def kernel(x, weights=None): | ||
""" | ||
Weighting function dependent only on provided value (usualy a residual) | ||
""" | ||
# fit with chi sq outlier rejection | ||
# helpers | ||
def chi_sq(spline, weights): | ||
return np.nansum((y - spline(x)) ** 2 * weights) | ||
|
||
kernal = (np.where(x**2 <= 1, 1 - x**2, 0.))**2 | ||
if weights is not None: | ||
kernal *= weights | ||
|
||
return kernal | ||
# initial fit | ||
spline = _lsq_spline(x, y, weights, knots, degree) | ||
chi = chi_sq(spline, weights) | ||
|
||
@staticmethod | ||
def _chi(model, x, y, weights=None): | ||
# astropy code used the model params which pad the knots based on degree | ||
nparams = len(knots) + (degree + 1) * 2 | ||
deg_of_freedom = np.sum(weights) - nparams | ||
|
||
resid = (y - model(x))**2 | ||
if weights is not None: | ||
resid *= weights | ||
for _ in range(1000 * nparams): | ||
scale = np.sqrt(chi / deg_of_freedom) | ||
|
||
return np.nansum(resid) | ||
# Calculate new weights | ||
resid = (y - spline(x)) / (scale * domain) | ||
new_w = (np.where(resid**2 <= 1, 1 - resid**2, 0.))**2 * weights | ||
|
||
@staticmethod | ||
def _params(model): | ||
return model_to_fit_params(model)[0] | ||
# Fit new model and find chi | ||
spline = _lsq_spline(x, y, new_w, knots, degree) | ||
new_chi = chi_sq(spline, new_w) | ||
|
||
@staticmethod | ||
def _sum_weights(x, weights=None): | ||
if weights is None: | ||
return len(x) | ||
else: | ||
return np.sum(weights) | ||
# Check if fit has converged | ||
tol = tolerance if new_chi < 1 else tolerance * new_chi | ||
if np.abs(chi - new_chi) < tol: | ||
break | ||
chi = new_chi | ||
else: | ||
raise RuntimeError("Bad fit, method should have converged") | ||
|
||
def _deg_of_freedom(self, model, x, weights=None): | ||
nparams = len(self._params(model)) | ||
sum_weights = self._sum_weights(x, weights) | ||
|
||
return nparams, sum_weights - nparams | ||
|
||
@staticmethod | ||
def _scale(chi, deg_of_freedom): | ||
return np.sqrt(chi / deg_of_freedom) | ||
|
||
def __call__(self, model, x, y, weights=None, **kwargs): | ||
# Assume equal weights if none are provided | ||
|
||
new_model = model.copy() | ||
|
||
# perform the initial fit | ||
new_model = self.fitter(new_model, x, y, weights=weights, **kwargs) | ||
chi = self._chi(new_model, x, y, weights) | ||
|
||
# calculate degrees of freedom | ||
nparams, deg_of_freedom = self._deg_of_freedom(new_model, x, weights) | ||
|
||
# Iteratively adjust the weights until fit converges | ||
for _ in range(1000 * nparams): | ||
scale = self._scale(chi, deg_of_freedom) | ||
|
||
# Calculate new weights | ||
resid = (y - new_model(x)) / (scale * self.domain) | ||
new_w = self.kernel(resid, weights) | ||
|
||
# Fit new model and find chi | ||
new_model = self.fitter(new_model, x, y, weights=new_w, **kwargs) | ||
new_chi = self._chi(new_model, x, y, new_w) | ||
|
||
# Check if fit has converged | ||
tol = self.tolerance if new_chi < 1 else self.tolerance * new_chi | ||
if np.abs(chi - new_chi) < tol: | ||
break | ||
chi = new_chi | ||
else: | ||
raise RuntimeError("Bad fit, method should have converged") | ||
|
||
return new_model | ||
return spline |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters