Skip to content

Commit

Permalink
feat: align List.replicate/Array.mkArray/Vector.mkVector lemmas (lean…
Browse files Browse the repository at this point in the history
…prover#6667)

This PR aligns `List.replicate`/`Array.mkArray`/`Vector.mkVector`
lemmas.
kim-em authored and JovanGerb committed Jan 21, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 4381c34 commit 8639c94
Showing 3 changed files with 262 additions and 51 deletions.
202 changes: 154 additions & 48 deletions src/Init/Data/Array/Lemmas.lean
Original file line number Diff line number Diff line change
@@ -93,7 +93,7 @@ theorem size_eq_one {l : Array α} : l.size = 1 ↔ ∃ a, l = #[a] := by

/-! ### push -/

theorem push_ne_empty {a : α} {xs : Array α} : xs.push a ≠ #[] := by
@[simp] theorem push_ne_empty {a : α} {xs : Array α} : xs.push a ≠ #[] := by
cases xs
simp

@@ -174,10 +174,6 @@ theorem mkArray_succ : mkArray (n + 1) a = (mkArray n a).push a := by
apply toList_inj.1
simp [List.replicate_succ']

theorem mkArray_inj : mkArray n a = mkArray m b ↔ n = m ∧ (n = 0 ∨ a = b) := by
rw [← List.replicate_inj, ← toList_inj]
simp

@[simp] theorem getElem_mkArray (n : Nat) (v : α) (h : i < (mkArray n v).size) :
(mkArray n v)[i] = v := by simp [← getElem_toList]

@@ -2050,6 +2046,11 @@ theorem flatMap_toList (l : Array α) (f : α → List β) :
rcases l with ⟨l⟩
simp

@[simp] theorem toList_flatMap (l : Array α) (f : α → Array β) :
(l.flatMap f).toList = l.toList.flatMap fun a => (f a).toList := by
rcases l with ⟨l⟩
simp

@[simp] theorem flatMap_id (l : Array (Array α)) : l.flatMap id = l.flatten := by simp [flatMap_def]

@[simp] theorem flatMap_id' (l : Array (Array α)) : l.flatMap (fun a => a) = l.flatten := by simp [flatMap_def]
@@ -2096,7 +2097,7 @@ theorem flatMap_singleton (f : α → Array β) (x : α) : #[x].flatMap f = f x
theorem flatMap_assoc {α β} (l : Array α) (f : α → Array β) (g : β → Array γ) :
(l.flatMap f).flatMap g = l.flatMap fun x => (f x).flatMap g := by
rcases l with ⟨l⟩
simp [List.flatMap_assoc, flatMap_toList]
simp [List.flatMap_assoc, ← toList_flatMap]

theorem map_flatMap (f : β → γ) (g : α → Array β) (l : Array α) :
(l.flatMap g).map f = l.flatMap fun a => (g a).map f := by
@@ -2135,8 +2136,155 @@ theorem flatMap_eq_foldl (f : α → Array β) (l : Array α) :
intro l'
simp [ih ((l' ++ (f a).toList)), toArray_append]

/-! ### mkArray -/

@[simp] theorem mkArray_one : mkArray 1 a = #[a] := rfl

/-- Variant of `mkArray_succ` that prepends `a` at the beginning of the array. -/
theorem mkArray_succ' : mkArray (n + 1) a = #[a] ++ mkArray n a := by
apply Array.ext'
simp [List.replicate_succ]

@[simp] theorem mem_mkArray {a b : α} {n} : b ∈ mkArray n a ↔ n ≠ 0 ∧ b = a := by
unfold mkArray
simp only [mem_toArray, List.mem_replicate]

theorem eq_of_mem_mkArray {a b : α} {n} (h : b ∈ mkArray n a) : b = a := (mem_mkArray.1 h).2

theorem forall_mem_mkArray {p : α → Prop} {a : α} {n} :
(∀ b, b ∈ mkArray n a → p b) ↔ n = 0 ∨ p a := by
cases n <;> simp [mem_mkArray]

@[simp] theorem mkArray_succ_ne_empty (n : Nat) (a : α) : mkArray (n+1) a ≠ #[] := by
simp [mkArray_succ]

@[simp] theorem mkArray_eq_empty_iff {n : Nat} (a : α) : mkArray n a = #[] ↔ n = 0 := by
cases n <;> simp

@[simp] theorem getElem?_mkArray_of_lt {n : Nat} {m : Nat} (h : m < n) : (mkArray n a)[m]? = some a := by
simp [getElem?_mkArray, h]

@[simp] theorem mkArray_inj : mkArray n a = mkArray m b ↔ n = m ∧ (n = 0 ∨ a = b) := by
rw [← toList_inj]
simp

theorem eq_mkArray_of_mem {a : α} {l : Array α} (h : ∀ (b) (_ : b ∈ l), b = a) : l = mkArray l.size a := by
rw [← toList_inj]
simpa using List.eq_replicate_of_mem (by simpa using h)

theorem eq_mkArray_iff {a : α} {n} {l : Array α} :
l = mkArray n a ↔ l.size = n ∧ ∀ (b) (_ : b ∈ l), b = a := by
rw [← toList_inj]
simpa using List.eq_replicate_iff (l := l.toList)

theorem map_eq_mkArray_iff {l : Array α} {f : α → β} {b : β} :
l.map f = mkArray l.size b ↔ ∀ x ∈ l, f x = b := by
simp [eq_mkArray_iff]

@[simp] theorem map_const (l : Array α) (b : β) : map (Function.const α b) l = mkArray l.size b :=
map_eq_mkArray_iff.mpr fun _ _ => rfl

@[simp] theorem map_const_fun (x : β) : map (Function.const α x) = (mkArray ·.size x) := by
funext l
simp

/-- Variant of `map_const` using a lambda rather than `Function.const`. -/
-- This can not be a `@[simp]` lemma because it would fire on every `List.map`.
theorem map_const' (l : Array α) (b : β) : map (fun _ => b) l = mkArray l.size b :=
map_const l b

@[simp] theorem set_mkArray_self : (mkArray n a).set i a h = mkArray n a := by
apply Array.ext'
simp

@[simp] theorem setIfInBounds_mkArray_self : (mkArray n a).setIfInBounds i a = mkArray n a := by
apply Array.ext'
simp

@[simp] theorem mkArray_append_mkArray : mkArray n a ++ mkArray m a = mkArray (n + m) a := by
apply Array.ext'
simp

theorem append_eq_mkArray_iff {l₁ l₂ : Array α} {a : α} :
l₁ ++ l₂ = mkArray n a ↔
l₁.size + l₂.size = n ∧ l₁ = mkArray l₁.size a ∧ l₂ = mkArray l₂.size a := by
simp [← toList_inj, List.append_eq_replicate_iff]

theorem mkArray_eq_append_iff {l₁ l₂ : Array α} {a : α} :
mkArray n a = l₁ ++ l₂ ↔
l₁.size + l₂.size = n ∧ l₁ = mkArray l₁.size a ∧ l₂ = mkArray l₂.size a := by
rw [eq_comm, append_eq_mkArray_iff]

@[simp] theorem map_mkArray : (mkArray n a).map f = mkArray n (f a) := by
apply Array.ext'
simp

theorem filter_mkArray (w : stop = n) :
(mkArray n a).filter p 0 stop = if p a then mkArray n a else #[] := by
apply Array.ext'
simp only [w, toList_filter', toList_mkArray, List.filter_replicate]
split <;> simp_all

@[simp] theorem filter_mkArray_of_pos (w : stop = n) (h : p a) :
(mkArray n a).filter p 0 stop = mkArray n a := by
simp [filter_mkArray, h, w]

@[simp] theorem filter_mkArray_of_neg (w : stop = n) (h : ¬ p a) :
(mkArray n a).filter p 0 stop = #[] := by
simp [filter_mkArray, h, w]

theorem filterMap_mkArray {f : α → Option β} (w : stop = n := by simp) :
(mkArray n a).filterMap f 0 stop = match f a with | none => #[] | .some b => mkArray n b := by
apply Array.ext'
simp only [w, size_mkArray, toList_filterMap', toList_mkArray, List.filterMap_replicate]
split <;> simp_all

-- This is not a useful `simp` lemma because `b` is unknown.
theorem filterMap_mkArray_of_some {f : α → Option β} (h : f a = some b) :
(mkArray n a).filterMap f = mkArray n b := by
simp [filterMap_mkArray, h]

@[simp] theorem filterMap_mkArray_of_isSome {f : α → Option β} (h : (f a).isSome) :
(mkArray n a).filterMap f = mkArray n (Option.get _ h) := by
match w : f a, h with
| some b, _ => simp [filterMap_mkArray, h, w]

@[simp] theorem filterMap_mkArray_of_none {f : α → Option β} (h : f a = none) :
(mkArray n a).filterMap f = #[] := by
simp [filterMap_mkArray, h]

@[simp] theorem flatten_mkArray_empty : (mkArray n (#[] : Array α)).flatten = #[] := by
rw [← toList_inj]
simp

@[simp] theorem flatten_mkArray_singleton : (mkArray n #[a]).flatten = mkArray n a := by
rw [← toList_inj]
simp

@[simp] theorem flatten_mkArray_mkArray : (mkArray n (mkArray m a)).flatten = mkArray (n * m) a := by
rw [← toList_inj]
simp

theorem flatMap_mkArray {β} (f : α → Array β) : (mkArray n a).flatMap f = (mkArray n (f a)).flatten := by
rw [← toList_inj]
simp [flatMap_toList, List.flatMap_replicate]

@[simp] theorem isEmpty_replicate : (mkArray n a).isEmpty = decide (n = 0) := by
rw [← List.toArray_replicate, List.isEmpty_toArray]
simp

@[simp] theorem sum_mkArray_nat (n : Nat) (a : Nat) : (mkArray n a).sum = n * a := by
rw [← List.toArray_replicate, List.sum_toArray]
simp

/-! Content below this point has not yet been aligned with `List`. -/

/-! ### sum -/

theorem sum_eq_sum_toList [Add α] [Zero α] (as : Array α) : as.toList.sum = as.sum := by
cases as
simp [Array.sum, List.sum]

-- This is a duplicate of `List.toArray_toList`.
-- It's confusing to guess which namespace this theorem should live in,
-- so we provide both.
@@ -3176,48 +3324,6 @@ theorem foldr_map' (g : α → β) (f : α → α → α) (f' : β → β → β
| nil => simp
| cons xs xss ih => simp [ih]

/-! ### sum -/

theorem sum_eq_sum_toList [Add α] [Zero α] (as : Array α) : as.sum = as.toList.sum := by
cases as
simp [Array.sum, List.sum]

/-! ### mkArray -/

theorem eq_mkArray_of_mem {a : α} {l : Array α} (h : ∀ (b) (_ : b ∈ l), b = a) : l = mkArray l.size a := by
rcases l with ⟨l⟩
have := List.eq_replicate_of_mem (by simpa using h)
rw [this]
simp

theorem eq_mkArray_iff {a : α} {n} {l : Array α} :
l = mkArray n a ↔ l.size = n ∧ ∀ (b) (_ : b ∈ l), b = a := by
rcases l with ⟨l⟩
simp [← List.eq_replicate_iff, toArray_eq]

theorem map_eq_mkArray_iff {l : Array α} {f : α → β} {b : β} :
l.map f = mkArray l.size b ↔ ∀ x ∈ l, f x = b := by
simp [eq_mkArray_iff]

@[simp] theorem mem_mkArray (a : α) (n : Nat) : b ∈ mkArray n a ↔ n ≠ 0 ∧ b = a := by
rw [mkArray, mem_toArray]
simp

@[simp] theorem map_const (l : Array α) (b : β) : map (Function.const α b) l = mkArray l.size b :=
map_eq_mkArray_iff.mpr fun _ _ => rfl

@[simp] theorem map_const_fun (x : β) : map (Function.const α x) = (mkArray ·.size x) := by
funext l
simp

/-- Variant of `map_const` using a lambda rather than `Function.const`. -/
-- This can not be a `@[simp]` lemma because it would fire on every `Array.map`.
theorem map_const' (l : Array α) (b : β) : map (fun _ => b) l = mkArray l.size b :=
map_const l b

@[simp] theorem sum_mkArray_nat (n : Nat) (a : Nat) : (mkArray n a).sum = n * a := by
simp [sum_eq_sum_toList, List.sum_replicate_nat]

/-! ### reverse -/

@[simp] theorem mem_reverse {x : α} {as : Array α} : x ∈ as.reverse ↔ x ∈ as := by
12 changes: 10 additions & 2 deletions src/Init/Data/List/Lemmas.lean
Original file line number Diff line number Diff line change
@@ -2274,14 +2274,17 @@ theorem map_const' (l : List α) (b : β) : map (fun _ => b) l = replicate l.len
· intro i h₁ h₂
simp [getElem_set]

@[simp] theorem append_replicate_replicate : replicate n a ++ replicate m a = replicate (n + m) a := by
@[simp] theorem replicate_append_replicate : replicate n a ++ replicate m a = replicate (n + m) a := by
rw [eq_replicate_iff]
constructor
· simp
· intro b
simp only [mem_append, mem_replicate, ne_eq]
rintro (⟨-, rfl⟩ | ⟨_, rfl⟩) <;> rfl

@[deprecated replicate_append_replicate (since := "2025-01-16")]
abbrev append_replicate_replicate := @replicate_append_replicate

theorem append_eq_replicate_iff {l₁ l₂ : List α} {a : α} :
l₁ ++ l₂ = replicate n a ↔
l₁.length + l₂.length = n ∧ l₁ = replicate l₁.length a ∧ l₂ = replicate l₂.length a := by
@@ -2292,6 +2295,11 @@ theorem append_eq_replicate_iff {l₁ l₂ : List α} {a : α} :

@[deprecated append_eq_replicate_iff (since := "2024-09-05")] abbrev append_eq_replicate := @append_eq_replicate_iff

theorem replicate_eq_append_iff {l₁ l₂ : List α} {a : α} :
replicate n a = l₁ ++ l₂ ↔
l₁.length + l₂.length = n ∧ l₁ = replicate l₁.length a ∧ l₂ = replicate l₂.length a := by
rw [eq_comm, append_eq_replicate_iff]

@[simp] theorem map_replicate : (replicate n a).map f = replicate n (f a) := by
ext1 n
simp only [getElem?_map, getElem?_replicate]
@@ -2343,7 +2351,7 @@ theorem filterMap_replicate_of_some {f : α → Option β} (h : f a = some b) :
induction n with
| zero => simp
| succ n ih =>
simp only [replicate_succ, flatten_cons, ih, append_replicate_replicate, replicate_inj, or_true,
simp only [replicate_succ, flatten_cons, ih, replicate_append_replicate, replicate_inj, or_true,
and_true, add_one_mul, Nat.add_comm]

theorem flatMap_replicate {β} (f : α → List β) : (replicate n a).flatMap f = (replicate n (f a)).flatten := by
99 changes: 98 additions & 1 deletion src/Init/Data/Vector/Lemmas.lean
Original file line number Diff line number Diff line change
@@ -475,7 +475,7 @@ theorem singleton_inj : #v[a] = #v[b] ↔ a = b := by
theorem mkVector_succ : mkVector (n + 1) a = (mkVector n a).push a := by
simp [mkVector, Array.mkArray_succ]

theorem mkVector_inj : mkVector n a = mkVector n b ↔ n = 0 ∨ a = b := by
@[simp] theorem mkVector_inj : mkVector n a = mkVector n b ↔ n = 0 ∨ a = b := by
simp [← toArray_inj, toArray_mkVector, Array.mkArray_inj]

/-! ## L[i] and L[i]? -/
@@ -1649,6 +1649,103 @@ theorem map_eq_flatMap {α β} (f : α → β) (l : Vector α n) :
rcases l with ⟨l, rfl⟩
simp [Array.map_eq_flatMap]

/-! ### mkVector -/

@[simp] theorem mkVector_one : mkVector 1 a = #v[a] := rfl

/-- Variant of `mkVector_succ` that prepends `a` at the beginning of the vector. -/
theorem mkVector_succ' : mkVector (n + 1) a = (#v[a] ++ mkVector n a).cast (by omega) := by
rw [← toArray_inj]
simp [Array.mkArray_succ']

@[simp] theorem mem_mkVector {a b : α} {n} : b ∈ mkVector n a ↔ n ≠ 0 ∧ b = a := by
unfold mkVector
simp

theorem eq_of_mem_mkVector {a b : α} {n} (h : b ∈ mkVector n a) : b = a := (mem_mkVector.1 h).2

theorem forall_mem_mkVector {p : α → Prop} {a : α} {n} :
(∀ b, b ∈ mkVector n a → p b) ↔ n = 0 ∨ p a := by
cases n <;> simp [mem_mkVector]

@[simp] theorem getElem_mkVector (a : α) (n i : Nat) (h : i < n) : (mkVector n a)[i] = a := by
simp [mkVector]

theorem getElem?_mkVector (a : α) (n i : Nat) : (mkVector n a)[i]? = if i < n then some a else none := by
simp [getElem?_def]

@[simp] theorem getElem?_mkVector_of_lt {n : Nat} {m : Nat} (h : m < n) : (mkVector n a)[m]? = some a := by
simp [getElem?_mkVector, h]

theorem eq_mkVector_of_mem {a : α} {l : Vector α n} (h : ∀ (b) (_ : b ∈ l), b = a) : l = mkVector n a := by
rw [← toArray_inj]
simpa using Array.eq_mkArray_of_mem (l := l.toArray) (by simpa using h)

theorem eq_mkVector_iff {a : α} {n} {l : Vector α n} :
l = mkVector n a ↔ ∀ (b) (_ : b ∈ l), b = a := by
rw [← toArray_inj]
simpa using Array.eq_mkArray_iff (l := l.toArray) (n := n)

theorem map_eq_mkVector_iff {l : Vector α n} {f : α → β} {b : β} :
l.map f = mkVector n b ↔ ∀ x ∈ l, f x = b := by
simp [eq_mkVector_iff]

@[simp] theorem map_const (l : Vector α n) (b : β) : map (Function.const α b) l = mkVector n b :=
map_eq_mkVector_iff.mpr fun _ _ => rfl

@[simp] theorem map_const_fun (x : β) : map (n := n) (Function.const α x) = fun _ => mkVector n x := by
funext l
simp

/-- Variant of `map_const` using a lambda rather than `Function.const`. -/
-- This can not be a `@[simp]` lemma because it would fire on every `List.map`.
theorem map_const' (l : Vector α n) (b : β) : map (fun _ => b) l = mkVector n b :=
map_const l b

@[simp] theorem set_mkVector_self : (mkVector n a).set i a h = mkVector n a := by
rw [← toArray_inj]
simp

@[simp] theorem setIfInBounds_mkVector_self : (mkVector n a).setIfInBounds i a = mkVector n a := by
rw [← toArray_inj]
simp

@[simp] theorem mkVector_append_mkVector : mkVector n a ++ mkVector m a = mkVector (n + m) a := by
rw [← toArray_inj]
simp

theorem append_eq_mkVector_iff {l₁ : Vector α n} {l₂ : Vector α m} {a : α} :
l₁ ++ l₂ = mkVector (n + m) a ↔ l₁ = mkVector n a ∧ l₂ = mkVector m a := by
simp [← toArray_inj, Array.append_eq_mkArray_iff]

theorem mkVector_eq_append_iff {l₁ : Vector α n} {l₂ : Vector α m} {a : α} :
mkVector (n + m) a = l₁ ++ l₂ ↔ l₁ = mkVector n a ∧ l₂ = mkVector m a := by
rw [eq_comm, append_eq_mkVector_iff]

@[simp] theorem map_mkVector : (mkVector n a).map f = mkVector n (f a) := by
rw [← toArray_inj]
simp


@[simp] theorem flatten_mkVector_empty : (mkVector n (#v[] : Vector α 0)).flatten = #v[] := by
rw [← toArray_inj]
simp

@[simp] theorem flatten_mkVector_singleton : (mkVector n #v[a]).flatten = (mkVector n a).cast (by simp) := by
ext i h
simp [h]

@[simp] theorem flatten_mkVector_mkVector : (mkVector n (mkVector m a)).flatten = mkVector (n * m) a := by
ext i h
simp [h]

theorem flatMap_mkArray {β} (f : α → Vector β m) : (mkVector n a).flatMap f = (mkVector n (f a)).flatten := by
ext i h
simp [h]

@[simp] theorem sum_mkArray_nat (n : Nat) (a : Nat) : (mkVector n a).sum = n * a := by
simp [toArray_mkVector]

/-! Content below this point has not yet been aligned with `List` and `Array`. -/

@[simp] theorem getElem_ofFn {α n} (f : Fin n → α) (i : Nat) (h : i < n) :

0 comments on commit 8639c94

Please sign in to comment.