diff --git a/bayesnewton/ops.py b/bayesnewton/ops.py index 747dadd..155cc2a 100644 --- a/bayesnewton/ops.py +++ b/bayesnewton/ops.py @@ -569,7 +569,7 @@ def _parallel_kf_mf(As, Qs, H, ys, noise_covs, m0, P0, masks, block_index): @vmap def build_block_diag(P_blocks): - P = Pzeros.at[0].add(P_blocks.flatten()) + P = Pzeros.at[block_index].add(P_blocks.flatten()) return P def build_mean(m):