Fixed for a futex race conditon and infinite polling

This commit is contained in:
Johnathan Sharratt
2023-01-19 01:02:51 +11:00
parent 7e1e9fa06b
commit 30ec91d489
6 changed files with 156 additions and 178 deletions

View File

@@ -177,7 +177,7 @@ pub trait VirtualNetworking: fmt::Debug + Send + Sync + 'static {
pub type DynVirtualNetworking = Arc<dyn VirtualNetworking>;
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct SocketReceive {
/// Data that was received
pub data: Bytes,
@@ -185,7 +185,7 @@ pub struct SocketReceive {
pub truncated: bool,
}
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct SocketReceiveFrom {
/// Data that was received
pub data: Bytes,

View File

@@ -255,7 +255,10 @@ pub struct LocalTcpStream {
connect_timeout: Option<Duration>,
linger_timeout: Option<Duration>,
nonblocking: bool,
sent_eof: bool,
shutdown: Option<Shutdown>,
tx_recv: mpsc::UnboundedSender<Result<SocketReceive>>,
rx_recv: mpsc::UnboundedReceiver<Result<SocketReceive>>,
tx_write_ready: mpsc::Sender<()>,
rx_write_ready: mpsc::Receiver<()>,
tx_write_poll_ready: mpsc::Sender<()>,
@@ -264,6 +267,7 @@ pub struct LocalTcpStream {
impl LocalTcpStream {
pub fn new(stream: tokio::net::TcpStream, addr: SocketAddr, nonblocking: bool) -> Self {
let (tx_recv, rx_recv) = mpsc::unbounded_channel();
let (tx_write_ready, rx_write_ready) = mpsc::channel(1);
let (tx_write_poll_ready, rx_write_poll_ready) = mpsc::channel(1);
Self {
@@ -275,10 +279,13 @@ impl LocalTcpStream {
linger_timeout: None,
nonblocking,
shutdown: None,
sent_eof: false,
tx_write_ready,
rx_write_ready,
tx_write_poll_ready,
rx_write_poll_ready,
tx_recv,
rx_recv,
}
}
}
@@ -355,6 +362,80 @@ impl VirtualTcpSocket for LocalTcpStream {
}
}
impl LocalTcpStream {
async fn recv_now_ext(
nonblocking: bool,
stream: &mut tokio::net::TcpStream,
timeout: Option<Duration>,
) -> Result<SocketReceive> {
if nonblocking {
let max_buf_size = 8192;
let mut buf = Vec::with_capacity(max_buf_size);
unsafe {
buf.set_len(max_buf_size);
}
let waker = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &NOOP_WAKER_VTABLE)) };
let mut cx = Context::from_waker(&waker);
let stream = Pin::new(stream);
let mut read_buf = tokio::io::ReadBuf::new(&mut buf);
return match stream.poll_read(&mut cx, &mut read_buf) {
Poll::Ready(Ok(read)) => {
let read = read_buf.remaining();
unsafe {
buf.set_len(read);
}
if read == 0 {
return Err(NetworkError::WouldBlock);
}
let buf = Bytes::from(buf);
Ok(SocketReceive {
data: buf,
truncated: read == max_buf_size,
})
}
Poll::Ready(Err(err)) => Err(io_err_into_net_error(err)),
Poll::Pending => Err(NetworkError::WouldBlock),
};
} else {
Self::recv_now(stream, timeout).await
}
}
async fn recv_now(
stream: &mut tokio::net::TcpStream,
timeout: Option<Duration>,
) -> Result<SocketReceive> {
use tokio::io::AsyncReadExt;
let max_buf_size = 8192;
let mut buf = Vec::with_capacity(max_buf_size);
unsafe {
buf.set_len(max_buf_size);
}
let work = async move {
match timeout {
Some(timeout) => tokio::time::timeout(timeout, stream.read(&mut buf[..]))
.await
.map_err(|_| Into::<std::io::Error>::into(std::io::ErrorKind::TimedOut))?,
None => stream.read(&mut buf[..]).await,
}
.map(|read| {
unsafe {
buf.set_len(read);
}
Bytes::from(buf)
})
};
let buf = work.await.map_err(io_err_into_net_error)?;
Ok(SocketReceive {
truncated: buf.len() == max_buf_size,
data: buf,
})
}
}
#[async_trait::async_trait]
impl VirtualConnectedSocket for LocalTcpStream {
fn set_linger(&mut self, linger: Option<Duration>) -> Result<()> {
@@ -432,155 +513,40 @@ impl VirtualConnectedSocket for LocalTcpStream {
}
async fn recv(&mut self) -> Result<SocketReceive> {
use tokio::io::AsyncReadExt;
let max_buf_size = 8192;
let mut buf = Vec::with_capacity(max_buf_size);
unsafe {
buf.set_len(max_buf_size);
if let Ok(ret) = self.rx_recv.try_recv() {
return ret;
}
let nonblocking = self.nonblocking;
if nonblocking {
let waker = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &NOOP_WAKER_VTABLE)) };
let mut cx = Context::from_waker(&waker);
let stream = Pin::new(&mut self.stream);
let mut read_buf = tokio::io::ReadBuf::new(&mut buf);
return match stream.poll_read(&mut cx, &mut read_buf) {
Poll::Ready(Ok(read)) => {
let read = read_buf.remaining();
unsafe {
buf.set_len(read);
}
if read == 0 {
return Err(NetworkError::WouldBlock);
}
let buf = Bytes::from(buf);
Ok(SocketReceive {
data: buf,
truncated: read == max_buf_size,
})
}
Poll::Ready(Err(err)) => Err(io_err_into_net_error(err)),
Poll::Pending => Err(NetworkError::WouldBlock),
};
tokio::select! {
ret = Self::recv_now_ext(
self.nonblocking,
&mut self.stream,
self.read_timeout.clone(),
) => ret,
ret = self.rx_recv.recv() => ret.unwrap_or(Err(NetworkError::ConnectionAborted))
}
let timeout = self.write_timeout.clone();
let work = async move {
match timeout {
Some(timeout) => tokio::time::timeout(timeout, self.stream.read(&mut buf[..]))
.await
.map_err(|_| Into::<std::io::Error>::into(std::io::ErrorKind::WouldBlock))?,
None => self.stream.read(&mut buf[..]).await,
}
.map(|read| {
unsafe {
buf.set_len(read);
}
Bytes::from(buf)
})
};
let buf = work.await.map_err(io_err_into_net_error)?;
if buf.is_empty() {
if nonblocking {
return Err(NetworkError::WouldBlock);
} else {
return Err(NetworkError::BrokenPipe);
}
}
Ok(SocketReceive {
truncated: buf.len() == max_buf_size,
data: buf,
})
}
fn try_recv(&mut self) -> Result<Option<SocketReceive>> {
let max_buf_size = 8192;
let mut buf = Vec::with_capacity(max_buf_size);
unsafe {
buf.set_len(max_buf_size);
}
let mut work = self.recv();
let waker = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &NOOP_WAKER_VTABLE)) };
let mut cx = Context::from_waker(&waker);
let stream = Pin::new(&mut self.stream);
let mut read_buf = tokio::io::ReadBuf::new(&mut buf);
match stream.poll_read(&mut cx, &mut read_buf) {
Poll::Ready(Ok(read)) => {
let read = read_buf.remaining();
unsafe {
buf.set_len(read);
}
if read == 0 {
return Err(NetworkError::WouldBlock);
}
let buf = Bytes::from(buf);
Ok(Some(SocketReceive {
data: buf,
truncated: read == max_buf_size,
}))
}
Poll::Ready(Err(err)) => Err(io_err_into_net_error(err)),
match work.as_mut().poll(&mut cx) {
Poll::Ready(Ok(ret)) => Ok(Some(ret)),
Poll::Ready(Err(err)) => Err(err),
Poll::Pending => Ok(None),
}
}
async fn peek(&mut self) -> Result<SocketReceive> {
let max_buf_size = 8192;
let mut buf = Vec::with_capacity(max_buf_size);
unsafe {
buf.set_len(max_buf_size);
}
if self.nonblocking {
let waker = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &NOOP_WAKER_VTABLE)) };
let mut cx = Context::from_waker(&waker);
let stream = Pin::new(&mut self.stream);
let mut read_buf = tokio::io::ReadBuf::new(&mut buf);
return match stream.poll_peek(&mut cx, &mut read_buf) {
Poll::Ready(Ok(read)) => {
unsafe {
buf.set_len(read);
}
if read == 0 {
return Err(NetworkError::WouldBlock);
}
let buf = Bytes::from(buf);
Ok(SocketReceive {
data: buf,
truncated: read == max_buf_size,
})
}
Poll::Ready(Err(err)) => Err(io_err_into_net_error(err)),
Poll::Pending => Err(NetworkError::WouldBlock),
};
}
let timeout = self.write_timeout.clone();
let work = async move {
match timeout {
Some(timeout) => tokio::time::timeout(timeout, self.stream.peek(&mut buf[..]))
.await
.map_err(|_| Into::<std::io::Error>::into(std::io::ErrorKind::WouldBlock))?,
None => self.stream.peek(&mut buf[..]).await,
}
.map(|read| {
unsafe {
buf.set_len(read);
}
Bytes::from(buf)
})
};
let buf = work.await.map_err(io_err_into_net_error)?;
if buf.len() == 0 {
return Err(NetworkError::BrokenPipe);
}
Ok(SocketReceive {
truncated: buf.len() == max_buf_size,
data: buf,
})
let ret = Self::recv_now_ext(
self.nonblocking,
&mut self.stream,
self.read_timeout.clone(),
)
.await;
self.tx_recv.send(ret.clone()).ok();
ret
}
}
@@ -615,10 +581,23 @@ impl VirtualSocket for LocalTcpStream {
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<usize>> {
self.stream
.poll_read_ready(cx)
.map_ok(|a| 1usize)
.map_err(io_err_into_net_error)
let ret = {
let mut work = Box::pin(Self::recv_now(&mut self.stream, self.read_timeout.clone()));
match work.as_mut().poll(cx) {
Poll::Ready(ret) => ret,
Poll::Pending => return Poll::Pending,
}
};
if let Ok(ret) = ret.as_ref() {
if ret.data.len() == 0 {
if self.sent_eof == true {
return Poll::Pending;
}
self.sent_eof = true;
}
}
self.tx_recv.send(ret).ok();
Poll::Ready(Ok(1))
}
fn poll_write_ready(

View File

@@ -154,10 +154,10 @@ pub(crate) struct WasiStateThreading {
/// Represents a futex which will make threads wait for completion in a more
/// CPU efficient manner
#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct WasiFutex {
pub(crate) refcnt: Arc<AtomicU32>,
pub(crate) inner: Arc<Mutex<tokio::sync::broadcast::Sender<()>>>,
pub(crate) refcnt: AtomicU32,
pub(crate) waker: tokio::sync::broadcast::Sender<()>,
}
#[derive(Debug)]
@@ -251,7 +251,7 @@ pub struct WasiState {
// TODO: review allow...
#[allow(dead_code)]
pub(crate) threading: RwLock<WasiStateThreading>,
pub(crate) futexs: Mutex<HashMap<u64, WasiFutex>>,
pub(crate) futexs: RwLock<HashMap<u64, WasiFutex>>,
pub(crate) clock_offset: Mutex<HashMap<Snapshot0Clockid, i64>>,
pub(crate) bus: WasiBusState,
pub args: Vec<String>,

View File

@@ -31,23 +31,6 @@ pub fn futex_wait<M: MemorySize>(
let pointer: u64 = wasi_try_ok!(futex_ptr.offset().try_into().map_err(|_| Errno::Overflow));
// Register the waiting futex (if its not already registered)
let futex = {
use std::collections::hash_map::Entry;
let mut guard = state.futexs.lock().unwrap();
match guard.entry(pointer) {
Entry::Occupied(entry) => entry.get().clone(),
Entry::Vacant(entry) => {
let futex = WasiFutex {
refcnt: Arc::new(AtomicU32::new(1)),
inner: Arc::new(Mutex::new(tokio::sync::broadcast::channel(1).0)),
};
entry.insert(futex.clone());
futex
}
}
};
// Determine the timeout
let timeout = {
let memory = env.memory_view(&ctx);
@@ -62,11 +45,23 @@ pub fn futex_wait<M: MemorySize>(
let mut woken = Bool::False;
let start = platform_clock_time_get(Snapshot0Clockid::Monotonic, 1).unwrap() as u128;
loop {
// Register the waiting futex (if its not already registered)
let mut rx = {
let futex_lock = futex.inner.lock().unwrap();
use std::collections::hash_map::Entry;
let mut guard = state.futexs.write().unwrap();
if guard.contains_key(&pointer) == false {
let futex = WasiFutex {
refcnt: AtomicU32::new(1),
waker: tokio::sync::broadcast::channel(1).0,
};
guard.insert(pointer, futex);
}
let futex = guard.get_mut(&pointer).unwrap();
// If the value of the memory is no longer the expected value
// then terminate from the loop (we do this under a futex lock
// so that its protected)
let rx = futex.waker.subscribe();
{
let view = env.memory_view(&ctx);
let val = wasi_try_mem_ok!(futex_ptr.read(&view));
@@ -75,7 +70,7 @@ pub fn futex_wait<M: MemorySize>(
break;
}
}
futex_lock.subscribe()
rx
};
// Check if we have timed out
@@ -91,16 +86,16 @@ pub fn futex_wait<M: MemorySize>(
}
// Now wait for it to be triggered
wasi_try_ok!(__asyncify(&mut ctx, sub_timeout, async move {
let _ = rx.recv().await;
__asyncify(&mut ctx, sub_timeout, async move {
rx.recv().await.ok();
Ok(())
})?);
})?;
env = ctx.data();
}
// Drop the reference count to the futex (and remove it if the refcnt hits zero)
{
let mut guard = state.futexs.lock().unwrap();
let mut guard = state.futexs.write().unwrap();
if guard
.get(&pointer)
.map(|futex| futex.refcnt.fetch_sub(1, Ordering::AcqRel) == 1)

View File

@@ -26,11 +26,10 @@ pub fn futex_wake<M: MemorySize>(
let pointer: u64 = wasi_try!(futex.offset().try_into().map_err(|_| Errno::Overflow));
let mut woken = false;
let mut guard = state.futexs.lock().unwrap();
let mut guard = state.futexs.read().unwrap();
if let Some(futex) = guard.get(&pointer) {
let inner = futex.inner.lock().unwrap();
woken = inner.receiver_count() > 0;
let _ = inner.send(());
woken = futex.waker.receiver_count() > 0;
let _ = futex.waker.send(());
} else {
trace!(
"wasi[{}:{}]::futex_wake - nothing waiting!",

View File

@@ -24,11 +24,16 @@ pub fn futex_wake_all<M: MemorySize>(
let pointer: u64 = wasi_try!(futex.offset().try_into().map_err(|_| Errno::Overflow));
let mut woken = false;
let mut guard = state.futexs.lock().unwrap();
if let Some(futex) = guard.remove(&pointer) {
let inner = futex.inner.lock().unwrap();
woken = inner.receiver_count() > 0;
let _ = inner.send(());
let mut guard = state.futexs.read().unwrap();
if let Some(futex) = guard.get(&pointer) {
woken = futex.waker.receiver_count() > 0;
let _ = futex.waker.send(());
} else {
trace!(
"wasi[{}:{}]::futex_wake_all - nothing waiting!",
ctx.data().pid(),
ctx.data().tid()
);
}
let woken = match woken {