From c4f93449eefa8248d485f40f7f8c9a4516948e9c Mon Sep 17 00:00:00 2001 From: mii443 Date: Sun, 1 Sep 2024 00:12:23 +0900 Subject: [PATCH] add close process --- src/bin.rs | 26 ++++++++++----- src/p2p.rs | 97 ++++++++++++++++++++++++++++++++++++++++-------------- 2 files changed, 91 insertions(+), 32 deletions(-) diff --git a/src/bin.rs b/src/bin.rs index f659bbe..be3ff79 100644 --- a/src/bin.rs +++ b/src/bin.rs @@ -1,6 +1,6 @@ use std::io::Write; -use easyp2p::p2p::P2P; +use easyp2p::p2p::{Receive, P2P}; use anyhow::Result; fn read_line() -> String { @@ -50,20 +50,30 @@ async fn client_b(code: &str) -> Result<()> { } async fn chat(mut p2p: P2P) -> Result<()> { - tokio::spawn({ - let receive = p2p.receive_data.clone(); + let receive_thread = tokio::spawn({ + let receive = p2p.receive_data_rx.clone(); async move { let mut receive = receive.lock().await; - while let Some(data) = receive.recv().await { + while let Some(Receive::Data(data)) = receive.recv().await { print!("Received: {}", String::from_utf8(data.to_vec()).unwrap()); std::io::stdout().flush().unwrap(); } + println!("P2P Closed!"); } }); - loop { - let line = read_line(); - p2p.send_text(&format!("{line}\n")).await?; + let send_thread = tokio::spawn(async move { + loop { + let line = tokio::task::spawn_blocking(|| read_line()).await.unwrap(); + p2p.send_text(&format!("{line}\n")).await.unwrap(); + } + }); + + tokio::select! { + _ = receive_thread => {}, + _ = send_thread => {} } -} \ No newline at end of file + + Ok(()) +} diff --git a/src/p2p.rs b/src/p2p.rs index e199da4..690c433 100644 --- a/src/p2p.rs +++ b/src/p2p.rs @@ -8,7 +8,7 @@ use futures::StreamExt; use reqwest::Client; use serde::{Deserialize, Serialize}; use tokio::sync::{mpsc::{Receiver, Sender}, Mutex}; -use webrtc::{api::{interceptor_registry::register_default_interceptors, media_engine::MediaEngine, APIBuilder}, data_channel::{data_channel_message::DataChannelMessage, RTCDataChannel}, ice_transport::{ice_candidate::RTCIceCandidate, ice_gatherer::RTCIceGatherer, ice_server::RTCIceServer}, interceptor::registry::Registry, peer_connection::{configuration::RTCConfiguration, peer_connection_state::RTCPeerConnectionState, sdp::session_description::RTCSessionDescription, RTCPeerConnection}}; +use webrtc::{api::{interceptor_registry::register_default_interceptors, media_engine::MediaEngine, APIBuilder}, data_channel::{data_channel_message::DataChannelMessage, RTCDataChannel}, ice_transport::{ice_candidate::RTCIceCandidate, ice_gatherer::RTCIceGatherer, ice_server::RTCIceServer}, interceptor::registry::Registry, peer_connection::{self, configuration::RTCConfiguration, peer_connection_state::RTCPeerConnectionState, sdp::session_description::RTCSessionDescription, RTCPeerConnection}}; #[derive(Clone, Serialize, Deserialize)] pub struct SessionDescription { @@ -22,11 +22,11 @@ pub struct ConnectionCode { pub struct P2P { peer_connection: Arc>, - send_data: Arc>>, + send_data_tx: Arc>>, send_data_rx: Arc>>, - pub receive_data: Arc>>, - receive_data_tx: Arc>>, - done_rx: Receiver<()>, + pub receive_data_rx: Arc>>, + receive_data_tx: Arc>>, + done_tx: Arc>>, on_open: Receiver<()>, on_open_tx: Arc>> } @@ -36,6 +36,11 @@ enum SendData { String(String) } +pub enum Receive { + Data(Bytes), + Close +} + impl P2P { pub async fn connect_with_code(&mut self, signaling_server: &str, code: &str) -> Result<()> { let client = Client::new(); @@ -121,19 +126,22 @@ impl P2P { } pub async fn send(&mut self, data: Bytes) -> Result<()> { - self.send_data.lock().await.send(SendData::Bytes(data)).await.context("Failed to send data") + self.send_data_tx.lock().await.send(SendData::Bytes(data)).await.context("Failed to send data") } pub async fn send_text(&mut self, data: &str) -> Result<()> { - self.send_data.lock().await.send(SendData::String(data.to_string())).await.context("Failed to send data") + self.send_data_tx.lock().await.send(SendData::String(data.to_string())).await.context("Failed to send data") } - pub async fn receive(&mut self) -> Option { - self.receive_data.lock().await.recv().await + pub async fn receive(&mut self) -> Option { + self.receive_data_rx.lock().await.recv().await } pub async fn receive_text(&mut self) -> Result { - String::from_utf8(self.receive_data.lock().await.recv().await.context("Failed to receive data")?.to_vec()).context("Failed to convert") + match self.receive().await.context("Failed to receive data")? { + Receive::Data(data) => String::from_utf8(data.to_vec()).context("Failed to convert"), + Receive::Close => Err(anyhow!("Session closed")) + } } pub async fn set_answer(&mut self, answer: &str, compress: bool) -> Result<()> { @@ -148,20 +156,30 @@ impl P2P { let receive_data_tx = self.receive_data_tx.clone(); let on_open_tx = self.on_open_tx.clone(); let send_data_rx = self.send_data_rx.clone(); + let done_tx = self.done_tx.clone(); peer_connection .on_data_channel(Box::new(move |d: Arc| { let receive_data_tx = receive_data_tx.clone(); let on_open_tx = on_open_tx.clone(); let send_data_rx = send_data_rx.clone(); + let done_tx = done_tx.clone(); Box::pin(async move { let d2 = Arc::clone(&d); + d.on_close(Box::new(move || { + let done_tx = done_tx.clone(); + + Box::pin(async move { + done_tx.lock().await.send(()).await.unwrap(); + }) + })); + d.on_message(Box::new(move |msg: DataChannelMessage| { let receive_data_tx = receive_data_tx.clone(); Box::pin(async move { - receive_data_tx.lock().await.send(msg.data).await.unwrap(); + receive_data_tx.lock().await.send(Receive::Data(msg.data)).await.unwrap(); }) })); @@ -219,7 +237,16 @@ impl P2P { let receive_data_tx = receive_data_tx.clone(); Box::pin(async move { - receive_data_tx.lock().await.send(msg.data).await.unwrap(); + receive_data_tx.lock().await.send(Receive::Data(msg.data)).await.unwrap(); + }) + })); + + let done_tx = self.done_tx.clone(); + data_channel.on_close(Box::new(move || { + let done_tx = done_tx.clone(); + + Box::pin(async move { + done_tx.lock().await.send(()).await.unwrap(); }) })); @@ -314,28 +341,50 @@ impl P2P { let peer_connection = Arc::new(Mutex::new(api.new_peer_connection(config.rtc_configuration).await?)); - let (done_tx, done_rx) = tokio::sync::mpsc::channel::<()>(1); + let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); - peer_connection.lock().await.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - if s == RTCPeerConnectionState::Failed || s == RTCPeerConnectionState::Disconnected { - let _ = done_tx.try_send(()); + let done_tx = Arc::new(Mutex::new(done_tx)); + peer_connection.lock().await.on_peer_connection_state_change(Box::new({ + let done_tx = done_tx.clone(); + move |s: RTCPeerConnectionState| { + let done_tx = done_tx.clone(); + Box::pin(async move { + if s == RTCPeerConnectionState::Failed || s == RTCPeerConnectionState::Disconnected || s == RTCPeerConnectionState::Closed { + let _ = done_tx.lock().await.try_send(()); + } + }) } - - Box::pin(async {}) })); let (send_data_tx, send_data_rx) = tokio::sync::mpsc::channel::(128); let (on_open_tx, on_open_rx) = tokio::sync::mpsc::channel::<()>(1); - let (receive_data_tx, receive_data_rx) = tokio::sync::mpsc::channel::(128); + let (receive_data_tx, receive_data_rx) = tokio::sync::mpsc::channel::(128); + + let send_data_tx = Arc::new(Mutex::new(send_data_tx)); + let send_data_rx = Arc::new(Mutex::new(send_data_rx)); + let receive_data_rx = Arc::new(Mutex::new(receive_data_rx)); + let receive_data_tx = Arc::new(Mutex::new(receive_data_tx)); + + tokio::spawn({ + let peer_connection = peer_connection.clone(); + let receive_data_tx = receive_data_tx.clone(); + async move { + if let Some(()) = done_rx.recv().await { + peer_connection.lock().await.close().await.unwrap(); + receive_data_tx.lock().await.send(Receive::Close).await.unwrap(); + receive_data_tx.lock().await.closed().await; + } + } + }); Ok( Self { peer_connection, - send_data: Arc::new(Mutex::new(send_data_tx)), - send_data_rx: Arc::new(Mutex::new(send_data_rx)), - receive_data: Arc::new(Mutex::new(receive_data_rx)), - receive_data_tx: Arc::new(Mutex::new(receive_data_tx)), - done_rx, + send_data_tx, + send_data_rx, + receive_data_rx, + receive_data_tx, + done_tx, on_open: on_open_rx, on_open_tx: Arc::new(Mutex::new(on_open_tx)) }