Skip to content

Commit

Permalink
jax.numpy.clip: update use of deprecated arguments.
Browse files Browse the repository at this point in the history
- a is now positional-only
- a_min is now min
- a_max is now max

The old argument names have been deprecated since JAX v0.4.27.

PiperOrigin-RevId: 714108580
  • Loading branch information
Jake VanderPlas authored and The cascades Authors committed Jan 11, 2025
1 parent cd837a8 commit ff6c4d9
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion cascades/examples/notebooks/trice.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@
" # [mask, -mask] to mask out the contributions from weight=0 rationales (which\n",
" # has a very small chance of happening due to the clipping below).\n",
" flat_signs = jnp.concatenate([mask, -1 * mask])\n",
" flat_weights = jnp.clip(flat_weights, a_min=1e-10)\n",
" flat_weights = jnp.clip(flat_weights, min=1e-10)\n",
" weights_mean = flat_weights.sum() / (mask.sum() + 1e-10)\n",
" # Note: to compute the loss without subsampling, we can set\n",
" # subsampled_indices = jnp.arange(2 * len(questions))\n",
Expand Down

0 comments on commit ff6c4d9

Please sign in to comment.