Skip to content

Commit

Permalink
Disallow re-entering wasm while off main stack (#117)
Browse files Browse the repository at this point in the history
This PR addresses an issue discussed in #109:
Correctly handling the case where we re-enter wasm while already on a
continuation stack is difficult. For the time being, we therefore
disallow this. This PR adds the necessary logic to detect this.

Concretely, in `invoke_wasm_and_catch_traps`, we inspect the chain of
nested wasm (+ host) invocations, represented by the linked list of
`CallThreadState` objects maintained in `wasmtime_runtime:traphandlers`.
To this end, for those `CallThreadState` objects that represent
execution of wasm, we store a pointer to the corresponding `Store`'s
`StackChainCell`.

Please note that the diff of the test file `typed_continuations.rs`
looks unnecessarily scary: I moved the existing tests into a module
`wasi`, and added modules `test_utils` and `host`.

---------

Co-authored-by: Daniel Hillerström <[email protected]>
  • Loading branch information
frank-emrich and dhil committed Feb 29, 2024
1 parent 846522e commit ae380fd
Show file tree
Hide file tree
Showing 7 changed files with 643 additions and 142 deletions.
7 changes: 7 additions & 0 deletions crates/continuations/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,13 @@ impl StackChain {
pub const ABSENT_DISCRIMINANT: usize = STACK_CHAIN_ABSENT_DISCRIMINANT;
pub const MAIN_STACK_DISCRIMINANT: usize = STACK_CHAIN_MAIN_STACK_DISCRIMINANT;
pub const CONTINUATION_DISCRIMINANT: usize = STACK_CHAIN_CONTINUATION_DISCRIMINANT;

pub fn is_main_stack(&self) -> bool {
match self {
StackChain::MainStack(_) => true,
_ => false,
}
}
}

#[repr(transparent)]
Expand Down
63 changes: 52 additions & 11 deletions crates/runtime/src/traphandlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@ mod backtrace;
mod coredump;

use crate::sys::traphandlers;
use crate::{Instance, VMContext, VMRuntimeLimits};
use crate::{Instance, VMContext, VMOpaqueContext, VMRuntimeLimits};
use anyhow::Error;
use std::any::Any;
use std::cell::{Cell, UnsafeCell};
use std::mem::MaybeUninit;
use std::ptr;
use std::sync::Once;
use wasmtime_continuations::StackChainCell;

pub use self::backtrace::{Backtrace, Frame};
pub use self::coredump::CoreDumpStack;
Expand Down Expand Up @@ -207,22 +208,31 @@ pub unsafe fn catch_traps<'a, F>(
capture_backtrace: bool,
capture_coredump: bool,
caller: *mut VMContext,
callee: *mut VMOpaqueContext,
mut closure: F,
) -> Result<(), Box<Trap>>
where
F: FnMut(*mut VMContext),
{
let limits = Instance::from_vmctx(caller, |i| i.runtime_limits());

let result = CallThreadState::new(signal_handler, capture_backtrace, capture_coredump, *limits)
.with(|cx| {
traphandlers::wasmtime_setjmp(
cx.jmp_buf.as_ptr(),
call_closure::<F>,
&mut closure as *mut F as *mut u8,
caller,
)
});
let callee_stack_chain = VMContext::try_from_opaque(callee)
.map(|vmctx| Instance::from_vmctx(vmctx, |i| *i.stack_chain() as *const StackChainCell));

let result = CallThreadState::new(
signal_handler,
capture_backtrace,
capture_coredump,
*limits,
callee_stack_chain,
)
.with(|cx| {
traphandlers::wasmtime_setjmp(
cx.jmp_buf.as_ptr(),
call_closure::<F>,
&mut closure as *mut F as *mut u8,
caller,
)
});

return match result {
Ok(x) => Ok(x),
Expand All @@ -242,6 +252,31 @@ where
}
}

/// Returns true if the first `CallThreadState` in this thread's chain that
/// actually executes wasm is doing so inside a continuation. Returns false
/// if there is no `CallThreadState` executing wasm.
pub fn first_wasm_state_on_fiber_stack() -> bool {
tls::with(|head_state| {
// Iterate this threads' CallThreadState chain starting at `head_state`
// (if chain is non-empty), skipping those CTSs whose
// `callee_stack_chain` is None. This means that if `first_wasm_state`
// is Some, it is the first entry in the call thread state chain
// actually executin wasm.
let first_wasm_state = head_state
.iter()
.flat_map(|head| head.iter())
.skip_while(|state| state.callee_stack_chain.is_none())
.next();

first_wasm_state.map_or(false, |state| unsafe {
let stack_chain = &*state
.callee_stack_chain
.expect("must be Some according to filtering above");
!(*stack_chain.0.get()).is_main_stack()
})
})
}

// Module to hide visibility of the `CallThreadState::prev` field and force
// usage of its accessor methods.
mod call_thread_state {
Expand All @@ -259,6 +294,10 @@ mod call_thread_state {

pub(crate) limits: *const VMRuntimeLimits,

/// `Some(ptr)` iff this CallThreadState is for the execution of wasm.
/// In that case, `ptr` is the executing `Store`'s stack chain.
pub(crate) callee_stack_chain: Option<*const StackChainCell>,

pub(super) prev: Cell<tls::Ptr>,

// The values of `VMRuntimeLimits::last_wasm_{exit_{pc,fp},entry_sp}`
Expand Down Expand Up @@ -291,6 +330,7 @@ mod call_thread_state {
capture_backtrace: bool,
capture_coredump: bool,
limits: *const VMRuntimeLimits,
callee_stack_chain: Option<*const StackChainCell>,
) -> CallThreadState {
CallThreadState {
unwind: UnsafeCell::new(MaybeUninit::uninit()),
Expand All @@ -299,6 +339,7 @@ mod call_thread_state {
capture_backtrace,
capture_coredump,
limits,
callee_stack_chain,
prev: Cell::new(ptr::null()),
old_last_wasm_exit_fp: Cell::new(unsafe { *(*limits).last_wasm_exit_fp.get() }),
old_last_wasm_exit_pc: Cell::new(unsafe { *(*limits).last_wasm_exit_pc.get() }),
Expand Down
10 changes: 10 additions & 0 deletions crates/runtime/src/vmcontext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,16 @@ impl VMContext {
debug_assert_eq!((*opaque).magic, VMCONTEXT_MAGIC);
opaque.cast()
}

/// Alternative to `from_opaque` that returns `None` if the given opaque
/// context is not actually a `VMContext`.
pub unsafe fn try_from_opaque(opaque: *mut VMOpaqueContext) -> Option<*mut VMContext> {
if (*opaque).magic == VMCONTEXT_MAGIC {
Some(Self::from_opaque(opaque))
} else {
None
}
}
}

/// A "raw" and unsafe representation of a WebAssembly value.
Expand Down
58 changes: 49 additions & 9 deletions crates/wasmtime/src/runtime/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1033,15 +1033,19 @@ impl Func {
params_and_returns: *mut ValRaw,
params_and_returns_capacity: usize,
) -> Result<()> {
invoke_wasm_and_catch_traps(store, |caller| {
let func_ref = func_ref.as_ref();
(func_ref.array_call)(
func_ref.vmctx,
caller.cast::<VMOpaqueContext>(),
params_and_returns,
params_and_returns_capacity,
)
})
invoke_wasm_and_catch_traps(
store,
|caller| {
let func_ref = func_ref.as_ref();
(func_ref.array_call)(
func_ref.vmctx,
caller.cast::<VMOpaqueContext>(),
params_and_returns,
params_and_returns_capacity,
)
},
func_ref.as_ref().vmctx,
)
}

/// Converts the raw representation of a `funcref` into an `Option<Func>`
Expand Down Expand Up @@ -1533,8 +1537,43 @@ impl Func {
pub(crate) fn invoke_wasm_and_catch_traps<T>(
store: &mut StoreContextMut<'_, T>,
closure: impl FnMut(*mut VMContext),
callee: *mut VMOpaqueContext,
) -> Result<()> {
unsafe {
if VMContext::try_from_opaque(callee).is_some() {
// If we get here, the callee is a "proper" `VMContext`, and we are
// indeed calling into wasm.
//
// We now ensure that the following invariant holds (see
// wasmfx/wasmfxtime#109): Since we know that we are (re)-entering
// wasm, it must not be the case that we weren't still running
// inside a continuation when reaching this point. In other words,
// we must currently be on the main stack.
//
// We check this by inspecting this thread's chain of
// `CallThreadState`s, which is a linked list of all (nested)
// invocations of wasm (and certain host calls). If any of them are
// executing wasm, we raise an error.
// Since we are doing this check every time we enter wasm, it is
// sufficient to only look at the most recent previous invocation of
// wasm (i.e., we do not need to walk the entire `CallTheadState`
// chain, but only walk to the first such state corresponding to an
// execution of wasm).
//
// As a result, the call below is O(n), where n is the number of
// `CallThreadState`s at the beginning in this thread's CTS chain before
// the first such state that corresponds to wasm execution.
// In other words, n is the nesting level of calls to wrapped host
// functions from within a host function (e.g., calling `f.call()`
// while within a host call, where `f` is the result from wrapping a
// Rust function inside a `Func`).
if wasmtime_runtime::first_wasm_state_on_fiber_stack() {
return Err(anyhow::anyhow!(
"Re-entering wasm while already executing on a continuation stack"
));
}
}

let exit = enter_wasm(store);

if let Err(trap) = store.0.call_hook(CallHook::CallingWasm) {
Expand All @@ -1546,6 +1585,7 @@ pub(crate) fn invoke_wasm_and_catch_traps<T>(
store.0.engine().config().wasm_backtrace,
store.0.engine().config().coredump_on_trap,
store.0.default_caller(),
callee,
closure,
);
exit_wasm(store, exit);
Expand Down
25 changes: 17 additions & 8 deletions crates/wasmtime/src/runtime/func/typed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,16 +187,25 @@ where
// efficient to move in memory. This closure is actually invoked on the
// other side of a C++ shim, so it can never be inlined enough to make
// the memory go away, so the size matters here for performance.
let vmctx = unsafe { func.as_ref().vmctx };
let mut captures = (func, MaybeUninit::uninit(), params, false);

let result = invoke_wasm_and_catch_traps(store, |caller| {
let (func_ref, ret, params, returned) = &mut captures;
let func_ref = func_ref.as_ref();
let result =
Params::invoke::<Results>(func_ref.native_call, func_ref.vmctx, caller, *params);
ptr::write(ret.as_mut_ptr(), result);
*returned = true
});
let result = invoke_wasm_and_catch_traps(
store,
|caller| {
let (func_ref, ret, params, returned) = &mut captures;
let func_ref = func_ref.as_ref();
let result = Params::invoke::<Results>(
func_ref.native_call,
func_ref.vmctx,
caller,
*params,
);
ptr::write(ret.as_mut_ptr(), result);
*returned = true
},
vmctx,
);
let (_, ret, _, returned) = captures;
debug_assert_eq!(result.is_ok(), returned);
result?;
Expand Down
19 changes: 12 additions & 7 deletions crates/wasmtime/src/runtime/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,14 +354,19 @@ impl Instance {
let instance = store.0.instance_mut(id);
let f = instance.get_exported_func(start);
let caller_vmctx = instance.vmctx();
let callee_vmctx = unsafe { f.func_ref.as_ref().vmctx };
unsafe {
super::func::invoke_wasm_and_catch_traps(store, |_default_caller| {
let func = mem::transmute::<
NonNull<VMNativeCallFunction>,
extern "C" fn(*mut VMOpaqueContext, *mut VMContext),
>(f.func_ref.as_ref().native_call);
func(f.func_ref.as_ref().vmctx, caller_vmctx)
})?;
super::func::invoke_wasm_and_catch_traps(
store,
|_default_caller| {
let func = mem::transmute::<
NonNull<VMNativeCallFunction>,
extern "C" fn(*mut VMOpaqueContext, *mut VMContext),
>(f.func_ref.as_ref().native_call);
func(f.func_ref.as_ref().vmctx, caller_vmctx)
},
callee_vmctx,
)?;
}
Ok(())
}
Expand Down
Loading

0 comments on commit ae380fd

Please sign in to comment.