From 216f3dad8e46e4869cbd63044cb27ab8f98ee939 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 28 Jun 2019 11:07:32 -0700 Subject: [PATCH] Squashed commit of the following: commit 3a80b2952dbd3716b53049d731499965afd40ee7 Author: eb8680 Date: Thu Jun 27 14:17:26 2019 -0700 Make substitution an interpretation (#155) * make substitute into an interpretation * simplify interpretation and add numpy+tco test to Makefile * tweak Independent * remove substitute_funsor * uncomment long examples * conditionally xfail some numpy tests * update semantics of Independent to be compatible with substitution changes * add Module to ground types commit ce5cdeb834ff10aacfaf57dc696e4c60309f07a9 Author: eb8680 Date: Wed Jun 26 10:18:32 2019 -0700 Remove eager_subs method from Joint (#154) commit ba0dfcb4ff0fbe7a8a14866bf178a2ef2b290feb Author: eb8680 Date: Tue Jun 25 15:04:38 2019 -0700 Separate generic and fresh variables in substitute (#153) commit 14766efd8d5da091cb78e92a788f9fd25301de0b Author: eb8680 Date: Tue Jun 25 14:52:19 2019 -0700 Remove most eager_subs methods from terms with no fresh variables (#152) commit 7e456e60fbcabae1db262787119489f603e7a1d5 Author: eb8680 Date: Tue Jun 25 13:34:58 2019 -0700 Alpha renaming of bound variables (#148) commit 3edfb0f998ebb26021d6d2c2f6eb666daaa5ab89 Author: Fritz Obermeyer Date: Wed Jun 5 09:58:30 2019 -0700 Use classifiers to specify the license (#150) commit d1cd1d2d4147b0b4f4e4de3adc84a10b3377172f Author: eb8680 Date: Fri May 24 11:10:14 2019 -0700 Add an explicit stack-based interpreter (#147) commit ac02dcce71ea77ac17e603b2aab6d761d191ca34 Author: Fritz Obermeyer Date: Mon May 6 11:02:20 2019 -0700 Update badge to point to travis-ci.com (#143) commit 416cff6dff9cec6f5a21d769edf90917e7c66bc0 Author: Fritz Obermeyer Date: Thu May 2 14:45:10 2019 -0700 Promote Einsum to a funsor for pattern matching (#142) commit 30e4255c4fdf137fb4168746b7b525a9cd0c8c37 Author: Fritz Obermeyer Date: Thu May 2 10:46:24 2019 -0700 Update CI to use PyTorch 1.1.0 (#141) * Update CI to use PyTorch 1.1.0 * Fix travis version * Accommodate change to torch.max return type commit 00c467d2f33c0810877fc422845557e62fc5632a Author: eb8680 Date: Wed May 1 22:57:54 2019 -0700 Add model and guide auxiliary variable elimination to minipyro.elbo (#126) * add auxiliary variable elimination to minipyro.elbo to mimic TraceEnum_ELBO * optional partial_sum_product * Common subexpression elimination * add a test for traceenum_elbo * add two more tests, one of which is xfailing... * Simplify test and minipyro (#140) * Add smoke test for gaussian-probit model * Remove bogus assertion * Fix TraceEnum_ELBO to use EnumerateMessenger * remove EnumerateMessenger * add optional optimization to traceenum_elbo and mark monte carlo test xfail commit c77dc99a0c44f90b1541f82e9906ef944cae8d12 Author: JP Date: Wed May 1 21:26:07 2019 -0700 rtd fix to not install torch 1.0 (#139) commit 2e5fd404e44deb0b39300986e1f2dab5e3343502 Author: JP Date: Wed May 1 14:10:27 2019 -0700 deploy documentation (#138) commit cb09325d3812ef3c2ac11860717b0a6cbca73703 Author: Fritz Obermeyer Date: Wed May 1 12:19:14 2019 -0700 Add generic Bernoulli wrapper around probs,logits version (#137) commit 45e0434c65fb4e47b4c319d748e856202e847a72 Author: Fritz Obermeyer Date: Wed May 1 12:11:03 2019 -0700 Refactor elbos (#136) * Refactor elbos * Fix jit classes commit 1f5f60271bc86ca76dca33fa0f60a3560eab9d1a Author: Fritz Obermeyer Date: Tue Apr 30 23:57:13 2019 -0700 Add BernoulliLogits distribution (#135) commit df256073e9fb5142de1529e2eebe70e22ed8c2bd Author: Fritz Obermeyer Date: Tue Apr 30 15:51:38 2019 -0700 Avoid legacy constructors (#134) commit 92e85c3b56dc452b91de415be04c38d181264304 Author: Fritz Obermeyer Date: Mon Apr 29 11:14:31 2019 -0700 Add JitTrace_ELBO class to minipyro (#133) * Add JitTrace_ELBO class to minipyro * Enable minipyro --jit test * Simplify JitTrace_ELBO * Make Gaussian math jit compatible * Add unit tests for BlockVector and BlockMatrix * Fix docs * Tweak docs commit 5053f554a1a375479ac84bec1e2639bdbd156389 Author: Fritz Obermeyer Date: Fri Apr 12 14:09:23 2019 -0400 Remove dependency on pyro ParamStoreDict (#129) commit 79a8621c06498797d7f1879ea42b9c462de21d0f Author: Fritz Obermeyer Date: Thu Apr 11 21:15:00 2019 -0700 Support constraints in funsor.minipyro (#128) * Support constraints in funsor.minipyro * Make minipyro tests a little stronger * Save metadata to value.unconstrained() commit a7f3b8a84d86cce8565b66ffa22d3fbf01353f45 Author: Fritz Obermeyer Date: Tue Apr 9 21:29:29 2019 -0700 Reinstate examples/minipyro.py, forked from pyro/examples (#125) * Reinstate examples/minipyro.py, forked from pyro/examples * Add a rule for eager evaluation of (p.exp() * f).reduce(ops.add) and a failing test showing numerical instability * compute most of expectation in log-space in minipyro * handle plates? * add commit 48fa3f782fb6e80902992e442dd5bf4225998075 Author: Neeraj Pradhan Date: Mon Apr 8 11:17:29 2019 -0700 Add Beta, Dirichlet, and Binomial distributions (#120) commit f74df0aad3217a96e60dbc561147eb151e19065c Author: eb8680 Date: Sat Apr 6 19:28:39 2019 -0700 Add rule for creating Gaussians from Affine inputs (#119) commit 8fa1c324e6ca274dbc2e8626d871b903dcb25146 Author: eb8680 Date: Sat Apr 6 15:21:33 2019 -0700 Add an Affine term to represent multilinear functions of real Variables (#116) commit 627711ef66d5780b2537c920e45e0976d49f4af6 Author: Fritz Obermeyer Date: Sat Apr 6 10:09:42 2019 -0700 Support wrapping PyTorch builtin functions (#118) commit 216626c90bb0702bf2ad1e718bf490ae151ce2de Author: Fritz Obermeyer Date: Sat Apr 6 09:06:51 2019 -0700 Add a torch_tensordot operation (#117) * Add a torch_tensordot op * Fix docstrings to avoid sphinx warning commit fcb4670aba586cff54444c786589a458332a9843 Author: Fritz Obermeyer Date: Fri Apr 5 10:30:03 2019 -0700 Add moment_matching interpretation + SLDS example (#115) * Sketch imm example (no inference yet) * Sketch Gaussian.moment_matching_reduce * Use moment-matching interpretation in IMM example * Rename example to slds.py * Make SLDS params more interesting * Relax assumptions regarding reduced_vars commit fd0dfe4b2f75c64076e9eb62c02cdceed441fae3 Author: Fritz Obermeyer Date: Tue Apr 2 15:04:43 2019 -0700 Add initial sphinx docs (#114) commit 51b54893d7fd1bb72eafc0b4be42bc629d41f560 Author: Fritz Obermeyer Date: Sun Mar 31 23:14:38 2019 -0700 Add a VAE example using the monte_carlo interpretation (#95) * Sketch Monte Carlo interpretation of logaddexp reduction * Use AssociativeOp in patterns * Fix op pattern matcher * Try eager before monte_carlo * Drop ops.sample, ops.marginal * Sketch VAE example using monte carlo interpretation * Refactor, focusing on .sample() and .monte_carlo_logsumexp() methods * Fix vae example * Sketch Tensor.sample() (untested) * Fix cyclic import * Sketch Gaussian.sample() (untested) * Implement Delta.sample() * Sketch Expectation class * Sketch sampler implementations * Delete Expectation in favor of Integrate in a separate PR * Revert .sample() sketch * Update VAE example to use multi-output Functions * Fix reductions in VAE * Sketch support for multiple args in __getitem__ * Fix bugs in getitem_tensor_tensor * Add stronger tests for tensor getitem * Add support for string indexing * Simplify vae example using multi-getitem * Add stub for Integrate * Fix typo * Sketch monte_carlo registration of Gaussian-Gaussian things * Add stubs for Joint integration * Fix typos * Sketch support for multiple samples * Fix test usage of registry * Fix bugs in gaussian integral * Handle scale factors in Funsor.sample() * Use Integrate in test_samplers.py * Fix bug in Integrate; be less clever * Add implementations of gaussian-linear integrals * Add interpretation logging controlled by FUNSOR_DEBUG * Simplify debug printing * Fix lazy reduction for Joint.reduce() * Fix recursion bug * Get univariate Gaussian sampling to mostly work * Fix bug in Tensor.eager_reduce with nontrivial output * Fix output shape broadcasting in Tensor * Fix assert_close in test_samplers.py * Fix cholesky bugs * Fix bug in _trace_mm() * Fixes for examples/vae.py * Remove examples/vae.py * Revert "Remove examples/vae.py" This reverts commit bee75b9f2efc5b09ac62fabe05b146f2b988b847. * Use funsor.Lambda in VAE example * Add a Lambda funsor, inverse to getitem * Use lazy substitution rather than Lambda * Add --pdb argument to examples/vae.py * Add function logging and filename logging when FUNSOR_DEBUG=1 * Enable more functions to be logged * Little fixes * Fixes to support vae example * Fix comment * Revert product sample rules * WIP sketch plates * Simplify vae example (still not working) * Sketch Uncurry funsor * Sketch Uncurry-Delta-Lambda pattern * Sketch Joint-Uncurry-Delta rule * Sketch uncurry-distribution test * Change to_funsor second arg from dtype to Domain * Add Funsor.__contains__ * Use Uncurry in VAE example * Fix test_normal_uncurry * Rename Uncurry to Independent * Revert irrelevant changes * Support sampling from mixtures * Update VAE example to use Independent * Get VAE example to start working * Flake8 * Revert nan validation * Add torchvision to setup.py commit 306aca680d86d1761c17b9254204f3b40cc84c86 Author: Fritz Obermeyer Date: Sun Mar 31 22:51:42 2019 -0700 Add Lambda and Independent funsors (#97) * Add a Lambda funsor, inverse to getitem * Sketch Uncurry funsor * Sketch Uncurry-Delta-Lambda pattern * Sketch Joint-Uncurry-Delta rule * Sketch uncurry-distribution test * Change to_funsor second arg from dtype to Domain * Add Funsor.__contains__ * Fix test_normal_uncurry * Rename Uncurry to Independent commit 89313efef9ec96c7ad3a4bcbd05cf72d37697b32 Author: Fritz Obermeyer Date: Sun Mar 31 13:40:43 2019 -0700 Refactor .unscaled_sample() interface (#113) commit 366e9e3728e2fc692fdd9405ab70f01f49acbb8f Author: Fritz Obermeyer Date: Sun Mar 31 13:39:35 2019 -0700 Change to_funsor second arg from dtype to Domain (#112) commit bddcc2c04ba9970ee4a316dcf77582b527c6d89b Author: Fritz Obermeyer Date: Sat Mar 30 18:15:54 2019 -0700 Support event_dim kwarg in pyro.param (#111) * Support event_dim kwarg in pyro.param * Use local-param branch of Pyro * Use pyro dev for CI commit 36e57aeaa67c7df9e944485df34b9e285af05378 Author: Fritz Obermeyer Date: Fri Mar 29 13:00:36 2019 -0700 Add a working minipyro with tests (#100) * Add pristine copy of pyro/contrib/minipyro.py * First pass at fixing minipyro * First pass at transformed distributions * Add test for sampling transformed Gaussian * Add test for renaming a Gaussian variable * Add dist.LogNormal distribution and density test * Add xfailing test for LogNormal sampler * Attempt to fix minipyro.elbo * Implement negation and subtraction ops for R-N derivatives * Fix bug in log_joint.process of sample * Implement reduction along a plate dimension * Implement correct but non-monte-carlo elbo * Fix bugs in Gaussian * Implement plate reductions for Gaussian, Joint * Use Expectation(...) in elbo computation * Fix typos in elbo * Add tests for plate reduction * Use Expectation interface by default * Add initial test_minipyro.py * Add more minipyro tests * Add shape assertions to gaussian math * Refactor minipyro * Add an observation to plate test * Add more plates tests * Support sampling in funsor.minipyro commit b9cdbac5a7286131693d065ff681951272ae006e Author: Fritz Obermeyer Date: Fri Mar 29 11:02:00 2019 -0700 Implement plate reductions for Gaussian, Joint (#108) commit 37ce962fc42d021baf0ab17693624fec8b39c5cc Author: Fritz Obermeyer Date: Thu Mar 28 14:43:45 2019 -0700 Make distributions lazy when used Pyro-style (#107) commit 16a2fbb7eed92a5bff16bda2a08eccb97e0412f2 Author: eb8680 Date: Wed Mar 27 18:04:50 2019 -0700 Add Contract to optimizer (#105) commit c339f1a854c02b6c9fa0d9e3063237fd85fdd64c Author: Fritz Obermeyer Date: Wed Mar 27 17:57:22 2019 -0700 Implement negation and subtraction ops for R-N derivatives (#104) commit 830e43076389f806a834f82388935017afc1f62d Author: Fritz Obermeyer Date: Wed Mar 27 14:58:14 2019 -0700 Implement basic transformed distributions (#103) * First pass at transformed distributions * Add test for sampling transformed Gaussian * Add test for renaming a Gaussian variable * Add dist.LogNormal distribution and density test * Add xfailing test for LogNormal sampler commit a54defa1ec916f5d6823ca15dd9c104c2afcdf32 Author: eb8680 Date: Wed Mar 27 14:02:34 2019 -0700 Support for pattern matching with the unification library (#78) commit 364745e8d1208f4dddbca277a9b2631ec7985dfd Author: Fritz Obermeyer Date: Tue Mar 26 19:08:23 2019 -0700 Remove modified version of minipyro (#102) commit ea7feb2e217c5631b9b2c17e659dd2a0aa70b1d4 Author: Fritz Obermeyer Date: Tue Mar 26 13:29:12 2019 -0700 Log function and filename when FUNSOR_DEBUG=1 (#101) * Add function logging and filename logging when FUNSOR_DEBUG=1 * Enable more functions to be logged commit 951630c09a52c8a70f6e088e514079333448097e Author: Fritz Obermeyer Date: Tue Mar 26 00:14:21 2019 -0700 Resurrect lazy Subs funsor (again) (#99) * Sketch Monte Carlo interpretation of logaddexp reduction * Use AssociativeOp in patterns * Fix op pattern matcher * Try eager before monte_carlo * Drop ops.sample, ops.marginal * Sketch VAE example using monte carlo interpretation * Refactor, focusing on .sample() and .monte_carlo_logsumexp() methods * Fix vae example * Sketch Tensor.sample() (untested) * Fix cyclic import * Sketch Gaussian.sample() (untested) * Implement Delta.sample() * Sketch Expectation class * Sketch sampler implementations * Delete Expectation in favor of Integrate in a separate PR * Revert .sample() sketch * Update VAE example to use multi-output Functions * Fix reductions in VAE * Sketch support for multiple args in __getitem__ * Fix bugs in getitem_tensor_tensor * Add stronger tests for tensor getitem * Add support for string indexing * Simplify vae example using multi-getitem * Add stub for Integrate * Fix typo * Sketch monte_carlo registration of Gaussian-Gaussian things * Add stubs for Joint integration * Fix typos * Sketch support for multiple samples * Fix test usage of registry * Fix bugs in gaussian integral * Handle scale factors in Funsor.sample() * Use Integrate in test_samplers.py * Fix bug in Integrate; be less clever * Add implementations of gaussian-linear integrals * Add interpretation logging controlled by FUNSOR_DEBUG * Simplify debug printing * Fix lazy reduction for Joint.reduce() * Fix recursion bug * Get univariate Gaussian sampling to mostly work * Fix bug in Tensor.eager_reduce with nontrivial output * Fix output shape broadcasting in Tensor * Fix assert_close in test_samplers.py * Fix cholesky bugs * Fix bug in _trace_mm() * Fixes for examples/vae.py * Remove examples/vae.py * Add docstrings * Resurrect lazy Subs funsor (again) * Fix typo * Allow completely lazy eager_subs method commit 15b0c7366dc913b33ce2b79e911a26773a446353 Author: Fritz Obermeyer Date: Mon Mar 25 17:04:45 2019 -0700 Implement Monte Carlo interpretation of Integrate (#54) * Sketch Monte Carlo interpretation of logaddexp reduction * Use AssociativeOp in patterns * Fix op pattern matcher * Try eager before monte_carlo * Drop ops.sample, ops.marginal * Sketch VAE example using monte carlo interpretation * Refactor, focusing on .sample() and .monte_carlo_logsumexp() methods * Fix vae example * Sketch Tensor.sample() (untested) * Fix cyclic import * Sketch Gaussian.sample() (untested) * Implement Delta.sample() * Sketch Expectation class * Sketch sampler implementations * Delete Expectation in favor of Integrate in a separate PR * Revert .sample() sketch * Update VAE example to use multi-output Functions * Fix reductions in VAE * Sketch support for multiple args in __getitem__ * Fix bugs in getitem_tensor_tensor * Add stronger tests for tensor getitem * Add support for string indexing * Simplify vae example using multi-getitem * Add stub for Integrate * Fix typo * Sketch monte_carlo registration of Gaussian-Gaussian things * Add stubs for Joint integration * Fix typos * Sketch support for multiple samples * Fix test usage of registry * Fix bugs in gaussian integral * Handle scale factors in Funsor.sample() * Use Integrate in test_samplers.py * Fix bug in Integrate; be less clever * Add implementations of gaussian-linear integrals * Add interpretation logging controlled by FUNSOR_DEBUG * Simplify debug printing * Fix lazy reduction for Joint.reduce() * Fix recursion bug * Get univariate Gaussian sampling to mostly work * Fix bug in Tensor.eager_reduce with nontrivial output * Fix output shape broadcasting in Tensor * Fix assert_close in test_samplers.py * Fix cholesky bugs * Fix bug in _trace_mm() * Fixes for examples/vae.py * Remove examples/vae.py * Add docstrings * Updates per review * Revert accidental change commit d2d4c4ab3b8ba39690bd1a9b814881488170d651 Author: Fritz Obermeyer Date: Sat Mar 23 13:44:16 2019 -0700 Fix lazy reduction for Joint.reduce() (#94) * Fix lazy reduction for Joint.reduce() * Fix recursion bug commit 270d1680f7243d3499ed9c113173e8226ff9f747 Author: Fritz Obermeyer Date: Fri Mar 22 16:44:01 2019 -0700 Add interpretation logging controlled by FUNSOR_DEBUG=1 (#93) * Add interpretation logging controlled by FUNSOR_DEBUG * Simplify debug printing * Improve pretty printing of Stack and Joint commit 5d0a291a5cb3f3a638218b14d6ddfadca465beb1 Author: Fritz Obermeyer Date: Fri Mar 22 12:55:34 2019 -0700 Implement advanced indexing in Funsor.__getitem__() (#88) * Sketch support for multiple args in __getitem__ * Fix bugs in getitem_tensor_tensor * Add stronger tests for tensor getitem * Add support for string indexing commit 028f640c0d2476de7c5ea043133dd13ad92f13c2 Author: Fritz Obermeyer Date: Thu Mar 21 15:45:29 2019 -0700 Add isort command to Makefile (#92) commit 7bbfb3ae9ea9844a89f72c787079dba96468b76e Author: Fritz Obermeyer Date: Thu Mar 21 13:39:40 2019 -0700 Sketched a simple PCFG example (#87) commit 3fe2581e28fc3e1e3b82cd63b5a16e9c4dd8c570 Author: Fritz Obermeyer Date: Thu Mar 21 13:39:14 2019 -0700 Refactor contract dependencies (#91) * Refactor contract dependencies * Fix optimize(Contract, Tensor, Tensor) * Fix dtype computation * Address review comment * Fix typo commit 7cca0f8c7f107d46b8d2c8a9645a5fb4213fa3fc Author: eb8680 Date: Wed Mar 20 12:22:44 2019 -0700 Implement Contract term (#77) commit 876070fc6e9b82d71cb70fd6669fe5ffa68f38af Author: Fritz Obermeyer Date: Wed Mar 20 00:19:26 2019 -0700 Add a to_data() helper (#84) * Add a to_nonfunsor() helper * Fix typo; add more tests * Rename to_nonfunsor to to_data commit 05649aea3f76a0acdbdfb48b530cf429b8cccec4 Author: Fritz Obermeyer Date: Tue Mar 19 20:50:04 2019 -0700 Support torch functions that return nested tuples of tensors (#82) commit 2b8c0e5611b5a120ca8559bfd8adac799f1cae95 Author: eb8680 Date: Tue Mar 19 14:39:51 2019 -0700 Rename funsor.contract module to funsor.sum_product (#81) commit 0139a70193cb981da1e65298f752689c708eb7da Author: Fritz Obermeyer Date: Fri Mar 15 15:58:11 2019 -0700 Implement monte carlo .sample() methods (#75) commit b5ea615746e20bf445c460615988e40ad4c47007 Author: Fritz Obermeyer Date: Fri Mar 15 15:40:08 2019 -0700 Implement sequential interpretation (#76) commit 5fa7fa63b5fdd49122b376a8e6b9658f2d34b0b2 Author: Fritz Obermeyer Date: Wed Mar 13 11:29:21 2019 -0700 Implement a Joint normal form funsor (#69) * Add a simple delta distribution * Add tests for nontrivial event dim * Simplify unit test * Sketch a general Delta funsor class * Simplify to binding a single name in Delta * Add some tests for Delta * Add test for ground substitution * Add tests for reduction * Add test for conversion from dist.Delta to Delta * Sketch JointNormalForm funsor * Settle on Joint interface * Add more + handling * Remove .log_density field from Delta funsor * Drop handling of .log_density from Joint * Add logic promoting various Binary(-,-) to Joint * Revert "Remove .log_density field from Delta funsor" This reverts commit 897f5237b65c402eee79cc105a81590f8a3f81c8. * Revert "Drop handling of .log_density from Joint" This reverts commit a7d008244c070de161415c8dcf0eba790fd98f10. * Simplify Gaussian funsor * WIP Refactor Joint patterns * Get Gaussian working with Joint * Add a smoke test for Joint * Add test for reduction * Make xfail more targeted * Update docstring on Joint * Remove unnecessary handling of Binary(ops.add,...) commit 9dbb231cf2c21810039c36ee352a837bcdc0a739 Author: Neeraj Pradhan Date: Tue Mar 12 18:46:00 2019 -0700 Separate out ops implementations based on backend (#74) * Separate out ops implementations based on backend * rebase with master * fix failing tests * fix error in safediv; get tests working * use object for binary ops * address comment * address comments commit b130f1715976bdd32f1e1b3d78228b0eced31939 Author: Fritz Obermeyer Date: Mon Mar 11 16:21:14 2019 -0700 Implement a general Delta funsor (#65) * Add a simple delta distribution * Add tests for nontrivial event dim * Simplify unit test * Sketch a general Delta funsor class * Simplify to binding a single name in Delta * Add some tests for Delta * Add test for ground substitution * Add tests for reduction * Add test for conversion from dist.Delta to Delta * Remove .log_density field from Delta funsor * Revert "Remove .log_density field from Delta funsor" This reverts commit 897f5237b65c402eee79cc105a81590f8a3f81c8. * Fix failing test * Add tests for log_density * Update __init__.py commit 34a1d9c45feaed2f2e9218a273efbcbb34bc7599 Author: Fritz Obermeyer Date: Mon Mar 11 15:53:09 2019 -0700 Implement MultivariateNormal distribution (#73) * Add multivariate distribution * Reduce test tolerance to fix CI build commit 8108b788b23bb05e491f3cb66c56c83bfee1b250 Author: Neeraj Pradhan Date: Mon Mar 11 12:48:37 2019 -0700 Add numpy backend for funsor (#58) commit 89540d6cb82a1be47a1a100bf1742f8a5cb9e1f8 Author: eb8680 Date: Mon Mar 11 12:43:22 2019 -0700 Add prototype funsor.adjoint module (#64) commit 27b4a767f0a7a24df93859e353c86b2beacd5700 Author: Fritz Obermeyer Date: Sun Mar 10 19:38:57 2019 -0700 Make reduction methods operate over events (#66) * Make reduction ops operate over events * Update existing tests * Update examples * Update README.md * Add tests for event reduction ops * Fix shape bug * Update README.md commit db497c465366b0e3dfe7d320f1336021f12e1912 Author: eb8680 Date: Sat Mar 9 19:43:27 2019 -0800 Add lazy option to HMM examples (#67) commit e46c2df51941f22b9d2f41ab06cf35ad71e41707 Author: Fritz Obermeyer Date: Sat Mar 9 01:56:27 2019 -0800 Refactor contract (#60) * Sketch funsor.contract module * Add unit test for _partition * Add tests for partial_sum_product() * Fix test in python 3 commit 97153c744ee66090aa6d34d1320d727b8e159116 Author: Fritz Obermeyer Date: Fri Mar 8 16:37:47 2019 -0800 Match Op type rather than object (#63) commit d96018265dd46f74b96d37899466ad2ff9d553b3 Author: Fritz Obermeyer Date: Fri Mar 8 15:56:34 2019 -0800 Add type check in Funsor.__call__() (#62) commit 72fedac169e14fa09290393f42a9b9273b09fb9f Author: Fritz Obermeyer Date: Fri Mar 8 15:14:15 2019 -0800 Update minipyro storyboard to correctly handle plates (#56) * Update minipyro storyboard to correctly handle plates * Revert changes to einsum.py commit 52f4366df6f1de64047d98ead7a174181fe816da Author: Fritz Obermeyer Date: Fri Mar 8 14:57:21 2019 -0800 Implement Op wrapper class to enable pattern matching (#55) commit e64a73045401882758c1767d37e21cb26f569316 Author: Fritz Obermeyer Date: Wed Mar 6 17:41:20 2019 -0800 Add a simple Delta distribution (#49) * Add a simple delta distribution * Add tests for nontrivial event dim * Simplify unit test commit c47a17f9db2f97d9cda8fc2155cafc73058977eb Author: eb8680 Date: Wed Mar 6 09:12:29 2019 -0800 Fix path evaluation order in optimizer (#47) commit 4537e903b4788c6c01d2fd3884485fc7e8dc74e4 Author: eb8680 Date: Tue Mar 5 18:27:36 2019 -0800 Add plated einsum implementation (#46) commit 448c1fa069337e5ec0a140cb843dd041010935f5 Author: Fritz Obermeyer Date: Mon Mar 4 17:12:54 2019 -0800 Implement general Gaussian funsor (#37) * WIP sketch Gaussian funsor * Partially implement binary_gaussian_gaussian * Implement marginalization along a dimension; add smoke test * Add more comments * Implement binary_gaussian_gaussian * Sketch to_affine() and Affine funsor * Add xfailing test for to_affine() * Sketch more of eager_subs * Remove affine stuff * WIP fix align_gaussian() using align_tensor() * Refactor and simplify align_tensor() * Fix eager_subs * Get smoke tests working * Switch from scale_tril to precision representation * Implement basic Normal -> Gaussian transform * Rename normal conversions * Fix filling in of defaults for distribution classes * Add test for binary_gaussian_number * Add test for binary_gaussian_tensor * Add xfailing test for gaussian + gaussian * Add more tests * Add test of Normal vs Gaussian * Fix math error in Gaussian .logsumexp() * Fix bugs in Gaussian+Gaussian, align_gaussian() * Fix kalman_filter.py, add to make test * Add more distribution tests commit d33341808f795011796f63be013c0bbdb0c96df0 Author: eb8680 Date: Sun Mar 3 10:35:23 2019 -0800 Reinstate opt_einsum-based optimizer (#40) --- .travis.yml | 8 +- Makefile | 30 +- README.md | 10 +- docs/Makefile | 20 + docs/make.bat | 36 ++ docs/requirements.txt | 6 + docs/source/adjoint.rst | 7 + docs/source/conf.py | 199 +++++++++ docs/source/distributions.rst | 7 + docs/source/domains.rst | 7 + docs/source/einsum.rst | 7 + docs/source/funsors.rst | 58 +++ docs/source/index.rst | 36 ++ docs/source/interpretations.rst | 18 + docs/source/minipyro.rst | 7 + docs/source/ops.rst | 6 + docs/source/optimizer.rst | 7 + examples/discrete_hmm.py | 22 +- examples/kalman_filter.py | 57 ++- examples/minipyro.py | 131 +++--- examples/pcfg.py | 63 +++ examples/slds.py | 81 ++++ examples/ss_vae_delayed.py | 8 +- examples/vae.py | 110 +++++ funsor/__init__.py | 44 +- funsor/adjoint.py | 102 +++++ funsor/affine.py | 194 ++++++++ funsor/contract.py | 74 ++++ funsor/delta.py | 199 +++++++++ funsor/distributions.py | 430 ++++++++++++++++-- funsor/domains.py | 16 +- funsor/einsum.py | 132 ++++++ funsor/gaussian.py | 615 ++++++++++++++++++++++++++ funsor/handlers.py | 144 ------ funsor/integrate.py | 84 ++++ funsor/interpreter.py | 277 ++++++++++-- funsor/joint.py | 471 ++++++++++++++++++++ funsor/minipyro.py | 760 ++++++++++++++++++------------- funsor/montecarlo.py | 58 +++ funsor/numpy.py | 291 ++++++++++++ funsor/ops.py | 244 ++++++++-- funsor/optimizer.py | 293 ++++++++++++ funsor/pattern.py | 69 +++ funsor/six.py | 30 +- funsor/sum_product.py | 98 ++++ funsor/terms.py | 762 ++++++++++++++++++++++++++------ funsor/testing.py | 201 ++++++++- funsor/torch.py | 590 +++++++++++++++++++++---- setup.cfg | 1 + setup.py | 15 +- test/conftest.py | 1 + test/test_adjoint.py | 164 +++++++ test/test_affine.py | 77 ++++ test/test_alpha_conversion.py | 98 ++++ test/test_contract.py | 83 ++++ test/test_delta.py | 69 +++ test/test_distributions.py | 425 +++++++++++++++++- test/test_einsum.py | 159 ++++--- test/test_gaussian.py | 427 ++++++++++++++++++ test/test_joint.py | 286 ++++++++++++ test/test_minipyro.py | 531 ++++++++++++++++++++++ test/test_numpy.py | 220 +++++++++ test/test_optimizer.py | 120 +++++ test/test_pattern.py | 42 ++ test/test_samplers.py | 276 ++++++++++++ test/test_sum_product.py | 81 ++++ test/test_terms.py | 82 +++- test/test_torch.py | 336 ++++++++++++-- 68 files changed, 9535 insertions(+), 1077 deletions(-) create mode 100644 docs/Makefile create mode 100644 docs/make.bat create mode 100644 docs/requirements.txt create mode 100644 docs/source/adjoint.rst create mode 100644 docs/source/conf.py create mode 100644 docs/source/distributions.rst create mode 100644 docs/source/domains.rst create mode 100644 docs/source/einsum.rst create mode 100644 docs/source/funsors.rst create mode 100644 docs/source/index.rst create mode 100644 docs/source/interpretations.rst create mode 100644 docs/source/minipyro.rst create mode 100644 docs/source/ops.rst create mode 100644 docs/source/optimizer.rst create mode 100644 examples/pcfg.py create mode 100644 examples/slds.py create mode 100644 examples/vae.py create mode 100644 funsor/adjoint.py create mode 100644 funsor/affine.py create mode 100644 funsor/contract.py create mode 100644 funsor/delta.py create mode 100644 funsor/einsum.py create mode 100644 funsor/gaussian.py delete mode 100644 funsor/handlers.py create mode 100644 funsor/integrate.py create mode 100644 funsor/joint.py create mode 100644 funsor/montecarlo.py create mode 100644 funsor/numpy.py create mode 100644 funsor/optimizer.py create mode 100644 funsor/pattern.py create mode 100644 funsor/sum_product.py create mode 100644 test/test_adjoint.py create mode 100644 test/test_affine.py create mode 100644 test/test_alpha_conversion.py create mode 100644 test/test_contract.py create mode 100644 test/test_delta.py create mode 100644 test/test_gaussian.py create mode 100644 test/test_joint.py create mode 100644 test/test_minipyro.py create mode 100644 test/test_numpy.py create mode 100644 test/test_optimizer.py create mode 100644 test/test_pattern.py create mode 100644 test/test_samplers.py create mode 100644 test/test_sum_product.py diff --git a/.travis.yml b/.travis.yml index 47c25007c..a090b88f6 100644 --- a/.travis.yml +++ b/.travis.yml @@ -13,10 +13,14 @@ cache: install: - pip install -U pip - if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then - pip install https://download.pytorch.org/whl/cpu/torch-1.0.1.post2-cp27-cp27mu-linux_x86_64.whl; + pip install https://download.pytorch.org/whl/cpu/torch-1.1.0-cp27-cp27mu-linux_x86_64.whl; else - pip install https://download.pytorch.org/whl/cpu/torch-1.0.1.post2-cp36-cp36m-linux_x86_64.whl; + pip install https://download.pytorch.org/whl/cpu/torch-1.1.0-cp36-cp36m-linux_x86_64.whl; fi + + # Keep track of Pyro dev branch + - pip install https://github.com/pyro-ppl/pyro/archive/dev.zip + - pip install .[test] - pip freeze diff --git a/Makefile b/Makefile index 228d2118b..d3c907936 100644 --- a/Makefile +++ b/Makefile @@ -1,20 +1,42 @@ -all: test +.PHONY: all install docs lint format test clean FORCE + +all: docs test install: pip install -e .[dev] +docs: FORCE + $(MAKE) -C docs html + lint: FORCE flake8 +format: FORCE + isort -y + test: lint FORCE pytest -v test + FUNSOR_DEBUG=1 pytest -v test/test_gaussian.py + FUNSOR_USE_TCO=1 pytest -v test/test_terms.py + FUNSOR_USE_TCO=1 pytest -v test/test_einsum.py + FUNSOR_USE_TCO=1 pytest -v test/test_numpy.py python examples/discrete_hmm.py -n 2 - @#python examples/kalman_filter.py --xfail-if-not-implemented + python examples/discrete_hmm.py -n 2 -t 50 --lazy + FUNSOR_USE_TCO=1 python examples/discrete_hmm.py -n 1 -t 50 --lazy + FUNSOR_USE_TCO=1 python examples/discrete_hmm.py -n 1 -t 500 --lazy + python examples/kalman_filter.py -n 2 + python examples/kalman_filter.py -n 2 -t 50 --lazy + FUNSOR_USE_TCO=1 python examples/kalman_filter.py -n 1 -t 50 --lazy + FUNSOR_USE_TCO=1 python examples/kalman_filter.py -n 1 -t 500 --lazy + python examples/minipyro.py + python examples/minipyro.py --jit + python examples/slds.py -n 2 -t 50 + python examples/pcfg.py --size 3 + python examples/vae.py --smoke-test @#python examples/ss_vae_delayed.py --xfail-if-not-implemented - @#python examples/minipyro.py --xfail-if-not-implemented @echo PASS clean: FORCE - git clean -dfx -e pyro-egg.info + git clean -dfx -e funsor-egg.info FORCE: diff --git a/README.md b/README.md index 0eecd50af..790bc4512 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,8 @@ -# Funsor ![unstable](https://img.shields.io/badge/status-unstable-red.svg) +![unstable](https://img.shields.io/badge/status-unstable-red.svg) +[![Build Status](https://travis-ci.com/pyro-ppl/funsor.svg?branch=master)](https://travis-ci.com/pyro-ppl/funsor) +[![Documentation Status](https://readthedocs.org/projects/funsor/badge)](http://funsor.readthedocs.io) + +# Funsor Functional analysis + tensors + symbolic algebra. @@ -59,8 +63,8 @@ def pyro_sample(name, dist, obs=None): return value # ...later during inference... -log_prob = trace_log_prob.logsumexp() # collapses delayed variables -loss = -funsor.eval(log_prob) # performs variable elimination +log_prob = trace_log_prob.reduce(logaddexp) # collapses delayed variables +loss = -funsor.eval(log_prob) # performs variable elimination ``` See [examples/minipyro.py](examples/minipyro.py) for a more complete example. diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 000000000..45893adae --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line. +SPHINXOPTS = +SPHINXBUILD = sphinx-build +SPHINXPROJ = funsor +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) \ No newline at end of file diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 000000000..f6cee192e --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,36 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build +set SPHINXPROJ=funsor + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% + +:end +popd diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 000000000..11bfc5cb9 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,6 @@ +contextlib2 +multipledispatch +numpy>=1.7 +opt_einsum>=2.3.2 +six>=1.10.0 +unification diff --git a/docs/source/adjoint.rst b/docs/source/adjoint.rst new file mode 100644 index 000000000..c4f0d61d5 --- /dev/null +++ b/docs/source/adjoint.rst @@ -0,0 +1,7 @@ +Adjoint Algorithms +------------------ +.. automodule:: funsor.adjoint + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource diff --git a/docs/source/conf.py b/docs/source/conf.py new file mode 100644 index 000000000..42c22f761 --- /dev/null +++ b/docs/source/conf.py @@ -0,0 +1,199 @@ +import os +import sys + +import sphinx_rtd_theme + +# import pkg_resources + +# -*- coding: utf-8 -*- +# +# Configuration file for the Sphinx documentation builder. +# +# This file does only contain a selection of the most common options. For a +# full list see the documentation: +# http://www.sphinx-doc.org/en/master/config + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +sys.path.insert(0, os.path.abspath('../..')) + +# -- Project information ----------------------------------------------------- + +project = u'Funsor' +copyright = u'2019, Uber Technologies, Inc' +author = u'Uber AI Labs' + +# The short X.Y version +version = u'0.0' +# The full version, including alpha/beta/rc tags +release = u'0.0' + + +# -- General configuration --------------------------------------------------- + +# If your documentation needs a minimal Sphinx version, state it here. +# +# needs_sphinx = '1.0' + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.doctest', + 'sphinx.ext.intersphinx', + 'sphinx.ext.mathjax', + 'sphinx.ext.viewcode', +] + +# Disable documentation inheritance so as to avoid inheriting +# docstrings in a different format, e.g. when the parent class +# is a PyTorch class. + +autodoc_inherit_docstrings = False + +autodoc_default_options = { + 'member-order': 'bysource', + 'show-inheritance': True, + 'special-members': True, + 'undoc-members': True, + 'exclude-members': '__dict__,__module__,__weakref__', +} + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# +# source_suffix = ['.rst', '.md'] +source_suffix = '.rst' + +# The master toctree document. +master_doc = 'index' + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = None + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path . +exclude_patterns = [] + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = 'sphinx' + + +# do not prepend module name to functions +add_module_names = False + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = "sphinx_rtd_theme" +html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +# +# html_theme_options = {} + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] + +# Custom sidebar templates, must be a dictionary that maps document names +# to template names. +# +# The default sidebars (for documents that don't match any pattern) are +# defined by theme itself. Builtin themes are using these templates by +# default: ``['localtoc.html', 'relations.html', 'sourcelink.html', +# 'searchbox.html']``. +# +# html_sidebars = {} + + +# -- Options for HTMLHelp output --------------------------------------------- + +# Output file base name for HTML help builder. +htmlhelp_basename = 'funsordoc' + + +# -- Options for LaTeX output ------------------------------------------------ + +latex_elements = { + # The paper size ('letterpaper' or 'a4paper'). + # + # 'papersize': 'letterpaper', + + # The font size ('10pt', '11pt' or '12pt'). + # + # 'pointsize': '10pt', + + # Additional stuff for the LaTeX preamble. + # + # 'preamble': '', + + # Latex figure (float) alignment + # + # 'figure_align': 'htbp', +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, +# author, documentclass [howto, manual, or own class]). +latex_documents = [ + (master_doc, 'Funsor.tex', u'Funsor Documentation', u'Uber AI Labs', 'manual'), +] + +# -- Options for manual page output ------------------------------------------ + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [ + (master_doc, 'Funsor', u'Funsor Documentation', + [author], 1) +] + +# -- Options for Texinfo output ---------------------------------------------- + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [ + (master_doc, 'Funsor', u'Funsor Documentation', + author, 'Funsor', 'Functional analysis + tensors + symbolic algebra.', + 'Miscellaneous'), +] + + +# -- Extension configuration ------------------------------------------------- + +# -- Options for intersphinx extension --------------------------------------- + +# Example configuration for intersphinx: refer to the Python standard library. +intersphinx_mapping = { + 'python': ('https://docs.python.org/3/', None), + 'numpy': ('http://docs.scipy.org/doc/numpy/', None), + 'torch': ('http://pytorch.org/docs/master/', None), + 'pyro': ('http://docs.pyro.ai/en/stable/', None), + 'opt_einsum': ('https://optimized-einsum.readthedocs.io/en/stable/', None) +} + +# @jpchen's hack to get rtd builder to install latest pytorch +if 'READTHEDOCS' in os.environ: + os.system('pip install https://download.pytorch.org/whl/cpu/torch-1.1.0-cp37-cp37m-linux_x86_64.whl') + # pyro needs to be installed after torch so pyro doesnt install the bloated torch-1.0 wheel + os.system('pip install pyro-ppl') diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst new file mode 100644 index 000000000..7e7ef1d8d --- /dev/null +++ b/docs/source/distributions.rst @@ -0,0 +1,7 @@ +Distributions +------------- +.. automodule:: funsor.distributions + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource diff --git a/docs/source/domains.rst b/docs/source/domains.rst new file mode 100644 index 000000000..580ca825d --- /dev/null +++ b/docs/source/domains.rst @@ -0,0 +1,7 @@ +Domains +------- +.. automodule:: funsor.domains + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource diff --git a/docs/source/einsum.rst b/docs/source/einsum.rst new file mode 100644 index 000000000..519fcf6e1 --- /dev/null +++ b/docs/source/einsum.rst @@ -0,0 +1,7 @@ +Einsum Interface +---------------- +.. automodule:: funsor.einsum + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource diff --git a/docs/source/funsors.rst b/docs/source/funsors.rst new file mode 100644 index 000000000..0e3e4480b --- /dev/null +++ b/docs/source/funsors.rst @@ -0,0 +1,58 @@ +Funsors +======= + +Basic Funsors +------------- +.. automodule:: funsor.terms + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource + +Delta +----- +.. automodule:: funsor.delta + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource + +PyTorch +------- +.. automodule:: funsor.torch + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource + +NumPy +----- +.. automodule:: funsor.numpy + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource + +Gaussian +-------- +.. automodule:: funsor.gaussian + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource + +Contract +-------- +.. automodule:: funsor.contract + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource + +Integrate +--------- +.. automodule:: funsor.integrate + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource diff --git a/docs/source/index.rst b/docs/source/index.rst new file mode 100644 index 000000000..1578c092d --- /dev/null +++ b/docs/source/index.rst @@ -0,0 +1,36 @@ +.. funsor documentation master file, created by + sphinx-quickstart on Tue Apr 2 13:33:44 2019. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +Welcome to Funsor's documentation! +================================== + +.. toctree:: + :glob: + :maxdepth: 2 + :caption: Funsor Core: + + funsors + domains + ops + interpretations + optimizer + adjoint + +.. toctree:: + :glob: + :maxdepth: 2 + :caption: Interfaces: + + distributions + minipyro + einsum + + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` diff --git a/docs/source/interpretations.rst b/docs/source/interpretations.rst new file mode 100644 index 000000000..e7331f376 --- /dev/null +++ b/docs/source/interpretations.rst @@ -0,0 +1,18 @@ +Interpretations +=============== + +Interpreter +----------- +.. automodule:: funsor.interpreter + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource + +Monte Carlo +----------- +.. automodule:: funsor.montecarlo + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource diff --git a/docs/source/minipyro.rst b/docs/source/minipyro.rst new file mode 100644 index 000000000..85ff184b4 --- /dev/null +++ b/docs/source/minipyro.rst @@ -0,0 +1,7 @@ +Mini-Pyro Interface +------------------- +.. automodule:: funsor.minipyro + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource diff --git a/docs/source/ops.rst b/docs/source/ops.rst new file mode 100644 index 000000000..6e4e2578b --- /dev/null +++ b/docs/source/ops.rst @@ -0,0 +1,6 @@ +Operations +---------- +.. automodule:: funsor.ops + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/optimizer.rst b/docs/source/optimizer.rst new file mode 100644 index 000000000..89a16b217 --- /dev/null +++ b/docs/source/optimizer.rst @@ -0,0 +1,7 @@ +Optimizer +--------- +.. automodule:: funsor.optimizer + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource diff --git a/examples/discrete_hmm.py b/examples/discrete_hmm.py index 51fed837a..1f2593f89 100644 --- a/examples/discrete_hmm.py +++ b/examples/discrete_hmm.py @@ -7,6 +7,10 @@ import funsor import funsor.distributions as dist +import funsor.ops as ops +from funsor.interpreter import interpretation, reinterpret +from funsor.optimizer import apply_optimizer +from funsor.terms import lazy def main(args): @@ -37,21 +41,25 @@ def model(data): x_curr = funsor.Variable('x_{}'.format(t), funsor.bint(args.hidden_dim)) log_prob += trans(prev=x_prev, value=x_curr) - if isinstance(x_prev, funsor.Variable): - log_prob = log_prob.logsumexp(x_prev.name) + if not args.lazy and isinstance(x_prev, funsor.Variable): + log_prob = log_prob.reduce(ops.logaddexp, x_prev.name) - log_prob += emit(latent=x_curr, value=y) + log_prob += emit(latent=x_curr, value=funsor.Tensor(y, dtype=2)) - log_prob = log_prob.logsumexp() + log_prob = log_prob.reduce(ops.logaddexp) return log_prob # Train model parameters. - print('---- training ----') data = torch.ones(args.time_steps, dtype=torch.long) optim = torch.optim.Adam(params, lr=args.learning_rate) for step in range(args.train_steps): optim.zero_grad() - log_prob = model(data) + if args.lazy: + with interpretation(lazy): + log_prob = apply_optimizer(model(data)) + log_prob = reinterpret(log_prob) + else: + log_prob = model(data) assert not log_prob.inputs, 'free variables remain' loss = -log_prob.data loss.backward() @@ -64,7 +72,7 @@ def model(data): parser.add_argument("-n", "--train-steps", default=101, type=int) parser.add_argument("-lr", "--learning-rate", default=0.05, type=float) parser.add_argument("-d", "--hidden-dim", default=2, type=int) - parser.add_argument("--eager", action='store_true') + parser.add_argument("--lazy", action='store_true') parser.add_argument("--filter", action='store_true') parser.add_argument("--xfail-if-not-implemented", action='store_true') args = parser.parse_args() diff --git a/examples/kalman_filter.py b/examples/kalman_filter.py index 9f6464502..a46778a98 100644 --- a/examples/kalman_filter.py +++ b/examples/kalman_filter.py @@ -6,6 +6,10 @@ import funsor import funsor.distributions as dist +import funsor.ops as ops +from funsor.interpreter import interpretation, reinterpret +from funsor.optimizer import apply_optimizer +from funsor.terms import lazy def main(args): @@ -16,43 +20,45 @@ def main(args): # A Gaussian HMM model. def model(data): - prob = 1. + log_prob = funsor.to_funsor(0.) - x_curr = 0. + x_curr = funsor.Tensor(torch.tensor(0.)) for t, y in enumerate(data): x_prev = x_curr # A delayed sample statement. x_curr = funsor.Variable('x_{}'.format(t), funsor.reals()) - prob *= dist.Normal(loc=x_prev, scale=trans_noise, value=x_curr) + log_prob += dist.Normal(1 + x_prev / 2., trans_noise, value=x_curr) - # If we want, we can immediately marginalize out previous sample sites. - prob = prob.sum('x_{}'.format(t - 1)) - # TODO prob = Clever(funsor.eval)(prob) + # Optionally marginalize out the previous state. + if t > 0 and not args.lazy: + log_prob = log_prob.reduce(ops.logaddexp, x_prev.name) # An observe statement. - prob *= dist.Normal(loc=x_curr, scale=emit_noise, value=y) + log_prob += dist.Normal(0.5 + 3 * x_curr, emit_noise, value=y) - return prob + # Marginalize out all remaining delayed variables. + log_prob = log_prob.reduce(ops.logaddexp) + return log_prob # Train model parameters. - print('---- training ----') + torch.manual_seed(0) data = torch.randn(args.time_steps) optim = torch.optim.Adam(params, lr=args.learning_rate) for step in range(args.train_steps): optim.zero_grad() - prob = model(data) - # TODO prob = Clever(funsor.eval)(prob) - loss = -prob.sum().log() # Integrates out delayed variables. + if args.lazy: + with interpretation(lazy): + log_prob = apply_optimizer(model(data)) + log_prob = reinterpret(log_prob) + else: + log_prob = model(data) + assert not log_prob.inputs, 'free variables remain' + loss = -log_prob.data loss.backward() optim.step() - - # Serve by drawing a posterior sample. - print('---- serving ----') - prob = model(data) - prob = funsor.eval(prob.sum()) # Forward filter. - samples = prob.backward(prob.log()) # Bakward sample. - print(samples) + if args.verbose and step % 10 == 0: + print('step {} loss = {}'.format(step, loss.item())) if __name__ == '__main__': @@ -60,15 +66,8 @@ def model(data): parser.add_argument("-t", "--time-steps", default=10, type=int) parser.add_argument("-n", "--train-steps", default=101, type=int) parser.add_argument("-lr", "--learning-rate", default=0.05, type=float) - parser.add_argument("--eager", action='store_true') + parser.add_argument("--lazy", action='store_true') parser.add_argument("--filter", action='store_true') - parser.add_argument("--xfail-if-not-implemented", action='store_true') + parser.add_argument("-v", "--verbose", action="store_true") args = parser.parse_args() - - if args.xfail_if_not_implemented: - try: - main(args) - except NotImplementedError: - print('XFAIL') - else: - main(args) + main(args) diff --git a/examples/minipyro.py b/examples/minipyro.py index 60470a995..97f48bcb5 100644 --- a/examples/minipyro.py +++ b/examples/minipyro.py @@ -3,85 +3,70 @@ import argparse import torch +from pyro.generic import distributions as dist +from pyro.generic import infer, optim, pyro, pyro_backend +from torch.distributions import constraints -import funsor -import funsor.distributions as dist -import funsor.minipyro as pyro +from funsor.interpreter import interpretation +from funsor.montecarlo import monte_carlo def main(args): - """ - minipyro version of Gaussian HMM example - """ - - # a Gaussian HMM + # Define a basic model with a single Normal latent random variable `loc` + # and a batch of Normally distributed observations. def model(data): - - trans_noise = pyro.param(name="trans_noise") - emit_noise = pyro.param(name="emit_noise") - - x_curr = 0. - for t, y in enumerate(data): - x_prev = x_curr - - # a sample statement - x_curr = pyro.sample( - dist.Normal(loc=x_prev, scale=trans_noise), - name='x_{}'.format(t)) - - # an observe statement - pyro.sample( - dist.Normal(loc=x_curr, scale=emit_noise), - obs=y, - name='y_{}'.format(t)) - - return x_curr - - trans_noise = pyro.param(torch.tensor(0.1, requires_grad=True), name="trans_noise") # noqa: F841 - emit_noise = pyro.param(torch.tensor(0.5, requires_grad=True), name="emit_noise") # noqa: F841 - data = torch.randn(args.time_steps) - - params = [node["value"] for node in pyro.trace(model).get_trace(data).values() - if node["type"] == "param"] - - # training loop - print('---- training ----') - optim = torch.optim.Adam(params, lr=args.learning_rate) - for step in range(args.train_steps): - optim.zero_grad() - - tr = pyro.trace(pyro.deferred(model)).get_trace(data) - - log_prob = sum(node["fn"](node["value"]) - for node in tr.values() - if node["type"] == "sample") - - # integrate out deferred variables - log_prob = log_prob.logsumexp() - - loss = -funsor.eval(log_prob) # does all the work - - if step % 10 == 0: - print('step {} loss = {}'.format(step, loss.item())) - loss.backward() - optim.step() + loc = pyro.sample("loc", dist.Normal(0., 1.)) + with pyro.plate("data", len(data), dim=-1): + pyro.sample("obs", dist.Normal(loc, 1.), obs=data) + + # Define a guide (i.e. variational distribution) with a Normal + # distribution over the latent random variable `loc`. + def guide(data): + guide_loc = pyro.param("guide_loc", torch.tensor(0.)) + guide_scale = pyro.param("guide_scale", torch.tensor(1.), + constraint=constraints.positive) + pyro.sample("loc", dist.Normal(guide_loc, guide_scale)) + + # Generate some data. + torch.manual_seed(0) + data = torch.randn(100) + 3.0 + + # Because the API in minipyro matches that of Pyro proper, + # training code works with generic Pyro implementations. + with pyro_backend(args.backend), interpretation(monte_carlo): + # Construct an SVI object so we can do variational inference on our + # model/guide pair. + Elbo = infer.JitTrace_ELBO if args.jit else infer.Trace_ELBO + elbo = Elbo() + adam = optim.Adam({"lr": args.learning_rate}) + svi = infer.SVI(model, guide, adam, elbo) + + # Basic training loop + pyro.get_param_store().clear() + for step in range(args.num_steps): + loss = svi.step(data) + if args.verbose and step % 100 == 0: + print("step {} loss = {}".format(step, loss)) + + # Report the final values of the variational parameters + # in the guide after training. + if args.verbose: + for name in pyro.get_param_store(): + value = pyro.param(name).data + print("{} = {}".format(name, value.detach().cpu().numpy())) + + # For this simple (conjugate) model we know the exact posterior. In + # particular we know that the variational distribution should be + # centered near 3.0. So let's check this explicitly. + assert (pyro.param("guide_loc") - 3.0).abs() < 0.1 if __name__ == "__main__": - - parser = argparse.ArgumentParser(description="Gaussian HMM example") - parser.add_argument("-t", "--time-steps", default=10, type=int) - parser.add_argument("-n", "--train-steps", default=101, type=int) - parser.add_argument("-lr", "--learning-rate", default=0.05, type=float) - parser.add_argument("--eager", action='store_true') - parser.add_argument("--filter", action='store_true') - parser.add_argument("--xfail-if-not-implemented", action='store_true') + parser = argparse.ArgumentParser(description="Minipyro demo") + parser.add_argument("-b", "--backend", default="funsor") + parser.add_argument("-n", "--num-steps", default=1001, type=int) + parser.add_argument("-lr", "--learning-rate", default=0.02, type=float) + parser.add_argument("--jit", action="store_true") + parser.add_argument("-v", "--verbose", action="store_true") args = parser.parse_args() - - if args.xfail_if_not_implemented: - try: - main(args) - except NotImplementedError: - print('XFAIL') - else: - main(args) + main(args) diff --git a/examples/pcfg.py b/examples/pcfg.py new file mode 100644 index 000000000..332fdb189 --- /dev/null +++ b/examples/pcfg.py @@ -0,0 +1,63 @@ +from __future__ import absolute_import, division, print_function + +import argparse +import math +from collections import OrderedDict + +import torch + +import funsor.ops as ops +from funsor.delta import Delta +from funsor.domains import bint +from funsor.terms import Number, Stack, Variable +from funsor.torch import Tensor + + +def Uniform(components): + components = tuple(components) + size = len(components) + if size == 1: + return components[0] + var = Variable('v', bint(size)) + return (Stack(components, var.name).reduce(ops.logaddexp, var.name) + - math.log(size)) + + +# @of_shape(*([bint(2)] * size)) +def model(size, position=0): + if size == 1: + name = str(position) + return Uniform((Delta(name, Number(0, 2)), + Delta(name, Number(1, 2)))) + return Uniform(model(t, position) + + model(size - t, t + position) + for t in range(1, size)) + + +def main(args): + torch.manual_seed(args.seed) + + print_ = print if args.verbose else lambda msg: None + print_('Data:') + data = torch.distributions.Categorical(torch.ones(2)).sample((args.size,)) + assert data.shape == (args.size,) + data = Tensor(data, OrderedDict(i=bint(args.size)), dtype=2) + print_(data) + + print_('Model:') + m = model(args.size) + print_(m.pretty()) + + print_('Eager log_prob:') + obs = {str(i): data(i) for i in range(args.size)} + log_prob = m(**obs) + print_(log_prob) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="PCFG example") + parser.add_argument("-s", "--size", default=3, type=int) + parser.add_argument("--seed", default=0, type=int) + parser.add_argument("-v", "--verbose", action='store_true') + args = parser.parse_args() + main(args) diff --git a/examples/slds.py b/examples/slds.py new file mode 100644 index 000000000..e26d275af --- /dev/null +++ b/examples/slds.py @@ -0,0 +1,81 @@ +from __future__ import absolute_import, division, print_function + +import argparse + +import torch + +import funsor +import funsor.distributions as dist +import funsor.ops as ops + + +def main(args): + # Declare parameters. + trans_probs = funsor.Tensor(torch.tensor([[0.9, 0.1], + [0.1, 0.9]], requires_grad=True)) + trans_noise = funsor.Tensor(torch.tensor([ + 0.1, # low noise component + 1.0, # high noisy component + ], requires_grad=True)) + emit_noise = funsor.Tensor(torch.tensor(0.5, requires_grad=True)) + params = [trans_probs.data, + trans_noise.data, + emit_noise.data] + + # A Gaussian HMM model. + @funsor.interpreter.interpretation(funsor.terms.moment_matching) + def model(data): + log_prob = funsor.Number(0.) + + # s is the discrete latent state, + # x is the continuous latent state, + # y is the observed state. + s_curr = funsor.Tensor(torch.tensor(0), dtype=2) + x_curr = funsor.Tensor(torch.tensor(0.)) + for t, y in enumerate(data): + s_prev = s_curr + x_prev = x_curr + + # A delayed sample statement. + s_curr = funsor.Variable('s_{}'.format(t), funsor.bint(2)) + log_prob += dist.Categorical(trans_probs[s_prev], value=s_curr) + + # A delayed sample statement. + x_curr = funsor.Variable('x_{}'.format(t), funsor.reals()) + log_prob += dist.Normal(x_prev, trans_noise[s_curr], value=x_curr) + + # Marginalize out previous delayed sample statements. + if t > 0: + log_prob = log_prob.reduce( + ops.logaddexp, frozenset([s_prev.name, x_prev.name])) + + # An observe statement. + log_prob += dist.Normal(x_curr, emit_noise, value=y) + + log_prob = log_prob.reduce(ops.logaddexp) + return log_prob + + # Train model parameters. + torch.manual_seed(0) + data = torch.randn(args.time_steps) + optim = torch.optim.Adam(params, lr=args.learning_rate) + for step in range(args.train_steps): + optim.zero_grad() + log_prob = model(data) + assert not log_prob.inputs, 'free variables remain' + loss = -log_prob.data + loss.backward() + optim.step() + if args.verbose and step % 10 == 0: + print('step {} loss = {}'.format(step, loss.item())) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Switching linear dynamical system") + parser.add_argument("-t", "--time-steps", default=10, type=int) + parser.add_argument("-n", "--train-steps", default=101, type=int) + parser.add_argument("-lr", "--learning-rate", default=0.05, type=float) + parser.add_argument("--filter", action='store_true') + parser.add_argument("-v", "--verbose", action="store_true") + args = parser.parse_args() + main(args) diff --git a/examples/ss_vae_delayed.py b/examples/ss_vae_delayed.py index 654897913..cc21bee4a 100644 --- a/examples/ss_vae_delayed.py +++ b/examples/ss_vae_delayed.py @@ -3,12 +3,12 @@ import argparse from collections import OrderedDict +import pyro.distributions as dist import torch import torch.nn as nn import funsor import funsor.minipyro as pyro -import pyro.distributions as dist class Decoder(nn.Module): @@ -23,9 +23,9 @@ class SalientEncoder(nn.Module): pass # TODO -decoder = funsor.function((), (), ())(Decoder()) -nuisance_encoder = funsor.function((), ('loc_scale',))(NuisanceEncoder()) -salient_encoder = funsor.function((), (), ())(SalientEncoder()) +decoder = funsor.torch.function((), (), ())(Decoder()) +nuisance_encoder = funsor.torch.function((), ('loc_scale',))(NuisanceEncoder()) +salient_encoder = funsor.torch.function((), (), ())(SalientEncoder()) def model(image=None): diff --git a/examples/vae.py b/examples/vae.py new file mode 100644 index 000000000..469e6aa17 --- /dev/null +++ b/examples/vae.py @@ -0,0 +1,110 @@ +from __future__ import absolute_import, division, print_function + +import argparse +import os +from collections import OrderedDict + +import torch +import torch.utils.data +from torch import nn, optim +from torch.nn import functional as F +from torchvision import datasets, transforms + +import funsor +import funsor.distributions as dist +import funsor.ops as ops +from funsor.domains import bint, reals + +REPO_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +DATA_PATH = os.path.join(REPO_PATH, 'data') + + +class Encoder(nn.Module): + def __init__(self): + super(Encoder, self).__init__() + self.fc1 = nn.Linear(784, 400) + self.fc21 = nn.Linear(400, 20) + self.fc22 = nn.Linear(400, 20) + + def forward(self, image): + image = image.reshape(image.shape[:-2] + (-1,)) + h1 = F.relu(self.fc1(image)) + loc = self.fc21(h1) + scale = self.fc22(h1).exp() + return loc, scale + + +class Decoder(nn.Module): + def __init__(self): + super(Decoder, self).__init__() + self.fc3 = nn.Linear(20, 400) + self.fc4 = nn.Linear(400, 784) + + def forward(self, z): + h3 = F.relu(self.fc3(z)) + out = torch.sigmoid(self.fc4(h3)) + return out.reshape(out.shape[:-1] + (28, 28)) + + +def main(args): + encoder = Encoder() + decoder = Decoder() + + encode = funsor.torch.function(reals(28, 28), (reals(20), reals(20)))(encoder) + decode = funsor.torch.function(reals(20), reals(28, 28))(decoder) + + @funsor.interpreter.interpretation(funsor.montecarlo.monte_carlo) + def loss_function(data, subsample_scale): + # Lazily sample from the guide. + loc, scale = encode(data) + q = funsor.Independent( + dist.Normal(loc['i'], scale['i'], value='z'), + 'z', 'i') + + # Evaluate the model likelihood at the lazy value z. + probs = decode('z') + p = dist.Bernoulli(probs['x', 'y'], value=data['x', 'y']) + p = p.reduce(ops.add, frozenset(['x', 'y'])) + + # Construct an elbo. This is where sampling happens. + elbo = funsor.Integrate(q, p - q, frozenset(['z'])) + elbo = elbo.reduce(ops.add, 'batch') * subsample_scale + loss = -elbo + return loss + + train_loader = torch.utils.data.DataLoader( + datasets.MNIST(DATA_PATH, train=True, download=True, + transform=transforms.ToTensor()), + batch_size=args.batch_size, shuffle=True) + + encoder.train() + decoder.train() + optimizer = optim.Adam(list(encoder.parameters()) + + list(decoder.parameters()), lr=1e-3) + for epoch in range(args.num_epochs): + train_loss = 0 + for batch_idx, (data, _) in enumerate(train_loader): + subsample_scale = float(len(train_loader.dataset) / len(data)) + data = data[:, 0, :, :] + data = funsor.Tensor(data, OrderedDict(batch=bint(len(data)))) + + optimizer.zero_grad() + loss = loss_function(data, subsample_scale) + assert isinstance(loss, funsor.torch.Tensor), loss.pretty() + loss.data.backward() + train_loss += loss.item() + optimizer.step() + if batch_idx % 50 == 0: + print(' loss = {}'.format(loss.item())) + if batch_idx and args.smoke_test: + return + print('epoch {} train_loss = {}'.format(epoch, train_loss)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='VAE MNIST Example') + parser.add_argument('-n', '--num-epochs', type=int, default=10) + parser.add_argument('--batch-size', type=int, default=8) + parser.add_argument('--smoke-test', action='store_true') + args = parser.parse_args() + main(args) diff --git a/funsor/__init__.py b/funsor/__init__.py index b9d4e16e6..e5b27f6e9 100644 --- a/funsor/__init__.py +++ b/funsor/__init__.py @@ -2,37 +2,69 @@ from funsor.domains import Domain, bint, find_domain, reals from funsor.fixpoints import fix +from funsor.integrate import Integrate from funsor.interpreter import reinterpret -from funsor.terms import Funsor, Number, Variable, of_shape, to_funsor -from funsor.torch import Function, Tensor, arange, einsum, function +from funsor.terms import Funsor, Independent, Lambda, Number, Variable, of_shape, to_data, to_funsor +from funsor.torch import Tensor, arange -from . import distributions, domains, fixpoints, handlers, interpreter, minipyro, ops, terms, torch +from . import ( + adjoint, + affine, + contract, + delta, + distributions, + domains, + einsum, + fixpoints, + gaussian, + integrate, + interpreter, + joint, + minipyro, + montecarlo, + ops, + pattern, + sum_product, + terms, + torch +) __all__ = [ 'Domain', - 'Function', 'Funsor', + 'Independent', + 'Integrate', + 'Lambda', 'Number', 'Tensor', 'Variable', + 'adjoint', + 'affine', 'arange', 'backward', 'bint', + 'contract', + 'delta', 'distributions', 'domains', 'einsum', 'find_domain', 'fix', 'fixpoints', - 'function', - 'handlers', + 'gaussian', + 'integrate', 'interpreter', + 'joint', 'minipyro', + 'montecarlo', 'of_shape', 'ops', + 'pattern', 'reals', 'reinterpret', + 'sum_product', 'terms', + 'to_data', 'to_funsor', 'torch', ] diff --git a/funsor/adjoint.py b/funsor/adjoint.py new file mode 100644 index 000000000..c9619af9f --- /dev/null +++ b/funsor/adjoint.py @@ -0,0 +1,102 @@ +from __future__ import absolute_import, division, print_function + +from collections import defaultdict + +import torch + +import funsor.ops as ops +from funsor.contract import Contract +from funsor.interpreter import interpretation, reinterpret +from funsor.ops import AssociativeOp +from funsor.registry import KeyedRegistry +from funsor.terms import Binary, Funsor, Number, Reduce, Variable, eager +from funsor.torch import Tensor + + +class AdjointTape(object): + + def __init__(self): + self.tape = [] + + def __call__(self, cls, *args): + result = eager(cls, *args) + if cls in (Reduce, Contract, Binary, Tensor): + self.tape.append((result, cls, args)) + return result + + +def adjoint(expr, targets, start=Number(0.)): + + adjoint_values = defaultdict(lambda: Number(0.)) # 1 in logspace + multiplicities = defaultdict(lambda: 0) + + tape_recorder = AdjointTape() + with interpretation(tape_recorder): + adjoint_values[reinterpret(expr)] = start + + while tape_recorder.tape: + output, fn, inputs = tape_recorder.tape.pop() + in_adjs = adjoint_ops(fn, adjoint_values[output], output, *inputs) + for v, adjv in in_adjs.items(): + multiplicities[v] += 1 + adjoint_values[v] = adjoint_values[v] + adjv # product in logspace + + target_adjs = {} + for v in targets: + target_adjs[v] = adjoint_values[v] / multiplicities[v] + if not isinstance(v, Variable): + target_adjs[v] = target_adjs[v] + v + return target_adjs + + +# logaddexp/add +def _fail_default(*args): + raise NotImplementedError("Should not be here! {}".format(args)) + + +adjoint_ops = KeyedRegistry(default=_fail_default) + + +@adjoint_ops.register(Tensor, Funsor, Funsor, torch.Tensor, tuple, object) +def adjoint_tensor(out_adj, out, data, inputs, dtype): + all_vars = frozenset(k for (k, v) in inputs) + in_adjs = {} + for (k, v) in inputs: + in_adj = (out_adj + out).reduce(ops.logaddexp, all_vars - {k}) + in_adjs[Variable(k, v)] = in_adj + return in_adjs + + +@adjoint_ops.register(Binary, Funsor, Funsor, AssociativeOp, Funsor, Funsor) +def adjoint_binary(out_adj, out, op, lhs, rhs): + assert op is ops.add + + lhs_reduced_vars = frozenset(rhs.inputs) - frozenset(lhs.inputs) + lhs_adj = (out_adj + rhs).reduce(ops.logaddexp, lhs_reduced_vars) + + rhs_reduced_vars = frozenset(lhs.inputs) - frozenset(rhs.inputs) + rhs_adj = (out_adj + lhs).reduce(ops.logaddexp, rhs_reduced_vars) + + return {lhs: lhs_adj, rhs: rhs_adj} + + +@adjoint_ops.register(Reduce, Funsor, Funsor, AssociativeOp, Funsor, frozenset) +def adjoint_reduce(out_adj, out, op, arg, reduced_vars): + assert op in (ops.logaddexp, ops.add) + + if op is ops.logaddexp: + return {arg: out_adj + (arg * 0.)} # XXX hack to simulate "expand" + elif op is ops.add: # plate! + return {arg: out_adj + Binary(ops.safesub, out, arg)} + + +@adjoint_ops.register(Contract, Funsor, Funsor, AssociativeOp, AssociativeOp, Funsor, Funsor, frozenset) +def adjoint_contract(out_adj, out, sum_op, prod_op, lhs, rhs, reduced_vars): + + lhs_reduced_vars = frozenset(rhs.inputs) - frozenset(lhs.inputs) + lhs_adj = Contract(sum_op, prod_op, out_adj, rhs, lhs_reduced_vars) + + rhs_reduced_vars = frozenset(lhs.inputs) - frozenset(rhs.inputs) + rhs_adj = Contract(sum_op, prod_op, out_adj, lhs, rhs_reduced_vars) + + return {lhs: lhs_adj, rhs: rhs_adj} diff --git a/funsor/affine.py b/funsor/affine.py new file mode 100644 index 000000000..8b5591c4c --- /dev/null +++ b/funsor/affine.py @@ -0,0 +1,194 @@ +from __future__ import absolute_import, division, print_function + +from collections import OrderedDict + +import funsor.ops as ops +from funsor.domains import find_domain +from funsor.ops import NegOp, Op +from funsor.terms import Binary, Funsor, Number, Unary, Variable, eager +from funsor.torch import Tensor + + +class Affine(Funsor): + """ + Pattern representing multilinear function of input variables + """ + def __init__(self, const, coeffs): + assert isinstance(const, (Number, Tensor)) + assert not any(d.dtype == "real" for d in const.inputs.values()) + assert isinstance(coeffs, tuple) + inputs = const.inputs.copy() + output = const.output + assert output.dtype == "real" + for var, coeff in coeffs: + assert isinstance(var, Variable) + assert isinstance(coeff, (Number, Tensor)) + assert not any(d.dtype == "real" for d in coeff.inputs.values()) + inputs.update(coeff.inputs) + inputs.update(var.inputs) + output = find_domain(ops.add, output, find_domain(ops.mul, var.output, coeff.output)) + assert var.dtype == "real" + assert coeff.dtype == "real" + assert output.dtype == "real" + + super(Affine, self).__init__(inputs, output) + self.coeffs = OrderedDict(coeffs) + self.const = const + + +############################################### +# patterns for merging Affine with other terms +############################################### + +@eager.register(Affine, (Number, Tensor), tuple) +def eager_affine(const, coeffs): + if not coeffs: + return const + if not all(isinstance(var, Variable) for var, coeff in coeffs): + result = Affine(const, tuple((var, coeff) for var, coeff in coeffs if isinstance(var, Variable))) + for var, coeff in coeffs: + if not isinstance(var, Variable): + result += var * coeff + return result + return None + + +@eager.register(Binary, Op, Affine, (Number, Tensor)) +def eager_binary_affine(op, lhs, rhs): + if op is ops.add or op is ops.sub: + const = op(lhs.const, rhs) + return Affine(const, tuple(lhs.coeffs.items())) + if op is ops.mul or op is ops.truediv: + const = op(lhs.const, rhs) + coeffs = tuple((var, op(coeff, rhs)) for var, coeff in lhs.coeffs.items()) + return Affine(const, coeffs) + return None + + +@eager.register(Binary, Op, (Number, Tensor), Affine) +def eager_binary_affine(op, lhs, rhs): + if op is ops.add: + const = lhs + rhs.const + return Affine(const, tuple(rhs.coeffs.items())) + elif op is ops.sub: + return lhs + -rhs + if op is ops.mul: + const = lhs * rhs.const + coeffs = tuple((var, lhs * coeff) for var, coeff in rhs.coeffs.items()) + return Affine(const, coeffs) + return None + + +@eager.register(Binary, Op, Affine, Affine) +def eager_binary_affine_affine(op, lhs, rhs): + if op is ops.add: + const = lhs.const + rhs.const + coeffs = lhs.coeffs.copy() + for var, coeff in rhs.coeffs.items(): + if var in coeffs: + coeffs[var] += coeff + else: + coeffs[var] = coeff + return Affine(const, tuple(coeffs.items())) + + if op is ops.sub: + return lhs + -rhs + + return None + + +@eager.register(Binary, Op, Affine, Variable) +def eager_binary_affine_variable(op, affine, other): + if op is ops.add: + const = affine.const + coeffs = affine.coeffs.copy() + if other in affine.inputs: + coeffs[other] += 1 + else: + coeffs[other] = Number(1.) + return Affine(const, tuple(coeffs.items())) + + if op is ops.sub: + return affine + -other + + return None + + +@eager.register(Binary, Op, Variable, Affine) +def eager_binary_variable_affine(op, other, affine): + if op is ops.add: + return affine + other + + if op is ops.sub: + return -affine + other + + return None + + +@eager.register(Unary, NegOp, Affine) +def eager_negate_affine(op, affine): + const = -affine.const + coeffs = affine.coeffs.copy() + for var, coeff in coeffs.items(): + coeffs[var] = -coeff + return Affine(const, tuple(coeffs.items())) + + +######################################### +# patterns for creating new Affine terms +######################################### + +@eager.register(Binary, Op, Variable, (Number, Tensor)) +def eager_binary(op, var, other): + if var.dtype != "real" or other.dtype != "real": + return None + + if op is ops.add: + const = other + coeffs = ((var, Number(1.)),) + return Affine(const, coeffs) + elif op is ops.mul: + const = Number(0.) + coeffs = ((var, other),) + return Affine(const, coeffs) + elif op is ops.sub: + return var + -other + elif op is ops.truediv: + return var * (1. / other) + return None + + +@eager.register(Binary, Op, Variable, Variable) +def eager_binary(op, lhs, rhs): + if lhs.dtype != "real" or rhs.dtype != "real": + return None + + if op is ops.add: + const = Number(0.) + coeffs = ((lhs, Number(1.)), (rhs, Number(1.))) + return Affine(const, coeffs) + elif op is ops.sub: + return lhs + -rhs + return None + + +@eager.register(Binary, Op, (Number, Tensor), Variable) +def eager_binary(op, other, var): + if other.dtype != "real" or var.dtype != "real": + return None + + if op is ops.add or op is ops.mul: + return op(var, other) + elif op is ops.sub: + return -var + other + return None + + +@eager.register(Unary, NegOp, Variable) +def eager_negate_variable(op, var): + if var.dtype != "real": + return None + + const = Number(0.) + coeffs = ((var, Number(-1, "real")),) + return Affine(const, coeffs) diff --git a/funsor/contract.py b/funsor/contract.py new file mode 100644 index 000000000..80a906574 --- /dev/null +++ b/funsor/contract.py @@ -0,0 +1,74 @@ +from __future__ import absolute_import, division, print_function + +import functools +from collections import OrderedDict + +import funsor.interpreter as interpreter +import funsor.ops as ops +from funsor.terms import Funsor, eager + + +def _simplify_contract(fn, sum_op, prod_op, lhs, rhs, reduced_vars): + """ + Reduce free variables that do not appear explicitly in the lhs + """ + if not reduced_vars: + return prod_op(lhs, rhs) + + lhs_vars = frozenset(lhs.inputs) + rhs_vars = frozenset(rhs.inputs) + assert reduced_vars <= lhs_vars | rhs_vars + progress = False + if not reduced_vars <= lhs_vars: + rhs = rhs.reduce(sum_op, reduced_vars - lhs_vars) + reduced_vars = reduced_vars & lhs_vars + progress = True + if not reduced_vars <= rhs_vars: + lhs = lhs.reduce(sum_op, reduced_vars - rhs_vars) + reduced_vars = reduced_vars & rhs_vars + progress = True + if progress: + return Contract(sum_op, prod_op, lhs, rhs, reduced_vars) + + return fn(sum_op, prod_op, lhs, rhs, reduced_vars) + + +def contractor(fn): + """ + Decorator for contract implementations to simplify inputs. + """ + fn = interpreter.debug_logged(fn) + return functools.partial(_simplify_contract, fn) + + +class Contract(Funsor): + + def __init__(self, sum_op, prod_op, lhs, rhs, reduced_vars): + assert isinstance(sum_op, ops.AssociativeOp) + assert isinstance(prod_op, ops.AssociativeOp) + assert isinstance(lhs, Funsor) + assert isinstance(rhs, Funsor) + assert isinstance(reduced_vars, frozenset) + inputs = OrderedDict([(k, d) for t in (lhs, rhs) + for k, d in t.inputs.items() if k not in reduced_vars]) + output = rhs.output + fresh = frozenset() + bound = reduced_vars + super(Contract, self).__init__(inputs, output, fresh, bound) + self.sum_op = sum_op + self.prod_op = prod_op + self.lhs = lhs + self.rhs = rhs + self.reduced_vars = reduced_vars + + +@eager.register(Contract, ops.AssociativeOp, ops.AssociativeOp, Funsor, Funsor, frozenset) +@contractor +def eager_contract(sum_op, prod_op, lhs, rhs, reduced_vars): + return prod_op(lhs, rhs).reduce(sum_op, reduced_vars) + + +__all__ = [ + 'Contract', + 'contractor', +] diff --git a/funsor/delta.py b/funsor/delta.py new file mode 100644 index 000000000..506e5ccdc --- /dev/null +++ b/funsor/delta.py @@ -0,0 +1,199 @@ +from __future__ import absolute_import, division, print_function + +from collections import OrderedDict + +from six import add_metaclass + +import funsor.ops as ops +import funsor.terms +from funsor.domains import Domain, reals +from funsor.integrate import Integrate, integrator +from funsor.interpreter import debug_logged +from funsor.ops import AddOp, SubOp, TransformOp +from funsor.registry import KeyedRegistry +from funsor.terms import ( + Align, + Binary, + Funsor, + FunsorMeta, + Independent, + Lambda, + Number, + Reduce, + Subs, + Unary, + Variable, + eager, + to_funsor +) + + +class DeltaMeta(FunsorMeta): + """ + Wrapper to fill in defaults. + """ + def __call__(cls, name, point, log_density=0): + point = to_funsor(point) + log_density = to_funsor(log_density) + return super(DeltaMeta, cls).__call__(name, point, log_density) + + +@add_metaclass(DeltaMeta) +class Delta(Funsor): + """ + Normalized delta distribution binding a single variable. + + :param str name: Name of the bound variable. + :param Funsor point: Value of the bound variable. + :param Funsor log_density: Optional log density to be added when evaluating + at a point. This is needed to make :class:`Delta` closed under + differentiable substitution. + """ + def __init__(self, name, point, log_density=0): + assert isinstance(name, str) + assert isinstance(point, Funsor) + assert isinstance(log_density, Funsor) + assert log_density.output == reals() + inputs = OrderedDict([(name, point.output)]) + inputs.update(point.inputs) + inputs.update(log_density.inputs) + output = reals() + fresh = frozenset({name}) + bound = frozenset() + super(Delta, self).__init__(inputs, output, fresh, bound) + self.name = name + self.point = point + self.log_density = log_density + + def eager_subs(self, subs): + assert len(subs) == 1 and subs[0][0] == self.name + value = subs[0][1] + + if isinstance(value, Variable): + return Delta(value.name, self.point, self.log_density) + + if not any(d.dtype == 'real' for side in (value, self.point) + for d in side.inputs.values()): + return (value == self.point).all().log() + self.log_density + + # Try to invert the substitution. + soln = solve(value, self.point) + if soln is None: + return None # lazily substitute + name, point, log_density = soln + log_density += self.log_density + return Delta(name, point, log_density) + + def eager_reduce(self, op, reduced_vars): + if op is ops.logaddexp: + if self.name in reduced_vars: + return Number(0) # Deltas are normalized. + + # TODO Implement ops.add to simulate .to_event(). + + return None # defer to default implementation + + +@eager.register(Binary, AddOp, Delta, (Funsor, Align)) +def eager_add(op, lhs, rhs): + if lhs.name in rhs.inputs: + rhs = rhs(**{lhs.name: lhs.point}) + return op(lhs, rhs) + + return None # defer to default implementation + + +@eager.register(Binary, SubOp, Delta, (Funsor, Align)) +def eager_sub(op, lhs, rhs): + if lhs.name in rhs.inputs: + rhs = rhs(**{lhs.name: lhs.point}) + return op(lhs, rhs) + + return None # defer to default implementation + + +@eager.register(Binary, AddOp, (Funsor, Align), Delta) +def eager_add(op, lhs, rhs): + if rhs.name in lhs.inputs: + lhs = lhs(**{rhs.name: rhs.point}) + return op(lhs, rhs) + + return None # defer to default implementation + + +eager.register(Binary, AddOp, Delta, Reduce)( + funsor.terms.eager_distribute_other_reduce) +eager.register(Binary, AddOp, Reduce, Delta)( + funsor.terms.eager_distribute_reduce_other) + + +@eager.register(Independent, Delta, str, str) +def eager_independent(delta, reals_var, bint_var): + if delta.name == reals_var or delta.name.startswith(reals_var + "__BOUND"): + i = Variable(bint_var, delta.inputs[bint_var]) + point = Lambda(i, delta.point) + if bint_var in delta.log_density.inputs: + log_density = delta.log_density.reduce(ops.add, bint_var) + else: + log_density = delta.log_density * delta.inputs[bint_var].dtype + return Delta(reals_var, point, log_density) + + return None # defer to default implementation + + +@eager.register(Integrate, Delta, Funsor, frozenset) +@integrator +def eager_integrate(delta, integrand, reduced_vars): + assert delta.name in reduced_vars + integrand = Subs(integrand, ((delta.name, delta.point),)) + log_measure = delta.log_density + reduced_vars -= frozenset([delta.name]) + return Integrate(log_measure, integrand, reduced_vars) + + +def solve(expr, value): + """ + Tries to solve for free inputs of an ``expr`` such that ``expr == value``, + and computes the log-abs-det-Jacobian of the resulting substitution. + + :param Funsor expr: An expression with a free variable. + :param Funsor value: A target value. + :return: A tuple ``(name, point, log_abs_det_jacobian)`` + :rtype: tuple + :raises: ValueError + """ + assert isinstance(expr, Funsor) + assert isinstance(value, Funsor) + result = solve.dispatch(type(expr), *(expr._ast_values + (value,))) + if result is None: + raise ValueError("Cannot substitute into a Delta: {}".format(value)) + return result + + +_solve = KeyedRegistry(lambda *args: None) +solve.dispatch = _solve.__call__ +solve.register = _solve.register + + +@solve.register(Variable, str, Domain, Funsor) +@debug_logged +def solve_variable(name, output, y): + assert y.output == output + point = y + log_density = Number(0) + return name, point, log_density + + +@solve.register(Unary, TransformOp, Funsor, Funsor) +@debug_logged +def solve_unary(op, arg, y): + x = op.inv(y) + name, point, log_density = solve(arg, x) + log_density += op.log_abs_det_jacobian(x, y) + return name, point, log_density + + +__all__ = [ + 'Delta', + 'solve', +] diff --git a/funsor/distributions.py b/funsor/distributions.py index 7ed8afddd..bd2ba806f 100644 --- a/funsor/distributions.py +++ b/funsor/distributions.py @@ -1,16 +1,61 @@ from __future__ import absolute_import, division, print_function +import math from collections import OrderedDict import pyro.distributions as dist +import torch +from pyro.distributions.util import broadcast_shape from six import add_metaclass +import funsor.delta import funsor.ops as ops +from funsor.affine import Affine from funsor.domains import bint, reals -from funsor.terms import Funsor, FunsorMeta, Number, Variable, eager, to_funsor -from funsor.torch import Tensor, align_tensors, materialize +from funsor.gaussian import BlockMatrix, BlockVector, Gaussian +from funsor.interpreter import interpretation +from funsor.terms import Funsor, FunsorMeta, Number, Variable, eager, lazy, to_funsor +from funsor.torch import Tensor, align_tensors, ignore_jit_warnings, materialize, torch_stack +def numbers_to_tensors(*args): + """ + Convert :class:`~funsor.terms.Number`s to :class:`funsor.torch.Tensor`s, + using any provided tensor as a prototype, if available. + """ + if any(isinstance(x, Number) for x in args): + options = dict(dtype=torch.get_default_dtype()) + for x in args: + if isinstance(x, Tensor): + options = dict(dtype=x.data.dtype, device=x.data.device) + break + with ignore_jit_warnings(): + args = tuple(Tensor(torch.tensor(x.data, **options), dtype=x.dtype) + if isinstance(x, Number) else x + for x in args) + return args + + +class DistributionMeta(FunsorMeta): + """ + Wrapper to fill in default values and convert Numbers to Tensors. + """ + def __call__(cls, *args, **kwargs): + kwargs.update(zip(cls._ast_fields, args)) + args = cls._fill_defaults(**kwargs) + args = numbers_to_tensors(*args) + + # If value was explicitly specified, evaluate under current interpretation. + if 'value' in kwargs: + return super(DistributionMeta, cls).__call__(*args) + + # Otherwise lazily construct a distribution instance. + # This makes it cheaper to construct observations in minipyro. + with interpretation(lazy): + return super(DistributionMeta, cls).__call__(*args) + + +@add_metaclass(DistributionMeta) class Distribution(Funsor): """ Funsor backed by a PyTorch distribution object. @@ -34,17 +79,10 @@ def __repr__(self): return '{}({})'.format(type(self).__name__, ', '.join('{}={}'.format(*kv) for kv in self.params)) - def eager_subs(self, subs): - assert isinstance(subs, tuple) - if not any(k in self.inputs for k, v in subs): - return self - params = OrderedDict((k, v.eager_subs(subs)) for k, v in self.params) - return type(self)(**params) - def eager_reduce(self, op, reduced_vars): if op is ops.logaddexp and isinstance(self.value, Variable) and self.value.name in reduced_vars: return Number(0.) # distributions are normalized - return super(Distribution, self).reduce(op, reduced_vars) + return super(Distribution, self).eager_reduce(op, reduced_vars) @classmethod def eager_log_prob(cls, **params): @@ -59,73 +97,383 @@ def eager_log_prob(cls, **params): # Distribution Wrappers ################################################################################ -class CategoricalMeta(FunsorMeta): - """ - Wrapper to fill in default params. - """ - def __call__(cls, probs, value=None): +class BernoulliProbs(Distribution): + dist_class = dist.Bernoulli + + @staticmethod + def _fill_defaults(probs, value='value'): + probs = to_funsor(probs) + assert probs.dtype == "real" + value = to_funsor(value, reals()) + return probs, value + + def __init__(self, probs, value=None): + super(BernoulliProbs, self).__init__(probs, value) + + +@eager.register(BernoulliProbs, Tensor, Tensor) +def eager_bernoulli(probs, value): + return BernoulliProbs.eager_log_prob(probs=probs, value=value) + + +class BernoulliLogits(Distribution): + dist_class = dist.Bernoulli + + @staticmethod + def _fill_defaults(logits, value='value'): + logits = to_funsor(logits) + assert logits.dtype == "real" + value = to_funsor(value, reals()) + return logits, value + + def __init__(self, logits, value=None): + super(BernoulliLogits, self).__init__(logits, value) + + +@eager.register(BernoulliLogits, Tensor, Tensor) +def eager_bernoulli_logits(logits, value): + return BernoulliLogits.eager_log_prob(logits=logits, value=value) + + +def Bernoulli(probs=None, logits=None, value='value'): + if probs is not None: + return BernoulliProbs(probs, value) + if logits is not None: + return BernoulliLogits(logits, value) + raise ValueError('Either probs or logits must be specified') + + +class Beta(Distribution): + dist_class = dist.Beta + + @staticmethod + def _fill_defaults(concentration1, concentration0, value='value'): + concentration1 = to_funsor(concentration1, reals()) + concentration0 = to_funsor(concentration0, reals()) + value = to_funsor(value, reals()) + return concentration1, concentration0, value + + def __init__(self, concentration1, concentration0, value=None): + super(Beta, self).__init__(concentration1, concentration0, value) + + +@eager.register(Beta, Tensor, Tensor, Tensor) +def eager_beta(concentration1, concentration0, value): + return Beta.eager_log_prob(concentration1=concentration1, + concentration0=concentration0, + value=value) + + +@eager.register(Beta, Funsor, Funsor, Funsor) +def eager_beta(concentration1, concentration0, value): + concentration = torch_stack((concentration0, concentration1)) + value = torch_stack((1 - value, value)) + return Dirichlet(concentration, value=value) + + +class Binomial(Distribution): + dist_class = dist.Binomial + + @staticmethod + def _fill_defaults(total_count, probs, value='value'): + total_count = to_funsor(total_count, reals()) probs = to_funsor(probs) - if value is None: - size = probs.output.shape[0] - value = Variable('value', bint(size)) - else: - value = to_funsor(value) - return super(CategoricalMeta, cls).__call__(probs, value) + assert probs.dtype == "real" + value = to_funsor(value, reals()) + return total_count, probs, value + + def __init__(self, total_count, probs, value=None): + super(Binomial, self).__init__(total_count, probs, value) + + +@eager.register(Binomial, Tensor, Tensor, Tensor) +def eager_binomial(total_count, probs, value): + return Binomial.eager_log_prob(total_count=total_count, probs=probs, value=value) + + +@eager.register(Binomial, Funsor, Funsor, Funsor) +def eager_binomial(total_count, probs, value): + probs = torch_stack((1 - probs, probs)) + value = torch_stack((total_count - value, value)) + return Multinomial(total_count, probs, value=value) -@add_metaclass(CategoricalMeta) class Categorical(Distribution): dist_class = dist.Categorical - def __init__(self, probs, value=None): + @staticmethod + def _fill_defaults(probs, value='value'): + probs = to_funsor(probs) + assert probs.dtype == "real" + value = to_funsor(value, bint(probs.output.shape[0])) + return probs, value + + def __init__(self, probs, value='value'): super(Categorical, self).__init__(probs, value) -@eager.register(Categorical, Funsor, Number) +@eager.register(Categorical, Funsor, Tensor) def eager_categorical(probs, value): return probs[value].log() -@eager.register(Categorical, (Number, Tensor), (Number, Tensor)) +@eager.register(Categorical, Tensor, Tensor) def eager_categorical(probs, value): return Categorical.eager_log_prob(probs=probs, value=value) -@eager.register(Categorical, (Number, Tensor), Variable) +@eager.register(Categorical, Tensor, Variable) def eager_categorical(probs, value): value = materialize(value) return Categorical.eager_log_prob(probs=probs, value=value) -class NormalMeta(FunsorMeta): - """ - Wrapper to fill in default params. - """ - def __call__(cls, loc, scale, value=None): - loc = to_funsor(loc) - scale = to_funsor(scale) - if value is None: - value = Variable('value', reals()) - else: - value = to_funsor(value) - return super(NormalMeta, cls).__call__(loc, scale, value) +class Delta(Distribution): + dist_class = dist.Delta + + @staticmethod + def _fill_defaults(v, log_density=0, value='value'): + v = to_funsor(v) + log_density = to_funsor(log_density, reals()) + value = to_funsor(value, v.output) + return v, log_density, value + + def __init__(self, v, log_density=0, value='value'): + return super(Delta, self).__init__(v, log_density, value) + + +@eager.register(Delta, Tensor, Tensor, Tensor) +def eager_delta(v, log_density, value): + # This handles event_dim specially, and hence cannot use the + # generic Delta.eager_log_prob() method. + assert v.output == value.output + event_dim = len(v.output.shape) + inputs, (v, log_density, value) = align_tensors(v, log_density, value) + data = dist.Delta(v, log_density, event_dim).log_prob(value) + return Tensor(data, inputs) + + +@eager.register(Delta, Funsor, Funsor, Variable) +@eager.register(Delta, Variable, Funsor, Variable) +def eager_delta(v, log_density, value): + assert v.output == value.output + return funsor.delta.Delta(value.name, v, log_density) + + +@eager.register(Delta, Variable, Funsor, Funsor) +def eager_delta(v, log_density, value): + assert v.output == value.output + return funsor.delta.Delta(v.name, value, log_density) + + +class Dirichlet(Distribution): + dist_class = dist.Dirichlet + + @staticmethod + def _fill_defaults(concentration, value='value'): + concentration = to_funsor(concentration) + assert concentration.dtype == "real" + assert len(concentration.output.shape) == 1 + dim = concentration.output.shape[0] + value = to_funsor(value, reals(dim)) + return concentration, value + + def __init__(self, concentration, value='value'): + super(Dirichlet, self).__init__(concentration, value) + + +@eager.register(Dirichlet, Tensor, Tensor) +def eager_dirichlet(concentration, value): + return Dirichlet.eager_log_prob(concentration=concentration, value=value) + + +class DirichletMultinomial(Distribution): + dist_class = dist.DirichletMultinomial + + @staticmethod + def _fill_defaults(concentration, total_count=1, value='value'): + concentration = to_funsor(concentration) + assert concentration.dtype == "real" + assert len(concentration.output.shape) == 1 + total_count = to_funsor(total_count, reals()) + dim = concentration.output.shape[0] + value = to_funsor(value, reals(dim)) # Should this be bint(total_count)? + return concentration, total_count, value + + def __init__(self, concentration, total_count, value='value'): + super(DirichletMultinomial, self).__init__(concentration, total_count, value) + + +@eager.register(DirichletMultinomial, Tensor, Tensor, Tensor) +def eager_dirichlet_multinomial(concentration, total_count, value): + return DirichletMultinomial.eager_log_prob( + concentration=concentration, total_count=total_count, value=value) + + +def LogNormal(loc, scale, value='value'): + loc, scale, y = Normal._fill_defaults(loc, scale, value) + t = ops.exp + x = t.inv(y) + log_abs_det_jacobian = t.log_abs_det_jacobian(x, y) + return Normal(loc, scale, x) - log_abs_det_jacobian + + +class Multinomial(Distribution): + dist_class = dist.Multinomial + + @staticmethod + def _fill_defaults(total_count, probs, value='value'): + total_count = to_funsor(total_count, reals()) + probs = to_funsor(probs) + assert probs.dtype == "real" + assert len(probs.output.shape) == 1 + value = to_funsor(value, probs.output) + return total_count, probs, value + + def __init__(self, total_count, probs, value=None): + super(Multinomial, self).__init__(total_count, probs, value) + + +@eager.register(Multinomial, Tensor, Tensor, Tensor) +def eager_multinomial(total_count, probs, value): + # Multinomial.log_prob() supports inhomogeneous total_count only by + # avoiding passing total_count to the constructor. + inputs, (total_count, probs, value) = align_tensors(total_count, probs, value) + shape = broadcast_shape(total_count.shape + (1,), probs.shape, value.shape) + probs = Tensor(probs.expand(shape), inputs) + value = Tensor(value.expand(shape), inputs) + total_count = Number(total_count.max().item()) # Used by distributions validation code. + return Multinomial.eager_log_prob(total_count=total_count, probs=probs, value=value) -@add_metaclass(NormalMeta) class Normal(Distribution): dist_class = dist.Normal - def __init__(self, loc, scale, value=None): + @staticmethod + def _fill_defaults(loc, scale, value='value'): + loc = to_funsor(loc, reals()) + scale = to_funsor(scale, reals()) + value = to_funsor(value, reals()) + return loc, scale, value + + def __init__(self, loc, scale, value='value'): super(Normal, self).__init__(loc, scale, value) -@eager.register(Normal, (Number, Tensor), (Number, Tensor), (Number, Tensor)) +@eager.register(Normal, Tensor, Tensor, Tensor) def eager_normal(loc, scale, value): return Normal.eager_log_prob(loc=loc, scale=scale, value=value) +# Create a Gaussian from a ground prior or ground likelihood. +@eager.register(Normal, Tensor, Tensor, Variable) +@eager.register(Normal, Variable, Tensor, Tensor) +def eager_normal(loc, scale, value): + if isinstance(loc, Variable): + loc, value = value, loc + + inputs, (loc, scale) = align_tensors(loc, scale) + loc, scale = torch.broadcast_tensors(loc, scale) + inputs.update(value.inputs) + int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real') + + log_prob = -0.5 * math.log(2 * math.pi) - scale.log() + loc = loc.unsqueeze(-1) + precision = scale.pow(-2).unsqueeze(-1).unsqueeze(-1) + return Tensor(log_prob, int_inputs) + Gaussian(loc, precision, inputs) + + +# Create a transformed Gaussian from a ground prior or ground likelihood. +@eager.register(Normal, Tensor, Tensor, Funsor) +@eager.register(Normal, Funsor, Tensor, Tensor) +def eager_normal(loc, scale, value): + if not isinstance(loc, Tensor): + loc, value = value, loc + return Normal(loc, scale, 'value')(value=value) + + +@eager.register(Normal, (Variable, Affine), Tensor, (Variable, Affine)) +@eager.register(Normal, (Variable, Affine), Tensor, Tensor) +@eager.register(Normal, Tensor, Tensor, (Variable, Affine)) +def eager_normal(loc, scale, value): + affine = (loc - value) / scale + assert isinstance(affine, Affine) + real_inputs = OrderedDict((k, v) for k, v in affine.inputs.items() if v.dtype == 'real') + assert not any(v.shape for v in real_inputs.values()) + + tensors = [affine.const] + [c for v, c in affine.coeffs.items()] + inputs, tensors = align_tensors(*tensors) + tensors = torch.broadcast_tensors(*tensors) + const, coeffs = tensors[0], tensors[1:] + + dim = sum(d.num_elements for d in real_inputs.values()) + loc = BlockVector(const.shape + (dim,)) + loc[..., 0] = -const / coeffs[0] + precision = BlockMatrix(const.shape + (dim, dim)) + for i, (v1, c1) in enumerate(zip(real_inputs, coeffs)): + for j, (v2, c2) in enumerate(zip(real_inputs, coeffs)): + precision[..., i, j] = c1 * c2 + loc = loc.as_tensor() + precision = precision.as_tensor() + + log_prob = -0.5 * math.log(2 * math.pi) - scale.log() + return log_prob + Gaussian(loc, precision, affine.inputs) + + +class MultivariateNormal(Distribution): + dist_class = dist.MultivariateNormal + + @staticmethod + def _fill_defaults(loc, scale_tril, value='value'): + loc = to_funsor(loc) + scale_tril = to_funsor(scale_tril) + assert loc.dtype == 'real' + assert scale_tril.dtype == 'real' + assert len(loc.output.shape) == 1 + dim = loc.output.shape[0] + assert scale_tril.output.shape == (dim, dim) + value = to_funsor(value, loc.output) + return loc, scale_tril, value + + def __init__(self, loc, scale_tril, value='value'): + super(MultivariateNormal, self).__init__(loc, scale_tril, value) + + +@eager.register(MultivariateNormal, Tensor, Tensor, Tensor) +def eager_mvn(loc, scale_tril, value): + return MultivariateNormal.eager_log_prob(loc=loc, scale_tril=scale_tril, value=value) + + +# Create a Gaussian from a ground observation. +@eager.register(MultivariateNormal, Tensor, Tensor, Variable) +@eager.register(MultivariateNormal, Variable, Tensor, Tensor) +def eager_mvn(loc, scale_tril, value): + if isinstance(loc, Variable): + loc, value = value, loc + + dim, = loc.output.shape + inputs, (loc, scale_tril) = align_tensors(loc, scale_tril) + inputs.update(value.inputs) + int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real') + + log_prob = -0.5 * dim * math.log(2 * math.pi) - scale_tril.diagonal(dim1=-1, dim2=-2).log().sum(-1) + inv_scale_tril = torch.inverse(scale_tril) + precision = torch.matmul(inv_scale_tril.transpose(-1, -2), inv_scale_tril) + return Tensor(log_prob, int_inputs) + Gaussian(loc, precision, inputs) + + __all__ = [ + 'Bernoulli', + 'BernoulliLogits', + 'Beta', + 'Binomial', 'Categorical', + 'Delta', + 'Dirichlet', + 'DirichletMultinomial', 'Distribution', + 'LogNormal', + 'Multinomial', + 'MultivariateNormal', 'Normal', ] diff --git a/funsor/domains.py b/funsor/domains.py index 53846401e..e97391bfc 100644 --- a/funsor/domains.py +++ b/funsor/domains.py @@ -9,6 +9,7 @@ import funsor.ops as ops from funsor.util import lazy_property +import torch class Domain(namedtuple('Domain', ['shape', 'dtype'])): @@ -18,6 +19,9 @@ class Domain(namedtuple('Domain', ['shape', 'dtype'])): """ def __new__(cls, shape, dtype): assert isinstance(shape, tuple) + if torch._C._get_tracing_state(): + shape = tuple(map(int, shape)) + assert all(isinstance(size, integer_types) for size in shape), shape if isinstance(dtype, integer_types): assert not shape elif isinstance(dtype, str): @@ -58,6 +62,8 @@ def bint(size): """ Construct a bounded integer domain of scalar shape. """ + if torch._C._get_tracing_state(): + size = int(size) assert isinstance(size, integer_types) and size >= 0 return Domain((), size) @@ -71,12 +77,16 @@ def find_domain(op, *domains): assert callable(op), op assert all(isinstance(arg, Domain) for arg in domains) if len(domains) == 1: - return domains[0] + dtype = domains[0].dtype + shape = domains[0].shape + if op is ops.log or op is ops.exp: + dtype = 'real' + return Domain(shape, dtype) lhs, rhs = domains - if op is ops.getitem: + if isinstance(op, ops.GetitemOp): dtype = lhs.dtype - shape = lhs.shape[rhs.num_elements:] + shape = lhs.shape[:op.offset] + lhs.shape[1 + op.offset:] return Domain(shape, dtype) if lhs.dtype == 'real' or rhs.dtype == 'real': diff --git a/funsor/einsum.py b/funsor/einsum.py new file mode 100644 index 000000000..9f7ca18fd --- /dev/null +++ b/funsor/einsum.py @@ -0,0 +1,132 @@ +from __future__ import absolute_import, division, print_function + +from collections import OrderedDict + +import torch +from six import integer_types +from six.moves import reduce + +import funsor.ops as ops +from funsor.contract import Contract +from funsor.interpreter import interpretation, reinterpret +from funsor.optimizer import Finitary, apply_optimizer, optimize +from funsor.sum_product import sum_product +from funsor.terms import Funsor, reflect +from funsor.torch import Tensor + + +def _make_base_lhs(prod_op, arg, reduced_vars, normalized=False): + if not all(isinstance(d.dtype, integer_types) for d in arg.inputs.values()): + raise NotImplementedError("TODO implement continuous base lhss") + + if prod_op not in (ops.add, ops.mul): + raise NotImplementedError("{} not supported product op".format(prod_op)) + + make_unit = torch.ones if prod_op is ops.mul else torch.zeros + + sizes = OrderedDict(set((var, dtype) for var, dtype in arg.inputs.items())) + terms = tuple( + Tensor(make_unit((d.dtype,)) / float(d.dtype), OrderedDict([(var, d)])) + if normalized else + Tensor(make_unit((d.dtype,)), OrderedDict([(var, d)])) + for var, d in sizes.items() if var in reduced_vars + ) + return Finitary(prod_op, terms) if len(terms) > 1 else terms[0] + + +def naive_contract_einsum(eqn, *terms, **kwargs): + """ + Use for testing Contract against einsum + """ + assert "plates" not in kwargs + + backend = kwargs.pop('backend', 'torch') + if backend == 'torch': + sum_op, prod_op = ops.add, ops.mul + elif backend in ('pyro.ops.einsum.torch_log', 'pyro.ops.einsum.torch_marginal'): + sum_op, prod_op = ops.logaddexp, ops.add + else: + raise ValueError("{} backend not implemented".format(backend)) + + assert isinstance(eqn, str) + assert all(isinstance(term, Funsor) for term in terms) + inputs, output = eqn.split('->') + inputs = inputs.split(',') + assert len(inputs) == len(terms) + assert len(output.split(',')) == 1 + input_dims = frozenset(d for inp in inputs for d in inp) + output_dims = frozenset(d for d in output) + reduced_vars = input_dims - output_dims + + with interpretation(optimize): + rhs = Finitary(prod_op, tuple(terms)) + lhs = _make_base_lhs(prod_op, rhs, reduced_vars, normalized=False) + assert frozenset(lhs.inputs) == reduced_vars + result = Contract(sum_op, prod_op, lhs, rhs, reduced_vars) + + return reinterpret(result) + + +def naive_einsum(eqn, *terms, **kwargs): + backend = kwargs.pop('backend', 'torch') + if backend == 'torch': + sum_op, prod_op = ops.add, ops.mul + elif backend in ('pyro.ops.einsum.torch_log', 'pyro.ops.einsum.torch_marginal'): + sum_op, prod_op = ops.logaddexp, ops.add + else: + raise ValueError("{} backend not implemented".format(backend)) + + assert isinstance(eqn, str) + assert all(isinstance(term, Funsor) for term in terms) + inputs, output = eqn.split('->') + assert len(output.split(',')) == 1 + input_dims = frozenset(d for inp in inputs.split(',') for d in inp) + output_dims = frozenset(output) + reduce_dims = input_dims - output_dims + return reduce(prod_op, terms).reduce(sum_op, reduce_dims) + + +def naive_plated_einsum(eqn, *terms, **kwargs): + """ + Implements Tensor Variable Elimination (Algorithm 1 in [Obermeyer et al 2019]) + + [Obermeyer et al 2019] Obermeyer, F., Bingham, E., Jankowiak, M., Chiu, J., + Pradhan, N., Rush, A., and Goodman, N. Tensor Variable Elimination for + Plated Factor Graphs, 2019 + """ + plates = kwargs.pop('plates', '') + if not plates: + return naive_einsum(eqn, *terms, **kwargs) + + backend = kwargs.pop('backend', 'torch') + if backend == 'torch': + sum_op, prod_op = ops.add, ops.mul + elif backend in ('pyro.ops.einsum.torch_log', 'pyro.ops.einsum.torch_marginal'): + sum_op, prod_op = ops.logaddexp, ops.add + else: + raise ValueError("{} backend not implemented".format(backend)) + + assert isinstance(eqn, str) + assert all(isinstance(term, Funsor) for term in terms) + inputs, output = eqn.split('->') + inputs = inputs.split(',') + assert len(inputs) == len(terms) + assert len(output.split(',')) == 1 + input_dims = frozenset(d for inp in inputs for d in inp) + output_dims = frozenset(d for d in output) + plate_dims = frozenset(plates) - output_dims + reduce_vars = input_dims - output_dims - frozenset(plates) + + output_plates = output_dims & frozenset(plates) + if not all(output_plates.issubset(inp) for inp in inputs): + raise NotImplementedError("TODO") + + eliminate = plate_dims | reduce_vars + return sum_product(sum_op, prod_op, terms, eliminate, frozenset(plates)) + + +def einsum(eqn, *terms, **kwargs): + with interpretation(reflect): + naive_ast = naive_plated_einsum(eqn, *terms, **kwargs) + optimized_ast = apply_optimizer(naive_ast) + return reinterpret(optimized_ast) # eager by default diff --git a/funsor/gaussian.py b/funsor/gaussian.py new file mode 100644 index 000000000..ddd40e0dc --- /dev/null +++ b/funsor/gaussian.py @@ -0,0 +1,615 @@ +from __future__ import absolute_import, division, print_function + +import math +import warnings +from collections import OrderedDict, defaultdict + +import torch +from pyro.distributions.util import broadcast_shape +from six import add_metaclass, integer_types +from six.moves import reduce + +import funsor.ops as ops +from funsor.delta import Delta +from funsor.domains import reals +from funsor.integrate import Integrate, integrator +from funsor.ops import AddOp, NegOp, SubOp +from funsor.terms import Align, Binary, Funsor, FunsorMeta, Number, Subs, Unary, Variable, eager, reflect, to_funsor +from funsor.torch import Tensor, align_tensor, align_tensors, materialize +from funsor.util import lazy_property + + +def _issubshape(subshape, supershape): + if len(subshape) > len(supershape): + return False + for sub, sup in zip(reversed(subshape), reversed(supershape)): + if sub not in (1, sup): + return False + return True + + +def _log_det_tri(x): + return x.diagonal(dim1=-1, dim2=-2).log().sum(-1) + + +def _det_tri(x): + return x.diagonal(dim1=-1, dim2=-2).prod(-1) + + +def _mv(mat, vec): + return torch.matmul(mat, vec.unsqueeze(-1)).squeeze(-1) + + +def _vmv(mat, vec): + """ + Computes the inner product ````. + """ + vt = vec.unsqueeze(-2) + v = vec.unsqueeze(-1) + result = torch.matmul(vt, torch.matmul(mat, v)) + return result.squeeze(-1).squeeze(-1) + + +def _trace_mm(x, y): + """ + Computes ``trace(x @ y)``. + """ + assert x.dim() >= 2 + assert y.dim() >= 2 + xy = x * y + return xy.reshape(xy.shape[:-2] + (-1,)).sum(-1) + + +def sym_inverse(mat): + r""" + Computes ``inverse(mat)`` assuming mat is symmetric and usually positive + definite, but falling back to general pseudoinverse if positive + definiteness fails. + """ + try: + # Attempt to use stable positive definite math. + tri = torch.inverse(torch.cholesky(mat)) + return torch.matmul(tri.transpose(-1, -2), tri) + except RuntimeError as e: + warnings.warn(e.message, RuntimeWarning) + + # Try masked reciprocal. + if mat.size(-1) == 1: + result = mat.reciprocal() + result[(mat != 0) == 0] = 0 + return result + + # Fall back to pseudoinverse. + return torch.pinverse(mat) + + +def sym_solve_mv(mat, vec): + r""" + Computes ``mat \ vec`` assuming mat is symmetric and usually positive definite, + but falling back to general pseudoinverse if positive definiteness fails. + """ + try: + # Attempt to use stable positive definite math. + tri = torch.inverse(torch.cholesky(mat)) + return _mv(tri.transpose(-1, -2), _mv(tri, vec)) + except RuntimeError as e: + warnings.warn(e.message, RuntimeWarning) + + # Fall back to pseudoinverse. + if mat.size(-1) == 1: + mat = mat.squeeze(-1) + mat, vec = torch.broadcast_tensors(mat, vec) + result = vec / mat + result[(mat != 0) == 0] = 0 + return result + return _mv(torch.pinverse(mat), vec) + + +def _compute_offsets(inputs): + """ + Compute offsets of real inputs into the concatenated Gaussian dims. + This ignores all int inputs. + + :param OrderedDict inputs: A schema mapping variable name to domain. + :return: a pair ``(offsets, total)``. + :rtype: tuple + """ + assert isinstance(inputs, OrderedDict) + offsets = {} + total = 0 + for key, domain in inputs.items(): + if domain.dtype == 'real': + offsets[key] = total + total += domain.num_elements + return offsets, total + + +def _find_gaps(intervals, end): + intervals = list(sorted(intervals)) + stops = [0] + [stop for start, stop in intervals] + starts = [start for start, stop in intervals] + [end] + return [(stop, start) for stop, start in zip(stops, starts) if stop != start] + + +def _parse_slices(index, value): + if not isinstance(index, tuple): + index = (index,) + if index[0] is Ellipsis: + index = index[1:] + start_stops = [] + for pos, i in reversed(list(enumerate(index))): + if isinstance(i, slice): + start_stops.append((i.start, i.stop)) + elif isinstance(i, integer_types): + start_stops.append((i, i + 1)) + value = value.unsqueeze(pos - len(index)) + else: + raise ValueError("invalid index: {}".format(i)) + start_stops.reverse() + return start_stops, value + + +class BlockVector(object): + """ + Jit-compatible helper to build blockwise vectors. + Syntax is similar to :func:`torch.zeros` :: + + x = BlockVector((100, 20)) + x[..., 0:4] = x1 + x[..., 6:10] = x2 + x = x.as_tensor() + assert x.shape == (100, 20) + """ + def __init__(self, shape): + self.shape = shape + self.parts = {} + + def __setitem__(self, index, value): + (i,), value = _parse_slices(index, value) + self.parts[i] = value + + def as_tensor(self): + # Fill gaps with zeros. + prototype = next(iter(self.parts.values())) + options = dict(dtype=prototype.dtype, device=prototype.device) + for i in _find_gaps(self.parts.keys(), self.shape[-1]): + self.parts[i] = torch.zeros(self.shape[:-1] + (i[1] - i[0],), **options) + + # Concatenate parts. + parts = [v for k, v in sorted(self.parts.items())] + result = torch.cat(parts, dim=-1) + if not torch._C._get_tracing_state(): + assert result.shape == self.shape + return result + + +class BlockMatrix(object): + """ + Jit-compatible helper to build blockwise matrices. + Syntax is similar to :func:`torch.zeros` :: + + x = BlockMatrix((100, 20, 20)) + x[..., 0:4, 0:4] = x11 + x[..., 0:4, 6:10] = x12 + x[..., 6:10, 0:4] = x12.transpose(-1, -2) + x[..., 6:10, 6:10] = x22 + x = x.as_tensor() + assert x.shape == (100, 20, 20) + """ + def __init__(self, shape): + self.shape = shape + self.parts = defaultdict(dict) + + def __setitem__(self, index, value): + (i, j), value = _parse_slices(index, value) + self.parts[i][j] = value + + def as_tensor(self): + # Fill gaps with zeros. + arbitrary_row = next(iter(self.parts.values())) + prototype = next(iter(arbitrary_row.values())) + options = dict(dtype=prototype.dtype, device=prototype.device) + i_gaps = _find_gaps(self.parts.keys(), self.shape[-2]) + j_gaps = _find_gaps(arbitrary_row.keys(), self.shape[-1]) + rows = set().union(i_gaps, self.parts) + cols = set().union(j_gaps, arbitrary_row) + for i in rows: + for j in cols: + if j not in self.parts[i]: + shape = self.shape[:-2] + (i[1] - i[0], j[1] - j[0]) + self.parts[i][j] = torch.zeros(shape, **options) + + # Concatenate parts. + columns = {i: torch.cat([v for j, v in sorted(part.items())], dim=-1) + for i, part in self.parts.items()} + result = torch.cat([v for i, v in sorted(columns.items())], dim=-2) + if not torch._C._get_tracing_state(): + assert result.shape == self.shape + return result + + +def align_gaussian(new_inputs, old): + """ + Align data of a Gaussian distribution to a new ``inputs`` shape. + """ + assert isinstance(new_inputs, OrderedDict) + assert isinstance(old, Gaussian) + loc = old.loc + precision = old.precision + + # Align int inputs. + # Since these are are managed as in Tensor, we can defer to align_tensor(). + new_ints = OrderedDict((k, d) for k, d in new_inputs.items() if d.dtype != 'real') + old_ints = OrderedDict((k, d) for k, d in old.inputs.items() if d.dtype != 'real') + if new_ints != old_ints: + loc = align_tensor(new_ints, Tensor(loc, old_ints)) + precision = align_tensor(new_ints, Tensor(precision, old_ints)) + + # Align real inputs, which are all concatenated in the rightmost dims. + new_offsets, new_dim = _compute_offsets(new_inputs) + old_offsets, old_dim = _compute_offsets(old.inputs) + assert loc.shape[-1:] == (old_dim,) + assert precision.shape[-2:] == (old_dim, old_dim) + if new_offsets != old_offsets: + old_loc = loc + old_precision = precision + loc = BlockVector(old_loc.shape[:-1] + (new_dim,)) + precision = BlockMatrix(old_loc.shape[:-1] + (new_dim, new_dim)) + for k1, new_offset1 in new_offsets.items(): + if k1 not in old_offsets: + continue + offset1 = old_offsets[k1] + num_elements1 = old.inputs[k1].num_elements + old_slice1 = slice(offset1, offset1 + num_elements1) + new_slice1 = slice(new_offset1, new_offset1 + num_elements1) + loc[..., new_slice1] = old_loc[..., old_slice1] + for k2, new_offset2 in new_offsets.items(): + if k2 not in old_offsets: + continue + offset2 = old_offsets[k2] + num_elements2 = old.inputs[k2].num_elements + old_slice2 = slice(offset2, offset2 + num_elements2) + new_slice2 = slice(new_offset2, new_offset2 + num_elements2) + precision[..., new_slice1, new_slice2] = old_precision[..., old_slice1, old_slice2] + loc = loc.as_tensor() + precision = precision.as_tensor() + + return loc, precision + + +class GaussianMeta(FunsorMeta): + """ + Wrapper to convert between OrderedDict and tuple. + """ + def __call__(cls, loc, precision, inputs): + if isinstance(inputs, OrderedDict): + inputs = tuple(inputs.items()) + assert isinstance(inputs, tuple) + return super(GaussianMeta, cls).__call__(loc, precision, inputs) + + +@add_metaclass(GaussianMeta) +class Gaussian(Funsor): + """ + Funsor representing a batched joint Gaussian distribution as a log-density + function. + + Note that :class:`Gaussian` s are not normalized, rather they are + canonicalized to evaluate to zero at their maximum value (at ``loc``). This + canonical form is useful because it allows :class:`Gaussian` s with + incomplete information, i.e. zero eigenvalues in the precision matrix. + These incomplete distributions arise when making low-dimensional + observations on higher dimensional hidden state. + """ + def __init__(self, loc, precision, inputs): + assert isinstance(loc, torch.Tensor) + assert isinstance(precision, torch.Tensor) + assert isinstance(inputs, tuple) + inputs = OrderedDict(inputs) + + # Compute total dimension of all real inputs. + dim = sum(d.num_elements for d in inputs.values() if d.dtype == 'real') + if not torch._C._get_tracing_state(): + assert dim + assert loc.dim() >= 1 and loc.size(-1) == dim + assert precision.dim() >= 2 and precision.shape[-2:] == (dim, dim) + + # Compute total shape of all bint inputs. + batch_shape = tuple(d.dtype for d in inputs.values() + if isinstance(d.dtype, integer_types)) + if not torch._C._get_tracing_state(): + assert _issubshape(loc.shape, batch_shape + (dim,)) + assert _issubshape(precision.shape, batch_shape + (dim, dim)) + + output = reals() + fresh = frozenset(inputs.keys()) + bound = frozenset() + super(Gaussian, self).__init__(inputs, output, fresh, bound) + self.loc = loc + self.precision = precision + self.batch_shape = batch_shape + self.event_shape = (dim,) + + def __repr__(self): + return 'Gaussian(..., ({}))'.format(' '.join( + '({}, {}),'.format(*kv) for kv in self.inputs.items())) + + def align(self, names): + assert isinstance(names, tuple) + assert all(name in self.inputs for name in names) + if not names or names == tuple(self.inputs): + return self + + inputs = OrderedDict((name, self.inputs[name]) for name in names) + inputs.update(self.inputs) + loc, precision = align_gaussian(inputs, self) + return Gaussian(loc, precision, inputs) + + def eager_subs(self, subs): + assert isinstance(subs, tuple) + subs = tuple((k, materialize(to_funsor(v, self.inputs[k]))) + for k, v in subs if k in self.inputs) + if not subs: + return self + + # Constants and Variables are eagerly substituted; + # everything else is lazily substituted. + lazy_subs = tuple((k, v) for k, v in subs + if not isinstance(v, (Number, Tensor, Variable))) + var_subs = tuple((k, v) for k, v in subs if isinstance(v, Variable)) + int_subs = tuple((k, v) for k, v in subs if isinstance(v, (Number, Tensor)) + if v.dtype != 'real') + real_subs = tuple((k, v) for k, v in subs if isinstance(v, (Number, Tensor)) + if v.dtype == 'real') + if not (var_subs or int_subs or real_subs): + return reflect(Subs, self, lazy_subs) + + # First perform any variable substitutions. + if var_subs: + rename = {k: v.name for k, v in var_subs} + inputs = OrderedDict((rename.get(k, k), d) for k, d in self.inputs.items()) + if len(inputs) != len(self.inputs): + raise ValueError("Variable substitution name conflict") + var_result = Gaussian(self.loc, self.precision, inputs) + return Subs(var_result, int_subs + real_subs + lazy_subs) + + # Next perform any integer substitution, i.e. slicing into a batch. + if int_subs: + int_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype != 'real') + real_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype == 'real') + tensors = [self.loc, self.precision] + funsors = [Subs(Tensor(x, int_inputs), int_subs) for x in tensors] + inputs = funsors[0].inputs.copy() + inputs.update(real_inputs) + int_result = Gaussian(funsors[0].data, funsors[1].data, inputs) + return Subs(int_result, real_subs + lazy_subs) + + # Try to perform a complete substitution of all real variables, resulting in a Tensor. + real_subs = OrderedDict(subs) + assert real_subs and not int_subs + if all(k in real_subs for k, d in self.inputs.items() if d.dtype == 'real'): + # Broadcast all component tensors. + int_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype != 'real') + tensors = [Tensor(self.loc, int_inputs), + Tensor(self.precision, int_inputs)] + tensors.extend(real_subs.values()) + inputs, tensors = align_tensors(*tensors) + batch_dim = tensors[0].dim() - 1 + batch_shape = broadcast_shape(*(x.shape[:batch_dim] for x in tensors)) + (loc, precision), values = tensors[:2], tensors[2:] + + # Form the concatenated value. + offsets, event_size = _compute_offsets(self.inputs) + value = BlockVector(batch_shape + (event_size,)) + for k, value_k in zip(real_subs, values): + offset = offsets[k] + value_k = value_k.reshape(value_k.shape[:batch_dim] + (-1,)) + if not torch._C._get_tracing_state(): + assert value_k.size(-1) == self.inputs[k].num_elements + value_k = value_k.expand(batch_shape + value_k.shape[-1:]) + value[..., offset: offset + self.inputs[k].num_elements] = value_k + value = value.as_tensor() + + # Evaluate the non-normalized log density. + result = -0.5 * _vmv(precision, value - loc) + result = Tensor(result, inputs) + assert result.output == reals() + return Subs(result, lazy_subs) + + # Perform a partial substution of a subset of real variables, resulting in a Joint. + # See "The Matrix Cookbook" (November 15, 2012) ss. 8.1.3 eq. 353. + # http://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf + raise NotImplementedError('TODO implement partial substitution of real variables') + + @lazy_property + def _log_normalizer(self): + dim = self.loc.size(-1) + log_det_term = _log_det_tri(torch.cholesky(self.precision)) + data = -log_det_term + 0.5 * math.log(2 * math.pi) * dim + inputs = OrderedDict((k, v) for k, v in self.inputs.items() if v.dtype != 'real') + return Tensor(data, inputs) + + def eager_reduce(self, op, reduced_vars): + if op is ops.logaddexp: + # Marginalize out real variables, but keep mixtures lazy. + assert all(v in self.inputs for v in reduced_vars) + real_vars = frozenset(k for k, d in self.inputs.items() if d.dtype == "real") + reduced_reals = reduced_vars & real_vars + reduced_ints = reduced_vars - real_vars + if not reduced_reals: + return None # defer to default implementation + + inputs = OrderedDict((k, d) for k, d in self.inputs.items() if k not in reduced_reals) + if reduced_reals == real_vars: + result = self._log_normalizer + else: + int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real') + offsets, _ = _compute_offsets(self.inputs) + index = [] + for key, domain in inputs.items(): + if domain.dtype == 'real': + index.extend(range(offsets[key], offsets[key] + domain.num_elements)) + index = torch.tensor(index) + + loc = self.loc[..., index] + self_scale_tri = torch.inverse(torch.cholesky(self.precision)).transpose(-1, -2) + self_covariance = torch.matmul(self_scale_tri, self_scale_tri.transpose(-1, -2)) + covariance = self_covariance[..., index.unsqueeze(-1), index] + scale_tri = torch.cholesky(covariance) + inv_scale_tri = torch.inverse(scale_tri) + precision = torch.matmul(inv_scale_tri.transpose(-1, -2), inv_scale_tri) + reduced_dim = sum(self.inputs[k].num_elements for k in reduced_reals) + log_det_term = _log_det_tri(self_scale_tri) - _log_det_tri(scale_tri) + log_prob = Tensor(log_det_term + 0.5 * math.log(2 * math.pi) * reduced_dim, int_inputs) + result = log_prob + Gaussian(loc, precision, inputs) + + return result.reduce(ops.logaddexp, reduced_ints) + + elif op is ops.add: + for v in reduced_vars: + if self.inputs[v].dtype == 'real': + raise ValueError("Cannot sum along a real dimension: {}".format(repr(v))) + + # Fuse Gaussians along a plate. Compare to eager_add_gaussian_gaussian(). + old_ints = OrderedDict((k, v) for k, v in self.inputs.items() if v.dtype != 'real') + new_ints = OrderedDict((k, v) for k, v in old_ints.items() if k not in reduced_vars) + inputs = OrderedDict((k, v) for k, v in self.inputs.items() if k not in reduced_vars) + + precision = Tensor(self.precision, old_ints).reduce(ops.add, reduced_vars) + precision_loc = Tensor(_mv(self.precision, self.loc), + old_ints).reduce(ops.add, reduced_vars) + assert precision.inputs == new_ints + assert precision_loc.inputs == new_ints + loc = Tensor(sym_solve_mv(precision.data, precision_loc.data), new_ints) + expanded_loc = align_tensor(old_ints, loc) + quadratic_term = Tensor(_vmv(self.precision, expanded_loc - self.loc), + old_ints).reduce(ops.add, reduced_vars) + assert quadratic_term.inputs == new_ints + likelihood = -0.5 * quadratic_term + return likelihood + Gaussian(loc.data, precision.data, inputs) + + return None # defer to default implementation + + def unscaled_sample(self, sampled_vars, sample_inputs): + # Sample only the real variables. + sampled_vars = frozenset(k for k, v in self.inputs.items() + if k in sampled_vars if v.dtype == 'real') + if not sampled_vars: + return self + + # Partition inputs into sample_inputs + int_inputs + real_inputs. + sample_inputs = OrderedDict((k, d) for k, d in sample_inputs.items() + if k not in self.inputs) + sample_shape = tuple(int(d.dtype) for d in sample_inputs.values()) + int_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype != 'real') + real_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype == 'real') + inputs = sample_inputs.copy() + inputs.update(int_inputs) + + if sampled_vars == frozenset(real_inputs): + scale_tri = torch.inverse(torch.cholesky(self.precision)).transpose(-1, -2) + if not torch._C._get_tracing_state(): + assert self.loc.shape == scale_tri.shape[:-1] + shape = sample_shape + self.loc.shape + white_noise = torch.randn(shape) + sample = self.loc + _mv(scale_tri, white_noise) + offsets, _ = _compute_offsets(real_inputs) + results = [] + for key, domain in real_inputs.items(): + data = sample[..., offsets[key]: offsets[key] + domain.num_elements] + data = data.reshape(shape[:-1] + domain.shape) + point = Tensor(data, inputs) + assert point.output == domain + results.append(Delta(key, point)) + results.append(self._log_normalizer) + return reduce(ops.add, results) + + raise NotImplementedError('TODO implement partial sampling of real variables') + + +@eager.register(Binary, AddOp, Gaussian, Gaussian) +def eager_add_gaussian_gaussian(op, lhs, rhs): + # Fuse two Gaussians by adding their log-densities pointwise. + # This is similar to a Kalman filter update, but also keeps track of + # the marginal likelihood which accumulates into a Tensor. + + # Align data. + inputs = lhs.inputs.copy() + inputs.update(rhs.inputs) + int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real') + lhs_loc, lhs_precision = align_gaussian(inputs, lhs) + rhs_loc, rhs_precision = align_gaussian(inputs, rhs) + + # Fuse aligned Gaussians. + precision_loc = _mv(lhs_precision, lhs_loc) + _mv(rhs_precision, rhs_loc) + precision = lhs_precision + rhs_precision + loc = sym_solve_mv(precision, precision_loc) + quadratic_term = _vmv(lhs_precision, loc - lhs_loc) + _vmv(rhs_precision, loc - rhs_loc) + likelihood = Tensor(-0.5 * quadratic_term, int_inputs) + return likelihood + Gaussian(loc, precision, inputs) + + +@eager.register(Binary, SubOp, Gaussian, (Funsor, Align, Gaussian)) +@eager.register(Binary, SubOp, (Funsor, Align), Gaussian) +def eager_sub(op, lhs, rhs): + return lhs + -rhs + + +@eager.register(Unary, NegOp, Gaussian) +def eager_neg(op, arg): + precision = -arg.precision + return Gaussian(arg.loc, precision, arg.inputs) + + +@eager.register(Integrate, Gaussian, Variable, frozenset) +@integrator +def eager_integrate(log_measure, integrand, reduced_vars): + real_vars = frozenset(k for k in reduced_vars if log_measure.inputs[k].dtype == 'real') + if real_vars: + assert real_vars == frozenset([integrand.name]) + data = log_measure.loc * log_measure._log_normalizer.data.exp().unsqueeze(-1) + data = data.reshape(log_measure.loc.shape[:-1] + integrand.output.shape) + inputs = OrderedDict((k, d) for k, d in log_measure.inputs.items() if d.dtype != 'real') + return Tensor(data, inputs) + + return None # defer to default implementation + + +@eager.register(Integrate, Gaussian, Gaussian, frozenset) +@integrator +def eager_integrate(log_measure, integrand, reduced_vars): + real_vars = frozenset(k for k in reduced_vars if log_measure.inputs[k].dtype == 'real') + if real_vars: + + lhs_reals = frozenset(k for k, d in log_measure.inputs.items() if d.dtype == 'real') + rhs_reals = frozenset(k for k, d in integrand.inputs.items() if d.dtype == 'real') + if lhs_reals == real_vars and rhs_reals <= real_vars: + inputs = OrderedDict((k, d) for t in (log_measure, integrand) + for k, d in t.inputs.items()) + lhs_loc, lhs_precision = align_gaussian(inputs, log_measure) + rhs_loc, rhs_precision = align_gaussian(inputs, integrand) + + # Compute the expectation of a non-normalized quadratic form. + # See "The Matrix Cookbook" (November 15, 2012) ss. 8.2.2 eq. 380. + # http://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf + lhs_scale_tri = torch.inverse(torch.cholesky(lhs_precision)).transpose(-1, -2) + lhs_covariance = torch.matmul(lhs_scale_tri, lhs_scale_tri.transpose(-1, -2)) + dim = lhs_loc.size(-1) + norm = _det_tri(lhs_scale_tri) * (2 * math.pi) ** (0.5 * dim) + data = -0.5 * norm * (_vmv(rhs_precision, lhs_loc - rhs_loc) + + _trace_mm(rhs_precision, lhs_covariance)) + inputs = OrderedDict((k, d) for k, d in inputs.items() if k not in reduced_vars) + result = Tensor(data, inputs) + return result.reduce(ops.add, reduced_vars - real_vars) + + raise NotImplementedError('TODO implement partial integration') + + return None # defer to default implementation + + +__all__ = [ + 'BlockMatrix', + 'BlockVector', + 'Gaussian', + 'align_gaussian', +] diff --git a/funsor/handlers.py b/funsor/handlers.py deleted file mode 100644 index 2e58ace6b..000000000 --- a/funsor/handlers.py +++ /dev/null @@ -1,144 +0,0 @@ -from __future__ import absolute_import, division, print_function - -import functools - -from multipledispatch import Dispatcher, dispatch -from six import add_metaclass - - -class Message(dict): - # TODO use defaultdict - - _fields = ("name", "fn", "args", "kwargs", "value", "stop") - - def __init__(self, **fields): - super(Message, self).__init__(**fields) - for field in self._fields: - if field not in self: - self[field] = None - - -class FunsorOp(Message): - pass - - -HANDLER_STACK = [] -STACK_POINTER = {"ptr": -1} - - -def set_default_handlers(*args): - assert not args or all(isinstance(arg, Handler) for arg in args) - while HANDLER_STACK: - HANDLER_STACK[-1].__exit__(None, None, None) - for arg in args: - arg.__enter__() - - -class Handler(object): - def __init__(self, fn=None): - self.fn = fn - - def __enter__(self): - HANDLER_STACK.append(self) - - def __exit__(self, exc_type, exc_value, traceback): - if exc_type is None: - assert HANDLER_STACK[-1] is self - HANDLER_STACK.pop() - else: - if self in HANDLER_STACK: - loc = HANDLER_STACK.index(self) - for i in range(loc, len(HANDLER_STACK)): - HANDLER_STACK.pop() - - @dispatch(Message) - def process(self, msg): - return msg - - @dispatch(Message) - def postprocess(self, msg): - return msg - - def __call__(self, *args, **kwargs): - with self: - return self.fn(*args, **kwargs) - - -class OpRegistryMeta(type): - def __init__(cls, name, bases, dct): - super(OpRegistryMeta, cls).__init__(name, bases, dct) - cls.dispatcher = Dispatcher(cls.__name__) - - -@add_metaclass(OpRegistryMeta) -class OpRegistry(Handler): - """ - Handler with convenient op registry functionality - """ - - @dispatch(object) - def process(self, msg): - return super(OpRegistry, self).process(msg) - - @dispatch(FunsorOp) - def process(self, msg): - impl = self.dispatcher.dispatch(msg["label"]) - if impl is not None: - msg["value"] = impl(*msg["args"], **msg["kwargs"]) - return msg - - @classmethod - def register(cls, *term_types, **kwargs): - return cls.dispatcher.register(tuple(term_types)) - - -def apply_stack(msg): - - for pointer, handler in enumerate(reversed(HANDLER_STACK)): - STACK_POINTER["ptr"] -= 1 - handler.process(msg) - if msg["stop"]: - break - - if msg["value"] is None: - msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"]) - - for handler in HANDLER_STACK[-pointer-1:]: - STACK_POINTER["ptr"] += 1 - handler.postprocess(msg) - - return msg - - -def effectful(term_type, fn=None): - - if fn is None: - return functools.partial(effectful, term_type) - - term_label = None - if not issubclass(term_type, Message): - # XXX hack to make OpRegistry work - term_label = term_type - term_type = FunsorOp - - assert issubclass(term_type, Message) - - def _fn(*args, **kwargs): - - if not HANDLER_STACK: - value = fn(*args, **kwargs) - else: - initial_msg = term_type( - name=kwargs.pop("name", None), - fn=fn, - args=args, - kwargs=kwargs, - value=None, - label=term_label, - ) - - value = apply_stack(initial_msg)["value"] - - return value - - return _fn diff --git a/funsor/integrate.py b/funsor/integrate.py new file mode 100644 index 000000000..75c4fe04a --- /dev/null +++ b/funsor/integrate.py @@ -0,0 +1,84 @@ +from __future__ import absolute_import, division, print_function + +import functools +from collections import OrderedDict + +import funsor.interpreter as interpreter +import funsor.ops as ops +from funsor.contract import Contract +from funsor.terms import Funsor, Reduce, eager + + +class Integrate(Funsor): + """ + Funsor representing an integral wrt a log density funsor. + """ + def __init__(self, log_measure, integrand, reduced_vars): + assert isinstance(log_measure, Funsor) + assert isinstance(integrand, Funsor) + assert isinstance(reduced_vars, frozenset) + inputs = OrderedDict((k, d) for term in (log_measure, integrand) + for (k, d) in term.inputs.items() + if k not in reduced_vars) + output = integrand.output + fresh = frozenset() + bound = reduced_vars + super(Integrate, self).__init__(inputs, output, fresh, bound) + self.log_measure = log_measure + self.integrand = integrand + self.reduced_vars = reduced_vars + + +def _simplify_integrate(fn, log_measure, integrand, reduced_vars): + """ + Reduce free variables that do not appear in both inputs. + """ + if not reduced_vars: + return log_measure.exp() * integrand + + log_measure_vars = frozenset(log_measure.inputs) + integrand_vars = frozenset(integrand.inputs) + assert reduced_vars <= log_measure_vars | integrand_vars + progress = False + if not reduced_vars <= log_measure_vars: + integrand = integrand.reduce(ops.add, reduced_vars - log_measure_vars) + reduced_vars = reduced_vars & log_measure_vars + progress = True + if not reduced_vars <= integrand_vars: + log_measure = log_measure.reduce(ops.logaddexp, reduced_vars - integrand_vars) + reduced_vars = reduced_vars & integrand_vars + progress = True + if progress: + return Integrate(log_measure, integrand, reduced_vars) + + return fn(log_measure, integrand, reduced_vars) + + +def integrator(fn): + """ + Decorator for integration implementations. + """ + fn = interpreter.debug_logged(fn) + return functools.partial(_simplify_integrate, fn) + + +@eager.register(Integrate, Funsor, Funsor, frozenset) +@integrator +def eager_integrate(log_measure, integrand, reduced_vars): + return Contract(ops.add, ops.mul, log_measure.exp(), integrand, reduced_vars) + + +@eager.register(Integrate, Reduce, Funsor, frozenset) +@integrator +def eager_integrate(log_measure, integrand, reduced_vars): + if log_measure.op is ops.logaddexp: + arg = Integrate(log_measure.arg, integrand, reduced_vars) + return arg.reduce(ops.add, log_measure.reduced_vars) + + return Contract(ops.add, ops.mul, log_measure.exp(), integrand, reduced_vars) + + +__all__ = [ + 'Integrate', + 'integrator', +] diff --git a/funsor/interpreter.py b/funsor/interpreter.py index 3d00885a5..442bb882b 100644 --- a/funsor/interpreter.py +++ b/funsor/interpreter.py @@ -1,19 +1,53 @@ from __future__ import absolute_import, division, print_function +import functools +import inspect +import os +import re import types from collections import OrderedDict +import numpy import torch from contextlib2 import contextmanager from funsor.domains import Domain +from funsor.ops import Op +from funsor.registry import KeyedRegistry from funsor.six import singledispatch +_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +_DEBUG = int(os.environ.get("FUNSOR_DEBUG", 0)) +_STACK_SIZE = 0 + _INTERPRETATION = None # To be set later in funsor.terms +_USE_TCO = int(os.environ.get("FUNSOR_USE_TCO", 0)) + +_GENSYM_COUNTER = 0 + + +if _DEBUG: + def interpret(cls, *args): + global _STACK_SIZE + indent = ' ' * _STACK_SIZE + typenames = [cls.__name__] + [type(arg).__name__ for arg in args] + print(indent + ' '.join(typenames)) + _STACK_SIZE += 1 + try: + result = _INTERPRETATION(cls, *args) + finally: + _STACK_SIZE -= 1 -def interpret(cls, *args): - return _INTERPRETATION(cls, *args) + if _DEBUG > 1: + result_str = re.sub('\n', '\n ' + indent, str(result)) + else: + result_str = type(result).__name__ + print(indent + '-> ' + result_str) + return result +else: + def interpret(cls, *args): + return _INTERPRETATION(cls, *args) def set_interpretation(new): @@ -35,15 +69,16 @@ def interpretation(new): @singledispatch -def reinterpret(x): +def recursion_reinterpret(x): r""" Overloaded reinterpretation of a deferred expression. + This interpreter uses the Python stack and is subject to the recursion limit. This handles a limited class of expressions, raising ``ValueError`` in unhandled cases. :param x: An input, typically involving deferred - :class:`~funsor.terms.Funsor`s. + :class:`~funsor.terms.Funsor` s. :type x: A funsor or data structure holding funsors. :return: A reinterpreted version of the input. :raises: ValueError @@ -54,42 +89,228 @@ def reinterpret(x): # We need to register this later in terms.py after declaring Funsor. # reinterpret.register(Funsor) def reinterpret_funsor(x): - return _INTERPRETATION(type(x), *map(reinterpret, x._ast_values)) - - -@reinterpret.register(str) -@reinterpret.register(int) -@reinterpret.register(float) -@reinterpret.register(type) -@reinterpret.register(types.FunctionType) -@reinterpret.register(types.BuiltinFunctionType) -@reinterpret.register(torch.Tensor) -@reinterpret.register(Domain) -def _reinterpret_ground(x): + return _INTERPRETATION(type(x), *map(recursion_reinterpret, x._ast_values)) + + +@recursion_reinterpret.register(str) +@recursion_reinterpret.register(int) +@recursion_reinterpret.register(float) +@recursion_reinterpret.register(type) +@recursion_reinterpret.register(functools.partial) +@recursion_reinterpret.register(types.FunctionType) +@recursion_reinterpret.register(types.BuiltinFunctionType) +@recursion_reinterpret.register(numpy.ndarray) +@recursion_reinterpret.register(torch.Tensor) +@recursion_reinterpret.register(torch.nn.Module) +@recursion_reinterpret.register(Domain) +@recursion_reinterpret.register(Op) +def recursion_reinterpret_ground(x): return x -@reinterpret.register(tuple) -def _reinterpret_tuple(x): - return tuple(map(reinterpret, x)) +@recursion_reinterpret.register(tuple) +def recursion_reinterpret_tuple(x): + return tuple(map(recursion_reinterpret, x)) + + +@recursion_reinterpret.register(frozenset) +def recursion_reinterpret_frozenset(x): + return frozenset(map(recursion_reinterpret, x)) + + +@recursion_reinterpret.register(dict) +def recursion_reinterpret_dict(x): + return {key: recursion_reinterpret(value) for key, value in x.items()} + + +@recursion_reinterpret.register(OrderedDict) +def recursion_reinterpret_ordereddict(x): + return OrderedDict((key, recursion_reinterpret(value)) for key, value in x.items()) + + +@singledispatch +def children(x): + raise ValueError(type(x)) + + +# has to be registered in terms.py +def children_funsor(x): + return x._ast_values + + +@children.register(tuple) +@children.register(frozenset) +def _children_tuple(x): + return x + + +@children.register(dict) +@children.register(OrderedDict) +def _children_tuple(x): + return x.values() + + +@children.register(str) +@children.register(int) +@children.register(float) +@children.register(type) +@children.register(functools.partial) +@children.register(types.FunctionType) +@children.register(types.BuiltinFunctionType) +@children.register(numpy.ndarray) +@children.register(torch.Tensor) +@children.register(torch.nn.Module) +@children.register(Domain) +@children.register(Op) +def _children_ground(x): + return () + + +def is_atom(x): + if isinstance(x, (tuple, frozenset)) and not isinstance(x, Domain): + return len(x) == 0 or all(is_atom(c) for c in x) + return isinstance(x, ( + int, + str, + float, + type, + functools.partial, + types.FunctionType, + types.BuiltinFunctionType, + torch.Tensor, + torch.nn.Module, + numpy.ndarray, + Domain, + Op + )) + + +def gensym(x=None): + global _GENSYM_COUNTER + _GENSYM_COUNTER += 1 + sym = _GENSYM_COUNTER + if x is not None: + if isinstance(x, str): + return x + "_" + str(sym) + return id(x) + return "V" + str(sym) + +def stack_reinterpret(x): + r""" + Overloaded reinterpretation of a deferred expression. + This interpreter uses an explicit stack and no recursion but is much slower. + + This handles a limited class of expressions, raising + ``ValueError`` in unhandled cases. + + :param x: An input, typically involving deferred + :class:`~funsor.terms.Funsor` s. + :type x: A funsor or data structure holding funsors. + :return: A reinterpreted version of the input. + :raises: ValueError + """ + x_name = gensym(x) + node_vars = {x_name: x} + node_names = {x: x_name} + env = {} + stack = [(x_name, x)] + parent_to_children = OrderedDict() + child_to_parents = OrderedDict() + while stack: + h_name, h = stack.pop(0) + parent_to_children[h_name] = [] + for c in children(h): + if c in node_names: + c_name = node_names[c] + else: + c_name = gensym(c) + node_names[c] = c_name + node_vars[c_name] = c + stack.append((c_name, c)) + parent_to_children.setdefault(h_name, []).append(c_name) + child_to_parents.setdefault(c_name, []).append(h_name) + + children_counts = OrderedDict((k, len(v)) for k, v in parent_to_children.items()) + leaves = [name for name, count in children_counts.items() if count == 0] + while leaves: + h_name = leaves.pop(0) + if h_name in child_to_parents: + for parent in child_to_parents[h_name]: + children_counts[parent] -= 1 + if children_counts[parent] == 0: + leaves.append(parent) -@reinterpret.register(frozenset) -def _reinterpret_frozenset(x): - return frozenset(map(reinterpret, x)) + h = node_vars[h_name] + if is_atom(h): + env[h_name] = h + elif isinstance(h, (tuple, frozenset)): + env[h_name] = type(h)( + env[c_name] for c_name in parent_to_children[h_name]) + else: + env[h_name] = _INTERPRETATION( + type(h), *(env[c_name] for c_name in parent_to_children[h_name])) + return env[x_name] -@reinterpret.register(dict) -def _reinterpret_dict(x): - return {key: reinterpret(value) for key, value in x.items()} +def reinterpret(x): + r""" + Overloaded reinterpretation of a deferred expression. -@reinterpret.register(OrderedDict) -def _reinterpret_ordereddict(x): - return OrderedDict((key, reinterpret(value)) for key, value in x.items()) + This handles a limited class of expressions, raising + ``ValueError`` in unhandled cases. + + :param x: An input, typically involving deferred + :class:`~funsor.terms.Funsor` s. + :type x: A funsor or data structure holding funsors. + :return: A reinterpreted version of the input. + :raises: ValueError + """ + if _USE_TCO: + return stack_reinterpret(x) + else: + return recursion_reinterpret(x) + + +if _DEBUG: + class DebugLogged(object): + def __init__(self, fn): + self.fn = fn + while isinstance(fn, functools.partial): + fn = fn.func + path = inspect.getabsfile(fn) + lineno = inspect.getsourcelines(fn)[1] + self._message = "{} file://{} {}".format(fn.__name__, path, lineno) + + def __call__(self, *args, **kwargs): + print(' ' * _STACK_SIZE + self._message) + return self.fn(*args, **kwargs) + + def debug_logged(fn): + if isinstance(fn, DebugLogged): + return fn + return DebugLogged(fn) +else: + def debug_logged(fn): + return fn + + +def dispatched_interpretation(fn): + """ + Decorator to create a dispatched interpretation function. + """ + registry = KeyedRegistry(default=lambda *args: None) + if _DEBUG: + fn.register = lambda *args: lambda fn: registry.register(*args)(debug_logged(fn)) + else: + fn.register = registry.register + fn.dispatch = registry.__call__ + return fn __all__ = [ + 'dispatched_interpretation', 'interpret', 'interpretation', 'reinterpret', diff --git a/funsor/joint.py b/funsor/joint.py new file mode 100644 index 000000000..2d0a0a953 --- /dev/null +++ b/funsor/joint.py @@ -0,0 +1,471 @@ +from __future__ import absolute_import, division, print_function + +import functools +import math +from collections import OrderedDict + +from six import add_metaclass +from six.moves import reduce + +import funsor.interpreter as interpreter +import funsor.ops as ops +import funsor.terms +from funsor.delta import Delta +from funsor.domains import reals +from funsor.gaussian import Gaussian, sym_inverse +from funsor.integrate import Integrate, integrator +from funsor.montecarlo import monte_carlo +from funsor.ops import AddOp, NegOp, SubOp +from funsor.terms import ( + Align, + Binary, + Funsor, + FunsorMeta, + Independent, + Number, + Reduce, + Subs, + Unary, + Variable, + eager, + to_funsor +) +from funsor.torch import Tensor, arange + + +class JointMeta(FunsorMeta): + """ + Wrapper to fill in defaults and convert to funsor. + """ + def __call__(cls, deltas=(), discrete=0, gaussian=0): + discrete = to_funsor(discrete) + gaussian = to_funsor(gaussian) + return super(JointMeta, cls).__call__(deltas, discrete, gaussian) + + +@add_metaclass(JointMeta) +class Joint(Funsor): + """ + Normal form for a joint log probability density funsor. + + The primary purpose of Joint is to handle substitution of + :class:`~funsor.delta.Delta` funsors into other funsors. + + Joint is closed under Bayesian fusion, i.e. under ``ops.add`` operations. + Joint is not closed under mixtures, i.e. ``ops.logaddexp`` operations, + hence mixtures will be represented as lazy ``ops.logaddexp`` of Joints. + + :param tuple deltas: A possibly-empty tuple of degenerate distributions + represented as :class:`~funsor.delta.Delta` funsors. + :param Funsor discrete: A joint discrete log mass function represented as + a :class:`~funsor.terms.Number` or `~funsor.terms.Tensor`. + :param Funsor gaussian: An optional joint multivariate normal distribution + a represented as :class:`~funsor.gaussian.Gaussian` or ``Number(0)`` if + absent. + """ + def __init__(self, deltas, discrete, gaussian): + assert isinstance(deltas, tuple) + assert isinstance(discrete, (Number, Tensor)) + assert discrete.output == reals() + assert gaussian is Number(0) or isinstance(gaussian, Gaussian) + inputs = OrderedDict() + for x in deltas: + assert isinstance(x, Delta) + assert x.name not in inputs + assert x.name not in discrete.inputs + assert x.name not in gaussian.inputs + inputs.update(x.inputs) + inputs.update(discrete.inputs) + inputs.update(gaussian.inputs) + output = reals() + super(Joint, self).__init__(inputs, output) + self.deltas = deltas + self.discrete = discrete + self.gaussian = gaussian + + def eager_reduce(self, op, reduced_vars): + if op is ops.logaddexp: + # Keep mixture parameters lazy. + mixture_vars = frozenset(k for k, d in self.gaussian.inputs.items() if d.dtype != 'real') + mixture_vars = mixture_vars.union(*(x.point.inputs for x in self.deltas)) + lazy_vars = reduced_vars & mixture_vars + reduced_vars -= lazy_vars + + # Integrate out degenerate variables, i.e. drop selected delta. + deltas = [] + remaining_vars = set(reduced_vars) + for d in self.deltas: + if d.name in reduced_vars: + remaining_vars.remove(d.name) + else: + deltas.append(d) + deltas = tuple(deltas) + reduced_vars = frozenset(remaining_vars) + + # Integrate out delayed discrete variables. + discrete_vars = reduced_vars.intersection(self.discrete.inputs) + discrete = self.discrete.reduce(op, discrete_vars) + reduced_vars -= discrete_vars + + # Integrate out delayed gaussian variables. + gaussian_vars = reduced_vars.intersection(self.gaussian.inputs) + gaussian = self.gaussian.reduce(ops.logaddexp, gaussian_vars) + reduced_vars -= gaussian_vars + + # Scale to account for remaining reduced_vars that were inputs to dropped deltas. + eager_result = Joint(deltas, discrete) + if gaussian is not Number(0): + eager_result += gaussian + reduced_vars |= lazy_vars.difference(eager_result.inputs) + lazy_vars = lazy_vars.intersection(eager_result.inputs) + if reduced_vars: + eager_result += ops.log(reduce(ops.mul, [self.inputs[v].dtype for v in reduced_vars])) + + # Return a value only if progress has been made. + if eager_result is self: + return None # defer to default implementation + else: + return eager_result.reduce(ops.logaddexp, lazy_vars) + + if op is ops.add: + terms = list(self.deltas) + [self.discrete, self.gaussian] + for i, term in enumerate(terms): + terms[i] = term.reduce(ops.add, reduced_vars.intersection(term.inputs)) + return reduce(ops.add, terms) + + return None # defer to default implementation + + def moment_matching_reduce(self, op, reduced_vars): + if not reduced_vars: + return self + if op is ops.logaddexp: + if not all(reduced_vars.isdisjoint(d.inputs) for d in self.deltas): + raise NotImplementedError('TODO handle moment_matching with Deltas') + lazy_vars = frozenset().union(*(d.inputs for d in self.deltas)).intersection(reduced_vars) + approx_vars = frozenset(k for k in reduced_vars - lazy_vars + if self.inputs[k].dtype != 'real' + if k in self.gaussian.inputs) + exact_vars = reduced_vars - lazy_vars - approx_vars + if exact_vars: + return self.eager_reduce(op, exact_vars).reduce(op, approx_vars | lazy_vars) + + # Moment-matching approximation. + assert approx_vars and not exact_vars + discrete = self.discrete + + new_discrete = discrete.reduce(ops.logaddexp, approx_vars.intersection(discrete.inputs)) + num_elements = reduce(ops.mul, [ + self.inputs[k].num_elements for k in approx_vars.difference(discrete.inputs)], 1) + if num_elements != 1: + new_discrete -= math.log(num_elements) + + gaussian = self.gaussian + int_inputs = OrderedDict((k, d) for k, d in gaussian.inputs.items() if d.dtype != 'real') + probs = (discrete - new_discrete).exp() + old_loc = Tensor(gaussian.loc, int_inputs) + new_loc = (probs * old_loc).reduce(ops.add, approx_vars) + old_cov = Tensor(sym_inverse(gaussian.precision), int_inputs) + diff = old_loc - new_loc + outers = Tensor(diff.data.unsqueeze(-1) * diff.data.unsqueeze(-2), diff.inputs) + new_cov = ((probs * old_cov).reduce(ops.add, approx_vars) + + (probs * outers).reduce(ops.add, approx_vars)) + new_precision = Tensor(sym_inverse(new_cov.data), new_cov.inputs) + new_inputs = new_loc.inputs.copy() + new_inputs.update((k, d) for k, d in self.gaussian.inputs.items() if d.dtype == 'real') + new_gaussian = Gaussian(new_loc.data, new_precision.data, new_inputs) + result = Joint(self.deltas, new_discrete, new_gaussian) + return result.reduce(ops.logaddexp, lazy_vars) + + return None # defer to default implementation + + def unscaled_sample(self, sampled_vars, sample_inputs): + discrete_vars = sampled_vars.intersection(self.discrete.inputs) + gaussian_vars = frozenset(k for k, v in self.gaussian.inputs.items() + if k in sampled_vars if v.dtype == 'real') + result = self + if discrete_vars: + discrete = result.discrete.unscaled_sample(discrete_vars, sample_inputs) + result = Joint(result.deltas, gaussian=result.gaussian) + discrete + if gaussian_vars: + # Draw an expanded sample. + gaussian_sample_inputs = sample_inputs.copy() + for k, d in self.inputs.items(): + if k not in result.gaussian.inputs and d.dtype != 'real': + gaussian_sample_inputs[k] = d + gaussian = result.gaussian.unscaled_sample(gaussian_vars, gaussian_sample_inputs) + result = Joint(result.deltas, result.discrete) + gaussian + return result + + +@eager.register(Joint, tuple, Funsor, Funsor) +def eager_joint(deltas, discrete, gaussian): + + if not isinstance(gaussian, (Number, Tensor, Gaussian)): + return Joint(deltas, discrete) + gaussian + + if any(not isinstance(d, Delta) for d in deltas): + new_deltas = [] + for d in deltas: + if isinstance(d, Delta): + new_deltas.append(d) + elif isinstance(d, (Number, Tensor)): + discrete += d + else: + raise ValueError("Invalid component for Joint: {}".format(d)) + return Joint(tuple(new_deltas), discrete) + gaussian + + if isinstance(gaussian, (Number, Tensor)) and gaussian is not Number(0): + discrete += gaussian + return Joint(deltas, discrete, Number(0)) + + # Demote a Joint to a simpler elementary funsor. + if not deltas: + if gaussian is Number(0): + return discrete + elif discrete is Number(0): + return gaussian + elif len(deltas) == 1: + if discrete is Number(0) and gaussian is Number(0): + return deltas[0] + + return None # defer to default implementation + + +@eager.register(Independent, Joint, str, str) +def eager_independent(joint, reals_var, bint_var): + for i, delta in enumerate(joint.deltas): + if delta.name == reals_var or delta.name.startswith(reals_var + "__BOUND"): + delta = Independent(delta, reals_var, bint_var) + deltas = joint.deltas[:i] + (delta,) + joint.deltas[1+i:] + discrete = joint.discrete + if bint_var in discrete.inputs: + discrete = discrete.reduce(ops.add, bint_var) + gaussian = joint.gaussian + if bint_var in gaussian.inputs: + gaussian = gaussian.reduce(ops.add, bint_var) + return Joint(deltas, discrete, gaussian) + + return None # defer to default implementation + + +################################################################################ +# Patterns to update a Joint with other funsors +################################################################################ + +@eager.register(Binary, AddOp, Joint, Joint) +def eager_add(op, joint, other): + # Fuse two joint distributions. + for d in other.deltas: + joint += d + joint += other.discrete + joint += other.gaussian + return joint + + +@eager.register(Binary, AddOp, Joint, Delta) +def eager_add(op, joint, delta): + # Update with a degenerate distribution, typically a monte carlo sample. + if delta.name in joint.inputs: + joint = Subs(joint, ((delta.name, delta.point),)) + if not isinstance(joint, Joint): + return joint + delta + for d in joint.deltas: + if d.name in delta.inputs: + delta = Subs(delta, ((d.name, d.point),)) + deltas = joint.deltas + (delta,) + return Joint(deltas, joint.discrete, joint.gaussian) + + +@eager.register(Binary, AddOp, Joint, (Number, Tensor)) +def eager_add(op, joint, other): + # Update with a delayed discrete random variable. + subs = tuple((d.name, d.point) for d in joint.deltas if d in other.inputs) + if subs: + return joint + Subs(other, subs) + return Joint(joint.deltas, joint.discrete + other, joint.gaussian) + + +@eager.register(Binary, AddOp, Joint, Gaussian) +def eager_add(op, joint, other): + # Update with a delayed gaussian random variable. + subs = tuple((d.name, d.point) for d in joint.deltas if d.name in other.inputs) + if subs: + other = Subs(other, subs) + if joint.gaussian is not Number(0): + other = joint.gaussian + other + if not isinstance(other, Gaussian): + return Joint(joint.deltas, joint.discrete) + other + return Joint(joint.deltas, joint.discrete, other) + + +eager.register(Binary, AddOp, Reduce, Joint)( + funsor.terms.eager_distribute_reduce_other) + + +@eager.register(Binary, AddOp, (Funsor, Align, Delta), Joint) +def eager_add(op, other, joint): + return joint + other + + +################################################################################ +# Patterns to create a Joint from elementary funsors +################################################################################ + +@eager.register(Binary, AddOp, Delta, Delta) +def eager_add(op, lhs, rhs): + if lhs.name == rhs.name: + raise NotImplementedError + if rhs.name in lhs.inputs: + assert lhs.name not in rhs.inputs + lhs = lhs(**{rhs.name: rhs.point}) + elif lhs.name in rhs.inputs: + rhs = rhs(**{lhs.name: lhs.point}) + return Joint(deltas=(lhs, rhs)) + + +@eager.register(Binary, AddOp, Delta, (Number, Tensor, Gaussian)) +def eager_add(op, delta, other): + if delta.name in other.inputs: + other = Subs(other, ((delta.name, delta.point),)) + assert isinstance(other, (Number, Tensor, Gaussian)) + if isinstance(other, (Number, Tensor)): + return Joint((delta,), discrete=other) + else: + return Joint((delta,), gaussian=other) + + +@eager.register(Binary, AddOp, (Number, Tensor, Gaussian), Delta) +def eager_add(op, other, delta): + return delta + other + + +@eager.register(Binary, AddOp, Gaussian, (Number, Tensor)) +def eager_add(op, gaussian, discrete): + return Joint(discrete=discrete, gaussian=gaussian) + + +@eager.register(Binary, AddOp, (Number, Tensor), Gaussian) +def eager_add(op, discrete, gaussian): + return Joint(discrete=discrete, gaussian=gaussian) + + +################################################################################ +# Patterns to compute Radon-Nikodym derivatives +################################################################################ + +@eager.register(Binary, SubOp, Joint, (Funsor, Align, Gaussian, Joint)) +def eager_sub(op, joint, other): + return joint + -other + + +@eager.register(Binary, SubOp, (Funsor, Align), Joint) +def eager_sub(op, other, joint): + return -joint + other + + +@eager.register(Binary, SubOp, Delta, (Number, Tensor, Gaussian, Joint)) +@eager.register(Binary, SubOp, (Number, Tensor), Gaussian) +@eager.register(Binary, SubOp, Gaussian, (Number, Tensor, Joint)) +def eager_sub(op, lhs, rhs): + return lhs + -rhs + + +@eager.register(Unary, NegOp, Joint) +def eager_neg(op, joint): + if joint.deltas: + raise ValueError("Cannot negate deltas") + discrete = -joint.discrete + gaussian = -joint.gaussian + return Joint(discrete=discrete, gaussian=gaussian) + + +################################################################################ +# Patterns for integration +################################################################################ + +def _simplify_integrate(fn, joint, integrand, reduced_vars): + if any(d.name in reduced_vars for d in joint.deltas): + subs = tuple((d.name, d.point) for d in joint.deltas if d.name in reduced_vars) + deltas = tuple(d for d in joint.deltas if d.name not in reduced_vars) + log_measure = Joint(deltas, joint.discrete, joint.gaussian) + integrand = Subs(integrand, subs) + reduced_vars = reduced_vars - frozenset(name for name, point in subs) + return Integrate(log_measure, integrand, reduced_vars) + + return fn(joint, integrand, reduced_vars) + + +def _joint_integrator(fn): + """ + Decorator for Integrate(Joint(...), ...) patterns. + """ + fn = interpreter.debug_logged(fn) + return integrator(functools.partial(_simplify_integrate, fn)) + + +@eager.register(Integrate, Joint, Funsor, frozenset) +@_joint_integrator +def eager_integrate(log_measure, integrand, reduced_vars): + return None # defer to default implementation + + +@eager.register(Integrate, Joint, Delta, frozenset) +@_joint_integrator +def eager_integrate(log_measure, integrand, reduced_vars): + raise NotImplementedError('TODO') + + +@eager.register(Integrate, Joint, Tensor, frozenset) +@_joint_integrator +def eager_integrate(log_measure, integrand, reduced_vars): + raise NotImplementedError('TODO') + + +@eager.register(Integrate, Joint, Gaussian, frozenset) +@_joint_integrator +def eager_integrate(log_measure, integrand, reduced_vars): + raise NotImplementedError('TODO') + + +@eager.register(Integrate, Joint, Joint, frozenset) +@_joint_integrator +def eager_integrate(log_measure, integrand, reduced_vars): + raise NotImplementedError('TODO') + + +@eager.register(Integrate, Joint, Variable, frozenset) +@integrator +def eager_integrate(log_measure, integrand, reduced_vars): + name = integrand.name + assert reduced_vars == frozenset([name]) + if any(d.name == name for d in log_measure.deltas): + deltas = tuple(d for d in log_measure.deltas if d.name != name) + log_norm = Joint(deltas, log_measure.discrete, log_measure.gaussian) + for d in log_measure.deltas: + if d.name == name: + mean = d.point + break + return mean * log_norm.exp() + elif name in log_measure.discrete.inputs: + integrand = arange(name, integrand.inputs[name].dtype) + return Integrate(log_measure, integrand, reduced_vars) + else: + assert name in log_measure.gaussian.inputs + gaussian = Integrate(log_measure.gaussian, integrand, reduced_vars) + return Joint(log_measure.deltas, log_measure.discrete).exp() * gaussian + + +@monte_carlo.register(Integrate, Joint, Funsor, frozenset) +@integrator +def monte_carlo_integrate(log_measure, integrand, reduced_vars): + sampled_log_measure = log_measure.sample(reduced_vars, monte_carlo.sample_inputs) + if sampled_log_measure is not log_measure: + reduced_vars = reduced_vars | frozenset(monte_carlo.sample_inputs) + return Integrate(sampled_log_measure, integrand, reduced_vars) + + return None # defer to default implementation + + +__all__ = [ + 'Joint', +] diff --git a/funsor/minipyro.py b/funsor/minipyro.py index 51b0845ae..73d8ceb86 100644 --- a/funsor/minipyro.py +++ b/funsor/minipyro.py @@ -4,63 +4,113 @@ This file contains a minimal implementation of the Pyro Probabilistic Programming Language. The API (method signatures, etc.) match that of -the full implementation as closely as possible. +the full implementation as closely as possible. This file is independent +of the rest of Pyro, with the exception of the :mod:`pyro.distributions` +module. An accompanying example that makes use of this implementation can be found at examples/minipyro.py. """ from __future__ import absolute_import, division, print_function -from collections import OrderedDict +import functools +import warnings +import weakref +from collections import OrderedDict, namedtuple import torch -from multipledispatch import dispatch +from pyro.distributions import validation_enabled import funsor -import funsor.ops as ops -from funsor.terms import Funsor, Variable -from .handlers import HANDLER_STACK, Handler, Message, apply_stack, effectful - -class Sample(Message): - pass +# Funsor repreresents distributions in a fundamentally different way from +# torch.Distributions and Pyro: funsor distributions are densities whereas +# torch Distributions are samplers. This class is a compatibility wrapper +# between the two. It is used only internally in the sample() function. +class Distribution(object): + def __init__(self, funsor_dist): + assert isinstance(funsor_dist, funsor.Funsor) + self.funsor_dist = funsor_dist + self.output = self.funsor_dist.inputs["value"] + + def log_prob(self, value): + return self.funsor_dist(value=value) + + # Draw a sample. + def __call__(self): + with funsor.interpreter.interpretation(funsor.terms.eager): + dist = self.funsor_dist(value='value') + delta = dist.sample(frozenset(['value'])) + if isinstance(delta, funsor.joint.Joint): + delta, = delta.deltas + return delta.point + + # Similar to torch.distributions.Distribution.expand(). + def expand_inputs(self, name, size): + if name in self.funsor_dist.inputs: + assert self.funsor_dist.inputs[name] == funsor.bint(int(size)) + return self + inputs = OrderedDict([(name, funsor.bint(int(size)))]) + funsor_dist = self.funsor_dist + funsor.torch.Tensor(torch.zeros(size), inputs) + return Distribution(funsor_dist) + + +# Pyro keeps track of two kinds of global state: +# i) The effect handler stack, which enables non-standard interpretations of +# Pyro primitives like sample(); +# See http://docs.pyro.ai/en/0.3.1/poutine.html +# ii) Trainable parameters in the Pyro ParamStore; +# See http://docs.pyro.ai/en/0.3.1/parameters.html + +PYRO_STACK = [] +PARAM_STORE = {} # maps name -> (unconstrained_value, constraint) -class Param(Message): - pass +def get_param_store(): + return PARAM_STORE -class Markov(Message): - pass +# The base effect handler class (called Messenger here for consistency with Pyro). +class Messenger(object): + def __init__(self, fn=None): + self.fn = fn + # Effect handlers push themselves onto the PYRO_STACK. + # Handlers earlier in the PYRO_STACK are applied first. + def __enter__(self): + PYRO_STACK.append(self) -class Ground(Message): - pass - + def __exit__(self, *args, **kwargs): + assert PYRO_STACK[-1] is self + PYRO_STACK.pop() -PARAM_STORE = {} + def process_message(self, msg): + pass + def postprocess_message(self, msg): + pass -def get_param_store(): - return PARAM_STORE + def __call__(self, *args, **kwargs): + with self: + return self.fn(*args, **kwargs) -class trace(Handler): +# A first useful example of an effect handler. +# trace records the inputs and outputs of any primitive site it encloses, +# and returns a dictionary containing that data to the user. +class trace(Messenger): def __enter__(self): super(trace, self).__enter__() self.trace = OrderedDict() return self.trace - # trace illustrates why we need postprocess in addition to process: + # trace illustrates why we need postprocess_message in addition to process_message: # We only want to record a value after all other effects have been applied - @dispatch(Sample) - def postprocess(self, msg): - assert msg["name"] is not None and \ - msg["name"] not in self.trace, \ - "all sites must have unique names" + def postprocess_message(self, msg): + assert msg["type"] != "sample" or msg["name"] not in self.trace, \ + "sample sites must have unique names" self.trace[msg["name"]] = msg.copy() - return msg def get_trace(self, *args, **kwargs): self(*args, **kwargs) @@ -72,293 +122,229 @@ def get_trace(self, *args, **kwargs): # We can compose trace and replay to replace values but preserve distributions, # allowing us to compute the joint probability density of samples under a model. # See the definition of elbo(...) below for an example of this pattern. -class replay(Handler): - def __init__(self, fn=None, guide_trace=None): +class replay(Messenger): + def __init__(self, fn, guide_trace): self.guide_trace = guide_trace super(replay, self).__init__(fn) - @dispatch(object) - def process(self, msg): - return super(replay, self).process(msg) - - @dispatch(Sample) - def process(self, msg): + def process_message(self, msg): if msg["name"] in self.guide_trace: msg["value"] = self.guide_trace[msg["name"]]["value"] - return msg -class block(Handler): - """ - This allows the selective application of effect handlers to different parts - of a model. Sites hidden by block will only have the handlers below block - on the ``HANDLER_STACK`` applied, allowing inference or other effectful - computations to be nested inside models. - """ +# block allows the selective application of effect handlers to different parts of a model. +# Sites hidden by block will only have the handlers below block on the PYRO_STACK applied, +# allowing inference or other effectful computations to be nested inside models. +class block(Messenger): def __init__(self, fn=None, hide_fn=lambda msg: True): self.hide_fn = hide_fn super(block, self).__init__(fn) - @dispatch(Message) - def process(self, msg): + def process_message(self, msg): if self.hide_fn(msg): msg["stop"] = True - return msg -class plate(Handler): - """ - This limited implementation of ``PlateHandler`` only implements broadcasting. - """ - def __init__(self, fn, size, dim, name): +# Conditional independence is recorded as a plate context at each site. +CondIndepStackFrame = namedtuple("CondIndepStackFrame", ["name", "size", "dim"]) + + +# This implementation of vectorized PlateMessenger broadcasts and +# records a cond_indep_stack which is later used to convert +# torch.Tensors to funsor.torch.Tensors. +class PlateMessenger(Messenger): + def __init__(self, fn, name, size, dim): assert dim < 0 - self.size = size - self.dim = dim - super(plate, self).__init__(fn) - - @dispatch(object) - def process(self, msg): - return super(plate, self).process(msg) - - @dispatch(Sample) - def process(self, msg): - batch_shape = msg["fn"].batch_shape - if len(batch_shape) < -self.dim or batch_shape[self.dim] != self.size: - batch_shape = [1] * (-self.dim - len(batch_shape)) + list(batch_shape) - batch_shape[self.dim] = self.size - msg["fn"] = msg["fn"].expand(tuple(batch_shape)) - return msg - - def __iter__(self): - return range(self.size) - - -def sample(fn, obs=None, name=None): - """ - This is an effectful version of ``Distribution.sample(...)``. When any - effect handlers are active, it constructs an initial message and calls - ``apply_stack``. - """ - assert isinstance(fn, Funsor) - - # if there are no active Handlers, we just create a lazy compute graph. - if not HANDLER_STACK: - return Variable(name, fn.output) + self.frame = CondIndepStackFrame(name, size, dim) + super(PlateMessenger, self).__init__(fn) + + def process_message(self, msg): + if msg["type"] in ("sample", "param"): + assert self.frame.dim not in msg["cond_indep_stack"] + msg["cond_indep_stack"][self.frame.dim] = self.frame + if msg["type"] == "sample": + msg["fn"] = msg["fn"].expand_inputs(self.frame.name, self.frame.size) + + +# This converts raw torch.Tensors to funsor.Funsors with .inputs and .output +# based on information in msg["cond_indep_stack"] and msg["fn"]. +def tensor_to_funsor(value, cond_indep_stack, output): + assert isinstance(value, torch.Tensor) + event_shape = output.shape + batch_shape = value.shape[:value.dim() - len(event_shape)] + if torch._C._get_tracing_state(): + with funsor.torch.ignore_jit_warnings(): + batch_shape = tuple(map(int, batch_shape)) + inputs = OrderedDict() + data = value + for dim, size in enumerate(batch_shape): + if size == 1: + data = data.squeeze(dim - value.dim()) + else: + frame = cond_indep_stack[dim - len(batch_shape)] + assert size == frame.size, (size, frame) + inputs[frame.name] = funsor.bint(int(size)) + value = funsor.torch.Tensor(data, inputs, output.dtype) + assert value.output == output + return value + + +# The log_joint messenger is the main way of recording log probabilities. +# This is roughly the Funsor equivalent to pyro.poutine.trace. +class log_joint(Messenger): + def __enter__(self): + super(log_joint, self).__enter__() + self.log_factors = OrderedDict() # maps site name to log_prob factor + self.plates = set() + return self + + def process_message(self, msg): + if msg["type"] == "sample": + if msg["value"] is None: + # Create a delayed sample. + msg["value"] = funsor.Variable(msg["name"], msg["fn"].output) + + def postprocess_message(self, msg): + if msg["type"] == "sample": + assert msg["name"] not in self.log_factors, "all sites must have unique names" + log_prob = msg["fn"].log_prob(msg["value"]) + self.log_factors[msg["name"]] = log_prob + self.plates.update(f.name for f in msg["cond_indep_stack"].values()) + + +# apply_stack is called by pyro.sample and pyro.param. +# It is responsible for applying each Messenger to each effectful operation. +def apply_stack(msg): + for pointer, handler in enumerate(reversed(PYRO_STACK)): + handler.process_message(msg) + # When a Messenger sets the "stop" field of a message, + # it prevents any Messengers above it on the stack from being applied. + if msg.get("stop"): + break + if msg["value"] is None: + msg["value"] = msg["fn"](*msg["args"]) + if isinstance(msg["value"], torch.Tensor): + msg["value"] = tensor_to_funsor(msg["value"], msg["cond_indep_stack"], msg["output"]) + + # A Messenger that sets msg["stop"] == True also prevents application + # of postprocess_message by Messengers above it on the stack + # via the pointer variable from the process_message loop + for handler in PYRO_STACK[-pointer-1:]: + handler.postprocess_message(msg) + return msg + + +# sample is an effectful version of Distribution.sample(...) +# When any effect handlers are active, it constructs an initial message and calls apply_stack. +def sample(name, fn, obs=None, infer=None): + # Wrap the funsor distribution in a Pyro-compatible way. + fn = Distribution(fn) + + # if there are no active Messengers, we just draw a sample and return it as expected: + if not PYRO_STACK: + return fn() # Otherwise, we initialize a message... - initial_msg = Sample(**{ + initial_msg = { + "type": "sample", "name": name, "fn": fn, "args": (), - "kwargs": {}, "value": obs, - }) + "cond_indep_stack": {}, # maps dim to CondIndepStackFrame + "output": fn.output, + "infer": {} if infer is None else infer, + } - # ...and use apply_stack to send it to the Handlers + # ...and use apply_stack to send it to the Messengers msg = apply_stack(initial_msg) + assert isinstance(msg["value"], funsor.Funsor) return msg["value"] -def param(init_value=None, name=None): - """ - This is an effectful version of ``PARAM_STORE.setdefault``. When any effect - handlers are active, it constructs an initial message and calls - ``apply_stack``. - """ - - if init_value is None and name is None: - raise ValueError("empty args to param") - - def fn(init_value): - value = PARAM_STORE.setdefault(name, init_value) - value.requires_grad_() - return value - - # if there are no active Handlers, we just draw a sample and return it as expected: - if not HANDLER_STACK: - return fn(init_value) +# param is an effectful version of PARAM_STORE.setdefault that also handles constraints. +# When any effect handlers are active, it constructs an initial message and calls apply_stack. +def param(name, init_value=None, constraint=torch.distributions.constraints.real, event_dim=None): + cond_indep_stack = {} + output = None + if init_value is not None: + if event_dim is None: + event_dim = init_value.dim() + output = funsor.reals(*init_value.shape[init_value.dim() - event_dim:]) + + def fn(init_value, constraint): + if name in PARAM_STORE: + unconstrained_value, constraint = PARAM_STORE[name] + else: + # Initialize with a constrained value. + assert init_value is not None + with torch.no_grad(): + constrained_value = init_value.detach() + unconstrained_value = torch.distributions.transform_to(constraint).inv(constrained_value) + unconstrained_value.requires_grad_() + unconstrained_value._funsor_metadata = (cond_indep_stack, output) + PARAM_STORE[name] = unconstrained_value, constraint + + # Transform from unconstrained space to constrained space. + constrained_value = torch.distributions.transform_to(constraint)(unconstrained_value) + constrained_value.unconstrained = weakref.ref(unconstrained_value) + return tensor_to_funsor(constrained_value, *unconstrained_value._funsor_metadata) + + # if there are no active Messengers, we just draw a sample and return it as expected: + if not PYRO_STACK: + return fn(init_value, constraint) # Otherwise, we initialize a message... - initial_msg = Param(**{ + initial_msg = { + "type": "param", + "name": name, "fn": fn, - "args": (init_value,), + "args": (init_value, constraint), "value": None, - }) + "cond_indep_stack": cond_indep_stack, # maps dim to CondIndepStackFrame + "output": output, + } - # ...and use apply_stack to send it to the Handlers + # ...and use apply_stack to send it to the Messengers msg = apply_stack(initial_msg) + assert isinstance(msg["value"], funsor.Funsor) return msg["value"] -class SelectiveHandler(Handler): - def __init__(self, fn=None, match_fn=None): - self.match_fn = (lambda msg: True) if match_fn is None else match_fn - super(SelectiveHandler, self).__init__(fn=fn) - - -class deferred(SelectiveHandler): - - @dispatch(object) - def process(self, msg): - return super(deferred, self).process(msg) - - @dispatch(Sample) - def process(self, msg): - if msg["value"] is not None and self.match_fn(msg): - msg["value"] = Variable(msg["name"], msg["fn"].output) - return msg - - -class monte_carlo(SelectiveHandler): - - @dispatch(object) - def process(self, msg): - return super(monte_carlo, self).process(msg) +# boilerplate to match the syntax of actual pyro.plate: +def plate(name, size, dim): + return PlateMessenger(fn=None, name=name, size=size, dim=dim) - @dispatch(Sample) - def process(self, msg): - if msg["value"] is not None and self.match_fn(msg): - msg["value"] = msg["fn"].sample() - return msg +# This is a thin wrapper around the `torch.optim.Adam` class that +# dynamically generates optimizers for dynamically generated parameters. +# See http://docs.pyro.ai/en/0.3.1/optimization.html +class Adam(object): + def __init__(self, optim_args): + self.optim_args = optim_args + # Each parameter will get its own optimizer, which we keep track + # of using this dictionary keyed on parameters. + self.optim_objs = {} -class log_joint(Handler): - """ - Tracks log joint density during delayed sampling. - """ - - def __enter__(self): - self.log_prob = funsor.to_funsor(0.) - self.samples = OrderedDict() - return self - - @dispatch(object) - def process(self, msg): - return super(log_joint, self).process(msg) - - @dispatch(Sample) - def process(self, msg): - assert msg["value"] is not None - self.samples[msg["name"]] = msg["value"] - self.log_prob += msg["fn"].log_prob(msg["value"]) - return msg - - @dispatch(Markov) - def process(self, msg): - funsors = [] - _recursive_map(funsors.append, msg["value"]) - hidden_dims = (frozenset(self.samples) - frozenset(funsors) - ).intersection(self.log_prob.dims) - if hidden_dims: - marginal = self.log_prob.reduce(ops.sample, hidden_dims) - self.log_prob = funsor.eval(marginal) - subs = funsor.backward(ops.sample, self.log_prob, hidden_dims) - msg["value"] = _recursive_map(lambda x: x(**subs), msg["value"]) - return msg - - @dispatch(Ground) - def process(self, msg): - value = msg["value"] - if not isinstance(value, (funsor.Number, funsor.Tensor)): - log_prob = self.log_prob.reduce(ops.sample, value.dims) - self.log_prob = funsor.eval(log_prob) - subs = funsor.backward(ops.sample, self.log_prob, value.dims) - self.samples.update(subs) - msg["value"] = value(**subs) - context = msg["context"] - for key, value in list(context.items()): - if isinstance(value, Funsor): - context[key] = value(**subs) - - return msg - - -@effectful(Markov) -def markov(state): - """ - Declaration that behavior after this point in a program depends on behavior - before this point in a program only through the passed ``state`` object, - which can be a :class:`~funsor.Funsor` or recursive structure built from - funsors via ``tuple`` or non-funsor keyed ``dict``. - - Example:: - - x = 0 - for t in range(100): - x = pyro.sample("x_{}".format(t), trans(x)) - x = pyro.markov(x) # it is now safe to marginalize past xs - pyro.sample("y_{}".format(t), emit(x), obs=data[t]) - """ - # if there are no active Handlers, we just return the state - return state - - -def _recursive_map(fn, x): - if isinstance(x, funsor.Funsor): - return fn(x) - if isinstance(x, tuple): - return tuple(fn(y) for y in x) - if isinstance(x, dict): - return {k: fn(v) for k, v in x.items()} - - -@effectful(Ground) -def ground(value, context): - """ - Sample enough deferred random variables so that ``value`` is ground, - and update ``context`` with the new samples. Typically ``context`` is - ``locals()`` as called in a small model, or a global dict storing all - random state in a larger model. This is typically used in a - :class:`deferred` context. - - This is like ``value()`` in the Birch probabilistic programming language. - - Example:: - - with pyro.deferred(): - # ...draw deferred samples... - x = pyro.ground(x, locals()) - if x > 0: # requires x to be a ground value - # ...do stuff... - - :param Funsor value: A funsor possibly depending on delayed sample sites. - :param dict context: A dict containing all other random state. - :return: A version of ``value`` with all deferred variables sampled. - :rtype: Funsor - """ - assert isinstance(value, Funsor) - assert isinstance(context, dict) - return value - - -def elbo(model, guide, *args, **kwargs): - """ - This is an attempt to compute a deferred elbo. - """ - # sample guide - with log_joint() as guide_joint: - guide(*args, **kwargs) - # FIXME This is only correct for reparametrized sites. - # FIXME do not marginalize; instead sample. - q = guide_joint.log_prob.logsumexp() - tr = guide_joint.samples - tr.update(funsor.backward(ops.sample, q)) # force deferred samples? - - # replay model against guide - with log_joint() as model_joint, replay(guide_trace=tr): - model(*args, **kwargs) - p = funsor.eval(model_joint.log_prob.logsumexp()) - - elbo = p - q - return -elbo # negate, for use as loss + def __call__(self, params): + for param in params: + # If we've seen this parameter before, use the previously + # constructed optimizer. + if param in self.optim_objs: + optim = self.optim_objs[param] + # If we've never seen this parameter before, construct + # an Adam optimizer and keep track of it. + else: + optim = torch.optim.Adam([param], **self.optim_args) + self.optim_objs[param] = optim + # Take a gradient step for the parameter param. + optim.step() +# This is a unified interface for stochastic variational inference in Pyro. +# The actual construction of the loss is taken care of by `loss`. +# See http://docs.pyro.ai/en/0.3.1/inference_algos.html class SVI(object): - """ - This is a unified interface for stochastic variational inference in Pyro. - The actual construction of the loss is taken care of by `loss`. - See http://docs.pyro.ai/en/stable/inference_algos.html - """ def __init__(self, model, guide, optim, loss): self.model = model self.guide = guide @@ -373,58 +359,204 @@ def step(self, *args, **kwargs): # further tracing occurs inside of `loss`. with trace() as param_capture: # We use block here to allow tracing to record parameters only. - with block(hide_fn=lambda msg: msg["type"] == "sample"): + with block(hide_fn=lambda msg: msg["type"] != "param"): loss = self.loss(self.model, self.guide, *args, **kwargs) # Differentiate the loss. - loss.backward() + loss.data.backward() # Grab all the parameters from the trace. - params = [site["value"] for site in param_capture.values()] + params = [site["value"].data.unconstrained() + for site in param_capture.values()] # Take a step w.r.t. each parameter in params. self.optim(params) # Zero out the gradients so that they don't accumulate. for p in params: - p.grad = p.new_zeros(p.shape) + p.grad = torch.zeros_like(p.grad) return loss.item() -class Adam(object): - """ - This is a thin wrapper around the `torch.optim.Adam` class that - dynamically generates optimizers for dynamically generated parameters. - See http://docs.pyro.ai/en/stable/optimization.html - """ - def __init__(self, optim_args): - self.optim_args = optim_args - # Each parameter will get its own optimizer, which we keep track - # of using this dictionary keyed on parameters. - self.optim_objs = {} +# TODO(eb8680) Replace this with funsor.Expectation. +def Expectation(log_probs, costs, sum_vars, prod_vars): + result = 0 + for cost in costs: + log_prob = funsor.sum_product.sum_product( + sum_op=funsor.ops.logaddexp, + prod_op=funsor.ops.add, + factors=log_probs, + plates=prod_vars, + eliminate=(prod_vars | sum_vars) - frozenset(cost.inputs) + ) + term = funsor.Integrate(log_prob, cost, sum_vars & frozenset(cost.inputs)) + term = term.reduce(funsor.ops.add, prod_vars & frozenset(cost.inputs)) + result += term + return result + + +# This is a basic implementation of the Evidence Lower Bound, which is the +# fundamental objective in Variational Inference. +# See http://pyro.ai/examples/svi_part_i.html for details. +# This implementation uses a Dice estimator similar to TraceEnum_ELBO. +def elbo(model, guide, *args, **kwargs): + with log_joint() as guide_log_joint: + guide(*args, **kwargs) + with log_joint() as model_log_joint: + model(*args, **kwargs) - def __call__(self, params): - for param in params: - # If we've seen this parameter before, use the previously - # constructed optimizer. - if param in self.optim_objs: - optim = self.optim_objs[param] - # If we've never seen this parameter before, construct - # an Adam optimizer and keep track of it. - else: - optim = torch.optim.Adam([param.data], **self.optim_args) - self.optim_objs[param] = optim - # Take a gradient step for the parameter param. - optim.step() + # contract out auxiliary variables in the guide + guide_log_probs = list(guide_log_joint.log_factors.values()) + guide_aux_vars = frozenset().union(*(f.inputs for f in guide_log_probs)) - \ + frozenset(guide_log_joint.plates) - \ + frozenset(model_log_joint.log_factors) + if guide_aux_vars: + guide_log_probs = funsor.sum_product.partial_sum_product( + funsor.ops.logaddexp, + funsor.ops.add, + guide_log_probs, + plates=frozenset(guide_log_joint.plates), + eliminate=guide_aux_vars) + + # contract out auxiliary variables in the model + model_log_probs = list(model_log_joint.log_factors.values()) + model_aux_vars = frozenset().union(*(f.inputs for f in model_log_probs)) - \ + frozenset(model_log_joint.plates) - \ + frozenset(guide_log_joint.log_factors) + if model_aux_vars: + model_log_probs = funsor.sum_product.partial_sum_product( + funsor.ops.logaddexp, + funsor.ops.add, + model_log_probs, + plates=frozenset(model_log_joint.plates), + eliminate=model_aux_vars) + + # compute remaining plates and sum_dims + plates = frozenset().union( + *(model_log_joint.plates.intersection(f.inputs) for f in model_log_probs)) + plates = plates | frozenset().union( + *(guide_log_joint.plates.intersection(f.inputs) for f in guide_log_probs)) + sum_vars = frozenset().union(model_log_joint.log_factors, guide_log_joint.log_factors) - \ + frozenset(model_aux_vars | guide_aux_vars) + + # Accumulate costs from model and guide and log_probs from guide. + # Cf. pyro.infer.traceenum_elbo._compute_dice_elbo() + # https://github.com/pyro-ppl/pyro/blob/0.3.0/pyro/infer/traceenum_elbo.py#L119 + costs = [] + log_probs = [] + for p in model_log_probs: + costs.append(p) + for q in guide_log_probs: + costs.append(-q) + log_probs.append(q) + + # Compute expected cost. + # Cf. pyro.infer.util.Dice.compute_expectation() + # https://github.com/pyro-ppl/pyro/blob/0.3.0/pyro/infer/util.py#L212 + elbo = Expectation(tuple(log_probs), + tuple(costs), + sum_vars=sum_vars, + prod_vars=plates) + + loss = -elbo + assert not loss.inputs + return loss + + +# Base class for elbo implementations. +class ELBO(object): + def __init__(self, **kwargs): + self.options = kwargs + + def __call__(self, model, guide, *args, **kwargs): + return elbo(model, guide, *args, **kwargs) + + +# This is a wrapper for compatibility with full Pyro. +class Trace_ELBO(ELBO): + def __call__(self, model, guide, *args, **kwargs): + with funsor.montecarlo.monte_carlo_interpretation(): + return elbo(model, guide, *args, **kwargs) + + +class TraceMeanField_ELBO(ELBO): + # TODO Use exact KLs where possible. + pass -__all__ = [ - 'Adam', - 'block', - 'deferred', - 'elbo', - 'get_param_store', - 'ground', - 'markov', - 'param', - 'plate', - 'replay', - 'sample', - 'trace', -] +class TraceEnum_ELBO(ELBO): + # TODO allow mixing of sampling and exact integration + def __call__(self, model, guide, *args, **kwargs): + if self.options.get("optimize", None): + with funsor.interpreter.interpretation(funsor.optimizer.optimize): + elbo_expr = elbo(model, guide, *args, **kwargs) + return funsor.reinterpret(elbo_expr) + return elbo(model, guide, *args, **kwargs) + + +# This is a PyTorch jit wrapper that (1) delays tracing until the first +# invocation, and (2) registers pyro.param() statements with torch.jit.trace. +# This version does not support variable number of args or non-tensor kwargs. +class Jit(object): + def __init__(self, fn, **kwargs): + self.fn = fn + self.ignore_jit_warnings = kwargs.get("ignore_jit_warnings", False) + self._compiled = None + self._param_trace = None + + def __call__(self, *args): + # On first call, initialize params and save their names. + if self._param_trace is None: + with block(), trace() as tr, block(hide_fn=lambda m: m["type"] != "param"): + self.fn(*args) + self._param_trace = tr + + # Augment args with reads from the global param store. + unconstrained_params = tuple(param(name).data.unconstrained() + for name in self._param_trace) + params_and_args = unconstrained_params + args + + # On first call, create a compiled elbo. + if self._compiled is None: + + def compiled(*params_and_args): + unconstrained_params = params_and_args[:len(self._param_trace)] + args = params_and_args[len(self._param_trace):] + for name, unconstrained_param in zip(self._param_trace, unconstrained_params): + constrained_param = param(name) # assume param has been initialized + assert constrained_param.data.unconstrained() is unconstrained_param + self._param_trace[name]["value"] = constrained_param + result = replay(self.fn, guide_trace=self._param_trace)(*args) + assert not result.inputs + assert result.output == funsor.reals() + return result.data + + with validation_enabled(False), warnings.catch_warnings(): + if self.ignore_jit_warnings: + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) + self._compiled = torch.jit.trace(compiled, params_and_args, check_trace=False) + + data = self._compiled(*params_and_args) + return funsor.torch.Tensor(data) + + +# This is a jit wrapper for ELBO implementations. +class Jit_ELBO(ELBO): + def __init__(self, elbo, **kwargs): + super(Jit_ELBO, self).__init__(**kwargs) + self._elbo = elbo(**kwargs) + self._compiled = {} # maps (model,guide) -> Jit instances + + def __call__(self, model, guide, *args): + if (model, guide) not in self._compiled: + elbo = functools.partial(self._elbo, model, guide) + self._compiled[model, guide] = Jit(elbo, **self.options) + return self._compiled[model, guide](*args) + + +def JitTrace_ELBO(**kwargs): + return Jit_ELBO(Trace_ELBO, **kwargs) + + +def JitTraceMeanField_ELBO(**kwargs): + return Jit_ELBO(TraceMeanField_ELBO, **kwargs) + + +def JitTraceEnum_ELBO(**kwargs): + return Jit_ELBO(TraceEnum_ELBO, **kwargs) diff --git a/funsor/montecarlo.py b/funsor/montecarlo.py new file mode 100644 index 000000000..f528b4611 --- /dev/null +++ b/funsor/montecarlo.py @@ -0,0 +1,58 @@ +from __future__ import absolute_import, division, print_function + +from collections import OrderedDict + +from contextlib2 import contextmanager + +from funsor.integrate import Integrate, integrator +from funsor.interpreter import dispatched_interpretation, interpretation +from funsor.terms import Funsor, eager + + +@dispatched_interpretation +def monte_carlo(cls, *args): + """ + A Monte Carlo interpretation of :class:`~funsor.integrate.Integrate` + expressions. This falls back to :class:`~funsor.terms.eager` in other + cases. + """ + # TODO Memoize sample statements in a context manager. + result = monte_carlo.dispatch(cls, *args) + if result is None: + result = eager(cls, *args) + return result + + +# This is a globally configurable parameter to draw multiple samples. +monte_carlo.sample_inputs = OrderedDict() + + +@contextmanager +def monte_carlo_interpretation(**sample_inputs): + """ + Context manager to set ``monte_carlo.sample_inputs`` and + install the :func:`monte_carlo` interpretation. + """ + old = monte_carlo.sample_inputs + monte_carlo.sample_inputs = OrderedDict(sample_inputs) + try: + with interpretation(monte_carlo): + yield + finally: + monte_carlo.sample_inputs = old + + +@monte_carlo.register(Integrate, Funsor, Funsor, frozenset) +@integrator +def monte_carlo_integrate(log_measure, integrand, reduced_vars): + sample = log_measure.sample(reduced_vars, monte_carlo.sample_inputs) + if sample is log_measure: + return None # cannot progress + reduced_vars |= frozenset(monte_carlo.sample_inputs).intersection(sample.inputs) + return Integrate(sample, integrand, reduced_vars) + + +__all__ = [ + 'monte_carlo', + 'monte_carlo_interpretation' +] diff --git a/funsor/numpy.py b/funsor/numpy.py new file mode 100644 index 000000000..892e5147f --- /dev/null +++ b/funsor/numpy.py @@ -0,0 +1,291 @@ +from __future__ import absolute_import, division, print_function + +from collections import OrderedDict + +import numpy as np +from multipledispatch import dispatch +from six import add_metaclass, integer_types + +import funsor.ops as ops +from funsor.domains import Domain, bint, find_domain +from funsor.terms import Binary, Funsor, FunsorMeta, Number, eager, substitute, to_data, to_funsor + + +def align_array(new_inputs, x): + r""" + Permute and expand an array to match desired ``new_inputs``. + + :param OrderedDict new_inputs: A target set of inputs. + :param funsor.terms.Funsor x: A :class:`Array` s or + or :class:`~funsor.terms.Number` . + :return: a number or :class:`numpy.ndarray` that can be broadcast to other + array with inputs ``new_inputs``. + :rtype: tuple + """ + assert isinstance(new_inputs, OrderedDict) + assert isinstance(x, (Number, Array)) + assert all(isinstance(d.dtype, integer_types) for d in x.inputs.values()) + + data = x.data + if isinstance(x, Number): + return data + + old_inputs = x.inputs + if old_inputs == new_inputs: + return data + + # Permute squashed input dims. + x_keys = tuple(old_inputs) + data = np.transpose(data, (tuple(x_keys.index(k) for k in new_inputs if k in old_inputs) + + tuple(range(len(old_inputs), data.ndim)))) + + # Unsquash multivariate input dims by filling in ones. + data = np.reshape(data, tuple(old_inputs[k].dtype if k in old_inputs else 1 for k in new_inputs) + + x.output.shape) + return data + + +def align_arrays(*args): + r""" + Permute multiple arrays before applying a broadcasted op. + + This is mainly useful for implementing eager funsor operations. + + :param funsor.terms.Funsor \*args: Multiple :class:`Array` s and + :class:`~funsor.terms.Number` s. + :return: a pair ``(inputs, arrays)`` where arrayss are all + :class:`numpy.ndarray` s that can be broadcast together to a single data + with given ``inputs``. + :rtype: tuple + """ + inputs = OrderedDict() + for x in args: + inputs.update(x.inputs) + arrays = [align_array(inputs, x) for x in args] + return inputs, arrays + + +class ArrayMeta(FunsorMeta): + """ + Wrapper to fill in default args and convert between OrderedDict and tuple. + """ + def __call__(cls, data, inputs=None, dtype="real"): + if inputs is None: + inputs = tuple() + elif isinstance(inputs, OrderedDict): + inputs = tuple(inputs.items()) + return super(ArrayMeta, cls).__call__(data, inputs, dtype) + + +@add_metaclass(ArrayMeta) +class Array(Funsor): + """ + Funsor backed by a numpy ndarray. + + :param tuple dims: A tuple of strings of dimension names. + :param np.ndarray data: A np.ndarray of appropriate shape. + """ + def __init__(self, data, inputs=None, dtype="real"): + assert isinstance(data, np.ndarray) or np.isscalar(data) + assert isinstance(inputs, tuple) + assert all(isinstance(d.dtype, integer_types) for k, d in inputs) + inputs = OrderedDict(inputs) + output = Domain(data.shape[len(inputs):], dtype) + fresh = frozenset(inputs.keys()) + bound = frozenset() + super(Array, self).__init__(inputs, output, fresh, bound) + self.data = data + + def __repr__(self): + if self.output != "real": + return 'Array({}, {}, {})'.format(self.data, self.inputs, repr(self.dtype)) + elif self.inputs: + return 'Array({}, {})'.format(self.data, self.inputs) + else: + return 'Array({})'.format(self.data) + + def __str__(self): + if self.dtype != "real": + return 'Array({}, {}, {})'.format(self.data, self.inputs, repr(self.dtype)) + elif self.inputs: + return 'Array({}, {})'.format(self.data, self.inputs) + else: + return str(self.data) + + def __int__(self): + return int(self.data) + + def __float__(self): + return float(self.data) + + def __bool__(self): + return bool(self.data) + + def item(self): + return self.data.item() + + def align(self, names): + assert isinstance(names, tuple) + assert all(name in self.inputs for name in names) + if not names or names == tuple(self.inputs): + return self + inputs = OrderedDict((name, self.inputs[name]) for name in names) + inputs.update(self.inputs) + + if any(d.shape for d in self.inputs.values()): + raise NotImplementedError("TODO: Implement align with vector indices.") + old_dims = tuple(self.inputs) + new_dims = tuple(inputs) + data = np.transpose(self.data, (tuple(old_dims.index(d) for d in new_dims))) + return Array(data, inputs, self.dtype) + + def eager_subs(self, subs): + assert isinstance(subs, tuple) + subs = {k: materialize(v) for k, v in subs if k in self.inputs} + if not subs: + return self + + # Compute result shapes. + inputs = OrderedDict() + for k, domain in self.inputs.items(): + if k in subs: + inputs.update(subs[k].inputs) + else: + inputs[k] = domain + + # Construct a dict with each input's positional dim, + # counting from the right so as to support broadcasting. + total_size = len(inputs) + len(self.output.shape) # Assumes only scalar indices. + new_dims = {} + for k, domain in inputs.items(): + assert not domain.shape + new_dims[k] = len(new_dims) - total_size + + # Use advanced indexing to construct a simultaneous substitution. + index = [] + for k, domain in self.inputs.items(): + if k in subs: + v = subs.get(k) + if isinstance(v, Number): + index.append(int(v.data)) + else: + # Permute and expand v.data to end up at new_dims. + assert isinstance(v, Array) + v = v.align(tuple(k2 for k2 in inputs if k2 in v.inputs)) + assert isinstance(v, Array) + v_shape = [1] * total_size + for k2, size in zip(v.inputs, v.data.shape): + v_shape[new_dims[k2]] = size + index.append(v.data.reshape(tuple(v_shape))) + else: + # Construct a [:] slice for this preserved input. + offset_from_right = -1 - new_dims[k] + index.append(np.arange(domain.dtype).reshape( + (-1,) + (1,) * offset_from_right)) + + # Construct a [:] slice for the output. + for i, size in enumerate(self.output.shape): + offset_from_right = len(self.output.shape) - i - 1 + index.append(np.arange(size).reshape( + (-1,) + (1,) * offset_from_right)) + + data = self.data[tuple(index)] + return Array(data, inputs, self.dtype) + + +@dispatch(np.ndarray) +def to_funsor(x): + return Array(x) + + +@dispatch(np.ndarray, Domain) +def to_funsor(x, output): + result = Array(x, dtype=output.dtype) + if result.output != output: + raise ValueError("Invalid shape: expected {}, actual {}" + .format(output.shape, result.output.shape)) + return result + + +@to_data.register(Array) +def _to_data_array(x): + if x.inputs: + raise ValueError("cannot convert Array to a data due to lazy inputs: {}" + .format(set(x.inputs))) + return x.data + + +@eager.register(Binary, object, Array, Number) +def eager_binary_array_number(op, lhs, rhs): + if op is ops.getitem: + # Shift by that Funsor is using for inputs. + index = [slice(None)] * len(lhs.inputs) + index.append(rhs.data) + index = tuple(index) + data = lhs.data[index] + else: + data = op(lhs.data, rhs.data) + return Array(data, lhs.inputs, lhs.dtype) + + +@eager.register(Binary, object, Number, Array) +def eager_binary_number_array(op, lhs, rhs): + data = op(lhs.data, rhs.data) + return Array(data, rhs.inputs, rhs.dtype) + + +@eager.register(Binary, object, Array, Array) +def eager_binary_array_array(op, lhs, rhs): + # Compute inputs and outputs. + dtype = find_domain(op, lhs.output, rhs.output).dtype + if lhs.inputs == rhs.inputs: + inputs = lhs.inputs + lhs_data, rhs_data = lhs.data, rhs.data + else: + inputs, (lhs_data, rhs_data) = align_arrays(lhs, rhs) + + if op is ops.getitem: + # getitem has special shape semantics. + if rhs.output.shape: + raise NotImplementedError('TODO support vector indexing') + assert lhs.output.shape == (rhs.dtype,) + index = [np.arange(size).reshape((-1,) + (1,) * (lhs_data.ndim - pos - 2)) + for pos, size in enumerate(lhs_data.shape)] + index[-1] = rhs_data + data = lhs_data[tuple(index)] + else: + data = op(lhs_data, rhs_data) + + return Array(data, inputs, dtype) + + +def arange(name, size): + """ + Helper to create a named :func:`numpy.arange` funsor. + + :param str name: A variable name. + :param int size: A size. + :rtype: Array + """ + data = np.arange(size) + inputs = OrderedDict([(name, bint(size))]) + return Array(data, inputs, dtype=size) + + +def materialize(x): + """ + Attempt to convert a Funsor to a :class:`~funsor.terms.Number` or + :class:`numpy.ndarray` by substituting :func:`arange` s into its free variables. + """ + assert isinstance(x, Funsor) + if isinstance(x, (Number, Array)): + return x + subs = [] + for name, domain in x.inputs.items(): + if not isinstance(domain.dtype, integer_types): + raise ValueError('materialize() requires integer free variables but found ' + '"{}" of domain {}'.format(name, domain)) + assert not domain.shape + subs.append((name, arange(name, domain.dtype))) + subs = tuple(subs) + return substitute(x, subs) diff --git a/funsor/ops.py b/funsor/ops.py index 5a430b0f5..8e0d8b143 100644 --- a/funsor/ops.py +++ b/funsor/ops.py @@ -1,10 +1,11 @@ from __future__ import absolute_import, division, print_function +import operator from numbers import Number -from operator import add, and_, eq, ge, getitem, gt, invert, le, lt, mul, ne, neg, or_, sub, truediv, xor import numpy as np -import torch +from multipledispatch import Dispatcher +from six import add_metaclass _builtin_abs = abs _builtin_max = max @@ -12,101 +13,256 @@ _builtin_pow = pow +class Op(Dispatcher): + def __init__(self, fn): + super(Op, self).__init__(fn.__name__) + # register as default operation + for nargs in (1, 2): + default_signature = (object,) * nargs + self.add(default_signature, fn) + + def __repr__(self): + return self.__name__ + + def __str__(self): + return self.__name__ + + +class TransformOp(Op): + def set_inv(self, fn): + """ + :param callable fn: A function that inputs an arg ``y`` and outputs a + value ``x`` such that ``y=self(x)``. + """ + assert callable(fn) + self.inv = fn + return fn + + def set_log_abs_det_jacobian(self, fn): + """ + :param callable fn: A function that inputs two args ``x, y``, where + ``y=self(x)``, and returns ``log(abs(det(dy/dx)))``. + """ + assert callable(fn) + self.log_abs_det_jacobian = fn + return fn + + @staticmethod + def inv(x): + raise NotImplementedError + + @staticmethod + def log_abs_det_jacobian(x, y): + raise NotImplementedError + + +class AssociativeOp(Op): + pass + + +class AddOp(AssociativeOp): + pass + + +class SubOp(Op): + pass + + +class NegOp(Op): + pass + + +class GetitemMeta(type): + _cache = {} + + def __call__(cls, offset): + try: + return GetitemMeta._cache[offset] + except KeyError: + instance = super(GetitemMeta, cls).__call__(offset) + GetitemMeta._cache[offset] = instance + return instance + + +@add_metaclass(GetitemMeta) +class GetitemOp(Op): + """ + Op encoding an index into one dime, e.g. ``x[:,:,:,y]`` for offset of 3. + """ + def __init__(self, offset): + assert isinstance(offset, int) + assert offset >= 0 + self.offset = offset + self._prefix = (slice(None),) * offset + super(GetitemOp, self).__init__(self._default) + self.__name__ = 'GetitemOp({})'.format(offset) + + def _default(self, x, y): + return x[self._prefix + (y,)] if self.offset else x[y] + + +getitem = GetitemOp(0) + +eq = Op(operator.eq) +ge = Op(operator.ge) +gt = Op(operator.gt) +invert = Op(operator.invert) +le = Op(operator.le) +lt = Op(operator.lt) +ne = Op(operator.ne) +neg = NegOp(operator.neg) +sub = SubOp(operator.sub) +truediv = Op(operator.truediv) + +add = AddOp(operator.add) +and_ = AssociativeOp(operator.and_) +mul = AssociativeOp(operator.mul) +or_ = AssociativeOp(operator.or_) +xor = AssociativeOp(operator.xor) + + +@add.register(object) +def _unary_add(x): + return x.sum() + + +@Op def abs(x): - return _builtin_abs(x) if isinstance(x, Number) else x.abs() + return x.abs() + +@abs.register(Number) +def _abs(x): + return _builtin_abs(x) + +@Op def sqrt(x): - return np.sqrt(x) if isinstance(x, Number) else x.sqrt() + return np.sqrt(x) +@TransformOp def exp(x): - return np.exp(x) if isinstance(x, Number) else x.exp() + return np.exp(x) + + +@exp.set_log_abs_det_jacobian +def log_abs_det_jacobian(x, y): + return add(x) +@TransformOp def log(x): - return np.log(x) if isinstance(x, Number) else x.log() + return np.log(x) +@log.set_log_abs_det_jacobian +def log_abs_det_jacobian(x, y): + return -add(y) + + +exp.set_inv(log) +log.set_inv(exp) + + +@Op def log1p(x): - return np.log1p(x) if isinstance(x, Number) else x.log1p() + return np.log1p(x) +@Op def pow(x, y): - result = x ** y - # work around shape bug https://github.com/pytorch/pytorch/issues/16685 - if isinstance(x, Number) and isinstance(y, torch.Tensor): - result = result.reshape(y.shape) - return result + return x ** y +@AssociativeOp def min(x, y): if hasattr(x, '__min__'): return x.__min__(y) if hasattr(y, '__min__'): return y.__min__(x) - if isinstance(x, torch.Tensor): - if isinstance(y, torch.Tensor): - return torch.min(x, y) - return x.clamp(max=y) - if isinstance(y, torch.Tensor): - return y.clamp(max=x) return _builtin_min(x, y) +@AssociativeOp def max(x, y): if hasattr(x, '__max__'): return x.__max__(y) if hasattr(y, '__max__'): return y.__max__(x) - if isinstance(x, torch.Tensor): - if isinstance(y, torch.Tensor): - return torch.max(x, y) - return x.clamp(min=y) - if isinstance(y, torch.Tensor): - return y.clamp(min=x) return _builtin_max(x, y) +@AssociativeOp def logaddexp(x, y): shift = max(x, y) return log(exp(x - shift) + exp(y - shift)) + shift -# just a placeholder -def marginal(x, y): - raise ValueError +@Op +def safesub(x, y): + if isinstance(y, Number): + return sub(x, y) + + +@Op +def safediv(x, y): + if isinstance(y, Number): + return truediv(x, y) -# just a placeholder -def sample(x, y): - raise ValueError +@Op +def reciprocal(x): + if isinstance(x, Number): + return 1. / x + raise ValueError("No reciprocal for type {}".format(type(x))) -REDUCE_OP_TO_TORCH = { - add: torch.sum, - mul: torch.prod, - and_: torch.all, - or_: torch.any, - logaddexp: torch.logsumexp, - min: torch.min, - max: torch.max, +DISTRIBUTIVE_OPS = frozenset([ + (logaddexp, add), + (add, mul), + (max, mul), + (min, mul), + (max, add), + (min, add), +]) + + +UNITS = { + mul: 1., + add: 0., +} + + +PRODUCT_INVERSES = { + mul: safediv, + add: safesub, } __all__ = [ - 'REDUCE_OP_TO_TORCH', + 'AddOp', + 'AssociativeOp', + 'DISTRIBUTIVE_OPS', + 'GetitemOp', + 'NegOp', + 'Op', + 'PRODUCT_INVERSES', + 'UNITS', + 'SubOp', 'abs', 'add', 'and_', 'eq', + 'exp', 'ge', 'getitem', 'gt', 'invert', 'le', + 'log', + 'log1p', 'lt', - 'marginal', 'max', 'min', 'mul', @@ -114,12 +270,10 @@ def sample(x, y): 'neg', 'or_', 'pow', - 'sample', + 'safediv', + 'safesub', + 'sqrt', 'sub', 'truediv', 'xor', - 'sqrt', - 'exp', - 'log', - 'log1p', ] diff --git a/funsor/optimizer.py b/funsor/optimizer.py new file mode 100644 index 000000000..961e38071 --- /dev/null +++ b/funsor/optimizer.py @@ -0,0 +1,293 @@ +from __future__ import absolute_import, division, print_function + +import collections + +from opt_einsum.paths import greedy +from six.moves import reduce + +import funsor.ops as ops +from funsor.contract import Contract, contractor +from funsor.delta import Delta +from funsor.domains import find_domain +from funsor.gaussian import Gaussian +from funsor.integrate import Integrate +from funsor.interpreter import dispatched_interpretation, interpretation, reinterpret +from funsor.joint import Joint +from funsor.ops import DISTRIBUTIVE_OPS, UNITS, AssociativeOp +from funsor.terms import Binary, Funsor, Reduce, Unary, eager, lazy, to_funsor +from funsor.torch import Tensor + + +class Finitary(Funsor): + """ + Lazy finitary operation. Used internally in the optimizer. + Finitary(op, operands) == six.moves.reduce(op, operands) + """ + def __init__(self, op, operands): + assert callable(op) + assert isinstance(operands, tuple) + assert all(isinstance(operand, Funsor) for operand in operands) + inputs = collections.OrderedDict() + for operand in operands: + inputs.update(operand.inputs) + + output = reduce(lambda lhs, rhs: find_domain(op, lhs, rhs), + [operand.output for operand in reversed(operands)]) + + super(Finitary, self).__init__(inputs, output) + self.op = op + self.operands = operands + + def __repr__(self): + return 'Finitary({}, {})'.format(self.op.__name__, self.operands) + + +@eager.register(Finitary, AssociativeOp, tuple) +def eager_finitary(op, operands): + return reduce(op, operands) + + +@dispatched_interpretation +def associate(cls, *args): + result = associate.dispatch(cls, *args) + if result is None: + result = lazy(cls, *args) + return result + + +@associate.register(Binary, AssociativeOp, Funsor, Funsor) +def binary_to_finitary(op, lhs, rhs): + """convert Binary to Finitary""" + return Finitary(op, (lhs, rhs)) + + +@associate.register(Finitary, AssociativeOp, tuple) +def associate_finitary(op, operands): + # Finitary(Finitary) -> Finitary + new_operands = [] + for term in operands: + if isinstance(term, Finitary) and term.op is op: + new_operands.extend(term.operands) + else: + new_operands.append(term) + + with interpretation(lazy): + return Finitary(op, tuple(new_operands)) + + +@associate.register(Reduce, AssociativeOp, Reduce, frozenset) +def associate_reduce(op, arg, reduced_vars): + """ + Rewrite to the largest possible Reduce(Finitary) by combining Reduces + Assumes that all input Reduce/Finitary ops have been rewritten + """ + if arg.op is op: + # Reduce(Reduce) -> Reduce + new_reduced_vars = reduced_vars.union(arg.reduced_vars) + return Reduce(op, arg.arg, new_reduced_vars) + return None + + +@dispatched_interpretation +def distribute(cls, *args): + result = distribute.dispatch(cls, *args) + if result is None: + result = lazy(cls, *args) + return result + + +@distribute.register(Finitary, AssociativeOp, tuple) +def distribute_finitary(op, operands): + # TODO raise an error or warning on name collision + if len(operands) == 1: + return operands[0] + + reduce_op = None + reduce_terms, remaining_terms, reduced_vars = [], [], frozenset() + for term in operands: + # term is not reduce -> do nothing + # term is reduce but does not distribute -> do nothing + # term is reduce and distributes -> put into reduce_terms + if isinstance(term, Reduce): + if not reduce_op or not (reduce_op, op) in DISTRIBUTIVE_OPS: + reduce_op = term.op + if term.op == reduce_op and (reduce_op, op) in DISTRIBUTIVE_OPS: + reduce_terms.append(term) + reduced_vars = reduced_vars | term.reduced_vars + else: + remaining_terms.append(term) + else: + remaining_terms.append(term) + + if len(reduce_terms) > 1: + new_finitary_term = Finitary(op, tuple(term.arg for term in reduce_terms)) + remaining_terms.append(Reduce(reduce_op, new_finitary_term, reduced_vars)) + return Finitary(op, tuple(remaining_terms)) + + return None + + +@dispatched_interpretation +def optimize(cls, *args): + result = optimize.dispatch(cls, *args) + if result is None: + result = lazy(cls, *args) + return result + + +# TODO set a better value for this +REAL_SIZE = 3 # the "size" of a real-valued dimension passed to the path optimizer + + +@optimize.register(Reduce, AssociativeOp, Funsor, frozenset) +def optimize_reduction_trivial(op, arg, reduced_vars): + if not reduced_vars: + return arg + return None + + +@optimize.register(Reduce, AssociativeOp, Binary, frozenset) +@eager.register(Reduce, AssociativeOp, Binary, frozenset) +def optimize_reduce_binary_exp(op, arg, reduced_vars): + if op is not ops.add or arg.op is not ops.mul or \ + not isinstance(arg.lhs, Unary) or arg.lhs.op is not ops.exp: + return None + return Integrate(arg.lhs.arg, arg.rhs, reduced_vars) + + +@optimize.register(Reduce, AssociativeOp, Finitary, frozenset) +def optimize_reduction(op, arg, reduced_vars): + r""" + Recursively convert large Reduce(Finitary) ops to many smaller versions + by reordering execution with a modified opt_einsum optimizer + """ + if not reduced_vars: + return arg + + if not (op, arg.op) in DISTRIBUTIVE_OPS: + return None + + return Contract(op, arg.op, arg, to_funsor(UNITS[arg.op]), reduced_vars) + + +@optimize.register(Contract, AssociativeOp, AssociativeOp, Finitary, (Finitary, Funsor, Unary), frozenset) +def optimize_contract_finitary_funsor(sum_op, prod_op, lhs, rhs, reduced_vars): + + if prod_op is not lhs.op: + return None + + # build opt_einsum optimizer IR + inputs = [frozenset(t.inputs) for t in lhs.operands] + [frozenset(rhs.inputs)] + size_dict = {k: ((REAL_SIZE * v.num_elements) if v.dtype == 'real' else v.dtype) + for arg in (lhs, rhs) for k, v in arg.inputs.items()} + outputs = frozenset().union(*inputs) - reduced_vars + + # optimize path with greedy opt_einsum optimizer + # TODO switch to new 'auto' strategy when it's released + path = greedy(inputs, outputs, size_dict) + + # convert path IR back to sequence of Reduce(Finitary(...)) + + # first prepare a reduce_dim counter to avoid early reduction + reduce_dim_counter = collections.Counter() + for input in inputs: + reduce_dim_counter.update({d: 1 for d in input}) + + operands = list(lhs.operands) + [rhs] + for (a, b) in path: + b, a = tuple(sorted((a, b), reverse=True)) + tb = operands.pop(b) + ta = operands.pop(a) + + # don't reduce a dimension too early - keep a collections.Counter + # and only reduce when the dimension is removed from all lhs terms in path + reduce_dim_counter.subtract({d: 1 for d in reduced_vars & frozenset(ta.inputs.keys())}) + reduce_dim_counter.subtract({d: 1 for d in reduced_vars & frozenset(tb.inputs.keys())}) + + # reduce variables that don't appear in other terms + both_vars = frozenset(ta.inputs.keys()) | frozenset(tb.inputs.keys()) + path_end_reduced_vars = frozenset(d for d in reduced_vars & both_vars + if reduce_dim_counter[d] == 0) + + # count new appearance of variables that aren't reduced + reduce_dim_counter.update({d: 1 for d in reduced_vars & (both_vars - path_end_reduced_vars)}) + + path_end = Contract(sum_op, prod_op, ta, tb, path_end_reduced_vars) + operands.append(path_end) + + # reduce any remaining dims, if necessary + final_reduced_vars = frozenset(d for (d, count) in reduce_dim_counter.items() + if count > 0) & reduced_vars + if final_reduced_vars: + path_end = Reduce(sum_op, path_end, final_reduced_vars) + return path_end + + +@optimize.register(Finitary, AssociativeOp, tuple) +def remove_single_finitary(op, operands): + if len(operands) == 1: + return operands[0] + return None + + +@optimize.register(Unary, ops.Op, Finitary) +def optimize_exp_finitary(op, arg): + # useful for handling Integrate... + if op is not ops.exp or arg.op is not ops.add: + return None + return Finitary(ops.mul, tuple(operand.exp() for operand in arg.operands)) + + +@optimize.register(Contract, AssociativeOp, AssociativeOp, Unary, Unary, frozenset) +@optimize.register(Contract, AssociativeOp, AssociativeOp, Funsor, Funsor, frozenset) +@contractor +def optimize_contract(sum_op, prod_op, lhs, rhs, reduced_vars): + return None + + +@optimize.register(Contract, AssociativeOp, AssociativeOp, (Unary, Funsor), Finitary, frozenset) +def optimize_contract_funsor_finitary(sum_op, prod_op, lhs, rhs, reduced_vars): + return Contract(sum_op, prod_op, rhs, lhs, reduced_vars) + + +@optimize.register(Contract, AssociativeOp, AssociativeOp, Unary, Funsor, frozenset) +def optimize_contract_exp_funsor(sum_op, prod_op, lhs, rhs, reduced_vars): + if lhs.op is ops.exp and isinstance(lhs.arg, (Gaussian, Tensor, Delta, Joint)) and \ + sum_op is ops.add and prod_op is ops.mul: + return Integrate(lhs.arg, rhs, reduced_vars) + return None + + +@optimize.register(Contract, AssociativeOp, AssociativeOp, Funsor, Unary, frozenset) +def optimize_contract_funsor_exp(sum_op, prod_op, lhs, rhs, reduced_vars): + return Contract(sum_op, prod_op, rhs, lhs, reduced_vars) + + +@dispatched_interpretation +def desugar(cls, *args): + result = desugar.dispatch(cls, *args) + if result is None: + result = lazy(cls, *args) + return result + + +@desugar.register(Finitary, AssociativeOp, tuple) +def desugar_finitary(op, operands): + return reduce(op, operands) + + +def apply_optimizer(x): + + with interpretation(associate): + x = reinterpret(x) + + with interpretation(distribute): + x = reinterpret(x) + + with interpretation(optimize): + x = reinterpret(x) + + with interpretation(desugar): + x = reinterpret(x) + + return reinterpret(x) # use previous interpretation diff --git a/funsor/pattern.py b/funsor/pattern.py new file mode 100644 index 000000000..fbb06d3e9 --- /dev/null +++ b/funsor/pattern.py @@ -0,0 +1,69 @@ +from __future__ import absolute_import, division, print_function + +import functools + +import unification.match +from unification import unify +from unification.variable import isvar + +import funsor.ops as ops +from funsor.interpreter import dispatched_interpretation, interpretation +from funsor.terms import Binary, Funsor, Variable, lazy + + +@dispatched_interpretation +def unify_interpreter(cls, *args): + result = unify_interpreter.dispatch(cls, *args) + if result is None: + result = lazy(cls, *args) + return result + + +@unify_interpreter.register(Binary, ops.Op, Funsor, Funsor) +def unify_eq(op, lhs, rhs): + if op is ops.eq: + return lhs is rhs # equality via cons-hashing + return None + + +class EqDispatcher(unification.match.Dispatcher): + + resolve = interpretation(unify_interpreter)(unification.match.Dispatcher.resolve) + + +class EqVarDispatcher(EqDispatcher): + + def __call__(self, *args, **kwargs): + func, s = self.resolve(args) + d = dict((k.name if isinstance(k, Variable) else k.token, v) for k, v in s.items()) + return func(**d) + + +@isvar.register(Variable) +def _isvar_funsor_variable(v): + return True + + +@unify.register(Funsor, Funsor, dict) +@interpretation(unify_interpreter) +def unify_funsor(pattern, expr, subs): + if type(pattern) is not type(expr): + return False + return unify(pattern._ast_values, expr._ast_values, subs) + + +@unify.register(Variable, (Variable, Funsor), dict) +@unify.register((Variable, Funsor), Variable, dict) +def unify_patternvar(pattern, expr, subs): + subs.update({pattern: expr} if isinstance(pattern, Variable) else {expr: pattern}) + return subs + + +match_vars = functools.partial(unification.match.match, Dispatcher=EqVarDispatcher) +match = functools.partial(unification.match.match, Dispatcher=EqDispatcher) + + +__all__ = [ + "match", + "match_vars", +] diff --git a/funsor/six.py b/funsor/six.py index eef789240..9428b0d4e 100644 --- a/funsor/six.py +++ b/funsor/six.py @@ -1,6 +1,7 @@ from __future__ import absolute_import, division, print_function import inspect +import re import six @@ -44,11 +45,30 @@ def decorator(fn): def getargspec(fn): - """wrapper to remove annoying DeprecationWarning for inspect.getargspec in Py3""" - if six.PY3: - args, vargs, kwargs, defaults, _, _, _ = inspect.getfullargspec(fn) - else: - args, vargs, kwargs, defaults = inspect.getargspec(fn) + """ + Similar to Python 2's :py:func:`inspect.getargspec` but: + - In Python 3 uses ``getfullargspec`` to avoid ``DeprecationWarning``. + - For builtin functions like ``torch.matmul``, falls back to attmpting + to parse the function docstring, assuming torch-style. + """ + assert callable(fn) + try: + if six.PY3: + args, vargs, kwargs, defaults, _, _, _ = inspect.getfullargspec(fn) + else: + args, vargs, kwargs, defaults = inspect.getargspec(fn) + except TypeError: + # Fall back to attmpting to parse a PyTorch-style docstring. + match = re.match(r"\s{}\(([^)]*)\)".format(fn.__name__), fn.__doc__) + if match is None: + raise + parts = match.group(1).split(", ") + args = [a.split("=")[0] for a in parts] + if not all(re.match(r"^[^\d\W]\w*\Z", arg) for arg in args): + raise + vargs = None + kwargs = None + defaults = () # Ignore defaults. return args, vargs, kwargs, defaults diff --git a/funsor/sum_product.py b/funsor/sum_product.py new file mode 100644 index 000000000..a41910d6e --- /dev/null +++ b/funsor/sum_product.py @@ -0,0 +1,98 @@ +from __future__ import absolute_import, division, print_function + +from collections import OrderedDict, defaultdict + +from six.moves import reduce + +from funsor.ops import UNITS +from funsor.terms import Funsor, Number + + +def _partition(terms, sum_vars): + # Construct a bipartite graph between terms and the vars + neighbors = OrderedDict([(t, []) for t in terms]) + for term in terms: + for dim in term.inputs.keys(): + if dim in sum_vars: + neighbors[term].append(dim) + neighbors.setdefault(dim, []).append(term) + + # Partition the bipartite graph into connected components for contraction. + components = [] + while neighbors: + v, pending = neighbors.popitem() + component = OrderedDict([(v, None)]) # used as an OrderedSet + for v in pending: + component[v] = None + while pending: + v = pending.pop() + for v in neighbors.pop(v): + if v not in component: + component[v] = None + pending.append(v) + + # Split this connected component into tensors and dims. + component_terms = tuple(v for v in component if isinstance(v, Funsor)) + if component_terms: + component_dims = frozenset(v for v in component if not isinstance(v, Funsor)) + components.append((component_terms, component_dims)) + return components + + +def partial_sum_product(sum_op, prod_op, factors, eliminate=frozenset(), plates=frozenset()): + """ + Performs partial sum-product contraction of a collection of factors. + + :return: a list of partially contracted Funsors. + :rtype: list + """ + assert callable(sum_op) + assert callable(prod_op) + assert isinstance(factors, (tuple, list)) + assert all(isinstance(f, Funsor) for f in factors) + assert isinstance(eliminate, frozenset) + assert isinstance(plates, frozenset) + sum_vars = eliminate - plates + + var_to_ordinal = {} + ordinal_to_factors = defaultdict(list) + for f in factors: + ordinal = plates.intersection(f.inputs) + ordinal_to_factors[ordinal].append(f) + for var in sum_vars.intersection(f.inputs): + var_to_ordinal[var] = var_to_ordinal.get(var, ordinal) & ordinal + + ordinal_to_vars = defaultdict(set) + for var, ordinal in var_to_ordinal.items(): + ordinal_to_vars[ordinal].add(var) + + results = [] + while ordinal_to_factors: + leaf = max(ordinal_to_factors, key=len) + leaf_factors = ordinal_to_factors.pop(leaf) + leaf_reduce_vars = ordinal_to_vars[leaf] + for (group_factors, group_vars) in _partition(leaf_factors, leaf_reduce_vars): + f = reduce(prod_op, group_factors).reduce(sum_op, group_vars) + remaining_sum_vars = sum_vars.intersection(f.inputs) + if not remaining_sum_vars: + results.append(f.reduce(prod_op, leaf & eliminate)) + else: + new_plates = frozenset().union( + *(var_to_ordinal[v] for v in remaining_sum_vars)) + if new_plates == leaf: + raise ValueError("intractable!") + f = f.reduce(prod_op, leaf - new_plates) + ordinal_to_factors[new_plates].append(f) + + return results + + +def sum_product(sum_op, prod_op, factors, eliminate=frozenset(), plates=frozenset()): + """ + Performs sum-product contraction of a collection of factors. + + :return: a single contracted Funsor. + :rtype: :class:`~funsor.terms.Funsor` + """ + factors = partial_sum_product(sum_op, prod_op, factors, eliminate, plates) + return reduce(prod_op, factors, Number(UNITS[prod_op])) diff --git a/funsor/terms.py b/funsor/terms.py index c6cfbc228..b242d1f70 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -1,72 +1,136 @@ -r""" -Funsor interpretations ----------------------- - -Funsor provides three basic interpretations. - -- ``reflect`` is completely lazy, even with respect to substitution. -- ``lazy`` substitutes eagerly but performs ops lazily. -- ``eager`` does everything eagerly. - -""" - from __future__ import absolute_import, division, print_function import functools import itertools +import math import numbers -from abc import ABCMeta, abstractmethod -from collections import OrderedDict +import re +from collections import Hashable, OrderedDict from weakref import WeakValueDictionary +from multipledispatch import dispatch from six import add_metaclass, integer_types from six.moves import reduce import funsor.interpreter as interpreter import funsor.ops as ops -from funsor.domains import Domain, find_domain, bint -from funsor.interpreter import interpret -from funsor.registry import KeyedRegistry +from funsor.domains import Domain, bint, find_domain, reals +from funsor.interpreter import dispatched_interpretation, interpret +from funsor.ops import AssociativeOp, GetitemOp, Op from funsor.six import getargspec, singledispatch +def substitute(expr, subs): + if isinstance(subs, (dict, OrderedDict)): + subs = tuple(subs.items()) + assert isinstance(subs, tuple) + + @interpreter.interpretation(interpreter._INTERPRETATION) # use base + def subs_interpreter(cls, *args): + expr = cls(*args) + fresh_subs = tuple((k, v) for k, v in subs if k in expr.fresh) + if fresh_subs: + expr = interpreter.debug_logged(expr.eager_subs)(fresh_subs) + return expr + + with interpreter.interpretation(subs_interpreter): + return interpreter.reinterpret(expr) + + +def alpha_convert(expr): + alpha_subs = {name: interpreter.gensym(name + "__BOUND") + for name in expr.bound if "__BOUND" not in name} + if not alpha_subs: + return expr + + new_values = [] + for v in expr._ast_values: + v = substitute(v, alpha_subs) + if isinstance(v, str) and v not in expr.fresh: + v = alpha_subs.get(v, v) + elif isinstance(v, frozenset): + swapped = v & frozenset(alpha_subs.keys()) + v |= frozenset(alpha_subs[k] for k in swapped) + v -= swapped + elif isinstance(v, tuple) and isinstance(v[0], tuple) and len(v[0]) == 2 and \ + isinstance(v[0][0], str) and isinstance(v[0][1], Funsor): + v = tuple((alpha_subs[k] if k in alpha_subs else k, vv) for k, vv in v) + elif isinstance(v, OrderedDict): # XXX is this case ever actually triggered? + v = OrderedDict([(alpha_subs[k] if k in alpha_subs else k, vv) for k, vv in v.items()]) + new_values.append(v) + + return reflect(type(expr), *new_values) + + def reflect(cls, *args): """ Construct a funsor, populate ``._ast_values``, and cons hash. + This is the only interpretation allowed to construct funsors. """ - if args in cls._cons_cache: - return cls._cons_cache[args] + cache_key = tuple(id(arg) if not isinstance(arg, Hashable) else arg for arg in args) + if cache_key in cls._cons_cache: + return cls._cons_cache[cache_key] + result = super(FunsorMeta, cls).__call__(*args) result._ast_values = args - cls._cons_cache[args] = result - return result + # alpha-convert eagerly upon binding any variable + result = alpha_convert(result) -_lazy = KeyedRegistry(default=lambda *args: None) -_eager = KeyedRegistry(default=lambda *args: None) + cls._cons_cache[cache_key] = result + return result +@dispatched_interpretation def lazy(cls, *args): - result = _lazy(cls, *args) + """ + Substitute eagerly but perform ops lazily. + """ + result = lazy.dispatch(cls, *args) if result is None: result = reflect(cls, *args) return result +@dispatched_interpretation def eager(cls, *args): - result = _eager(cls, *args) + """ + Eagerly execute ops with known implementations. + """ + result = eager.dispatch(cls, *args) if result is None: result = reflect(cls, *args) return result -lazy.register = _lazy.register -eager.register = _eager.register +@dispatched_interpretation +def sequential(cls, *args): + """ + Eagerly execute ops with known implementations; additonally execute + vectorized ops sequentially if no known vectorized implementation exists. + """ + result = sequential.dispatch(cls, *args) + if result is None: + result = eager(cls, *args) + return result + + +@dispatched_interpretation +def moment_matching(cls, *args): + """ + A moment matching interpretation of :class:`Reduce` expressions. This falls + back to :class:`eager` in other cases. + """ + result = moment_matching.dispatch(cls, *args) + if result is None: + result = eager(cls, *args) + return result + interpreter.set_interpretation(eager) # Use eager interpretation by default. -class FunsorMeta(ABCMeta): +class FunsorMeta(type): """ Metaclass for Funsors to perform three independent tasks: @@ -111,22 +175,28 @@ class Funsor(object): Concrete derived classes must implement ``__init__()`` methods taking hashable ``*args`` and no optional ``**kwargs`` so as to support cons - hashing. Derived classes must implement an :meth:`eager_subs` method. + hashing. :param OrderedDict inputs: A mapping from input name to domain. This can be viewed as a typed context or a mapping from free variables to domains. :param Domain output: An output domain. """ - def __init__(self, inputs, output): + def __init__(self, inputs, output, fresh=None, bound=None): + fresh = frozenset() if fresh is None else fresh + bound = frozenset() if bound is None else bound assert isinstance(inputs, OrderedDict) for name, input_ in inputs.items(): assert isinstance(name, str) assert isinstance(input_, Domain) assert isinstance(output, Domain) + assert isinstance(fresh, frozenset) + assert isinstance(bound, frozenset) super(Funsor, self).__init__() self.inputs = inputs self.output = output + self.fresh = fresh + self.bound = bound @property def dtype(self): @@ -135,6 +205,32 @@ def dtype(self): def __hash__(self): return id(self) + def __repr__(self): + return '{}({})'.format(type(self).__name__, ', '.join(map(repr, self._ast_values))) + + def __str__(self): + return '{}({})'.format(type(self).__name__, ', '.join(map(str, self._ast_values))) + + def _pretty(self, lines, indent=0): + lines.append((indent, type(self).__name__)) + for arg in self._ast_values: + if isinstance(arg, Funsor): + arg._pretty(lines, indent + 1) + elif type(arg) is tuple and all(isinstance(x, Funsor) for x in arg): + lines.append((indent + 1, 'tuple')) + for x in arg: + x._pretty(lines, indent + 2) + else: + lines.append((indent + 1, re.sub('\n\\s*', ' ', str(arg)))) + + def pretty(self): + lines = [] + self._pretty(lines) + return '\n'.join('| ' * indent + text for indent, text in lines) + + def __contains__(self, item): + raise TypeError + def __call__(self, *args, **kwargs): """ Partially evaluates this funsor by substituting dimensions. @@ -145,12 +241,12 @@ def __call__(self, *args, **kwargs): if k in kwargs: subs[k] = kwargs[k] for k, v in subs.items(): - if isinstance(v, str): - # Allow renaming of inputs via syntax x(y="z"). - subs[k] = Variable(v, self.inputs[k]) - else: - subs[k] = to_funsor(v) - return self.eager_subs(tuple(subs.items())) + v = to_funsor(v, self.inputs[k]) + if v.output != self.inputs[k]: + raise ValueError("Expected substitution of {} to have type {}, but got {}" + .format(repr(k), v.output, self.inputs[k])) + subs[k] = v + return Subs(self, tuple(subs.items())) def __bool__(self): if self.inputs or self.output.shape: @@ -161,6 +257,15 @@ def __bool__(self): def __nonzero__(self): return self.__bool__() + def __len__(self): + if not self.output.shape: + raise ValueError('Funsor with empty shape has no len()') + return self.output.shape[0] + + def __iter__(self): + for i in range(len(self)): + yield self[i] + def item(self): if self.inputs or self.output.shape: raise ValueError( @@ -176,6 +281,7 @@ def reduce(self, op, reduced_vars=None): If unspecified, all inputs will be reduced. :type reduced_vars: str or frozenset """ + assert isinstance(op, AssociativeOp) # Eagerly convert reduced_vars to appropriate things. if reduced_vars is None: # Empty reduced_vars means "reduce over everything". @@ -184,9 +290,65 @@ def reduce(self, op, reduced_vars=None): # A single name means "reduce over this one variable". reduced_vars = frozenset([reduced_vars]) assert isinstance(reduced_vars, frozenset), reduced_vars + if not reduced_vars: + return self assert reduced_vars.issubset(self.inputs) return Reduce(op, self, reduced_vars) + def sample(self, sampled_vars, sample_inputs=None): + """ + Create a Monte Carlo approximation to this funsor by replacing + functions of ``sampled_vars`` with :class:`~funsor.delta.Delta` s. + + The result is a :class:`Funsor` with the same ``.inputs`` and + ``.output`` as the original funsor (plus ``sample_inputs`` if + provided), so that self can be replaced by the sample in expectation + computations:: + + y = x.sample(sampled_vars) + assert y.inputs == x.inputs + assert y.output == x.output + exact = (x.exp() * integrand).reduce(ops.add) + approx = (y.exp() * integrand).reduce(ops.add) + + If ``sample_inputs`` is provided, this creates a batch of samples + scaled samples. + + :param frozenset sampled_vars: A set of input variables to sample. + :param OrderedDict sample_inputs: An optional mapping from variable + name to :class:`~funsor.domains.Domain` over which samples will + be batched. + """ + assert self.output == reals() + assert isinstance(sampled_vars, frozenset) + if sample_inputs is None: + sample_inputs = OrderedDict() + assert isinstance(sample_inputs, OrderedDict) + if sampled_vars.isdisjoint(self.inputs): + return self + + result = interpreter.debug_logged(self.unscaled_sample)(sampled_vars, sample_inputs) + if sample_inputs is not None: + log_scale = 0 + for var, domain in sample_inputs.items(): + if var in result.inputs and var not in self.inputs: + log_scale -= math.log(domain.dtype) + if log_scale != 0: + result += log_scale + return result + + def unscaled_sample(self, sampled_vars, sample_inputs): + """ + Internal method to draw an unscaled sample. + This should be overridden by subclasses. + """ + assert self.output == reals() + assert isinstance(sampled_vars, frozenset) + assert isinstance(sample_inputs, OrderedDict) + if sampled_vars.isdisjoint(self.inputs): + return self + raise ValueError("Cannot sample from a {}".format(type(self).__name__)) + def align(self, names): """ Align this funsor to match given ``names``. @@ -203,15 +365,14 @@ def align(self, names): return self return Align(self, names) - @abstractmethod def eager_subs(self, subs): """ Internal substitution function. This relies on the user-facing :meth:`__call__` method to coerce non-Funsors to Funsors. Once all inputs are Funsors, :meth:`eager_subs` implementations can recurse to - call other :meth:`eager_subs` methods. + call :class:`Subs`. """ - raise NotImplementedError + return None # defer to default implementation def eager_unary(self, op): return None # defer to default implementation @@ -221,6 +382,13 @@ def eager_reduce(self, op, reduced_vars): if not reduced_vars: return self + return None # defer to default implementation + + def sequential_reduce(self, op, reduced_vars): + assert reduced_vars.issubset(self.inputs) # FIXME Is this valid? + if not reduced_vars: + return self + # Try to sum out integer scalars. This is mainly useful for testing, # since reduction is more efficiently implemented by Tensor. eager_vars = [] @@ -241,6 +409,15 @@ def eager_reduce(self, op, reduced_vars): return None # defer to default implementation + def moment_matching_reduce(self, op, reduced_vars): + assert reduced_vars.issubset(self.inputs) # FIXME Is this valid? + if not reduced_vars: + return self + + return None # defer to default implementation + + # The following methods conform to a standard array/tensor interface. + def __invert__(self): return Unary(ops.invert, self) @@ -262,6 +439,31 @@ def log(self): def log1p(self): return Unary(ops.log1p, self) + # The following reductions are treated as Unary ops because they + # reduce over output shape while preserving all inputs. + # To reduce over inputs, instead call .reduce(op, reduced_vars). + + def sum(self): + return Unary(ops.add, self) + + def prod(self): + return Unary(ops.mul, self) + + def logsumexp(self): + return Unary(ops.logaddexp, self) + + def all(self): + return Unary(ops.and_, self) + + def any(self): + return Unary(ops.or_, self) + + def min(self): + return Unary(ops.min, self) + + def max(self): + return Unary(ops.max, self) + def __add__(self, other): return Binary(ops.add, self, to_funsor(other)) @@ -332,50 +534,100 @@ def __max__(self, other): return Binary(ops.max, self, to_funsor(other)) def __getitem__(self, other): - return Binary(ops.getitem, self, to_funsor(other)) + if type(other) is not tuple: + other = to_funsor(other, bint(self.output.shape[0])) + return Binary(ops.getitem, self, other) + + # Handle Ellipsis slicing. + if any(part is Ellipsis for part in other): + left = [] + for part in other: + if part is Ellipsis: + break + left.append(part) + right = [] + for part in reversed(other): + if part is Ellipsis: + break + right.append(part) + right.reverse() + missing = len(self.output.shape) - len(left) - len(right) + assert missing >= 0 + middle = [slice(None)] * missing + other = tuple(left + middle + right) + + # Handle each slice separately. + result = self + offset = 0 + for part in other: + if isinstance(part, slice): + if part != slice(None): + raise NotImplementedError('TODO support nontrivial slicing') + offset += 1 + else: + part = to_funsor(part, bint(result.output.shape[offset])) + result = Binary(GetitemOp(offset), result, part) + return result - def sum(self, reduced_vars=None): - return self.reduce(ops.add, reduced_vars) - def prod(self, reduced_vars=None): - return self.reduce(ops.mul, reduced_vars) +interpreter.recursion_reinterpret.register(Funsor)(interpreter.reinterpret_funsor) +interpreter.children.register(Funsor)(interpreter.children_funsor) - def logsumexp(self, reduced_vars=None): - return self.reduce(ops.logaddexp, reduced_vars) - def all(self, reduced_vars=None): - return self.reduce(ops.and_, reduced_vars) +@dispatch(object) +def to_funsor(x): + """ + Convert to a :class:`Funsor`. + Only :class:`Funsor`s and scalars are accepted. - def any(self, reduced_vars=None): - return self.reduce(ops.or_, reduced_vars) + :param x: An object. + :param funsor.domains.Domain output: An optional output hint. + :return: A Funsor equivalent to ``x``. + :rtype: Funsor + :raises: ValueError + """ + raise ValueError("Cannot convert to Funsor: {}".format(repr(x))) - def min(self, reduced_vars=None): - return self.reduce(ops.min, reduced_vars) - def max(self, reduced_vars=None): - return self.reduce(ops.max, reduced_vars) +@dispatch(object, Domain) +def to_funsor(x, output): + raise ValueError("Cannot convert to Funsor: {}".format(repr(x))) -interpreter.reinterpret.register(Funsor)(interpreter.reinterpret_funsor) +@dispatch(object, object) +def to_funsor(x, output): + raise TypeError("Invalid Domain: {}".format(repr(output))) -@singledispatch +@dispatch(Funsor) def to_funsor(x): + return x + + +@dispatch(Funsor, Domain) +def to_funsor(x, output): + if x.output != output: + raise ValueError("Output mismatch: {} vs {}".format(x.output, output)) + return x + + +@singledispatch +def to_data(x): """ - Convert to a :class:`Funsor`. - Only :class:`Funsor`s and scalars are accepted. + Extract a python object from a :class:`Funsor`. - :param x: An object. - :return: A Funsor equivalent to ``x``. - :rtype: Funsor + Raises a ``ValueError`` if free variables remain or if the funsor is lazy. + + :param x: An object, possibly a :class:`Funsor`. + :return: A non-funsor equivalent to ``x``. :raises: ValueError """ - raise ValueError("cannot convert to Funsor: {}".format(x)) + return x -@to_funsor.register(Funsor) -def _to_funsor_funsor(x): - return x +@to_data.register(Funsor) +def _to_data_funsor(x): + raise ValueError("cannot convert to a non-Funsor: {}".format(repr(x))) class Variable(Funsor): @@ -387,7 +639,8 @@ class Variable(Funsor): """ def __init__(self, name, output): inputs = OrderedDict([(name, output)]) - super(Variable, self).__init__(inputs, output) + fresh = frozenset({name}) + super(Variable, self).__init__(inputs, output, fresh) self.name = name def __repr__(self): @@ -397,11 +650,66 @@ def __str__(self): return self.name def eager_subs(self, subs): + assert len(subs) == 1 and subs[0][0] == self.name + v = subs[0][1] + return v if isinstance(v, Funsor) else to_funsor(v, self.output) + + +@dispatch(str, Domain) +def to_funsor(name, output): + return Variable(name, output) + + +class Subs(Funsor): + """ + Lazy substitution of the form ``x(u=y, v=z)``. + """ + def __init__(self, arg, subs): + assert isinstance(arg, Funsor) assert isinstance(subs, tuple) - for k, v in subs: - if k == self.name: - return v - return self + for key, value in subs: + assert isinstance(key, str) + assert key in arg.inputs + assert isinstance(value, Funsor) + inputs = arg.inputs.copy() + for key, value in subs: + del inputs[key] + for key, value in subs: + inputs.update(value.inputs) + fresh = frozenset() + bound = frozenset(key for key, value in subs if key not in inputs) + super(Subs, self).__init__(inputs, arg.output, fresh, bound) + self.arg = arg + self.subs = OrderedDict(subs) + + def __repr__(self): + return 'Subs({}, {})'.format(self.arg, self.subs) + + def unscaled_sample(self, sampled_vars, sample_inputs): + if any(k in sample_inputs for k, v in self.subs.items()): + raise NotImplementedError('TODO alpha-convert') + subs_sampled_vars = set() + for name in sampled_vars: + if name in self.arg.inputs: + if any(name in v.inputs for k, v in self.subs.items()): + raise ValueError("Cannot sample") + subs_sampled_vars.add(name) + else: + for k, v in self.subs.items(): + if name in v.inputs: + subs_sampled_vars.add(k) + subs_sampled_vars = frozenset(subs_sampled_vars) + arg = self.arg.unscaled_sample(subs_sampled_vars, sample_inputs) + return Subs(arg, tuple(self.subs.items())) + + +@lazy.register(Subs, Funsor, object) +@eager.register(Subs, Funsor, object) +def eager_subs(arg, subs): + assert isinstance(subs, tuple) + if not any(k in arg.inputs for k, v in subs): + return arg + return substitute(arg, subs) _PREFIX = { @@ -427,16 +735,17 @@ def __repr__(self): return '{}{}'.format(_PREFIX[self.op], self.arg) return 'Unary({}, {})'.format(self.op.__name__, self.arg) - def eager_subs(self, subs): - if not any(k in self.inputs for k, v in subs): - return self - arg = self.arg.eager_subs(subs) - return Unary(self.op, arg) + +@eager.register(Unary, Op, Funsor) +def eager_unary(op, arg): + return interpreter.debug_logged(arg.eager_unary)(op) -@eager.register(Unary, object, Funsor) +@eager.register(Unary, AssociativeOp, Funsor) def eager_unary(op, arg): - return arg.eager_unary(op) + if not arg.output.shape: + return arg + return interpreter.debug_logged(arg.eager_unary)(op) _INFIX = { @@ -469,12 +778,20 @@ def __repr__(self): return '({} {} {})'.format(self.lhs, _INFIX[self.op], self.rhs) return 'Binary({}, {}, {})'.format(self.op.__name__, self.lhs, self.rhs) - def eager_subs(self, subs): - if not any(k in self.inputs for k, v in subs): - return self - lhs = self.lhs.eager_subs(subs) - rhs = self.rhs.eager_subs(subs) - return Binary(self.op, lhs, rhs) + def eager_reduce(self, op, reduced_vars): + if op is self.op: + lhs = self.lhs.reduce(op, reduced_vars) + rhs = self.rhs.reduce(op, reduced_vars) + return op(lhs, rhs) + return interpreter.debug_logged(super(Binary, self).eager_reduce)(op, reduced_vars) + + def unscaled_sample(self, sampled_vars, sample_inputs=None): + if self.op is ops.logaddexp: + # Sample mixture components independently. + lhs = self.lhs.unscaled_sample(sampled_vars, sample_inputs) + rhs = self.rhs.unscaled_sample(sampled_vars, sample_inputs) + return Binary(ops.logaddexp, lhs, rhs) + raise TypeError("Cannot sample from Binary({}, ...)".format(self.op)) class Reduce(Funsor): @@ -487,7 +804,9 @@ def __init__(self, op, arg, reduced_vars): assert isinstance(reduced_vars, frozenset) inputs = OrderedDict((k, v) for k, v in arg.inputs.items() if k not in reduced_vars) output = arg.output - super(Reduce, self).__init__(inputs, output) + fresh = frozenset() + bound = reduced_vars + super(Reduce, self).__init__(inputs, output, fresh, bound) self.op = op self.arg = arg self.reduced_vars = reduced_vars @@ -496,26 +815,54 @@ def __repr__(self): return 'Reduce({}, {}, {})'.format( self.op.__name__, self.arg, self.reduced_vars) - def eager_subs(self, subs): - subs = tuple((k, v) for k, v in subs if k not in self.reduced_vars) - if not any(k in self.inputs for k, v in subs): - return self - if not all(self.reduced_vars.isdisjoint(v.inputs) for k, v in subs): - raise NotImplementedError('TODO alpha-convert to avoid conflict') - return self.arg.eager_subs(subs).reduce(self.op, self.reduced_vars) - def eager_reduce(self, op, reduced_vars): if op is self.op: # Eagerly fuse reductions. assert isinstance(reduced_vars, frozenset) reduced_vars = reduced_vars.intersection(self.inputs) | self.reduced_vars return Reduce(op, self.arg, reduced_vars) - return super(Reduce, self).reduce(op, reduced_vars) + return super(Reduce, self).eager_reduce(op, reduced_vars) + + def unscaled_sample(self, sampled_vars, sample_inputs=None): + if self.op is ops.logaddexp: + arg = self.arg.unscaled_sample(sampled_vars, sample_inputs) + return Reduce(ops.logaddexp, arg, self.reduced_vars) + raise TypeError("Cannot sample from Reduce({}, ...)".format(self.op)) -@eager.register(Reduce, object, Funsor, frozenset) +@eager.register(Reduce, AssociativeOp, Funsor, frozenset) def eager_reduce(op, arg, reduced_vars): - return arg.eager_reduce(op, reduced_vars) + return interpreter.debug_logged(arg.eager_reduce)(op, reduced_vars) + + +@eager.register(Binary, AssociativeOp, Reduce, (Funsor, Reduce)) +def eager_distribute_reduce_other(op, red, other): + if (red.op, op) in ops.DISTRIBUTIVE_OPS: + # Use distributive law. + arg = op(red.arg, other) + return arg.reduce(red.op, red.reduced_vars) + + return None # defer to default implementation + + +@eager.register(Binary, AssociativeOp, Funsor, Reduce) +def eager_distribute_other_reduce(op, other, red): + if (red.op, op) in ops.DISTRIBUTIVE_OPS: + # Use distributive law. + arg = op(other, red.arg) + return arg.reduce(red.op, red.reduced_vars) + + return None # defer to default implementation + + +@sequential.register(Reduce, AssociativeOp, Funsor, frozenset) +def sequential_reduce(op, arg, reduced_vars): + return interpreter.debug_logged(arg.sequential_reduce)(op, reduced_vars) + + +@moment_matching.register(Reduce, AssociativeOp, Funsor, frozenset) +def moment_matching_reduce(op, arg, reduced_vars): + return interpreter.debug_logged(arg.moment_matching_reduce)(op, reduced_vars) class NumberMeta(FunsorMeta): @@ -528,7 +875,6 @@ def __call__(cls, data, dtype=None): return super(NumberMeta, cls).__call__(data, dtype) -@to_funsor.register(numbers.Number) @add_metaclass(NumberMeta) class Number(Funsor): """ @@ -541,6 +887,8 @@ def __init__(self, data, dtype=None): assert isinstance(data, numbers.Number) if isinstance(dtype, integer_types): data = type(dtype)(data) + if dtype != 2: # booleans have bitwise interpretation + assert 0 <= data and data < dtype else: assert isinstance(dtype, str) and dtype == "real" data = float(data) @@ -570,14 +918,29 @@ def __bool__(self): def item(self): return self.data - def eager_subs(self, subs): - return self - def eager_unary(self, op): - return Number(op(self.data), self.dtype) + dtype = find_domain(op, self.output).dtype + return Number(op(self.data), dtype) + + +@dispatch(numbers.Number) +def to_funsor(x): + return Number(x) + +@dispatch(numbers.Number, Domain) +def to_funsor(x, output): + if output.shape: + raise ValueError("Cannot create Number with shape {}".format(output.shape)) + return Number(x, output.dtype) -@eager.register(Binary, object, Number, Number) + +@to_data.register(Number) +def _to_data_number(x): + return x.data + + +@eager.register(Binary, Op, Number, Number) def eager_binary_number_number(op, lhs, rhs): data = op(lhs.data, rhs.data) output = find_domain(op, lhs.output, rhs.output) @@ -597,38 +960,48 @@ def __init__(self, arg, names): inputs = OrderedDict((name, arg.inputs[name]) for name in names) inputs.update(arg.inputs) output = arg.output - super(Align, self).__init__(inputs, output) + fresh = frozenset() # TODO get this right + bound = frozenset() + super(Align, self).__init__(inputs, output, fresh, bound) self.arg = arg def align(self, names): return self.arg.align(names) - def eager_subs(self, subs): - assert isinstance(subs, tuple) - return self.arg.eager_subs(subs) - def eager_unary(self, op): - return self.arg.eager_unary(op) + return Unary(op, self.arg) def eager_reduce(self, op, reduced_vars): - return self.arg.eager_reduce(op, reduced_vars) + return self.arg.reduce(op, reduced_vars) -@eager.register(Binary, object, Align, Funsor) +@eager.register(Align, Funsor, tuple) +def eager_align(arg, names): + if not frozenset(names) == frozenset(arg.inputs.keys()): + # assume there's been a substitution and this align is no longer valid + return arg + return None + + +@eager.register(Binary, Op, Align, Funsor) def eager_binary_align_funsor(op, lhs, rhs): return Binary(op, lhs.arg, rhs) -@eager.register(Binary, object, Funsor, Align) +@eager.register(Binary, Op, Funsor, Align) def eager_binary_funsor_align(op, lhs, rhs): return Binary(op, lhs, rhs.arg) -@eager.register(Binary, object, Align, Align) +@eager.register(Binary, Op, Align, Align) def eager_binary_align_align(op, lhs, rhs): return Binary(op, lhs.arg, rhs.arg) +eager.register(Binary, AssociativeOp, Reduce, Align)(eager_distribute_reduce_other) +eager.register(Binary, AssociativeOp, Align, Reduce)(eager_distribute_other_reduce) + + class Stack(Funsor): """ Stack of funsors along a new input dimension. @@ -646,50 +1019,28 @@ def __init__(self, components, name): inputs = OrderedDict([(name, domain)]) for x in components: inputs.update(x.inputs) - super(Stack, self).__init__(inputs, output) + fresh = frozenset({name}) + super(Stack, self).__init__(inputs, output, fresh) self.components = components self.name = name def eager_subs(self, subs): - assert isinstance(subs, tuple) - if not any(k in self.inputs for k, v in subs): - return self - pos = None - for i, (k, index) in enumerate(subs): - if k == self.name: - pos = i - break - - if pos is None: - # Eagerly recurse into components. - assert not any(self.name in v.inputs for k, v in subs) - components = tuple(x.eager_subs(subs) for x in self.components) - return Stack(components, self.name) + assert isinstance(subs, tuple) and len(subs) == 1 and subs[0][0] == self.name + index = subs[0][1] # Try to eagerly select an index. assert index.output == bint(len(self.components)) - subs = subs[:pos] + subs[1 + pos:] if isinstance(index, Number): # Select a single component. - result = self.components[index.data] - return result.eager_subs(subs) - - if isinstance(index, Variable): + return self.components[index.data] + elif isinstance(index, Variable): # Rename the stacking dimension. components = self.components - if subs: - components = tuple(x.eager_subs(subs) for x in components) return Stack(components, index.name) - - if not subs: + else: raise NotImplementedError('TODO support advanced indexing in Stack') - # Eagerly recurse into components but lazily substitute. - components = tuple(x.eager_subs(subs) for x in self.components) - result = Stack(components, self.name) - return result.eager_subs(((self.name, index),)) - def eager_reduce(self, op, reduced_vars): components = self.components if self.name in reduced_vars: @@ -701,6 +1052,102 @@ def eager_reduce(self, op, reduced_vars): return Stack(components, self.name) +class Lambda(Funsor): + """ + Lazy inverse to ``ops.getitem``. + + This is useful to simulate higher-order functions of integers + by representing those functions as arrays. + """ + def __init__(self, var, expr): + assert isinstance(var, Variable) + assert isinstance(var.dtype, integer_types) + assert isinstance(expr, Funsor) + inputs = expr.inputs.copy() + inputs.pop(var.name, None) + shape = (var.dtype,) + expr.output.shape + output = Domain(shape, expr.dtype) + fresh = frozenset() + bound = frozenset({var.name}) # TODO make sure this is correct + super(Lambda, self).__init__(inputs, output, fresh, bound) + self.var = var + self.expr = expr + + +@eager.register(Binary, GetitemOp, Lambda, (Funsor, Align)) +def eager_getitem_lambda(op, lhs, rhs): + if op.offset == 0: + return Subs(lhs.expr, ((lhs.var.name, rhs),)) + expr = GetitemOp(op.offset - 1)(lhs.expr, rhs) + return Lambda(lhs.var, expr) + + +class Independent(Funsor): + """ + Creates an independent diagonal distribution. + + This is equivalent to substitution followed by reduction:: + + f = ... + assert f.inputs['x'] == reals(4, 5) + assert f.inputs['i'] == bint(3) + + g = Independent(f, 'x', 'i') + assert g.inputs['x'] == reals(3, 4, 5) + assert 'i' not in g.inputs + + x = Variable('x', reals(3, 4, 5)) + g == f(x=x['i']).reduce(ops.logaddexp, 'i') + """ + def __init__(self, fn, reals_var, bint_var): + assert isinstance(fn, Funsor) + assert isinstance(reals_var, str) + for k in fn.inputs: + if k == reals_var or k.startswith(reals_var + "__BOUND"): + reals_var_bound = k + break + assert reals_var_bound in fn.inputs + assert fn.inputs[reals_var_bound].dtype == 'real' + assert isinstance(bint_var, str) + assert bint_var in fn.inputs + assert isinstance(fn.inputs[bint_var].dtype, int) + inputs = fn.inputs.copy() + shape = (inputs.pop(bint_var).dtype,) + inputs.pop(reals_var_bound).shape + inputs[reals_var] = reals(*shape) + fresh = frozenset({reals_var}) + bound = frozenset({bint_var, reals_var_bound}) + super(Independent, self).__init__(inputs, fn.output, fresh, bound) + self.fn = fn + self.reals_var = reals_var + self.bint_var = bint_var + self.reals_var_bound = reals_var_bound + + def unscaled_sample(self, sampled_vars, sample_inputs): + if self.bint_var in sampled_vars or self.bint_var in sample_inputs: + raise NotImplementedError('TODO alpha-convert') + sampled_vars = frozenset(self.reals_var_bound if v == self.reals_var else v + for v in sampled_vars) + fn = self.fn.unscaled_sample(sampled_vars, sample_inputs) + return Independent(fn, self.reals_var, self.bint_var) + + def eager_subs(self, subs): + subs = tuple((self.reals_var_bound, v[self.bint_var]) + if k == self.reals_var + else (k, v) + for k, v in subs) + new_fn = substitute(self.fn, subs) + new_fn = new_fn.reduce(ops.add, self.bint_var) + return new_fn + + +@eager.register(Independent, Funsor, str, str) +def eager_independent_trivial(fn, reals_var, bint_var): + # compare to Independent.eager_subs + if not any(k.startswith(reals_var + "__BOUND") or k == reals_var for k in fn.inputs): + return fn.reduce(ops.add, bint_var) + return None + + def _of_shape(fn, shape): args, vargs, kwargs, defaults = getargspec(fn) assert not vargs @@ -718,17 +1165,52 @@ def of_shape(*shape): return functools.partial(_of_shape, shape=shape) +################################################################################ +# Register Ops +################################################################################ + +@ops.abs.register(Funsor) +def _abs(x): + return Unary(ops.abs, x) + + +@ops.sqrt.register(Funsor) +def _sqrt(x): + return Unary(ops.sqrt, x) + + +@ops.exp.register(Funsor) +def _exp(x): + return Unary(ops.exp, x) + + +@ops.log.register(Funsor) +def _log(x): + return Unary(ops.log, x) + + +@ops.log1p.register(Funsor) +def _log1p(x): + return Unary(ops.log1p, x) + + __all__ = [ 'Binary', 'Funsor', + 'Independent', + 'Lambda', 'Number', 'Reduce', 'Stack', + 'Subs', 'Unary', 'Variable', 'eager', 'lazy', + 'moment_matching', 'of_shape', 'reflect', + 'sequential', + 'to_data', 'to_funsor', ] diff --git a/funsor/testing.py b/funsor/testing.py index c8681e15d..90a27532f 100644 --- a/funsor/testing.py +++ b/funsor/testing.py @@ -1,27 +1,98 @@ from __future__ import absolute_import, division, print_function +import contextlib +import itertools +import numbers import operator +from collections import OrderedDict, namedtuple +import numpy as np +import opt_einsum +import pytest import torch -from six import integer_types from six.moves import reduce -from funsor.terms import Funsor +from funsor.delta import Delta +from funsor.domains import Domain, bint, reals +from funsor.gaussian import Gaussian +from funsor.joint import Joint +from funsor.numpy import Array +from funsor.terms import Funsor, Number from funsor.torch import Tensor +@contextlib.contextmanager +def xfail_if_not_implemented(msg="Not implemented"): + try: + yield + except NotImplementedError as e: + pytest.xfail(reason='{}:\n{}'.format(msg, e)) + + +class ActualExpected(namedtuple('LazyComparison', ['actual', 'expected'])): + """ + Lazy string formatter for test assertions. + """ + def __repr__(self): + return '\n'.join(['Expected:', str(self.expected), 'Actual:', str(self.actual)]) + + +def id_from_inputs(inputs): + if isinstance(inputs, (dict, OrderedDict)): + inputs = inputs.items() + if not inputs: + return '()' + return ','.join(k + ''.join(map(str, d.shape)) for k, d in inputs) + + def assert_close(actual, expected, atol=1e-6, rtol=1e-6): - assert isinstance(actual, Funsor) - assert isinstance(expected, Funsor) - assert actual.inputs == expected.inputs, (actual.inputs, expected.inputs) - assert actual.output == expected.output - if isinstance(actual, Tensor): - if actual.data.dtype in (torch.long, torch.uint8): - assert (actual.data == expected.data).all() + msg = ActualExpected(actual, expected) + assert type(actual) == type(expected), msg + if isinstance(actual, Funsor): + assert isinstance(actual, Funsor) + assert isinstance(expected, Funsor) + assert actual.inputs == expected.inputs, (actual.inputs, expected.inputs) + assert actual.output == expected.output, (actual.output, expected.output) + + if isinstance(actual, (Number, Tensor)): + assert_close(actual.data, expected.data, atol=atol, rtol=rtol) + elif isinstance(actual, Delta): + assert actual.name == expected.name + assert_close(actual.point, expected.point, atol=atol, rtol=rtol) + assert_close(actual.log_density, expected.log_density, atol=atol, rtol=rtol) + elif isinstance(actual, Gaussian): + assert_close(actual.loc, expected.loc, atol=atol, rtol=rtol) + assert_close(actual.precision, expected.precision, atol=atol, rtol=rtol) + elif isinstance(actual, Joint): + actual_deltas = {d.name: d for d in actual.deltas} + expected_deltas = {d.name: d for d in expected.deltas} + assert set(actual_deltas) == set(expected_deltas) + for name, actual_delta in actual_deltas.items(): + assert_close(actual_delta, expected_deltas[name]) + assert_close(actual.discrete, expected.discrete, atol=atol, rtol=rtol) + assert_close(actual.gaussian, expected.gaussian, atol=atol, rtol=rtol) + elif isinstance(actual, torch.Tensor): + assert actual.dtype == expected.dtype, msg + if actual.dtype in (torch.long, torch.uint8): + assert (actual == expected).all(), msg else: - diff = (actual.data.detach() - expected.data.detach()).abs() - assert diff.max() < atol - assert (diff / (atol + expected.data.detach().abs())).max() < rtol + eq = (actual == expected) + if eq.all(): + return + if eq.any(): + actual = actual[~eq] + expected = expected[~eq] + diff = (actual.detach() - expected.detach()).abs() + if rtol is not None: + assert (diff / (atol + expected.detach().abs())).max() < rtol, msg + elif atol is not None: + assert diff.max() < atol, msg + elif isinstance(actual, numbers.Number): + diff = abs(actual - expected) + if rtol is not None: + assert diff < (atol + expected) * rtol, msg + elif atol is not None: + assert diff < atol, msg else: raise ValueError('cannot compare objects of type {}'.format(type(actual))) @@ -45,6 +116,30 @@ def check_funsor(x, inputs, output, data=None): assert x_data == data +def xfail_param(*args, **kwargs): + return pytest.param(*args, marks=[pytest.mark.xfail(**kwargs)]) + + +def make_einsum_example(equation, fill=None, sizes=(2, 3)): + symbols = sorted(set(equation) - set(',->')) + sizes = {dim: size for dim, size in zip(symbols, itertools.cycle(sizes))} + inputs, outputs = equation.split('->') + inputs = inputs.split(',') + outputs = outputs.split(',') + operands = [] + for dims in inputs: + shape = tuple(sizes[dim] for dim in dims) + operands.append(torch.randn(shape) if fill is None else torch.full(shape, fill)) + funsor_operands = [ + Tensor(operand, OrderedDict([(d, bint(sizes[d])) for d in inp])) + for inp, operand in zip(inputs, operands) + ] + + assert equation == \ + ",".join(["".join(operand.inputs.keys()) for operand in funsor_operands]) + "->" + ",".join(outputs) + return inputs, outputs, sizes, operands, funsor_operands + + def assert_equiv(x, y): """ Check that two funsors are equivalent up to permutation of inputs. @@ -52,18 +147,82 @@ def assert_equiv(x, y): check_funsor(x, y.inputs, y.output, y.data) -def random_tensor(dtype, shape): +def random_tensor(inputs, output=reals()): """ - Creates a random :class:`torch.Tensor` suitable for a given - :class:`~funsor.domains.Domain`. + Creates a random :class:`funsor.torch.Tensor` with given inputs and output. """ - assert isinstance(shape, tuple) - if isinstance(dtype, integer_types): + assert isinstance(inputs, OrderedDict) + assert isinstance(output, Domain) + shape = tuple(d.dtype for d in inputs.values()) + output.shape + if output.dtype == 'real': + data = torch.randn(shape) + else: num_elements = reduce(operator.mul, shape, 1) - return torch.multinomial(torch.ones(dtype), + data = torch.multinomial(torch.ones(output.dtype), num_elements, replacement=True).reshape(shape) - elif dtype == "real": - return torch.randn(shape) + return Tensor(data, inputs, output.dtype) + + +def random_array(inputs, output): + """ + Creates a random :class:`funsor.numpy.Array` with given inputs and output. + """ + assert isinstance(inputs, OrderedDict) + assert isinstance(output, Domain) + shape = tuple(d.dtype for d in inputs.values()) + output.shape + if output.dtype == 'real': + data = np.random.normal(size=shape) else: - raise ValueError('unknown dtype: {}'.format(repr(dtype))) + num_elements = reduce(operator.mul, shape, 1) + data = np.random.choice(np.arange(output.dtype), + size=num_elements, + replace=True).reshape(shape) + return Array(data, inputs, output.dtype) + + +def random_gaussian(inputs): + """ + Creates a random :class:`funsor.gaussian.Gaussian` with given inputs. + """ + assert isinstance(inputs, OrderedDict) + batch_shape = tuple(d.dtype for d in inputs.values() if d.dtype != 'real') + event_shape = (sum(d.num_elements for d in inputs.values() if d.dtype == 'real'),) + loc = torch.randn(batch_shape + event_shape) + prec_sqrt = torch.randn(batch_shape + event_shape + event_shape) + precision = torch.matmul(prec_sqrt, prec_sqrt.transpose(-1, -2)) + precision = precision + 0.05 * torch.eye(event_shape[0]) + return Gaussian(loc, precision, inputs) + + +def make_plated_hmm_einsum(num_steps, num_obs_plates=1, num_hidden_plates=0): + + assert num_obs_plates >= num_hidden_plates + t0 = num_obs_plates + 1 + + obs_plates = ''.join(opt_einsum.get_symbol(i) for i in range(num_obs_plates)) + hidden_plates = ''.join(opt_einsum.get_symbol(i) for i in range(num_hidden_plates)) + + inputs = [str(opt_einsum.get_symbol(t0))] + for t in range(t0, num_steps+t0): + inputs.append(str(opt_einsum.get_symbol(t)) + str(opt_einsum.get_symbol(t+1)) + hidden_plates) + inputs.append(str(opt_einsum.get_symbol(t+1)) + obs_plates) + equation = ",".join(inputs) + "->" + return (equation, ''.join(set(obs_plates + hidden_plates))) + + +def make_chain_einsum(num_steps): + inputs = [str(opt_einsum.get_symbol(0))] + for t in range(num_steps): + inputs.append(str(opt_einsum.get_symbol(t)) + str(opt_einsum.get_symbol(t+1))) + equation = ",".join(inputs) + "->" + return equation + + +def make_hmm_einsum(num_steps): + inputs = [str(opt_einsum.get_symbol(0))] + for t in range(num_steps): + inputs.append(str(opt_einsum.get_symbol(t)) + str(opt_einsum.get_symbol(t+1))) + inputs.append(str(opt_einsum.get_symbol(t+1))) + equation = ",".join(inputs) + "->" + return equation diff --git a/funsor/torch.py b/funsor/torch.py index d8d999907..5079905a9 100644 --- a/funsor/torch.py +++ b/funsor/torch.py @@ -1,15 +1,31 @@ from __future__ import absolute_import, division, print_function import functools +import warnings from collections import OrderedDict +import opt_einsum import torch +from contextlib2 import contextmanager +from multipledispatch import dispatch from six import add_metaclass, integer_types +from six.moves import reduce import funsor.ops as ops +from funsor.contract import Contract, contractor +from funsor.delta import Delta from funsor.domains import Domain, bint, find_domain, reals +from funsor.ops import AssociativeOp, GetitemOp, Op from funsor.six import getargspec -from funsor.terms import Binary, Funsor, FunsorMeta, Number, Variable, eager, to_funsor +from funsor.terms import Binary, Funsor, FunsorMeta, Lambda, Number, Variable, \ + eager, substitute, to_data, to_funsor + + +@contextmanager +def ignore_jit_warnings(): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) + yield def align_tensor(new_inputs, x): @@ -17,8 +33,8 @@ def align_tensor(new_inputs, x): Permute and expand a tensor to match desired ``new_inputs``. :param OrderedDict new_inputs: A target set of inputs. - :param funsor.terms.Funsor x: A :class:`Tensor`s or - :class:`~funsor.terms.Number`. + :param funsor.terms.Funsor x: A :class:`Tensor` or + :class:`~funsor.terms.Number` . :return: a number or :class:`torch.Tensor` that can be broadcast to other tensors with inputs ``new_inputs``. :rtype: tuple @@ -35,7 +51,7 @@ def align_tensor(new_inputs, x): if old_inputs == new_inputs: return data - # Pemute squashed input dims. + # Permute squashed input dims. x_keys = tuple(old_inputs) data = data.permute(tuple(x_keys.index(k) for k in new_inputs if k in old_inputs) + tuple(range(len(old_inputs), data.dim()))) @@ -52,10 +68,10 @@ def align_tensors(*args): This is mainly useful for implementing eager funsor operations. - :param funsor.terms.Funsor \*args: Multiple :class:`Tensor`s and - :class:`~funsor.terms.Number`s. + :param funsor.terms.Funsor \*args: Multiple :class:`Tensor` s and + :class:`~funsor.terms.Number` s. :return: a pair ``(inputs, tensors)`` where tensors are all - :class:`torch.Tensor`s that can be broadcast together to a single data + :class:`torch.Tensor` s that can be broadcast together to a single data with given ``inputs``. :rtype: tuple """ @@ -78,7 +94,6 @@ def __call__(cls, data, inputs=None, dtype="real"): return super(TensorMeta, cls).__call__(data, inputs, dtype) -@to_funsor.register(torch.Tensor) @add_metaclass(TensorMeta) class Tensor(Funsor): """ @@ -90,10 +105,15 @@ class Tensor(Funsor): def __init__(self, data, inputs=None, dtype="real"): assert isinstance(data, torch.Tensor) assert isinstance(inputs, tuple) - assert all(isinstance(d.dtype, integer_types) for k, d in inputs) + if not torch._C._get_tracing_state(): + assert len(inputs) <= data.dim() + for (k, d), size in zip(inputs, data.shape): + assert d.dtype == size inputs = OrderedDict(inputs) output = Domain(data.shape[len(inputs):], dtype) - super(Tensor, self).__init__(inputs, output) + fresh = frozenset(inputs.keys()) + bound = frozenset() + super(Tensor, self).__init__(inputs, output, fresh, bound) self.data = data def __repr__(self): @@ -129,11 +149,9 @@ def align(self, names): assert all(name in self.inputs for name in names) if not names or names == tuple(self.inputs): return self + inputs = OrderedDict((name, self.inputs[name]) for name in names) inputs.update(self.inputs) - - if any(d.shape for d in self.inputs.values()): - raise NotImplementedError("TODO: Implement align with vector indices.") old_dims = tuple(self.inputs) new_dims = tuple(inputs) data = self.data.permute(tuple(old_dims.index(d) for d in new_dims)) @@ -141,7 +159,8 @@ def align(self, names): def eager_subs(self, subs): assert isinstance(subs, tuple) - subs = {k: materialize(v) for k, v in subs if k in self.inputs} + subs = {k: materialize(to_funsor(v, self.inputs[k])) + for k, v in subs if k in self.inputs} if not subs: return self @@ -193,15 +212,23 @@ def eager_subs(self, subs): return Tensor(data, inputs, self.dtype) def eager_unary(self, op): - return Tensor(op(self.data), self.inputs, self.dtype) + dtype = find_domain(op, self.output).dtype + if op in REDUCE_OP_TO_TORCH: + batch_dim = len(self.data.shape) - len(self.output.shape) + data = self.data.reshape(self.data.shape[:batch_dim] + (-1,)) + data = REDUCE_OP_TO_TORCH[op](data, -1) + if op is ops.min or op is ops.max: + data = data[0] + return Tensor(data, self.inputs, dtype) + return Tensor(op(self.data), self.inputs, dtype) def eager_reduce(self, op, reduced_vars): - if op in ops.REDUCE_OP_TO_TORCH: - torch_op = ops.REDUCE_OP_TO_TORCH[op] + if op in REDUCE_OP_TO_TORCH: + torch_op = REDUCE_OP_TO_TORCH[op] assert isinstance(reduced_vars, frozenset) self_vars = frozenset(self.inputs) reduced_vars = reduced_vars & self_vars - if reduced_vars == self_vars: + if reduced_vars == self_vars and not self.output.shape: # Reduce all dims at once. if op is ops.logaddexp: # work around missing torch.Tensor.logsumexp() @@ -225,27 +252,102 @@ def eager_reduce(self, op, reduced_vars): return Tensor(data, inputs, self.dtype) return super(Tensor, self).eager_reduce(op, reduced_vars) + def unscaled_sample(self, sampled_vars, sample_inputs): + assert self.output == reals() + sampled_vars = sampled_vars.intersection(self.inputs) + if not sampled_vars: + return self + + # Partition inputs into sample_inputs + batch_inputs + event_inputs. + sample_inputs = OrderedDict((k, d) for k, d in sample_inputs.items() + if k not in self.inputs) + sample_shape = tuple(int(d.dtype) for d in sample_inputs.values()) + batch_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if k not in sampled_vars) + event_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if k in sampled_vars) + be_inputs = batch_inputs.copy() + be_inputs.update(event_inputs) + sb_inputs = sample_inputs.copy() + sb_inputs.update(batch_inputs) + + # Sample all variables in a single Categorical call. + logits = align_tensor(be_inputs, self) + batch_shape = logits.shape[:len(batch_inputs)] + flat_logits = logits.reshape(batch_shape + (-1,)) + sample_shape = tuple(d.dtype for d in sample_inputs.values()) + flat_sample = torch.distributions.Categorical(logits=flat_logits).sample(sample_shape) + assert flat_sample.shape == sample_shape + batch_shape + results = [] + mod_sample = flat_sample + for name, domain in reversed(list(event_inputs.items())): + size = domain.dtype + point = Tensor(mod_sample % size, sb_inputs, size) + mod_sample = mod_sample / size + results.append(Delta(name, point)) + + # Account for the log normalizer factor. + # Derivation: Let f be a nonnormalized distribution (a funsor), and + # consider operations in linear space (source code is in log space). + # Let x0 ~ f/|f| be a monte carlo sample from a normalized f/|f|. + # f(x0) / |f| # dice numerator + # Let g = delta(x=x0) |f| ----------------- + # detach(f(x0)/|f|) # dice denominator + # |detach(f)| f(x0) + # = delta(x=x0) ----------------- be a dice approximation of f. + # detach(f(x0)) + # Then g is an unbiased estimator of f in value and all derivatives. + # In the special case f = detach(f), we can simplify to + # g = delta(x=x0) |f|. + if flat_logits.requires_grad: + # Apply a dice factor to preserve differentiability. + index = [torch.arange(n).reshape((n,) + (1,) * (flat_logits.dim() - i - 2)) + for i, n in enumerate(flat_logits.shape[:-1])] + index.append(flat_sample) + log_prob = flat_logits[index] + assert log_prob.shape == flat_sample.shape + results.append(Tensor(flat_logits.detach().logsumexp(-1) + + (log_prob - log_prob.detach()), sb_inputs)) + else: + # This is the special case f = detach(f). + results.append(Tensor(flat_logits.logsumexp(-1), batch_inputs)) + + return reduce(ops.add, results) + + +@dispatch(torch.Tensor) +def to_funsor(x): + return Tensor(x) + + +@dispatch(torch.Tensor, Domain) +def to_funsor(x, output): + result = Tensor(x, dtype=output.dtype) + if result.output != output: + raise ValueError("Invalid shape: expected {}, actual {}" + .format(output.shape, result.output.shape)) + return result + -@eager.register(Binary, object, Tensor, Number) +@to_data.register(Tensor) +def _to_data_tensor(x): + if x.inputs: + raise ValueError("cannot convert Tensor to a data due to lazy inputs: {}" + .format(set(x.inputs))) + return x.data + + +@eager.register(Binary, Op, Tensor, Number) def eager_binary_tensor_number(op, lhs, rhs): - if op is ops.getitem: - # Shift by that Funsor is using for inputs. - index = [slice(None)] * len(lhs.inputs) - index.append(rhs.data) - index = tuple(index) - data = lhs.data[index] - else: - data = op(lhs.data, rhs.data) + data = op(lhs.data, rhs.data) return Tensor(data, lhs.inputs, lhs.dtype) -@eager.register(Binary, object, Number, Tensor) +@eager.register(Binary, Op, Number, Tensor) def eager_binary_number_tensor(op, lhs, rhs): data = op(lhs.data, rhs.data) return Tensor(data, rhs.inputs, rhs.dtype) -@eager.register(Binary, object, Tensor, Tensor) +@eager.register(Binary, Op, Tensor, Tensor) def eager_binary_tensor_tensor(op, lhs, rhs): # Compute inputs and outputs. dtype = find_domain(op, lhs.output, rhs.output).dtype @@ -255,18 +357,113 @@ def eager_binary_tensor_tensor(op, lhs, rhs): else: inputs, (lhs_data, rhs_data) = align_tensors(lhs, rhs) - if op is ops.getitem: - # getitem has special shape semantics. - if rhs.output.shape: - raise NotImplementedError('TODO support vector indexing') - assert lhs.output.shape == (rhs.dtype,) - index = [torch.arange(size).reshape((-1,) + (1,) * (lhs_data.dim() - pos - 2)) - for pos, size in enumerate(lhs_data.shape)] - index[-1] = rhs_data - data = lhs_data[tuple(index)] + # Reshape to support broadcasting of output shape. + if inputs: + lhs_dim = len(lhs.output.shape) + rhs_dim = len(rhs.output.shape) + if lhs_dim < rhs_dim: + cut = lhs_data.dim() - lhs_dim + shape = lhs_data.shape + shape = shape[:cut] + (1,) * (rhs_dim - lhs_dim) + shape[cut:] + lhs_data = lhs_data.reshape(shape) + elif rhs_dim < lhs_dim: + cut = rhs_data.dim() - rhs_dim + shape = rhs_data.shape + shape = shape[:cut] + (1,) * (lhs_dim - rhs_dim) + shape[cut:] + rhs_data = rhs_data.reshape(shape) + + data = op(lhs_data, rhs_data) + return Tensor(data, inputs, dtype) + + +@eager.register(Binary, GetitemOp, Tensor, Number) +def eager_getitem_tensor_number(op, lhs, rhs): + index = [slice(None)] * (len(lhs.inputs) + op.offset) + index.append(rhs.data) + index = tuple(index) + data = lhs.data[index] + return Tensor(data, lhs.inputs, lhs.dtype) + + +@eager.register(Binary, GetitemOp, Tensor, Variable) +def eager_getitem_tensor_variable(op, lhs, rhs): + assert op.offset < len(lhs.output.shape) + assert rhs.output == bint(lhs.output.shape[op.offset]) + assert rhs.name not in lhs.inputs + + # Convert a positional event dimension to a named batch dimension. + inputs = lhs.inputs.copy() + inputs[rhs.name] = rhs.output + data = lhs.data + target_dim = len(lhs.inputs) + source_dim = target_dim + op.offset + if target_dim != source_dim: + perm = list(range(data.dim())) + del perm[source_dim] + perm.insert(target_dim, source_dim) + data = data.permute(*perm) + return Tensor(data, inputs, lhs.dtype) + + +@eager.register(Binary, GetitemOp, Tensor, Tensor) +def eager_getitem_tensor_tensor(op, lhs, rhs): + assert op.offset < len(lhs.output.shape) + assert rhs.output == bint(lhs.output.shape[op.offset]) + + # Compute inputs and outputs. + if lhs.inputs == rhs.inputs: + inputs, lhs_data, rhs_data = lhs.inputs, lhs.data, rhs.data + else: + inputs, (lhs_data, rhs_data) = align_tensors(lhs, rhs) + if len(lhs.output.shape) > 1: + rhs_data = rhs_data.reshape(rhs_data.shape + (1,) * (len(lhs.output.shape) - 1)) + + # Perform advanced indexing. + target_dim = len(lhs.inputs) + op.offset + index = [None] * lhs_data.dim() + for i in range(target_dim): + index[i] = torch.arange(lhs_data.size(i)).reshape((-1,) + (1,) * (lhs_data.dim() - i - 2)) + index[target_dim] = rhs_data + for i in range(1 + target_dim, lhs_data.dim()): + index[i] = torch.arange(lhs_data.size(i)).reshape((-1,) + (1,) * (lhs_data.dim() - i - 1)) + data = lhs_data[tuple(index)] + return Tensor(data, inputs, lhs.dtype) + + +@eager.register(Lambda, Variable, Tensor) +def eager_lambda(var, expr): + inputs = expr.inputs.copy() + if var.name in inputs: + inputs.pop(var.name) + inputs[var.name] = var.output + data = align_tensor(inputs, expr) + inputs.pop(var.name) else: - data = op(lhs_data, rhs_data) + data = expr.data + shape = data.shape + dim = len(shape) - len(expr.output.shape) + data = data.reshape(shape[:dim] + (1,) + shape[dim:]) + data = data.expand(shape[:dim] + (var.dtype,) + shape[dim:]) + return Tensor(data, inputs, expr.dtype) + + +@eager.register(Contract, AssociativeOp, AssociativeOp, Tensor, Tensor, frozenset) +@contractor +def eager_contract(sum_op, prod_op, lhs, rhs, reduced_vars): + if (sum_op, prod_op) == (ops.add, ops.mul): + backend = "torch" + elif (sum_op, prod_op) == (ops.logaddexp, ops.add): + backend = "pyro.ops.einsum.torch_log" + else: + return prod_op(lhs, rhs).reduce(sum_op, reduced_vars) + + inputs = OrderedDict((k, d) for t in (lhs, rhs) + for k, d in t.inputs.items() if k not in reduced_vars) + data = opt_einsum.contract(lhs.data, list(lhs.inputs), + rhs.data, list(rhs.inputs), + list(inputs), backend=backend) + dtype = find_domain(prod_op, lhs.output, rhs.output).dtype return Tensor(data, inputs, dtype) @@ -286,34 +483,36 @@ def arange(name, size): def materialize(x): """ Attempt to convert a Funsor to a :class:`~funsor.terms.Number` or - :class:`Tensor` by substituting :func:`arange`s into its free variables. + :class:`Tensor` by substituting :func:`arange` s into its free variables. """ assert isinstance(x, Funsor) if isinstance(x, (Number, Tensor)): return x subs = [] for name, domain in x.inputs.items(): - if not isinstance(domain.dtype, integer_types): - raise ValueError('materialize() requires integer free variables but found ' - '"{}" of domain {}'.format(name, domain)) - assert not domain.shape - subs.append((name, arange(name, domain.dtype))) + if isinstance(domain.dtype, integer_types): + subs.append((name, arange(name, domain.dtype))) subs = tuple(subs) - return x.eager_subs(subs) + return substitute(x, subs) + + +class LazyTuple(tuple): + def __call__(self, *args, **kwargs): + return LazyTuple(x(*args, **kwargs) for x in self) class Function(Funsor): r""" Funsor wrapped by a PyTorch function. - Functions are support broadcasting and can be eagerly evaluated on funsors - with free variables of int type (i.e. batch dimensions). + Functions are assumed to support broadcasting and can be eagerly evaluated + on funsors with free variables of int type (i.e. batch dimensions). - :class:`Function`s are often created via the :func:`function` decorator. + :class:`Function` s are usually created via the :func:`function` decorator. :param callable fn: A PyTorch function to wrap. :param funsor.domains.Domain output: An output domain. - :param Funsor \*args: Funsor arguments. + :param Funsor args: Funsor arguments. """ def __init__(self, fn, output, args): assert callable(fn) @@ -327,18 +526,14 @@ def __init__(self, fn, output, args): self.args = args def __repr__(self): - return 'Function({})'.format(', '.join( - [type(self).__name__, repr(self.output)] + list(map(repr, self.args)))) + name = getattr(self.fn, '__name__', type(self.fn).__name__) + return '{}({}, {}, {})'.format(type(self).__name__, name, + repr(self.output), repr(self.args)) def __str__(self): - return 'Function({})'.format(', '.join( - [type(self).__name__, str(self.output)] + list(map(str, self.args)))) - - def eager_subs(self, subs): - if not any(k in self.inputs for k, v in subs): - return self - args = tuple(arg.eager_subs(subs) for arg in self.args) - return Function(self.fn, self.output, args) + name = getattr(self.fn, '__name__', type(self.fn).__name__) + return '{}({}, {}, {})'.format(type(self).__name__, name, + str(self.output), str(self.args)) @eager.register(Function, object, Domain, tuple) @@ -352,11 +547,57 @@ def eager_function(fn, output, args): return result +def _select(fn, i, *args): + result = fn(*args) + assert isinstance(result, tuple) + return result[i] + + +def _nested_function(fn, args, output): + if isinstance(output, Domain): + return Function(fn, output, args) + elif isinstance(output, tuple): + result = [] + for i, output_i in enumerate(output): + fn_i = functools.partial(_select, fn, i) + fn_i.__name__ = "{}_{}".format(fn_i, i) + result.append(_nested_function(fn_i, args, output_i)) + return LazyTuple(result) + raise ValueError("Invalid output: {}".format(output)) + + +class _Memoized(object): + def __init__(self, fn): + self.fn = fn + self._cache = None + + def __call__(self, *args): + if self._cache is not None: + old_args, old_result = self._cache + if all(x is y for x, y in zip(args, old_args)): + return old_result + result = self.fn(*args) + self._cache = args, result + return result + + @property + def __name__(self): + return self.fn.__name__ + + def _function(inputs, output, fn): - names = getargspec(fn)[0] + if isinstance(fn, torch.nn.Module): + names = getargspec(fn.forward)[0][1:] + else: + names = getargspec(fn)[0] args = tuple(Variable(name, domain) for (name, domain) in zip(names, inputs)) assert len(args) == len(inputs) - return Function(fn, output, args) + if not isinstance(output, Domain): + assert isinstance(output, tuple) + # Memoize multiple-output functions so that invocations can be shared among + # all outputs. This is not foolproof, but does work in simple situations. + fn = _Memoized(fn) + return _nested_function(fn, args, output) def function(*signature): @@ -365,54 +606,243 @@ def function(*signature): Example:: - @funsor.function(reals(3,4), reals(4,5), reals(3,5)) + @funsor.torch.function(reals(3,4), reals(4,5), reals(3,5)) def matmul(x, y): return torch.matmul(x, y) - @funsor.function(reals(10), reals(10, 10), reals()) + @funsor.torch.function(reals(10), reals(10, 10), reals()) def mvn_log_prob(loc, scale_tril, x): d = torch.distributions.MultivariateNormal(loc, scale_tril) return d.log_prob(x) + To support functions that output nested tuples of tensors, specify a nested + tuple of output types, for example:: + + @funsor.torch.function(reals(8), (reals(), bint(8))) + def max_and_argmax(x): + return torch.max(x, dim=-1) + :param \*signature: A sequence if input domains followed by a final output domain. """ assert signature - assert all(isinstance(d, Domain) for d in signature) inputs, output = signature[:-1], signature[-1] + assert all(isinstance(d, Domain) for d in inputs) + assert isinstance(output, (Domain, tuple)) return functools.partial(_function, inputs, output) -def einsum(equation, *operands): +class Einsum(Funsor): """ Wrapper around :func:`torch.einsum` to operate on real-valued Funsors. Note this operates only on the ``output`` tensor. To perform sum-product contractions on named dimensions, instead use ``+`` and :class:`~funsor.terms.Reduce`. + + :param str equation: An einsum equation. + :param tuple operands: A tuple of input funsors. """ - assert isinstance(equation, str) - assert isinstance(operands, tuple) - for x in operands: - assert isinstance(x, Funsor) - assert x.dtype == 'real' - inputs, output = equation.split('->') - inputs = inputs.split(',') - sizes = {dim: size - for input_, operand in zip(inputs, operands) - for dim, size in zip(input_, operand.output.shape)} - output = reals(*(sizes[dim] for dim in output)) - fn = functools.partial(torch.einsum, equation) - return Function(fn, output, operands) + def __init__(self, equation, operands): + assert isinstance(equation, str) + assert isinstance(operands, tuple) + assert all(isinstance(x, Funsor) for x in operands) + ein_inputs, ein_output = equation.split('->') + ein_inputs = ein_inputs.split(',') + size_dict = {} + inputs = OrderedDict() + assert len(ein_inputs) == len(operands) + for ein_input, x in zip(ein_inputs, operands): + assert x.dtype == 'real' + inputs.update(x.inputs) + assert len(ein_inputs) == len(x.output.shape) + for name, size in zip(ein_inputs, x.output.shape): + other_size = size_dict.setdefault(name, size) + if other_size != size: + raise ValueError("Size mismatch at {}: {} vs {}" + .format(name, size, other_size)) + output = reals(*(size_dict[d] for d in ein_output)) + super(Einsum, self).__init__(inputs, output) + self.equation = equation + self.operands = operands + + def __repr__(self): + return 'Einsum({}, {})'.format(repr(self.equation), repr(self.operands)) + + def __str__(self): + return 'Einsum({}, {})'.format(repr(self.equation), str(self.operands)) + + +@eager.register(Einsum, str, tuple) +def eager_einsum(equation, operands): + if all(isinstance(x, Tensor) for x in operands): + inputs, tensors = align_tensors(*operands) + data = torch.einsum(equation, tensors) + return Tensor(data, inputs) + + return None # defer to default implementation + + +def torch_tensordot(x, y, dims): + """ + Wrapper around :func:`torch.tensordot` to operate on real-valued Funsors. + + Note this operates only on the ``output`` tensor. To perform sum-product + contractions on named dimensions, instead use ``+`` and + :class:`~funsor.terms.Reduce`. + """ + x_start, x_end = 0, len(x.output.shape) + y_start = x_end - dims + y_end = y_start + len(y.output.shape) + symbols = 'abcdefghijklmnopqrstuvwxyz' + equation = '{},{}->{}'.format(symbols[x_start:x_end], + symbols[y_start:y_end], + symbols[x_start:y_start] + symbols[x_end:y_end]) + return Einsum(equation, (x, y)) + + +def _torch_stack(dim, *parts): + return torch.stack(parts, dim=dim) + + +def torch_stack(parts, dim=0): + """ + Wrapper around :func:`torch.stack` to operate on real-valued Funsors. + + Note this operates only on the ``output`` tensor. To stack funsors in a + new named dim, instead use :class:`~funsor.terms.Stack`. + """ + assert isinstance(dim, int) + assert isinstance(parts, tuple) + assert len(set(x.output for x in parts)) == 1 + shape = parts[0].output.shape + if dim >= 0: + dim = dim - len(shape) - 1 + assert dim < 0 + split = dim + len(shape) + 1 + shape = shape[:split] + (len(parts),) + shape[split:] + output = Domain(shape, parts[0].dtype) + fn = functools.partial(_torch_stack, dim) + return Function(fn, output, parts) + + +################################################################################ +# Register Ops +################################################################################ + +@ops.abs.register(torch.Tensor) +def _abs(x): + return x.abs() + + +@ops.sqrt.register(torch.Tensor) +def _sqrt(x): + return x.sqrt() + + +@ops.exp.register(torch.Tensor) +def _exp(x): + return x.exp() + + +@ops.log.register(torch.Tensor) +def _log(x): + if x.dtype in (torch.uint8, torch.long): + x = x.float() + return x.log() + + +@ops.log1p.register(torch.Tensor) +def _log1p(x): + return x.log1p() + + +@ops.pow.register(object, torch.Tensor) +def _pow(x, y): + result = x ** y + # work around shape bug https://github.com/pytorch/pytorch/issues/16685 + return result.reshape(y.shape) + + +@ops.pow.register(torch.Tensor, (object, torch.Tensor)) +def _pow(x, y): + return x ** y + + +@ops.min.register(torch.Tensor, torch.Tensor) +def _min(x, y): + return torch.min(x, y) + + +@ops.min.register(object, torch.Tensor) +def _min(x, y): + return y.clamp(max=x) + + +@ops.min.register(torch.Tensor, object) +def _min(x, y): + return x.clamp(max=y) + + +@ops.max.register(torch.Tensor, torch.Tensor) +def _max(x, y): + return torch.max(x, y) + + +@ops.max.register(object, torch.Tensor) +def _max(x, y): + return y.clamp(min=x) + + +@ops.max.register(torch.Tensor, object) +def _max(x, y): + return x.clamp(min=y) + + +@ops.reciprocal.register(torch.Tensor) +def _reciprocal(x): + result = x.reciprocal() + result.clamp_(max=torch.finfo(result.dtype).max) + return result + + +@ops.safesub.register(object, torch.Tensor) +def _safesub(x, y): + try: + return x + -y.clamp(max=torch.finfo(y.dtype).max) + except TypeError: + return x + -y.clamp(max=torch.iinfo(y.dtype).max) + + +@ops.safediv.register(object, torch.Tensor) +def _safediv(x, y): + try: + return x * y.reciprocal().clamp(max=torch.finfo(y.dtype).max) + except TypeError: + return x * y.reciprocal().clamp(max=torch.iinfo(y.dtype).max) + + +REDUCE_OP_TO_TORCH = { + ops.add: torch.sum, + ops.mul: torch.prod, + ops.and_: torch.all, + ops.or_: torch.any, + ops.logaddexp: torch.logsumexp, + ops.min: torch.min, + ops.max: torch.max, +} __all__ = [ + 'Einsum', 'Function', + 'REDUCE_OP_TO_TORCH', 'Tensor', 'align_tensor', 'align_tensors', 'arange', - 'einsum', 'function', + 'ignore_jit_warnings', 'materialize', + 'torch_tensordot', ] diff --git a/setup.cfg b/setup.cfg index 17068cd61..f6bc91e7a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,6 +5,7 @@ ignore = F811,E121,E123,E126,E226,E24,E704,W503,W504 [isort] line_length = 120 +multi_line_output=3 not_skip = __init__.py known_first_party = funsor, test known_third_party = opt_einsum, pyro, six, torch, torchvision diff --git a/setup.py b/setup.py index 779fe27d0..d2a551178 100644 --- a/setup.py +++ b/setup.py @@ -13,22 +13,31 @@ install_requires=[ 'contextlib2', 'multipledispatch', + 'numpy>=1.7', 'opt_einsum>=2.3.2', 'pyro-ppl>=0.3', 'six>=1.10.0', 'torch>=1.0.0', + 'unification', ], extras_require={ - 'test': ['flake8', 'pytest>=4.1'], - 'dev': ['flake8', 'pytest>=4.1', 'isort'], + 'test': ['flake8', 'pytest>=4.1', 'torchvision==0.2.1'], + 'dev': [ + 'flake8', + 'isort', + 'pytest>=4.1', + 'sphinx>=2.0', + 'sphinx_rtd_theme', + 'torchvision==0.2.1', + ], }, tests_require=['flake8', 'pytest>=4.1'], keywords='probabilistic machine learning bayesian statistics pytorch', - license='MIT License', classifiers=[ 'Intended Audience :: Developers', 'Intended Audience :: Education', 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: MIT License', 'Operating System :: POSIX :: Linux', 'Operating System :: MacOS :: MacOS X', 'Programming Language :: Python :: 2.7', diff --git a/test/conftest.py b/test/conftest.py index 511326de8..a3243290b 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -5,3 +5,4 @@ def pytest_runtest_setup(item): pyro.set_rng_seed(0) + pyro.enable_validation(True) diff --git a/test/test_adjoint.py b/test/test_adjoint.py new file mode 100644 index 000000000..bc562c823 --- /dev/null +++ b/test/test_adjoint.py @@ -0,0 +1,164 @@ +from __future__ import absolute_import, division, print_function + +import opt_einsum +import pytest +import torch +from pyro.ops.contract import einsum as pyro_einsum +from pyro.ops.einsum.adjoint import require_backward as pyro_require_backward + +import funsor +from funsor.adjoint import adjoint +from funsor.domains import bint +from funsor.einsum import einsum, naive_einsum, naive_plated_einsum +from funsor.interpreter import interpretation +from funsor.terms import Variable, reflect +from funsor.testing import make_einsum_example, make_plated_hmm_einsum + + +# FIXME rewrite adjoint for compatibility with substitution changes +xfail_with_new_subs = pytest.mark.xfail(True, reason="fails w/ new subs") + + +EINSUM_EXAMPLES = [ + "a->", + "ab->", + ",->", + ",,->", + "a,a->a", + "a,a,a->a", + "a,b->", + "ab,a->", + "a,b,c->", + "a,a->", + "a,a,a,ab->", + "abc,bcd,cde->", + "ab,bc,cd->", + "ab,b,bc,c,cd,d->", +] + + +@xfail_with_new_subs +@pytest.mark.parametrize('einsum_impl', [naive_einsum, einsum]) +@pytest.mark.parametrize('equation', EINSUM_EXAMPLES) +@pytest.mark.parametrize('backend', ['pyro.ops.einsum.torch_marginal']) +def test_einsum_adjoint(einsum_impl, equation, backend): + inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation) + + with interpretation(reflect): + fwd_expr = einsum_impl(equation, *funsor_operands, backend=backend) + actuals = adjoint(fwd_expr, funsor_operands) + + for operand in operands: + pyro_require_backward(operand) + expected_out = pyro_einsum(equation, *operands, + modulo_total=True, + backend=backend)[0] + expected_out._pyro_backward() + + for i, (inp, tv, fv) in enumerate(zip(inputs, operands, funsor_operands)): + actual = actuals[fv] + expected = tv._pyro_backward_result + if inp: + actual = actual.align(tuple(inp)) + assert isinstance(actual, funsor.Tensor) + assert expected.shape == actual.data.shape + assert torch.allclose(expected, actual.data, atol=1e-7) + + +@xfail_with_new_subs +@pytest.mark.parametrize('einsum_impl', [naive_einsum, einsum]) +@pytest.mark.parametrize('equation', EINSUM_EXAMPLES) +@pytest.mark.parametrize('backend', ['pyro.ops.einsum.torch_marginal']) +def test_einsum_adjoint_unary_marginals(einsum_impl, equation, backend): + inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation) + equation = ",".join(inputs) + "->" + + targets = [Variable(k, bint(sizes[k])) for k in set(sizes)] + with interpretation(reflect): + fwd_expr = einsum_impl(equation, *funsor_operands, backend=backend) + actuals = adjoint(fwd_expr, targets) + + for target in targets: + actual = actuals[target] + + expected = opt_einsum.contract(equation + target.name, *operands, + backend=backend) + assert isinstance(actual, funsor.Tensor) + assert expected.shape == actual.data.shape + assert torch.allclose(expected, actual.data, atol=1e-7) + + +PLATED_EINSUM_EXAMPLES = [ + ('i->', 'i'), + (',i->', 'i'), + ('ai->', 'i'), + (',ai,abij->', 'ij'), + ('a,ai,bij->', 'ij'), + ('ai,abi,bci,cdi->', 'i'), + ('aij,abij,bcij->', 'ij'), + ('a,abi,bcij,cdij->', 'ij'), +] + + +@xfail_with_new_subs +@pytest.mark.parametrize('einsum_impl', [naive_plated_einsum, einsum]) +@pytest.mark.parametrize('equation,plates', PLATED_EINSUM_EXAMPLES) +@pytest.mark.parametrize('backend', ['pyro.ops.einsum.torch_marginal']) +def test_plated_einsum_adjoint(einsum_impl, equation, plates, backend): + inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation) + + with interpretation(reflect): + fwd_expr = einsum_impl(equation, *funsor_operands, plates=plates, backend=backend) + actuals = adjoint(fwd_expr, funsor_operands) + + for operand in operands: + pyro_require_backward(operand) + expected_out = pyro_einsum(equation, *operands, + modulo_total=False, + plates=plates, + backend=backend)[0] + expected_out._pyro_backward() + + for i, (inp, tv, fv) in enumerate(zip(inputs, operands, funsor_operands)): + actual = actuals[fv] + expected = tv._pyro_backward_result + if inp: + actual = actual.align(tuple(inp)) + assert isinstance(actual, funsor.Tensor) + assert expected.shape == actual.data.shape + assert torch.allclose(expected, actual.data, atol=1e-7) + + +OPTIMIZED_PLATED_EINSUM_EXAMPLES = [ + make_plated_hmm_einsum(num_steps, num_obs_plates=b, num_hidden_plates=a) + for num_steps in range(20, 50, 6) + for (a, b) in [(0, 0), (0, 1), (0, 2), (1, 1), (1, 2)] +] + + +@xfail_with_new_subs +@pytest.mark.parametrize('equation,plates', OPTIMIZED_PLATED_EINSUM_EXAMPLES) +@pytest.mark.parametrize('backend', ['pyro.ops.einsum.torch_marginal']) +def test_optimized_plated_einsum_adjoint(equation, plates, backend): + inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation) + + with interpretation(reflect): + fwd_expr = einsum(equation, *funsor_operands, plates=plates, backend=backend) + actuals = adjoint(fwd_expr, funsor_operands) + + for operand in operands: + pyro_require_backward(operand) + expected_out = pyro_einsum(equation, *operands, + modulo_total=False, + plates=plates, + backend=backend)[0] + expected_out._pyro_backward() + + for i, (inp, tv, fv) in enumerate(zip(inputs, operands, funsor_operands)): + actual = actuals[fv] + expected = tv._pyro_backward_result + if inp: + actual = actual.align(tuple(inp)) + assert isinstance(actual, funsor.Tensor) + assert expected.shape == actual.data.shape + assert torch.allclose(expected, actual.data, atol=1e-7) diff --git a/test/test_affine.py b/test/test_affine.py new file mode 100644 index 000000000..796e39234 --- /dev/null +++ b/test/test_affine.py @@ -0,0 +1,77 @@ +from __future__ import absolute_import, division, print_function + +from collections import OrderedDict + +import pytest +import torch + +from funsor.affine import Affine +from funsor.domains import bint, reals +from funsor.terms import Number, Variable +from funsor.testing import check_funsor +from funsor.torch import Tensor + +SMOKE_TESTS = [ + ('t+x', Affine), + ('x+t', Affine), + ('n+x', Affine), + ('n*x', Affine), + ('t*x', Affine), + ('x*t', Affine), + ('-x', Affine), + ('t-x', Affine), +] + + +@pytest.mark.parametrize('expr,expected_type', SMOKE_TESTS) +def test_smoke(expr, expected_type): + + t = Tensor(torch.randn(2, 3), OrderedDict([('i', bint(2)), ('j', bint(3))])) + assert isinstance(t, Tensor) + + n = Number(2.) + assert isinstance(n, Number) + + x = Variable('x', reals()) + assert isinstance(x, Variable) + + y = Variable('y', reals()) + assert isinstance(y, Variable) + + result = eval(expr) + assert isinstance(result, expected_type) + + +SUBS_TESTS = [ + ("(t * x)(i=1)", Affine, {"j": bint(3), "x": reals()}), + ("(t * x)(i=1, x=y)", Affine, {"j": bint(3), "y": reals()}), + ("(t * x + n)(x=y)", Affine, {"y": reals(), "i": bint(2), "j": bint(3)}), + ("(x + y)(y=z)", Affine, {"x": reals(), "z": reals()}), + ("(-x)(x=y+z)", Affine, {"y": reals(), "z": reals()}), + ("(t * x + t * y)(x=z)", Affine, {"y": reals(), "z": reals(), "i": bint(2), "j": bint(3)}), +] + + +@pytest.mark.parametrize("expr,expected_type,expected_inputs", SUBS_TESTS) +def test_affine_subs(expr, expected_type, expected_inputs): + + expected_output = reals() + + t = Tensor(torch.randn(2, 3), OrderedDict([('i', bint(2)), ('j', bint(3))])) + assert isinstance(t, Tensor) + + n = Number(2.) + assert isinstance(n, Number) + + x = Variable('x', reals()) + assert isinstance(x, Variable) + + y = Variable('y', reals()) + assert isinstance(y, Variable) + + z = Variable('z', reals()) + assert isinstance(z, Variable) + + result = eval(expr) + assert isinstance(result, expected_type) + check_funsor(result, expected_inputs, expected_output) diff --git a/test/test_alpha_conversion.py b/test/test_alpha_conversion.py new file mode 100644 index 000000000..2669a4c8c --- /dev/null +++ b/test/test_alpha_conversion.py @@ -0,0 +1,98 @@ +from __future__ import absolute_import, division, print_function + +from collections import OrderedDict + +import pytest + +import funsor.ops as ops +from funsor.domains import bint, reals +from funsor.interpreter import gensym, interpretation +from funsor.terms import Independent, Lambda, Variable, reflect +from funsor.testing import assert_close, check_funsor, random_tensor + + +def test_sample_subs_smoke(): + x = random_tensor(OrderedDict([('i', bint(3)), ('j', bint(2))]), reals()) + with interpretation(reflect): + z = x(i=1) + actual = z.sample(frozenset({"j"}), OrderedDict({"i": bint(4)})) + check_funsor(actual, {"j": bint(2), "i": bint(4)}, reals()) + + +def test_subs_reduce(): + x = random_tensor(OrderedDict([('i', bint(3)), ('j', bint(2))]), reals()) + ix = random_tensor(OrderedDict([('i', bint(3))]), bint(2)) + ix2 = ix(i='i2') + with interpretation(reflect): + actual = x.reduce(ops.add, frozenset({"i"})) + actual = actual(j=ix) + expected = x(j=ix2).reduce(ops.add, frozenset({"i"}))(i2='i') + assert_close(actual, expected) + + +@pytest.mark.parametrize('lhs_vars', [(), ('i',), ('j',), ('i', 'j')]) +@pytest.mark.parametrize('rhs_vars', [(), ('i',), ('j',), ('i', 'j')]) +def test_distribute_reduce(lhs_vars, rhs_vars): + + lhs_vars, rhs_vars = frozenset(lhs_vars), frozenset(rhs_vars) + lhs = random_tensor(OrderedDict([('i', bint(3)), ('j', bint(2))]), reals()) + rhs = random_tensor(OrderedDict([('i', bint(3)), ('j', bint(2))]), reals()) + + with interpretation(reflect): + actual_lhs = lhs.reduce(ops.add, lhs_vars) if lhs_vars else lhs + actual_rhs = rhs.reduce(ops.add, rhs_vars) if rhs_vars else rhs + + actual = actual_lhs * actual_rhs + + lhs_subs = {v: gensym(v) for v in lhs_vars} + rhs_subs = {v: gensym(v) for v in rhs_vars} + expected = (lhs(**lhs_subs) * rhs(**rhs_subs)).reduce( + ops.add, frozenset(lhs_subs.values()) | frozenset(rhs_subs.values())) + + assert_close(actual, expected) + + +def test_subs_lambda(): + z = Variable('z', reals()) + i = Variable('i', bint(5)) + ix = random_tensor(OrderedDict([('i', bint(5))]), reals()) + actual = Lambda(i, z)(z=ix) + expected = Lambda(i(i='j'), z(z=ix)) + check_funsor(actual, expected.inputs, expected.output) + assert_close(actual, expected) + + +def test_slice_lambda(): + z = Variable('z', reals()) + i = Variable('i', bint(5)) + j = Variable('j', bint(7)) + zi = Lambda(i, z) + zj = Lambda(j, z) + zij = Lambda(j, zi) + zj2 = zij[:, i] + check_funsor(zj2, zj.inputs, zj.output) + + +def test_subs_independent(): + f = Variable('x', reals(4, 5)) + random_tensor(OrderedDict(i=bint(3))) + + actual = Independent(f, 'x', 'i') + assert 'i' not in actual.inputs + + y = Variable('y', reals(3, 4, 5)) + fsub = y + (0. * random_tensor(OrderedDict(i=bint(7)))) + actual = actual(x=fsub) + assert actual.inputs['i'] == bint(7) + + expected = f(x=y['i']).reduce(ops.add, 'i') + + data = random_tensor(OrderedDict(i=bint(7)), y.output) + assert_close(actual(y=data), expected(y=data)) + + +@pytest.mark.xfail(reason="Independent not quite compatible with sample") +def test_sample_independent(): + f = Variable('x', reals(4, 5)) + random_tensor(OrderedDict(i=bint(3))) + actual = Independent(f, 'x', 'i') + assert actual.sample('i') + assert actual.sample('j', {'i': 2}) diff --git a/test/test_contract.py b/test/test_contract.py new file mode 100644 index 000000000..64d599751 --- /dev/null +++ b/test/test_contract.py @@ -0,0 +1,83 @@ +from __future__ import absolute_import, division, print_function + +from collections import OrderedDict # noqa: F401 + +import pytest +import torch # noqa: F401 + +import funsor +import funsor.ops as ops +from funsor.contract import Contract +from funsor.domains import bint # noqa: F401 +from funsor.einsum import einsum, naive_contract_einsum +from funsor.interpreter import interpretation, reinterpret +from funsor.optimizer import Finitary, optimize +from funsor.terms import reflect +from funsor.testing import assert_close, make_einsum_example +from funsor.torch import Tensor # noqa: F401 + +EINSUM_EXAMPLES = [ + "a,b->", + "ab,a->", + "a,a->", + "a,a,a,ab->", + "ab->", + "ab,bc,cd->", + "abc,bcd,def->", + "abc,abc,bcd,bcd,def,def->", + "ab,bc,cd,de->", + "ab,ab,bc,bc,cd,cd->", +] + + +@pytest.mark.parametrize('equation', EINSUM_EXAMPLES) +@pytest.mark.parametrize('backend,fill', [ + ('torch', None), + ('torch', 1.), + ('pyro.ops.einsum.torch_log', None), + ('pyro.ops.einsum.torch_marginal', None) +]) +def test_contract_einsum_product_lhs(equation, backend, fill): + inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation, fill=fill) + + with interpretation(reflect): + expected = einsum(equation, *funsor_operands, backend=backend) + expected = reinterpret(expected) + actual = naive_contract_einsum(equation, *funsor_operands, backend=backend) + + assert isinstance(actual, funsor.Tensor) and len(outputs) == 1 + print(expected / actual, actual / expected) + assert_close(expected, actual, atol=1e-4) + for output in outputs: + for i, output_dim in enumerate(output): + assert output_dim in actual.inputs + assert actual.inputs[output_dim].dtype == sizes[output_dim] + + +@pytest.mark.parametrize('equation1', EINSUM_EXAMPLES) +@pytest.mark.parametrize('equation2', EINSUM_EXAMPLES) +def test_contract_naive_pair(equation1, equation2): + + # identical structure + case1 = make_einsum_example(equation1) + case2 = make_einsum_example(equation2) + sizes1, funsor_operands1 = case1[2], case1[-1] + sizes2, funsor_operands2 = case2[2], case2[-1] + + assert all(sizes1[k] == sizes2[k] for k in set(sizes1.keys()) & set(sizes2.keys())) + + with interpretation(optimize): + lhs = Finitary(ops.mul, tuple(funsor_operands1)) + rhs = Finitary(ops.mul, tuple(funsor_operands2)) + + expected = (lhs * rhs).reduce(ops.add) + + actual1 = Contract(ops.add, ops.mul, lhs, rhs, frozenset(lhs.inputs) | frozenset(rhs.inputs)) + actual2 = Contract(ops.add, ops.mul, rhs, lhs, frozenset(lhs.inputs) | frozenset(rhs.inputs)) + + actual1 = reinterpret(actual1) + actual2 = reinterpret(actual2) + expected = reinterpret(expected) + + assert_close(actual1, expected, atol=1e-4, rtol=1e-4) + assert_close(actual2, expected, atol=1e-4, rtol=1e-4) diff --git a/test/test_delta.py b/test/test_delta.py new file mode 100644 index 000000000..11b7ceb07 --- /dev/null +++ b/test/test_delta.py @@ -0,0 +1,69 @@ +from __future__ import absolute_import, division, print_function + +import pytest +import torch + +import funsor.ops as ops +from funsor.delta import Delta +from funsor.domains import reals +from funsor.terms import Number, Variable +from funsor.testing import assert_close, check_funsor +from funsor.torch import Tensor + + +def test_eager_subs_variable(): + v = Variable('v', reals(3)) + point = Tensor(torch.randn(3)) + d = Delta('foo', v) + assert d(v=point) is Delta('foo', point) + + +@pytest.mark.parametrize('log_density', [0, 1.234]) +def test_eager_subs_ground(log_density): + point1 = Tensor(torch.randn(3)) + point2 = Tensor(torch.randn(3)) + d = Delta('foo', point1, log_density) + check_funsor(d(foo=point1), {}, reals(), torch.tensor(float(log_density))) + check_funsor(d(foo=point2), {}, reals(), torch.tensor(float('-inf'))) + + +def test_add_delta_funsor(): + x = Variable('x', reals(3)) + y = Variable('y', reals(3)) + d = Delta('x', y) + + expr = -(1 + x ** 2).log() + assert d + expr is d + expr(x=y) + assert expr + d is expr(x=y) + d + + +def test_reduce(): + point = Tensor(torch.randn(3)) + d = Delta('foo', point) + assert d.reduce(ops.logaddexp, frozenset(['foo'])) is Number(0) + + +@pytest.mark.parametrize('log_density', [0, 1.234]) +def test_reduce_density(log_density): + point = Tensor(torch.randn(3)) + d = Delta('foo', point, log_density) + # Note that log_density affects ground substitution but does not affect reduction. + assert d.reduce(ops.logaddexp, frozenset(['foo'])) is Number(0) + + +@pytest.mark.parametrize('shape', [(), (4,), (2, 3)], ids=str) +def test_transform_exp(shape): + point = Tensor(torch.randn(shape).abs()) + x = Variable('x', reals(*shape)) + actual = Delta('y', point)(y=ops.exp(x)) + expected = Delta('x', point.log(), point.log().sum()) + assert_close(actual, expected) + + +@pytest.mark.parametrize('shape', [(), (4,), (2, 3)], ids=str) +def test_transform_log(shape): + point = Tensor(torch.randn(shape)) + x = Variable('x', reals(*shape)) + actual = Delta('y', point)(y=ops.log(x)) + expected = Delta('x', point.exp(), -point.sum()) + assert_close(actual, expected) diff --git a/test/test_distributions.py b/test/test_distributions.py index 623616453..075ec6820 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -4,16 +4,128 @@ from collections import OrderedDict import pytest +import pyro import torch import funsor import funsor.distributions as dist +from funsor.delta import Delta from funsor.domains import bint, reals -from funsor.terms import Variable +from funsor.joint import Joint +from funsor.terms import Independent, Variable from funsor.testing import assert_close, check_funsor, random_tensor from funsor.torch import Tensor +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +@pytest.mark.parametrize('eager', [False, True]) +def test_beta_density(batch_shape, eager): + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) + + @funsor.torch.function(reals(), reals(), reals(), reals()) + def beta(concentration1, concentration0, value): + return torch.distributions.Beta(concentration1, concentration0).log_prob(value) + + check_funsor(beta, {'concentration1': reals(), 'concentration0': reals(), 'value': reals()}, reals()) + + concentration1 = Tensor(torch.randn(batch_shape).exp(), inputs) + concentration0 = Tensor(torch.randn(batch_shape).exp(), inputs) + value = Tensor(torch.rand(batch_shape), inputs) + expected = beta(concentration1, concentration0, value) + check_funsor(expected, inputs, reals()) + + d = Variable('value', reals()) + actual = dist.Beta(concentration1, concentration0, value) if eager else \ + dist.Beta(concentration1, concentration0, d)(value=value) + check_funsor(actual, inputs, reals()) + assert_close(actual, expected) + + +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +@pytest.mark.parametrize('syntax', ['eager', 'lazy', 'generic']) +def test_bernoulli_probs_density(batch_shape, syntax): + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) + + @funsor.torch.function(reals(), reals(), reals()) + def bernoulli(probs, value): + return torch.distributions.Bernoulli(probs).log_prob(value) + + check_funsor(bernoulli, {'probs': reals(), 'value': reals()}, reals()) + + probs = Tensor(torch.rand(batch_shape), inputs) + value = Tensor(torch.rand(batch_shape).round(), inputs) + expected = bernoulli(probs, value) + check_funsor(expected, inputs, reals()) + + d = Variable('value', reals()) + if syntax == 'eager': + actual = dist.BernoulliProbs(probs, value) + elif syntax == 'lazy': + actual = dist.BernoulliProbs(probs, d)(value=value) + elif syntax == 'generic': + actual = dist.Bernoulli(probs=probs)(value=value) + check_funsor(actual, inputs, reals()) + assert_close(actual, expected) + + +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +@pytest.mark.parametrize('syntax', ['eager', 'lazy', 'generic']) +def test_bernoulli_logits_density(batch_shape, syntax): + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) + + @funsor.torch.function(reals(), reals(), reals()) + def bernoulli(logits, value): + return torch.distributions.Bernoulli(logits=logits).log_prob(value) + + check_funsor(bernoulli, {'logits': reals(), 'value': reals()}, reals()) + + logits = Tensor(torch.rand(batch_shape), inputs) + value = Tensor(torch.rand(batch_shape).round(), inputs) + expected = bernoulli(logits, value) + check_funsor(expected, inputs, reals()) + + d = Variable('value', reals()) + if syntax == 'eager': + actual = dist.BernoulliLogits(logits, value) + elif syntax == 'lazy': + actual = dist.BernoulliLogits(logits, d)(value=value) + elif syntax == 'generic': + actual = dist.Bernoulli(logits=logits)(value=value) + check_funsor(actual, inputs, reals()) + assert_close(actual, expected) + + +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +@pytest.mark.parametrize('eager', [False, True]) +def test_binomial_density(batch_shape, eager): + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) + max_count = 10 + + @funsor.torch.function(reals(), reals(), reals(), reals()) + def binomial(total_count, probs, value): + return torch.distributions.Binomial(total_count, probs).log_prob(value) + + check_funsor(binomial, {'total_count': reals(), 'probs': reals(), 'value': reals()}, reals()) + + value_data = random_tensor(inputs, bint(max_count)).data.float() + total_count_data = value_data + random_tensor(inputs, bint(max_count)).data.float() + value = Tensor(value_data, inputs) + total_count = Tensor(total_count_data, inputs) + probs = Tensor(torch.rand(batch_shape), inputs) + expected = binomial(total_count, probs, value) + check_funsor(expected, inputs, reals()) + + m = Variable('value', reals()) + actual = dist.Binomial(total_count, probs, value) if eager else \ + dist.Binomial(total_count, probs, m)(value=value) + check_funsor(actual, inputs, reals()) + assert_close(actual, expected) + + def test_categorical_defaults(): probs = Variable('probs', reals(3)) value = Variable('value', bint(3)) @@ -21,7 +133,7 @@ def test_categorical_defaults(): @pytest.mark.parametrize('size', [4]) -@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)]) +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) def test_categorical_density(size, batch_shape): batch_dims = ('i', 'j', 'k')[:len(batch_shape)] inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) @@ -35,7 +147,7 @@ def categorical(probs, value): probs_data = torch.randn(batch_shape + (size,)).exp() probs_data /= probs_data.sum(-1, keepdim=True) probs = Tensor(probs_data, inputs) - value = Tensor(random_tensor(size, batch_shape), inputs, size) + value = random_tensor(inputs, bint(size)) expected = categorical(probs, value) check_funsor(expected, inputs, reals()) @@ -44,6 +156,152 @@ def categorical(probs, value): assert_close(actual, expected) +def test_delta_defaults(): + v = Variable('v', reals()) + log_density = Variable('log_density', reals()) + assert isinstance(dist.Delta(v, log_density), dist.Delta) + value = Variable('value', reals()) + assert dist.Delta(v, log_density, 'value') is dist.Delta(v, log_density, value) + + +@pytest.mark.parametrize('event_shape', [(), (4,), (3, 2)], ids=str) +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +def test_delta_density(batch_shape, event_shape): + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) + + @funsor.torch.function(reals(*event_shape), reals(), reals(*event_shape), reals()) + def delta(v, log_density, value): + eq = (v == value) + for _ in range(len(event_shape)): + eq = eq.all(dim=-1) + return eq.type(v.dtype).log() + log_density + + check_funsor(delta, {'v': reals(*event_shape), + 'log_density': reals(), + 'value': reals(*event_shape)}, reals()) + + v = Tensor(torch.randn(batch_shape + event_shape), inputs) + log_density = Tensor(torch.randn(batch_shape).exp(), inputs) + for value in [v, Tensor(torch.randn(batch_shape + event_shape), inputs)]: + expected = delta(v, log_density, value) + check_funsor(expected, inputs, reals()) + + actual = dist.Delta(v, log_density, value) + check_funsor(actual, inputs, reals()) + assert_close(actual, expected) + + +def test_delta_delta(): + v = Variable('v', reals(2)) + point = Tensor(torch.randn(2)) + log_density = Tensor(torch.tensor(0.5)) + d = dist.Delta(point, log_density, v) + assert d is Delta('v', point, log_density) + + +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +@pytest.mark.parametrize('event_shape', [(1,), (4,), (5,)], ids=str) +def test_dirichlet_density(batch_shape, event_shape): + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) + + @funsor.torch.function(reals(*event_shape), reals(*event_shape), reals()) + def dirichlet(concentration, value): + return torch.distributions.Dirichlet(concentration).log_prob(value) + + check_funsor(dirichlet, {'concentration': reals(*event_shape), 'value': reals(*event_shape)}, reals()) + + concentration = Tensor(torch.randn(batch_shape + event_shape).exp(), inputs) + value_data = torch.rand(batch_shape + event_shape) + value_data = value_data / value_data.sum(-1, keepdim=True) + value = Tensor(value_data, inputs) + expected = dirichlet(concentration, value) + check_funsor(expected, inputs, reals()) + actual = dist.Dirichlet(concentration, value) + check_funsor(actual, inputs, reals()) + assert_close(actual, expected) + + +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +@pytest.mark.parametrize('event_shape', [(1,), (4,), (5,)], ids=str) +def test_dirichlet_multinomial_density(batch_shape, event_shape): + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) + max_count = 10 + + @funsor.torch.function(reals(*event_shape), reals(), reals(*event_shape), reals()) + def dirichlet_multinomial(concentration, total_count, value): + return pyro.distributions.DirichletMultinomial(concentration, total_count).log_prob(value) + + check_funsor(dirichlet_multinomial, {'concentration': reals(*event_shape), + 'total_count': reals(), + 'value': reals(*event_shape)}, + reals()) + + concentration = Tensor(torch.randn(batch_shape + event_shape).exp(), inputs) + value_data = torch.randint(0, max_count, size=batch_shape + event_shape).float() + total_count_data = value_data.sum(-1) + torch.randint(0, max_count, size=batch_shape).float() + value = Tensor(value_data, inputs) + total_count = Tensor(total_count_data, inputs) + expected = dirichlet_multinomial(concentration, total_count, value) + check_funsor(expected, inputs, reals()) + actual = dist.DirichletMultinomial(concentration, total_count, value) + check_funsor(actual, inputs, reals()) + assert_close(actual, expected) + + +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +def test_lognormal_density(batch_shape): + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) + + @funsor.torch.function(reals(), reals(), reals(), reals()) + def log_normal(loc, scale, value): + return torch.distributions.LogNormal(loc, scale).log_prob(value) + + check_funsor(log_normal, {'loc': reals(), 'scale': reals(), 'value': reals()}, reals()) + + loc = Tensor(torch.randn(batch_shape), inputs) + scale = Tensor(torch.randn(batch_shape).exp(), inputs) + value = Tensor(torch.randn(batch_shape).exp(), inputs) + expected = log_normal(loc, scale, value) + check_funsor(expected, inputs, reals()) + + actual = dist.LogNormal(loc, scale, value) + check_funsor(actual, inputs, reals()) + assert_close(actual, expected) + + +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +@pytest.mark.parametrize('event_shape', [(1,), (4,), (5,)], ids=str) +def test_multinomial_density(batch_shape, event_shape): + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) + max_count = 10 + + @funsor.torch.function(reals(), reals(*event_shape), reals(*event_shape), reals()) + def multinomial(total_count, probs, value): + total_count = total_count.max().item() + return torch.distributions.Multinomial(total_count, probs).log_prob(value) + + check_funsor(multinomial, {'total_count': reals(), 'probs': reals(*event_shape), 'value': reals(*event_shape)}, + reals()) + + probs_data = torch.rand(batch_shape + event_shape) + probs_data = probs_data / probs_data.sum(-1, keepdim=True) + probs = Tensor(probs_data, inputs) + value_data = torch.randint(0, max_count, size=batch_shape + event_shape).float() + total_count_data = value_data.sum(-1) + torch.randint(0, max_count, size=batch_shape).float() + value = Tensor(value_data, inputs) + total_count = Tensor(total_count_data, inputs) + expected = multinomial(total_count, probs, value) + check_funsor(expected, inputs, reals()) + actual = dist.Multinomial(total_count, probs, value) + check_funsor(actual, inputs, reals()) + assert_close(actual, expected) + + def test_normal_defaults(): loc = Variable('loc', reals()) scale = Variable('scale', reals()) @@ -51,7 +309,7 @@ def test_normal_defaults(): assert dist.Normal(loc, scale) is dist.Normal(loc, scale, value) -@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)]) +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) def test_normal_density(batch_shape): batch_dims = ('i', 'j', 'k')[:len(batch_shape)] inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) @@ -71,3 +329,162 @@ def normal(loc, scale, value): actual = dist.Normal(loc, scale, value) check_funsor(actual, inputs, reals()) assert_close(actual, expected) + + +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +def test_normal_gaussian_1(batch_shape): + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) + + loc = Tensor(torch.randn(batch_shape), inputs) + scale = Tensor(torch.randn(batch_shape).exp(), inputs) + value = Tensor(torch.randn(batch_shape), inputs) + + expected = dist.Normal(loc, scale, value) + assert isinstance(expected, Tensor) + check_funsor(expected, inputs, reals()) + + g = dist.Normal(loc, scale, 'value') + assert isinstance(g, Joint) + actual = g(value=value) + check_funsor(actual, inputs, reals()) + + assert_close(actual, expected, atol=1e-4) + + +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +def test_normal_gaussian_2(batch_shape): + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) + + loc = Tensor(torch.randn(batch_shape), inputs) + scale = Tensor(torch.randn(batch_shape).exp(), inputs) + value = Tensor(torch.randn(batch_shape), inputs) + + expected = dist.Normal(loc, scale, value) + assert isinstance(expected, Tensor) + check_funsor(expected, inputs, reals()) + + g = dist.Normal(Variable('value', reals()), scale, loc) + assert isinstance(g, Joint) + actual = g(value=value) + check_funsor(actual, inputs, reals()) + + assert_close(actual, expected, atol=1e-4) + + +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +def test_normal_gaussian_3(batch_shape): + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) + + loc = Tensor(torch.randn(batch_shape), inputs) + scale = Tensor(torch.randn(batch_shape).exp(), inputs) + value = Tensor(torch.randn(batch_shape), inputs) + + expected = dist.Normal(loc, scale, value) + assert isinstance(expected, Tensor) + check_funsor(expected, inputs, reals()) + + g = dist.Normal(Variable('loc', reals()), scale, 'value') + assert isinstance(g, Joint) + actual = g(loc=loc, value=value) + check_funsor(actual, inputs, reals()) + + assert_close(actual, expected, atol=1e-4) + + +NORMAL_AFFINE_TESTS = [ + 'dist.Normal(x+2, scale, y+2)', + 'dist.Normal(y, scale, x)', + 'dist.Normal(x - y, scale, 0)', + 'dist.Normal(0, scale, y - x)', + 'dist.Normal(2 * x - y, scale, x)', + # TODO should we expect these to work without correction terms? + 'dist.Normal(0, 1, (x - y) / scale) - scale.log()', + 'dist.Normal(2 * y, 2 * scale, 2 * x) + math.log(2)', +] + + +@pytest.mark.parametrize('expr', NORMAL_AFFINE_TESTS) +def test_normal_affine(expr): + + scale = Tensor(torch.tensor(0.3), OrderedDict()) + x = Variable('x', reals()) + y = Variable('y', reals()) + + expected = dist.Normal(x, scale, y) + actual = eval(expr) + + assert isinstance(actual, Joint) + assert dict(actual.inputs) == dict(expected.inputs), (actual.inputs, expected.inputs) + + assert_close(actual.gaussian.align(tuple(expected.gaussian.inputs)), expected.gaussian) + assert_close(actual.discrete.align(tuple(expected.discrete.inputs)), expected.discrete) + + +def test_normal_independent(): + loc = random_tensor(OrderedDict(), reals(2)) + scale = random_tensor(OrderedDict(), reals(2)).exp() + fn = dist.Normal(loc['i'], scale['i'], value='z') + assert fn.inputs['z'] == reals() + d = Independent(fn, 'z', 'i') + assert d.inputs['z'] == reals(2) + sample = d.sample(frozenset(['z'])) + assert isinstance(sample, Joint) + assert sample.inputs['z'] == reals(2) + + +def test_mvn_defaults(): + loc = Variable('loc', reals(3)) + scale_tril = Variable('scale', reals(3, 3)) + value = Variable('value', reals(3)) + assert dist.MultivariateNormal(loc, scale_tril) is dist.MultivariateNormal(loc, scale_tril, value) + + +def _random_scale_tril(shape): + data = torch.randn(shape) + return torch.distributions.transform_to(torch.distributions.constraints.lower_cholesky)(data) + + +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +def test_mvn_density(batch_shape): + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) + + @funsor.torch.function(reals(3), reals(3, 3), reals(3), reals()) + def mvn(loc, scale_tril, value): + return torch.distributions.MultivariateNormal(loc, scale_tril=scale_tril).log_prob(value) + + check_funsor(mvn, {'loc': reals(3), 'scale_tril': reals(3, 3), 'value': reals(3)}, reals()) + + loc = Tensor(torch.randn(batch_shape + (3,)), inputs) + scale_tril = Tensor(_random_scale_tril(batch_shape + (3, 3)), inputs) + value = Tensor(torch.randn(batch_shape + (3,)), inputs) + expected = mvn(loc, scale_tril, value) + check_funsor(expected, inputs, reals()) + + actual = dist.MultivariateNormal(loc, scale_tril, value) + check_funsor(actual, inputs, reals()) + assert_close(actual, expected) + + +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +def test_mvn_gaussian(batch_shape): + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) + + loc = Tensor(torch.randn(batch_shape + (3,)), inputs) + scale_tril = Tensor(_random_scale_tril(batch_shape + (3, 3)), inputs) + value = Tensor(torch.randn(batch_shape + (3,)), inputs) + + expected = dist.MultivariateNormal(loc, scale_tril, value) + assert isinstance(expected, Tensor) + check_funsor(expected, inputs, reals()) + + g = dist.MultivariateNormal(loc, scale_tril, 'value') + assert isinstance(g, Joint) + actual = g(value=value) + check_funsor(actual, inputs, reals()) + + assert_close(actual, expected, atol=1e-3, rtol=1e-4) diff --git a/test/test_einsum.py b/test/test_einsum.py index 0769113cb..7d2d92992 100644 --- a/test/test_einsum.py +++ b/test/test_einsum.py @@ -1,77 +1,92 @@ from __future__ import absolute_import, division, print_function -import itertools -import pytest from collections import OrderedDict +import opt_einsum +import pytest import torch from pyro.ops.contract import naive_ubersum import funsor +from funsor.distributions import Categorical +from funsor.domains import bint +from funsor.einsum import naive_einsum, naive_plated_einsum +from funsor.interpreter import interpretation, reinterpret +from funsor.optimizer import apply_optimizer +from funsor.terms import Variable, reflect +from funsor.testing import assert_close, make_einsum_example +from funsor.torch import Tensor + +EINSUM_EXAMPLES = [ + "a,b->", + "ab,a->", + "a,a->", + "a,a->a", + "a,a,a,ab->ab", + "ab->ba", + "ab,bc,cd->da", +] -def xfail_param(*args, **kwargs): - return pytest.param(*args, marks=[pytest.mark.xfail(**kwargs)]) +@pytest.mark.parametrize('equation', EINSUM_EXAMPLES) +@pytest.mark.parametrize('backend', ['torch', 'pyro.ops.einsum.torch_log']) +def test_einsum(equation, backend): + inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation) + expected = opt_einsum.contract(equation, *operands, backend=backend) + with interpretation(reflect): + naive_ast = naive_einsum(equation, *funsor_operands, backend=backend) + optimized_ast = apply_optimizer(naive_ast) + print("Naive expression: {}".format(naive_ast)) + print("Optimized expression: {}".format(optimized_ast)) + actual_optimized = reinterpret(optimized_ast) # eager by default + actual = naive_einsum(equation, *funsor_operands, backend=backend) -def make_example(equation, fill=None, sizes=(2, 3)): - symbols = sorted(set(equation) - set(',->')) - sizes = {dim: size for dim, size in zip(symbols, itertools.cycle(sizes))} - inputs, outputs = equation.split('->') - inputs = inputs.split(',') - outputs = outputs.split(',') - operands = [] - for dims in inputs: - shape = tuple(sizes[dim] for dim in dims) - operands.append(torch.randn(shape) if fill is None else torch.full(shape, fill)) - return inputs, outputs, operands, sizes + assert isinstance(actual, funsor.Tensor) and len(outputs) == 1 + if len(outputs[0]) > 0: + actual = actual.align(tuple(outputs[0])) + actual_optimized = actual_optimized.align(tuple(outputs[0])) + + assert_close(actual, actual_optimized, atol=1e-4) + assert expected.shape == actual.data.shape + assert torch.allclose(expected, actual.data) + for output in outputs: + for i, output_dim in enumerate(output): + assert output_dim in actual.inputs + assert actual.inputs[output_dim].dtype == sizes[output_dim] -def naive_einsum(eqn, *terms): - assert isinstance(eqn, str) - assert all(isinstance(term, funsor.Funsor) for term in terms) - inputs, output = eqn.split('->') - input_dims = frozenset(d for inp in inputs.split(',') for d in inp) - output_dims = frozenset(d for d in output) - reduce_dims = tuple(d for d in input_dims - output_dims) - prod = terms[0] - for term in terms[1:]: - prod = prod * term - for reduce_dim in reduce_dims: - prod = prod.sum(reduce_dim) - return prod +@pytest.mark.parametrize('equation', EINSUM_EXAMPLES) +def test_einsum_categorical(equation): + inputs, outputs, sizes, operands, _ = make_einsum_example(equation) + operands = [operand.abs() / operand.abs().sum(-1, keepdim=True) + for operand in operands] + expected = opt_einsum.contract(equation, *operands, backend='torch') -def naive_plated_einsum(eqn, *terms, **kwargs): - assert isinstance(eqn, str) - assert all(isinstance(term, funsor.Funsor) for term in terms) - # ... - raise NotImplementedError("TODO implement naive plated einsum") + with interpretation(reflect): + funsor_operands = [ + Categorical(probs=Tensor( + operand, + inputs=OrderedDict([(d, bint(sizes[d])) for d in inp[:-1]]) + ))(value=Variable(inp[-1], bint(sizes[inp[-1]]))).exp() + for inp, operand in zip(inputs, operands) + ] + naive_ast = naive_einsum(equation, *funsor_operands) + optimized_ast = apply_optimizer(naive_ast) -EINSUM_EXAMPLES = [ - "a,b->", - "ab,a->", - "a,a->", - "a,a->a", - "a,a,a,ab->ab", - "a,ab,bc,cd->", -] + print("Naive expression: {}".format(naive_ast)) + print("Optimized expression: {}".format(optimized_ast)) + actual_optimized = reinterpret(optimized_ast) # eager by default + actual = naive_einsum(equation, *map(reinterpret, funsor_operands)) -XFAIL_EINSUM_EXAMPLES = [ - xfail_param("ab->ba", reason="align not implemented"), # see pyro-ppl/funsor#26 -] + if len(outputs[0]) > 0: + actual = actual.align(tuple(outputs[0])) + actual_optimized = actual_optimized.align(tuple(outputs[0])) + assert_close(actual, actual_optimized, atol=1e-4) -@pytest.mark.parametrize('equation', EINSUM_EXAMPLES + XFAIL_EINSUM_EXAMPLES) -def test_einsum(equation): - inputs, outputs, operands, sizes = make_example(equation) - funsor_operands = [ - funsor.Tensor(operand, OrderedDict([(d, funsor.bint(sizes[d])) for d in inp])) - for inp, operand in zip(inputs, operands) - ] - expected = torch.einsum(equation, operands) - actual = naive_einsum(equation, *funsor_operands) assert expected.shape == actual.data.shape assert torch.allclose(expected, actual.data) for output in outputs: @@ -80,29 +95,35 @@ def test_einsum(equation): assert actual.inputs[output_dim].dtype == sizes[output_dim] -PLATED_EINSUM_EXAMPLES = [(ex, '') for ex in EINSUM_EXAMPLES] + [ +PLATED_EINSUM_EXAMPLES = [ ('i->', 'i'), - ('i->i', 'i'), (',i->', 'i'), - (',i->i', 'i'), ('ai->', 'i'), - ('ai->i', 'i'), - ('ai->ai', 'i'), - (',ai,abij->aij', 'ij'), - ('a,ai,bij->bij', 'ij'), + (',ai,abij->', 'ij'), + ('a,ai,bij->', 'ij'), + ('ai,abi,bci,cdi->', 'i'), + ('aij,abij,bcij->', 'ij'), + ('a,abi,bcij,cdij->', 'ij'), ] -@pytest.mark.xfail(reason="naive plated einsum not implemented") @pytest.mark.parametrize('equation,plates', PLATED_EINSUM_EXAMPLES) -def test_plated_einsum(equation, plates): - inputs, outputs, operands, sizes = make_example(equation) - funsor_operands = [ - funsor.Tensor(operand, OrderedDict([(d, funsor.bint(sizes[d])) for d in inp])) - for inp, operand in zip(inputs, operands) - ] - expected = naive_ubersum(equation, *operands, plates=plates, backend='torch', modulo_total=False)[0] - actual = naive_plated_einsum(equation, *funsor_operands, plates=plates) +@pytest.mark.parametrize('backend', ['torch', 'pyro.ops.einsum.torch_log']) +def test_plated_einsum(equation, plates, backend): + inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation) + expected = naive_ubersum(equation, *operands, plates=plates, backend=backend, modulo_total=False)[0] + with interpretation(reflect): + naive_ast = naive_plated_einsum(equation, *funsor_operands, plates=plates, backend=backend) + optimized_ast = apply_optimizer(naive_ast) + actual_optimized = reinterpret(optimized_ast) # eager by default + actual = naive_plated_einsum(equation, *funsor_operands, plates=plates, backend=backend) + + if len(outputs[0]) > 0: + actual = actual.align(tuple(outputs[0])) + actual_optimized = actual_optimized.align(tuple(outputs[0])) + + assert_close(actual, actual_optimized, atol=1e-3 if backend == 'torch' else 1e-4) + assert expected.shape == actual.data.shape assert torch.allclose(expected, actual.data) for output in outputs: diff --git a/test/test_gaussian.py b/test/test_gaussian.py new file mode 100644 index 000000000..b3bbfb4d9 --- /dev/null +++ b/test/test_gaussian.py @@ -0,0 +1,427 @@ +from __future__ import absolute_import, division, print_function + +import itertools +from collections import OrderedDict + +import pytest +import torch +from six.moves import reduce + +import funsor.ops as ops +from funsor.domains import bint, reals +from funsor.gaussian import BlockMatrix, BlockVector, Gaussian +from funsor.integrate import Integrate +from funsor.interpreter import interpretation +from funsor.joint import Joint +from funsor.montecarlo import monte_carlo, monte_carlo_interpretation +from funsor.terms import Number, Variable +from funsor.testing import assert_close, id_from_inputs, random_gaussian, random_tensor, xfail_if_not_implemented +from funsor.torch import Tensor + + +def test_block_vector(): + shape = (10,) + expected = torch.zeros(shape) + actual = BlockVector(shape) + + expected[1] = torch.randn(()) + actual[1] = expected[1] + + expected[3:5] = torch.randn(2) + actual[3:5] = expected[3:5] + + assert_close(actual.as_tensor(), expected) + + +@pytest.mark.parametrize('batch_shape', [(), (4,), (3, 2)]) +def test_block_vector_batched(batch_shape): + shape = batch_shape + (10,) + expected = torch.zeros(shape) + actual = BlockVector(shape) + + expected[..., 1] = torch.randn(batch_shape) + actual[..., 1] = expected[..., 1] + + expected[..., 3:5] = torch.randn(batch_shape + (2,)) + actual[..., 3:5] = expected[..., 3:5] + + assert_close(actual.as_tensor(), expected) + + +def test_block_matrix(): + shape = (10, 10) + expected = torch.zeros(shape) + actual = BlockMatrix(shape) + + expected[1, 1] = torch.randn(()) + actual[1, 1] = expected[1, 1] + + expected[1, 3:5] = torch.randn(2) + actual[1, 3:5] = expected[1, 3:5] + + expected[3:5, 1] = torch.randn(2) + actual[3:5, 1] = expected[3:5, 1] + + expected[3:5, 3:5] = torch.randn(2, 2) + actual[3:5, 3:5] = expected[3:5, 3:5] + + assert_close(actual.as_tensor(), expected) + + +@pytest.mark.parametrize('batch_shape', [(), (4,), (3, 2)]) +def test_block_matrix_batched(batch_shape): + shape = batch_shape + (10, 10) + expected = torch.zeros(shape) + actual = BlockMatrix(shape) + + expected[..., 1, 1] = torch.randn(batch_shape) + actual[..., 1, 1] = expected[..., 1, 1] + + expected[..., 1, 3:5] = torch.randn(batch_shape + (2,)) + actual[..., 1, 3:5] = expected[..., 1, 3:5] + + expected[..., 3:5, 1] = torch.randn(batch_shape + (2,)) + actual[..., 3:5, 1] = expected[..., 3:5, 1] + + expected[..., 3:5, 3:5] = torch.randn(batch_shape + (2, 2)) + actual[..., 3:5, 3:5] = expected[..., 3:5, 3:5] + + assert_close(actual.as_tensor(), expected) + + +@pytest.mark.parametrize('expr,expected_type', [ + ('-g1', Gaussian), + ('g1 + 1', Joint), + ('g1 - 1', Joint), + ('1 + g1', Joint), + ('g1 + shift', Joint), + ('g1 + shift', Joint), + ('shift + g1', Joint), + ('shift - g1', Joint), + ('g1 + g1', Joint), + ('(g1 + g2 + g2) - g2', Joint), + ('g1(i=i0)', Gaussian), + ('g2(i=i0)', Gaussian), + ('g1(i=i0) + g2(i=i0)', Joint), + ('g1(i=i0) + g2', Joint), + ('g1(x=x0)', Tensor), + ('g2(y=y0)', Tensor), + ('(g1 + g2)(i=i0)', Joint), + ('(g1 + g2)(x=x0, y=y0)', Tensor), + ('(g2 + g1)(x=x0, y=y0)', Tensor), + ('g1.reduce(ops.logaddexp, "x")', Tensor), + ('(g1 + g2).reduce(ops.logaddexp, "x")', Joint), + ('(g1 + g2).reduce(ops.logaddexp, "y")', Joint), + ('(g1 + g2).reduce(ops.logaddexp, frozenset(["x", "y"]))', Tensor), +]) +def test_smoke(expr, expected_type): + g1 = Gaussian( + loc=torch.tensor([[0.0, 0.1, 0.2], + [2.0, 3.0, 4.0]]), + precision=torch.tensor([[[1.0, 0.1, 0.2], + [0.1, 1.0, 0.3], + [0.2, 0.3, 1.0]], + [[1.0, 0.1, 0.2], + [0.1, 1.0, 0.3], + [0.2, 0.3, 1.0]]]), + inputs=OrderedDict([('i', bint(2)), ('x', reals(3))])) + assert isinstance(g1, Gaussian) + + g2 = Gaussian( + loc=torch.tensor([[0.0, 0.1], + [2.0, 3.0]]), + precision=torch.tensor([[[1.0, 0.2], + [0.2, 1.0]], + [[1.0, 0.2], + [0.2, 1.0]]]), + inputs=OrderedDict([('i', bint(2)), ('y', reals(2))])) + assert isinstance(g2, Gaussian) + + shift = Tensor(torch.tensor([-1., 1.]), OrderedDict([('i', bint(2))])) + assert isinstance(shift, Tensor) + + i0 = Number(1, 2) + assert isinstance(i0, Number) + + x0 = Tensor(torch.tensor([0.5, 0.6, 0.7])) + assert isinstance(x0, Tensor) + + y0 = Tensor(torch.tensor([[0.2, 0.3], + [0.8, 0.9]]), + inputs=OrderedDict([('i', bint(2))])) + assert isinstance(y0, Tensor) + + result = eval(expr) + assert isinstance(result, expected_type) + + +@pytest.mark.parametrize('int_inputs', [ + {}, + {'i': bint(2)}, + {'i': bint(2), 'j': bint(3)}, +], ids=id_from_inputs) +@pytest.mark.parametrize('real_inputs', [ + {'x': reals()}, + {'x': reals(4)}, + {'x': reals(2, 3)}, + {'x': reals(), 'y': reals()}, + {'x': reals(2), 'y': reals(3)}, + {'x': reals(4), 'y': reals(2, 3), 'z': reals()}, +], ids=id_from_inputs) +def test_align(int_inputs, real_inputs): + inputs1 = OrderedDict(list(sorted(int_inputs.items())) + + list(sorted(real_inputs.items()))) + inputs2 = OrderedDict(reversed(inputs1.items())) + g1 = random_gaussian(inputs1) + g2 = g1.align(tuple(inputs2)) + assert g2.inputs == inputs2 + g3 = g2.align(tuple(inputs1)) + assert_close(g3, g1) + + +@pytest.mark.parametrize('int_inputs', [ + {}, + {'i': bint(2)}, + {'i': bint(2), 'j': bint(3)}, +], ids=id_from_inputs) +@pytest.mark.parametrize('real_inputs', [ + {'x': reals()}, + {'x': reals(4)}, + {'x': reals(2, 3)}, + {'x': reals(), 'y': reals()}, + {'x': reals(2), 'y': reals(3)}, + {'x': reals(4), 'y': reals(2, 3), 'z': reals()}, +], ids=id_from_inputs) +def test_eager_subs(int_inputs, real_inputs): + int_inputs = OrderedDict(sorted(int_inputs.items())) + real_inputs = OrderedDict(sorted(real_inputs.items())) + inputs = int_inputs.copy() + inputs.update(real_inputs) + + g = random_gaussian(inputs) + + for order in itertools.permutations(inputs): + ground_values = {} + dependent_values = {} + for i, name in enumerate(order): + upstream = OrderedDict([(k, inputs[k]) for k in order[:i] if k in int_inputs]) + value = random_tensor(upstream, inputs[name]) + ground_values[name] = value(**ground_values) + dependent_values[name] = value + + expected = g(**ground_values) + actual = g + for k in reversed(order): + with xfail_if_not_implemented(): + actual = actual(**{k: dependent_values[k]}) + assert_close(actual, expected, atol=1e-4) + + +def test_eager_subs_variable(): + inputs = OrderedDict([('i', bint(2)), ('x', reals()), ('y', reals(2))]) + g1 = random_gaussian(inputs) + + g2 = g1(x='z') + assert set(g2.inputs) == {'i', 'y', 'z'} + + g2 = g1(x='y', y='x') + assert set(g2.inputs) == {'i', 'x', 'y'} + assert g2.inputs['x'] == reals(2) + + +@pytest.mark.parametrize('int_inputs', [ + {}, + {'i': bint(2)}, + {'i': bint(2), 'j': bint(3)}, +], ids=id_from_inputs) +@pytest.mark.parametrize('real_inputs', [ + {'x': reals()}, + {'x': reals(4)}, + {'x': reals(2, 3)}, + {'x': reals(), 'y': reals()}, + {'x': reals(2), 'y': reals(3)}, + {'x': reals(4), 'y': reals(2, 3), 'z': reals()}, +], ids=id_from_inputs) +def test_add_gaussian_number(int_inputs, real_inputs): + int_inputs = OrderedDict(sorted(int_inputs.items())) + real_inputs = OrderedDict(sorted(real_inputs.items())) + inputs = int_inputs.copy() + inputs.update(real_inputs) + + g = random_gaussian(inputs) + n = Number(1.234) + values = {name: random_tensor(int_inputs, domain) + for name, domain in real_inputs.items()} + + assert_close((g + n)(**values), g(**values) + n, atol=1e-4) + assert_close((n + g)(**values), n + g(**values), atol=1e-4) + assert_close((g - n)(**values), g(**values) - n, atol=1e-4) + + +@pytest.mark.parametrize('int_inputs', [ + {}, + {'i': bint(2)}, + {'i': bint(2), 'j': bint(3)}, +], ids=id_from_inputs) +@pytest.mark.parametrize('real_inputs', [ + {'x': reals()}, + {'x': reals(4)}, + {'x': reals(2, 3)}, + {'x': reals(), 'y': reals()}, + {'x': reals(2), 'y': reals(3)}, + {'x': reals(4), 'y': reals(2, 3), 'z': reals()}, +], ids=id_from_inputs) +def test_add_gaussian_tensor(int_inputs, real_inputs): + int_inputs = OrderedDict(sorted(int_inputs.items())) + real_inputs = OrderedDict(sorted(real_inputs.items())) + inputs = int_inputs.copy() + inputs.update(real_inputs) + + g = random_gaussian(inputs) + t = random_tensor(int_inputs, reals()) + values = {name: random_tensor(int_inputs, domain) + for name, domain in real_inputs.items()} + + assert_close((g + t)(**values), g(**values) + t, atol=1e-4) + assert_close((t + g)(**values), t + g(**values), atol=1e-4) + assert_close((g - t)(**values), g(**values) - t, atol=1e-4) + + +@pytest.mark.parametrize('lhs_inputs', [ + {'x': reals()}, + {'y': reals(4)}, + {'z': reals(2, 3)}, + {'x': reals(), 'y': reals(4)}, + {'y': reals(4), 'z': reals(2, 3)}, +], ids=id_from_inputs) +@pytest.mark.parametrize('rhs_inputs', [ + {'x': reals()}, + {'y': reals(4)}, + {'z': reals(2, 3)}, + {'x': reals(), 'y': reals(4)}, + {'y': reals(4), 'z': reals(2, 3)}, +], ids=id_from_inputs) +def test_add_gaussian_gaussian(lhs_inputs, rhs_inputs): + lhs_inputs = OrderedDict(sorted(lhs_inputs.items())) + rhs_inputs = OrderedDict(sorted(rhs_inputs.items())) + inputs = lhs_inputs.copy() + inputs.update(rhs_inputs) + int_inputs = OrderedDict((k, d) for k, d in inputs.items() if d.dtype != 'real') + real_inputs = OrderedDict((k, d) for k, d in inputs.items() if d.dtype == 'real') + + g1 = random_gaussian(lhs_inputs) + g2 = random_gaussian(rhs_inputs) + values = {name: random_tensor(int_inputs, domain) + for name, domain in real_inputs.items()} + + assert_close((g1 + g2)(**values), g1(**values) + g2(**values), atol=1e-4, rtol=None) + + +@pytest.mark.parametrize('inputs', [ + OrderedDict([('i', bint(2)), ('x', reals())]), + OrderedDict([('i', bint(3)), ('x', reals())]), + OrderedDict([('i', bint(2)), ('x', reals(2))]), + OrderedDict([('i', bint(2)), ('x', reals()), ('y', reals())]), + OrderedDict([('i', bint(3)), ('j', bint(4)), ('x', reals(2))]), +], ids=id_from_inputs) +def test_reduce_add(inputs): + g = random_gaussian(inputs) + actual = g.reduce(ops.add, 'i') + + gs = [g(i=i) for i in range(g.inputs['i'].dtype)] + expected = reduce(ops.add, gs) + assert_close(actual, expected) + + +@pytest.mark.parametrize('int_inputs', [ + {}, + {'i': bint(2)}, + {'i': bint(2), 'j': bint(3)}, +], ids=id_from_inputs) +@pytest.mark.parametrize('real_inputs', [ + {'x': reals(), 'y': reals()}, + {'x': reals(2), 'y': reals(3)}, + {'x': reals(4), 'y': reals(2, 3), 'z': reals()}, + {'w': reals(5), 'x': reals(4), 'y': reals(2, 3), 'z': reals()}, +], ids=id_from_inputs) +def test_reduce_logsumexp(int_inputs, real_inputs): + int_inputs = OrderedDict(sorted(int_inputs.items())) + real_inputs = OrderedDict(sorted(real_inputs.items())) + inputs = int_inputs.copy() + inputs.update(real_inputs) + + g = random_gaussian(inputs) + g_xy = g.reduce(ops.logaddexp, frozenset(['x', 'y'])) + assert_close(g_xy, g.reduce(ops.logaddexp, 'x').reduce(ops.logaddexp, 'y'), atol=1e-3, rtol=None) + assert_close(g_xy, g.reduce(ops.logaddexp, 'y').reduce(ops.logaddexp, 'x'), atol=1e-3, rtol=None) + + +@pytest.mark.parametrize('int_inputs', [ + {}, + {'i': bint(2)}, +], ids=id_from_inputs) +@pytest.mark.parametrize('real_inputs', [ + {'x': reals()}, + {'x': reals(4)}, + {'x': reals(2, 3)}, +], ids=id_from_inputs) +def test_integrate_variable(int_inputs, real_inputs): + int_inputs = OrderedDict(sorted(int_inputs.items())) + real_inputs = OrderedDict(sorted(real_inputs.items())) + inputs = int_inputs.copy() + inputs.update(real_inputs) + + log_measure = random_gaussian(inputs) + integrand = reduce(ops.add, [Variable(k, d) for k, d in real_inputs.items()]) + reduced_vars = frozenset(real_inputs) + + with monte_carlo_interpretation(particle=bint(100000)): + approx = Integrate(log_measure, integrand, reduced_vars) + assert isinstance(approx, Tensor) + + exact = Integrate(log_measure, integrand, reduced_vars) + assert isinstance(exact, Tensor) + assert_close(approx, exact, atol=0.1, rtol=0.1) + + +@pytest.mark.parametrize('int_inputs', [ + {}, + {'i': bint(2)}, + {'i': bint(2), 'j': bint(3)}, +], ids=id_from_inputs) +@pytest.mark.parametrize('real_inputs', [ + {'x': reals()}, + {'x': reals(2)}, + {'x': reals(), 'y': reals()}, + {'x': reals(2), 'y': reals(3)}, + {'x': reals(4), 'y': reals(2, 3)}, +], ids=id_from_inputs) +def test_integrate_gaussian(int_inputs, real_inputs): + int_inputs = OrderedDict(sorted(int_inputs.items())) + real_inputs = OrderedDict(sorted(real_inputs.items())) + inputs = int_inputs.copy() + inputs.update(real_inputs) + + log_measure = random_gaussian(inputs) + integrand = random_gaussian(inputs) + reduced_vars = frozenset(real_inputs) + + with monte_carlo_interpretation(particle=bint(10000)): + approx = Integrate(log_measure, integrand, reduced_vars) + assert isinstance(approx, Tensor) + + exact = Integrate(log_measure, integrand, reduced_vars) + assert isinstance(exact, Tensor) + assert_close(approx, exact, atol=0.1, rtol=0.1) + + +@pytest.mark.xfail(reason="numerically unstable") +def test_mc_plate_gaussian(): + log_measure = Gaussian(torch.tensor([0.]), torch.tensor([[1.]]), + (('loc', reals()),)) + torch.tensor(-0.9189) + integrand = Gaussian(torch.randn((100, 1)) + 3., torch.ones((100, 1, 1)), + (('data', bint(100)), ('loc', reals()))) + with interpretation(monte_carlo): + res = Integrate(log_measure, integrand, frozenset({'loc'})) + res = res.reduce(ops.mul, frozenset({'data'})) + assert not torch.isinf(res).any() diff --git a/test/test_joint.py b/test/test_joint.py new file mode 100644 index 000000000..82750c811 --- /dev/null +++ b/test/test_joint.py @@ -0,0 +1,286 @@ +from __future__ import absolute_import, division, print_function + +from collections import OrderedDict + +import pytest +import torch +from six.moves import reduce + +import funsor.ops as ops +from funsor.delta import Delta +from funsor.domains import bint, reals +from funsor.gaussian import Gaussian +from funsor.interpreter import interpretation +from funsor.joint import Joint +from funsor.terms import Number, Reduce, eager, moment_matching +from funsor.testing import assert_close, random_gaussian, random_tensor, xfail_if_not_implemented +from funsor.torch import Tensor + + +def id_from_inputs(inputs): + if not inputs: + return '()' + return ','.join(k + ''.join(map(str, d.shape)) for k, d in inputs.items()) + + +SMOKE_TESTS = [ + ('dx + dy', Joint), + ('dx + g', Joint), + ('dy + g', Joint), + ('g + dx', Joint), + ('g + dy', Joint), + ('dx + t', Joint), + ('dy + t', Joint), + ('dx - t', Joint), + ('dy - t', Joint), + ('t + dx', Joint), + ('t + dy', Joint), + ('g + 1', Joint), + ('g - 1', Joint), + ('1 + g', Joint), + ('g + t', Joint), + ('g - t', Joint), + ('t + g', Joint), + ('t - g', Joint), + ('g + g', Joint), + ('-(g + g)', Joint), + ('(dx + dy)(i=i0)', Joint), + ('(dx + g)(i=i0)', Joint), + ('(dy + g)(i=i0)', Joint), + ('(g + dx)(i=i0)', Joint), + ('(g + dy)(i=i0)', Joint), + ('(dx + t)(i=i0)', Joint), + ('(dy + t)(i=i0)', Joint), + ('(dx - t)(i=i0)', Joint), + ('(dy - t)(i=i0)', Joint), + ('(t + dx)(i=i0)', Joint), + ('(t + dy)(i=i0)', Joint), + ('(g + 1)(i=i0)', Joint), + ('(g - 1)(i=i0)', Joint), + ('(1 + g)(i=i0)', Joint), + ('(g + t)(i=i0)', Joint), + ('(g - t)(i=i0)', Joint), + ('(t + g)(i=i0)', Joint), + ('(g + g)(i=i0)', Joint), + ('(dx + dy)(x=x0)', Joint), + ('(dx + g)(x=x0)', Tensor), + ('(dy + g)(x=x0)', Joint), + ('(g + dx)(x=x0)', Tensor), + ('(g + dy)(x=x0)', Joint), + ('(dx + t)(x=x0)', Tensor), + ('(dy + t)(x=x0)', Joint), + ('(dx - t)(x=x0)', Tensor), + ('(dy - t)(x=x0)', Joint), + ('(t + dx)(x=x0)', Tensor), + ('(t + dy)(x=x0)', Joint), + ('(g + 1)(x=x0)', Tensor), + ('(g - 1)(x=x0)', Tensor), + ('(1 + g)(x=x0)', Tensor), + ('(g + t)(x=x0)', Tensor), + ('(g - t)(x=x0)', Tensor), + ('(t + g)(x=x0)', Tensor), + ('(g + g)(x=x0)', Tensor), + ('(g + dy).reduce(ops.logaddexp, "x")', Joint), + ('(g + dy).reduce(ops.logaddexp, "y")', Gaussian), + ('(t + g + dy).reduce(ops.logaddexp, "x")', Joint), + ('(t + g + dy).reduce(ops.logaddexp, "y")', Joint), + ('(t + g).reduce(ops.logaddexp, "x")', Tensor), +] + + +@pytest.mark.parametrize('expr,expected_type', SMOKE_TESTS) +def test_smoke(expr, expected_type): + dx = Delta('x', Tensor(torch.randn(2, 3), OrderedDict([('i', bint(2))]))) + assert isinstance(dx, Delta) + + dy = Delta('y', Tensor(torch.randn(3, 4), OrderedDict([('j', bint(3))]))) + assert isinstance(dy, Delta) + + t = Tensor(torch.randn(2, 3), OrderedDict([('i', bint(2)), ('j', bint(3))])) + assert isinstance(t, Tensor) + + g = Gaussian( + loc=torch.tensor([[0.0, 0.1, 0.2], + [2.0, 3.0, 4.0]]), + precision=torch.tensor([[[1.0, 0.1, 0.2], + [0.1, 1.0, 0.3], + [0.2, 0.3, 1.0]], + [[1.0, 0.1, 0.2], + [0.1, 1.0, 0.3], + [0.2, 0.3, 1.0]]]), + inputs=OrderedDict([('i', bint(2)), ('x', reals(3))])) + assert isinstance(g, Gaussian) + + i0 = Number(1, 2) + assert isinstance(i0, Number) + + x0 = Tensor(torch.tensor([0.5, 0.6, 0.7])) + assert isinstance(x0, Tensor) + + result = eval(expr) + assert isinstance(result, expected_type) + + +@pytest.mark.parametrize('int_inputs', [ + {}, + {'i': bint(2)}, + {'i': bint(2), 'j': bint(3)}, +], ids=id_from_inputs) +@pytest.mark.parametrize('real_inputs', [ + {'x': reals()}, + {'x': reals(4)}, + {'x': reals(2, 3)}, + {'x': reals(), 'y': reals()}, + {'x': reals(2), 'y': reals(3)}, + {'x': reals(4), 'y': reals(2, 3), 'z': reals()}, +], ids=id_from_inputs) +def test_reduce_logaddexp(int_inputs, real_inputs): + int_inputs = OrderedDict(sorted(int_inputs.items())) + real_inputs = OrderedDict(sorted(real_inputs.items())) + inputs = int_inputs.copy() + inputs.update(real_inputs) + + t = random_tensor(int_inputs) + g = random_gaussian(inputs) + truth = {name: random_tensor(int_inputs, domain) for name, domain in real_inputs.items()} + + state = 0 + state += g + state += t + for name, point in truth.items(): + with xfail_if_not_implemented(): + state += Delta(name, point) + actual = state.reduce(ops.logaddexp, frozenset(truth)) + + expected = t + g(**truth) + assert_close(actual, expected) + + +def test_reduce_logaddexp_deltas_lazy(): + a = Delta('a', Tensor(torch.randn(3, 2), OrderedDict(i=bint(3)))) + b = Delta('b', Tensor(torch.randn(3), OrderedDict(i=bint(3)))) + x = a + b + assert isinstance(x, Joint) + assert set(x.inputs) == {'a', 'b', 'i'} + + y = x.reduce(ops.logaddexp, 'i') + assert isinstance(y, Reduce) + assert set(y.inputs) == {'a', 'b'} + assert_close(x.reduce(ops.logaddexp), y.reduce(ops.logaddexp)) + + +def test_reduce_logaddexp_deltas_discrete_lazy(): + a = Delta('a', Tensor(torch.randn(3, 2), OrderedDict(i=bint(3)))) + b = Delta('b', Tensor(torch.randn(3), OrderedDict(i=bint(3)))) + c = Tensor(torch.randn(3), OrderedDict(i=bint(3))) + x = a + b + c + assert isinstance(x, Joint) + assert set(x.inputs) == {'a', 'b', 'i'} + + y = x.reduce(ops.logaddexp, 'i') + assert isinstance(y, Reduce) + assert set(y.inputs) == {'a', 'b'} + assert_close(x.reduce(ops.logaddexp), y.reduce(ops.logaddexp)) + + +def test_reduce_logaddexp_gaussian_lazy(): + a = random_gaussian(OrderedDict(i=bint(3), a=reals(2))) + b = random_tensor(OrderedDict(i=bint(3), b=bint(2))) + x = a + b + assert isinstance(x, Joint) + assert set(x.inputs) == {'a', 'b', 'i'} + + y = x.reduce(ops.logaddexp, 'i') + assert isinstance(y, Reduce) + assert set(y.inputs) == {'a', 'b'} + assert_close(x.reduce(ops.logaddexp), y.reduce(ops.logaddexp)) + + +@pytest.mark.parametrize('inputs', [ + OrderedDict([('i', bint(2)), ('x', reals())]), + OrderedDict([('i', bint(3)), ('x', reals())]), + OrderedDict([('i', bint(2)), ('x', reals(2))]), + OrderedDict([('i', bint(2)), ('x', reals()), ('y', reals())]), + OrderedDict([('i', bint(3)), ('j', bint(4)), ('x', reals(2))]), + OrderedDict([('j', bint(2)), ('i', bint(3)), ('k', bint(2)), ('x', reals(2))]), +], ids=id_from_inputs) +def test_reduce_add(inputs): + int_inputs = OrderedDict((k, d) for k, d in inputs.items() if d.dtype != 'real') + x = random_gaussian(inputs) + random_tensor(int_inputs) + assert isinstance(x, Joint) + actual = x.reduce(ops.add, 'i') + + xs = [x(i=i) for i in range(x.inputs['i'].dtype)] + expected = reduce(ops.add, xs) + assert_close(actual, expected, atol=1e-3, rtol=1e-4) + + +def test_reduce_moment_matching_univariate(): + int_inputs = [('i', bint(2))] + real_inputs = [('x', reals())] + inputs = OrderedDict(int_inputs + real_inputs) + int_inputs = OrderedDict(int_inputs) + real_inputs = OrderedDict(real_inputs) + + p = 0.8 + t = 1.234 + s1, s2, s3 = 2.0, 3.0, 4.0 + loc = torch.tensor([[-s1], [s1]]) + precision = torch.tensor([[[s2 ** -2]], [[s3 ** -2]]]) + discrete = Tensor(torch.tensor([1 - p, p]).log() + t, int_inputs) + gaussian = Gaussian(loc, precision, inputs) + joint = discrete + gaussian + with interpretation(moment_matching): + actual = joint.reduce(ops.logaddexp, 'i') + + expected_loc = torch.tensor([(2 * p - 1) * s1]) + expected_variance = (4 * p * (1 - p) * s1 ** 2 + + (1 - p) * s2 ** 2 + + p * s3 ** 2) + expected_precision = torch.tensor([[1 / expected_variance]]) + expected_gaussian = Gaussian(expected_loc, expected_precision, real_inputs) + expected_discrete = Tensor(torch.tensor(t)) + expected = expected_discrete + expected_gaussian + assert_close(actual, expected, atol=1e-5, rtol=None) + + +def test_reduce_moment_matching_multivariate(): + int_inputs = [('i', bint(4))] + real_inputs = [('x', reals(2))] + inputs = OrderedDict(int_inputs + real_inputs) + int_inputs = OrderedDict(int_inputs) + real_inputs = OrderedDict(real_inputs) + + loc = torch.tensor([[-10., -1.], + [+10., -1.], + [+10., +1.], + [-10., +1.]]) + precision = torch.zeros(4, 1, 1) + torch.eye(2, 2) + discrete = Tensor(torch.zeros(4), int_inputs) + gaussian = Gaussian(loc, precision, inputs) + joint = discrete + gaussian + with interpretation(moment_matching): + actual = joint.reduce(ops.logaddexp, 'i') + + expected_loc = torch.zeros(2) + expected_covariance = torch.tensor([[101., 0.], [0., 2.]]) + expected_precision = torch.inverse(expected_covariance) + expected_gaussian = Gaussian(expected_loc, expected_precision, real_inputs) + expected_discrete = Tensor(torch.tensor(4.).log()) + expected = expected_discrete + expected_gaussian + assert_close(actual, expected, atol=1e-5, rtol=None) + + +@pytest.mark.parametrize('interp', [eager, moment_matching], + ids=lambda f: f.__name__) +def test_reduce_moment_matching_shape(interp): + delta = Delta('x', random_tensor(OrderedDict([('h', bint(7))]))) + discrete = random_tensor(OrderedDict( + [('h', bint(7)), ('i', bint(6)), ('j', bint(5)), ('k', bint(4))])) + gaussian = random_gaussian(OrderedDict( + [('k', bint(4)), ('l', bint(3)), ('m', bint(2)), ('y', reals()), ('z', reals(2))])) + reduced_vars = frozenset(['i', 'k', 'l']) + joint = delta + discrete + gaussian + with interpretation(interp): + actual = joint.reduce(ops.logaddexp, reduced_vars) + assert set(actual.inputs) == set(joint.inputs) - reduced_vars diff --git a/test/test_minipyro.py b/test/test_minipyro.py new file mode 100644 index 000000000..e27affd70 --- /dev/null +++ b/test/test_minipyro.py @@ -0,0 +1,531 @@ +from __future__ import absolute_import, division, print_function + +import warnings + +import pytest +import torch + +from torch.autograd import grad +from torch.distributions import constraints, kl_divergence + +from pyro.ops.indexing import Vindex as _Vindex +from pyro.generic import distributions as dist +from pyro.generic import infer, optim, pyro, pyro_backend + +import funsor +from funsor.testing import assert_close, xfail_param + +# This file tests a variety of model,guide pairs with valid and invalid structure. +# See https://github.com/pyro-ppl/pyro/blob/0.3.1/tests/infer/test_valid_models.py + + +def Vindex(x): + if isinstance(x, funsor.Funsor): + return x + return _Vindex(x) + + +def _check_loss_and_grads(expected_loss, actual_loss, atol=1e-4, rtol=1e-4): + # copied from pyro + expected_loss, actual_loss = funsor.to_data(expected_loss), funsor.to_data(actual_loss) + assert_close(actual_loss, expected_loss, atol=atol, rtol=rtol) + names = pyro.get_param_store().keys() + params = [] + for name in names: + params.append(funsor.to_data(pyro.param(name)).unconstrained()) + actual_grads = grad(actual_loss, params, allow_unused=True, retain_graph=True) + expected_grads = grad(expected_loss, params, allow_unused=True, retain_graph=True) + for name, actual_grad, expected_grad in zip(names, actual_grads, expected_grads): + if actual_grad is None or expected_grad is None: + continue + assert_close(actual_grad, expected_grad, atol=atol, rtol=rtol) + + +def assert_ok(model, guide, elbo, *args, **kwargs): + """ + Assert that inference works without warnings or errors. + """ + pyro.get_param_store().clear() + adam = optim.Adam({"lr": 1e-6}) + inference = infer.SVI(model, guide, adam, elbo) + for i in range(2): + inference.step(*args, **kwargs) + + +def assert_error(model, guide, elbo, match=None): + """ + Assert that inference fails with an error. + """ + pyro.get_param_store().clear() + adam = optim.Adam({"lr": 1e-6}) + inference = infer.SVI(model, guide, adam, elbo) + with pytest.raises((NotImplementedError, UserWarning, KeyError, ValueError, RuntimeError), + match=match): + inference.step() + + +def assert_warning(model, guide, elbo): + """ + Assert that inference works but with a warning. + """ + pyro.get_param_store().clear() + adam = optim.Adam({"lr": 1e-6}) + inference = infer.SVI(model, guide, adam, elbo) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + inference.step() + assert len(w), 'No warnings were raised' + for warning in w: + print(warning) + + +@pytest.mark.parametrize("backend", ["pyro", "minipyro", "funsor"]) +def test_generate_data(backend): + + def model(data=None): + loc = pyro.param("loc", torch.tensor(2.0)) + scale = pyro.param("scale", torch.tensor(1.0)) + x = pyro.sample("x", dist.Normal(loc, scale), obs=data) + return x + + with pyro_backend(backend): + data = model().data + assert data.shape == () + + +@pytest.mark.parametrize("backend", ["pyro", "minipyro", "funsor"]) +def test_generate_data_plate(backend): + num_points = 1000 + + def model(data=None): + loc = pyro.param("loc", torch.tensor(2.0)) + scale = pyro.param("scale", torch.tensor(1.0)) + with pyro.plate("data", 1000, dim=-1): + x = pyro.sample("x", dist.Normal(loc, scale), obs=data) + return x + + with pyro_backend(backend): + data = model().data + assert data.shape == (num_points,) + mean = data.sum().item() / num_points + assert 1.9 <= mean <= 2.1 + + +@pytest.mark.parametrize("jit", [False, True], ids=["py", "jit"]) +@pytest.mark.parametrize("backend", ["pyro", "minipyro", "funsor"]) +def test_nonempty_model_empty_guide_ok(backend, jit): + + def model(data): + loc = pyro.param("loc", torch.tensor(0.0)) + pyro.sample("x", dist.Normal(loc, 1.), obs=data) + + def guide(data): + pass + + data = torch.tensor(2.) + with pyro_backend(backend): + Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO + elbo = Elbo(ignore_jit_warnings=True) + assert_ok(model, guide, elbo, data) + + +@pytest.mark.parametrize("jit", [False, True], ids=["py", "jit"]) +@pytest.mark.parametrize("backend", ["pyro", "minipyro", "funsor"]) +def test_plate_ok(backend, jit): + data = torch.randn(10) + + def model(): + locs = pyro.param("locs", torch.tensor([0.2, 0.3, 0.5])) + p = torch.tensor([0.2, 0.3, 0.5]) + with pyro.plate("plate", len(data), dim=-1): + x = pyro.sample("x", dist.Categorical(p)) + pyro.sample("obs", dist.Normal(locs[x], 1.), obs=data) + + def guide(): + p = pyro.param("p", torch.tensor([0.5, 0.3, 0.2])) + with pyro.plate("plate", len(data), dim=-1): + pyro.sample("x", dist.Categorical(p)) + + with pyro_backend(backend): + Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO + elbo = Elbo(ignore_jit_warnings=True) + assert_ok(model, guide, elbo) + + +@pytest.mark.parametrize("jit", [False, True], ids=["py", "jit"]) +@pytest.mark.parametrize("backend", ["pyro", "minipyro", "funsor"]) +def test_nested_plate_plate_ok(backend, jit): + data = torch.randn(2, 3) + + def model(): + loc = torch.tensor(3.0) + with pyro.plate("plate_outer", data.size(-1), dim=-1): + x = pyro.sample("x", dist.Normal(loc, 1.)) + with pyro.plate("plate_inner", data.size(-2), dim=-2): + pyro.sample("y", dist.Normal(x, 1.), obs=data) + + def guide(): + loc = pyro.param("loc", torch.tensor(0.)) + scale = pyro.param("scale", torch.tensor(1.)) + with pyro.plate("plate_outer", data.size(-1), dim=-1): + pyro.sample("x", dist.Normal(loc, scale)) + + with pyro_backend(backend): + Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO + elbo = Elbo(ignore_jit_warnings=True) + assert_ok(model, guide, elbo) + + +@pytest.mark.parametrize("jit", [False, True], ids=["py", "jit"]) +@pytest.mark.parametrize("backend", ["pyro", "funsor"]) +def test_local_param_ok(backend, jit): + data = torch.randn(10) + + def model(): + locs = pyro.param("locs", torch.tensor([-1., 0., 1.])) + with pyro.plate("plate", len(data), dim=-1): + x = pyro.sample("x", dist.Categorical(torch.ones(3) / 3)) + pyro.sample("obs", dist.Normal(locs[x], 1.), obs=data) + + def guide(): + with pyro.plate("plate", len(data), dim=-1): + p = pyro.param("p", torch.ones(len(data), 3) / 3, event_dim=1) + pyro.sample("x", dist.Categorical(p)) + return p + + with pyro_backend(backend): + Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO + elbo = Elbo(ignore_jit_warnings=True) + assert_ok(model, guide, elbo) + + # Check that pyro.param() can be called without init_value. + expected = guide() + actual = pyro.param("p") + assert_close(actual, expected) + + +@pytest.mark.parametrize("jit", [False, True], ids=["py", "jit"]) +@pytest.mark.parametrize("backend", ["pyro", "minipyro", "funsor"]) +def test_constraints(backend, jit): + data = torch.tensor(0.5) + + def model(): + locs = pyro.param("locs", torch.randn(3), constraint=constraints.real) + scales = pyro.param("scales", torch.randn(3).exp(), constraint=constraints.positive) + p = torch.tensor([0.5, 0.3, 0.2]) + x = pyro.sample("x", dist.Categorical(p)) + pyro.sample("obs", dist.Normal(locs[x], scales[x]), obs=data) + + def guide(): + q = pyro.param("q", torch.randn(3).exp(), constraint=constraints.simplex) + pyro.sample("x", dist.Categorical(q)) + + with pyro_backend(backend): + Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO + elbo = Elbo(ignore_jit_warnings=True) + assert_ok(model, guide, elbo) + + +@pytest.mark.parametrize("backend", [ + "pyro", + xfail_param("funsor", reason="missing patterns"), +]) +def test_mean_field_ok(backend): + + def model(): + x = pyro.sample("x", dist.Normal(0., 1.)) + pyro.sample("y", dist.Normal(x, 1.)) + + def guide(): + loc = pyro.param("loc", torch.tensor(0.)) + x = pyro.sample("x", dist.Normal(loc, 1.)) + pyro.sample("y", dist.Normal(x, 1.)) + + with pyro_backend(backend): + elbo = infer.TraceMeanField_ELBO() + assert_ok(model, guide, elbo) + + +@pytest.mark.parametrize("backend", [ + "pyro", + xfail_param("funsor", reason="missing patterns"), +]) +def test_mean_field_warn(backend): + + def model(): + x = pyro.sample("x", dist.Normal(0., 1.)) + pyro.sample("y", dist.Normal(x, 1.)) + + def guide(): + loc = pyro.param("loc", torch.tensor(0.)) + y = pyro.sample("y", dist.Normal(loc, 1.)) + pyro.sample("x", dist.Normal(y, 1.)) + + with pyro_backend(backend): + elbo = infer.TraceMeanField_ELBO() + assert_warning(model, guide, elbo) + + +@pytest.mark.parametrize("backend", ["pyro", "funsor"]) +@pytest.mark.parametrize("inner_dim", [2]) +@pytest.mark.parametrize("outer_dim", [2]) +def test_elbo_plate_plate(backend, outer_dim, inner_dim): + with pyro_backend(backend): + pyro.get_param_store().clear() + num_particles = 1 + q = pyro.param("q", torch.tensor([0.75, 0.25], requires_grad=True)) + p = 0.2693204236205713 # for which kl(Categorical(q), Categorical(p)) = 0.5 + p = torch.tensor([p, 1-p]) + + def model(): + d = dist.Categorical(p) + context1 = pyro.plate("outer", outer_dim, dim=-1) + context2 = pyro.plate("inner", inner_dim, dim=-2) + pyro.sample("w", d) + with context1: + pyro.sample("x", d) + with context2: + pyro.sample("y", d) + with context1, context2: + pyro.sample("z", d) + + def guide(): + d = dist.Categorical(pyro.param("q")) + context1 = pyro.plate("outer", outer_dim, dim=-1) + context2 = pyro.plate("inner", inner_dim, dim=-2) + pyro.sample("w", d, infer={"enumerate": "parallel"}) + with context1: + pyro.sample("x", d, infer={"enumerate": "parallel"}) + with context2: + pyro.sample("y", d, infer={"enumerate": "parallel"}) + with context1, context2: + pyro.sample("z", d, infer={"enumerate": "parallel"}) + + kl_node = kl_divergence(torch.distributions.Categorical(funsor.to_data(q)), + torch.distributions.Categorical(funsor.to_data(p))) + kl = (1 + outer_dim + inner_dim + outer_dim * inner_dim) * kl_node + expected_loss = kl + expected_grad = grad(kl, [funsor.to_data(q)])[0] + + elbo = infer.TraceEnum_ELBO(num_particles=num_particles, + vectorize_particles=True, + strict_enumeration_warning=True) + elbo = elbo.differentiable_loss if backend == "pyro" else elbo + actual_loss = funsor.to_data(elbo(model, guide)) + actual_loss.backward() + actual_grad = funsor.to_data(pyro.param('q')).grad + + assert_close(actual_loss, expected_loss, atol=1e-5) + assert_close(actual_grad, expected_grad, atol=1e-5) + + +@pytest.mark.parametrize('backend', ["pyro", "funsor"]) +def test_elbo_enumerate_plates_1(backend): + # +-----------------+ + # | a ----> b M=2 | + # +-----------------+ + # +-----------------+ + # | c ----> d N=3 | + # +-----------------+ + # This tests two unrelated plates. + # Each should remain uncontracted. + with pyro_backend(backend): + pyro.param("probs_a", + torch.tensor([0.45, 0.55]), + constraint=constraints.simplex) + pyro.param("probs_b", + torch.tensor([[0.6, 0.4], [0.4, 0.6]]), + constraint=constraints.simplex) + pyro.param("probs_c", + torch.tensor([0.75, 0.25]), + constraint=constraints.simplex) + pyro.param("probs_d", + torch.tensor([[0.4, 0.6], [0.3, 0.7]]), + constraint=constraints.simplex) + b_data = torch.tensor([0, 1]) + d_data = torch.tensor([0, 0, 1]) + + def auto_model(): + probs_a = pyro.param("probs_a") + probs_b = pyro.param("probs_b") + probs_c = pyro.param("probs_c") + probs_d = pyro.param("probs_d") + with pyro.plate("a_axis", 2, dim=-1): + a = pyro.sample("a", dist.Categorical(probs_a), + infer={"enumerate": "parallel"}) + pyro.sample("b", dist.Categorical(probs_b[a]), obs=b_data) + with pyro.plate("c_axis", 3, dim=-1): + c = pyro.sample("c", dist.Categorical(probs_c), + infer={"enumerate": "parallel"}) + pyro.sample("d", dist.Categorical(probs_d[c]), obs=d_data) + + def hand_model(): + probs_a = pyro.param("probs_a") + probs_b = pyro.param("probs_b") + probs_c = pyro.param("probs_c") + probs_d = pyro.param("probs_d") + for i in range(2): + a = pyro.sample("a_{}".format(i), dist.Categorical(probs_a), + infer={"enumerate": "parallel"}) + pyro.sample("b_{}".format(i), dist.Categorical(probs_b[a]), obs=b_data[i]) + for j in range(3): + c = pyro.sample("c_{}".format(j), dist.Categorical(probs_c), + infer={"enumerate": "parallel"}) + pyro.sample("d_{}".format(j), dist.Categorical(probs_d[c]), obs=d_data[j]) + + def guide(): + pass + + elbo = infer.TraceEnum_ELBO(max_plate_nesting=1) + elbo = elbo.differentiable_loss if backend == "pyro" else elbo + auto_loss = elbo(auto_model, guide) + elbo = infer.TraceEnum_ELBO(max_plate_nesting=0) + elbo = elbo.differentiable_loss if backend == "pyro" else elbo + hand_loss = elbo(hand_model, guide) + _check_loss_and_grads(hand_loss, auto_loss) + + +@pytest.mark.parametrize('backend', ["pyro", "funsor"]) +def test_elbo_enumerate_plate_7(backend): + # Guide Model + # a -----> b + # | | + # +-|--------|----------------+ + # | V V | + # | c -----> d -----> e N=2 | + # +---------------------------+ + # This tests a mixture of model and guide enumeration. + with pyro_backend(backend): + pyro.param("model_probs_a", + torch.tensor([0.45, 0.55]), + constraint=constraints.simplex) + pyro.param("model_probs_b", + torch.tensor([[0.6, 0.4], [0.4, 0.6]]), + constraint=constraints.simplex) + pyro.param("model_probs_c", + torch.tensor([[0.75, 0.25], [0.55, 0.45]]), + constraint=constraints.simplex) + pyro.param("model_probs_d", + torch.tensor([[[0.4, 0.6], [0.3, 0.7]], [[0.3, 0.7], [0.2, 0.8]]]), + constraint=constraints.simplex) + pyro.param("model_probs_e", + torch.tensor([[0.75, 0.25], [0.55, 0.45]]), + constraint=constraints.simplex) + pyro.param("guide_probs_a", + torch.tensor([0.35, 0.64]), + constraint=constraints.simplex) + pyro.param("guide_probs_c", + torch.tensor([[0., 1.], [1., 0.]]), # deterministic + constraint=constraints.simplex) + + def auto_model(data): + probs_a = pyro.param("model_probs_a") + probs_b = pyro.param("model_probs_b") + probs_c = pyro.param("model_probs_c") + probs_d = pyro.param("model_probs_d") + probs_e = pyro.param("model_probs_e") + a = pyro.sample("a", dist.Categorical(probs_a)) + b = pyro.sample("b", dist.Categorical(probs_b[a]), + infer={"enumerate": "parallel"}) + with pyro.plate("data", 2, dim=-1): + c = pyro.sample("c", dist.Categorical(probs_c[a])) + d = pyro.sample("d", dist.Categorical(Vindex(probs_d)[b, c]), + infer={"enumerate": "parallel"}) + pyro.sample("obs", dist.Categorical(probs_e[d]), obs=data) + + def auto_guide(data): + probs_a = pyro.param("guide_probs_a") + probs_c = pyro.param("guide_probs_c") + a = pyro.sample("a", dist.Categorical(probs_a), + infer={"enumerate": "parallel"}) + with pyro.plate("data", 2, dim=-1): + pyro.sample("c", dist.Categorical(probs_c[a])) + + def hand_model(data): + probs_a = pyro.param("model_probs_a") + probs_b = pyro.param("model_probs_b") + probs_c = pyro.param("model_probs_c") + probs_d = pyro.param("model_probs_d") + probs_e = pyro.param("model_probs_e") + a = pyro.sample("a", dist.Categorical(probs_a)) + b = pyro.sample("b", dist.Categorical(probs_b[a]), + infer={"enumerate": "parallel"}) + for i in range(2): + c = pyro.sample("c_{}".format(i), dist.Categorical(probs_c[a])) + d = pyro.sample("d_{}".format(i), + dist.Categorical(Vindex(probs_d)[b, c]), + infer={"enumerate": "parallel"}) + pyro.sample("obs_{}".format(i), dist.Categorical(probs_e[d]), obs=data[i]) + + def hand_guide(data): + probs_a = pyro.param("guide_probs_a") + probs_c = pyro.param("guide_probs_c") + a = pyro.sample("a", dist.Categorical(probs_a), + infer={"enumerate": "parallel"}) + for i in range(2): + pyro.sample("c_{}".format(i), dist.Categorical(probs_c[a])) + + data = torch.tensor([0, 0]) + elbo = infer.TraceEnum_ELBO(max_plate_nesting=1) + elbo = elbo.differentiable_loss if backend == "pyro" else elbo + auto_loss = elbo(auto_model, auto_guide, data) + elbo = infer.TraceEnum_ELBO(max_plate_nesting=0) + elbo = elbo.differentiable_loss if backend == "pyro" else elbo + hand_loss = elbo(hand_model, hand_guide, data) + _check_loss_and_grads(hand_loss, auto_loss) + + +@pytest.mark.xfail(reason="missing patterns") +@pytest.mark.parametrize("jit", [False, True], ids=["py", "jit"]) +@pytest.mark.parametrize("exact", [ + True, + xfail_param(False, reason="mixed sampling and exact not implemented yet") +], ids=["exact", "monte-carlo"]) +def test_gaussian_probit_hmm_smoke(exact, jit): + + def model(data): + T, N, D = data.shape # time steps, individuals, features + + # Gaussian initial distribution. + init_loc = pyro.param("init_loc", torch.zeros(D)) + init_scale = pyro.param("init_scale", 1e-2 * torch.eye(D), + constraint=constraints.lower_cholesky) + + # Linear dynamics with Gaussian noise. + trans_const = pyro.param("trans_const", torch.zeros(D)) + trans_coeff = pyro.param("trans_coeff", torch.eye(D)) + noise = pyro.param("noise", 1e-2 * torch.eye(D), + constraint=constraints.lower_cholesky) + + obs_plate = pyro.plate("channel", D, dim=-1) + with pyro.plate("data", N, dim=-2): + state = None + for t in range(T): + # Transition. + if t == 0: + loc = init_loc + scale_tril = init_scale + else: + loc = trans_const + funsor.torch.torch_tensordot(trans_coeff, state, 1) + scale_tril = noise + state = pyro.sample("state_{}".format(t), + dist.MultivariateNormal(loc, scale_tril), + infer={"exact": exact}) + + # Factorial probit likelihood model. + with obs_plate: + pyro.sample("obs_{}".format(t), + dist.Bernoulli(logits=state["channel"]), + obs=data[t]) + + def guide(data): + pass + + data = torch.distributions.Bernoulli(0.5).sample((3, 4, 2)) + + with pyro_backend("funsor"): + Elbo = infer.JitTraceEnum_ELBO if jit else infer.TraceEnum_ELBO + elbo = Elbo() + adam = optim.Adam({"lr": 1e-3}) + svi = infer.SVI(model, guide, adam, elbo) + svi.step(data) diff --git a/test/test_numpy.py b/test/test_numpy.py new file mode 100644 index 000000000..4c777ac6c --- /dev/null +++ b/test/test_numpy.py @@ -0,0 +1,220 @@ +from collections import OrderedDict + +import numpy as np +import pytest + +import funsor +from funsor import Number, Variable, bint, reals +from funsor.interpreter import _USE_TCO +from funsor.numpy import Array +from funsor.testing import assert_equiv, check_funsor, random_array + + +# FIXME rewrite stack-based interpreter to be compatible with unhashable data +xfail_with_tco = pytest.mark.xfail( + _USE_TCO, + reason="fails w/ TCO because numpy arrays can't be hashed" +) + + +@pytest.mark.parametrize('shape', [(), (4,), (3, 2)]) +@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32, np.int64, np.uint8]) +def test_to_funsor(shape, dtype): + t = np.random.normal(size=shape).astype(dtype) + f = funsor.to_funsor(t) + assert isinstance(f, Array) + assert funsor.to_funsor(t, reals(*shape)) is f + with pytest.raises(ValueError): + funsor.to_funsor(t, reals(5, *shape)) + + +def test_to_data(): + data = np.zeros((3, 3)) + x = Array(data) + assert funsor.to_data(x) is data + + +def test_to_data_error(): + data = np.zeros((3, 3)) + x = Array(data, OrderedDict(i=bint(3))) + with pytest.raises(ValueError): + funsor.to_data(x) + + +def test_cons_hash(): + x = np.random.randn(3, 3) + assert Array(x) is Array(x) + + +@xfail_with_tco +def test_indexing(): + data = np.random.normal(size=(4, 5)) + inputs = OrderedDict([('i', bint(4)), + ('j', bint(5))]) + x = Array(data, inputs) + check_funsor(x, inputs, reals(), data) + + assert x() is x + assert x(k=3) is x + check_funsor(x(1), {'j': bint(5)}, reals(), data[1]) + check_funsor(x(1, 2), {}, reals(), data[1, 2]) + check_funsor(x(1, 2, k=3), {}, reals(), data[1, 2]) + check_funsor(x(1, j=2), {}, reals(), data[1, 2]) + check_funsor(x(1, j=2, k=3), (), reals(), data[1, 2]) + check_funsor(x(1, k=3), {'j': bint(5)}, reals(), data[1]) + check_funsor(x(i=1), {'j': bint(5)}, reals(), data[1]) + check_funsor(x(i=1, j=2), (), reals(), data[1, 2]) + check_funsor(x(i=1, j=2, k=3), (), reals(), data[1, 2]) + check_funsor(x(i=1, k=3), {'j': bint(5)}, reals(), data[1]) + check_funsor(x(j=2), {'i': bint(4)}, reals(), data[:, 2]) + check_funsor(x(j=2, k=3), {'i': bint(4)}, reals(), data[:, 2]) + + +@xfail_with_tco +def test_advanced_indexing_shape(): + I, J, M, N = 4, 4, 2, 3 + x = Array(np.random.normal(size=(I, J)), OrderedDict([ + ('i', bint(I)), + ('j', bint(J)), + ])) + m = Array(np.array([2, 3]), OrderedDict([('m', bint(M))]), I) + n = Array(np.array([0, 1, 1]), OrderedDict([('n', bint(N))]), J) + assert x.data.shape == (I, J) + + check_funsor(x(i=m), {'j': bint(J), 'm': bint(M)}, reals()) + check_funsor(x(i=m, j=n), {'m': bint(M), 'n': bint(N)}, reals()) + check_funsor(x(i=m, j=n, k=m), {'m': bint(M), 'n': bint(N)}, reals()) + check_funsor(x(i=m, k=m), {'j': bint(J), 'm': bint(M)}, reals()) + check_funsor(x(i=n), {'j': bint(J), 'n': bint(N)}, reals()) + check_funsor(x(i=n, k=m), {'j': bint(J), 'n': bint(N)}, reals()) + check_funsor(x(j=m), {'i': bint(I), 'm': bint(M)}, reals()) + check_funsor(x(j=m, i=n), {'m': bint(M), 'n': bint(N)}, reals()) + check_funsor(x(j=m, i=n, k=m), {'m': bint(M), 'n': bint(N)}, reals()) + check_funsor(x(j=m, k=m), {'i': bint(I), 'm': bint(M)}, reals()) + check_funsor(x(j=n), {'i': bint(I), 'n': bint(N)}, reals()) + check_funsor(x(j=n, k=m), {'i': bint(I), 'n': bint(N)}, reals()) + check_funsor(x(m), {'j': bint(J), 'm': bint(M)}, reals()) + check_funsor(x(m, j=n), {'m': bint(M), 'n': bint(N)}, reals()) + check_funsor(x(m, j=n, k=m), {'m': bint(M), 'n': bint(N)}, reals()) + check_funsor(x(m, k=m), {'j': bint(J), 'm': bint(M)}, reals()) + check_funsor(x(m, n), {'m': bint(M), 'n': bint(N)}, reals()) + check_funsor(x(m, n, k=m), {'m': bint(M), 'n': bint(N)}, reals()) + check_funsor(x(n), {'j': bint(J), 'n': bint(N)}, reals()) + check_funsor(x(n, k=m), {'j': bint(J), 'n': bint(N)}, reals()) + check_funsor(x(n, m), {'m': bint(M), 'n': bint(N)}, reals()) + check_funsor(x(n, m, k=m), {'m': bint(M), 'n': bint(N)}, reals()) + + +@xfail_with_tco +@pytest.mark.parametrize('output_shape', [(), (7,), (3, 2)]) +def test_advanced_indexing_array(output_shape): + # u v + # / \ / \ + # i j k + # \ | / + # \ | / + # x + output = reals(*output_shape) + x = random_array(OrderedDict([ + ('i', bint(2)), + ('j', bint(3)), + ('k', bint(4)), + ]), output) + i = random_array(OrderedDict([ + ('u', bint(5)), + ]), bint(2)) + j = random_array(OrderedDict([ + ('v', bint(6)), + ('u', bint(5)), + ]), bint(3)) + k = random_array(OrderedDict([ + ('v', bint(6)), + ]), bint(4)) + + expected_data = np.empty((5, 6) + output_shape) + for u in range(5): + for v in range(6): + expected_data[u, v] = x.data[i.data[u], j.data[v, u], k.data[v]] + expected = Array(expected_data, OrderedDict([ + ('u', bint(5)), + ('v', bint(6)), + ])) + + assert_equiv(expected, x(i, j, k)) + assert_equiv(expected, x(i=i, j=j, k=k)) + + assert_equiv(expected, x(i=i, j=j)(k=k)) + assert_equiv(expected, x(j=j, k=k)(i=i)) + assert_equiv(expected, x(k=k, i=i)(j=j)) + + assert_equiv(expected, x(i=i)(j=j, k=k)) + assert_equiv(expected, x(j=j)(k=k, i=i)) + assert_equiv(expected, x(k=k)(i=i, j=j)) + + assert_equiv(expected, x(i=i)(j=j)(k=k)) + assert_equiv(expected, x(i=i)(k=k)(j=j)) + assert_equiv(expected, x(j=j)(i=i)(k=k)) + assert_equiv(expected, x(j=j)(k=k)(i=i)) + assert_equiv(expected, x(k=k)(i=i)(j=j)) + assert_equiv(expected, x(k=k)(j=j)(i=i)) + + +@xfail_with_tco +@pytest.mark.parametrize('output_shape', [(), (7,), (3, 2)]) +def test_advanced_indexing_lazy(output_shape): + x = Array(np.random.normal(size=(2, 3, 4) + output_shape), OrderedDict([ + ('i', bint(2)), + ('j', bint(3)), + ('k', bint(4)), + ])) + u = Variable('u', bint(2)) + v = Variable('v', bint(3)) + i = Number(1, 2) - u + j = Number(2, 3) - v + k = u + v + + expected_data = np.empty((2, 3) + output_shape) + i_data = funsor.numpy.materialize(i).data.astype(np.int64) + j_data = funsor.numpy.materialize(j).data.astype(np.int64) + k_data = funsor.numpy.materialize(k).data.astype(np.int64) + for u in range(2): + for v in range(3): + expected_data[u, v] = x.data[i_data[u], j_data[v], k_data[u, v]] + expected = Array(expected_data, OrderedDict([ + ('u', bint(2)), + ('v', bint(3)), + ])) + + assert_equiv(expected, x(i, j, k)) + assert_equiv(expected, x(i=i, j=j, k=k)) + + assert_equiv(expected, x(i=i, j=j)(k=k)) + assert_equiv(expected, x(j=j, k=k)(i=i)) + assert_equiv(expected, x(k=k, i=i)(j=j)) + + assert_equiv(expected, x(i=i)(j=j, k=k)) + assert_equiv(expected, x(j=j)(k=k, i=i)) + assert_equiv(expected, x(k=k)(i=i, j=j)) + + assert_equiv(expected, x(i=i)(j=j)(k=k)) + assert_equiv(expected, x(i=i)(k=k)(j=j)) + assert_equiv(expected, x(j=j)(i=i)(k=k)) + assert_equiv(expected, x(j=j)(k=k)(i=i)) + assert_equiv(expected, x(k=k)(i=i)(j=j)) + assert_equiv(expected, x(k=k)(j=j)(i=i)) + + +@xfail_with_tco +def test_align(): + x = Array(np.random.randn(2, 3, 4), OrderedDict([ + ('i', bint(2)), + ('j', bint(3)), + ('k', bint(4)), + ])) + y = x.align(('j', 'k', 'i')) + assert isinstance(y, Array) + assert tuple(y.inputs) == ('j', 'k', 'i') + for i in range(2): + for j in range(3): + for k in range(4): + assert x(i=i, j=j, k=k) == y(i=i, j=j, k=k) diff --git a/test/test_optimizer.py b/test/test_optimizer.py new file mode 100644 index 000000000..4b4ed247e --- /dev/null +++ b/test/test_optimizer.py @@ -0,0 +1,120 @@ +from __future__ import absolute_import, division, print_function + +from collections import OrderedDict + +import opt_einsum +import pyro.ops.contract as pyro_einsum +import pytest +import torch + +import funsor +from funsor.distributions import Categorical +from funsor.domains import bint +from funsor.einsum import einsum, naive_contract_einsum, naive_einsum, naive_plated_einsum +from funsor.interpreter import interpretation, reinterpret +from funsor.optimizer import apply_optimizer +from funsor.terms import Variable, lazy +from funsor.testing import assert_close, make_chain_einsum, make_einsum_example, make_hmm_einsum, make_plated_hmm_einsum +from funsor.torch import Tensor + +OPTIMIZED_EINSUM_EXAMPLES = [ + make_chain_einsum(t) for t in range(2, 50, 10) +] + [ + make_hmm_einsum(t) for t in range(2, 50, 10) +] + + +@pytest.mark.parametrize('equation', OPTIMIZED_EINSUM_EXAMPLES) +@pytest.mark.parametrize('backend', ['pyro.ops.einsum.torch_log']) +@pytest.mark.parametrize("einsum_impl", [ + naive_einsum, + # naive_contract_einsum, # XXX not working, probably issue with canonicalization +]) +def test_optimized_einsum(equation, backend, einsum_impl): + inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation) + expected = opt_einsum.contract(equation, *operands, backend=backend) + with interpretation(lazy): + naive_ast = einsum_impl(equation, *funsor_operands, backend=backend) + optimized_ast = apply_optimizer(naive_ast) + actual = reinterpret(optimized_ast) # eager by default + + assert isinstance(actual, funsor.Tensor) and len(outputs) == 1 + if len(outputs[0]) > 0: + actual = actual.align(tuple(outputs[0])) + + assert expected.shape == actual.data.shape + assert torch.allclose(expected, actual.data) + for output in outputs: + for i, output_dim in enumerate(output): + assert output_dim in actual.inputs + assert actual.inputs[output_dim].dtype == sizes[output_dim] + + +@pytest.mark.parametrize("eqn1,eqn2", [ + ("ab,bc,cd->d", "de,ef,fg->"), +]) +@pytest.mark.parametrize("optimize1", [False, True]) +@pytest.mark.parametrize("optimize2", [False, True]) +@pytest.mark.parametrize("backend1", ['torch', 'pyro.ops.einsum.torch_log']) +@pytest.mark.parametrize("backend2", ['torch', 'pyro.ops.einsum.torch_log']) +@pytest.mark.parametrize("einsum_impl", [naive_einsum, naive_contract_einsum]) +def test_nested_einsum(eqn1, eqn2, optimize1, optimize2, backend1, backend2, einsum_impl): + inputs1, outputs1, sizes1, operands1, _ = make_einsum_example(eqn1, sizes=(3,)) + inputs2, outputs2, sizes2, operands2, funsor_operands2 = make_einsum_example(eqn2, sizes=(3,)) + + # normalize the probs for ground-truth comparison + operands1 = [operand.abs() / operand.abs().sum(-1, keepdim=True) + for operand in operands1] + + expected1 = opt_einsum.contract(eqn1, *operands1, backend=backend1) + expected2 = opt_einsum.contract(outputs1[0] + "," + eqn2, *([expected1] + operands2), backend=backend2) + + with interpretation(lazy): + funsor_operands1 = [ + Categorical(probs=Tensor( + operand, + inputs=OrderedDict([(d, bint(sizes1[d])) for d in inp[:-1]]) + ))(value=Variable(inp[-1], bint(sizes1[inp[-1]]))).exp() + for inp, operand in zip(inputs1, operands1) + ] + + output1_naive = einsum_impl(eqn1, *funsor_operands1, backend=backend1) + output1 = apply_optimizer(output1_naive) if optimize1 else output1_naive + output2_naive = einsum_impl(outputs1[0] + "," + eqn2, *([output1] + funsor_operands2), backend=backend2) + output2 = apply_optimizer(output2_naive) if optimize2 else output2_naive + + actual1 = reinterpret(output1) + actual2 = reinterpret(output2) + + assert torch.allclose(expected1, actual1.data) + assert torch.allclose(expected2, actual2.data) + + +PLATED_EINSUM_EXAMPLES = [ + make_plated_hmm_einsum(num_steps, num_obs_plates=b, num_hidden_plates=a) + for num_steps in range(3, 50, 6) + for (a, b) in [(0, 1), (0, 2), (0, 0), (1, 1), (1, 2)] +] + + +@pytest.mark.parametrize('equation,plates', PLATED_EINSUM_EXAMPLES) +@pytest.mark.parametrize('backend', ['pyro.ops.einsum.torch_log']) +def test_optimized_plated_einsum(equation, plates, backend): + inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation) + expected = pyro_einsum.einsum(equation, *operands, plates=plates, backend=backend)[0] + actual = einsum(equation, *funsor_operands, plates=plates, backend=backend) + + if len(equation) < 10: + actual_naive = naive_plated_einsum(equation, *funsor_operands, plates=plates, backend=backend) + assert_close(actual, actual_naive) + + assert isinstance(actual, funsor.Tensor) and len(outputs) == 1 + if len(outputs[0]) > 0: + actual = actual.align(tuple(outputs[0])) + + assert expected.shape == actual.data.shape + assert torch.allclose(expected, actual.data) + for output in outputs: + for i, output_dim in enumerate(output): + assert output_dim in actual.inputs + assert actual.inputs[output_dim].dtype == sizes[output_dim] diff --git a/test/test_pattern.py b/test/test_pattern.py new file mode 100644 index 000000000..1ee5503dc --- /dev/null +++ b/test/test_pattern.py @@ -0,0 +1,42 @@ +from __future__ import absolute_import, division, print_function + +from unification import unify + +from funsor.domains import reals +from funsor.interpreter import interpretation, reinterpret +from funsor.pattern import match, match_vars, unify_interpreter +from funsor.terms import Number, Variable, lazy + + +def test_unify_binary(): + with interpretation(lazy): + pattern = Variable('a', reals()) + Number(2.) * Variable('b', reals()) + expr = Number(1.) + Number(2.) * (Number(3.) - Number(4.)) + + subs = unify(pattern, expr) + print(subs, pattern(**{k.name: v for k, v in subs.items()})) + assert subs is not False + + with interpretation(unify_interpreter): + assert unify((pattern,), (expr,)) is not False + + +def test_match_binary(): + with interpretation(lazy): + pattern = Variable('a', reals()) + Number(2.) * Variable('b', reals()) + expr = Number(1.) + Number(2.) * (Number(3.) - Number(4.)) + + @match_vars(pattern) + def expand_2_vars(a, b): + return a + b + b + + @match(pattern) + def expand_2_walk(x): + return x.lhs + x.rhs.rhs + x.rhs.rhs + + eager_val = reinterpret(expr) + lazy_val = expand_2_vars(expr) + assert eager_val == reinterpret(lazy_val) + + lazy_val_2 = expand_2_walk(expr) + assert eager_val == reinterpret(lazy_val_2) diff --git a/test/test_samplers.py b/test/test_samplers.py new file mode 100644 index 000000000..ad90501a3 --- /dev/null +++ b/test/test_samplers.py @@ -0,0 +1,276 @@ +from __future__ import absolute_import, division, print_function + +import itertools +from collections import OrderedDict + +import pytest +import torch +from torch.autograd import grad + +import funsor.distributions as dist +import funsor.ops as ops +from funsor.delta import Delta +from funsor.domains import bint, reals +from funsor.integrate import Integrate +from funsor.joint import Joint +from funsor.montecarlo import monte_carlo_interpretation +from funsor.terms import Variable +from funsor.testing import assert_close, id_from_inputs, random_gaussian, random_tensor, xfail_if_not_implemented +from funsor.torch import align_tensors, materialize + + +@pytest.mark.parametrize('sample_inputs', [ + (), + (('s', bint(6)),), + (('s', bint(6)), ('t', bint(7))), +], ids=id_from_inputs) +@pytest.mark.parametrize('batch_inputs', [ + (), + (('b', bint(4)),), + (('b', bint(4)), ('c', bint(5))), +], ids=id_from_inputs) +@pytest.mark.parametrize('event_inputs', [ + (('e', bint(2)),), + (('e', bint(2)), ('f', bint(3))), +], ids=id_from_inputs) +def test_tensor_shape(sample_inputs, batch_inputs, event_inputs): + be_inputs = OrderedDict(batch_inputs + event_inputs) + expected_inputs = OrderedDict(sample_inputs + batch_inputs + event_inputs) + sample_inputs = OrderedDict(sample_inputs) + batch_inputs = OrderedDict(batch_inputs) + event_inputs = OrderedDict(event_inputs) + x = random_tensor(be_inputs) + + for num_sampled in range(len(event_inputs) + 1): + for sampled_vars in itertools.combinations(list(event_inputs), num_sampled): + sampled_vars = frozenset(sampled_vars) + print('sampled_vars: {}'.format(', '.join(sampled_vars))) + y = x.sample(sampled_vars, sample_inputs) + if num_sampled == len(event_inputs): + assert isinstance(y, (Delta, Joint)) + if sampled_vars: + assert dict(y.inputs) == dict(expected_inputs), sampled_vars + else: + assert y is x + + +@pytest.mark.parametrize('sample_inputs', [ + (), + (('s', bint(3)),), + (('s', bint(3)), ('t', bint(4))), +], ids=id_from_inputs) +@pytest.mark.parametrize('batch_inputs', [ + (), + (('b', bint(2)),), + (('c', reals()),), + (('b', bint(2)), ('c', reals())), +], ids=id_from_inputs) +@pytest.mark.parametrize('event_inputs', [ + (('e', reals()),), + (('e', reals()), ('f', reals(2))), +], ids=id_from_inputs) +def test_gaussian_shape(sample_inputs, batch_inputs, event_inputs): + be_inputs = OrderedDict(batch_inputs + event_inputs) + expected_inputs = OrderedDict(sample_inputs + batch_inputs + event_inputs) + sample_inputs = OrderedDict(sample_inputs) + batch_inputs = OrderedDict(batch_inputs) + event_inputs = OrderedDict(event_inputs) + x = random_gaussian(be_inputs) + + xfail = False + for num_sampled in range(len(event_inputs) + 1): + for sampled_vars in itertools.combinations(list(event_inputs), num_sampled): + sampled_vars = frozenset(sampled_vars) + print('sampled_vars: {}'.format(', '.join(sampled_vars))) + try: + y = x.sample(sampled_vars, sample_inputs) + except NotImplementedError: + xfail = True + continue + if num_sampled == len(event_inputs): + assert isinstance(y, (Delta, Joint)) + if sampled_vars: + assert dict(y.inputs) == dict(expected_inputs), sampled_vars + else: + assert y is x + if xfail: + pytest.xfail(reason='Not implemented') + + +@pytest.mark.parametrize('sample_inputs', [ + (), + (('s', bint(3)),), + (('s', bint(3)), ('t', bint(4))), +], ids=id_from_inputs) +@pytest.mark.parametrize('batch_inputs', [ + (), + (('b', bint(2)),), + (('c', reals()),), + (('b', bint(2)), ('c', reals())), +], ids=id_from_inputs) +@pytest.mark.parametrize('event_inputs', [ + (('e', reals()),), + (('e', reals()), ('f', reals(2))), +], ids=id_from_inputs) +def test_transformed_gaussian_shape(sample_inputs, batch_inputs, event_inputs): + be_inputs = OrderedDict(batch_inputs + event_inputs) + expected_inputs = OrderedDict(sample_inputs + batch_inputs + event_inputs) + sample_inputs = OrderedDict(sample_inputs) + batch_inputs = OrderedDict(batch_inputs) + event_inputs = OrderedDict(event_inputs) + + x = random_gaussian(be_inputs) + x = x(**{name: name + '_' for name, domain in event_inputs.items()}) + x = x(**{name + '_': Variable(name, domain).log() + for name, domain in event_inputs.items()}) + + xfail = False + for num_sampled in range(len(event_inputs) + 1): + for sampled_vars in itertools.combinations(list(event_inputs), num_sampled): + sampled_vars = frozenset(sampled_vars) + print('sampled_vars: {}'.format(', '.join(sampled_vars))) + try: + y = x.sample(sampled_vars, sample_inputs) + except NotImplementedError: + xfail = True + continue + if num_sampled == len(event_inputs): + assert isinstance(y, (Delta, Joint)) + if sampled_vars: + assert dict(y.inputs) == dict(expected_inputs), sampled_vars + else: + assert y is x + if xfail: + pytest.xfail(reason='Not implemented') + + +@pytest.mark.parametrize('sample_inputs', [ + (), + (('s', bint(6)),), + (('s', bint(6)), ('t', bint(7))), +], ids=id_from_inputs) +@pytest.mark.parametrize('int_event_inputs', [ + (), + (('d', bint(2)),), + (('d', bint(2)), ('e', bint(3))), +], ids=id_from_inputs) +@pytest.mark.parametrize('real_event_inputs', [ + (('g', reals()),), + (('g', reals()), ('h', reals(4))), +], ids=id_from_inputs) +def test_joint_shape(sample_inputs, int_event_inputs, real_event_inputs): + event_inputs = int_event_inputs + real_event_inputs + discrete_inputs = OrderedDict(int_event_inputs) + gaussian_inputs = OrderedDict(event_inputs) + expected_inputs = OrderedDict(sample_inputs + event_inputs) + sample_inputs = OrderedDict(sample_inputs) + event_inputs = OrderedDict(event_inputs) + t = random_tensor(discrete_inputs) + g = random_gaussian(gaussian_inputs) + x = Joint(discrete=t, gaussian=g) + + xfail = False + for num_sampled in range(len(event_inputs)): + for sampled_vars in itertools.combinations(list(event_inputs), num_sampled): + sampled_vars = frozenset(sampled_vars) + print('sampled_vars: {}'.format(', '.join(sampled_vars))) + try: + y = x.sample(sampled_vars, sample_inputs) + except NotImplementedError: + xfail = True + continue + if sampled_vars: + assert dict(y.inputs) == dict(expected_inputs), sampled_vars + else: + assert y is x + if xfail: + pytest.xfail(reason='Not implemented') + + +@pytest.mark.parametrize('batch_inputs', [ + (), + (('b', bint(4)),), + (('b', bint(2)), ('c', bint(2))), +], ids=id_from_inputs) +@pytest.mark.parametrize('event_inputs', [ + (('e', bint(3)),), + (('e', bint(2)), ('f', bint(2))), +], ids=id_from_inputs) +@pytest.mark.parametrize('test_grad', [False, True], ids=['value', 'grad']) +def test_tensor_distribution(event_inputs, batch_inputs, test_grad): + num_samples = 50000 + sample_inputs = OrderedDict(n=bint(num_samples)) + be_inputs = OrderedDict(batch_inputs + event_inputs) + batch_inputs = OrderedDict(batch_inputs) + event_inputs = OrderedDict(event_inputs) + sampled_vars = frozenset(event_inputs) + p = random_tensor(be_inputs) + p.data.requires_grad_(test_grad) + + q = p.sample(sampled_vars, sample_inputs) + mq = materialize(q).reduce(ops.logaddexp, 'n') + mq = mq.align(tuple(p.inputs)) + assert_close(mq, p, atol=0.1, rtol=None) + + if test_grad: + _, (p_data, mq_data) = align_tensors(p, mq) + assert p_data.shape == mq_data.shape + probe = torch.randn(p_data.shape) + expected = grad((p_data.exp() * probe).sum(), [p.data])[0] + actual = grad((mq_data.exp() * probe).sum(), [p.data])[0] + assert_close(actual, expected, atol=0.1, rtol=None) + + +@pytest.mark.parametrize('batch_inputs', [ + (), + (('b', bint(3)),), + (('b', bint(3)), ('c', bint(4))), +], ids=id_from_inputs) +@pytest.mark.parametrize('event_inputs', [ + (('e', reals()),), + (('e', reals()), ('f', reals(2))), +], ids=id_from_inputs) +def test_gaussian_distribution(event_inputs, batch_inputs): + num_samples = 100000 + sample_inputs = OrderedDict(particle=bint(num_samples)) + be_inputs = OrderedDict(batch_inputs + event_inputs) + batch_inputs = OrderedDict(batch_inputs) + event_inputs = OrderedDict(event_inputs) + sampled_vars = frozenset(event_inputs) + p = random_gaussian(be_inputs) + + q = p.sample(sampled_vars, sample_inputs) + p_vars = sampled_vars + q_vars = sampled_vars | frozenset(['particle']) + # Check zeroth moment. + assert_close(q.reduce(ops.logaddexp, q_vars), + p.reduce(ops.logaddexp, p_vars), atol=1e-6) + for k1, d1 in event_inputs.items(): + x = Variable(k1, d1) + # Check first moments. + assert_close(Integrate(q, x, q_vars), + Integrate(p, x, p_vars), atol=0.5, rtol=0.2) + for k2, d2 in event_inputs.items(): + y = Variable(k2, d2) + # Check second moments. + continue # FIXME: Quadratic integration is not supported: + assert_close(Integrate(q, x * y, q_vars), + Integrate(p, x * y, p_vars), atol=1e-2) + + +@pytest.mark.parametrize('moment', [0, 1, 2, 3]) +def test_lognormal_distribution(moment): + num_samples = 100000 + inputs = OrderedDict(batch=bint(10)) + loc = random_tensor(inputs) + scale = random_tensor(inputs).exp() + + log_measure = dist.LogNormal(loc, scale) + probe = Variable('x', reals()) ** moment + with monte_carlo_interpretation(particle=bint(num_samples)): + with xfail_if_not_implemented(): + actual = Integrate(log_measure, probe) + + samples = torch.distributions.LogNormal(loc, scale).sample((num_samples,)) + expected = (samples ** moment).mean(0) + assert_close(actual.data, expected, atol=1e-2, rtol=1e-2) diff --git a/test/test_sum_product.py b/test/test_sum_product.py new file mode 100644 index 000000000..44cad3807 --- /dev/null +++ b/test/test_sum_product.py @@ -0,0 +1,81 @@ +from __future__ import absolute_import, division, print_function + +from collections import OrderedDict + +import pytest +from six.moves import reduce + +import funsor.ops as ops +from funsor.domains import bint +from funsor.sum_product import _partition, partial_sum_product, sum_product +from funsor.testing import assert_close, random_tensor + + +@pytest.mark.parametrize('inputs,dims,expected_num_components', [ + ([''], set(), 1), + (['a'], set(), 1), + (['a'], set('a'), 1), + (['a', 'a'], set(), 2), + (['a', 'a'], set('a'), 1), + (['a', 'a', 'b', 'b'], set(), 4), + (['a', 'a', 'b', 'b'], set('a'), 3), + (['a', 'a', 'b', 'b'], set('b'), 3), + (['a', 'a', 'b', 'b'], set('ab'), 2), + (['a', 'ab', 'b'], set(), 3), + (['a', 'ab', 'b'], set('a'), 2), + (['a', 'ab', 'b'], set('b'), 2), + (['a', 'ab', 'b'], set('ab'), 1), + (['a', 'ab', 'bc', 'c'], set(), 4), + (['a', 'ab', 'bc', 'c'], set('c'), 3), + (['a', 'ab', 'bc', 'c'], set('b'), 3), + (['a', 'ab', 'bc', 'c'], set('a'), 3), + (['a', 'ab', 'bc', 'c'], set('ac'), 2), + (['a', 'ab', 'bc', 'c'], set('abc'), 1), +]) +def test_partition(inputs, dims, expected_num_components): + sizes = dict(zip('abc', [2, 3, 4])) + terms = [random_tensor(OrderedDict((s, bint(sizes[s])) for s in input_)) + for input_ in inputs] + components = list(_partition(terms, dims)) + + # Check that result is a partition. + expected_terms = sorted(terms, key=id) + actual_terms = sorted((x for c in components for x in c[0]), key=id) + assert actual_terms == expected_terms + assert dims == set.union(set(), *(c[1] for c in components)) + + # Check that the partition is not too coarse. + assert len(components) == expected_num_components + + # Check that partition is not too fine. + component_dict = {x: i for i, (terms, _) in enumerate(components) for x in terms} + for x in terms: + for y in terms: + if x is not y: + if dims.intersection(x.inputs, y.inputs): + assert component_dict[x] == component_dict[y] + + +@pytest.mark.parametrize('sum_op,prod_op', [(ops.add, ops.mul), (ops.logaddexp, ops.add)]) +@pytest.mark.parametrize('inputs,plates', [('a,abi,bcij', 'ij')]) +@pytest.mark.parametrize('vars1,vars2', [ + ('', 'abcij'), + ('c', 'abij'), + ('cj', 'abi'), + ('bcj', 'ai'), + ('bcij', 'a'), + ('abcij', ''), +]) +def test_partial_sum_product(sum_op, prod_op, inputs, plates, vars1, vars2): + inputs = inputs.split(',') + factors = [random_tensor(OrderedDict((d, bint(2)) for d in ds)) for ds in inputs] + plates = frozenset(plates) + vars1 = frozenset(vars1) + vars2 = frozenset(vars2) + + factors1 = partial_sum_product(sum_op, prod_op, factors, vars1, plates) + factors2 = partial_sum_product(sum_op, prod_op, factors1, vars2, plates) + actual = reduce(prod_op, factors2) + + expected = sum_product(sum_op, prod_op, factors, vars1 | vars2, plates) + assert_close(actual, expected) diff --git a/test/test_terms.py b/test/test_terms.py index d7608ae70..527156823 100644 --- a/test/test_terms.py +++ b/test/test_terms.py @@ -1,6 +1,7 @@ from __future__ import absolute_import, division, print_function import itertools +from collections import OrderedDict import numpy as np import pytest @@ -9,8 +10,10 @@ import funsor import funsor.ops as ops from funsor.domains import Domain, bint, reals -from funsor.terms import Binary, Number, Stack, Variable, to_funsor -from funsor.testing import check_funsor +from funsor.interpreter import interpretation +from funsor.terms import Binary, Independent, Lambda, Number, Stack, Variable, sequential, to_data, to_funsor +from funsor.testing import assert_close, check_funsor, random_tensor +from funsor.torch import REDUCE_OP_TO_TORCH np.seterr(all='ignore') @@ -20,11 +23,25 @@ def test_to_funsor(): @pytest.mark.parametrize('x', ["foo", list(), tuple(), set(), dict()]) -def test_to_funsor_undefined(x): +def test_to_funsor_error(x): with pytest.raises(ValueError): to_funsor(x) +def test_to_data(): + actual = to_data(Number(0.)) + expected = 0. + assert type(actual) == type(expected) + assert actual == expected + + +def test_to_data_error(): + with pytest.raises(ValueError): + to_data(Variable('x', reals())) + with pytest.raises(ValueError): + to_data(Variable('y', bint(12))) + + def test_cons_hash(): assert Variable('x', bint(3)) is Variable('x', bint(3)) assert Variable('x', reals()) is Variable('x', reals()) @@ -61,7 +78,6 @@ def test_variable(domain): x4 = Variable('x', bint(4)) assert x4 is not x assert x4('x') is x4 - assert x(x=x4) is x4 assert x(y=x4) is x xp1 = x + 1. @@ -142,8 +158,8 @@ def test_binary(symbol, data1, data2): check_funsor(actual, {}, Domain((), dtype), expected_data) -@pytest.mark.parametrize('op', ops.REDUCE_OP_TO_TORCH, - ids=[op.__name__ for op in ops.REDUCE_OP_TO_TORCH]) +@pytest.mark.parametrize('op', REDUCE_OP_TO_TORCH, + ids=[op.__name__ for op in REDUCE_OP_TO_TORCH]) def test_reduce_all(op): x = Variable('x', bint(2)) y = Variable('y', bint(3)) @@ -154,7 +170,8 @@ def test_reduce_all(op): if op is ops.logaddexp: pytest.skip() - actual = f.reduce(op) + with interpretation(sequential): + actual = f.reduce(op) values = [f(x=i, y=j, z=k) for i in x.output @@ -169,8 +186,8 @@ def test_reduce_all(op): for num_reduced in range(3 + 1) for reduced_vars in itertools.combinations('xyz', num_reduced) ]) -@pytest.mark.parametrize('op', ops.REDUCE_OP_TO_TORCH, - ids=[op.__name__ for op in ops.REDUCE_OP_TO_TORCH]) +@pytest.mark.parametrize('op', REDUCE_OP_TO_TORCH, + ids=[op.__name__ for op in REDUCE_OP_TO_TORCH]) def test_reduce_subset(op, reduced_vars): reduced_vars = frozenset(reduced_vars) x = Variable('x', bint(2)) @@ -182,7 +199,8 @@ def test_reduce_subset(op, reduced_vars): if op is ops.logaddexp: pytest.skip() - actual = f.reduce(op, reduced_vars) + with interpretation(sequential): + actual = f.reduce(op, reduced_vars) expected = f for v in [x, y, z]: @@ -195,6 +213,46 @@ def test_reduce_subset(op, reduced_vars): assert actual is f +@pytest.mark.parametrize('base_shape', [(), (4,), (3, 2)], ids=str) +def test_lambda(base_shape): + z = Variable('z', reals(*base_shape)) + i = Variable('i', bint(5)) + j = Variable('j', bint(7)) + + zi = Lambda(i, z) + assert zi.output.shape == (5,) + base_shape + assert zi[i] is z + + zj = Lambda(j, z) + assert zj.output.shape == (7,) + base_shape + assert zj[j] is z + + zij = Lambda(j, zi) + assert zij.output.shape == (7, 5) + base_shape + assert zij[j] is zi + assert zij[j, i] is z + # assert zij[:, i] is zj # XXX this was disabled by alpha-renaming + check_funsor(zij[:, i], zj.inputs, zj.output) + + +def test_independent(): + f = Variable('x', reals(4, 5)) + random_tensor(OrderedDict(i=bint(3))) + assert f.inputs['x'] == reals(4, 5) + assert f.inputs['i'] == bint(3) + + actual = Independent(f, 'x', 'i') + assert actual.inputs['x'] == reals(3, 4, 5) + assert 'i' not in actual.inputs + + x = Variable('x', reals(3, 4, 5)) + expected = f(x=x['i']).reduce(ops.add, 'i') + assert actual.inputs == expected.inputs + assert actual.output == expected.output + + data = random_tensor(OrderedDict(), x.output) + assert_close(actual(data), expected(data), atol=1e-5, rtol=1e-5) + + def test_stack_simple(): x = Number(0.) y = Number(1.) @@ -206,7 +264,7 @@ def test_stack_simple(): assert xyz(i=Number(0, 3)) is x assert xyz(i=Number(1, 3)) is y assert xyz(i=Number(2, 3)) is z - assert xyz.sum('i') == 5. + assert xyz.reduce(ops.add, 'i') == 5. def test_stack_subs(): @@ -224,7 +282,7 @@ def test_stack_subs(): assert f(i=Number(2, 3)) is y * z assert f(i=j) is Stack((Number(0), x, y * z), 'j') assert f(i='j') is Stack((Number(0), x, y * z), 'j') - assert f.sum('i') is Number(0) + x + (y * z) + assert f.reduce(ops.add, 'i') is Number(0) + x + (y * z) assert f(x=0) is Stack((Number(0), Number(0), y * z), 'i') assert f(y=x) is Stack((Number(0), x, x * z), 'i') diff --git a/test/test_torch.py b/test/test_torch.py index 2aaa65ab9..9268e617e 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -7,10 +7,11 @@ import torch import funsor -from funsor.domains import Domain, bint, reals -from funsor.terms import Variable +import funsor.ops as ops +from funsor.domains import Domain, bint, find_domain, reals +from funsor.terms import Lambda, Number, Variable from funsor.testing import assert_close, assert_equiv, check_funsor, random_tensor -from funsor.torch import Tensor, align_tensors +from funsor.torch import REDUCE_OP_TO_TORCH, Einsum, Tensor, align_tensors, torch_stack, torch_tensordot @pytest.mark.parametrize('shape', [(), (4,), (3, 2)]) @@ -19,6 +20,22 @@ def test_to_funsor(shape, dtype): t = torch.randn(shape).type(dtype) f = funsor.to_funsor(t) assert isinstance(f, Tensor) + assert funsor.to_funsor(t, reals(*shape)) is f + with pytest.raises(ValueError): + funsor.to_funsor(t, reals(5, *shape)) + + +def test_to_data(): + data = torch.zeros(3, 3) + x = Tensor(data) + assert funsor.to_data(x) is data + + +def test_to_data_error(): + data = torch.zeros(3, 3) + x = Tensor(data, OrderedDict(i=bint(3))) + with pytest.raises(ValueError): + funsor.to_data(x) def test_cons_hash(): @@ -50,14 +67,14 @@ def test_indexing(): def test_advanced_indexing_shape(): - I, J, M, N = 4, 5, 2, 3 - x = Tensor(torch.randn(4, 5), OrderedDict([ + I, J, M, N = 4, 4, 2, 3 + x = Tensor(torch.randn(I, J), OrderedDict([ ('i', bint(I)), ('j', bint(J)), ])) - m = Tensor(torch.tensor([2, 3]), OrderedDict([('m', bint(M))])) - n = Tensor(torch.tensor([0, 1, 1]), OrderedDict([('n', bint(N))])) - assert x.data.shape == (4, 5) + m = Tensor(torch.tensor([2, 3]), OrderedDict([('m', bint(M))]), I) + n = Tensor(torch.tensor([0, 1, 1]), OrderedDict([('n', bint(N))]), J) + assert x.data.shape == (I, J) check_funsor(x(i=m), {'j': bint(J), 'm': bint(M)}, reals()) check_funsor(x(i=m, j=n), {'m': bint(M), 'n': bint(N)}, reals()) @@ -91,21 +108,22 @@ def test_advanced_indexing_tensor(output_shape): # \ | / # \ | / # x - x = Tensor(torch.randn((2, 3, 4) + output_shape), OrderedDict([ + output = reals(*output_shape) + x = random_tensor(OrderedDict([ ('i', bint(2)), ('j', bint(3)), ('k', bint(4)), - ])) - i = Tensor(random_tensor(2, (5,)), OrderedDict([ + ]), output) + i = random_tensor(OrderedDict([ ('u', bint(5)), - ])) - j = Tensor(random_tensor(3, (6, 5)), OrderedDict([ + ]), bint(2)) + j = random_tensor(OrderedDict([ ('v', bint(6)), ('u', bint(5)), - ])) - k = Tensor(random_tensor(4, (6,)), OrderedDict([ + ]), bint(3)) + k = random_tensor(OrderedDict([ ('v', bint(6)), - ])) + ]), bint(4)) expected_data = torch.empty((5, 6) + output_shape) for u in range(5): @@ -144,8 +162,8 @@ def test_advanced_indexing_lazy(output_shape): ])) u = Variable('u', bint(2)) v = Variable('v', bint(3)) - i = 1 - u - j = 2 - v + i = Number(1, 2) - u + j = Number(2, 3) - v k = u + v expected_data = torch.empty((2, 3) + output_shape) @@ -245,6 +263,26 @@ def test_binary_funsor_funsor(symbol, dims1, dims2): check_funsor(actual, inputs, Domain((), dtype), expected_data) +@pytest.mark.parametrize('output_shape2', [(), (2,), (3, 2)], ids=str) +@pytest.mark.parametrize('output_shape1', [(), (2,), (3, 2)], ids=str) +@pytest.mark.parametrize('inputs2', [(), ('a',), ('b', 'a'), ('b', 'c', 'a')], ids=str) +@pytest.mark.parametrize('inputs1', [(), ('a',), ('a', 'b'), ('b', 'a', 'c')], ids=str) +def test_binary_broadcast(inputs1, inputs2, output_shape1, output_shape2): + sizes = {'a': 4, 'b': 5, 'c': 6} + inputs1 = OrderedDict((k, bint(sizes[k])) for k in inputs1) + inputs2 = OrderedDict((k, bint(sizes[k])) for k in inputs2) + x1 = random_tensor(inputs1, reals(*output_shape1)) + x2 = random_tensor(inputs1, reals(*output_shape2)) + + actual = x1 + x2 + assert actual.output == find_domain(ops.add, x1.output, x2.output) + + block = {'a': 1, 'b': 2, 'c': 3} + actual_block = actual(**block) + expected_block = Tensor(x1(**block).data + x2(**block).data) + assert_close(actual_block, expected_block) + + @pytest.mark.parametrize('scalar', [0.5]) @pytest.mark.parametrize('dims', [(), ('a',), ('a', 'b'), ('b', 'a', 'c')]) @pytest.mark.parametrize('symbol', BINARY_OPS) @@ -275,26 +313,124 @@ def test_binary_scalar_funsor(symbol, dims, scalar): check_funsor(actual, inputs, reals(), expected_data) -REDUCE_OPS = ['sum', 'prod', 'logsumexp', 'all', 'any', 'min', 'max'] +def test_getitem_number_0_inputs(): + data = torch.randn((5, 4, 3, 2)) + x = Tensor(data) + assert_close(x[2], Tensor(data[2])) + assert_close(x[:, 1], Tensor(data[:, 1])) + assert_close(x[2, 1], Tensor(data[2, 1])) + assert_close(x[2, :, 1], Tensor(data[2, :, 1])) + assert_close(x[3, ...], Tensor(data[3, ...])) + assert_close(x[3, 2, ...], Tensor(data[3, 2, ...])) + assert_close(x[..., 1], Tensor(data[..., 1])) + assert_close(x[..., 2, 1], Tensor(data[..., 2, 1])) + assert_close(x[3, ..., 1], Tensor(data[3, ..., 1])) + + +def test_getitem_number_1_inputs(): + data = torch.randn((3, 5, 4, 3, 2)) + inputs = OrderedDict([('i', bint(3))]) + x = Tensor(data, inputs) + assert_close(x[2], Tensor(data[:, 2], inputs)) + assert_close(x[:, 1], Tensor(data[:, :, 1], inputs)) + assert_close(x[2, 1], Tensor(data[:, 2, 1], inputs)) + assert_close(x[2, :, 1], Tensor(data[:, 2, :, 1], inputs)) + assert_close(x[3, ...], Tensor(data[:, 3, ...], inputs)) + assert_close(x[3, 2, ...], Tensor(data[:, 3, 2, ...], inputs)) + assert_close(x[..., 1], Tensor(data[..., 1], inputs)) + assert_close(x[..., 2, 1], Tensor(data[..., 2, 1], inputs)) + assert_close(x[3, ..., 1], Tensor(data[:, 3, ..., 1], inputs)) + + +def test_getitem_number_2_inputs(): + data = torch.randn((3, 4, 5, 4, 3, 2)) + inputs = OrderedDict([('i', bint(3)), ('j', bint(4))]) + x = Tensor(data, inputs) + assert_close(x[2], Tensor(data[:, :, 2], inputs)) + assert_close(x[:, 1], Tensor(data[:, :, :, 1], inputs)) + assert_close(x[2, 1], Tensor(data[:, :, 2, 1], inputs)) + assert_close(x[2, :, 1], Tensor(data[:, :, 2, :, 1], inputs)) + assert_close(x[3, ...], Tensor(data[:, :, 3, ...], inputs)) + assert_close(x[3, 2, ...], Tensor(data[:, :, 3, 2, ...], inputs)) + assert_close(x[..., 1], Tensor(data[..., 1], inputs)) + assert_close(x[..., 2, 1], Tensor(data[..., 2, 1], inputs)) + assert_close(x[3, ..., 1], Tensor(data[:, :, 3, ..., 1], inputs)) + + +def test_getitem_variable(): + data = torch.randn((5, 4, 3, 2)) + x = Tensor(data) + i = Variable('i', bint(5)) + j = Variable('j', bint(4)) + assert x[i] is Tensor(data, OrderedDict([('i', bint(5))])) + assert x[i, j] is Tensor(data, OrderedDict([('i', bint(5)), ('j', bint(4))])) + + +def test_getitem_string(): + data = torch.randn((5, 4, 3, 2)) + x = Tensor(data) + assert x['i'] is Tensor(data, OrderedDict([('i', bint(5))])) + assert x['i', 'j'] is Tensor(data, OrderedDict([('i', bint(5)), ('j', bint(4))])) + + +def test_getitem_tensor(): + data = torch.randn((5, 4, 3, 2)) + x = Tensor(data) + i = Variable('i', bint(5)) + j = Variable('j', bint(4)) + k = Variable('k', bint(3)) + m = Variable('m', bint(2)) + + y = random_tensor(OrderedDict(), bint(5)) + assert_close(x[i](i=y), x[y]) + + y = random_tensor(OrderedDict(), bint(4)) + assert_close(x[:, j](j=y), x[:, y]) + + y = random_tensor(OrderedDict(), bint(3)) + assert_close(x[:, :, k](k=y), x[:, :, y]) + + y = random_tensor(OrderedDict(), bint(2)) + assert_close(x[:, :, :, m](m=y), x[:, :, :, y]) + + y = random_tensor(OrderedDict([('i', i.output)]), + bint(j.dtype)) + assert_close(x[i, j](j=y), x[i, y]) + + y = random_tensor(OrderedDict([('i', i.output), ('j', j.output)]), + bint(k.dtype)) + assert_close(x[i, j, k](k=y), x[i, j, y]) + + +def test_lambda_getitem(): + data = torch.randn(2) + x = Tensor(data) + y = Tensor(data, OrderedDict(i=bint(2))) + i = Variable('i', bint(2)) + assert x[i] is y + assert Lambda(i, y) is x + + +REDUCE_OPS = [ops.add, ops.mul, ops.and_, ops.or_, ops.logaddexp, ops.min, ops.max] @pytest.mark.parametrize('dims', [(), ('a',), ('a', 'b'), ('b', 'a', 'c')]) -@pytest.mark.parametrize('op_name', REDUCE_OPS) -def test_reduce_all(dims, op_name): +@pytest.mark.parametrize('op', REDUCE_OPS, ids=str) +def test_reduce_all(dims, op): sizes = {'a': 3, 'b': 4, 'c': 5} shape = tuple(sizes[d] for d in dims) inputs = OrderedDict((d, bint(sizes[d])) for d in dims) data = torch.rand(shape) + 0.5 - if op_name in ['all', 'any']: + if op in [ops.and_, ops.or_]: data = data.byte() - if op_name == 'logsumexp': + if op is ops.logaddexp: # work around missing torch.Tensor.logsumexp() expected_data = data.reshape(-1).logsumexp(0) else: - expected_data = getattr(data, op_name)() + expected_data = REDUCE_OP_TO_TORCH[op](data) x = Tensor(data, inputs) - actual = getattr(x, op_name)() + actual = x.reduce(op) check_funsor(actual, {}, reals(), expected_data) @@ -304,19 +440,19 @@ def test_reduce_all(dims, op_name): for num_reduced in range(len(dims) + 2) for reduced_vars in itertools.combinations(dims, num_reduced) ]) -@pytest.mark.parametrize('op_name', REDUCE_OPS) -def test_reduce_subset(dims, reduced_vars, op_name): +@pytest.mark.parametrize('op', REDUCE_OPS) +def test_reduce_subset(dims, reduced_vars, op): reduced_vars = frozenset(reduced_vars) sizes = {'a': 3, 'b': 4, 'c': 5} shape = tuple(sizes[d] for d in dims) inputs = OrderedDict((d, bint(sizes[d])) for d in dims) data = torch.rand(shape) + 0.5 dtype = 'real' - if op_name in ['all', 'any']: + if op in [ops.and_, ops.or_]: data = data.byte() dtype = 2 x = Tensor(data, inputs, dtype) - actual = getattr(x, op_name)(reduced_vars) + actual = x.reduce(op, reduced_vars) expected_inputs = OrderedDict( (d, bint(sizes[d])) for d in dims if d not in reduced_vars) @@ -325,25 +461,63 @@ def test_reduce_subset(dims, reduced_vars, op_name): assert actual is x else: if reduced_vars == frozenset(dims): - if op_name == 'logsumexp': + if op is ops.logaddexp: # work around missing torch.Tensor.logsumexp() data = data.reshape(-1).logsumexp(0) else: - data = getattr(data, op_name)() + data = REDUCE_OP_TO_TORCH[op](data) else: for pos in reversed(sorted(map(dims.index, reduced_vars))): - if op_name in ('min', 'max'): - data = getattr(data, op_name)(pos)[0] - else: - data = getattr(data, op_name)(pos) + data = REDUCE_OP_TO_TORCH[op](data, pos) + if op in (ops.min, ops.max): + data = data[0] check_funsor(actual, expected_inputs, Domain((), dtype)) assert_close(actual, Tensor(data, expected_inputs, dtype), atol=1e-5, rtol=1e-5) +@pytest.mark.parametrize('dims', [(), ('a',), ('a', 'b'), ('b', 'a', 'c')]) +@pytest.mark.parametrize('event_shape', [(), (4,), (2, 3)]) +@pytest.mark.parametrize('op', REDUCE_OPS, ids=str) +def test_reduce_event(op, event_shape, dims): + sizes = {'a': 3, 'b': 4, 'c': 5} + batch_shape = tuple(sizes[d] for d in dims) + shape = batch_shape + event_shape + inputs = OrderedDict((d, bint(sizes[d])) for d in dims) + torch_op = REDUCE_OP_TO_TORCH[op] + data = torch.rand(shape) + 0.5 + dtype = 'real' + if op in [ops.and_, ops.or_]: + data = data.byte() + expected_data = torch_op(data.reshape(batch_shape + (-1,)), -1) + if op in [ops.min, ops.max]: + expected_data = expected_data[0] + + x = Tensor(data, inputs, dtype=dtype) + actual = getattr(x, torch_op.__name__)() + check_funsor(actual, inputs, Domain((), dtype), expected_data) + + +@pytest.mark.parametrize('shape', [(), (4,), (2, 3)]) +def test_all_equal(shape): + inputs = OrderedDict() + data1 = torch.rand(shape) + 0.5 + data2 = torch.rand(shape) + 0.5 + dtype = 'real' + + x1 = Tensor(data1, inputs, dtype=dtype) + x2 = Tensor(data2, inputs, dtype=dtype) + assert (x1 == x1).all() + assert (x2 == x2).all() + assert not (x1 == x2).all() + assert not (x1 != x1).any() + assert not (x2 != x2).any() + assert (x1 != x2).any() + + def test_function_matmul(): - @funsor.function(reals(3, 4), reals(4, 5), reals(3, 5)) + @funsor.torch.function(reals(3, 4), reals(4, 5), reals(3, 5)) def matmul(x, y): return torch.matmul(x, y) @@ -358,15 +532,15 @@ def matmul(x, y): def test_function_lazy_matmul(): - @funsor.function(reals(3, 4), reals(4, 5), reals(3, 5)) + @funsor.torch.function(reals(3, 4), reals(4, 5), reals(3, 5)) def matmul(x, y): return torch.matmul(x, y) - x_lazy = funsor.Variable('x', reals(3, 4)) + x_lazy = Variable('x', reals(3, 4)) y = Tensor(torch.randn(4, 5)) actual_lazy = matmul(x_lazy, y) check_funsor(actual_lazy, {'x': reals(3, 4)}, reals(3, 5)) - assert isinstance(actual_lazy, funsor.Function) + assert isinstance(actual_lazy, funsor.torch.Function) x = Tensor(torch.randn(3, 4)) actual = actual_lazy(x=x) @@ -374,6 +548,54 @@ def matmul(x, y): check_funsor(actual, {}, reals(3, 5), expected_data) +def test_function_nested_eager(): + + @funsor.torch.function(reals(8), (reals(), bint(8))) + def max_and_argmax(x): + return tuple(torch.max(x, dim=-1)) + + inputs = OrderedDict([('i', bint(2)), ('j', bint(3))]) + x = Tensor(torch.randn(2, 3, 8), inputs) + m, a = x.data.max(dim=-1) + expected_max = Tensor(m, inputs, 'real') + expected_argmax = Tensor(a, inputs, 8) + + actual_max, actual_argmax = max_and_argmax(x) + assert_close(actual_max, expected_max) + assert_close(actual_argmax, expected_argmax) + + +def test_function_nested_lazy(): + + @funsor.torch.function(reals(8), (reals(), bint(8))) + def max_and_argmax(x): + return tuple(torch.max(x, dim=-1)) + + x_lazy = Variable('x', reals(8)) + lazy_max, lazy_argmax = max_and_argmax(x_lazy) + assert isinstance(lazy_max, funsor.torch.Function) + assert isinstance(lazy_argmax, funsor.torch.Function) + check_funsor(lazy_max, {'x': reals(8)}, reals()) + check_funsor(lazy_argmax, {'x': reals(8)}, bint(8)) + + inputs = OrderedDict([('i', bint(2)), ('j', bint(3))]) + y = Tensor(torch.randn(2, 3, 8), inputs) + actual_max = lazy_max(x=y) + actual_argmax = lazy_argmax(x=y) + expected_max, expected_argmax = max_and_argmax(y) + assert_close(actual_max, expected_max) + assert_close(actual_argmax, expected_argmax) + + +def test_function_of_torch_tensor(): + x = torch.randn(4, 3) + y = torch.randn(3, 2) + f = funsor.torch.function(reals(4, 3), reals(3, 2), reals(4, 2))(torch.matmul) + actual = f(x, y) + expected = f(Tensor(x), Tensor(y)) + assert_close(actual, expected) + + def test_align(): x = Tensor(torch.randn(2, 3, 4), OrderedDict([ ('i', bint(2)), @@ -411,5 +633,39 @@ def test_einsum(equation): tensors = [torch.randn(tuple(sizes[d] for d in dims)) for dims in inputs] funsors = [Tensor(x) for x in tensors] expected = Tensor(torch.einsum(equation, *tensors)) - actual = funsor.einsum(equation, *funsors) + actual = Einsum(equation, tuple(funsors)) + assert_close(actual, expected, atol=1e-5, rtol=None) + + +@pytest.mark.parametrize('y_shape', [(), (4,), (4, 5)], ids=str) +@pytest.mark.parametrize('xy_shape', [(), (6,), (6, 7)], ids=str) +@pytest.mark.parametrize('x_shape', [(), (2,), (2, 3)], ids=str) +def test_tensordot(x_shape, xy_shape, y_shape): + x = torch.randn(x_shape + xy_shape) + y = torch.randn(xy_shape + y_shape) + dim = len(xy_shape) + actual = torch_tensordot(Tensor(x), Tensor(y), dim) + expected = Tensor(torch.tensordot(x, y, dim)) + assert_close(actual, expected, atol=1e-5, rtol=None) + + +@pytest.mark.parametrize('n', [1, 2, 5]) +@pytest.mark.parametrize('shape,dim', [ + ((), 0), + ((), -1), + ((1,), 0), + ((1,), 1), + ((1,), -1), + ((1,), -2), + ((2, 3), 0), + ((2, 3), 1), + ((2, 3), 2), + ((2, 3), -1), + ((2, 3), -2), + ((2, 3), -3), +], ids=str) +def test_stack(n, shape, dim): + tensors = [torch.randn(shape) for _ in range(n)] + actual = torch_stack(tuple(Tensor(t) for t in tensors), dim=dim) + expected = Tensor(torch.stack(tensors, dim=dim)) assert_close(actual, expected)