-
-
Notifications
You must be signed in to change notification settings - Fork 984
/
Copy pathtest_valid_models.py
2782 lines (2296 loc) · 91.1 KB
/
test_valid_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import logging
import warnings
from collections import defaultdict
import pytest
import torch
from torch.distributions import constraints
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.distributions.testing import fakes
from pyro.infer import (
SVI,
EnergyDistance,
Trace_ELBO,
TraceEnum_ELBO,
TraceGraph_ELBO,
TraceMeanField_ELBO,
TraceTailAdaptive_ELBO,
config_enumerate,
)
from pyro.infer.reparam import LatentStableReparam
from pyro.infer.tracetmc_elbo import TraceTMC_ELBO
from pyro.infer.util import torch_item
from pyro.ops.indexing import Vindex
from pyro.optim import Adam
from pyro.poutine.plate_messenger import block_plate
from tests.common import assert_close
logger = logging.getLogger(__name__)
# This file tests a variety of model,guide pairs with valid and invalid structure.
def EnergyDistance_prior(**kwargs):
kwargs["prior_scale"] = 0.0
kwargs.pop("strict_enumeration_warning", None)
return EnergyDistance(**kwargs)
def EnergyDistance_noprior(**kwargs):
kwargs["prior_scale"] = 1.0
kwargs.pop("strict_enumeration_warning", None)
return EnergyDistance(**kwargs)
def assert_ok(model, guide, elbo, **kwargs):
"""
Assert that inference works without warnings or errors.
"""
pyro.clear_param_store()
inference = SVI(model, guide, Adam({"lr": 1e-6}), elbo)
inference.step(**kwargs)
try:
pyro.set_rng_seed(0)
loss = elbo.loss(model, guide, **kwargs)
if hasattr(elbo, "differentiable_loss"):
try:
pyro.set_rng_seed(0)
differentiable_loss = torch_item(
elbo.differentiable_loss(model, guide, **kwargs)
)
except ValueError:
pass # Ignore cases where elbo cannot be differentiated
else:
assert_close(differentiable_loss, loss, atol=0.01)
if hasattr(elbo, "loss_and_grads"):
pyro.set_rng_seed(0)
loss_and_grads = elbo.loss_and_grads(model, guide, **kwargs)
assert_close(loss_and_grads, loss, atol=0.01)
except NotImplementedError:
pass # Ignore cases where loss isn't implemented, eg. TraceTailAdaptive_ELBO
def assert_error(model, guide, elbo, match=None):
"""
Assert that inference fails with an error.
"""
pyro.clear_param_store()
inference = SVI(model, guide, Adam({"lr": 1e-6}), elbo)
with pytest.raises(
(NotImplementedError, UserWarning, KeyError, ValueError, RuntimeError),
match=match,
):
inference.step()
def assert_warning(model, guide, elbo):
"""
Assert that inference works but with a warning.
"""
pyro.clear_param_store()
inference = SVI(model, guide, Adam({"lr": 1e-6}), elbo)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
inference.step()
assert len(w), "No warnings were raised"
for warning in w:
logger.info(warning)
@pytest.mark.parametrize(
"Elbo",
[
Trace_ELBO,
TraceGraph_ELBO,
TraceEnum_ELBO,
TraceTMC_ELBO,
EnergyDistance_prior,
EnergyDistance_noprior,
],
)
@pytest.mark.parametrize("strict_enumeration_warning", [True, False])
def test_nonempty_model_empty_guide_ok(Elbo, strict_enumeration_warning):
def model():
loc = torch.tensor([0.0, 0.0])
scale = torch.tensor([1.0, 1.0])
pyro.sample("x", dist.Normal(loc, scale).to_event(1), obs=loc)
def guide():
pass
elbo = Elbo(strict_enumeration_warning=strict_enumeration_warning)
if strict_enumeration_warning and Elbo in (TraceEnum_ELBO, TraceTMC_ELBO):
assert_warning(model, guide, elbo)
else:
assert_ok(model, guide, elbo)
@pytest.mark.parametrize(
"Elbo",
[
Trace_ELBO,
TraceGraph_ELBO,
TraceEnum_ELBO,
TraceTMC_ELBO,
EnergyDistance_prior,
EnergyDistance_noprior,
],
)
@pytest.mark.parametrize("strict_enumeration_warning", [True, False])
def test_nonempty_model_empty_guide_error(Elbo, strict_enumeration_warning):
def model():
pyro.sample("x", dist.Normal(0, 1))
def guide():
pass
elbo = Elbo(strict_enumeration_warning=strict_enumeration_warning)
assert_error(model, guide, elbo)
@pytest.mark.parametrize(
"Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]
)
@pytest.mark.parametrize("strict_enumeration_warning", [True, False])
def test_empty_model_empty_guide_ok(Elbo, strict_enumeration_warning):
def model():
pass
def guide():
pass
elbo = Elbo(strict_enumeration_warning=strict_enumeration_warning)
if strict_enumeration_warning and Elbo in (TraceEnum_ELBO, TraceTMC_ELBO):
assert_warning(model, guide, elbo)
else:
assert_ok(model, guide, elbo)
@pytest.mark.parametrize(
"Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]
)
def test_variable_clash_in_model_error(Elbo):
def model():
p = torch.tensor(0.5)
pyro.sample("x", dist.Bernoulli(p))
pyro.sample("x", dist.Bernoulli(p)) # Should error here.
def guide():
p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
pyro.sample("x", dist.Bernoulli(p))
assert_error(model, guide, Elbo(), match="Multiple sample sites named")
@pytest.mark.parametrize(
"Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]
)
def test_model_guide_dim_mismatch_error(Elbo):
def model():
loc = torch.zeros(2)
scale = torch.ones(2)
pyro.sample("x", dist.Normal(loc, scale).to_event(1))
def guide():
loc = pyro.param("loc", torch.zeros(2, 1, requires_grad=True))
scale = pyro.param("scale", torch.ones(2, 1, requires_grad=True))
pyro.sample("x", dist.Normal(loc, scale).to_event(2))
assert_error(
model,
guide,
Elbo(strict_enumeration_warning=False),
match="invalid log_prob shape|Model and guide event_dims disagree",
)
@pytest.mark.parametrize(
"Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]
)
def test_model_guide_shape_mismatch_error(Elbo):
def model():
loc = torch.zeros(1, 2)
scale = torch.ones(1, 2)
pyro.sample("x", dist.Normal(loc, scale).to_event(2))
def guide():
loc = pyro.param("loc", torch.zeros(2, 1, requires_grad=True))
scale = pyro.param("scale", torch.ones(2, 1, requires_grad=True))
pyro.sample("x", dist.Normal(loc, scale).to_event(2))
assert_error(
model,
guide,
Elbo(strict_enumeration_warning=False),
match="Model and guide shapes disagree",
)
@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO])
def test_variable_clash_in_guide_error(Elbo):
def model():
p = torch.tensor(0.5)
pyro.sample("x", dist.Bernoulli(p))
def guide():
p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
pyro.sample("x", dist.Bernoulli(p))
pyro.sample("x", dist.Bernoulli(p)) # Should error here.
assert_error(model, guide, Elbo(), match="Multiple sample sites named")
@pytest.mark.parametrize("has_rsample", [False, True])
@pytest.mark.parametrize(
"Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]
)
def test_set_has_rsample_ok(has_rsample, Elbo):
# This model has sparse gradients, so users may want to disable
# reparametrized sampling to reduce variance of gradient estimates.
# However both versions should be correct, i.e. with or without has_rsample.
def model():
z = pyro.sample("z", dist.Normal(0, 1))
loc = (z * 100).clamp(min=0, max=1) # sparse gradients
pyro.sample("x", dist.Normal(loc, 1), obs=torch.tensor(0.0))
def guide():
loc = pyro.param("loc", torch.tensor(0.0))
pyro.sample("z", dist.Normal(loc, 1).has_rsample_(has_rsample))
if Elbo is TraceEnum_ELBO:
guide = config_enumerate(guide)
elif Elbo is TraceTMC_ELBO:
guide = config_enumerate(guide, num_samples=2)
assert_ok(model, guide, Elbo(strict_enumeration_warning=False))
@pytest.mark.parametrize(
"Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]
)
def test_not_has_rsample_ok(Elbo):
def model():
z = pyro.sample("z", dist.Normal(0, 1))
p = z.round().clamp(min=0.2, max=0.8) # discontinuous
pyro.sample("x", dist.Bernoulli(p), obs=torch.tensor(0.0))
def guide():
loc = pyro.param("loc", torch.tensor(0.0))
pyro.sample("z", dist.Normal(loc, 1).has_rsample_(False))
if Elbo is TraceEnum_ELBO:
guide = config_enumerate(guide)
elif Elbo is TraceTMC_ELBO:
guide = config_enumerate(guide, num_samples=2)
assert_ok(model, guide, Elbo(strict_enumeration_warning=False))
@pytest.mark.parametrize("subsample_size", [None, 2], ids=["full", "subsample"])
@pytest.mark.parametrize(
"Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]
)
def test_iplate_ok(subsample_size, Elbo):
def model():
p = torch.tensor(0.5)
for i in pyro.plate("plate", 4, subsample_size):
pyro.sample("x_{}".format(i), dist.Bernoulli(p))
def guide():
p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
for i in pyro.plate("plate", 4, subsample_size):
pyro.sample("x_{}".format(i), dist.Bernoulli(p))
if Elbo is TraceEnum_ELBO:
guide = config_enumerate(guide)
elif Elbo is TraceTMC_ELBO:
guide = config_enumerate(guide, num_samples=2)
assert_ok(model, guide, Elbo())
@pytest.mark.parametrize(
"Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]
)
def test_iplate_variable_clash_error(Elbo):
def model():
p = torch.tensor(0.5)
for i in pyro.plate("plate", 2):
# Each loop iteration should give the sample site a different name.
pyro.sample("x", dist.Bernoulli(p))
def guide():
p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
for i in pyro.plate("plate", 2):
# Each loop iteration should give the sample site a different name.
pyro.sample("x", dist.Bernoulli(p))
if Elbo is TraceEnum_ELBO:
guide = config_enumerate(guide)
elif Elbo is TraceTMC_ELBO:
guide = config_enumerate(guide, num_samples=2)
assert_error(model, guide, Elbo(), match="Multiple sample sites named")
@pytest.mark.parametrize("subsample_size", [None, 5], ids=["full", "subsample"])
@pytest.mark.parametrize(
"Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]
)
def test_plate_ok(subsample_size, Elbo):
def model():
p = torch.tensor(0.5)
with pyro.plate("plate", 10, subsample_size) as ind:
pyro.sample("x", dist.Bernoulli(p).expand_by([len(ind)]))
def guide():
p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
with pyro.plate("plate", 10, subsample_size) as ind:
pyro.sample("x", dist.Bernoulli(p).expand_by([len(ind)]))
if Elbo is TraceEnum_ELBO:
guide = config_enumerate(guide)
elif Elbo is TraceTMC_ELBO:
guide = config_enumerate(guide, num_samples=2)
assert_ok(model, guide, Elbo())
@pytest.mark.parametrize("subsample_size", [None, 5], ids=["full", "subsample"])
@pytest.mark.parametrize(
"Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]
)
def test_plate_subsample_param_ok(subsample_size, Elbo):
def model():
p = torch.tensor(0.5)
with pyro.plate("plate", 10, subsample_size):
pyro.sample("x", dist.Bernoulli(p))
def guide():
with pyro.plate("plate", 10, subsample_size) as ind:
p0 = pyro.param("p0", torch.tensor(0.0), event_dim=0)
assert p0.shape == ()
p = pyro.param("p", 0.5 * torch.ones(10), event_dim=0)
assert len(p) == len(ind)
pyro.sample("x", dist.Bernoulli(p))
if Elbo is TraceEnum_ELBO:
guide = config_enumerate(guide)
elif Elbo is TraceTMC_ELBO:
guide = config_enumerate(guide, num_samples=2)
assert_ok(model, guide, Elbo())
@pytest.mark.parametrize("subsample_size", [None, 5], ids=["full", "subsample"])
@pytest.mark.parametrize(
"Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]
)
def test_plate_subsample_primitive_ok(subsample_size, Elbo):
def model():
p = torch.tensor(0.5)
with pyro.plate("plate", 10, subsample_size):
pyro.sample("x", dist.Bernoulli(p))
def guide():
with pyro.plate("plate", 10, subsample_size) as ind:
p0 = torch.tensor(0.0)
p0 = pyro.subsample(p0, event_dim=0)
assert p0.shape == ()
p = 0.5 * torch.ones(10)
p = pyro.subsample(p, event_dim=0)
assert len(p) == len(ind)
pyro.sample("x", dist.Bernoulli(p))
if Elbo is TraceEnum_ELBO:
guide = config_enumerate(guide)
elif Elbo is TraceTMC_ELBO:
guide = config_enumerate(guide, num_samples=2)
assert_ok(model, guide, Elbo())
@pytest.mark.parametrize("subsample_size", [None, 5], ids=["full", "subsample"])
@pytest.mark.parametrize(
"Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]
)
@pytest.mark.parametrize(
"shape,ok",
[
((), True),
((1,), True),
((10,), True),
((3, 1), True),
((3, 10), True),
((5), False),
((3, 5), False),
],
)
def test_plate_param_size_mismatch_error(subsample_size, Elbo, shape, ok):
def model():
p = torch.tensor(0.5)
with pyro.plate("plate", 10, subsample_size):
pyro.sample("x", dist.Bernoulli(p))
def guide():
with pyro.plate("plate", 10, subsample_size):
pyro.param("p0", torch.ones(shape), event_dim=0)
p = pyro.param("p", torch.ones(10), event_dim=0)
pyro.sample("x", dist.Bernoulli(p))
if Elbo is TraceEnum_ELBO:
guide = config_enumerate(guide)
elif Elbo is TraceTMC_ELBO:
guide = config_enumerate(guide, num_samples=2)
if ok:
assert_ok(model, guide, Elbo())
else:
assert_error(model, guide, Elbo(), match="invalid shape of pyro.param")
@pytest.mark.parametrize(
"Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]
)
def test_plate_no_size_ok(Elbo):
def model():
p = torch.tensor(0.5)
with pyro.plate("plate"):
pyro.sample("x", dist.Bernoulli(p).expand_by([10]))
def guide():
p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
with pyro.plate("plate"):
pyro.sample("x", dist.Bernoulli(p).expand_by([10]))
if Elbo is TraceEnum_ELBO:
guide = config_enumerate(guide)
elif Elbo is TraceTMC_ELBO:
guide = config_enumerate(guide, default="parallel", num_samples=2)
assert_ok(model, guide, Elbo())
@pytest.mark.parametrize("max_plate_nesting", [0, float("inf")])
@pytest.mark.parametrize("subsample_size", [None, 2], ids=["full", "subsample"])
@pytest.mark.parametrize(
"Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]
)
def test_iplate_iplate_ok(subsample_size, Elbo, max_plate_nesting):
def model():
p = torch.tensor(0.5)
outer_iplate = pyro.plate("plate_0", 3, subsample_size)
inner_iplate = pyro.plate("plate_1", 3, subsample_size)
for i in outer_iplate:
for j in inner_iplate:
pyro.sample("x_{}_{}".format(i, j), dist.Bernoulli(p))
def guide():
p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
outer_iplate = pyro.plate("plate_0", 3, subsample_size)
inner_iplate = pyro.plate("plate_1", 3, subsample_size)
for i in outer_iplate:
for j in inner_iplate:
pyro.sample("x_{}_{}".format(i, j), dist.Bernoulli(p))
if Elbo is TraceEnum_ELBO:
guide = config_enumerate(guide, "parallel")
elif Elbo is TraceTMC_ELBO:
guide = config_enumerate(guide, num_samples=2)
assert_ok(model, guide, Elbo(max_plate_nesting=max_plate_nesting))
@pytest.mark.parametrize("max_plate_nesting", [0, float("inf")])
@pytest.mark.parametrize("subsample_size", [None, 2], ids=["full", "subsample"])
@pytest.mark.parametrize(
"Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]
)
def test_iplate_iplate_swap_ok(subsample_size, Elbo, max_plate_nesting):
def model():
p = torch.tensor(0.5)
outer_iplate = pyro.plate("plate_0", 3, subsample_size)
inner_iplate = pyro.plate("plate_1", 3, subsample_size)
for i in outer_iplate:
for j in inner_iplate:
pyro.sample("x_{}_{}".format(i, j), dist.Bernoulli(p))
def guide():
p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
outer_iplate = pyro.plate("plate_0", 3, subsample_size)
inner_iplate = pyro.plate("plate_1", 3, subsample_size)
for j in inner_iplate:
for i in outer_iplate:
pyro.sample("x_{}_{}".format(i, j), dist.Bernoulli(p))
if Elbo is TraceEnum_ELBO:
guide = config_enumerate(guide, "parallel")
elif Elbo is TraceTMC_ELBO:
guide = config_enumerate(guide, default="parallel", num_samples=2)
assert_ok(model, guide, Elbo(max_plate_nesting=max_plate_nesting))
@pytest.mark.parametrize("subsample_size", [None, 5], ids=["full", "subsample"])
@pytest.mark.parametrize(
"Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]
)
def test_iplate_in_model_not_guide_ok(subsample_size, Elbo):
def model():
p = torch.tensor(0.5)
for i in pyro.plate("plate", 10, subsample_size):
pass
pyro.sample("x", dist.Bernoulli(p))
def guide():
p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
pyro.sample("x", dist.Bernoulli(p))
if Elbo is TraceEnum_ELBO:
guide = config_enumerate(guide)
elif Elbo is TraceTMC_ELBO:
guide = config_enumerate(guide, num_samples=2)
assert_ok(model, guide, Elbo())
@pytest.mark.parametrize("subsample_size", [None, 5], ids=["full", "subsample"])
@pytest.mark.parametrize(
"Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]
)
@pytest.mark.parametrize("is_validate", [True, False])
def test_iplate_in_guide_not_model_error(subsample_size, Elbo, is_validate):
def model():
p = torch.tensor(0.5)
pyro.sample("x", dist.Bernoulli(p))
def guide():
p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
for i in pyro.plate("plate", 10, subsample_size):
pass
pyro.sample("x", dist.Bernoulli(p))
with pyro.validation_enabled(is_validate):
if is_validate:
assert_error(
model,
guide,
Elbo(),
match="Found plate statements in guide but not model",
)
else:
assert_ok(model, guide, Elbo())
@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO])
def test_plate_broadcast_error(Elbo):
def model():
p = torch.tensor(0.5, requires_grad=True)
with pyro.plate("plate", 10, 5):
pyro.sample("x", dist.Bernoulli(p).expand_by([2]))
assert_error(model, model, Elbo(), match="Shape mismatch inside plate")
@pytest.mark.parametrize(
"Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]
)
def test_plate_iplate_ok(Elbo):
def model():
p = torch.tensor(0.5)
with pyro.plate("plate", 3, 2) as ind:
for i in pyro.plate("iplate", 3, 2):
pyro.sample("x_{}".format(i), dist.Bernoulli(p).expand_by([len(ind)]))
def guide():
p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
with pyro.plate("plate", 3, 2) as ind:
for i in pyro.plate("iplate", 3, 2):
pyro.sample("x_{}".format(i), dist.Bernoulli(p).expand_by([len(ind)]))
if Elbo is TraceEnum_ELBO:
guide = config_enumerate(guide)
elif Elbo is TraceTMC_ELBO:
guide = config_enumerate(guide, num_samples=2)
assert_ok(model, guide, Elbo())
@pytest.mark.parametrize(
"Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]
)
def test_iplate_plate_ok(Elbo):
def model():
p = torch.tensor(0.5)
inner_plate = pyro.plate("plate", 3, 2)
for i in pyro.plate("iplate", 3, 2):
with inner_plate as ind:
pyro.sample("x_{}".format(i), dist.Bernoulli(p).expand_by([len(ind)]))
def guide():
p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
inner_plate = pyro.plate("plate", 3, 2)
for i in pyro.plate("iplate", 3, 2):
with inner_plate as ind:
pyro.sample("x_{}".format(i), dist.Bernoulli(p).expand_by([len(ind)]))
if Elbo is TraceEnum_ELBO:
guide = config_enumerate(guide)
elif Elbo is TraceTMC_ELBO:
guide = config_enumerate(guide, num_samples=2)
assert_ok(model, guide, Elbo())
@pytest.mark.parametrize(
"Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]
)
@pytest.mark.parametrize("sizes", [(3,), (3, 4), (3, 4, 5)])
def test_plate_stack_ok(Elbo, sizes):
def model():
p = torch.tensor(0.5)
with pyro.plate_stack("plate_stack", sizes):
pyro.sample("x", dist.Bernoulli(p))
def guide():
p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
with pyro.plate_stack("plate_stack", sizes):
pyro.sample("x", dist.Bernoulli(p))
if Elbo is TraceEnum_ELBO:
guide = config_enumerate(guide)
elif Elbo is TraceTMC_ELBO:
guide = config_enumerate(guide, num_samples=2)
assert_ok(model, guide, Elbo())
@pytest.mark.parametrize(
"Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]
)
@pytest.mark.parametrize("sizes", [(3,), (3, 4), (3, 4, 5)])
def test_plate_stack_and_plate_ok(Elbo, sizes):
def model():
p = torch.tensor(0.5)
with pyro.plate_stack("plate_stack", sizes):
with pyro.plate("plate", 7):
pyro.sample("x", dist.Bernoulli(p))
def guide():
p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
with pyro.plate_stack("plate_stack", sizes):
with pyro.plate("plate", 7):
pyro.sample("x", dist.Bernoulli(p))
if Elbo is TraceEnum_ELBO:
guide = config_enumerate(guide)
elif Elbo is TraceTMC_ELBO:
guide = config_enumerate(guide, num_samples=2)
assert_ok(model, guide, Elbo())
@pytest.mark.parametrize("sizes", [(3,), (3, 4), (3, 4, 5)])
def test_plate_stack_sizes(sizes):
def model():
p = 0.5 * torch.ones(3)
with pyro.plate_stack("plate_stack", sizes):
x = pyro.sample("x", dist.Bernoulli(p).to_event(1))
assert x.shape == sizes + (3,)
model()
@pytest.mark.parametrize(
"Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]
)
def test_nested_plate_plate_ok(Elbo):
def model():
p = torch.tensor(0.5, requires_grad=True)
with pyro.plate("plate_outer", 10, 5) as ind_outer:
pyro.sample("x", dist.Bernoulli(p).expand_by([len(ind_outer)]))
with pyro.plate("plate_inner", 11, 6) as ind_inner:
pyro.sample(
"y", dist.Bernoulli(p).expand_by([len(ind_inner), len(ind_outer)])
)
if Elbo is TraceEnum_ELBO:
guide = config_enumerate(model)
elif Elbo is TraceTMC_ELBO:
guide = config_enumerate(model, num_samples=2)
else:
guide = model
assert_ok(model, guide, Elbo())
@pytest.mark.parametrize(
"Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]
)
def test_plate_reuse_ok(Elbo):
def model():
p = torch.tensor(0.5, requires_grad=True)
plate_outer = pyro.plate("plate_outer", 10, 5, dim=-1)
plate_inner = pyro.plate("plate_inner", 11, 6, dim=-2)
with plate_outer as ind_outer:
pyro.sample("x", dist.Bernoulli(p).expand_by([len(ind_outer)]))
with plate_inner as ind_inner:
pyro.sample("y", dist.Bernoulli(p).expand_by([len(ind_inner), 1]))
with plate_outer as ind_outer, plate_inner as ind_inner:
pyro.sample(
"z", dist.Bernoulli(p).expand_by([len(ind_inner), len(ind_outer)])
)
if Elbo is TraceEnum_ELBO:
guide = config_enumerate(model)
elif Elbo is TraceTMC_ELBO:
guide = config_enumerate(model, num_samples=2)
else:
guide = model
assert_ok(model, guide, Elbo())
@pytest.mark.parametrize(
"Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]
)
def test_nested_plate_plate_dim_error_1(Elbo):
def model():
p = torch.tensor([0.5], requires_grad=True)
with pyro.plate("plate_outer", 10, 5) as ind_outer:
pyro.sample(
"x", dist.Bernoulli(p).expand_by([len(ind_outer)])
) # error here
with pyro.plate("plate_inner", 11, 6) as ind_inner:
pyro.sample("y", dist.Bernoulli(p).expand_by([len(ind_inner)]))
pyro.sample(
"z", dist.Bernoulli(p).expand_by([len(ind_outer), len(ind_inner)])
)
if Elbo is TraceEnum_ELBO:
guide = config_enumerate(model)
elif Elbo is TraceTMC_ELBO:
guide = config_enumerate(model, num_samples=2)
else:
guide = model
assert_error(model, guide, Elbo(), match="invalid log_prob shape")
@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO])
def test_nested_plate_plate_dim_error_2(Elbo):
def model():
p = torch.tensor([0.5], requires_grad=True)
with pyro.plate("plate_outer", 10, 5) as ind_outer:
pyro.sample("x", dist.Bernoulli(p).expand_by([len(ind_outer), 1]))
with pyro.plate("plate_inner", 11, 6) as ind_inner:
pyro.sample(
"y", dist.Bernoulli(p).expand_by([len(ind_outer)])
) # error here
pyro.sample(
"z", dist.Bernoulli(p).expand_by([len(ind_outer), len(ind_inner)])
)
guide = config_enumerate(model) if Elbo is TraceEnum_ELBO else model
assert_error(model, guide, Elbo(), match="Shape mismatch inside plate")
@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO])
def test_nested_plate_plate_dim_error_3(Elbo):
def model():
p = torch.tensor([0.5], requires_grad=True)
with pyro.plate("plate_outer", 10, 5) as ind_outer:
pyro.sample("x", dist.Bernoulli(p).expand_by([len(ind_outer), 1]))
with pyro.plate("plate_inner", 11, 6) as ind_inner:
pyro.sample("y", dist.Bernoulli(p).expand_by([len(ind_inner)]))
pyro.sample(
"z", dist.Bernoulli(p).expand_by([len(ind_inner), 1])
) # error here
guide = config_enumerate(model) if Elbo is TraceEnum_ELBO else model
assert_error(model, guide, Elbo(), match="invalid log_prob shape|shape mismatch")
@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO])
def test_nested_plate_plate_dim_error_4(Elbo):
def model():
p = torch.tensor([0.5], requires_grad=True)
with pyro.plate("plate_outer", 10, 5) as ind_outer:
pyro.sample("x", dist.Bernoulli(p).expand_by([len(ind_outer), 1]))
with pyro.plate("plate_inner", 11, 6) as ind_inner:
pyro.sample("y", dist.Bernoulli(p).expand_by([len(ind_inner)]))
pyro.sample(
"z", dist.Bernoulli(p).expand_by([len(ind_outer), len(ind_outer)])
) # error here
guide = config_enumerate(model) if Elbo is TraceEnum_ELBO else model
assert_error(model, guide, Elbo(), match="hape mismatch inside plate")
@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO])
def test_nested_plate_plate_subsample_param_ok(Elbo):
def model():
with pyro.plate("plate_outer", 10, 5):
pyro.sample("x", dist.Bernoulli(0.2))
with pyro.plate("plate_inner", 11, 6):
pyro.sample("y", dist.Bernoulli(0.2))
def guide():
p0 = pyro.param("p0", 0.5 * torch.ones(4, 5), event_dim=2)
assert p0.shape == (4, 5)
with pyro.plate("plate_outer", 10, 5):
p1 = pyro.param("p1", 0.5 * torch.ones(10, 3), event_dim=1)
assert p1.shape == (5, 3)
px = pyro.param("px", 0.5 * torch.ones(10), event_dim=0)
assert px.shape == (5,)
pyro.sample("x", dist.Bernoulli(px))
with pyro.plate("plate_inner", 11, 6):
py = pyro.param("py", 0.5 * torch.ones(11, 10), event_dim=0)
assert py.shape == (6, 5)
pyro.sample("y", dist.Bernoulli(py))
if Elbo is TraceEnum_ELBO:
guide = config_enumerate(guide)
assert_ok(model, guide, Elbo())
@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO])
def test_nonnested_plate_plate_ok(Elbo):
def model():
p = torch.tensor(0.5, requires_grad=True)
with pyro.plate("plate_0", 10, 5) as ind1:
pyro.sample("x0", dist.Bernoulli(p).expand_by([len(ind1)]))
with pyro.plate("plate_1", 11, 6) as ind2:
pyro.sample("x1", dist.Bernoulli(p).expand_by([len(ind2)]))
guide = config_enumerate(model) if Elbo is TraceEnum_ELBO else model
assert_ok(model, guide, Elbo())
def test_three_indep_plate_at_different_depths_ok():
r"""
/\
/\ ia
ia ia
"""
def model():
p = torch.tensor(0.5)
inner_plate = pyro.plate("plate2", 10, 5)
for i in pyro.plate("plate0", 2):
pyro.sample("x_%d" % i, dist.Bernoulli(p))
if i == 0:
for j in pyro.plate("plate1", 2):
with inner_plate as ind:
pyro.sample("y_%d" % j, dist.Bernoulli(p).expand_by([len(ind)]))
elif i == 1:
with inner_plate as ind:
pyro.sample("z", dist.Bernoulli(p).expand_by([len(ind)]))
def guide():
p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
inner_plate = pyro.plate("plate2", 10, 5)
for i in pyro.plate("plate0", 2):
pyro.sample("x_%d" % i, dist.Bernoulli(p))
if i == 0:
for j in pyro.plate("plate1", 2):
with inner_plate as ind:
pyro.sample("y_%d" % j, dist.Bernoulli(p).expand_by([len(ind)]))
elif i == 1:
with inner_plate as ind:
pyro.sample("z", dist.Bernoulli(p).expand_by([len(ind)]))
assert_ok(model, guide, TraceGraph_ELBO())
def test_plate_wrong_size_error():
def model():
p = torch.tensor(0.5)
with pyro.plate("plate", 10, 5) as ind:
pyro.sample("x", dist.Bernoulli(p).expand_by([1 + len(ind)]))
def guide():
p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
with pyro.plate("plate", 10, 5) as ind:
pyro.sample("x", dist.Bernoulli(p).expand_by([1 + len(ind)]))
assert_error(model, guide, TraceGraph_ELBO(), match="Shape mismatch inside plate")
def test_block_plate_name_ok():
def model():
a = pyro.sample("a", dist.Normal(0, 1))
assert a.shape == ()
with pyro.plate("plate", 2):
b = pyro.sample("b", dist.Normal(0, 1))
assert b.shape == (2,)
with block_plate("plate"):
c = pyro.sample("c", dist.Normal(0, 1))
assert c.shape == ()
def guide():
c = pyro.sample("c", dist.Normal(0, 1))
assert c.shape == ()
with pyro.plate("plate", 2):
b = pyro.sample("b", dist.Normal(0, 1))
assert b.shape == (2,)
with block_plate("plate"):
a = pyro.sample("a", dist.Normal(0, 1))
assert a.shape == ()
assert_ok(model, guide, Trace_ELBO())
def test_block_plate_dim_ok():
def model():
a = pyro.sample("a", dist.Normal(0, 1))
assert a.shape == ()
with pyro.plate("plate", 2):
b = pyro.sample("b", dist.Normal(0, 1))
assert b.shape == (2,)
with block_plate(dim=-1):
c = pyro.sample("c", dist.Normal(0, 1))
assert c.shape == ()
def guide():
c = pyro.sample("c", dist.Normal(0, 1))
assert c.shape == ()
with pyro.plate("plate", 2):
b = pyro.sample("b", dist.Normal(0, 1))
assert b.shape == (2,)
with block_plate(dim=-1):
a = pyro.sample("a", dist.Normal(0, 1))
assert a.shape == ()
assert_ok(model, guide, Trace_ELBO())
def test_block_plate_missing_error():
def model():
with block_plate("plate"):
pyro.sample("a", dist.Normal(0, 1))
def guide():
pyro.sample("a", dist.Normal(0, 1))
assert_error(model, guide, Trace_ELBO(), match="block_plate matched 0 messengers")
@pytest.mark.parametrize("enumerate_", [None, "sequential", "parallel"])
@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO])
def test_enum_discrete_misuse_warning(Elbo, enumerate_):
def model():
p = torch.tensor(0.5)
pyro.sample("x", dist.Bernoulli(p))
def guide():
p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
pyro.sample("x", dist.Bernoulli(p), infer={"enumerate": enumerate_})
if (enumerate_ is None) == (Elbo is TraceEnum_ELBO):
assert_warning(model, guide, Elbo(max_plate_nesting=0))
else:
assert_ok(model, guide, Elbo(max_plate_nesting=0))