Use a union for VMContext for functions

This makes the code more self documenting and correctly uses `unsafe`
to communicate that it's the user's responsibility to ensure that the
code paths that can lead to the access can only write the value they
expect to be there.
This commit is contained in:
Mark McCaskey
2020-10-22 11:46:42 -07:00
parent 36a713a649
commit b124ea2032
10 changed files with 105 additions and 50 deletions

View File

@@ -11,8 +11,8 @@ use std::cmp::max;
use std::fmt; use std::fmt;
use wasmer_vm::{ use wasmer_vm::{
raise_user_trap, resume_panic, wasmer_call_trampoline, Export, ExportFunction, raise_user_trap, resume_panic, wasmer_call_trampoline, Export, ExportFunction,
VMCallerCheckedAnyfunc, VMContext, VMDynamicFunctionContext, VMFunctionBody, VMFunctionKind, VMCallerCheckedAnyfunc, VMDynamicFunctionContext, VMFunctionBody, VMFunctionExtraData,
VMTrampoline, VMFunctionKind, VMTrampoline,
}; };
/// A function defined in the Wasm module /// A function defined in the Wasm module
@@ -85,7 +85,9 @@ impl Function {
// The engine linker will replace the address with one pointing to a // The engine linker will replace the address with one pointing to a
// generated dynamic trampoline. // generated dynamic trampoline.
let address = std::ptr::null() as *const VMFunctionBody; let address = std::ptr::null() as *const VMFunctionBody;
let vmctx = Box::into_raw(Box::new(dynamic_ctx)) as *mut VMContext; let vmctx = VMFunctionExtraData {
host_env: Box::into_raw(Box::new(dynamic_ctx)) as *mut _,
};
Self { Self {
store: store.clone(), store: store.clone(),
@@ -135,7 +137,9 @@ impl Function {
// The engine linker will replace the address with one pointing to a // The engine linker will replace the address with one pointing to a
// generated dynamic trampoline. // generated dynamic trampoline.
let address = std::ptr::null() as *const VMFunctionBody; let address = std::ptr::null() as *const VMFunctionBody;
let vmctx = Box::into_raw(Box::new(dynamic_ctx)) as *mut VMContext; let vmctx = VMFunctionExtraData {
host_env: Box::into_raw(Box::new(dynamic_ctx)) as *mut _,
};
Self { Self {
store: store.clone(), store: store.clone(),
@@ -176,7 +180,9 @@ impl Function {
{ {
let function = inner::Function::<Args, Rets>::new(func); let function = inner::Function::<Args, Rets>::new(func);
let address = function.address() as *const VMFunctionBody; let address = function.address() as *const VMFunctionBody;
let vmctx = std::ptr::null_mut() as *mut _ as *mut VMContext; let vmctx = VMFunctionExtraData {
host_env: std::ptr::null_mut() as *mut _,
};
let signature = function.ty(); let signature = function.ty();
Self { Self {
@@ -230,7 +236,9 @@ impl Function {
// In the case of Host-defined functions `VMContext` is whatever environment // In the case of Host-defined functions `VMContext` is whatever environment
// the user want to attach to the function. // the user want to attach to the function.
let box_env = Box::new(env); let box_env = Box::new(env);
let vmctx = Box::into_raw(box_env) as *mut _ as *mut VMContext; let vmctx = VMFunctionExtraData {
host_env: Box::into_raw(box_env) as *mut _,
};
let signature = function.ty(); let signature = function.ty();
Self { Self {
@@ -365,7 +373,8 @@ impl Function {
Self { Self {
store: store.clone(), store: store.clone(),
definition: FunctionDefinition::Host(HostFunctionDefinition { definition: FunctionDefinition::Host(HostFunctionDefinition {
has_env: !wasmer_export.vmctx.is_null(), // TOOD: make safe function on this union to check for null
has_env: !unsafe { wasmer_export.vmctx.host_env.is_null() },
}), }),
exported: wasmer_export, exported: wasmer_export,
} }

View File

@@ -17,7 +17,7 @@ use crate::{FromToNativeWasmType, Function, FunctionType, RuntimeError, Store, W
use std::panic::{catch_unwind, AssertUnwindSafe}; use std::panic::{catch_unwind, AssertUnwindSafe};
use wasmer_types::NativeWasmType; use wasmer_types::NativeWasmType;
use wasmer_vm::{ use wasmer_vm::{
ExportFunction, VMContext, VMDynamicFunctionContext, VMFunctionBody, VMFunctionKind, ExportFunction, VMDynamicFunctionContext, VMFunctionBody, VMFunctionExtraData, VMFunctionKind,
}; };
/// A WebAssembly function that can be called natively /// A WebAssembly function that can be called natively
@@ -26,7 +26,7 @@ pub struct NativeFunc<'a, Args = (), Rets = ()> {
definition: FunctionDefinition, definition: FunctionDefinition,
store: Store, store: Store,
address: *const VMFunctionBody, address: *const VMFunctionBody,
vmctx: *mut VMContext, vmctx: VMFunctionExtraData,
arg_kind: VMFunctionKind, arg_kind: VMFunctionKind,
// exported: ExportFunction, // exported: ExportFunction,
_phantom: PhantomData<(&'a (), Args, Rets)>, _phantom: PhantomData<(&'a (), Args, Rets)>,
@@ -42,7 +42,7 @@ where
pub(crate) fn new( pub(crate) fn new(
store: Store, store: Store,
address: *const VMFunctionBody, address: *const VMFunctionBody,
vmctx: *mut VMContext, vmctx: VMFunctionExtraData,
arg_kind: VMFunctionKind, arg_kind: VMFunctionKind,
definition: FunctionDefinition, definition: FunctionDefinition,
) -> Self { ) -> Self {
@@ -165,7 +165,7 @@ macro_rules! impl_native_traits {
match self.arg_kind { match self.arg_kind {
VMFunctionKind::Static => { VMFunctionKind::Static => {
let results = catch_unwind(AssertUnwindSafe(|| unsafe { let results = catch_unwind(AssertUnwindSafe(|| unsafe {
let f = std::mem::transmute::<_, unsafe extern "C" fn( *mut VMContext, $( $x, )*) -> Rets::CStruct>(self.address); let f = std::mem::transmute::<_, unsafe extern "C" fn( VMFunctionExtraData, $( $x, )*) -> Rets::CStruct>(self.address);
// We always pass the vmctx // We always pass the vmctx
f( self.vmctx, $( $x, )* ) f( self.vmctx, $( $x, )* )
})).map_err(|e| RuntimeError::new(format!("{:?}", e)))?; })).map_err(|e| RuntimeError::new(format!("{:?}", e)))?;
@@ -175,12 +175,16 @@ macro_rules! impl_native_traits {
let params_list = [ $( $x.to_native().to_value() ),* ]; let params_list = [ $( $x.to_native().to_value() ),* ];
let results = if !has_env { let results = if !has_env {
type VMContextWithoutEnv = VMDynamicFunctionContext<VMDynamicFunctionWithoutEnv>; type VMContextWithoutEnv = VMDynamicFunctionContext<VMDynamicFunctionWithoutEnv>;
let ctx = self.vmctx as *mut VMContextWithoutEnv; unsafe {
unsafe { (*ctx).ctx.call(&params_list)? } let ctx = self.vmctx.host_env as *mut VMContextWithoutEnv;
(*ctx).ctx.call(&params_list)?
}
} else { } else {
type VMContextWithEnv = VMDynamicFunctionContext<VMDynamicFunctionWithEnv<std::ffi::c_void>>; type VMContextWithEnv = VMDynamicFunctionContext<VMDynamicFunctionWithEnv<std::ffi::c_void>>;
let ctx = self.vmctx as *mut VMContextWithEnv; unsafe {
unsafe { (*ctx).ctx.call(&params_list)? } let ctx = self.vmctx.host_env as *mut VMContextWithEnv;
(*ctx).ctx.call(&params_list)?
}
}; };
let mut rets_list_array = Rets::empty_array(); let mut rets_list_array = Rets::empty_array();
let mut_rets = rets_list_array.as_mut() as *mut [i128] as *mut i128; let mut_rets = rets_list_array.as_mut() as *mut [i128] as *mut i128;

View File

@@ -56,7 +56,9 @@ impl ValFuncRef for Val {
Self::ExternRef(ExternRef::Null) => wasmer_vm::VMCallerCheckedAnyfunc { Self::ExternRef(ExternRef::Null) => wasmer_vm::VMCallerCheckedAnyfunc {
func_ptr: ptr::null(), func_ptr: ptr::null(),
type_index: wasmer_vm::VMSharedSignatureIndex::default(), type_index: wasmer_vm::VMSharedSignatureIndex::default(),
vmctx: ptr::null_mut(), vmctx: wasmer_vm::VMFunctionExtraData {
host_env: ptr::null_mut(),
},
}, },
Self::FuncRef(f) => f.checked_anyfunc(), Self::FuncRef(f) => f.checked_anyfunc(),
_ => return Err(RuntimeError::new("val is not funcref")), _ => return Err(RuntimeError::new("val is not funcref")),

View File

@@ -102,14 +102,14 @@ impl Module {
// Properly drop the empty `vm::Ctx` // Properly drop the empty `vm::Ctx`
// created by the host function. // created by the host function.
unsafe { unsafe {
ptr::drop_in_place::<vm::Ctx>(function.vmctx as _); ptr::drop_in_place::<vm::Ctx>(function.vmctx.host_env as _);
} }
// Update the pointer to `VMContext`, // Update the pointer to `VMContext`,
// which is actually a `vm::Ctx` // which is actually a `vm::Ctx`
// pointer, to fallback on the // pointer, to fallback on the
// environment hack. // environment hack.
function.vmctx = pre_instance.vmctx_ptr() as _; function.vmctx.host_env = pre_instance.vmctx_ptr() as _;
} }
// `function` is a dynamic host function // `function` is a dynamic host function
// constructed with // constructed with
@@ -147,13 +147,13 @@ impl Module {
new::wasmer_vm::VMDynamicFunctionContext< new::wasmer_vm::VMDynamicFunctionContext<
VMDynamicFunctionWithEnv<DynamicCtx>, VMDynamicFunctionWithEnv<DynamicCtx>,
>, >,
> = unsafe { Box::from_raw(function.vmctx as *mut _) }; > = unsafe { Box::from_raw(function.vmctx.host_env as *mut _) };
// Replace the environment by ours. // Replace the environment by ours.
vmctx.ctx.env.borrow_mut().vmctx = pre_instance.vmctx(); vmctx.ctx.env.borrow_mut().vmctx = pre_instance.vmctx();
// … without anyone noticing… // … without anyone noticing…
function.vmctx = Box::into_raw(vmctx) as _; function.vmctx.host_env = Box::into_raw(vmctx) as _;
} }
} }

View File

@@ -167,7 +167,7 @@ pub fn resolve_imports(
}; };
function_imports.push(VMFunctionImport { function_imports.push(VMFunctionImport {
body: address, body: address,
vmctx: f.vmctx, extra_data: f.vmctx,
}); });
} }
Export::Table(ref t) => { Export::Table(ref t) => {

View File

@@ -4,7 +4,7 @@
use crate::global::Global; use crate::global::Global;
use crate::memory::{Memory, MemoryStyle}; use crate::memory::{Memory, MemoryStyle};
use crate::table::{Table, TableStyle}; use crate::table::{Table, TableStyle};
use crate::vmcontext::{VMContext, VMFunctionBody, VMFunctionKind, VMTrampoline}; use crate::vmcontext::{VMFunctionBody, VMFunctionExtraData, VMFunctionKind, VMTrampoline};
use std::sync::Arc; use std::sync::Arc;
use wasmer_types::{FunctionType, MemoryType, TableType}; use wasmer_types::{FunctionType, MemoryType, TableType};
@@ -30,7 +30,7 @@ pub struct ExportFunction {
/// The address of the native-code function. /// The address of the native-code function.
pub address: *const VMFunctionBody, pub address: *const VMFunctionBody,
/// Pointer to the containing `VMContext`. /// Pointer to the containing `VMContext`.
pub vmctx: *mut VMContext, pub vmctx: VMFunctionExtraData,
/// The function type, used for compatibility checking. /// The function type, used for compatibility checking.
pub signature: FunctionType, pub signature: FunctionType,
/// The function kind (it defines how it's the signature that provided `address` have) /// The function kind (it defines how it's the signature that provided `address` have)

View File

@@ -11,9 +11,10 @@ use crate::memory::{Memory, MemoryError};
use crate::table::Table; use crate::table::Table;
use crate::trap::{catch_traps, init_traps, Trap, TrapCode}; use crate::trap::{catch_traps, init_traps, Trap, TrapCode};
use crate::vmcontext::{ use crate::vmcontext::{
VMBuiltinFunctionsArray, VMCallerCheckedAnyfunc, VMContext, VMFunctionBody, VMFunctionImport, VMBuiltinFunctionsArray, VMCallerCheckedAnyfunc, VMContext, VMFunctionBody,
VMFunctionKind, VMGlobalDefinition, VMGlobalImport, VMMemoryDefinition, VMMemoryImport, VMFunctionExtraData, VMFunctionImport, VMFunctionKind, VMGlobalDefinition, VMGlobalImport,
VMSharedSignatureIndex, VMTableDefinition, VMTableImport, VMTrampoline, VMMemoryDefinition, VMMemoryImport, VMSharedSignatureIndex, VMTableDefinition, VMTableImport,
VMTrampoline,
}; };
use crate::{ExportFunction, ExportGlobal, ExportMemory, ExportTable}; use crate::{ExportFunction, ExportGlobal, ExportMemory, ExportTable};
use crate::{FunctionBodyPtr, ModuleInfo, VMOffsets}; use crate::{FunctionBodyPtr, ModuleInfo, VMOffsets};
@@ -296,10 +297,15 @@ impl Instance {
let sig_index = &self.module.functions[*index]; let sig_index = &self.module.functions[*index];
let (address, vmctx) = if let Some(def_index) = self.module.local_func_index(*index) let (address, vmctx) = if let Some(def_index) = self.module.local_func_index(*index)
{ {
(self.functions[def_index].0 as *const _, self.vmctx_ptr()) (
self.functions[def_index].0 as *const _,
VMFunctionExtraData {
vmctx: self.vmctx_ptr(),
},
)
} else { } else {
let import = self.imported_function(*index); let import = self.imported_function(*index);
(import.body, import.vmctx) (import.body, import.extra_data)
}; };
let call_trampoline = Some(self.function_call_trampolines[*sig_index]); let call_trampoline = Some(self.function_call_trampolines[*sig_index]);
let signature = self.module.signatures[*sig_index].clone(); let signature = self.module.signatures[*sig_index].clone();
@@ -377,19 +383,24 @@ impl Instance {
.get(local_index) .get(local_index)
.expect("function index is out of bounds") .expect("function index is out of bounds")
.0; .0;
(body as *const _, self.vmctx_ptr()) (
body as *const _,
VMFunctionExtraData {
vmctx: self.vmctx_ptr(),
},
)
} }
None => { None => {
assert_lt!(start_index.index(), self.module.num_imported_functions); assert_lt!(start_index.index(), self.module.num_imported_functions);
let import = self.imported_function(start_index); let import = self.imported_function(start_index);
(import.body, import.vmctx) (import.body, import.extra_data)
} }
}; };
// Make the call. // Make the call.
unsafe { unsafe {
catch_traps(callee_vmctx, || { catch_traps(callee_vmctx, || {
mem::transmute::<*const VMFunctionBody, unsafe extern "C" fn(*mut VMContext)>( mem::transmute::<*const VMFunctionBody, unsafe extern "C" fn(VMFunctionExtraData)>(
callee_address, callee_address,
)(callee_vmctx) )(callee_vmctx)
}) })
@@ -561,10 +572,15 @@ impl Instance {
let type_index = self.signature_id(sig); let type_index = self.signature_id(sig);
let (func_ptr, vmctx) = if let Some(def_index) = self.module.local_func_index(index) { let (func_ptr, vmctx) = if let Some(def_index) = self.module.local_func_index(index) {
(self.functions[def_index].0 as *const _, self.vmctx_ptr()) (
self.functions[def_index].0 as *const _,
VMFunctionExtraData {
vmctx: self.vmctx_ptr(),
},
)
} else { } else {
let import = self.imported_function(index); let import = self.imported_function(index);
(import.body, import.vmctx) (import.body, import.extra_data)
}; };
VMCallerCheckedAnyfunc { VMCallerCheckedAnyfunc {
func_ptr, func_ptr,

View File

@@ -49,9 +49,9 @@ pub use crate::table::{LinearTable, Table, TableStyle};
pub use crate::trap::*; pub use crate::trap::*;
pub use crate::vmcontext::{ pub use crate::vmcontext::{
VMBuiltinFunctionIndex, VMCallerCheckedAnyfunc, VMContext, VMDynamicFunctionContext, VMBuiltinFunctionIndex, VMCallerCheckedAnyfunc, VMContext, VMDynamicFunctionContext,
VMFunctionBody, VMFunctionImport, VMFunctionKind, VMGlobalDefinition, VMGlobalImport, VMFunctionBody, VMFunctionExtraData, VMFunctionImport, VMFunctionKind, VMGlobalDefinition,
VMMemoryDefinition, VMMemoryImport, VMSharedSignatureIndex, VMTableDefinition, VMTableImport, VMGlobalImport, VMMemoryDefinition, VMMemoryImport, VMSharedSignatureIndex, VMTableDefinition,
VMTrampoline, VMTableImport, VMTrampoline,
}; };
pub use crate::vmoffsets::{TargetSharedSignatureIndex, VMOffsets}; pub use crate::vmoffsets::{TargetSharedSignatureIndex, VMOffsets};

View File

@@ -6,7 +6,7 @@
use super::trapcode::TrapCode; use super::trapcode::TrapCode;
use crate::instance::{InstanceHandle, SignalHandler}; use crate::instance::{InstanceHandle, SignalHandler};
use crate::vmcontext::{VMContext, VMFunctionBody, VMTrampoline}; use crate::vmcontext::{VMFunctionBody, VMFunctionExtraData, VMTrampoline};
use backtrace::Backtrace; use backtrace::Backtrace;
use std::any::Any; use std::any::Any;
use std::cell::Cell; use std::cell::Cell;
@@ -429,13 +429,13 @@ impl Trap {
/// Wildly unsafe because it calls raw function pointers and reads/writes raw /// Wildly unsafe because it calls raw function pointers and reads/writes raw
/// function pointers. /// function pointers.
pub unsafe fn wasmer_call_trampoline( pub unsafe fn wasmer_call_trampoline(
vmctx: *mut VMContext, vmctx: VMFunctionExtraData,
trampoline: VMTrampoline, trampoline: VMTrampoline,
callee: *const VMFunctionBody, callee: *const VMFunctionBody,
values_vec: *mut u8, values_vec: *mut u8,
) -> Result<(), Trap> { ) -> Result<(), Trap> {
catch_traps(vmctx, || { catch_traps(vmctx, || {
mem::transmute::<_, extern "C" fn(*mut VMContext, *const VMFunctionBody, *mut u8)>( mem::transmute::<_, extern "C" fn(VMFunctionExtraData, *const VMFunctionBody, *mut u8)>(
trampoline, trampoline,
)(vmctx, callee, values_vec) )(vmctx, callee, values_vec)
}) })
@@ -447,7 +447,7 @@ pub unsafe fn wasmer_call_trampoline(
/// # Safety /// # Safety
/// ///
/// Highly unsafe since `closure` won't have any destructors run. /// Highly unsafe since `closure` won't have any destructors run.
pub unsafe fn catch_traps<F>(vmctx: *mut VMContext, mut closure: F) -> Result<(), Trap> pub unsafe fn catch_traps<F>(vmctx: VMFunctionExtraData, mut closure: F) -> Result<(), Trap>
where where
F: FnMut(), F: FnMut(),
{ {
@@ -481,7 +481,7 @@ where
/// ///
/// Check [`catch_traps`]. /// Check [`catch_traps`].
pub unsafe fn catch_traps_with_result<F, R>( pub unsafe fn catch_traps_with_result<F, R>(
vmctx: *mut VMContext, vmctx: VMFunctionExtraData,
mut closure: F, mut closure: F,
) -> Result<R, Trap> ) -> Result<R, Trap>
where where
@@ -501,7 +501,7 @@ pub struct CallThreadState {
jmp_buf: Cell<*const u8>, jmp_buf: Cell<*const u8>,
reset_guard_page: Cell<bool>, reset_guard_page: Cell<bool>,
prev: Option<*const CallThreadState>, prev: Option<*const CallThreadState>,
vmctx: *mut VMContext, vmctx: VMFunctionExtraData,
handling_trap: Cell<bool>, handling_trap: Cell<bool>,
} }
@@ -518,7 +518,7 @@ enum UnwindReason {
} }
impl CallThreadState { impl CallThreadState {
fn new(vmctx: *mut VMContext) -> Self { fn new(vmctx: VMFunctionExtraData) -> Self {
Self { Self {
unwind: Cell::new(UnwindReason::None), unwind: Cell::new(UnwindReason::None),
vmctx, vmctx,
@@ -561,7 +561,7 @@ impl CallThreadState {
fn any_instance(&self, func: impl Fn(&InstanceHandle) -> bool) -> bool { fn any_instance(&self, func: impl Fn(&InstanceHandle) -> bool) -> bool {
unsafe { unsafe {
if func(&InstanceHandle::from_vmctx(self.vmctx)) { if func(&InstanceHandle::from_vmctx(self.vmctx.vmctx)) {
return true; return true;
} }
match self.prev { match self.prev {

View File

@@ -15,6 +15,28 @@ use std::ptr::{self, NonNull};
use std::sync::Arc; use std::sync::Arc;
use std::u32; use std::u32;
/// We stop lying about what this daat is
/// TODO:
#[derive(Copy, Clone)]
pub union VMFunctionExtraData {
/// Wasm function, it has a real VMContext:
pub vmctx: *mut VMContext,
/// Host functions can have custom environments
pub host_env: *mut std::ffi::c_void,
}
impl std::fmt::Debug for VMFunctionExtraData {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "FunctionExtarData union")
}
}
impl std::cmp::PartialEq for VMFunctionExtraData {
fn eq(&self, rhs: &Self) -> bool {
unsafe { self.host_env as usize == rhs.host_env as usize }
}
}
/// An imported function. /// An imported function.
#[derive(Debug, Copy, Clone)] #[derive(Debug, Copy, Clone)]
#[repr(C)] #[repr(C)]
@@ -22,8 +44,8 @@ pub struct VMFunctionImport {
/// A pointer to the imported function body. /// A pointer to the imported function body.
pub body: *const VMFunctionBody, pub body: *const VMFunctionBody,
/// A pointer to the `VMContext` that owns the function. /// A pointer to the `VMContext` that owns the function or host env data.
pub vmctx: *mut VMContext, pub extra_data: VMFunctionExtraData,
} }
#[cfg(test)] #[cfg(test)]
@@ -46,7 +68,7 @@ mod test_vmfunction_import {
usize::from(offsets.vmfunction_import_body()) usize::from(offsets.vmfunction_import_body())
); );
assert_eq!( assert_eq!(
offset_of!(VMFunctionImport, vmctx), offset_of!(VMFunctionImport, extra_data),
usize::from(offsets.vmfunction_import_vmctx()) usize::from(offsets.vmfunction_import_vmctx())
); );
} }
@@ -728,8 +750,8 @@ pub struct VMCallerCheckedAnyfunc {
pub func_ptr: *const VMFunctionBody, pub func_ptr: *const VMFunctionBody,
/// Function signature id. /// Function signature id.
pub type_index: VMSharedSignatureIndex, pub type_index: VMSharedSignatureIndex,
/// Function `VMContext`. /// Function `VMContext` or host env.
pub vmctx: *mut VMContext, pub vmctx: VMFunctionExtraData,
// If more elements are added here, remember to add offset_of tests below! // If more elements are added here, remember to add offset_of tests below!
} }
@@ -768,7 +790,9 @@ impl Default for VMCallerCheckedAnyfunc {
Self { Self {
func_ptr: ptr::null_mut(), func_ptr: ptr::null_mut(),
type_index: Default::default(), type_index: Default::default(),
vmctx: ptr::null_mut(), vmctx: VMFunctionExtraData {
vmctx: ptr::null_mut(),
},
} }
} }
} }