Skip to content

Commit

Permalink
Merge pull request #152 from danielward27/clip_spline
Browse files Browse the repository at this point in the history
Add clipping for numerical stability
  • Loading branch information
danielward27 authored Apr 22, 2024
2 parents 4ae7063 + 142c4b9 commit ee6a185
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions flowjax/bijections/rational_quadratic_spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ def transform(self, x, condition=None):
num = (yk1 - yk) * (sk * xi**2 + dk * xi * (1 - xi))
den = sk + (dk1 + dk - 2 * sk) * xi * (1 - xi)
y = yk + num / den # eq. 4

# avoid numerical precision issues transforming from in -> out of bounds
y = jnp.clip(y, -self.interval, self.interval)
return jnp.where(in_bounds, y, x)

def transform_and_log_det(self, x, condition=None):
Expand All @@ -129,6 +132,9 @@ def inverse(self, y, condition=None):
sqrt_term = jnp.sqrt(b**2 - 4 * a * c)
xi = (2 * c) / (-b - sqrt_term)
x = xi * (xk1 - xk) + xk

# avoid numerical precision issues transforming from in -> out of bounds
x = jnp.clip(x, -self.interval, self.interval)
return jnp.where(in_bounds, x, y)

def inverse_and_log_det(self, y, condition=None):
Expand Down

0 comments on commit ee6a185

Please sign in to comment.