Skip to content

Commit

Permalink
revert black changes
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi committed Feb 21, 2021
1 parent 274841d commit e7dc9d1
Show file tree
Hide file tree
Showing 33 changed files with 1,172 additions and 194 deletions.
6 changes: 5 additions & 1 deletion funsor/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,8 @@ def extract_affine(fn):
return const, coeffs


__all__ = ["affine_inputs", "extract_affine", "is_affine"]
__all__ = [
"affine_inputs",
"extract_affine",
"is_affine",
]
6 changes: 5 additions & 1 deletion funsor/cnf.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,11 @@ def unary_contract(op, arg):
)


BACKEND_TO_EINSUM_BACKEND = {"numpy": "numpy", "torch": "torch", "jax": "jax.numpy"}
BACKEND_TO_EINSUM_BACKEND = {
"numpy": "numpy",
"torch": "torch",
"jax": "jax.numpy",
}
# NB: numpy_log, numpy_map is backend-agnostic so they also work for torch backend;
# however, we might need to profile to make a switch
BACKEND_TO_LOGSUMEXP_BACKEND = {
Expand Down
5 changes: 4 additions & 1 deletion funsor/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,4 +248,7 @@ def eager_independent_delta(delta, reals_var, bint_var, diag_var):
return None


__all__ = ["Delta", "solve"]
__all__ = [
"Delta",
"solve",
]
7 changes: 5 additions & 2 deletions funsor/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def eager_log_prob(cls, *params):
params, value = params[:-1], params[-1]
params = params + (Variable("value", value.output),)
instance = reflect.interpret(cls, *params)
(raw_dist, value_name, value_output, dim_to_name) = instance._get_raw_dist()
raw_dist, value_name, value_output, dim_to_name = instance._get_raw_dist()
assert value.output == value_output
name_to_dim = {v: k for k, v in dim_to_name.items()}
dim_to_name.update(
Expand Down Expand Up @@ -379,7 +379,10 @@ def dist_init(self, **kwargs):
dist_class = DistributionMeta(
backend_dist_class.__name__.split("Wrapper_")[-1],
(Distribution,),
{"dist_class": backend_dist_class, "__init__": dist_init},
{
"dist_class": backend_dist_class,
"__init__": dist_init,
},
)

if generate_eager:
Expand Down
7 changes: 6 additions & 1 deletion funsor/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,4 +779,9 @@ def eager_neg(op, arg):
return Gaussian(info_vec, precision, arg.inputs)


__all__ = ["BlockMatrix", "BlockVector", "Gaussian", "align_gaussian"]
__all__ = [
"BlockMatrix",
"BlockVector",
"Gaussian",
"align_gaussian",
]
9 changes: 8 additions & 1 deletion funsor/instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,11 @@ def print_counters():
print("-" * 80)


__all__ = ["DEBUG", "PROFILE", "STACK_SIZE", "debug_logged", "get_indent", "profile"]
__all__ = [
"DEBUG",
"PROFILE",
"STACK_SIZE",
"debug_logged",
"get_indent",
"profile",
]
4 changes: 3 additions & 1 deletion funsor/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,4 +230,6 @@ def eager_integrate(log_measure, integrand, reduced_vars):
return None # defer to default implementation


__all__ = ["Integrate"]
__all__ = [
"Integrate",
]
3 changes: 2 additions & 1 deletion funsor/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def interpret(cls, *args):

def interpretation(new):
warnings.warn(
"'with interpretation(x)' should be replaced by 'with x'", DeprecationWarning
"'with interpretation(x)' should be replaced by 'with x'",
DeprecationWarning,
)
return new

Expand Down
8 changes: 7 additions & 1 deletion funsor/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@


@adjoint_ops.register(
Tensor, AssociativeOp, AssociativeOp, Funsor, (DeviceArray, Tracer), tuple, object
Tensor,
AssociativeOp,
AssociativeOp,
Funsor,
(DeviceArray, Tracer),
tuple,
object,
)
def adjoint_tensor(adj_redop, adj_binop, out_adj, data, inputs, dtype):
return {}
Expand Down
5 changes: 4 additions & 1 deletion funsor/jax/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,10 @@ def _triangular_solve(x, y, upper=False, transpose=False):
x_new_shape = batch_shape[:prepend_ndim]
for (sy, sx) in zip(y.shape[:-2], batch_shape[prepend_ndim:]):
x_new_shape += (sx // sy, sy)
x_new_shape += (n, m)
x_new_shape += (
n,
m,
)
x = np.reshape(x, x_new_shape)
# Permute y to make it have shape (..., 1, j, m, i, 1, n)
batch_ndim = x.ndim - 2
Expand Down
4 changes: 3 additions & 1 deletion funsor/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ def moment_matching_contract_joint(red_op, bin_op, reduced_vars, discrete, gauss
discrete += gaussian.log_normalizer
new_discrete = discrete.reduce(ops.logaddexp, approx_vars & discrete.input_vars)
num_elements = reduce(
ops.mul, [v.output.num_elements for v in approx_vars - discrete.input_vars], 1
ops.mul,
[v.output.num_elements for v in approx_vars - discrete.input_vars],
1,
)
if num_elements != 1:
new_discrete -= math.log(num_elements)
Expand Down
4 changes: 3 additions & 1 deletion funsor/memoize.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,6 @@ def interpret(self, cls, *args):
return value


__all__ = ["memoize"]
__all__ = [
"memoize",
]
4 changes: 3 additions & 1 deletion funsor/montecarlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,6 @@ def monte_carlo_integrate(state, log_measure, integrand, reduced_vars):
return Integrate(sample, integrand, reduced_vars)


__all__ = ["MonteCarlo"]
__all__ = [
"MonteCarlo",
]
4 changes: 3 additions & 1 deletion funsor/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,6 @@ def dispatch(self, key, *args):
return self[key].partial_call(*args)


__all__ = ["KeyedRegistry"]
__all__ = [
"KeyedRegistry",
]
23 changes: 19 additions & 4 deletions funsor/syntax.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,12 @@ def visit_UnaryOp(self, node):
var = self.prefix.get(type(node.op))
if var is not None:
node = ast.Call(
func=ast.Name(id=var, ctx=ast.Load()), args=[node.operand], keywords=[]
func=ast.Name(
id=var,
ctx=ast.Load(),
),
args=[node.operand],
keywords=[],
)
return node

Expand All @@ -68,7 +73,10 @@ def visit_BinOp(self, node):
var = self.infix.get(type(node.op))
if var is not None:
node = ast.Call(
func=ast.Name(id=var, ctx=ast.Load()),
func=ast.Name(
id=var,
ctx=ast.Load(),
),
args=[node.left, node.right],
keywords=[],
)
Expand All @@ -90,7 +98,10 @@ def visit_Compare(self, node):
var = self.infix.get(type(node_op))
if var is not None:
node = ast.Call(
func=ast.Name(id=var, ctx=ast.Load()),
func=ast.Name(
id=var,
ctx=ast.Load(),
),
args=[node.left, node_right],
keywords=[],
)
Expand Down Expand Up @@ -161,4 +172,8 @@ def decorator(fn):
return decorator


__all__ = ["INFIX_OPERATORS", "PREFIX_OPERATORS", "rewrite_ops"]
__all__ = [
"INFIX_OPERATORS",
"PREFIX_OPERATORS",
"rewrite_ops",
]
2 changes: 1 addition & 1 deletion funsor/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1520,7 +1520,7 @@ def eager_subs(self, subs):
n -= size
assert False
elif isinstance(value, Slice):
start, stop, step = (value.slice.start, value.slice.stop, value.slice.step)
start, stop, step = value.slice.start, value.slice.stop, value.slice.step
new_parts = []
pos = 0
for part in self.parts:
Expand Down
6 changes: 3 additions & 3 deletions funsor/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ def assert_close(actual, expected, atol=1e-6, rtol=1e-6):
n for n, p in expected.terms
)
actual = actual.align(tuple(n for n, p in expected.terms))
for (
(actual_name, (actual_point, actual_log_density)),
(expected_name, (expected_point, expected_log_density)),
for (actual_name, (actual_point, actual_log_density)), (
expected_name,
(expected_point, expected_log_density),
) in zip(actual.terms, expected.terms):
assert actual_name == expected_name
assert_close(actual_point, expected_point, atol=atol, rtol=rtol)
Expand Down
92 changes: 76 additions & 16 deletions test/examples/test_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ def unpack_gate_rate(gate_rate):

@pytest.mark.parametrize(
"analytic_kl",
[False, xfail_param(True, reason="missing pattern")],
[
False,
xfail_param(True, reason="missing pattern"),
],
ids=["monte-carlo-kl", "analytic-kl"],
)
def test_bart(analytic_kl):
Expand Down Expand Up @@ -93,7 +96,16 @@ def test_bart(analytic_kl):
],
dtype=torch.float32,
), # noqa
(("time_b4", Bint[2]), ("_event_1_b2", Bint[8])),
(
(
"time_b4",
Bint[2],
),
(
"_event_1_b2",
Bint[8],
),
),
"real",
),
Gaussian(
Expand Down Expand Up @@ -148,9 +160,18 @@ def test_bart(analytic_kl):
dtype=torch.float32,
), # noqa
(
("time_b4", Bint[2]),
("_event_1_b2", Bint[8]),
("value_b1", Real),
(
"time_b4",
Bint[2],
),
(
"_event_1_b2",
Bint[8],
),
(
"value_b1",
Real,
),
),
),
),
Expand Down Expand Up @@ -220,8 +241,14 @@ def test_bart(analytic_kl):
dtype=torch.float32,
), # noqa
(
("state_b7", Reals[2]),
("state(time=1)_b8", Reals[2]),
(
"state_b7",
Reals[2],
),
(
"state(time=1)_b8",
Reals[2],
),
),
),
Subs(
Expand Down Expand Up @@ -281,7 +308,12 @@ def test_bart(analytic_kl):
],
dtype=torch.float32,
), # noqa
(("time_b9", Bint[2]),),
(
(
"time_b9",
Bint[2],
),
),
"real",
),
Tensor(
Expand Down Expand Up @@ -310,7 +342,12 @@ def test_bart(analytic_kl):
],
dtype=torch.float32,
), # noqa
(("time_b9", Bint[2]),),
(
(
"time_b9",
Bint[2],
),
),
"real",
),
Variable("state(time=1)_b8", Reals[2]),
Expand Down Expand Up @@ -352,7 +389,12 @@ def test_bart(analytic_kl):
),
Variable("value_b5", Reals[2]),
),
(("value_b5", Variable("state_b10", Reals[2])),),
(
(
"value_b5",
Variable("state_b10", Reals[2]),
),
),
),
),
)
Expand Down Expand Up @@ -449,9 +491,18 @@ def test_bart(analytic_kl):
dtype=torch.float32,
), # noqa
(
("time_b17", Bint[2]),
("origin_b15", Bint[2]),
("destin_b16", Bint[2]),
(
"time_b17",
Bint[2],
),
(
"origin_b15",
Bint[2],
),
(
"destin_b16",
Bint[2],
),
),
"real",
),
Expand All @@ -476,9 +527,18 @@ def test_bart(analytic_kl):
dtype=torch.float32,
), # noqa
(
("time_b17", Bint[2]),
("origin_b15", Bint[2]),
("destin_b16", Bint[2]),
(
"time_b17",
Bint[2],
),
(
"origin_b15",
Bint[2],
),
(
"destin_b16",
Bint[2],
),
),
"real",
),
Expand Down
11 changes: 10 additions & 1 deletion test/examples/test_sensor_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,16 @@ def test_affine_subs():
],
dtype=torch.float32,
), # noqa
(("state_1_b6", Reals[3]), ("obs_b2", Reals[2])),
(
(
"state_1_b6",
Reals[3],
),
(
"obs_b2",
Reals[2],
),
),
),
(
(
Expand Down
Loading

0 comments on commit e7dc9d1

Please sign in to comment.