Skip to content

Commit

Permalink
feat: teach bv_normalize to rewrite subtractions to additions (#6890)
Browse files Browse the repository at this point in the history
This PR teaches bv_normalize to replace subtractions on one side of an
equality with an addition on the other side, this re-write eliminates a
not + addition in the normalized form so it is easier on the solver.

Note that I also make a point to normalize (1 + ~~~x) to (~~~x + 1) to
limit the amount of boilerplate symmetry theorems we require.
  • Loading branch information
vlad902 authored Feb 1, 2025
1 parent 66471ba commit ca96ea3
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 31 deletions.
5 changes: 5 additions & 0 deletions src/Init/Data/BitVec/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -2566,6 +2566,11 @@ theorem eq_sub_iff_add_eq {x y z : BitVec w} : x = z - y ↔ x + y = z := by
· simp [h, sub_add_cancel]
· simp [←h, add_sub_cancel]

theorem sub_eq_iff_eq_add {x y z : BitVec w} : x - y = z ↔ x = z + y := by
apply Iff.intro <;> intro h
· simp [← h, sub_add_cancel]
· simp [h, add_sub_cancel]

theorem negOne_eq_allOnes : -1#w = allOnes w := by
apply eq_of_toNat_eq
if g : w = 0 then
Expand Down
35 changes: 9 additions & 26 deletions src/Std/Tactic/BVDecide/Normalize/BitVec.lean
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,12 @@ theorem BitVec.and_ones (a : BitVec w) : a &&& (-1#w) = a := by
ext
simp [BitVec.negOne_eq_allOnes]

-- Normalize (1#w + ~~~x) to (~~~x + 1#w) to limit the number of symmetries we need for theorems
-- related to negative BitVecs.
@[bv_normalize]
theorem BitVec.one_plus_not_eq_not_plus_one (x : BitVec w) : (1#w + ~~~x) = (~~~x + 1#w) := by
rw [BitVec.add_comm]

attribute [bv_normalize] BitVec.and_self

@[bv_normalize]
Expand Down Expand Up @@ -207,43 +213,28 @@ theorem BitVec.add_neg (a : BitVec w) : a + (~~~a + 1#w) = 0#w := by
rw [← BitVec.sub_toAdd]
rw [BitVec.sub_self]

@[bv_normalize]
theorem BitVec.add_neg' (a : BitVec w) : a + (1#w + ~~~a) = 0#w := by
rw [BitVec.add_comm 1#w (~~~a)]
rw [BitVec.add_neg]

@[bv_normalize]
theorem BitVec.neg_add (a : BitVec w) : (~~~a + 1#w) + a = 0#w := by
rw [← BitVec.neg_eq_not_add]
rw [BitVec.add_comm]
rw [← BitVec.sub_toAdd]
rw [BitVec.sub_self]

@[bv_normalize]
theorem BitVec.neg_add' (a : BitVec w) : (1#w + ~~~a) + a = 0#w := by
rw [BitVec.add_comm 1#w (~~~a)]
rw [BitVec.neg_add]

@[bv_normalize]
theorem BitVec.not_neg (x : BitVec w) : ~~~(~~~x + 1#w) = x + -1#w := by
rw [← BitVec.neg_eq_not_add x]
rw [_root_.BitVec.not_neg]

@[bv_normalize]
theorem BitVec.not_neg' (x : BitVec w) : ~~~(1#w + ~~~x) = x + -1#w := by
rw [BitVec.add_comm 1#w (~~~x)]
rw [BitVec.not_neg]

@[bv_normalize]
theorem BitVec.not_neg'' (x : BitVec w) : ~~~(x + 1#w) = ~~~x + -1#w := by
theorem BitVec.not_neg' (x : BitVec w) : ~~~(x + 1#w) = ~~~x + -1#w := by
rw [← BitVec.not_not (b := x)]
rw [BitVec.not_neg]
simp

@[bv_normalize]
theorem BitVec.not_neg''' (x : BitVec w) : ~~~(1#w + x) = ~~~x + -1#w := by
theorem BitVec.not_neg'' (x : BitVec w) : ~~~(1#w + x) = ~~~x + -1#w := by
rw [BitVec.add_comm 1#w x]
rw [BitVec.not_neg'']
rw [BitVec.not_neg']

@[bv_normalize]
theorem BitVec.add_same (a : BitVec w) : a + a = a * 2#w := by
Expand All @@ -266,18 +257,10 @@ attribute [bv_normalize] BitVec.sshiftRight'_ofNat_eq_sshiftRight
theorem BitVec.neg_mul (x y : BitVec w) : (~~~x + 1#w) * y = ~~~(x * y) + 1#w := by
rw [← BitVec.neg_eq_not_add, ← BitVec.neg_eq_not_add, _root_.BitVec.neg_mul]

@[bv_normalize]
theorem BitVec.neg_mul' (x y : BitVec w) : (1#w + ~~~x) * y = ~~~(x * y) + 1#w := by
rw [BitVec.add_comm, BitVec.neg_mul]

@[bv_normalize]
theorem BitVec.mul_neg (x y : BitVec w) : x * (~~~y + 1#w) = ~~~(x * y) + 1#w := by
rw [← BitVec.neg_eq_not_add, ← BitVec.neg_eq_not_add, _root_.BitVec.mul_neg]

@[bv_normalize]
theorem BitVec.mul_neg' (x y : BitVec w) : x * (1#w + ~~~y) = ~~~(x * y) + 1#w := by
rw [BitVec.add_comm, BitVec.mul_neg]

attribute [bv_normalize] BitVec.shiftLeft_zero
attribute [bv_normalize] BitVec.zero_shiftLeft

Expand Down
20 changes: 20 additions & 0 deletions src/Std/Tactic/BVDecide/Normalize/Equal.lean
Original file line number Diff line number Diff line change
Expand Up @@ -105,5 +105,25 @@ theorem BitVec.self_eq_add_left (a b : BitVec w) : (a == b + a) = (b == 0#w) :=
rw [Bool.eq_iff_iff]
simp

@[bv_normalize]
theorem BitVec.eq_sub_iff_add_eq (a b c : BitVec w) : (a == c + (~~~b + 1#w)) = (a + b == c) := by
rw [Bool.eq_iff_iff, beq_iff_eq, beq_iff_eq, ← BitVec.neg_eq_not_add, ← @BitVec.sub_toAdd]
exact _root_.BitVec.eq_sub_iff_add_eq

@[bv_normalize]
theorem BitVec.eq_neg_add_iff_add_eq (a b c : BitVec w) : (a == (~~~b + 1#w) + c) = (a + b == c) := by
rw [BitVec.add_comm]
exact BitVec.eq_sub_iff_add_eq _ _ _

@[bv_normalize]
theorem BitVec.sub_eq_iff_eq_add (a b c : BitVec w) : (a + (~~~b + 1#w) == c) = (a == c + b) := by
rw [Bool.eq_iff_iff, beq_iff_eq, beq_iff_eq, ← BitVec.neg_eq_not_add, ← @BitVec.sub_toAdd]
exact _root_.BitVec.sub_eq_iff_eq_add

@[bv_normalize]
theorem BitVec.neg_add_eq_iff_eq_add (a b c : BitVec w) : ((~~~a + 1#w) + b == c) = (b == c + a) := by
rw [BitVec.add_comm]
exact BitVec.sub_eq_iff_eq_add _ _ _

end Frontend.Normalize
end Std.Tactic.BVDecide
37 changes: 32 additions & 5 deletions tests/lean/run/bv_decide_rewriter.lean
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,6 @@ example {x y : BitVec 1} : x + y = x ^^^ y := by bv_normalize
example {x y : BitVec 1} : x * y = x &&& y := by bv_normalize
example {x : BitVec 16} : x / 0 = 0 := by bv_normalize
example {x : BitVec 16} : x % 0 = x := by bv_normalize
example {x : BitVec 16} : ~~~(-x) = x + (-1#16) := by bv_normalize
example {x : BitVec 16} : ~~~(~~~x + 1#16) = x + (-1#16) := by bv_normalize
example {x : BitVec 16} : ~~~(x + 1#16) = ~~~x + (-1#16) := by bv_normalize
example {x : BitVec 16} : ~~~(1#16 + ~~~x) = x + (-1#16) := by bv_normalize
example {x : BitVec 16} : ~~~(1#16 + x) = ~~~x + (-1#16) := by bv_normalize
example {x : BitVec 16} : (10 + x) + 2 = 12 + x := by bv_normalize
example {x : BitVec 16} : (x + 10) + 2 = 12 + x := by bv_normalize
example {x : BitVec 16} : 2 + (x + 10) = 12 + x := by bv_normalize
Expand All @@ -90,6 +85,22 @@ example {x : BitVec 16} : x / (BitVec.ofNat 16 8) = x >>> 3 := by bv_normalize
example {x y : Bool} (h1 : x && y) : x || y := by bv_normalize
example (a b c: Bool) : (if a then b else c) = (if !a then c else b) := by bv_normalize

-- not_neg
example {x : BitVec 16} : ~~~(-x) = x + (-1#16) := by bv_normalize
example {x : BitVec 16} : ~~~(~~~x + 1#16) = x + (-1#16) := by bv_normalize
example {x : BitVec 16} : ~~~(x + 1#16) = ~~~x + (-1#16) := by bv_normalize
example {x : BitVec 16} : ~~~(1#16 + ~~~x) = x + (-1#16) := by bv_normalize
example {x : BitVec 16} : ~~~(1#16 + x) = ~~~x + (-1#16) := by bv_normalize

-- add_neg / neg_add
example (x : BitVec 16) : x + -x = 0 := by bv_normalize
example (x : BitVec 16) : x - x = 0 := by bv_normalize
example (x : BitVec 16) : x + (~~~x + 1) = 0 := by bv_normalize
example (x : BitVec 16) : x + (1 + ~~~x) = 0 := by bv_normalize
example (x : BitVec 16) : -x + x = 0 := by bv_normalize
example (x : BitVec 16) : (~~~x + 1) + x = 0 := by bv_normalize
example (x : BitVec 16) : (1 + ~~~x) + x = 0 := by bv_normalize

-- neg_mul / mul_neg
example (x y : BitVec 16) : (-x) * y = -(x * y) := by bv_normalize
example (x y : BitVec 16) : x * (-y) = -(x * y) := by bv_normalize
Expand Down Expand Up @@ -131,6 +142,22 @@ example (x y : BitVec 16) : (x + y == y) = (x == 0) := by bv_normalize
example (x y : BitVec 16) : (x == x + y) = (y == 0) := by bv_normalize
example (x y : BitVec 16) : (x == y + x) = (y == 0) := by bv_normalize

-- eq_sub_iff_add_eq / sub_eq_iff_eq_add
example (x y z : BitVec 16) : (x + -y == z) = (x == z + y) := by bv_normalize
example (x y z : BitVec 16) : (x - y == z) = (x == z + y) := by bv_normalize
example (x y z : BitVec 16) : (x + (~~~y + 1) == z) = (x == z + y) := by bv_normalize
example (x y z : BitVec 16) : (x + (1 + ~~~y) == z) = (x == z + y) := by bv_normalize
example (x y z : BitVec 16) : (-x + y == z) = (y == z + x) := by bv_normalize
example (x y z : BitVec 16) : ((~~~x + 1) + y == z) = (y == z + x) := by bv_normalize
example (x y z : BitVec 16) : ((1 + ~~~x) + y == z) = (y == z + x) := by bv_normalize
example (x y z : BitVec 16) : (z == x + -y) = (z + y == x) := by bv_normalize
example (x y z : BitVec 16) : (z == x - y) = (z + y == x) := by bv_normalize
example (x y z : BitVec 16) : (z == x + (~~~y + 1)) = (z + y == x) := by bv_normalize
example (x y z : BitVec 16) : (z == x + (1 + ~~~y)) = (z + y == x) := by bv_normalize
example (x y z : BitVec 16) : (z == -x + y) = (z + x == y) := by bv_normalize
example (x y z : BitVec 16) : (z == (~~~x + 1) + y) = (z + x == y) := by bv_normalize
example (x y z : BitVec 16) : (z == (1 + ~~~x) + y) = (z + x == y) := by bv_normalize

-- or_beq_zero_iff
example (x y : BitVec 16) : (x ||| y == 0) = (x == 0 && y == 0) := by bv_normalize
example (x y : BitVec 16) : (0 == x ||| y) = (x == 0 && y == 0) := by bv_normalize
Expand Down

0 comments on commit ca96ea3

Please sign in to comment.