Fixed for a futex race conditon and infinite polling

This commit is contained in:
Johnathan Sharratt
2023-01-19 01:02:51 +11:00
parent 7e1e9fa06b
commit 30ec91d489
6 changed files with 156 additions and 178 deletions

View File

@@ -177,7 +177,7 @@ pub trait VirtualNetworking: fmt::Debug + Send + Sync + 'static {
pub type DynVirtualNetworking = Arc<dyn VirtualNetworking>; pub type DynVirtualNetworking = Arc<dyn VirtualNetworking>;
#[derive(Debug)] #[derive(Debug, Clone)]
pub struct SocketReceive { pub struct SocketReceive {
/// Data that was received /// Data that was received
pub data: Bytes, pub data: Bytes,
@@ -185,7 +185,7 @@ pub struct SocketReceive {
pub truncated: bool, pub truncated: bool,
} }
#[derive(Debug)] #[derive(Debug, Clone)]
pub struct SocketReceiveFrom { pub struct SocketReceiveFrom {
/// Data that was received /// Data that was received
pub data: Bytes, pub data: Bytes,

View File

@@ -255,7 +255,10 @@ pub struct LocalTcpStream {
connect_timeout: Option<Duration>, connect_timeout: Option<Duration>,
linger_timeout: Option<Duration>, linger_timeout: Option<Duration>,
nonblocking: bool, nonblocking: bool,
sent_eof: bool,
shutdown: Option<Shutdown>, shutdown: Option<Shutdown>,
tx_recv: mpsc::UnboundedSender<Result<SocketReceive>>,
rx_recv: mpsc::UnboundedReceiver<Result<SocketReceive>>,
tx_write_ready: mpsc::Sender<()>, tx_write_ready: mpsc::Sender<()>,
rx_write_ready: mpsc::Receiver<()>, rx_write_ready: mpsc::Receiver<()>,
tx_write_poll_ready: mpsc::Sender<()>, tx_write_poll_ready: mpsc::Sender<()>,
@@ -264,6 +267,7 @@ pub struct LocalTcpStream {
impl LocalTcpStream { impl LocalTcpStream {
pub fn new(stream: tokio::net::TcpStream, addr: SocketAddr, nonblocking: bool) -> Self { pub fn new(stream: tokio::net::TcpStream, addr: SocketAddr, nonblocking: bool) -> Self {
let (tx_recv, rx_recv) = mpsc::unbounded_channel();
let (tx_write_ready, rx_write_ready) = mpsc::channel(1); let (tx_write_ready, rx_write_ready) = mpsc::channel(1);
let (tx_write_poll_ready, rx_write_poll_ready) = mpsc::channel(1); let (tx_write_poll_ready, rx_write_poll_ready) = mpsc::channel(1);
Self { Self {
@@ -275,10 +279,13 @@ impl LocalTcpStream {
linger_timeout: None, linger_timeout: None,
nonblocking, nonblocking,
shutdown: None, shutdown: None,
sent_eof: false,
tx_write_ready, tx_write_ready,
rx_write_ready, rx_write_ready,
tx_write_poll_ready, tx_write_poll_ready,
rx_write_poll_ready, rx_write_poll_ready,
tx_recv,
rx_recv,
} }
} }
} }
@@ -355,6 +362,80 @@ impl VirtualTcpSocket for LocalTcpStream {
} }
} }
impl LocalTcpStream {
async fn recv_now_ext(
nonblocking: bool,
stream: &mut tokio::net::TcpStream,
timeout: Option<Duration>,
) -> Result<SocketReceive> {
if nonblocking {
let max_buf_size = 8192;
let mut buf = Vec::with_capacity(max_buf_size);
unsafe {
buf.set_len(max_buf_size);
}
let waker = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &NOOP_WAKER_VTABLE)) };
let mut cx = Context::from_waker(&waker);
let stream = Pin::new(stream);
let mut read_buf = tokio::io::ReadBuf::new(&mut buf);
return match stream.poll_read(&mut cx, &mut read_buf) {
Poll::Ready(Ok(read)) => {
let read = read_buf.remaining();
unsafe {
buf.set_len(read);
}
if read == 0 {
return Err(NetworkError::WouldBlock);
}
let buf = Bytes::from(buf);
Ok(SocketReceive {
data: buf,
truncated: read == max_buf_size,
})
}
Poll::Ready(Err(err)) => Err(io_err_into_net_error(err)),
Poll::Pending => Err(NetworkError::WouldBlock),
};
} else {
Self::recv_now(stream, timeout).await
}
}
async fn recv_now(
stream: &mut tokio::net::TcpStream,
timeout: Option<Duration>,
) -> Result<SocketReceive> {
use tokio::io::AsyncReadExt;
let max_buf_size = 8192;
let mut buf = Vec::with_capacity(max_buf_size);
unsafe {
buf.set_len(max_buf_size);
}
let work = async move {
match timeout {
Some(timeout) => tokio::time::timeout(timeout, stream.read(&mut buf[..]))
.await
.map_err(|_| Into::<std::io::Error>::into(std::io::ErrorKind::TimedOut))?,
None => stream.read(&mut buf[..]).await,
}
.map(|read| {
unsafe {
buf.set_len(read);
}
Bytes::from(buf)
})
};
let buf = work.await.map_err(io_err_into_net_error)?;
Ok(SocketReceive {
truncated: buf.len() == max_buf_size,
data: buf,
})
}
}
#[async_trait::async_trait] #[async_trait::async_trait]
impl VirtualConnectedSocket for LocalTcpStream { impl VirtualConnectedSocket for LocalTcpStream {
fn set_linger(&mut self, linger: Option<Duration>) -> Result<()> { fn set_linger(&mut self, linger: Option<Duration>) -> Result<()> {
@@ -432,155 +513,40 @@ impl VirtualConnectedSocket for LocalTcpStream {
} }
async fn recv(&mut self) -> Result<SocketReceive> { async fn recv(&mut self) -> Result<SocketReceive> {
use tokio::io::AsyncReadExt; if let Ok(ret) = self.rx_recv.try_recv() {
let max_buf_size = 8192; return ret;
let mut buf = Vec::with_capacity(max_buf_size);
unsafe {
buf.set_len(max_buf_size);
} }
let nonblocking = self.nonblocking; tokio::select! {
if nonblocking { ret = Self::recv_now_ext(
let waker = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &NOOP_WAKER_VTABLE)) }; self.nonblocking,
let mut cx = Context::from_waker(&waker); &mut self.stream,
let stream = Pin::new(&mut self.stream); self.read_timeout.clone(),
let mut read_buf = tokio::io::ReadBuf::new(&mut buf); ) => ret,
return match stream.poll_read(&mut cx, &mut read_buf) { ret = self.rx_recv.recv() => ret.unwrap_or(Err(NetworkError::ConnectionAborted))
Poll::Ready(Ok(read)) => {
let read = read_buf.remaining();
unsafe {
buf.set_len(read);
}
if read == 0 {
return Err(NetworkError::WouldBlock);
}
let buf = Bytes::from(buf);
Ok(SocketReceive {
data: buf,
truncated: read == max_buf_size,
})
}
Poll::Ready(Err(err)) => Err(io_err_into_net_error(err)),
Poll::Pending => Err(NetworkError::WouldBlock),
};
} }
let timeout = self.write_timeout.clone();
let work = async move {
match timeout {
Some(timeout) => tokio::time::timeout(timeout, self.stream.read(&mut buf[..]))
.await
.map_err(|_| Into::<std::io::Error>::into(std::io::ErrorKind::WouldBlock))?,
None => self.stream.read(&mut buf[..]).await,
}
.map(|read| {
unsafe {
buf.set_len(read);
}
Bytes::from(buf)
})
};
let buf = work.await.map_err(io_err_into_net_error)?;
if buf.is_empty() {
if nonblocking {
return Err(NetworkError::WouldBlock);
} else {
return Err(NetworkError::BrokenPipe);
}
}
Ok(SocketReceive {
truncated: buf.len() == max_buf_size,
data: buf,
})
} }
fn try_recv(&mut self) -> Result<Option<SocketReceive>> { fn try_recv(&mut self) -> Result<Option<SocketReceive>> {
let max_buf_size = 8192; let mut work = self.recv();
let mut buf = Vec::with_capacity(max_buf_size);
unsafe {
buf.set_len(max_buf_size);
}
let waker = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &NOOP_WAKER_VTABLE)) }; let waker = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &NOOP_WAKER_VTABLE)) };
let mut cx = Context::from_waker(&waker); let mut cx = Context::from_waker(&waker);
let stream = Pin::new(&mut self.stream); match work.as_mut().poll(&mut cx) {
let mut read_buf = tokio::io::ReadBuf::new(&mut buf); Poll::Ready(Ok(ret)) => Ok(Some(ret)),
match stream.poll_read(&mut cx, &mut read_buf) { Poll::Ready(Err(err)) => Err(err),
Poll::Ready(Ok(read)) => {
let read = read_buf.remaining();
unsafe {
buf.set_len(read);
}
if read == 0 {
return Err(NetworkError::WouldBlock);
}
let buf = Bytes::from(buf);
Ok(Some(SocketReceive {
data: buf,
truncated: read == max_buf_size,
}))
}
Poll::Ready(Err(err)) => Err(io_err_into_net_error(err)),
Poll::Pending => Ok(None), Poll::Pending => Ok(None),
} }
} }
async fn peek(&mut self) -> Result<SocketReceive> { async fn peek(&mut self) -> Result<SocketReceive> {
let max_buf_size = 8192; let ret = Self::recv_now_ext(
let mut buf = Vec::with_capacity(max_buf_size); self.nonblocking,
unsafe { &mut self.stream,
buf.set_len(max_buf_size); self.read_timeout.clone(),
} )
.await;
if self.nonblocking { self.tx_recv.send(ret.clone()).ok();
let waker = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &NOOP_WAKER_VTABLE)) }; ret
let mut cx = Context::from_waker(&waker);
let stream = Pin::new(&mut self.stream);
let mut read_buf = tokio::io::ReadBuf::new(&mut buf);
return match stream.poll_peek(&mut cx, &mut read_buf) {
Poll::Ready(Ok(read)) => {
unsafe {
buf.set_len(read);
}
if read == 0 {
return Err(NetworkError::WouldBlock);
}
let buf = Bytes::from(buf);
Ok(SocketReceive {
data: buf,
truncated: read == max_buf_size,
})
}
Poll::Ready(Err(err)) => Err(io_err_into_net_error(err)),
Poll::Pending => Err(NetworkError::WouldBlock),
};
}
let timeout = self.write_timeout.clone();
let work = async move {
match timeout {
Some(timeout) => tokio::time::timeout(timeout, self.stream.peek(&mut buf[..]))
.await
.map_err(|_| Into::<std::io::Error>::into(std::io::ErrorKind::WouldBlock))?,
None => self.stream.peek(&mut buf[..]).await,
}
.map(|read| {
unsafe {
buf.set_len(read);
}
Bytes::from(buf)
})
};
let buf = work.await.map_err(io_err_into_net_error)?;
if buf.len() == 0 {
return Err(NetworkError::BrokenPipe);
}
Ok(SocketReceive {
truncated: buf.len() == max_buf_size,
data: buf,
})
} }
} }
@@ -615,10 +581,23 @@ impl VirtualSocket for LocalTcpStream {
&mut self, &mut self,
cx: &mut std::task::Context<'_>, cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<usize>> { ) -> std::task::Poll<Result<usize>> {
self.stream let ret = {
.poll_read_ready(cx) let mut work = Box::pin(Self::recv_now(&mut self.stream, self.read_timeout.clone()));
.map_ok(|a| 1usize) match work.as_mut().poll(cx) {
.map_err(io_err_into_net_error) Poll::Ready(ret) => ret,
Poll::Pending => return Poll::Pending,
}
};
if let Ok(ret) = ret.as_ref() {
if ret.data.len() == 0 {
if self.sent_eof == true {
return Poll::Pending;
}
self.sent_eof = true;
}
}
self.tx_recv.send(ret).ok();
Poll::Ready(Ok(1))
} }
fn poll_write_ready( fn poll_write_ready(

View File

@@ -154,10 +154,10 @@ pub(crate) struct WasiStateThreading {
/// Represents a futex which will make threads wait for completion in a more /// Represents a futex which will make threads wait for completion in a more
/// CPU efficient manner /// CPU efficient manner
#[derive(Debug, Clone)] #[derive(Debug)]
pub struct WasiFutex { pub struct WasiFutex {
pub(crate) refcnt: Arc<AtomicU32>, pub(crate) refcnt: AtomicU32,
pub(crate) inner: Arc<Mutex<tokio::sync::broadcast::Sender<()>>>, pub(crate) waker: tokio::sync::broadcast::Sender<()>,
} }
#[derive(Debug)] #[derive(Debug)]
@@ -251,7 +251,7 @@ pub struct WasiState {
// TODO: review allow... // TODO: review allow...
#[allow(dead_code)] #[allow(dead_code)]
pub(crate) threading: RwLock<WasiStateThreading>, pub(crate) threading: RwLock<WasiStateThreading>,
pub(crate) futexs: Mutex<HashMap<u64, WasiFutex>>, pub(crate) futexs: RwLock<HashMap<u64, WasiFutex>>,
pub(crate) clock_offset: Mutex<HashMap<Snapshot0Clockid, i64>>, pub(crate) clock_offset: Mutex<HashMap<Snapshot0Clockid, i64>>,
pub(crate) bus: WasiBusState, pub(crate) bus: WasiBusState,
pub args: Vec<String>, pub args: Vec<String>,

View File

@@ -31,23 +31,6 @@ pub fn futex_wait<M: MemorySize>(
let pointer: u64 = wasi_try_ok!(futex_ptr.offset().try_into().map_err(|_| Errno::Overflow)); let pointer: u64 = wasi_try_ok!(futex_ptr.offset().try_into().map_err(|_| Errno::Overflow));
// Register the waiting futex (if its not already registered)
let futex = {
use std::collections::hash_map::Entry;
let mut guard = state.futexs.lock().unwrap();
match guard.entry(pointer) {
Entry::Occupied(entry) => entry.get().clone(),
Entry::Vacant(entry) => {
let futex = WasiFutex {
refcnt: Arc::new(AtomicU32::new(1)),
inner: Arc::new(Mutex::new(tokio::sync::broadcast::channel(1).0)),
};
entry.insert(futex.clone());
futex
}
}
};
// Determine the timeout // Determine the timeout
let timeout = { let timeout = {
let memory = env.memory_view(&ctx); let memory = env.memory_view(&ctx);
@@ -62,11 +45,23 @@ pub fn futex_wait<M: MemorySize>(
let mut woken = Bool::False; let mut woken = Bool::False;
let start = platform_clock_time_get(Snapshot0Clockid::Monotonic, 1).unwrap() as u128; let start = platform_clock_time_get(Snapshot0Clockid::Monotonic, 1).unwrap() as u128;
loop { loop {
// Register the waiting futex (if its not already registered)
let mut rx = { let mut rx = {
let futex_lock = futex.inner.lock().unwrap(); use std::collections::hash_map::Entry;
let mut guard = state.futexs.write().unwrap();
if guard.contains_key(&pointer) == false {
let futex = WasiFutex {
refcnt: AtomicU32::new(1),
waker: tokio::sync::broadcast::channel(1).0,
};
guard.insert(pointer, futex);
}
let futex = guard.get_mut(&pointer).unwrap();
// If the value of the memory is no longer the expected value // If the value of the memory is no longer the expected value
// then terminate from the loop (we do this under a futex lock // then terminate from the loop (we do this under a futex lock
// so that its protected) // so that its protected)
let rx = futex.waker.subscribe();
{ {
let view = env.memory_view(&ctx); let view = env.memory_view(&ctx);
let val = wasi_try_mem_ok!(futex_ptr.read(&view)); let val = wasi_try_mem_ok!(futex_ptr.read(&view));
@@ -75,7 +70,7 @@ pub fn futex_wait<M: MemorySize>(
break; break;
} }
} }
futex_lock.subscribe() rx
}; };
// Check if we have timed out // Check if we have timed out
@@ -91,16 +86,16 @@ pub fn futex_wait<M: MemorySize>(
} }
// Now wait for it to be triggered // Now wait for it to be triggered
wasi_try_ok!(__asyncify(&mut ctx, sub_timeout, async move { __asyncify(&mut ctx, sub_timeout, async move {
let _ = rx.recv().await; rx.recv().await.ok();
Ok(()) Ok(())
})?); })?;
env = ctx.data(); env = ctx.data();
} }
// Drop the reference count to the futex (and remove it if the refcnt hits zero) // Drop the reference count to the futex (and remove it if the refcnt hits zero)
{ {
let mut guard = state.futexs.lock().unwrap(); let mut guard = state.futexs.write().unwrap();
if guard if guard
.get(&pointer) .get(&pointer)
.map(|futex| futex.refcnt.fetch_sub(1, Ordering::AcqRel) == 1) .map(|futex| futex.refcnt.fetch_sub(1, Ordering::AcqRel) == 1)

View File

@@ -26,11 +26,10 @@ pub fn futex_wake<M: MemorySize>(
let pointer: u64 = wasi_try!(futex.offset().try_into().map_err(|_| Errno::Overflow)); let pointer: u64 = wasi_try!(futex.offset().try_into().map_err(|_| Errno::Overflow));
let mut woken = false; let mut woken = false;
let mut guard = state.futexs.lock().unwrap(); let mut guard = state.futexs.read().unwrap();
if let Some(futex) = guard.get(&pointer) { if let Some(futex) = guard.get(&pointer) {
let inner = futex.inner.lock().unwrap(); woken = futex.waker.receiver_count() > 0;
woken = inner.receiver_count() > 0; let _ = futex.waker.send(());
let _ = inner.send(());
} else { } else {
trace!( trace!(
"wasi[{}:{}]::futex_wake - nothing waiting!", "wasi[{}:{}]::futex_wake - nothing waiting!",

View File

@@ -24,11 +24,16 @@ pub fn futex_wake_all<M: MemorySize>(
let pointer: u64 = wasi_try!(futex.offset().try_into().map_err(|_| Errno::Overflow)); let pointer: u64 = wasi_try!(futex.offset().try_into().map_err(|_| Errno::Overflow));
let mut woken = false; let mut woken = false;
let mut guard = state.futexs.lock().unwrap(); let mut guard = state.futexs.read().unwrap();
if let Some(futex) = guard.remove(&pointer) { if let Some(futex) = guard.get(&pointer) {
let inner = futex.inner.lock().unwrap(); woken = futex.waker.receiver_count() > 0;
woken = inner.receiver_count() > 0; let _ = futex.waker.send(());
let _ = inner.send(()); } else {
trace!(
"wasi[{}:{}]::futex_wake_all - nothing waiting!",
ctx.data().pid(),
ctx.data().tid()
);
} }
let woken = match woken { let woken = match woken {