diff --git a/src/Init/Grind/Tactics.lean b/src/Init/Grind/Tactics.lean index 708d19ceaa21..59cfe83d16e6 100644 --- a/src/Init/Grind/Tactics.lean +++ b/src/Init/Grind/Tactics.lean @@ -16,14 +16,16 @@ namespace Attr syntax grindEq := "= " syntax grindEqBoth := atomic("_" "=" "_ ") syntax grindEqRhs := atomic("=" "_ ") -syntax grindEqBwd := atomic("←" "= ") -syntax grindBwd := "← " -syntax grindFwd := "→ " +syntax grindEqBwd := atomic("←" "= ") <|> atomic("<-" "= ") +syntax grindBwd := "← " <|> "-> " +syntax grindFwd := "→ " <|> "<- " +syntax grindRL := "⇐ " <|> "<= " +syntax grindLR := "⇒ " <|> "=> " syntax grindUsr := &"usr " syntax grindCases := &"cases " syntax grindCasesEager := atomic(&"cases" &"eager ") syntax grindIntro := &"intro " -syntax grindMod := grindEqBoth <|> grindEqRhs <|> grindEq <|> grindEqBwd <|> grindBwd <|> grindFwd <|> grindUsr <|> grindCasesEager <|> grindCases <|> grindIntro +syntax grindMod := grindEqBoth <|> grindEqRhs <|> grindEq <|> grindEqBwd <|> grindBwd <|> grindFwd <|> grindRL <|> grindLR <|> grindUsr <|> grindCasesEager <|> grindCases <|> grindIntro syntax (name := grind) "grind" (grindMod)? : attr end Attr end Lean.Parser diff --git a/src/Lean/Elab/Tactic/Grind.lean b/src/Lean/Elab/Tactic/Grind.lean index 91cd97085674..32955d3924c0 100644 --- a/src/Lean/Elab/Tactic/Grind.lean +++ b/src/Lean/Elab/Tactic/Grind.lean @@ -197,14 +197,16 @@ private def mkGrindOnly else let decl : Ident := mkIdent (← unresolveNameGlobalAvoidingLocals declName) let param ← match kind with - | .eqLhs => `(Parser.Tactic.grindParam| = $decl) - | .eqRhs => `(Parser.Tactic.grindParam| =_ $decl) - | .eqBoth => `(Parser.Tactic.grindParam| _=_ $decl) - | .eqBwd => `(Parser.Tactic.grindParam| ←= $decl) - | .bwd => `(Parser.Tactic.grindParam| ← $decl) - | .fwd => `(Parser.Tactic.grindParam| → $decl) - | .user => `(Parser.Tactic.grindParam| usr $decl) - | .default => `(Parser.Tactic.grindParam| $decl:ident) + | .eqLhs => `(Parser.Tactic.grindParam| = $decl) + | .eqRhs => `(Parser.Tactic.grindParam| =_ $decl) + | .eqBoth => `(Parser.Tactic.grindParam| _=_ $decl) + | .eqBwd => `(Parser.Tactic.grindParam| ←= $decl) + | .bwd => `(Parser.Tactic.grindParam| ← $decl) + | .fwd => `(Parser.Tactic.grindParam| → $decl) + | .leftRight => `(Parser.Tactic.grindParam| => $decl) + | .rightLeft => `(Parser.Tactic.grindParam| <= $decl) + | .user => `(Parser.Tactic.grindParam| usr $decl) + | .default => `(Parser.Tactic.grindParam| $decl:ident) params := params.push param for declName in trace.eagerCases.toList do unless Grind.isBuiltinEagerCases declName do diff --git a/src/Lean/Meta/Tactic/Grind/Attr.lean b/src/Lean/Meta/Tactic/Grind/Attr.lean index 505b5ae3dce4..60acdebd6aeb 100644 --- a/src/Lean/Meta/Tactic/Grind/Attr.lean +++ b/src/Lean/Meta/Tactic/Grind/Attr.lean @@ -19,11 +19,17 @@ inductive AttrKind where def getAttrKindCore (stx : Syntax) : CoreM AttrKind := do match stx with | `(Parser.Attr.grindMod| =) => return .ematch .eqLhs - | `(Parser.Attr.grindMod| →) => return .ematch .fwd - | `(Parser.Attr.grindMod| ←) => return .ematch .bwd + | `(Parser.Attr.grindMod| →) + | `(Parser.Attr.grindMod| ->) => return .ematch .fwd + | `(Parser.Attr.grindMod| ←) + | `(Parser.Attr.grindMod| <-) => return .ematch .bwd | `(Parser.Attr.grindMod| =_) => return .ematch .eqRhs | `(Parser.Attr.grindMod| _=_) => return .ematch .eqBoth | `(Parser.Attr.grindMod| ←=) => return .ematch .eqBwd + | `(Parser.Attr.grindMod| ⇒) + | `(Parser.Attr.grindMod| =>) => return .ematch .leftRight + | `(Parser.Attr.grindMod| ⇐) + | `(Parser.Attr.grindMod| <=) => return .ematch .rightLeft | `(Parser.Attr.grindMod| usr) => return .ematch .user | `(Parser.Attr.grindMod| cases) => return .cases false | `(Parser.Attr.grindMod| cases eager) => return .cases true diff --git a/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean b/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean index 26988727d034..c5219109cb03 100644 --- a/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean +++ b/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean @@ -93,28 +93,32 @@ instance : Hashable Origin where hash a := hash a.key inductive EMatchTheoremKind where - | eqLhs | eqRhs | eqBoth | eqBwd | fwd | bwd | default | user /- pattern specified using `grind_pattern` command -/ + | eqLhs | eqRhs | eqBoth | eqBwd | fwd | bwd | leftRight | rightLeft | default | user /- pattern specified using `grind_pattern` command -/ deriving Inhabited, BEq, Repr, Hashable private def EMatchTheoremKind.toAttribute : EMatchTheoremKind → String - | .eqLhs => "[grind =]" - | .eqRhs => "[grind =_]" - | .eqBoth => "[grind _=_]" - | .eqBwd => "[grind ←=]" - | .fwd => "[grind →]" - | .bwd => "[grind ←]" - | .default => "[grind]" - | .user => "[grind]" + | .eqLhs => "[grind =]" + | .eqRhs => "[grind =_]" + | .eqBoth => "[grind _=_]" + | .eqBwd => "[grind ←=]" + | .fwd => "[grind →]" + | .bwd => "[grind ←]" + | .leftRight => "[grind =>]" + | .rightLeft => "[grind <=]" + | .default => "[grind]" + | .user => "[grind]" private def EMatchTheoremKind.explainFailure : EMatchTheoremKind → String - | .eqLhs => "failed to find pattern in the left-hand side of the theorem's conclusion" - | .eqRhs => "failed to find pattern in the right-hand side of the theorem's conclusion" - | .eqBoth => unreachable! -- eqBoth is a macro - | .eqBwd => "failed to use theorem's conclusion as a pattern" - | .fwd => "failed to find patterns in the antecedents of the theorem" - | .bwd => "failed to find patterns in the theorem's conclusion" - | .default => "failed to find patterns" - | .user => unreachable! + | .eqLhs => "failed to find pattern in the left-hand side of the theorem's conclusion" + | .eqRhs => "failed to find pattern in the right-hand side of the theorem's conclusion" + | .eqBoth => unreachable! -- eqBoth is a macro + | .eqBwd => "failed to use theorem's conclusion as a pattern" + | .fwd => "failed to find patterns in the antecedents of the theorem" + | .bwd => "failed to find patterns in the theorem's conclusion" + | .leftRight => "failed to find patterns searching from left to right" + | .rightLeft => "failed to find patterns searching from right to left" + | .default => "failed to find patterns" + | .user => unreachable! /-- A theorem for heuristic instantiation based on E-matching. -/ structure EMatchTheorem where @@ -664,22 +668,24 @@ where | .bvar idx => modify fun s => if s.contains idx then s else idx :: s | _ => return () +private def diff (s : List Nat) (found : Std.HashSet Nat) : List Nat := + if found.isEmpty then s else s.filter fun x => !found.contains x + /-- Returns `true` if pattern `p` contains a child `c` such that -1- `p` and `c` have the same pattern variables. +1- `p` and `c` have the same new pattern variables. We say a pattern variable is new if it is not in `alreadyFound`. 2- `c` is not a support argument. See `NormalizePattern.getPatternSupportMask` for definition. 3- `c` is not an offset pattern. 4- `c` is not a bound variable. -/ -private def hasChildWithSameBVars (p : Expr) (supportMask : Array Bool) : CoreM Bool := do - let s := collectPatternBVars p +private def hasChildWithSameNewBVars (p : Expr) (supportMask : Array Bool) (alreadyFound : Std.HashSet Nat) : CoreM Bool := do + let s := diff (collectPatternBVars p) alreadyFound for arg in p.getAppArgs, support in supportMask do unless support do unless arg.isBVar do unless isOffsetPattern? arg |>.isSome do - let sArg := collectPatternBVars arg + let sArg := diff (collectPatternBVars arg) alreadyFound if s ⊆ sArg then - trace[Meta.debug] "SKIPPED: {p}, {arg}, {s}, {sArg}" return true return false @@ -699,7 +705,7 @@ private partial def collect (e : Expr) : CollectorM Unit := do return () let p ← NormalizePattern.normalizePattern p if saved.bvarsFound.size < (← getThe NormalizePattern.State).bvarsFound.size then - unless (← hasChildWithSameBVars p supportMask) do + unless (← hasChildWithSameNewBVars p supportMask saved.bvarsFound) do addNewPattern p return () trace[grind.ematch.pattern.search] "skip, no new variables covered" @@ -812,6 +818,8 @@ def mkEMatchTheoremWithKind? throwError "invalid `grind` forward theorem, theorem `{← origin.pp}` does not have propositional hypotheses" pure ps | .bwd => pure #[type] + | .leftRight => pure <| (← getPropTypes xs).push type + | .rightLeft => pure <| #[type] ++ (← getPropTypes xs).reverse | .default => pure <| #[type] ++ (← getPropTypes xs) | _ => unreachable! go xs searchPlaces diff --git a/src/Lean/Meta/Tactic/Grind/ForallProp.lean b/src/Lean/Meta/Tactic/Grind/ForallProp.lean index 3fada63a9faf..77a14c162b75 100644 --- a/src/Lean/Meta/Tactic/Grind/ForallProp.lean +++ b/src/Lean/Meta/Tactic/Grind/ForallProp.lean @@ -70,9 +70,9 @@ private def addLocalEMatchTheorems (e : Expr) : GoalM Unit := do let size := (← get).newThms.size let gen ← getGeneration e -- TODO: we should have a flag for collecting all unary patterns in a local theorem - if let some thm ← mkEMatchTheoremWithKind'? origin proof .fwd then + if let some thm ← mkEMatchTheoremWithKind'? origin proof .leftRight then activateTheorem thm gen - if let some thm ← mkEMatchTheoremWithKind'? origin proof .bwd then + if let some thm ← mkEMatchTheoremWithKind'? origin proof .rightLeft then activateTheorem thm gen if (← get).newThms.size == size then if let some thm ← mkEMatchTheoremWithKind'? origin proof .default then diff --git a/tests/lean/run/grind_attrs.lean b/tests/lean/run/grind_attrs.lean new file mode 100644 index 000000000000..67938973f005 --- /dev/null +++ b/tests/lean/run/grind_attrs.lean @@ -0,0 +1,52 @@ +opaque R : Nat → Nat → Prop + +@[grind ->] +axiom Rtrans {x y z : Nat} : R x y → R y z → R x z + +@[grind →] +axiom Rtrans' {x y z : Nat} : R x y → R y z → R x z + +@[grind <-] +axiom Rsymm {x y : Nat} : R x y → R y x + +@[grind ←] +axiom Rsymm' {x y : Nat} : R x y → R y x + +example : R a b → R b c → R d c → R a d := by + grind only [-> Rtrans, <- Rsymm] + +example : R a b → R b c → R d c → R a d := by + grind only [→ Rtrans, ← Rsymm] + + +opaque State : Type +opaque State.le (σ₁ σ₂ : State) : Prop +axiom State.update : State → Nat → Nat → State +opaque Expr : Type +opaque Expr.eval : Expr → State → Nat +axiom Expr.constProp : Expr → State → Expr + + +/-- +info: [grind.ematch.pattern] Expr.eval_constProp_of_sub: [State.le #3 #2, Expr.constProp #1 #3] +-/ +#guard_msgs (info) in +set_option trace.grind.ematch.pattern true in +@[grind =>] theorem Expr.eval_constProp_of_sub (e : Expr) (h : State.le σ' σ) : (e.constProp σ').eval σ = e.eval σ := + sorry + +/-- +info: [grind.ematch.pattern] Expr.eval_constProp_of_eq_of_sub: [State.le #3 #2, Expr.constProp #1 #3] +-/ +#guard_msgs (info) in +set_option trace.grind.ematch.pattern true in +@[grind =>] theorem Expr.eval_constProp_of_eq_of_sub {e : Expr} (h₂ : State.le σ' σ) : (e.constProp σ').eval σ = e.eval σ := + sorry + +/-- +info: [grind.ematch.pattern] State.update_le_update: [State.le #4 #3, State.update #4 #2 #1] +-/ +#guard_msgs (info) in +set_option trace.grind.ematch.pattern true in +@[grind =>] theorem State.update_le_update (h : State.le σ' σ) : State.le (σ'.update x v) (σ.update x v) := + sorry diff --git a/tests/lean/run/grind_constProp.lean b/tests/lean/run/grind_constProp.lean index fca4e6d1eda9..8c40908ee150 100644 --- a/tests/lean/run/grind_constProp.lean +++ b/tests/lean/run/grind_constProp.lean @@ -280,7 +280,7 @@ attribute [local grind] State.le State.erase State.find? State.update theorem State.le_trans : σ₁ ≼ σ₂ → σ₂ ≼ σ₃ → σ₁ ≼ σ₃ := by grind -theorem State.bot_le (σ : State) : ⊥ ≼ σ := by +@[grind] theorem State.bot_le (σ : State) : ⊥ ≼ σ := by grind theorem State.erase_le_cons (h : σ' ≼ σ) : σ'.erase x ≼ ((x, v) :: σ) := by @@ -320,41 +320,25 @@ theorem State.erase_le_of_le_cons (h : σ' ≼ (x, v) :: σ) : σ'.erase x ≼ @[grind] theorem State.erase_le_update (h : σ' ≼ σ) : σ'.erase x ≼ σ.update x v := by grind -@[grind] theorem State.update_le_update (h : σ' ≼ σ) : σ'.update x v ≼ σ.update x v := by +@[grind =>] theorem State.update_le_update (h : σ' ≼ σ) : σ'.update x v ≼ σ.update x v := by grind -grind_pattern State.update_le_update => σ' ≼ σ, σ'.update x v - -@[grind] theorem Expr.eval_constProp_of_sub (e : Expr) (h : σ' ≼ σ) : (e.constProp σ').eval σ = e.eval σ := by +@[grind =>] theorem Expr.eval_constProp_of_sub (e : Expr) (h : σ' ≼ σ) : (e.constProp σ').eval σ = e.eval σ := by induction e <;> grind --- TODO: better pattern selection heuristic. We want to avoid the following step. -grind_pattern Expr.eval_constProp_of_sub => σ' ≼ σ, e.constProp σ' - -theorem Expr.eval_constProp_of_eq_of_sub {e : Expr} (h₂ : σ' ≼ σ) : (e.constProp σ').eval σ = e.eval σ := by +@[grind =>] theorem Expr.eval_constProp_of_eq_of_sub {e : Expr} (h₂ : σ' ≼ σ) : (e.constProp σ').eval σ = e.eval σ := by grind -grind_pattern Expr.eval_constProp_of_eq_of_sub => σ' ≼ σ, e.constProp σ' - -theorem Stmt.constProp_sub (h₁ : (σ₁, s) ⇓ σ₂) (h₂ : σ₁' ≼ σ₁) : (s.constProp σ₁').2 ≼ σ₂ := by +@[grind =>] theorem Stmt.constProp_sub (h₁ : (σ₁, s) ⇓ σ₂) (h₂ : σ₁' ≼ σ₁) : (s.constProp σ₁').2 ≼ σ₂ := by induction h₁ generalizing σ₁' with grind [=_ Expr.eval_simplify] -grind_pattern Stmt.constProp_sub => (σ₁, s) ⇓ σ₂, s.constProp σ₁' - end -theorem Stmt.constProp_correct (h₁ : (σ₁, s) ⇓ σ₂) (h₂ : σ₁' ≼ σ₁) : (σ₁, (s.constProp σ₁').1) ⇓ σ₂ := by +@[grind] theorem Stmt.constProp_correct (h₁ : (σ₁, s) ⇓ σ₂) (h₂ : σ₁' ≼ σ₁) : (σ₁, (s.constProp σ₁').1) ⇓ σ₂ := by induction h₁ generalizing σ₁' <;> try grind [=_ Expr.eval_simplify, intro Bigstep] - next heq h₁ h₂ ih₁ ih₂ => - -- TODO: we need better heuristics for selecting patterns for local quantifiers. - -- both `ih₁` and `ih₂` are local, and the current pattern selection picks reall bad patterns. - have ih₁ := ih₁ (State.bot_le _) - have ih₂ := ih₂ (State.bot_le _) - grind [intro Bigstep, constProp] - -def Stmt.constPropagation (s : Stmt) : Stmt := + +@[grind] def Stmt.constPropagation (s : Stmt) : Stmt := (s.constProp ⊥).1 theorem Stmt.constPropagation_correct (h : (σ, s) ⇓ σ') : (σ, s.constPropagation) ⇓ σ' := by - -- TODO: grind [constProp_correct, State.bot_le] - exact constProp_correct h (State.bot_le _) + grind diff --git a/tests/lean/run/grind_ematch2.lean b/tests/lean/run/grind_ematch2.lean index c300622ac8ec..2d3d6f384738 100644 --- a/tests/lean/run/grind_ematch2.lean +++ b/tests/lean/run/grind_ematch2.lean @@ -47,19 +47,19 @@ info: [grind] Counters [thm] Array.size_set ↦ 3 --- info: [diag] Diagnostics - [reduction] unfolded declarations (max: 11839, num: 3): - [reduction] LT.lt ↦ 11839 - [reduction] getElem ↦ 64 - [reduction] Nat.lt ↦ 32 - [reduction] unfolded instances (max: 32, num: 1): - [reduction] Array.instGetElemNatLtSize ↦ 32 - [reduction] unfolded reducible declarations (max: 7079, num: 7): - [reduction] Array.size ↦ 7079 - [reduction] Array.toList ↦ 1885 - [reduction] autoParam ↦ 1715 - [reduction] outParam ↦ 124 - [reduction] Ne ↦ 57 - [reduction] GT.gt ↦ 40 + [reduction] unfolded declarations (max: 11842, num: 3): + [reduction] LT.lt ↦ 11842 + [reduction] getElem ↦ 76 + [reduction] Nat.lt ↦ 35 + [reduction] unfolded instances (max: 38, num: 1): + [reduction] Array.instGetElemNatLtSize ↦ 38 + [reduction] unfolded reducible declarations (max: 7091, num: 7): + [reduction] Array.size ↦ 7091 + [reduction] Array.toList ↦ 1897 + [reduction] autoParam ↦ 1724 + [reduction] outParam ↦ 172 + [reduction] Ne ↦ 60 + [reduction] GT.gt ↦ 46 [reduction] List.casesOn ↦ 24 [def_eq] heuristic for solving `f a =?= f b` (max: 5067, num: 2): [def_eq] Nat.lt ↦ 5067 diff --git a/tests/lean/run/grind_eq_bwd.lean b/tests/lean/run/grind_eq_bwd.lean index 16214c52b7b1..46b3f3db103e 100644 --- a/tests/lean/run/grind_eq_bwd.lean +++ b/tests/lean/run/grind_eq_bwd.lean @@ -26,6 +26,9 @@ attribute [grind ←=] inv_eq example {a b : α} (w : mul a b = one) : inv a = b := by grind +example {a b : α} (w : mul a b = one) : inv a = b := by + grind only [<-= inv_eq] + structure S where f : Bool → α h : mul (f true) (f false) = one @@ -36,6 +39,9 @@ attribute [grind =] S.h S.h' example (s : S) : inv (s.f true) = s.f false := by grind +example (s : S) : inv (s.f true) = s.f false := by + grind only [<-= inv_eq, = S.h] + example (s : S) : s.f false = inv (s.f true) := by grind diff --git a/tests/lean/run/grind_pattern2.lean b/tests/lean/run/grind_pattern2.lean index 8e0db05208be..5c0cee1e4d81 100644 --- a/tests/lean/run/grind_pattern2.lean +++ b/tests/lean/run/grind_pattern2.lean @@ -56,10 +56,8 @@ info: [grind.internalize] foo x y [grind.internalize] x [grind.internalize] y [grind.internalize] z ---- -warning: declaration uses 'sorry' -/ -#guard_msgs in +#guard_msgs (info) in set_option trace.grind.internalize true in example : foo x y = z → False := by fail_if_success grind diff --git a/tests/lean/run/grind_trace.lean b/tests/lean/run/grind_trace.lean index 41bbc4ce76fe..3c79c764f3f6 100644 --- a/tests/lean/run/grind_trace.lean +++ b/tests/lean/run/grind_trace.lean @@ -41,9 +41,7 @@ info: Try this: grind only [= List.getElem?_replicate, = List.getElem?_eq_some_i theorem map_replicate' : (List.replicate n a).map f = List.replicate n (f a) := by grind? -/-- -info: Try this: grind only [= List.getLast?_eq_some_iff, List.mem_concat_self] --/ +/-- info: Try this: grind only [List.mem_concat_self, = List.getLast?_eq_some_iff] -/ #guard_msgs (info) in theorem mem_of_getLast?_eq_some' {xs : List α} {a : α} (h : xs.getLast? = some a) : a ∈ xs := by grind? diff --git a/tests/lean/run/no_grind_constProp.lean b/tests/lean/run/no_grind_constProp.lean new file mode 100644 index 000000000000..df1035475e2d --- /dev/null +++ b/tests/lean/run/no_grind_constProp.lean @@ -0,0 +1,508 @@ +set_option profiler true + +abbrev Var := String + +inductive Val where + | int (i : Int) + | bool (b : Bool) + deriving DecidableEq, Repr + +instance : Coe Bool Val where + coe b := .bool b + +instance : OfNat Val n where + ofNat := .int n + +inductive BinOp where + | eq | and | lt | add | sub + deriving Repr + +inductive UnaryOp where + | not + deriving Repr + +inductive Expr where + | val (v : Val) + | var (x : Var) + | bin (lhs : Expr) (op : BinOp) (rhs : Expr) + | una (op : UnaryOp) (arg : Expr) + deriving Repr + +@[simp] def BinOp.eval : BinOp → Val → Val → Option Val + | .eq, v₁, v₂ => some (.bool (v₁ = v₂)) + | .and, .bool b₁, .bool b₂ => some (.bool (b₁ && b₂)) + | .lt, .int i₁, .int i₂ => some (.bool (i₁ < i₂)) + | .add, .int i₁, .int i₂ => some (.int (i₁ + i₂)) + | .sub, .int i₁, .int i₂ => some (.int (i₁ - i₂)) + | _, _, _ => none + +@[simp] def UnaryOp.eval : UnaryOp → Val → Option Val + | .not, .bool b => some (.bool !b) + | _, _ => none + +inductive Stmt where + | skip + | assign (x : Var) (e : Expr) + | seq (s₁ s₂ : Stmt) + | ite (c : Expr) (e t : Stmt) + | while (c : Expr) (b : Stmt) + deriving Repr + +infix:150 " ::= " => Stmt.assign +infixr:130 ";; " => Stmt.seq + +abbrev State := List (Var × Val) + +@[simp] def State.update (σ : State) (x : Var) (v : Val) : State := + match σ with + | [] => [(x, v)] + | (y, w)::σ => if x = y then (x, v)::σ else (y, w) :: update σ x v + +@[simp] def State.find? (σ : State) (x : Var) : Option Val := + match σ with + | [] => none + | (y, v) :: σ => if x = y then some v else find? σ x + +def State.get (σ : State) (x : Var) : Val := + σ.find? x |>.getD (.int 0) + +@[simp] def State.erase (σ : State) (x : Var) : State := + match σ with + | [] => [] + | (y, v) :: σ => if x = y then erase σ x else (y, v) :: erase σ x + +@[simp] theorem State.find?_update_self (σ : State) (x : Var) (v : Val) : (σ.update x v).find? x = some v := by + match σ with -- TODO: automate this proof + | [] => simp + | (y, w) :: s => + simp + split <;> simp [*] + apply find?_update_self + +@[simp] theorem State.find?_update (σ : State) (v : Val) (h : x ≠ z) : (σ.update x v).find? z = σ.find? z := by + match σ with -- TODO: automate this proof + | [] => simp [h.symm] + | (y, w) :: σ => + simp + split <;> simp [*] + next hc => split <;> simp_all + next => + split + next => rfl + next => exact find?_update σ v h + +-- TODO: remove after we add better automation +@[simp] theorem State.find?_update' (σ : State) (v : Val) (h : z ≠ x) : (σ.update x v).find? z = σ.find? z := + State.find?_update σ v h.symm + +theorem State.get_of_find? {σ : State} (h : σ.find? x = some v) : σ.get x = v := by + simp [State.get, h, Option.getD] + +@[simp] theorem State.find?_erase_self (σ : State) (x : Var) : (σ.erase x).find? x = none := by + match σ with + | [] => simp + | (y, w) :: σ => + simp + split <;> simp [*] + next => exact find?_erase_self σ y + next => exact find?_erase_self σ x + +@[simp] theorem State.find?_erase (σ : State) (h : x ≠ z) : (σ.erase x).find? z = σ.find? z := by + match σ with + | [] => simp + | (y, w) :: σ => + simp + split <;> simp [*] + next hxy => rw [hxy] at h; simp [h.symm]; exact find?_erase σ h + next => + split + next => rfl + next => exact find?_erase σ h + +-- TODO: remove after we add better automation +@[simp] theorem State.find?_erase' (σ : State) (h : z ≠ x) : (σ.erase x).find? z = σ.find? z := + State.find?_erase σ h.symm + +syntax ident " ↦ " term : term + +macro_rules + | `($id:ident ↦ $v:term) => `(($(Lean.quote id.getId.toString), $v:term)) + +example : State.get [x ↦ .int 10, y ↦ .int 20] "x" = .int 10 := rfl + +example : State.get [x ↦ 10, y ↦ 20] "x" = 10 := rfl +example : State.get [x ↦ 10, y ↦ true] "y" = true := rfl + +@[simp] def Expr.eval (σ : State) : Expr → Option Val + | val v => some v + | var x => σ.get x + | bin lhs op rhs => match lhs.eval σ, rhs.eval σ with + | some v₁, some v₂ => op.eval v₁ v₂ -- BinOp.eval : BinOp → Val → Val → Option Val + | _, _ => none + | una op arg => match arg.eval σ with + | some v => op.eval v + | _ => none + + +@[simp] def evalTrue (c : Expr) (σ : State) : Prop := c.eval σ = some (Val.bool true) +@[simp] def evalFalse (c : Expr) (σ : State) : Prop := c.eval σ = some (Val.bool false) + +section +set_option hygiene false -- HACK: allow forward reference in notation +local notation:60 "(" σ ", " s ")" " ⇓ " σ':60 => Bigstep σ s σ' + +inductive Bigstep : State → Stmt → State → Prop where + | skip : (σ, .skip) ⇓ σ + | assign : e.eval σ = some v → (σ, x ::= e) ⇓ σ.update x v + | seq : (σ₁, s₁) ⇓ σ₂ → (σ₂, s₂) ⇓ σ₃ → (σ₁, s₁ ;; s₂) ⇓ σ₃ + | ifTrue : evalTrue c σ₁ → (σ₁, t) ⇓ σ₂ → (σ₁, .ite c t e) ⇓ σ₂ + | ifFalse : evalFalse c σ₁ → (σ₁, e) ⇓ σ₂ → (σ₁, .ite c t e) ⇓ σ₂ + | whileTrue : evalTrue c σ₁ → (σ₁, b) ⇓ σ₂ → (σ₂, .while c b) ⇓ σ₃ → (σ₁, .while c b) ⇓ σ₃ + | whileFalse : evalFalse c σ → (σ, .while c b) ⇓ σ + +end + +notation:60 "(" σ ", " s ")" " ⇓ " σ':60 => Bigstep σ s σ' + +/- This proof can be automated using forward reasoning. -/ +theorem Bigstem.det (h₁ : (σ, s) ⇓ σ₁) (h₂ : (σ, s) ⇓ σ₂) : σ₁ = σ₂ := by + induction h₁ generalizing σ₂ <;> cases h₂ <;> try simp_all + -- The rest of this proof should be automatic with congruence closure and a bit of forward reasoning + case seq ih₁ ih₂ _ h₁ h₂ => + simp [ih₁ h₁] at ih₂ + simp [ih₂ h₂] + case ifTrue ih h => + simp [ih h] + case ifFalse ih h => + simp [ih h] + case whileTrue ih₁ ih₂ h₁ h₂ => + simp [ih₁ h₁] at ih₂ + simp [ih₂ h₂] + +abbrev EvalM := ExceptT String (StateM State) + +def evalExpr (e : Expr) : EvalM Val := do + match e.eval (← get) with + | some v => return v + | none => throw "failed to evaluate" + +@[simp] def Stmt.eval (stmt : Stmt) (fuel : Nat := 100) : EvalM Unit := do + match fuel with + | 0 => throw "out of fuel" + | fuel+1 => + match stmt with + | skip => return () + | assign x e => let v ← evalExpr e; modify fun s => s.update x v + | seq s₁ s₂ => s₁.eval fuel; s₂.eval fuel + | ite c e t => + match (← evalExpr c) with + | .bool true => e.eval fuel + | .bool false => t.eval fuel + | _ => throw "Boolean expected" + | .while c b => + match (← evalExpr c) with + | .bool true => b.eval fuel; stmt.eval fuel + | .bool false => return () + | _ => throw "Boolean expected" + + +@[simp] def BinOp.simplify : BinOp → Expr → Expr → Expr + | .eq, .val v₁, .val v₂ => .val (.bool (v₁ = v₂)) + | .and, .val (.bool a), .val (.bool b) => .val (.bool (a && b)) + | .lt, .val (.int a), .val (.int b) => .val (.bool (a < b)) + | .add, .val (.int a), .val (.int b) => .val (.int (a + b)) + | .sub, .val (.int a), .val (.int b) => .val (.int (a - b)) + | op, a, b => .bin a op b + +@[simp] def UnaryOp.simplify : UnaryOp → Expr → Expr + | .not, .val (.bool b) => .val (.bool !b) + | op, a => .una op a + +@[simp] def Expr.simplify : Expr → Expr + | bin lhs op rhs => op.simplify lhs.simplify rhs.simplify + | una op arg => op.simplify arg.simplify + | e => e + +@[simp] theorem Expr.eval_simplify (e : Expr) : e.simplify.eval σ = e.eval σ := by + induction e with + -- Due to fine-grained equational theorems we have to pass `eq_def` lemmas here + simp only [simplify, BinOp.simplify.eq_def, eval, UnaryOp.simplify.eq_def] + | bin lhs op rhs ih_lhs ih_rhs => + simp [← ih_lhs, ← ih_rhs] + split <;> simp [*] + | una op arg ih_arg => + simp [← ih_arg] + split <;> simp [*] + +@[simp] def Stmt.simplify : Stmt → Stmt + | skip => skip + | assign x e => assign x e.simplify + | seq s₁ s₂ => seq s₁.simplify s₂.simplify + | ite c e t => + match c.simplify with + | .val (.bool true) => e.simplify + | .val (.bool false) => t.simplify + | c' => ite c' e.simplify t.simplify + | .while c b => + match c.simplify with + | .val (.bool false) => skip + | c' => .while c' b.simplify + +theorem Stmt.simplify_correct (h : (σ, s) ⇓ σ') : (σ, s.simplify) ⇓ σ' := by + induction h with simp_all + | skip => exact Bigstep.skip + | seq h₁ h₂ ih₁ ih₂ => exact Bigstep.seq ih₁ ih₂ + | assign => apply Bigstep.assign; simp [*] + | whileTrue heq h₁ h₂ ih₁ ih₂ => + rw [← Expr.eval_simplify] at heq + split + next h => rw [h] at heq; simp at heq + next hnp => simp [hnp] at ih₂; apply Bigstep.whileTrue heq ih₁ ih₂ + | whileFalse heq => + split + next => exact Bigstep.skip + next => apply Bigstep.whileFalse; simp [heq] + | ifFalse heq h ih => + rw [← Expr.eval_simplify] at heq + split <;> simp_all + rw [← Expr.eval_simplify] at heq + apply Bigstep.ifFalse heq ih + | ifTrue heq h ih => + rw [← Expr.eval_simplify] at heq + split <;> simp_all + rw [← Expr.eval_simplify] at heq + apply Bigstep.ifTrue heq ih + +@[simp] def Expr.constProp (e : Expr) (σ : State) : Expr := + match e with + | val v => .val v + | var x => match σ.find? x with + | some v => val v + | none => var x + | bin lhs op rhs => bin (lhs.constProp σ) op (rhs.constProp σ) + | una op arg => una op (arg.constProp σ) + +@[simp] theorem Expr.constProp_nil (e : Expr) : e.constProp [] = e := by + induction e <;> simp [*] + +def State.length_erase_le (σ : State) (x : Var) : (σ.erase x).length ≤ σ.length := by + match σ with + | [] => simp + | (y, v) :: σ => + by_cases hxy : x = y <;> simp [hxy] + next => exact Nat.le_trans (length_erase_le σ y) (by simp_arith) + next => simp_arith [length_erase_le σ x] + +def State.length_erase_lt (σ : State) (x : Var) : (σ.erase x).length < σ.length.succ := + Nat.lt_of_le_of_lt (length_erase_le ..) (by simp_arith) + +@[simp] def State.join (σ₁ σ₂ : State) : State := + match σ₁ with + | [] => [] + | (x, v) :: σ₁ => + let σ₁' := erase σ₁ x -- Must remove duplicates. Alternative design: carry invariant that input state at constProp has no duplicates + have : (erase σ₁ x).length < σ₁.length.succ := length_erase_lt .. + match σ₂.find? x with + | some w => if v = w then (x, v) :: join σ₁' σ₂ else join σ₁' σ₂ + | none => join σ₁' σ₂ +termination_by σ₁.length + +local notation "⊥" => [] + +@[simp] def Stmt.constProp (s : Stmt) (σ : State) : Stmt × State := + match s with + | skip => (skip, σ) + | assign x e => match (e.constProp σ).simplify with + | (.val v) => (assign x (.val v), σ.update x v) + | e' => (assign x e', σ.erase x) + | seq s₁ s₂ => match s₁.constProp σ with + | (s₁', σ₁) => match s₂.constProp σ₁ with + | (s₂', σ₂) => (seq s₁' s₂', σ₂) + | ite c s₁ s₂ => + match s₁.constProp σ, s₂.constProp σ with + | (s₁', σ₁), (s₂', σ₂) => (ite (c.constProp σ) s₁' s₂', σ₁.join σ₂) + | .while c b => (.while (c.constProp ⊥) (b.constProp ⊥).1, ⊥) + +def State.le (σ₁ σ₂ : State) : Prop := + ∀ ⦃x : Var⦄ ⦃v : Val⦄, σ₁.find? x = some v → σ₂.find? x = some v + +infix:50 " ≼ " => State.le + +theorem State.le_refl (σ : State) : σ ≼ σ := + fun _ _ h => h + +theorem State.le_trans : σ₁ ≼ σ₂ → σ₂ ≼ σ₃ → σ₁ ≼ σ₃ := + fun h₁ h₂ x v h => h₂ (h₁ h) + +theorem State.bot_le (σ : State) : ⊥ ≼ σ := + fun _ _ h => by contradiction + +theorem State.erase_le_cons (h : σ' ≼ σ) : σ'.erase x ≼ ((x, v) :: σ) := by + intro y w hf' + by_cases hyx : y = x <;> simp [*] at hf' |- + exact h hf' + +theorem State.cons_le_cons (h : σ' ≼ σ) : (x, v) :: σ' ≼ (x, v) :: σ := by + intro y w hf' + by_cases hyx : y = x <;> simp [*] at hf' |- + next => assumption + next => exact h hf' + +theorem State.cons_le_of_eq (h₁ : σ' ≼ σ) (h₂ : σ.find? x = some v) : (x, v) :: σ' ≼ σ := by + intro y w hf' + by_cases hyx : y = x <;> simp [*] at hf' |- + next => assumption + next => exact h₁ hf' + +theorem State.erase_le (σ : State) : σ.erase x ≼ σ := by + match σ with + | [] => simp; apply le_refl + | (y, v) :: σ => + simp + split <;> try simp [*] + next => apply erase_le_cons; apply le_refl + next => apply cons_le_cons; apply erase_le + +theorem State.join_le_left (σ₁ σ₂ : State) : σ₁.join σ₂ ≼ σ₁ := by + match σ₁ with + | [] => simp; apply le_refl + | (x, v) :: σ₁ => + simp + have : (erase σ₁ x).length < σ₁.length.succ := length_erase_lt .. + have ih := join_le_left (State.erase σ₁ x) σ₂ + split + next y w h => + split + next => apply cons_le_cons; apply le_trans ih (erase_le _) + next => apply le_trans ih (erase_le_cons (le_refl _)) + next h => apply le_trans ih (erase_le_cons (le_refl _)) +termination_by σ₁.length + +theorem State.join_le_left_of (h : σ₁ ≼ σ₂) (σ₃ : State) : σ₁.join σ₃ ≼ σ₂ := + le_trans (join_le_left σ₁ σ₃) h + +theorem State.join_le_right (σ₁ σ₂ : State) : σ₁.join σ₂ ≼ σ₂ := by + match σ₁ with + | [] => simp; apply bot_le + | (x, v) :: σ₁ => + simp + have : (erase σ₁ x).length < σ₁.length.succ := length_erase_lt .. + have ih := join_le_right (erase σ₁ x) σ₂ + split + next y w h => + split <;> simp [*] + next => apply cons_le_of_eq ih h + next h => assumption +termination_by σ₁.length + +theorem State.join_le_right_of (h : σ₁ ≼ σ₂) (σ₃ : State) : σ₃.join σ₁ ≼ σ₂ := + le_trans (join_le_right σ₃ σ₁) h + +theorem State.eq_bot (h : σ ≼ ⊥) : σ = ⊥ := by + match σ with + | [] => simp + | (y, v) :: σ => + have : State.find? ((y, v) :: σ) y = some v := by simp + have := h this + contradiction + +theorem State.erase_le_of_le_cons (h : σ' ≼ (x, v) :: σ) : σ'.erase x ≼ σ := by + intro y w hf' + by_cases hxy : x = y <;> simp [*] at hf' + have hf := h hf' + simp [hxy, Ne.symm hxy] at hf + assumption + +theorem State.erase_le_update (h : σ' ≼ σ) : σ'.erase x ≼ σ.update x v := by + intro y w hf' + by_cases hxy : x = y <;> simp [*] at hf' |- + exact h hf' + +theorem State.update_le_update (h : σ' ≼ σ) : σ'.update x v ≼ σ.update x v := by + intro y w hf + induction σ generalizing σ' hf with + | nil => rw [eq_bot h] at hf; assumption + | cons zw' σ ih => + have (z, w') := zw'; simp + have : σ'.erase z ≼ σ := erase_le_of_le_cons h + have ih := ih this + revert ih hf + split <;> simp [*] <;> by_cases hyz : y = z <;> simp (config := { contextual := true }) [*] + next => + intro he' + have he := h he' + simp [*] at he + assumption + next => + by_cases hxy : x = y <;> simp_all + +theorem Expr.eval_constProp_of_sub (e : Expr) (h : σ' ≼ σ) : (e.constProp σ').eval σ = e.eval σ := by + induction e with simp [*] + | var x => + split <;> simp + next he => rw [State.get_of_find? (h he)] + +theorem Expr.eval_constProp_of_eq_of_sub {e : Expr} (h₁ : e.eval σ = v) (h₂ : σ' ≼ σ) : (e.constProp σ').eval σ = v := by + have := eval_constProp_of_sub e h₂ + simp [h₁] at this + assumption + +theorem Stmt.constProp_sub (h₁ : (σ₁, s) ⇓ σ₂) (h₂ : σ₁' ≼ σ₁) : (s.constProp σ₁').2 ≼ σ₂ := by + induction h₁ generalizing σ₁' with simp + | skip => assumption + | assign heq => + split <;> simp + next h => + have heq' := Expr.eval_constProp_of_eq_of_sub heq h₂ + rw [← Expr.eval_simplify, h] at heq' + simp at heq' + rw [heq'] + apply State.update_le_update h₂ + next h _ _ => + exact State.erase_le_update h₂ + | whileTrue heq h₃ h₄ ih₃ ih₄ => + have ih₃ := ih₃ h₂ + have ih₄ := ih₄ ih₃ + simp [heq] at ih₄ + exact ih₄ + | whileFalse heq => apply State.bot_le + | ifTrue heq h ih => + have ih := ih h₂ + apply State.join_le_left_of ih + | ifFalse heq h ih => + have ih := ih h₂ + apply State.join_le_right_of ih + | seq h₃ h₄ ih₃ ih₄ => exact ih₄ (ih₃ h₂) + +theorem Stmt.constProp_correct (h₁ : (σ₁, s) ⇓ σ₂) (h₂ : σ₁' ≼ σ₁) : (σ₁, (s.constProp σ₁').1) ⇓ σ₂ := by + induction h₁ generalizing σ₁' with simp_all + | skip => exact Bigstep.skip + | assign heq => + split <;> simp + next h => + have heq' := Expr.eval_constProp_of_eq_of_sub heq h₂ + rw [← Expr.eval_simplify, h] at heq' + simp at heq' + apply Bigstep.assign; simp [*] + next => + have heq' := Expr.eval_constProp_of_eq_of_sub heq h₂ + rw [← Expr.eval_simplify] at heq' + apply Bigstep.assign heq' + | seq h₁ h₂ ih₁ ih₂ => + apply Bigstep.seq (ih₁ h₂) (ih₂ (constProp_sub h₁ h₂)) + | whileTrue heq h₁ h₂ ih₁ ih₂ => + have ih₁ := ih₁ (State.bot_le _) + have ih₂ := ih₂ (State.bot_le _) + exact Bigstep.whileTrue heq ih₁ ih₂ + | whileFalse heq => + exact Bigstep.whileFalse heq + | ifTrue heq h ih => + exact Bigstep.ifTrue (Expr.eval_constProp_of_eq_of_sub heq h₂) (ih h₂) + | ifFalse heq h ih => + exact Bigstep.ifFalse (Expr.eval_constProp_of_eq_of_sub heq h₂) (ih h₂) + +def Stmt.constPropagation (s : Stmt) : Stmt := + (s.constProp ⊥).1 + +theorem Stmt.constPropagation_correct (h : (σ, s) ⇓ σ') : (σ, s.constPropagation) ⇓ σ' := + constProp_correct h (State.bot_le _)