From 201704c5d185187905d0d8c0a963acd5593b19fe Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Thu, 23 Jan 2025 09:20:27 +0100 Subject: [PATCH 1/2] LRU eviction at revision bump This change makes us LRU evict once a new revision starts, instead of at the very moment we reach the limit while computing tracked functions. Notably the previous behavior was unsound as we were immediately clearing the value (without going through the delayed deletion buffer) while there could have been outstanding references. --- src/accumulator.rs | 3 +- src/function.rs | 5 +- src/function/fetch.rs | 6 +-- src/function/lru.rs | 26 +++++----- src/function/memo.rs | 76 ++++++++++++++--------------- src/ingredient.rs | 3 +- src/input.rs | 6 ++- src/input/input_field.rs | 3 +- src/interned.rs | 8 ++- src/runtime.rs | 4 ++ src/table.rs | 40 ++++++++++++--- src/table/memo.rs | 10 ++-- src/tracked_struct.rs | 8 ++- src/tracked_struct/tracked_field.rs | 3 +- src/zalsa.rs | 2 +- tests/lru.rs | 37 +++----------- 16 files changed, 137 insertions(+), 103 deletions(-) diff --git a/src/accumulator.rs b/src/accumulator.rs index 3518c910..9f67bdb3 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -13,6 +13,7 @@ use crate::{ cycle::CycleRecoveryStrategy, ingredient::{fmt_index, Ingredient, Jar, MaybeChangedAfter}, plumbing::JarAux, + table::Table, zalsa::IngredientIndex, zalsa_local::QueryOrigin, Database, DatabaseKeyIndex, Id, Revision, @@ -137,7 +138,7 @@ impl Ingredient for IngredientImpl { false } - fn reset_for_new_revision(&mut self) { + fn reset_for_new_revision(&mut self, _: &mut Table) { panic!("unexpected reset on accumulator") } diff --git a/src/function.rs b/src/function.rs index 639b1819..fcbe87d9 100644 --- a/src/function.rs +++ b/src/function.rs @@ -7,6 +7,7 @@ use crate::{ key::DatabaseKeyIndex, plumbing::JarAux, salsa_struct::SalsaStructInDb, + table::Table, zalsa::{IngredientIndex, MemoIngredientIndex, Zalsa}, zalsa_local::QueryOrigin, Cycle, Database, Id, Revision, @@ -231,7 +232,9 @@ where true } - fn reset_for_new_revision(&mut self) { + fn reset_for_new_revision(&mut self, table: &mut Table) { + self.lru + .for_each_evicted(|evict| self.evict_value_from_memo_for(table.memos_mut(evict))); std::mem::take(&mut self.deleted_entries); } diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 8ffb4639..7828f33b 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -9,7 +9,7 @@ where C: Configuration, { pub fn fetch<'db>(&'db self, db: &'db C::DbView, id: Id) -> &'db C::Output<'db> { - let (zalsa, zalsa_local) = db.zalsas(); + let zalsa_local = db.zalsa_local(); zalsa_local.unwind_if_revision_cancelled(db.as_dyn_database()); let memo = self.refresh_memo(db, id); @@ -19,9 +19,7 @@ where changed_at, } = memo.revisions.stamped_value(memo.value.as_ref().unwrap()); - if let Some(evicted) = self.lru.record_use(id) { - self.evict_value_from_memo_for(zalsa, evicted); - } + self.lru.record_use(id); zalsa_local.report_tracked_read( self.database_key_index(id).into(), diff --git a/src/function/lru.rs b/src/function/lru.rs index 6322be0e..5fd5b27e 100644 --- a/src/function/lru.rs +++ b/src/function/lru.rs @@ -11,31 +11,35 @@ pub(super) struct Lru { } impl Lru { - pub(super) fn record_use(&self, index: Id) -> Option { + pub(super) fn record_use(&self, index: Id) { // Relaxed should be fine, we don't need to synchronize on this. let capacity = self.capacity.load(Ordering::Relaxed); - if capacity == 0 { // LRU is disabled - return None; + return; } let mut set = self.set.lock(); set.insert(index); - if set.len() > capacity { - return set.pop_front(); - } - - None } pub(super) fn set_capacity(&self, capacity: usize) { // Relaxed should be fine, we don't need to synchronize on this. self.capacity.store(capacity, Ordering::Relaxed); + } - if capacity == 0 { - let mut set = self.set.lock(); - *set = FxLinkedHashSet::default(); + pub(super) fn for_each_evicted(&self, mut cb: impl FnMut(Id)) { + let mut set = self.set.lock(); + // Relaxed should be fine, we don't need to synchronize on this. + let cap = self.capacity.load(Ordering::Relaxed); + if set.len() <= cap || cap == 0 { + return; + } + while let Some(id) = set.pop_front() { + cb(id); + if set.len() <= cap { + break; + } } } } diff --git a/src/function/memo.rs b/src/function/memo.rs index 1e4c4cf9..72e7a852 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use crate::accumulator::accumulated_map::InputAccumulatedValues; use crate::revision::AtomicRevision; +use crate::table::memo::MemoTable; use crate::zalsa_local::QueryOrigin; use crate::{ key::DatabaseKeyIndex, zalsa::Zalsa, zalsa_local::QueryRevisions, Event, EventKind, Id, @@ -60,47 +61,46 @@ impl IngredientImpl { /// Evicts the existing memo for the given key, replacing it /// with an equivalent memo that has no value. If the memo is untracked, BaseInput, /// or has values assigned as output of another query, this has no effect. - pub(super) fn evict_value_from_memo_for<'db>(&'db self, zalsa: &'db Zalsa, id: Id) { - let old = zalsa.memo_table_for(id).map_memo::>>( - self.memo_ingredient_index, - |memo| { - match memo.revisions.origin { - QueryOrigin::Assigned(_) - | QueryOrigin::DerivedUntracked(_) - | QueryOrigin::BaseInput => { - // Careful: Cannot evict memos whose values were - // assigned as output of another query - // or those with untracked inputs - // as their values cannot be reconstructed. - memo - } - QueryOrigin::Derived(_) => { - // QueryRevisions: !Clone to discourage cloning, we need it here though - let &QueryRevisions { + pub(super) fn evict_value_from_memo_for(&self, table: &mut MemoTable) { + let old = table.map_memo::>>(self.memo_ingredient_index, |memo| { + match &memo.revisions.origin { + QueryOrigin::Assigned(_) + | QueryOrigin::DerivedUntracked(_) + | QueryOrigin::BaseInput => { + // Careful: Cannot evict memos whose values were + // assigned as output of another query + // or those with untracked inputs + // as their values cannot be reconstructed. + memo + } + QueryOrigin::Derived(_) => { + // Note that we cannot use `Arc::get_mut` here as the use of `ArcSwap` makes it + // impossible to get unique access to the interior Arc + // QueryRevisions: !Clone to discourage cloning, we need it here though + let &QueryRevisions { + changed_at, + durability, + ref origin, + ref tracked_struct_ids, + ref accumulated, + ref accumulated_inputs, + } = &memo.revisions; + // Re-assemble the memo but with the value set to `None` + Arc::new(Memo::new( + None, + memo.verified_at.load(), + QueryRevisions { changed_at, durability, - ref origin, - ref tracked_struct_ids, - ref accumulated, - ref accumulated_inputs, - } = &memo.revisions; - // Re-assemble the memo but with the value set to `None` - Arc::new(Memo::new( - None, - memo.verified_at.load(), - QueryRevisions { - changed_at, - durability, - origin: origin.clone(), - tracked_struct_ids: tracked_struct_ids.clone(), - accumulated: accumulated.clone(), - accumulated_inputs: accumulated_inputs.clone(), - }, - )) - } + origin: origin.clone(), + tracked_struct_ids: tracked_struct_ids.clone(), + accumulated: accumulated.clone(), + accumulated_inputs: accumulated_inputs.clone(), + }, + )) } - }, - ); + } + }); if let Some(old) = old { // In case there is a reference to the old memo out there, we have to store it // in the deleted entries. This will get cleared when a new revision starts. diff --git a/src/ingredient.rs b/src/ingredient.rs index 84f6d3e3..a063d981 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -6,6 +6,7 @@ use std::{ use crate::{ accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues}, cycle::CycleRecoveryStrategy, + table::Table, zalsa::{IngredientIndex, MemoIngredientIndex}, zalsa_local::QueryOrigin, Database, DatabaseKeyIndex, Id, @@ -122,7 +123,7 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { /// /// **Important:** to actually receive resets, the ingredient must set /// [`IngredientRequiresReset::RESET_ON_NEW_REVISION`] to true. - fn reset_for_new_revision(&mut self); + fn reset_for_new_revision(&mut self, table: &mut Table); fn fmt_index(&self, index: Option, fmt: &mut fmt::Formatter<'_>) -> fmt::Result; } diff --git a/src/input.rs b/src/input.rs index e671c325..43429330 100644 --- a/src/input.rs +++ b/src/input.rs @@ -260,7 +260,7 @@ impl Ingredient for IngredientImpl { false } - fn reset_for_new_revision(&mut self) { + fn reset_for_new_revision(&mut self, _: &mut Table) { panic!("unexpected call to `reset_for_new_revision`") } @@ -328,6 +328,10 @@ where &self.memos } + fn memos_mut(&mut self) -> &mut crate::table::memo::MemoTable { + &mut self.memos + } + unsafe fn syncs(&self, _current_revision: Revision) -> &SyncTable { &self.syncs } diff --git a/src/input/input_field.rs b/src/input/input_field.rs index cbb51190..b22d03b2 100644 --- a/src/input/input_field.rs +++ b/src/input/input_field.rs @@ -1,6 +1,7 @@ use crate::cycle::CycleRecoveryStrategy; use crate::ingredient::{fmt_index, Ingredient, MaybeChangedAfter}; use crate::input::Configuration; +use crate::table::Table; use crate::zalsa::IngredientIndex; use crate::zalsa_local::QueryOrigin; use crate::{Database, DatabaseKeyIndex, Id, Revision}; @@ -85,7 +86,7 @@ where false } - fn reset_for_new_revision(&mut self) { + fn reset_for_new_revision(&mut self, _: &mut Table) { panic!("unexpected call: input fields don't register for resets"); } diff --git a/src/interned.rs b/src/interned.rs index 92c358c8..c31adeac 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -7,7 +7,7 @@ use crate::key::InputDependencyIndex; use crate::plumbing::{Jar, JarAux}; use crate::table::memo::MemoTable; use crate::table::sync::SyncTable; -use crate::table::Slot; +use crate::table::{Slot, Table}; use crate::zalsa::IngredientIndex; use crate::zalsa_local::QueryOrigin; use crate::{Database, DatabaseKeyIndex, Id}; @@ -327,7 +327,7 @@ where false } - fn reset_for_new_revision(&mut self) { + fn reset_for_new_revision(&mut self, _: &mut Table) { // Interned ingredients do not, normally, get deleted except when they are "reset" en masse. // There ARE methods (e.g., `clear_deleted_entries` and `remove`) for deleting individual // items, but those are only used for tracked struct ingredients. @@ -362,6 +362,10 @@ where &self.memos } + fn memos_mut(&mut self) -> &mut MemoTable { + &mut self.memos + } + unsafe fn syncs(&self, _current_revision: Revision) -> &crate::table::sync::SyncTable { &self.syncs } diff --git a/src/runtime.rs b/src/runtime.rs index ae07afa4..aa2f69ef 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -138,6 +138,10 @@ impl Runtime { &self.table } + pub(crate) fn table_mut(&mut self) -> &mut Table { + &mut self.table + } + /// Increments the "current revision" counter and clears /// the cancellation flag. /// diff --git a/src/table.rs b/src/table.rs index 1b88320e..16b75130 100644 --- a/src/table.rs +++ b/src/table.rs @@ -23,7 +23,7 @@ const PAGE_LEN_MASK: usize = PAGE_LEN - 1; const PAGE_LEN: usize = 1 << PAGE_LEN_BITS; const MAX_PAGES: usize = 1 << (32 - PAGE_LEN_BITS); -pub(crate) struct Table { +pub struct Table { pub(crate) pages: AppendOnlyVec>, } @@ -37,6 +37,9 @@ pub(crate) trait TablePage: Any + Send + Sync { /// The `current_revision` MUST be the current revision of the database owning this table page. unsafe fn memos(&self, slot: SlotIndex, current_revision: Revision) -> &MemoTable; + /// Access the memos attached to `slot`. + fn memos_mut(&mut self, slot: SlotIndex) -> &mut MemoTable; + /// Access the syncs attached to `slot`. /// /// # Safety condition @@ -75,6 +78,9 @@ pub(crate) trait Slot: Any + Send + Sync { /// The current revision MUST be the current revision of the database containing this slot. unsafe fn memos(&self, current_revision: Revision) -> &MemoTable; + /// Mutably access the [`MemoTable`] for this slot. + fn memos_mut(&mut self) -> &mut MemoTable; + /// Access the [`SyncTable`][] for this slot. /// /// # Safety condition @@ -123,7 +129,7 @@ impl Table { /// # Panics /// /// If `id` is out of bounds or the does not have the type `T`. - pub fn get(&self, id: Id) -> &T { + pub(crate) fn get(&self, id: Id) -> &T { let (page, slot) = split_id(id); let page_ref = self.page::(page); page_ref.get(slot) @@ -138,7 +144,7 @@ impl Table { /// # Safety /// /// See [`Page::get_raw`][]. - pub fn get_raw(&self, id: Id) -> *mut T { + pub(crate) fn get_raw(&self, id: Id) -> *mut T { let (page, slot) = split_id(id); let page_ref = self.page::(page); page_ref.get_raw(slot) @@ -149,12 +155,12 @@ impl Table { /// # Panics /// /// If `page` is out of bounds or the type `T` is incorrect. - pub fn page(&self, page: PageIndex) -> &Page { + pub(crate) fn page(&self, page: PageIndex) -> &Page { self.pages[page.0].assert_type::>() } /// Allocate a new page for the given ingredient and with slots of type `T` - pub fn push_page(&self, ingredient: IngredientIndex) -> PageIndex { + pub(crate) fn push_page(&self, ingredient: IngredientIndex) -> PageIndex { let page = Box::new(>::new(ingredient)); PageIndex::new(self.pages.push(page)) } @@ -165,18 +171,24 @@ impl Table { /// /// The parameter `current_revision` MUST be the current revision /// of the owner of database owning this table. - pub unsafe fn memos(&self, id: Id, current_revision: Revision) -> &MemoTable { + pub(crate) unsafe fn memos(&self, id: Id, current_revision: Revision) -> &MemoTable { let (page, slot) = split_id(id); self.pages[page.0].memos(slot, current_revision) } + /// Get the memo table associated with `id` + pub(crate) fn memos_mut(&mut self, id: Id) -> &mut MemoTable { + let (page, slot) = split_id(id); + self.pages[page.0].memos_mut(slot) + } + /// Get the sync table associated with `id` /// /// # Safety condition /// /// The parameter `current_revision` MUST be the current revision /// of the owner of database owning this table. - pub unsafe fn syncs(&self, id: Id, current_revision: Revision) -> &SyncTable { + pub(crate) unsafe fn syncs(&self, id: Id, current_revision: Revision) -> &SyncTable { let (page, slot) = split_id(id); self.pages[page.0].syncs(slot, current_revision) } @@ -211,6 +223,16 @@ impl Page { unsafe { (*self.data[slot.0].get()).assume_init_ref() } } + /// Returns a reference to the given slot. + /// + /// # Panics + /// + /// If slot is out of bounds + pub(crate) fn get_mut(&mut self, slot: SlotIndex) -> &mut T { + self.check_bounds(slot); + unsafe { (*self.data[slot.0].get()).assume_init_mut() } + } + pub(crate) fn slots(&self) -> impl Iterator { let len = self.allocated.load(std::sync::atomic::Ordering::Acquire); let mut idx = 0; @@ -274,6 +296,10 @@ impl TablePage for Page { self.get(slot).memos(current_revision) } + fn memos_mut(&mut self, slot: SlotIndex) -> &mut MemoTable { + self.get_mut(slot).memos_mut() + } + unsafe fn syncs(&self, slot: SlotIndex, current_revision: Revision) -> &SyncTable { self.get(slot).syncs(current_revision) } diff --git a/src/table/memo.rs b/src/table/memo.rs index e06fe853..0aac21c3 100644 --- a/src/table/memo.rs +++ b/src/table/memo.rs @@ -166,13 +166,11 @@ impl MemoTable { /// Calls `f` on the memo at `memo_ingredient_index` and replaces the memo with the result of `f`. /// If the memo is not present, `f` is not called. pub(crate) fn map_memo( - &self, + &mut self, memo_ingredient_index: MemoIngredientIndex, f: impl FnOnce(Arc) -> Arc, ) -> Option> { - // If the memo slot is already occupied, it must already have the - // right type info etc, and we only need the read-lock. - let memos = self.memos.read(); + let memos = self.memos.get_mut(); let Some(MemoEntry { data: Some(MemoEntryData { @@ -189,6 +187,10 @@ impl MemoTable { TypeId::of::(), "inconsistent type-id for `{memo_ingredient_index:?}`" ); + // arc-swap does not expose accessing the interior mutably at all unfortunately + // https://github.com/vorner/arc-swap/issues/131 + // so we are required to allocate a nwe arc within `f` instead of being able + // to swap out the interior // SAFETY: type_id check asserted above let memo = f(unsafe { Self::from_dummy(arc_swap.load_full()) }); Some(unsafe { Self::from_dummy::(arc_swap.swap(Self::to_dummy(memo))) }) diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index 9c11d1a1..ac65f716 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -759,7 +759,9 @@ where false } - fn reset_for_new_revision(&mut self) {} + fn reset_for_new_revision(&mut self, _: &mut Table) { + panic!("tracked struct ingredients do not require reset") + } } impl std::fmt::Debug for IngredientImpl @@ -831,6 +833,10 @@ where &self.memos } + fn memos_mut(&mut self) -> &mut crate::table::memo::MemoTable { + &mut self.memos + } + unsafe fn syncs(&self, current_revision: Revision) -> &crate::table::sync::SyncTable { // Acquiring the read lock here with the current revision // ensures that there is no danger of a race diff --git a/src/tracked_struct/tracked_field.rs b/src/tracked_struct/tracked_field.rs index a9453359..e3d1e07a 100644 --- a/src/tracked_struct/tracked_field.rs +++ b/src/tracked_struct/tracked_field.rs @@ -2,6 +2,7 @@ use std::marker::PhantomData; use crate::{ ingredient::{Ingredient, MaybeChangedAfter}, + table::Table, zalsa::IngredientIndex, Database, Id, }; @@ -97,7 +98,7 @@ where false } - fn reset_for_new_revision(&mut self) { + fn reset_for_new_revision(&mut self, _: &mut Table) { panic!("tracked field ingredients do not require reset") } diff --git a/src/zalsa.rs b/src/zalsa.rs index cfd775d0..6c9f7c20 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -272,7 +272,7 @@ impl Zalsa { let new_revision = self.runtime.new_revision(); for index in self.ingredients_requiring_reset.iter() { - self.ingredients_vec[index.as_usize()].reset_for_new_revision(); + self.ingredients_vec[index.as_usize()].reset_for_new_revision(self.runtime.table_mut()); } new_revision diff --git a/tests/lru.rs b/tests/lru.rs index e8ad930d..704670cb 100644 --- a/tests/lru.rs +++ b/tests/lru.rs @@ -48,13 +48,6 @@ fn get_hot_potato2(db: &dyn LogDatabase, input: MyInput) -> u32 { get_hot_potato(db, input).0 } -#[salsa::tracked(lru = 32)] -fn get_volatile(db: &dyn LogDatabase, _input: MyInput) -> usize { - static COUNTER: AtomicUsize = AtomicUsize::new(0); - db.report_untracked_read(); - COUNTER.fetch_add(1, Ordering::SeqCst) -} - fn load_n_potatoes() -> usize { N_POTATOES.with(|n| n.load(Ordering::SeqCst)) } @@ -67,32 +60,15 @@ fn lru_works() { for i in 0..128u32 { let input = MyInput::new(&db, i); let p = get_hot_potato(&db, input); - assert_eq!(p.0, i) + assert_eq!(p.0, i); } + assert_eq!(load_n_potatoes(), 128); // trigger the GC db.synthetic_write(salsa::Durability::HIGH); assert_eq!(load_n_potatoes(), 32); } -#[test] -fn lru_doesnt_break_volatile_queries() { - let db = common::LoggerDatabase::default(); - - // Create all inputs first, so that there are no revision changes among calls to `get_volatile` - let inputs: Vec = (0..128usize).map(|i| MyInput::new(&db, i as u32)).collect(); - - // Here, we check that we execute each volatile query at most once, despite - // LRU. That does mean that we have more values in DB than the LRU capacity, - // but it's much better than inconsistent results from volatile queries! - for _ in 0..3 { - for (i, input) in inputs.iter().enumerate() { - let x = get_volatile(&db, *input); - assert_eq!(x, i); - } - } -} - #[test] fn lru_can_be_changed_at_runtime() { let mut db = common::LoggerDatabase::default(); @@ -102,9 +78,10 @@ fn lru_can_be_changed_at_runtime() { for &(i, input) in inputs.iter() { let p = get_hot_potato(&db, input); - assert_eq!(p.0, i) + assert_eq!(p.0, i); } + assert_eq!(load_n_potatoes(), 128); // trigger the GC db.synthetic_write(salsa::Durability::HIGH); assert_eq!(load_n_potatoes(), 32); @@ -113,9 +90,10 @@ fn lru_can_be_changed_at_runtime() { assert_eq!(load_n_potatoes(), 32); for &(i, input) in inputs.iter() { let p = get_hot_potato(&db, input); - assert_eq!(p.0, i) + assert_eq!(p.0, i); } + assert_eq!(load_n_potatoes(), 128); // trigger the GC db.synthetic_write(salsa::Durability::HIGH); assert_eq!(load_n_potatoes(), 64); @@ -125,9 +103,10 @@ fn lru_can_be_changed_at_runtime() { assert_eq!(load_n_potatoes(), 64); for &(i, input) in inputs.iter() { let p = get_hot_potato(&db, input); - assert_eq!(p.0, i) + assert_eq!(p.0, i); } + assert_eq!(load_n_potatoes(), 128); // trigger the GC db.synthetic_write(salsa::Durability::HIGH); assert_eq!(load_n_potatoes(), 128); From f2c9cb5129e58f9883cf08cbe0fe9340494522b9 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Mon, 27 Jan 2025 09:18:00 +0100 Subject: [PATCH 2/2] Mark `MemoTable` methods that evict entries unsafe Callers of these functions need to ensure that the value will be dropped when the revision ends --- src/function.rs | 17 ++++++------ src/function/memo.rs | 31 ++++++++++++++------- src/table/memo.rs | 63 ++++++++++++++++++++++++++++++------------- src/tracked_struct.rs | 18 ++++++------- 4 files changed, 82 insertions(+), 47 deletions(-) diff --git a/src/function.rs b/src/function.rs index fcbe87d9..1073496c 100644 --- a/src/function.rs +++ b/src/function.rs @@ -1,4 +1,4 @@ -use std::{any::Any, fmt, sync::Arc}; +use std::{any::Any, fmt, mem::ManuallyDrop, sync::Arc}; use crate::{ accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues}, @@ -168,15 +168,16 @@ where memo: memo::Memo>, ) -> &'db memo::Memo> { let memo = Arc::new(memo); - let db_memo = unsafe { - // Unsafety conditions: memo must be in the map (it's not yet, but it will be by the time this - // value is returned) and anything removed from map is added to deleted entries (ensured elsewhere). - self.extend_memo_lifetime(&memo) - }; - if let Some(old_value) = self.insert_memo_into_table_for(zalsa, id, memo) { + // Unsafety conditions: memo must be in the map (it's not yet, but it will be by the time this + // value is returned) and anything removed from map is added to deleted entries (ensured elsewhere). + let db_memo = unsafe { self.extend_memo_lifetime(&memo) }; + // Safety: We delay the drop of `old_value` until a new revision starts which ensures no + // references will exist for the memo contents. + if let Some(old_value) = unsafe { self.insert_memo_into_table_for(zalsa, id, memo) } { // In case there is a reference to the old memo out there, we have to store it // in the deleted entries. This will get cleared when a new revision starts. - self.deleted_entries.push(old_value); + self.deleted_entries + .push(ManuallyDrop::into_inner(old_value)); } db_memo } diff --git a/src/function/memo.rs b/src/function/memo.rs index 72e7a852..d2415175 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -1,6 +1,7 @@ use std::any::Any; use std::fmt::Debug; use std::fmt::Formatter; +use std::mem::ManuallyDrop; use std::sync::Arc; use crate::accumulator::accumulated_map::InputAccumulatedValues; @@ -32,18 +33,26 @@ impl IngredientImpl { unsafe { std::mem::transmute(memo) } } - /// Inserts the memo for the given key; (atomically) overwrites any previously existing memo. - pub(super) fn insert_memo_into_table_for<'db>( + /// Inserts the memo for the given key; (atomically) overwrites and returns any previously existing memo + /// + /// # Safety + /// + /// The caller needs to make sure to not drop the returned value until no more references into + /// the database exist as there may be outstanding borrows into the `Arc` contents. + pub(super) unsafe fn insert_memo_into_table_for<'db>( &'db self, zalsa: &'db Zalsa, id: Id, memo: ArcMemo<'db, C>, - ) -> Option> { + ) -> Option>> { let static_memo = unsafe { self.to_static(memo) }; - let old_static_memo = zalsa - .memo_table_for(id) - .insert(self.memo_ingredient_index, static_memo)?; - unsafe { Some(self.to_self(old_static_memo)) } + let old_static_memo = unsafe { + zalsa + .memo_table_for(id) + .insert(self.memo_ingredient_index, static_memo) + }?; + let old_static_memo = ManuallyDrop::into_inner(old_static_memo); + Some(ManuallyDrop::new(unsafe { self.to_self(old_static_memo) })) } /// Loads the current memo for `key_index`. This does not hold any sort of @@ -62,7 +71,7 @@ impl IngredientImpl { /// with an equivalent memo that has no value. If the memo is untracked, BaseInput, /// or has values assigned as output of another query, this has no effect. pub(super) fn evict_value_from_memo_for(&self, table: &mut MemoTable) { - let old = table.map_memo::>>(self.memo_ingredient_index, |memo| { + let map = |memo: ArcMemo<'static, C>| -> ArcMemo<'static, C> { match &memo.revisions.origin { QueryOrigin::Assigned(_) | QueryOrigin::DerivedUntracked(_) @@ -100,11 +109,13 @@ impl IngredientImpl { )) } } - }); + }; + // SAFETY: We queue the old value for deletion, delaying its drop until the next revision bump. + let old = unsafe { table.map_memo(self.memo_ingredient_index, map) }; if let Some(old) = old { // In case there is a reference to the old memo out there, we have to store it // in the deleted entries. This will get cleared when a new revision starts. - self.deleted_entries.push(old); + self.deleted_entries.push(ManuallyDrop::into_inner(old)); } } } diff --git a/src/table/memo.rs b/src/table/memo.rs index 0aac21c3..c14e1337 100644 --- a/src/table/memo.rs +++ b/src/table/memo.rs @@ -1,6 +1,7 @@ use std::{ any::{Any, TypeId}, fmt::Debug, + mem::ManuallyDrop, sync::Arc, }; @@ -79,11 +80,15 @@ impl MemoTable { } } - pub(crate) fn insert( + /// # Safety + /// + /// The caller needs to make sure to not drop the returned value until no more references into + /// the database exist as there may be outstanding borrows into the `Arc` contents. + pub(crate) unsafe fn insert( &self, memo_ingredient_index: MemoIngredientIndex, memo: Arc, - ) -> Option> { + ) -> Option>> { // If the memo slot is already occupied, it must already have the // right type info etc, and we only need the read-lock. if let Some(MemoEntry { @@ -101,18 +106,23 @@ impl MemoTable { "inconsistent type-id for `{memo_ingredient_index:?}`" ); let old_memo = arc_swap.swap(Self::to_dummy(memo)); - return unsafe { Some(Self::from_dummy(old_memo)) }; + return Some(ManuallyDrop::new(unsafe { Self::from_dummy(old_memo) })); } // Otherwise we need the write lock. - self.insert_cold(memo_ingredient_index, memo) + // SAFETY: The caller is responsible for dropping + unsafe { self.insert_cold(memo_ingredient_index, memo) } } - fn insert_cold( + /// # Safety + /// + /// The caller needs to make sure to not drop the returned value until no more references into + /// the database exist as there may be outstanding borrows into the `Arc` contents. + unsafe fn insert_cold( &self, memo_ingredient_index: MemoIngredientIndex, memo: Arc, - ) -> Option> { + ) -> Option>> { let mut memos = self.memos.write(); let memo_ingredient_index = memo_ingredient_index.as_usize(); if memos.len() < memo_ingredient_index + 1 { @@ -126,13 +136,15 @@ impl MemoTable { arc_swap: ArcSwap::new(Self::to_dummy(memo)), }), ); - old_entry.map( - |MemoEntryData { - type_id: _, - to_dyn_fn: _, - arc_swap, - }| unsafe { Self::from_dummy(arc_swap.into_inner()) }, - ) + old_entry + .map( + |MemoEntryData { + type_id: _, + to_dyn_fn: _, + arc_swap, + }| unsafe { Self::from_dummy(arc_swap.into_inner()) }, + ) + .map(ManuallyDrop::new) } pub(crate) fn get( @@ -165,11 +177,16 @@ impl MemoTable { /// Calls `f` on the memo at `memo_ingredient_index` and replaces the memo with the result of `f`. /// If the memo is not present, `f` is not called. - pub(crate) fn map_memo( + /// + /// # Safety + /// + /// The caller needs to make sure to not drop the returned value until no more references into + /// the database exist as there may be outstanding borrows into the `Arc` contents. + pub(crate) unsafe fn map_memo( &mut self, memo_ingredient_index: MemoIngredientIndex, f: impl FnOnce(Arc) -> Arc, - ) -> Option> { + ) -> Option>> { let memos = self.memos.get_mut(); let Some(MemoEntry { data: @@ -189,14 +206,22 @@ impl MemoTable { ); // arc-swap does not expose accessing the interior mutably at all unfortunately // https://github.com/vorner/arc-swap/issues/131 - // so we are required to allocate a nwe arc within `f` instead of being able + // so we are required to allocate a new arc within `f` instead of being able // to swap out the interior // SAFETY: type_id check asserted above let memo = f(unsafe { Self::from_dummy(arc_swap.load_full()) }); - Some(unsafe { Self::from_dummy::(arc_swap.swap(Self::to_dummy(memo))) }) + Some(ManuallyDrop::new(unsafe { + Self::from_dummy::(arc_swap.swap(Self::to_dummy(memo))) + })) } - pub(crate) fn into_memos(self) -> impl Iterator)> { + /// # Safety + /// + /// The caller needs to make sure to not drop the returned value until no more references into + /// the database exist as there may be outstanding borrows into the `Arc` contents. + pub(crate) unsafe fn into_memos( + self, + ) -> impl Iterator>)> { self.memos .into_inner() .into_iter() @@ -213,7 +238,7 @@ impl MemoTable { )| { ( MemoIngredientIndex::from_usize(index), - to_dyn_fn(arc_swap.into_inner()), + ManuallyDrop::new(to_dyn_fn(arc_swap.into_inner())), ) }, ) diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index ac65f716..cff5e2ee 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -1,4 +1,4 @@ -use std::{any::TypeId, fmt, hash::Hash, marker::PhantomData, ops::DerefMut}; +use std::{any::TypeId, fmt, hash::Hash, marker::PhantomData, mem::ManuallyDrop, ops::DerefMut}; use crossbeam_queue::SegQueue; use tracked_field::FieldIngredientImpl; @@ -580,15 +580,10 @@ where None => { panic!("cannot delete write-locked id `{id:?}`; value leaked across threads"); } - + Some(r) if r == current_revision => panic!( + "cannot delete read-locked id `{id:?}`; value leaked across threads or user functions not deterministic" + ), Some(r) => { - if r == current_revision { - panic!( - "cannot delete read-locked id `{id:?}`; \ - value leaked across threads or user functions not deterministic" - ) - } - if data_ref.updated_at.compare_exchange(Some(r), None).is_err() { panic!("race occurred when deleting value `{id:?}`") } @@ -598,7 +593,10 @@ where // Take the memo table. This is safe because we have modified `data_ref.updated_at` to `None` // and the code that references the memo-table has a read-lock. let memo_table = unsafe { (*data).take_memo_table() }; - for (memo_ingredient_index, memo) in memo_table.into_memos() { + // SAFETY: We have verified that no more references to these memos exist and so we are good + // to drop them. + for (memo_ingredient_index, memo) in unsafe { memo_table.into_memos() } { + let memo = ManuallyDrop::into_inner(memo); let ingredient_index = zalsa.ingredient_index_for_memo(self.ingredient_index, memo_ingredient_index);