Skip to content

Commit

Permalink
Mark MemoTable methods that evict entries unsafe
Browse files Browse the repository at this point in the history
Callers of these functions need to ensure that the value will be dropped when the revision ends
  • Loading branch information
Veykril committed Feb 12, 2025
1 parent 201704c commit f2c9cb5
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 47 deletions.
17 changes: 9 additions & 8 deletions src/function.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{any::Any, fmt, sync::Arc};
use std::{any::Any, fmt, mem::ManuallyDrop, sync::Arc};

use crate::{
accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues},
Expand Down Expand Up @@ -168,15 +168,16 @@ where
memo: memo::Memo<C::Output<'db>>,
) -> &'db memo::Memo<C::Output<'db>> {
let memo = Arc::new(memo);
let db_memo = unsafe {
// Unsafety conditions: memo must be in the map (it's not yet, but it will be by the time this
// value is returned) and anything removed from map is added to deleted entries (ensured elsewhere).
self.extend_memo_lifetime(&memo)
};
if let Some(old_value) = self.insert_memo_into_table_for(zalsa, id, memo) {
// Unsafety conditions: memo must be in the map (it's not yet, but it will be by the time this
// value is returned) and anything removed from map is added to deleted entries (ensured elsewhere).
let db_memo = unsafe { self.extend_memo_lifetime(&memo) };
// Safety: We delay the drop of `old_value` until a new revision starts which ensures no
// references will exist for the memo contents.
if let Some(old_value) = unsafe { self.insert_memo_into_table_for(zalsa, id, memo) } {
// In case there is a reference to the old memo out there, we have to store it
// in the deleted entries. This will get cleared when a new revision starts.
self.deleted_entries.push(old_value);
self.deleted_entries
.push(ManuallyDrop::into_inner(old_value));
}
db_memo
}
Expand Down
31 changes: 21 additions & 10 deletions src/function/memo.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::any::Any;
use std::fmt::Debug;
use std::fmt::Formatter;
use std::mem::ManuallyDrop;
use std::sync::Arc;

use crate::accumulator::accumulated_map::InputAccumulatedValues;
Expand Down Expand Up @@ -32,18 +33,26 @@ impl<C: Configuration> IngredientImpl<C> {
unsafe { std::mem::transmute(memo) }
}

/// Inserts the memo for the given key; (atomically) overwrites any previously existing memo.
pub(super) fn insert_memo_into_table_for<'db>(
/// Inserts the memo for the given key; (atomically) overwrites and returns any previously existing memo
///
/// # Safety
///
/// The caller needs to make sure to not drop the returned value until no more references into
/// the database exist as there may be outstanding borrows into the `Arc` contents.
pub(super) unsafe fn insert_memo_into_table_for<'db>(
&'db self,
zalsa: &'db Zalsa,
id: Id,
memo: ArcMemo<'db, C>,
) -> Option<ArcMemo<'db, C>> {
) -> Option<ManuallyDrop<ArcMemo<'db, C>>> {
let static_memo = unsafe { self.to_static(memo) };
let old_static_memo = zalsa
.memo_table_for(id)
.insert(self.memo_ingredient_index, static_memo)?;
unsafe { Some(self.to_self(old_static_memo)) }
let old_static_memo = unsafe {
zalsa
.memo_table_for(id)
.insert(self.memo_ingredient_index, static_memo)
}?;
let old_static_memo = ManuallyDrop::into_inner(old_static_memo);
Some(ManuallyDrop::new(unsafe { self.to_self(old_static_memo) }))
}

/// Loads the current memo for `key_index`. This does not hold any sort of
Expand All @@ -62,7 +71,7 @@ impl<C: Configuration> IngredientImpl<C> {
/// with an equivalent memo that has no value. If the memo is untracked, BaseInput,
/// or has values assigned as output of another query, this has no effect.
pub(super) fn evict_value_from_memo_for(&self, table: &mut MemoTable) {
let old = table.map_memo::<Memo<C::Output<'_>>>(self.memo_ingredient_index, |memo| {
let map = |memo: ArcMemo<'static, C>| -> ArcMemo<'static, C> {
match &memo.revisions.origin {
QueryOrigin::Assigned(_)
| QueryOrigin::DerivedUntracked(_)
Expand Down Expand Up @@ -100,11 +109,13 @@ impl<C: Configuration> IngredientImpl<C> {
))
}
}
});
};
// SAFETY: We queue the old value for deletion, delaying its drop until the next revision bump.
let old = unsafe { table.map_memo(self.memo_ingredient_index, map) };
if let Some(old) = old {
// In case there is a reference to the old memo out there, we have to store it
// in the deleted entries. This will get cleared when a new revision starts.
self.deleted_entries.push(old);
self.deleted_entries.push(ManuallyDrop::into_inner(old));
}
}
}
Expand Down
63 changes: 44 additions & 19 deletions src/table/memo.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{
any::{Any, TypeId},
fmt::Debug,
mem::ManuallyDrop,
sync::Arc,
};

Expand Down Expand Up @@ -79,11 +80,15 @@ impl MemoTable {
}
}

pub(crate) fn insert<M: Memo>(
/// # Safety
///
/// The caller needs to make sure to not drop the returned value until no more references into
/// the database exist as there may be outstanding borrows into the `Arc` contents.
pub(crate) unsafe fn insert<M: Memo>(
&self,
memo_ingredient_index: MemoIngredientIndex,
memo: Arc<M>,
) -> Option<Arc<M>> {
) -> Option<ManuallyDrop<Arc<M>>> {
// If the memo slot is already occupied, it must already have the
// right type info etc, and we only need the read-lock.
if let Some(MemoEntry {
Expand All @@ -101,18 +106,23 @@ impl MemoTable {
"inconsistent type-id for `{memo_ingredient_index:?}`"
);
let old_memo = arc_swap.swap(Self::to_dummy(memo));
return unsafe { Some(Self::from_dummy(old_memo)) };
return Some(ManuallyDrop::new(unsafe { Self::from_dummy(old_memo) }));
}

// Otherwise we need the write lock.
self.insert_cold(memo_ingredient_index, memo)
// SAFETY: The caller is responsible for dropping
unsafe { self.insert_cold(memo_ingredient_index, memo) }
}

fn insert_cold<M: Memo>(
/// # Safety
///
/// The caller needs to make sure to not drop the returned value until no more references into
/// the database exist as there may be outstanding borrows into the `Arc` contents.
unsafe fn insert_cold<M: Memo>(
&self,
memo_ingredient_index: MemoIngredientIndex,
memo: Arc<M>,
) -> Option<Arc<M>> {
) -> Option<ManuallyDrop<Arc<M>>> {
let mut memos = self.memos.write();
let memo_ingredient_index = memo_ingredient_index.as_usize();
if memos.len() < memo_ingredient_index + 1 {
Expand All @@ -126,13 +136,15 @@ impl MemoTable {
arc_swap: ArcSwap::new(Self::to_dummy(memo)),
}),
);
old_entry.map(
|MemoEntryData {
type_id: _,
to_dyn_fn: _,
arc_swap,
}| unsafe { Self::from_dummy(arc_swap.into_inner()) },
)
old_entry
.map(
|MemoEntryData {
type_id: _,
to_dyn_fn: _,
arc_swap,
}| unsafe { Self::from_dummy(arc_swap.into_inner()) },
)
.map(ManuallyDrop::new)
}

pub(crate) fn get<M: Memo>(
Expand Down Expand Up @@ -165,11 +177,16 @@ impl MemoTable {

/// Calls `f` on the memo at `memo_ingredient_index` and replaces the memo with the result of `f`.
/// If the memo is not present, `f` is not called.
pub(crate) fn map_memo<M: Memo>(
///
/// # Safety
///
/// The caller needs to make sure to not drop the returned value until no more references into
/// the database exist as there may be outstanding borrows into the `Arc` contents.
pub(crate) unsafe fn map_memo<M: Memo>(
&mut self,
memo_ingredient_index: MemoIngredientIndex,
f: impl FnOnce(Arc<M>) -> Arc<M>,
) -> Option<Arc<M>> {
) -> Option<ManuallyDrop<Arc<M>>> {
let memos = self.memos.get_mut();
let Some(MemoEntry {
data:
Expand All @@ -189,14 +206,22 @@ impl MemoTable {
);
// arc-swap does not expose accessing the interior mutably at all unfortunately
// https://github.com/vorner/arc-swap/issues/131
// so we are required to allocate a nwe arc within `f` instead of being able
// so we are required to allocate a new arc within `f` instead of being able
// to swap out the interior
// SAFETY: type_id check asserted above
let memo = f(unsafe { Self::from_dummy(arc_swap.load_full()) });
Some(unsafe { Self::from_dummy::<M>(arc_swap.swap(Self::to_dummy(memo))) })
Some(ManuallyDrop::new(unsafe {
Self::from_dummy::<M>(arc_swap.swap(Self::to_dummy(memo)))
}))
}

pub(crate) fn into_memos(self) -> impl Iterator<Item = (MemoIngredientIndex, Arc<dyn Memo>)> {
/// # Safety
///
/// The caller needs to make sure to not drop the returned value until no more references into
/// the database exist as there may be outstanding borrows into the `Arc` contents.
pub(crate) unsafe fn into_memos(
self,
) -> impl Iterator<Item = (MemoIngredientIndex, ManuallyDrop<Arc<dyn Memo>>)> {
self.memos
.into_inner()
.into_iter()
Expand All @@ -213,7 +238,7 @@ impl MemoTable {
)| {
(
MemoIngredientIndex::from_usize(index),
to_dyn_fn(arc_swap.into_inner()),
ManuallyDrop::new(to_dyn_fn(arc_swap.into_inner())),
)
},
)
Expand Down
18 changes: 8 additions & 10 deletions src/tracked_struct.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{any::TypeId, fmt, hash::Hash, marker::PhantomData, ops::DerefMut};
use std::{any::TypeId, fmt, hash::Hash, marker::PhantomData, mem::ManuallyDrop, ops::DerefMut};

use crossbeam_queue::SegQueue;
use tracked_field::FieldIngredientImpl;
Expand Down Expand Up @@ -580,15 +580,10 @@ where
None => {
panic!("cannot delete write-locked id `{id:?}`; value leaked across threads");
}

Some(r) if r == current_revision => panic!(
"cannot delete read-locked id `{id:?}`; value leaked across threads or user functions not deterministic"
),
Some(r) => {
if r == current_revision {
panic!(
"cannot delete read-locked id `{id:?}`; \
value leaked across threads or user functions not deterministic"
)
}

if data_ref.updated_at.compare_exchange(Some(r), None).is_err() {
panic!("race occurred when deleting value `{id:?}`")
}
Expand All @@ -598,7 +593,10 @@ where
// Take the memo table. This is safe because we have modified `data_ref.updated_at` to `None`
// and the code that references the memo-table has a read-lock.
let memo_table = unsafe { (*data).take_memo_table() };
for (memo_ingredient_index, memo) in memo_table.into_memos() {
// SAFETY: We have verified that no more references to these memos exist and so we are good
// to drop them.
for (memo_ingredient_index, memo) in unsafe { memo_table.into_memos() } {
let memo = ManuallyDrop::into_inner(memo);
let ingredient_index =
zalsa.ingredient_index_for_memo(self.ingredient_index, memo_ingredient_index);

Expand Down

0 comments on commit f2c9cb5

Please sign in to comment.