Skip to content

Commit

Permalink
wasmtime: Fix resetting stack-walking registers when entering/exiting…
Browse files Browse the repository at this point in the history
… Wasm

Fixes a regression from bytecodealliance#6262, originally reported in
bytecodealliance/wasmtime-dotnet#245

The issue was that we would enter Wasm and save the stack-walking registers but
never clear them after Wasm returns. Then if a host-to-host call tried to
capture a stack, we would mistakenly attempt to use those stale registers to
start the stack walk. This mistake would be caught by an assertion, triggering a
panic.

This commit fixes the issue by managing the save/restore in the
`CallThreadState` construction/drop, rather than in the old `set_prev`
method.

Co-Authored-By: Alex Crichton <[email protected]>
  • Loading branch information
fitzgen and alexcrichton committed May 2, 2023
1 parent 85cbaa5 commit b82d804
Show file tree
Hide file tree
Showing 3 changed files with 224 additions and 108 deletions.
145 changes: 37 additions & 108 deletions crates/runtime/src/traphandlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,6 @@ where
// usage of its accessor methods.
mod call_thread_state {
use super::*;
use std::mem;

/// Temporary state stored on the stack which is registered in the `tls` module
/// below for calls into wasm.
Expand All @@ -262,18 +261,29 @@ mod call_thread_state {

prev: Cell<tls::Ptr>,

// The values of `VMRuntimeLimits::last_wasm_{exit_{pc,fp},entry_sp}` for
// the *previous* `CallThreadState`. Our *current* last wasm PC/FP/SP are
// saved in `self.limits`. We save a copy of the old registers here because
// the `VMRuntimeLimits` typically doesn't change across nested calls into
// Wasm (i.e. they are typically calls back into the same store and
// `self.limits == self.prev.limits`) and we must to maintain the list of
// The values of `VMRuntimeLimits::last_wasm_{exit_{pc,fp},entry_sp}`
// for the *previous* `CallThreadState` for this same store/limits. Our
// *current* last wasm PC/FP/SP are saved in `self.limits`. We save a
// copy of the old registers here because the `VMRuntimeLimits`
// typically doesn't change across nested calls into Wasm (i.e. they are
// typically calls back into the same store and `self.limits ==
// self.prev.limits`) and we must to maintain the list of
// contiguous-Wasm-frames stack regions for backtracing purposes.
old_last_wasm_exit_fp: Cell<usize>,
old_last_wasm_exit_pc: Cell<usize>,
old_last_wasm_entry_sp: Cell<usize>,
}

impl Drop for CallThreadState {
fn drop(&mut self) {
unsafe {
*(*self.limits).last_wasm_exit_fp.get() = self.old_last_wasm_exit_fp.get();
*(*self.limits).last_wasm_exit_pc.get() = self.old_last_wasm_exit_pc.get();
*(*self.limits).last_wasm_entry_sp.get() = self.old_last_wasm_entry_sp.get();
}
}
}

impl CallThreadState {
#[inline]
pub(super) fn new(
Expand All @@ -288,9 +298,9 @@ mod call_thread_state {
capture_backtrace,
limits,
prev: Cell::new(ptr::null()),
old_last_wasm_exit_fp: Cell::new(0),
old_last_wasm_exit_pc: Cell::new(0),
old_last_wasm_entry_sp: Cell::new(0),
old_last_wasm_exit_fp: Cell::new(unsafe { *(*limits).last_wasm_exit_fp.get() }),
old_last_wasm_exit_pc: Cell::new(unsafe { *(*limits).last_wasm_exit_pc.get() }),
old_last_wasm_entry_sp: Cell::new(unsafe { *(*limits).last_wasm_entry_sp.get() }),
}
}

Expand All @@ -314,83 +324,15 @@ mod call_thread_state {
self.prev.get()
}

/// Connect the link to the previous `CallThreadState`.
///
/// Synchronizes the last wasm FP, PC, and SP on `self` and the old
/// `self.prev` for the given new `prev`, and returns the old
/// `self.prev`.
pub unsafe fn set_prev(&self, prev: tls::Ptr) -> tls::Ptr {
let old_prev = self.prev.get();

// Restore the old `prev`'s saved registers in its
// `VMRuntimeLimits`. This is necessary for when we are async
// suspending the top `CallThreadState` and doing `set_prev(null)`
// on it, and so any stack walking we do subsequently will start at
// the old `prev` and look at its `VMRuntimeLimits` to get the
// initial saved registers.
if let Some(old_prev) = old_prev.as_ref() {
*(*old_prev.limits).last_wasm_exit_fp.get() = self.old_last_wasm_exit_fp();
*(*old_prev.limits).last_wasm_exit_pc.get() = self.old_last_wasm_exit_pc();
*(*old_prev.limits).last_wasm_entry_sp.get() = self.old_last_wasm_entry_sp();
}

self.prev.set(prev);

let mut old_last_wasm_exit_fp = 0;
let mut old_last_wasm_exit_pc = 0;
let mut old_last_wasm_entry_sp = 0;
if let Some(prev) = prev.as_ref() {
// We are entering a new `CallThreadState` or resuming a
// previously suspended one. This means we will push new Wasm
// frames that save the new Wasm FP/SP/PC registers into
// `VMRuntimeLimits`, we need to first save the old Wasm
// FP/SP/PC registers into this new `CallThreadState` to
// maintain our list of contiguous Wasm frame regions that we
// use when capturing stack traces.
//
// NB: the Wasm<--->host trampolines saved the Wasm FP/SP/PC
// registers in the active-at-that-time store's
// `VMRuntimeLimits`. For the most recent FP/PC/SP that is the
// `state.prev.limits` (since we haven't entered this
// `CallThreadState` yet). And that can be a different
// `VMRuntimeLimits` instance from the currently active
// `state.limits`, which will be used by the upcoming call into
// Wasm! Consider the case where we have multiple, nested calls
// across stores (with host code in between, by necessity, since
// only things in the same store can be linked directly
// together):
//
// | ... |
// | Host | |
// +-----------------+ | stack
// | Wasm in store A | | grows
// +-----------------+ | down
// | Host | |
// +-----------------+ |
// | Wasm in store B | V
// +-----------------+
//
// In this scenario `state.limits != state.prev.limits`,
// i.e. `B.limits != A.limits`! Therefore we must take care to
// read the old FP/SP/PC from `state.prev.limits`, rather than
// `state.limits`, and store those saved registers into the
// current `state`.
//
// See also the comment above the
// `CallThreadState::old_last_wasm_*` fields.
old_last_wasm_exit_fp =
mem::replace(&mut *(*prev.limits).last_wasm_exit_fp.get(), 0);
old_last_wasm_exit_pc =
mem::replace(&mut *(*prev.limits).last_wasm_exit_pc.get(), 0);
old_last_wasm_entry_sp =
mem::replace(&mut *(*prev.limits).last_wasm_entry_sp.get(), 0);
}

self.old_last_wasm_exit_fp.set(old_last_wasm_exit_fp);
self.old_last_wasm_exit_pc.set(old_last_wasm_exit_pc);
self.old_last_wasm_entry_sp.set(old_last_wasm_entry_sp);
pub(crate) unsafe fn push(&self) {
assert!(self.prev.get().is_null());
self.prev.set(tls::raw::replace(self));
}

old_prev
pub(crate) unsafe fn pop(&self) {
let prev = self.prev.replace(ptr::null());
let head = tls::raw::replace(prev);
assert!(std::ptr::eq(head, self));
}
}
}
Expand Down Expand Up @@ -533,7 +475,6 @@ impl<T: Copy> Drop for ResetCell<'_, T> {
// the caller to the trap site.
mod tls {
use super::CallThreadState;
use std::ptr;

pub use raw::Ptr;

Expand All @@ -551,7 +492,7 @@ mod tls {
//
// Note, though, that if async support is disabled at compile time then
// these functions are free to be inlined.
mod raw {
pub(super) mod raw {
use super::CallThreadState;
use std::cell::Cell;
use std::ptr;
Expand Down Expand Up @@ -625,8 +566,7 @@ mod tls {
// accidentally used later.
let state = raw::get();
if let Some(state) = state.as_ref() {
let prev_state = state.set_prev(ptr::null());
raw::replace(prev_state);
state.pop();
} else {
// Null case: we aren't in a wasm context, so theres no tls to
// save for restoration.
Expand All @@ -640,18 +580,12 @@ mod tls {
/// This is unsafe because it's intended to only be used within the
/// context of stack switching within wasmtime.
pub unsafe fn replace(self) {
// Null case: we aren't in a wasm context, so theres no tls
// to restore.
if self.state.is_null() {
return;
if let Some(state) = self.state.as_ref() {
state.push();
} else {
// Null case: we aren't in a wasm context, so theres no tls
// to restore.
}

// We need to configure our previous TLS pointer to whatever is in
// TLS at this time, and then we set the current state to ourselves.
let prev = raw::get();
assert!((*self.state).prev().is_null());
(*self.state).set_prev(prev);
raw::replace(self.state);
}
}

Expand All @@ -668,18 +602,13 @@ mod tls {
#[inline]
fn drop(&mut self) {
unsafe {
let prev = self.state.set_prev(ptr::null());
let old_state = raw::replace(prev);
debug_assert!(std::ptr::eq(old_state, self.state));
self.state.pop();
}
}
}

let prev = raw::replace(state);

unsafe {
state.set_prev(prev);

state.push();
let reset = Reset { state };
closure(reset.state)
}
Expand Down
137 changes: 137 additions & 0 deletions tests/all/async_functions.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use anyhow::{anyhow, bail, Result};
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
use wasmtime::*;

Expand Down Expand Up @@ -706,3 +707,139 @@ fn noop_waker() -> Waker {
const RAW: RawWaker = RawWaker::new(0 as *const (), &VTABLE);
unsafe { Waker::from_raw(RAW) }
}

#[tokio::test]
async fn non_stacky_async_activations() -> Result<()> {
let mut config = Config::new();
config.async_support(true);
let engine = Engine::new(&config)?;
let mut store1: Store<Option<Pin<Box<dyn Future<Output = Result<()>> + Send>>>> =
Store::new(&engine, None);
let mut linker1 = Linker::new(&engine);

let module1 = Module::new(
&engine,
r#"
(module $m1
(import "" "host_capture_stack" (func $host_capture_stack))
(import "" "start_async_instance" (func $start_async_instance))
(func $capture_stack (export "capture_stack")
call $host_capture_stack
)
(func $run_sync (export "run_sync")
call $start_async_instance
)
)
"#,
)?;

let module2 = Module::new(
&engine,
r#"
(module $m2
(import "" "yield" (func $yield))
(func $run_async (export "run_async")
call $yield
)
)
"#,
)?;

let stacks = Arc::new(Mutex::new(vec![]));
fn capture_stack(stacks: &Arc<Mutex<Vec<WasmBacktrace>>>, store: impl AsContext) {
let mut stacks = stacks.lock().unwrap();
stacks.push(wasmtime::WasmBacktrace::force_capture(store));
}

linker1.func_wrap0_async("", "host_capture_stack", {
let stacks = stacks.clone();
move |caller| {
capture_stack(&stacks, &caller);
Box::new(async { Ok(()) })
}
})?;

linker1.func_wrap0_async("", "start_async_instance", {
let stacks = stacks.clone();
move |mut caller| {
let stacks = stacks.clone();
capture_stack(&stacks, &caller);

let module2 = module2.clone();
let mut store2 = Store::new(caller.engine(), ());
let mut linker2 = Linker::new(caller.engine());
linker2
.func_wrap0_async("", "yield", {
let stacks = stacks.clone();
move |caller| {
let stacks = stacks.clone();
Box::new(async move {
capture_stack(&stacks, &caller);
tokio::task::yield_now().await;
capture_stack(&stacks, &caller);
Ok(())
})
}
})
.unwrap();

Box::new(async move {
let future = PollOnce::new(Box::pin({
let stacks = stacks.clone();
async move {
let instance2 = linker2.instantiate_async(&mut store2, &module2).await?;

instance2
.get_func(&mut store2, "run_async")
.unwrap()
.call_async(&mut store2, &[], &mut [])
.await?;

capture_stack(&stacks, &store2);
Ok(())
}
}) as _)
.await;
capture_stack(&stacks, &caller);
*caller.data_mut() = Some(future);
Ok(())
})
}
})?;

let instance1 = linker1.instantiate_async(&mut store1, &module1).await?;
instance1
.get_typed_func::<(), ()>(&mut store1, "run_sync")?
.call_async(&mut store1, ())
.await?;
let future = store1.data_mut().take().unwrap();
future.await?;

instance1
.get_typed_func::<(), ()>(&mut store1, "capture_stack")?
.call_async(&mut store1, ())
.await?;

let stacks = stacks.lock().unwrap();
eprintln!("stacks = {stacks:#?}");

assert_eq!(stacks.len(), 6);
for (actual, expected) in stacks.iter().zip(vec![
vec!["run_sync"],
vec!["run_async"],
vec!["run_sync"],
vec!["run_async"],
vec![],
vec!["capture_stack"],
]) {
eprintln!("expected = {expected:?}");
eprintln!("actual = {actual:?}");
assert_eq!(actual.frames().len(), expected.len());
for (actual, expected) in actual.frames().iter().zip(expected) {
assert_eq!(actual.func_name(), Some(expected));
}
}

Ok(())
}
Loading

0 comments on commit b82d804

Please sign in to comment.