Skip to content

Commit

Permalink
Teach the histogram optimization (aka "summarize") to perform Kroneck…
Browse files Browse the repository at this point in the history
…er expansion

Per discussion at #104 ,
and after waiting a long time for an error about properties on
the kronecker_in_simplify branch (e65bda5)
  • Loading branch information
ccshan committed Aug 3, 2017
1 parent c3dbaf0 commit 27de978
Showing 1 changed file with 28 additions and 12 deletions.
40 changes: 28 additions & 12 deletions maple/Summary.mpl
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ Summary := module ()
rng := map(SummarizeKB, rng, kb);
e1 := eval(op(1,e), op([2,1],e)=x);
if op(0, e) in '{sum, Sum}' and has(e1, 'piecewise') then
mr, f := summarize(e1, kb, x, summary);
mr, f := summarize(e1, kb, x=rng, summary);
if hasfun(mr, '{Fanout, Index}') then
mr := SummarizeKB(mr, kb1);
return Let(Bucket(mr, x=rng),
Expand Down Expand Up @@ -146,14 +146,12 @@ Summary := module ()
end if;
end proc;

summarize := proc(ee,
kb :: t_kb,
i :: {name,list(name),set(name)},
summary, $)
local e, r, s,
summarize := proc(ee, kb :: t_kb, ii :: name=range, summary, $)
local i, rng, e, r, s,
e1, mr1, f1, e2, mr2, f2,
variables, var_outerness, outermost, thunk, consider,
o, t, lo, hi, b, a;
i, rng := op(ii);

# Try to ensure termination
e := simplify_assuming(ee, kb);
Expand All @@ -162,15 +160,33 @@ Summary := module ()
# Nop rule
if Testzero(e) then return Nop(), 0 end if;

# Kronecker expansion (https://github.com/hakaru-dev/hakaru/issues/104):
# rewrite sum(f(i = b), i=rng) to
# sum(f(false), i=rng) +
# piecewise(..., eval(f(true) - f(false), i=b), 0)
for r in select(depends, indets(e, '{`=`, `<>`}'), i) do
if not ispoly(`-`(op(r)), 'linear', i, 'b', 'a') then next end if;
b := Normalizer(-b/a);
e2 := eval(e, {subsop(0=`<>`,r)=true, subsop(0=`=` ,r)=false});
if length(e2) >= length(e) then next end if;
e1 := eval(e, {subsop(0=`=` ,r)=true, subsop(0=`<>`,r)=false});
f1 := 'piecewise'(And(b::integer, lhs(rng)<=b, b<=rhs(rng)),
eval(e1-e2, i=b),
0);
if has(f1, '{undefined, infinity, FAIL}') then next end if;
mr2, f2 := summarize(e2, kb, ii, summary);
return mr2, f1 + f2;
end do;

r := indets(e, '{relation, logical, specfunc({And,Or})}');
s, r := selectremove(depends, r, i);
# Fanout rule
r := sort(convert(r, 'list'), 'length');
while nops(r) > 0 do
e1 := eval(e, r[-1]=true); if e = e1 then r := r[1..-2]; next end if;
e2 := eval(e, r[-1]=false); if e = e2 then r := r[1..-2]; next end if;
mr1, f1 := summarize(e1, kb, i, 'fst(summary)');
mr2, f2 := summarize(e2, kb, i, 'snd(summary)');
mr1, f1 := summarize(e1, kb, ii, 'fst(summary)');
mr2, f2 := summarize(e2, kb, ii, 'snd(summary)');
return Fanout(mr1, mr2), 'piecewise'(r[-1], f1, f2);
end do;

Expand All @@ -194,7 +210,7 @@ Summary := module ()
e2 := eval(e, r=false); if e = e2 then next end if;
# Index rule
if r :: `=` and Testzero(e2) then
for o in indets(r, 'name') minus convert(i, 'set') do
for o in indets(r, 'name') minus {i} do
if not (var_outerness[o] :: integer) then next end if;
t := getType(kb, o);
if not (t :: specfunc(HInt)) then next end if;
Expand All @@ -206,7 +222,7 @@ Summary := module ()
b := Normalizer(-b/a);
consider(b, ((e1, o, lo, hi, b) -> proc($)
local mr, f;
mr, f := summarize(e1, kb, i, 'idx'(summary, o-lo));
mr, f := summarize(e1, kb, ii, 'idx'(summary, o-lo));
Index(hi-lo+1, o, b, mr),
'piecewise'(And(o::integer, lo<=o, o<=hi), f, 0);
end proc)(e1, o, lo, hi, b));
Expand All @@ -215,8 +231,8 @@ Summary := module ()
# Split rule
consider(r, ((e1, e2, r) -> proc($)
local mr1, f1, mr2, f2;
mr1, f1 := summarize(e1, kb, i, 'fst(summary)');
mr2, f2 := summarize(e2, kb, i, 'snd(summary)');
mr1, f1 := summarize(e1, kb, ii, 'fst(summary)');
mr2, f2 := summarize(e2, kb, ii, 'snd(summary)');
Split(r, mr1, mr2), f1 + f2;
end proc)(e1, e2, r));
end do;
Expand Down

0 comments on commit 27de978

Please sign in to comment.