Skip to content

Commit

Permalink
loan_cell: add primitive for lending thread-local data (#743)
Browse files Browse the repository at this point in the history
Add a safe abstraction for temporarily lending on-stack data into thread
local storage. Use it in various places across the stack.

This fixes a use-after-free in `pal_async`, and it reduces the overhead
of TLS in `pal_async` and `underhill_threadpool`.
  • Loading branch information
jstarks authored Jan 30, 2025
1 parent 1b9ba4e commit c9b59a1
Show file tree
Hide file tree
Showing 14 changed files with 374 additions and 204 deletions.
10 changes: 10 additions & 0 deletions Cargo.lock
Original file line number Diff line number Diff line change
Expand Up @@ -3567,6 +3567,13 @@ dependencies = [
"zerocopy",
]

[[package]]
name = "loan_cell"
version = "0.0.0"
dependencies = [
"static_assertions",
]

[[package]]
name = "local_clock"
version = "0.0.0"
Expand Down Expand Up @@ -4876,6 +4883,7 @@ dependencies = [
"futures",
"getrandom",
"libc",
"loan_cell",
"once_cell",
"pal",
"pal_async_test",
Expand Down Expand Up @@ -4919,6 +4927,7 @@ dependencies = [
"inspect",
"io-uring",
"libc",
"loan_cell",
"once_cell",
"pal",
"pal_async",
Expand Down Expand Up @@ -7204,6 +7213,7 @@ version = "0.0.0"
dependencies = [
"fs-err",
"inspect",
"loan_cell",
"pal",
"pal_async",
"pal_uring",
Expand Down
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ inspect_proto = { path = "support/inspect_proto" }
inspect_rlimit = { path = "support/inspect_rlimit" }
inspect_task = { path = "support/inspect_task" }
kmsg = { path = "support/kmsg" }
loan_cell = { path = "support/loan_cell" }
local_clock = { path = "support/local_clock" }
mesh = { path = "support/mesh" }
mesh_build = { path = "support/mesh/mesh_build" }
Expand Down
11 changes: 6 additions & 5 deletions openhcl/underhill_core/src/vp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,11 @@ impl VpSpawner {
let thread = underhill_threadpool::Thread::current().unwrap();
// TODO propagate this error back earlier. This is easiest if
// set_idle_task is fixed to take a non-Send fn.
let mut vp = self
.vp
.bind_processor::<T>(thread.driver(), control)
.context("failed to initialize VP")?;
let mut vp = thread.with_driver(|driver| {
self.vp
.bind_processor::<T>(driver, control)
.context("failed to initialize VP")
})?;

if let Some(saved_state) = saved_state {
vmcore::save_restore::ProtobufSaveRestore::restore(&mut vp, saved_state)
Expand Down Expand Up @@ -166,7 +167,7 @@ impl VpSpawner {
self.vp.set_sidecar_exit_due_to_task(
thread
.first_task()
.map_or_else(|| "<unknown>".into(), |t| t.name.clone()),
.map_or_else(|| "<unknown>".into(), |t| t.name),
);
}

Expand Down
3 changes: 1 addition & 2 deletions openhcl/underhill_mem/src/init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -469,8 +469,7 @@ async fn apply_vtl2_protections(
tracing::debug!(
cpu = underhill_threadpool::Thread::current()
.unwrap()
.driver()
.target_cpu(),
.with_driver(|driver| driver.target_cpu()),
%range,
"applying protections"
);
Expand Down
1 change: 1 addition & 0 deletions openhcl/underhill_threadpool/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ rust-version.workspace = true

[target.'cfg(target_os = "linux")'.dependencies]
inspect = { workspace = true, features = ["std"] }
loan_cell.workspace = true
pal.workspace = true
pal_async.workspace = true
pal_uring.workspace = true
Expand Down
120 changes: 53 additions & 67 deletions openhcl/underhill_threadpool/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
//! This is built on top of [`pal_uring`] and [`pal_async`].
#![warn(missing_docs)]
// UNSAFETY: needed for saving per-thread state.
#![expect(unsafe_code)]
#![forbid(unsafe_code)]

use inspect::Inspect;
use loan_cell::LoanCell;
use pal::unix::affinity::CpuSet;
use pal_async::fd::FdReadyDriver;
use pal_async::task::Runnable;
Expand All @@ -30,7 +30,6 @@ use pal_uring::IoUringPool;
use pal_uring::PoolClient;
use pal_uring::Timer;
use parking_lot::Mutex;
use std::cell::Cell;
use std::future::poll_fn;
use std::io;
use std::marker::PhantomData;
Expand Down Expand Up @@ -215,15 +214,12 @@ impl ThreadpoolBuilder {
send.send(Ok(pool.client().clone())).ok();

// Store the current thread's driver so that spawned tasks can
// find it via `Thread::current()`.
CURRENT_THREADPOOL_CPU.with(|current| {
current.set(std::ptr::from_ref(&driver));
// find it via `Thread::current()`. Do this via a loan instead
// of storing it directly in TLS to avoid the overhead of
// registering a destructor.
CURRENT_THREAD_DRIVER.with(|current| {
current.lend(&driver, || pool.run());
});
pool.run();
CURRENT_THREADPOOL_CPU.with(|current| {
current.set(std::ptr::null());
});
drop(driver);
})?;

// Wait for the pool to be initialized.
Expand Down Expand Up @@ -360,33 +356,27 @@ impl Initiate for AffinitizedThreadpool {
/// The state for the thread pool thread for the currently running CPU.
#[derive(Debug, Copy, Clone)]
pub struct Thread {
driver: &'static ThreadpoolDriver,
_not_send_sync: PhantomData<*const ()>,
}

impl Thread {
/// Returns a new driver for the current CPU.
/// Returns an instance for the current CPU.
pub fn current() -> Option<Self> {
let inner = CURRENT_THREADPOOL_CPU.with(|current| {
let p = current.get();
// SAFETY: the `ThreadpoolDriver` is on the current thread's stack
// and so is guaranteed to be valid. And since `Thread` is not
// `Send` or `Sync`, this reference cannot be accessed after the
// driver has been dropped, since any task that can construct a
// `Thread` will have been completed by that time. So it's OK for
// this reference to live as long as `Thread`.
(!p.is_null()).then(|| unsafe { &*p })
})?;
if !CURRENT_THREAD_DRIVER.with(|current| current.is_lent()) {
return None;
}
Some(Self {
driver: inner,
_not_send_sync: PhantomData,
})
}

fn once(&self) -> &ThreadpoolDriverOnce {
// Since we are on the thread, the thread is guaranteed to have been
// initialized.
self.driver.inner.once.get().unwrap()
/// Calls `f` with the driver for the current thread.
pub fn with_driver<R>(&self, f: impl FnOnce(&ThreadpoolDriver) -> R) -> R {
CURRENT_THREAD_DRIVER.with(|current| current.borrow(|driver| f(driver.unwrap())))
}

fn with_once<R>(&self, f: impl FnOnce(&ThreadpoolDriver, &ThreadpoolDriverOnce) -> R) -> R {
self.with_driver(|driver| f(driver, driver.inner.once.get().unwrap()))
}

/// Sets the idle task to run. The task is returned by `f`, which receives
Expand All @@ -400,56 +390,52 @@ impl Thread {
F: 'static + Send + FnOnce(IdleControl) -> Fut,
Fut: std::future::Future<Output = ()>,
{
self.once().client.set_idle_task(f)
}

/// Returns the driver for the current thread.
pub fn driver(&self) -> &ThreadpoolDriver {
self.driver
self.with_once(|_, once| once.client.set_idle_task(f))
}

/// Tries to set the affinity to this thread's intended CPU, if it has not
/// already been set. Returns `Ok(false)` if the intended CPU is still
/// offline.
pub fn try_set_affinity(&self) -> Result<bool, SetAffinityError> {
let mut state = self.driver.inner.state.lock();
if matches!(state.affinity, AffinityState::Set) {
return Ok(true);
}
if !is_cpu_online(self.driver.inner.cpu).map_err(SetAffinityError::Online)? {
return Ok(false);
}
self.with_once(|driver, once| {
let mut state = driver.inner.state.lock();
if matches!(state.affinity, AffinityState::Set) {
return Ok(true);
}
if !is_cpu_online(driver.inner.cpu).map_err(SetAffinityError::Online)? {
return Ok(false);
}

let mut affinity = CpuSet::new();
affinity.set(self.driver.inner.cpu);

pal::unix::affinity::set_current_thread_affinity(&affinity)
.map_err(SetAffinityError::Thread)?;
self.once()
.client
.set_iowq_affinity(&affinity)
.map_err(SetAffinityError::Ring)?;

let old_affinity_state = std::mem::replace(&mut state.affinity, AffinityState::Set);
self.driver.inner.affinity_set.store(true, Relaxed);
drop(state);

match old_affinity_state {
AffinityState::Waiting(wakers) => {
for waker in wakers {
waker.wake();
let mut affinity = CpuSet::new();
affinity.set(driver.inner.cpu);

pal::unix::affinity::set_current_thread_affinity(&affinity)
.map_err(SetAffinityError::Thread)?;
once.client
.set_iowq_affinity(&affinity)
.map_err(SetAffinityError::Ring)?;

let old_affinity_state = std::mem::replace(&mut state.affinity, AffinityState::Set);
driver.inner.affinity_set.store(true, Relaxed);
drop(state);

match old_affinity_state {
AffinityState::Waiting(wakers) => {
for waker in wakers {
waker.wake();
}
}
AffinityState::Set => unreachable!(),
}
AffinityState::Set => unreachable!(),
}
Ok(true)
Ok(true)
})
}

/// Returns the that caused this thread to spawn.
///
/// Returns `None` if the thread was spawned to issue IO.
pub fn first_task(&self) -> Option<&TaskInfo> {
self.once().first_task.as_ref()
pub fn first_task(&self) -> Option<TaskInfo> {
self.with_once(|_, once| once.first_task.clone())
}
}

Expand All @@ -468,12 +454,12 @@ pub enum SetAffinityError {
}

thread_local! {
static CURRENT_THREADPOOL_CPU: Cell<*const ThreadpoolDriver> = const { Cell::new(std::ptr::null()) };
static CURRENT_THREAD_DRIVER: LoanCell<ThreadpoolDriver> = const { LoanCell::new() };
}

impl SpawnLocal for Thread {
fn scheduler_local(&self, metadata: &TaskMetadata) -> Arc<dyn Schedule> {
self.driver.scheduler(metadata).clone()
self.with_driver(|driver| driver.scheduler(metadata).clone())
}
}

Expand Down Expand Up @@ -506,7 +492,7 @@ struct ThreadpoolDriverOnce {
}

/// Information about a task that caused a thread to spawn.
#[derive(Debug, Inspect)]
#[derive(Debug, Clone, Inspect)]
pub struct TaskInfo {
/// The name of the task.
pub name: Arc<str>,
Expand Down
15 changes: 15 additions & 0 deletions support/loan_cell/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

[package]
name = "loan_cell"
rust-version.workspace = true
edition.workspace = true

[dependencies]

[dev-dependencies]
static_assertions.workspace = true

[lints]
workspace = true
Loading

0 comments on commit c9b59a1

Please sign in to comment.