diff --git a/blackjax/mcmc/metrics.py b/blackjax/mcmc/metrics.py index d86aec7d6..1e01e4559 100644 --- a/blackjax/mcmc/metrics.py +++ b/blackjax/mcmc/metrics.py @@ -65,8 +65,9 @@ def __call__( self, position: ArrayLikeTree, element: ArrayLikeTree, - inv: ArrayLikeTree, - trans: ArrayLikeTree, + *, + inv: bool, + trans: bool, ) -> ArrayLikeTree: ...