From b124ea2032d72308d0900b7f589784607e08e770 Mon Sep 17 00:00:00 2001 From: Mark McCaskey Date: Thu, 22 Oct 2020 11:46:42 -0700 Subject: [PATCH] Use a union for VMContext for functions This makes the code more self documenting and correctly uses `unsafe` to communicate that it's the user's responsibility to ensure that the code paths that can lead to the access can only write the value they expect to be there. --- lib/api/src/externals/function.rs | 23 ++++++++++----- lib/api/src/native.rs | 20 ++++++++----- lib/api/src/types.rs | 4 ++- lib/deprecated/runtime-core/src/module.rs | 8 ++--- lib/engine/src/resolver.rs | 2 +- lib/vm/src/export.rs | 4 +-- lib/vm/src/instance.rs | 36 ++++++++++++++++------- lib/vm/src/lib.rs | 6 ++-- lib/vm/src/trap/traphandlers.rs | 16 +++++----- lib/vm/src/vmcontext.rs | 36 +++++++++++++++++++---- 10 files changed, 105 insertions(+), 50 deletions(-) diff --git a/lib/api/src/externals/function.rs b/lib/api/src/externals/function.rs index 32af4b6cc..e8e3ec3d1 100644 --- a/lib/api/src/externals/function.rs +++ b/lib/api/src/externals/function.rs @@ -11,8 +11,8 @@ use std::cmp::max; use std::fmt; use wasmer_vm::{ raise_user_trap, resume_panic, wasmer_call_trampoline, Export, ExportFunction, - VMCallerCheckedAnyfunc, VMContext, VMDynamicFunctionContext, VMFunctionBody, VMFunctionKind, - VMTrampoline, + VMCallerCheckedAnyfunc, VMDynamicFunctionContext, VMFunctionBody, VMFunctionExtraData, + VMFunctionKind, VMTrampoline, }; /// A function defined in the Wasm module @@ -85,7 +85,9 @@ impl Function { // The engine linker will replace the address with one pointing to a // generated dynamic trampoline. let address = std::ptr::null() as *const VMFunctionBody; - let vmctx = Box::into_raw(Box::new(dynamic_ctx)) as *mut VMContext; + let vmctx = VMFunctionExtraData { + host_env: Box::into_raw(Box::new(dynamic_ctx)) as *mut _, + }; Self { store: store.clone(), @@ -135,7 +137,9 @@ impl Function { // The engine linker will replace the address with one pointing to a // generated dynamic trampoline. let address = std::ptr::null() as *const VMFunctionBody; - let vmctx = Box::into_raw(Box::new(dynamic_ctx)) as *mut VMContext; + let vmctx = VMFunctionExtraData { + host_env: Box::into_raw(Box::new(dynamic_ctx)) as *mut _, + }; Self { store: store.clone(), @@ -176,7 +180,9 @@ impl Function { { let function = inner::Function::::new(func); let address = function.address() as *const VMFunctionBody; - let vmctx = std::ptr::null_mut() as *mut _ as *mut VMContext; + let vmctx = VMFunctionExtraData { + host_env: std::ptr::null_mut() as *mut _, + }; let signature = function.ty(); Self { @@ -230,7 +236,9 @@ impl Function { // In the case of Host-defined functions `VMContext` is whatever environment // the user want to attach to the function. let box_env = Box::new(env); - let vmctx = Box::into_raw(box_env) as *mut _ as *mut VMContext; + let vmctx = VMFunctionExtraData { + host_env: Box::into_raw(box_env) as *mut _, + }; let signature = function.ty(); Self { @@ -365,7 +373,8 @@ impl Function { Self { store: store.clone(), definition: FunctionDefinition::Host(HostFunctionDefinition { - has_env: !wasmer_export.vmctx.is_null(), + // TOOD: make safe function on this union to check for null + has_env: !unsafe { wasmer_export.vmctx.host_env.is_null() }, }), exported: wasmer_export, } diff --git a/lib/api/src/native.rs b/lib/api/src/native.rs index 733bfd67e..e3dc1f517 100644 --- a/lib/api/src/native.rs +++ b/lib/api/src/native.rs @@ -17,7 +17,7 @@ use crate::{FromToNativeWasmType, Function, FunctionType, RuntimeError, Store, W use std::panic::{catch_unwind, AssertUnwindSafe}; use wasmer_types::NativeWasmType; use wasmer_vm::{ - ExportFunction, VMContext, VMDynamicFunctionContext, VMFunctionBody, VMFunctionKind, + ExportFunction, VMDynamicFunctionContext, VMFunctionBody, VMFunctionExtraData, VMFunctionKind, }; /// A WebAssembly function that can be called natively @@ -26,7 +26,7 @@ pub struct NativeFunc<'a, Args = (), Rets = ()> { definition: FunctionDefinition, store: Store, address: *const VMFunctionBody, - vmctx: *mut VMContext, + vmctx: VMFunctionExtraData, arg_kind: VMFunctionKind, // exported: ExportFunction, _phantom: PhantomData<(&'a (), Args, Rets)>, @@ -42,7 +42,7 @@ where pub(crate) fn new( store: Store, address: *const VMFunctionBody, - vmctx: *mut VMContext, + vmctx: VMFunctionExtraData, arg_kind: VMFunctionKind, definition: FunctionDefinition, ) -> Self { @@ -165,7 +165,7 @@ macro_rules! impl_native_traits { match self.arg_kind { VMFunctionKind::Static => { let results = catch_unwind(AssertUnwindSafe(|| unsafe { - let f = std::mem::transmute::<_, unsafe extern "C" fn( *mut VMContext, $( $x, )*) -> Rets::CStruct>(self.address); + let f = std::mem::transmute::<_, unsafe extern "C" fn( VMFunctionExtraData, $( $x, )*) -> Rets::CStruct>(self.address); // We always pass the vmctx f( self.vmctx, $( $x, )* ) })).map_err(|e| RuntimeError::new(format!("{:?}", e)))?; @@ -175,12 +175,16 @@ macro_rules! impl_native_traits { let params_list = [ $( $x.to_native().to_value() ),* ]; let results = if !has_env { type VMContextWithoutEnv = VMDynamicFunctionContext; - let ctx = self.vmctx as *mut VMContextWithoutEnv; - unsafe { (*ctx).ctx.call(¶ms_list)? } + unsafe { + let ctx = self.vmctx.host_env as *mut VMContextWithoutEnv; + (*ctx).ctx.call(¶ms_list)? + } } else { type VMContextWithEnv = VMDynamicFunctionContext>; - let ctx = self.vmctx as *mut VMContextWithEnv; - unsafe { (*ctx).ctx.call(¶ms_list)? } + unsafe { + let ctx = self.vmctx.host_env as *mut VMContextWithEnv; + (*ctx).ctx.call(¶ms_list)? + } }; let mut rets_list_array = Rets::empty_array(); let mut_rets = rets_list_array.as_mut() as *mut [i128] as *mut i128; diff --git a/lib/api/src/types.rs b/lib/api/src/types.rs index 5eac35f3b..ba61eed07 100644 --- a/lib/api/src/types.rs +++ b/lib/api/src/types.rs @@ -56,7 +56,9 @@ impl ValFuncRef for Val { Self::ExternRef(ExternRef::Null) => wasmer_vm::VMCallerCheckedAnyfunc { func_ptr: ptr::null(), type_index: wasmer_vm::VMSharedSignatureIndex::default(), - vmctx: ptr::null_mut(), + vmctx: wasmer_vm::VMFunctionExtraData { + host_env: ptr::null_mut(), + }, }, Self::FuncRef(f) => f.checked_anyfunc(), _ => return Err(RuntimeError::new("val is not funcref")), diff --git a/lib/deprecated/runtime-core/src/module.rs b/lib/deprecated/runtime-core/src/module.rs index ac9b606d1..7ef483dcb 100644 --- a/lib/deprecated/runtime-core/src/module.rs +++ b/lib/deprecated/runtime-core/src/module.rs @@ -102,14 +102,14 @@ impl Module { // Properly drop the empty `vm::Ctx` // created by the host function. unsafe { - ptr::drop_in_place::(function.vmctx as _); + ptr::drop_in_place::(function.vmctx.host_env as _); } // Update the pointer to `VMContext`, // which is actually a `vm::Ctx` // pointer, to fallback on the // environment hack. - function.vmctx = pre_instance.vmctx_ptr() as _; + function.vmctx.host_env = pre_instance.vmctx_ptr() as _; } // `function` is a dynamic host function // constructed with @@ -147,13 +147,13 @@ impl Module { new::wasmer_vm::VMDynamicFunctionContext< VMDynamicFunctionWithEnv, >, - > = unsafe { Box::from_raw(function.vmctx as *mut _) }; + > = unsafe { Box::from_raw(function.vmctx.host_env as *mut _) }; // Replace the environment by ours. vmctx.ctx.env.borrow_mut().vmctx = pre_instance.vmctx(); // … without anyone noticing… - function.vmctx = Box::into_raw(vmctx) as _; + function.vmctx.host_env = Box::into_raw(vmctx) as _; } } diff --git a/lib/engine/src/resolver.rs b/lib/engine/src/resolver.rs index 82b5df625..76f212381 100644 --- a/lib/engine/src/resolver.rs +++ b/lib/engine/src/resolver.rs @@ -167,7 +167,7 @@ pub fn resolve_imports( }; function_imports.push(VMFunctionImport { body: address, - vmctx: f.vmctx, + extra_data: f.vmctx, }); } Export::Table(ref t) => { diff --git a/lib/vm/src/export.rs b/lib/vm/src/export.rs index 32eb3ca64..ecdaf1f04 100644 --- a/lib/vm/src/export.rs +++ b/lib/vm/src/export.rs @@ -4,7 +4,7 @@ use crate::global::Global; use crate::memory::{Memory, MemoryStyle}; use crate::table::{Table, TableStyle}; -use crate::vmcontext::{VMContext, VMFunctionBody, VMFunctionKind, VMTrampoline}; +use crate::vmcontext::{VMFunctionBody, VMFunctionExtraData, VMFunctionKind, VMTrampoline}; use std::sync::Arc; use wasmer_types::{FunctionType, MemoryType, TableType}; @@ -30,7 +30,7 @@ pub struct ExportFunction { /// The address of the native-code function. pub address: *const VMFunctionBody, /// Pointer to the containing `VMContext`. - pub vmctx: *mut VMContext, + pub vmctx: VMFunctionExtraData, /// The function type, used for compatibility checking. pub signature: FunctionType, /// The function kind (it defines how it's the signature that provided `address` have) diff --git a/lib/vm/src/instance.rs b/lib/vm/src/instance.rs index 9777d51e4..e5267382a 100644 --- a/lib/vm/src/instance.rs +++ b/lib/vm/src/instance.rs @@ -11,9 +11,10 @@ use crate::memory::{Memory, MemoryError}; use crate::table::Table; use crate::trap::{catch_traps, init_traps, Trap, TrapCode}; use crate::vmcontext::{ - VMBuiltinFunctionsArray, VMCallerCheckedAnyfunc, VMContext, VMFunctionBody, VMFunctionImport, - VMFunctionKind, VMGlobalDefinition, VMGlobalImport, VMMemoryDefinition, VMMemoryImport, - VMSharedSignatureIndex, VMTableDefinition, VMTableImport, VMTrampoline, + VMBuiltinFunctionsArray, VMCallerCheckedAnyfunc, VMContext, VMFunctionBody, + VMFunctionExtraData, VMFunctionImport, VMFunctionKind, VMGlobalDefinition, VMGlobalImport, + VMMemoryDefinition, VMMemoryImport, VMSharedSignatureIndex, VMTableDefinition, VMTableImport, + VMTrampoline, }; use crate::{ExportFunction, ExportGlobal, ExportMemory, ExportTable}; use crate::{FunctionBodyPtr, ModuleInfo, VMOffsets}; @@ -296,10 +297,15 @@ impl Instance { let sig_index = &self.module.functions[*index]; let (address, vmctx) = if let Some(def_index) = self.module.local_func_index(*index) { - (self.functions[def_index].0 as *const _, self.vmctx_ptr()) + ( + self.functions[def_index].0 as *const _, + VMFunctionExtraData { + vmctx: self.vmctx_ptr(), + }, + ) } else { let import = self.imported_function(*index); - (import.body, import.vmctx) + (import.body, import.extra_data) }; let call_trampoline = Some(self.function_call_trampolines[*sig_index]); let signature = self.module.signatures[*sig_index].clone(); @@ -377,19 +383,24 @@ impl Instance { .get(local_index) .expect("function index is out of bounds") .0; - (body as *const _, self.vmctx_ptr()) + ( + body as *const _, + VMFunctionExtraData { + vmctx: self.vmctx_ptr(), + }, + ) } None => { assert_lt!(start_index.index(), self.module.num_imported_functions); let import = self.imported_function(start_index); - (import.body, import.vmctx) + (import.body, import.extra_data) } }; // Make the call. unsafe { catch_traps(callee_vmctx, || { - mem::transmute::<*const VMFunctionBody, unsafe extern "C" fn(*mut VMContext)>( + mem::transmute::<*const VMFunctionBody, unsafe extern "C" fn(VMFunctionExtraData)>( callee_address, )(callee_vmctx) }) @@ -561,10 +572,15 @@ impl Instance { let type_index = self.signature_id(sig); let (func_ptr, vmctx) = if let Some(def_index) = self.module.local_func_index(index) { - (self.functions[def_index].0 as *const _, self.vmctx_ptr()) + ( + self.functions[def_index].0 as *const _, + VMFunctionExtraData { + vmctx: self.vmctx_ptr(), + }, + ) } else { let import = self.imported_function(index); - (import.body, import.vmctx) + (import.body, import.extra_data) }; VMCallerCheckedAnyfunc { func_ptr, diff --git a/lib/vm/src/lib.rs b/lib/vm/src/lib.rs index f8e649f59..eb427ebff 100644 --- a/lib/vm/src/lib.rs +++ b/lib/vm/src/lib.rs @@ -49,9 +49,9 @@ pub use crate::table::{LinearTable, Table, TableStyle}; pub use crate::trap::*; pub use crate::vmcontext::{ VMBuiltinFunctionIndex, VMCallerCheckedAnyfunc, VMContext, VMDynamicFunctionContext, - VMFunctionBody, VMFunctionImport, VMFunctionKind, VMGlobalDefinition, VMGlobalImport, - VMMemoryDefinition, VMMemoryImport, VMSharedSignatureIndex, VMTableDefinition, VMTableImport, - VMTrampoline, + VMFunctionBody, VMFunctionExtraData, VMFunctionImport, VMFunctionKind, VMGlobalDefinition, + VMGlobalImport, VMMemoryDefinition, VMMemoryImport, VMSharedSignatureIndex, VMTableDefinition, + VMTableImport, VMTrampoline, }; pub use crate::vmoffsets::{TargetSharedSignatureIndex, VMOffsets}; diff --git a/lib/vm/src/trap/traphandlers.rs b/lib/vm/src/trap/traphandlers.rs index 613e8ec10..9d4237e36 100644 --- a/lib/vm/src/trap/traphandlers.rs +++ b/lib/vm/src/trap/traphandlers.rs @@ -6,7 +6,7 @@ use super::trapcode::TrapCode; use crate::instance::{InstanceHandle, SignalHandler}; -use crate::vmcontext::{VMContext, VMFunctionBody, VMTrampoline}; +use crate::vmcontext::{VMFunctionBody, VMFunctionExtraData, VMTrampoline}; use backtrace::Backtrace; use std::any::Any; use std::cell::Cell; @@ -429,13 +429,13 @@ impl Trap { /// Wildly unsafe because it calls raw function pointers and reads/writes raw /// function pointers. pub unsafe fn wasmer_call_trampoline( - vmctx: *mut VMContext, + vmctx: VMFunctionExtraData, trampoline: VMTrampoline, callee: *const VMFunctionBody, values_vec: *mut u8, ) -> Result<(), Trap> { catch_traps(vmctx, || { - mem::transmute::<_, extern "C" fn(*mut VMContext, *const VMFunctionBody, *mut u8)>( + mem::transmute::<_, extern "C" fn(VMFunctionExtraData, *const VMFunctionBody, *mut u8)>( trampoline, )(vmctx, callee, values_vec) }) @@ -447,7 +447,7 @@ pub unsafe fn wasmer_call_trampoline( /// # Safety /// /// Highly unsafe since `closure` won't have any destructors run. -pub unsafe fn catch_traps(vmctx: *mut VMContext, mut closure: F) -> Result<(), Trap> +pub unsafe fn catch_traps(vmctx: VMFunctionExtraData, mut closure: F) -> Result<(), Trap> where F: FnMut(), { @@ -481,7 +481,7 @@ where /// /// Check [`catch_traps`]. pub unsafe fn catch_traps_with_result( - vmctx: *mut VMContext, + vmctx: VMFunctionExtraData, mut closure: F, ) -> Result where @@ -501,7 +501,7 @@ pub struct CallThreadState { jmp_buf: Cell<*const u8>, reset_guard_page: Cell, prev: Option<*const CallThreadState>, - vmctx: *mut VMContext, + vmctx: VMFunctionExtraData, handling_trap: Cell, } @@ -518,7 +518,7 @@ enum UnwindReason { } impl CallThreadState { - fn new(vmctx: *mut VMContext) -> Self { + fn new(vmctx: VMFunctionExtraData) -> Self { Self { unwind: Cell::new(UnwindReason::None), vmctx, @@ -561,7 +561,7 @@ impl CallThreadState { fn any_instance(&self, func: impl Fn(&InstanceHandle) -> bool) -> bool { unsafe { - if func(&InstanceHandle::from_vmctx(self.vmctx)) { + if func(&InstanceHandle::from_vmctx(self.vmctx.vmctx)) { return true; } match self.prev { diff --git a/lib/vm/src/vmcontext.rs b/lib/vm/src/vmcontext.rs index 367a20762..719987b05 100644 --- a/lib/vm/src/vmcontext.rs +++ b/lib/vm/src/vmcontext.rs @@ -15,6 +15,28 @@ use std::ptr::{self, NonNull}; use std::sync::Arc; use std::u32; +/// We stop lying about what this daat is +/// TODO: +#[derive(Copy, Clone)] +pub union VMFunctionExtraData { + /// Wasm function, it has a real VMContext: + pub vmctx: *mut VMContext, + /// Host functions can have custom environments + pub host_env: *mut std::ffi::c_void, +} + +impl std::fmt::Debug for VMFunctionExtraData { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "FunctionExtarData union") + } +} + +impl std::cmp::PartialEq for VMFunctionExtraData { + fn eq(&self, rhs: &Self) -> bool { + unsafe { self.host_env as usize == rhs.host_env as usize } + } +} + /// An imported function. #[derive(Debug, Copy, Clone)] #[repr(C)] @@ -22,8 +44,8 @@ pub struct VMFunctionImport { /// A pointer to the imported function body. pub body: *const VMFunctionBody, - /// A pointer to the `VMContext` that owns the function. - pub vmctx: *mut VMContext, + /// A pointer to the `VMContext` that owns the function or host env data. + pub extra_data: VMFunctionExtraData, } #[cfg(test)] @@ -46,7 +68,7 @@ mod test_vmfunction_import { usize::from(offsets.vmfunction_import_body()) ); assert_eq!( - offset_of!(VMFunctionImport, vmctx), + offset_of!(VMFunctionImport, extra_data), usize::from(offsets.vmfunction_import_vmctx()) ); } @@ -728,8 +750,8 @@ pub struct VMCallerCheckedAnyfunc { pub func_ptr: *const VMFunctionBody, /// Function signature id. pub type_index: VMSharedSignatureIndex, - /// Function `VMContext`. - pub vmctx: *mut VMContext, + /// Function `VMContext` or host env. + pub vmctx: VMFunctionExtraData, // If more elements are added here, remember to add offset_of tests below! } @@ -768,7 +790,9 @@ impl Default for VMCallerCheckedAnyfunc { Self { func_ptr: ptr::null_mut(), type_index: Default::default(), - vmctx: ptr::null_mut(), + vmctx: VMFunctionExtraData { + vmctx: ptr::null_mut(), + }, } } }