Skip to content

Commit

Permalink
Fix batching bug
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Mar 28, 2022
1 parent f5ce714 commit 613f590
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 11 deletions.
2 changes: 1 addition & 1 deletion stheno/model/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def sample(self, state: B.RandomState, n: B.Int, *fdds: FDD):
lengths = [num_elements(fdd) for fdd in fdds]
i, samples = 0, []
for length in lengths:
samples.append(sample[i : i + length, :])
samples.append(sample[..., i : i + length, :])
i += length
return (state,) + tuple(samples)

Expand Down
2 changes: 1 addition & 1 deletion stheno/model/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def combine(*fdds: FDD):
def combine(*pairs: tuple):
fdds, ys = zip(*pairs)
combined_fdd = combine(*fdds)
combined_y = B.concat(*[B.uprank(y) for y in ys], axis=0)
combined_y = B.concat(*[B.uprank(y) for y in ys], axis=-2)
return combined_fdd, combined_y


Expand Down
22 changes: 13 additions & 9 deletions tests/model/test_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,23 +132,27 @@ def test_blr():


def test_batched():
x = B.randn(16, 10, 1)
x1 = B.randn(16, 10, 1)
x2 = B.randn(16, 5, 1)

p = GP(2 * EQ().stretch(0.5))
y = p(x).sample()
logpdf = p(x, 0.1).logpdf(y)
y1, y2 = p.measure.sample(p(x1), p(x2))
logpdf = p.measure.logpdf((p(x1, 0.1), y1), (p(x2, 0.1), y2))

assert B.shape(y1) == (16, 10, 1)
assert B.shape(y2) == (16, 5, 1)
assert B.shape(logpdf) == (16,)
assert B.shape(y) == (16, 10, 1)

p = p | (p(x), y)
y2 = p(x).sample()
logpdf2 = p(x, 0.1).logpdf(y)
p = p | ((p(x1), y1), (p(x2), y2))
y1_2, y2_2 = p.measure.sample(p(x1), p(x2))
logpdf2 = p.measure.logpdf((p(x1, 0.1), y1), (p(x2, 0.1), y2))

assert B.shape(y1_2) == (16, 10, 1)
assert B.shape(y2_2) == (16, 5, 1)
approx(y1, y1_2, atol=1e-5)
approx(y2, y2_2, atol=1e-5)
assert B.shape(logpdf2) == (16,)
assert B.all(logpdf2 > logpdf)
assert B.shape(y2) == (16, 10, 1)
approx(y, y2, atol=1e-5)


def test_mo_batched():
Expand Down

0 comments on commit 613f590

Please sign in to comment.