mirror of
https://github.com/mii443/wasmer.git
synced 2025-12-07 13:18:20 +00:00
Fixed for a futex race conditon and infinite polling
This commit is contained in:
@@ -177,7 +177,7 @@ pub trait VirtualNetworking: fmt::Debug + Send + Sync + 'static {
|
||||
|
||||
pub type DynVirtualNetworking = Arc<dyn VirtualNetworking>;
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SocketReceive {
|
||||
/// Data that was received
|
||||
pub data: Bytes,
|
||||
@@ -185,7 +185,7 @@ pub struct SocketReceive {
|
||||
pub truncated: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SocketReceiveFrom {
|
||||
/// Data that was received
|
||||
pub data: Bytes,
|
||||
|
||||
@@ -255,7 +255,10 @@ pub struct LocalTcpStream {
|
||||
connect_timeout: Option<Duration>,
|
||||
linger_timeout: Option<Duration>,
|
||||
nonblocking: bool,
|
||||
sent_eof: bool,
|
||||
shutdown: Option<Shutdown>,
|
||||
tx_recv: mpsc::UnboundedSender<Result<SocketReceive>>,
|
||||
rx_recv: mpsc::UnboundedReceiver<Result<SocketReceive>>,
|
||||
tx_write_ready: mpsc::Sender<()>,
|
||||
rx_write_ready: mpsc::Receiver<()>,
|
||||
tx_write_poll_ready: mpsc::Sender<()>,
|
||||
@@ -264,6 +267,7 @@ pub struct LocalTcpStream {
|
||||
|
||||
impl LocalTcpStream {
|
||||
pub fn new(stream: tokio::net::TcpStream, addr: SocketAddr, nonblocking: bool) -> Self {
|
||||
let (tx_recv, rx_recv) = mpsc::unbounded_channel();
|
||||
let (tx_write_ready, rx_write_ready) = mpsc::channel(1);
|
||||
let (tx_write_poll_ready, rx_write_poll_ready) = mpsc::channel(1);
|
||||
Self {
|
||||
@@ -275,10 +279,13 @@ impl LocalTcpStream {
|
||||
linger_timeout: None,
|
||||
nonblocking,
|
||||
shutdown: None,
|
||||
sent_eof: false,
|
||||
tx_write_ready,
|
||||
rx_write_ready,
|
||||
tx_write_poll_ready,
|
||||
rx_write_poll_ready,
|
||||
tx_recv,
|
||||
rx_recv,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -355,6 +362,80 @@ impl VirtualTcpSocket for LocalTcpStream {
|
||||
}
|
||||
}
|
||||
|
||||
impl LocalTcpStream {
|
||||
async fn recv_now_ext(
|
||||
nonblocking: bool,
|
||||
stream: &mut tokio::net::TcpStream,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<SocketReceive> {
|
||||
if nonblocking {
|
||||
let max_buf_size = 8192;
|
||||
let mut buf = Vec::with_capacity(max_buf_size);
|
||||
unsafe {
|
||||
buf.set_len(max_buf_size);
|
||||
}
|
||||
|
||||
let waker = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &NOOP_WAKER_VTABLE)) };
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
let stream = Pin::new(stream);
|
||||
let mut read_buf = tokio::io::ReadBuf::new(&mut buf);
|
||||
return match stream.poll_read(&mut cx, &mut read_buf) {
|
||||
Poll::Ready(Ok(read)) => {
|
||||
let read = read_buf.remaining();
|
||||
unsafe {
|
||||
buf.set_len(read);
|
||||
}
|
||||
if read == 0 {
|
||||
return Err(NetworkError::WouldBlock);
|
||||
}
|
||||
let buf = Bytes::from(buf);
|
||||
Ok(SocketReceive {
|
||||
data: buf,
|
||||
truncated: read == max_buf_size,
|
||||
})
|
||||
}
|
||||
Poll::Ready(Err(err)) => Err(io_err_into_net_error(err)),
|
||||
Poll::Pending => Err(NetworkError::WouldBlock),
|
||||
};
|
||||
} else {
|
||||
Self::recv_now(stream, timeout).await
|
||||
}
|
||||
}
|
||||
|
||||
async fn recv_now(
|
||||
stream: &mut tokio::net::TcpStream,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<SocketReceive> {
|
||||
use tokio::io::AsyncReadExt;
|
||||
let max_buf_size = 8192;
|
||||
let mut buf = Vec::with_capacity(max_buf_size);
|
||||
unsafe {
|
||||
buf.set_len(max_buf_size);
|
||||
}
|
||||
|
||||
let work = async move {
|
||||
match timeout {
|
||||
Some(timeout) => tokio::time::timeout(timeout, stream.read(&mut buf[..]))
|
||||
.await
|
||||
.map_err(|_| Into::<std::io::Error>::into(std::io::ErrorKind::TimedOut))?,
|
||||
None => stream.read(&mut buf[..]).await,
|
||||
}
|
||||
.map(|read| {
|
||||
unsafe {
|
||||
buf.set_len(read);
|
||||
}
|
||||
Bytes::from(buf)
|
||||
})
|
||||
};
|
||||
|
||||
let buf = work.await.map_err(io_err_into_net_error)?;
|
||||
Ok(SocketReceive {
|
||||
truncated: buf.len() == max_buf_size,
|
||||
data: buf,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl VirtualConnectedSocket for LocalTcpStream {
|
||||
fn set_linger(&mut self, linger: Option<Duration>) -> Result<()> {
|
||||
@@ -432,155 +513,40 @@ impl VirtualConnectedSocket for LocalTcpStream {
|
||||
}
|
||||
|
||||
async fn recv(&mut self) -> Result<SocketReceive> {
|
||||
use tokio::io::AsyncReadExt;
|
||||
let max_buf_size = 8192;
|
||||
let mut buf = Vec::with_capacity(max_buf_size);
|
||||
unsafe {
|
||||
buf.set_len(max_buf_size);
|
||||
if let Ok(ret) = self.rx_recv.try_recv() {
|
||||
return ret;
|
||||
}
|
||||
|
||||
let nonblocking = self.nonblocking;
|
||||
if nonblocking {
|
||||
let waker = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &NOOP_WAKER_VTABLE)) };
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
let stream = Pin::new(&mut self.stream);
|
||||
let mut read_buf = tokio::io::ReadBuf::new(&mut buf);
|
||||
return match stream.poll_read(&mut cx, &mut read_buf) {
|
||||
Poll::Ready(Ok(read)) => {
|
||||
let read = read_buf.remaining();
|
||||
unsafe {
|
||||
buf.set_len(read);
|
||||
tokio::select! {
|
||||
ret = Self::recv_now_ext(
|
||||
self.nonblocking,
|
||||
&mut self.stream,
|
||||
self.read_timeout.clone(),
|
||||
) => ret,
|
||||
ret = self.rx_recv.recv() => ret.unwrap_or(Err(NetworkError::ConnectionAborted))
|
||||
}
|
||||
if read == 0 {
|
||||
return Err(NetworkError::WouldBlock);
|
||||
}
|
||||
let buf = Bytes::from(buf);
|
||||
Ok(SocketReceive {
|
||||
data: buf,
|
||||
truncated: read == max_buf_size,
|
||||
})
|
||||
}
|
||||
Poll::Ready(Err(err)) => Err(io_err_into_net_error(err)),
|
||||
Poll::Pending => Err(NetworkError::WouldBlock),
|
||||
};
|
||||
}
|
||||
|
||||
let timeout = self.write_timeout.clone();
|
||||
let work = async move {
|
||||
match timeout {
|
||||
Some(timeout) => tokio::time::timeout(timeout, self.stream.read(&mut buf[..]))
|
||||
.await
|
||||
.map_err(|_| Into::<std::io::Error>::into(std::io::ErrorKind::WouldBlock))?,
|
||||
None => self.stream.read(&mut buf[..]).await,
|
||||
}
|
||||
.map(|read| {
|
||||
unsafe {
|
||||
buf.set_len(read);
|
||||
}
|
||||
Bytes::from(buf)
|
||||
})
|
||||
};
|
||||
|
||||
let buf = work.await.map_err(io_err_into_net_error)?;
|
||||
if buf.is_empty() {
|
||||
if nonblocking {
|
||||
return Err(NetworkError::WouldBlock);
|
||||
} else {
|
||||
return Err(NetworkError::BrokenPipe);
|
||||
}
|
||||
}
|
||||
Ok(SocketReceive {
|
||||
truncated: buf.len() == max_buf_size,
|
||||
data: buf,
|
||||
})
|
||||
}
|
||||
|
||||
fn try_recv(&mut self) -> Result<Option<SocketReceive>> {
|
||||
let max_buf_size = 8192;
|
||||
let mut buf = Vec::with_capacity(max_buf_size);
|
||||
unsafe {
|
||||
buf.set_len(max_buf_size);
|
||||
}
|
||||
|
||||
let mut work = self.recv();
|
||||
let waker = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &NOOP_WAKER_VTABLE)) };
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
let stream = Pin::new(&mut self.stream);
|
||||
let mut read_buf = tokio::io::ReadBuf::new(&mut buf);
|
||||
match stream.poll_read(&mut cx, &mut read_buf) {
|
||||
Poll::Ready(Ok(read)) => {
|
||||
let read = read_buf.remaining();
|
||||
unsafe {
|
||||
buf.set_len(read);
|
||||
}
|
||||
if read == 0 {
|
||||
return Err(NetworkError::WouldBlock);
|
||||
}
|
||||
let buf = Bytes::from(buf);
|
||||
Ok(Some(SocketReceive {
|
||||
data: buf,
|
||||
truncated: read == max_buf_size,
|
||||
}))
|
||||
}
|
||||
Poll::Ready(Err(err)) => Err(io_err_into_net_error(err)),
|
||||
match work.as_mut().poll(&mut cx) {
|
||||
Poll::Ready(Ok(ret)) => Ok(Some(ret)),
|
||||
Poll::Ready(Err(err)) => Err(err),
|
||||
Poll::Pending => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
async fn peek(&mut self) -> Result<SocketReceive> {
|
||||
let max_buf_size = 8192;
|
||||
let mut buf = Vec::with_capacity(max_buf_size);
|
||||
unsafe {
|
||||
buf.set_len(max_buf_size);
|
||||
}
|
||||
|
||||
if self.nonblocking {
|
||||
let waker = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &NOOP_WAKER_VTABLE)) };
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
let stream = Pin::new(&mut self.stream);
|
||||
let mut read_buf = tokio::io::ReadBuf::new(&mut buf);
|
||||
return match stream.poll_peek(&mut cx, &mut read_buf) {
|
||||
Poll::Ready(Ok(read)) => {
|
||||
unsafe {
|
||||
buf.set_len(read);
|
||||
}
|
||||
if read == 0 {
|
||||
return Err(NetworkError::WouldBlock);
|
||||
}
|
||||
let buf = Bytes::from(buf);
|
||||
Ok(SocketReceive {
|
||||
data: buf,
|
||||
truncated: read == max_buf_size,
|
||||
})
|
||||
}
|
||||
Poll::Ready(Err(err)) => Err(io_err_into_net_error(err)),
|
||||
Poll::Pending => Err(NetworkError::WouldBlock),
|
||||
};
|
||||
}
|
||||
|
||||
let timeout = self.write_timeout.clone();
|
||||
let work = async move {
|
||||
match timeout {
|
||||
Some(timeout) => tokio::time::timeout(timeout, self.stream.peek(&mut buf[..]))
|
||||
.await
|
||||
.map_err(|_| Into::<std::io::Error>::into(std::io::ErrorKind::WouldBlock))?,
|
||||
None => self.stream.peek(&mut buf[..]).await,
|
||||
}
|
||||
.map(|read| {
|
||||
unsafe {
|
||||
buf.set_len(read);
|
||||
}
|
||||
Bytes::from(buf)
|
||||
})
|
||||
};
|
||||
|
||||
let buf = work.await.map_err(io_err_into_net_error)?;
|
||||
if buf.len() == 0 {
|
||||
return Err(NetworkError::BrokenPipe);
|
||||
}
|
||||
Ok(SocketReceive {
|
||||
truncated: buf.len() == max_buf_size,
|
||||
data: buf,
|
||||
})
|
||||
let ret = Self::recv_now_ext(
|
||||
self.nonblocking,
|
||||
&mut self.stream,
|
||||
self.read_timeout.clone(),
|
||||
)
|
||||
.await;
|
||||
self.tx_recv.send(ret.clone()).ok();
|
||||
ret
|
||||
}
|
||||
}
|
||||
|
||||
@@ -615,10 +581,23 @@ impl VirtualSocket for LocalTcpStream {
|
||||
&mut self,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<usize>> {
|
||||
self.stream
|
||||
.poll_read_ready(cx)
|
||||
.map_ok(|a| 1usize)
|
||||
.map_err(io_err_into_net_error)
|
||||
let ret = {
|
||||
let mut work = Box::pin(Self::recv_now(&mut self.stream, self.read_timeout.clone()));
|
||||
match work.as_mut().poll(cx) {
|
||||
Poll::Ready(ret) => ret,
|
||||
Poll::Pending => return Poll::Pending,
|
||||
}
|
||||
};
|
||||
if let Ok(ret) = ret.as_ref() {
|
||||
if ret.data.len() == 0 {
|
||||
if self.sent_eof == true {
|
||||
return Poll::Pending;
|
||||
}
|
||||
self.sent_eof = true;
|
||||
}
|
||||
}
|
||||
self.tx_recv.send(ret).ok();
|
||||
Poll::Ready(Ok(1))
|
||||
}
|
||||
|
||||
fn poll_write_ready(
|
||||
|
||||
@@ -154,10 +154,10 @@ pub(crate) struct WasiStateThreading {
|
||||
|
||||
/// Represents a futex which will make threads wait for completion in a more
|
||||
/// CPU efficient manner
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug)]
|
||||
pub struct WasiFutex {
|
||||
pub(crate) refcnt: Arc<AtomicU32>,
|
||||
pub(crate) inner: Arc<Mutex<tokio::sync::broadcast::Sender<()>>>,
|
||||
pub(crate) refcnt: AtomicU32,
|
||||
pub(crate) waker: tokio::sync::broadcast::Sender<()>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -251,7 +251,7 @@ pub struct WasiState {
|
||||
// TODO: review allow...
|
||||
#[allow(dead_code)]
|
||||
pub(crate) threading: RwLock<WasiStateThreading>,
|
||||
pub(crate) futexs: Mutex<HashMap<u64, WasiFutex>>,
|
||||
pub(crate) futexs: RwLock<HashMap<u64, WasiFutex>>,
|
||||
pub(crate) clock_offset: Mutex<HashMap<Snapshot0Clockid, i64>>,
|
||||
pub(crate) bus: WasiBusState,
|
||||
pub args: Vec<String>,
|
||||
|
||||
@@ -31,23 +31,6 @@ pub fn futex_wait<M: MemorySize>(
|
||||
|
||||
let pointer: u64 = wasi_try_ok!(futex_ptr.offset().try_into().map_err(|_| Errno::Overflow));
|
||||
|
||||
// Register the waiting futex (if its not already registered)
|
||||
let futex = {
|
||||
use std::collections::hash_map::Entry;
|
||||
let mut guard = state.futexs.lock().unwrap();
|
||||
match guard.entry(pointer) {
|
||||
Entry::Occupied(entry) => entry.get().clone(),
|
||||
Entry::Vacant(entry) => {
|
||||
let futex = WasiFutex {
|
||||
refcnt: Arc::new(AtomicU32::new(1)),
|
||||
inner: Arc::new(Mutex::new(tokio::sync::broadcast::channel(1).0)),
|
||||
};
|
||||
entry.insert(futex.clone());
|
||||
futex
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Determine the timeout
|
||||
let timeout = {
|
||||
let memory = env.memory_view(&ctx);
|
||||
@@ -62,11 +45,23 @@ pub fn futex_wait<M: MemorySize>(
|
||||
let mut woken = Bool::False;
|
||||
let start = platform_clock_time_get(Snapshot0Clockid::Monotonic, 1).unwrap() as u128;
|
||||
loop {
|
||||
// Register the waiting futex (if its not already registered)
|
||||
let mut rx = {
|
||||
let futex_lock = futex.inner.lock().unwrap();
|
||||
use std::collections::hash_map::Entry;
|
||||
let mut guard = state.futexs.write().unwrap();
|
||||
if guard.contains_key(&pointer) == false {
|
||||
let futex = WasiFutex {
|
||||
refcnt: AtomicU32::new(1),
|
||||
waker: tokio::sync::broadcast::channel(1).0,
|
||||
};
|
||||
guard.insert(pointer, futex);
|
||||
}
|
||||
let futex = guard.get_mut(&pointer).unwrap();
|
||||
|
||||
// If the value of the memory is no longer the expected value
|
||||
// then terminate from the loop (we do this under a futex lock
|
||||
// so that its protected)
|
||||
let rx = futex.waker.subscribe();
|
||||
{
|
||||
let view = env.memory_view(&ctx);
|
||||
let val = wasi_try_mem_ok!(futex_ptr.read(&view));
|
||||
@@ -75,7 +70,7 @@ pub fn futex_wait<M: MemorySize>(
|
||||
break;
|
||||
}
|
||||
}
|
||||
futex_lock.subscribe()
|
||||
rx
|
||||
};
|
||||
|
||||
// Check if we have timed out
|
||||
@@ -91,16 +86,16 @@ pub fn futex_wait<M: MemorySize>(
|
||||
}
|
||||
|
||||
// Now wait for it to be triggered
|
||||
wasi_try_ok!(__asyncify(&mut ctx, sub_timeout, async move {
|
||||
let _ = rx.recv().await;
|
||||
__asyncify(&mut ctx, sub_timeout, async move {
|
||||
rx.recv().await.ok();
|
||||
Ok(())
|
||||
})?);
|
||||
})?;
|
||||
env = ctx.data();
|
||||
}
|
||||
|
||||
// Drop the reference count to the futex (and remove it if the refcnt hits zero)
|
||||
{
|
||||
let mut guard = state.futexs.lock().unwrap();
|
||||
let mut guard = state.futexs.write().unwrap();
|
||||
if guard
|
||||
.get(&pointer)
|
||||
.map(|futex| futex.refcnt.fetch_sub(1, Ordering::AcqRel) == 1)
|
||||
|
||||
@@ -26,11 +26,10 @@ pub fn futex_wake<M: MemorySize>(
|
||||
let pointer: u64 = wasi_try!(futex.offset().try_into().map_err(|_| Errno::Overflow));
|
||||
let mut woken = false;
|
||||
|
||||
let mut guard = state.futexs.lock().unwrap();
|
||||
let mut guard = state.futexs.read().unwrap();
|
||||
if let Some(futex) = guard.get(&pointer) {
|
||||
let inner = futex.inner.lock().unwrap();
|
||||
woken = inner.receiver_count() > 0;
|
||||
let _ = inner.send(());
|
||||
woken = futex.waker.receiver_count() > 0;
|
||||
let _ = futex.waker.send(());
|
||||
} else {
|
||||
trace!(
|
||||
"wasi[{}:{}]::futex_wake - nothing waiting!",
|
||||
|
||||
@@ -24,11 +24,16 @@ pub fn futex_wake_all<M: MemorySize>(
|
||||
let pointer: u64 = wasi_try!(futex.offset().try_into().map_err(|_| Errno::Overflow));
|
||||
let mut woken = false;
|
||||
|
||||
let mut guard = state.futexs.lock().unwrap();
|
||||
if let Some(futex) = guard.remove(&pointer) {
|
||||
let inner = futex.inner.lock().unwrap();
|
||||
woken = inner.receiver_count() > 0;
|
||||
let _ = inner.send(());
|
||||
let mut guard = state.futexs.read().unwrap();
|
||||
if let Some(futex) = guard.get(&pointer) {
|
||||
woken = futex.waker.receiver_count() > 0;
|
||||
let _ = futex.waker.send(());
|
||||
} else {
|
||||
trace!(
|
||||
"wasi[{}:{}]::futex_wake_all - nothing waiting!",
|
||||
ctx.data().pid(),
|
||||
ctx.data().tid()
|
||||
);
|
||||
}
|
||||
|
||||
let woken = match woken {
|
||||
|
||||
Reference in New Issue
Block a user