From 8abfed02b508e117e515376ce4cab3dd72fce061 Mon Sep 17 00:00:00 2001 From: MDM988 Date: Wed, 17 Jan 2024 15:19:46 -0500 Subject: [PATCH] bugfix for concatenate array splitting --- flowjax/bijections/concatenate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flowjax/bijections/concatenate.py b/flowjax/bijections/concatenate.py index 2c7d3ad2..21947aa3 100644 --- a/flowjax/bijections/concatenate.py +++ b/flowjax/bijections/concatenate.py @@ -35,7 +35,7 @@ def __init__(self, bijections: Sequence[AbstractBijection], axis: int = 0): self.shape = ( shapes[0][:axis] + (sum(s[axis] for s in shapes),) + shapes[0][axis + 1 :] ) - self.split_idxs = jnp.array([s[axis] for s in shapes[:-1]]) + self.split_idxs = jnp.cumsum(jnp.array([s[axis] for s in shapes[:-1]])) self.cond_shape = merge_cond_shapes([b.cond_shape for b in bijections]) def transform(self, x, condition=None):