diff --git a/stheno/model/measure.py b/stheno/model/measure.py index 043d79f..5eb652f 100644 --- a/stheno/model/measure.py +++ b/stheno/model/measure.py @@ -428,7 +428,7 @@ def sample(self, state: B.RandomState, n: B.Int, *fdds: FDD): lengths = [num_elements(fdd) for fdd in fdds] i, samples = 0, [] for length in lengths: - samples.append(sample[i : i + length, :]) + samples.append(sample[..., i : i + length, :]) i += length return (state,) + tuple(samples) diff --git a/stheno/model/observations.py b/stheno/model/observations.py index 3d548bb..42f4f56 100644 --- a/stheno/model/observations.py +++ b/stheno/model/observations.py @@ -42,7 +42,7 @@ def combine(*fdds: FDD): def combine(*pairs: tuple): fdds, ys = zip(*pairs) combined_fdd = combine(*fdds) - combined_y = B.concat(*[B.uprank(y) for y in ys], axis=0) + combined_y = B.concat(*[B.uprank(y) for y in ys], axis=-2) return combined_fdd, combined_y diff --git a/tests/model/test_cases.py b/tests/model/test_cases.py index 8482a7e..8169763 100644 --- a/tests/model/test_cases.py +++ b/tests/model/test_cases.py @@ -132,23 +132,27 @@ def test_blr(): def test_batched(): - x = B.randn(16, 10, 1) + x1 = B.randn(16, 10, 1) + x2 = B.randn(16, 5, 1) p = GP(2 * EQ().stretch(0.5)) - y = p(x).sample() - logpdf = p(x, 0.1).logpdf(y) + y1, y2 = p.measure.sample(p(x1), p(x2)) + logpdf = p.measure.logpdf((p(x1, 0.1), y1), (p(x2, 0.1), y2)) + assert B.shape(y1) == (16, 10, 1) + assert B.shape(y2) == (16, 5, 1) assert B.shape(logpdf) == (16,) - assert B.shape(y) == (16, 10, 1) - p = p | (p(x), y) - y2 = p(x).sample() - logpdf2 = p(x, 0.1).logpdf(y) + p = p | ((p(x1), y1), (p(x2), y2)) + y1_2, y2_2 = p.measure.sample(p(x1), p(x2)) + logpdf2 = p.measure.logpdf((p(x1, 0.1), y1), (p(x2, 0.1), y2)) + assert B.shape(y1_2) == (16, 10, 1) + assert B.shape(y2_2) == (16, 5, 1) + approx(y1, y1_2, atol=1e-5) + approx(y2, y2_2, atol=1e-5) assert B.shape(logpdf2) == (16,) assert B.all(logpdf2 > logpdf) - assert B.shape(y2) == (16, 10, 1) - approx(y, y2, atol=1e-5) def test_mo_batched():