-
-
Notifications
You must be signed in to change notification settings - Fork 984
/
Copy pathtorch_map.py
61 lines (49 loc) · 1.91 KB
/
torch_map.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import operator
from functools import reduce
from pyro.ops import packed
from pyro.ops.einsum.adjoint import (
Backward,
einsum_backward_sample,
transpose,
unflatten,
)
from pyro.ops.einsum.util import Tensordot
class _EinsumBackward(Backward):
def __init__(self, operands, argmax):
self.operands = operands
self.argmax = argmax
def process(self, message):
sample1 = self.argmax
sample2 = message
return einsum_backward_sample(self.operands, sample1, sample2)
def einsum(equation, *operands):
"""
Forward-max-sum backward-argmax implementation of einsum.
This assumes all operands have a ``._pyro_dims`` attribute set.
"""
equation = packed.rename_equation(equation, *operands)
inputs, output = equation.split("->")
any_requires_backward = any(hasattr(x, "_pyro_backward") for x in operands)
contract_dims = "".join(
sorted(set().union(*(x._pyro_dims for x in operands)) - set(output))
)
dims = output + contract_dims
result = reduce(operator.add, packed.broadcast_all(*operands, dims=dims))
argmax = None # work around lack of pytorch support for zero-sized tensors
if contract_dims:
output_shape = result.shape[: len(output)]
contract_shape = result.shape[len(output) :]
result, argmax = result.reshape(output_shape + (-1,)).max(-1)
if any_requires_backward:
argmax = unflatten(argmax, output, contract_dims, contract_shape)
elif result is operands[0]:
result = result[...] # create a new object
result._pyro_dims = output
assert result.dim() == len(result._pyro_dims)
if any_requires_backward:
result._pyro_backward = _EinsumBackward(operands, argmax)
return result
tensordot = Tensordot(einsum)
__all__ = ["transpose", "einsum", "tensordot"]