bones_framework/networking/
socket.rs

1// TODO
2#![allow(missing_docs)]
3
4use bones_matchmaker_proto::PLAY_ALPN;
5use bytes::Bytes;
6use iroh::NodeAddr;
7use tracing::{info, warn};
8
9use crate::networking::get_network_endpoint;
10
11use super::{GameMessage, NetworkSocket, SocketTarget, RUNTIME};
12
13/// The [`NetworkSocket`] implementation.
14#[derive(Debug, Clone)]
15pub struct Socket {
16    pub connections: Vec<(u32, iroh::endpoint::Connection)>,
17    pub ggrs_receiver: async_channel::Receiver<(u32, GameMessage)>,
18    pub reliable_receiver: async_channel::Receiver<(u32, Vec<u8>)>,
19    pub player_idx: u32,
20    pub player_count: u32,
21    /// ID for current match, messages received that do not match ID are dropped.
22    pub match_id: u8,
23}
24
25impl Socket {
26    pub fn new(player_idx: u32, connections: Vec<(u32, iroh::endpoint::Connection)>) -> Self {
27        let (ggrs_sender, ggrs_receiver) = async_channel::unbounded();
28        let (reliable_sender, reliable_receiver) = async_channel::unbounded();
29
30        // Spawn tasks to receive network messages from each peer
31        for (i, conn) in &connections {
32            let ggrs_sender = ggrs_sender.clone();
33            let i = *i;
34
35            // Unreliable message receiver
36            let conn_ = conn.clone();
37            RUNTIME.spawn(async move {
38                let conn = conn_;
39
40                #[cfg(feature = "debug-network-slowdown")]
41                use turborand::prelude::*;
42                #[cfg(feature = "debug-network-slowdown")]
43                let rng = AtomicRng::new();
44
45                loop {
46                    tokio::select! {
47                        closed = conn.closed() => {
48                            warn!("Connection error: {closed}");
49                            break;
50                        }
51                        datagram_result = conn.read_datagram() => match datagram_result {
52                            Ok(data) => {
53                                let message: GameMessage = postcard::from_bytes(&data)
54                                .expect("Could not deserialize net message");
55
56                                // Debugging code to introduce artificial latency
57                                #[cfg(feature = "debug-network-slowdown")]
58                                {
59                                    use async_timer::Oneshot;
60                                    async_timer::oneshot::Timer::new(
61                                        std::time::Duration::from_millis(
62                                            (rng.f32_normalized() * 100.0) as u64 + 1,
63                                        ),
64                                    )
65                                    .await;
66                                }
67                                if ggrs_sender.send((i, message)).await.is_err() {
68                                    break;
69                                }
70                            }
71                            Err(e) => {
72                                warn!("Connection error: {e}");
73                            }
74                        }
75                    }
76                }
77            });
78
79            // Reliable message receiver
80            let reliable_sender = reliable_sender.clone();
81            let conn = conn.clone();
82            RUNTIME.spawn(async move {
83                #[cfg(feature = "debug-network-slowdown")]
84                use turborand::prelude::*;
85                #[cfg(feature = "debug-network-slowdown")]
86                let rng = AtomicRng::new();
87
88                loop {
89                    tokio::select! {
90                        closed = conn.closed() => {
91                            warn!("Connection error: {closed}");
92                            break;
93                        }
94                        result = conn.accept_uni() => match result {
95                            Ok(mut stream) => {
96                                let data = stream.read_to_end(4096).await.expect("Network read error");
97
98                                // Debugging code to introduce artificial latency
99                                #[cfg(feature = "debug-network-slowdown")]
100                                {
101                                    use async_timer::Oneshot;
102                                    async_timer::oneshot::Timer::new(
103                                        std::time::Duration::from_millis(
104                                            (rng.f32_normalized() * 100.0) as u64 + 1,
105                                        ),
106                                    )
107                                    .await;
108                                }
109                                if reliable_sender.send((i, data)).await.is_err() {
110                                    break;
111                                }
112                            }
113                            Err(e) => {
114                                warn!("Connection error: {e}");
115                            }
116                        },
117                    }
118                }
119            });
120        }
121
122        Self {
123            player_idx,
124            player_count: (connections.len() + 1).try_into().unwrap(),
125            connections,
126            ggrs_receiver,
127            reliable_receiver,
128            match_id: 0,
129        }
130    }
131
132    fn get_connection(&self, idx: u32) -> &iroh::endpoint::Connection {
133        debug_assert!(idx < self.player_count);
134        // TODO: if this is too slow, optimize storage
135        self.connections
136            .iter()
137            .find(|(i, _)| *i == idx)
138            .map(|(_, c)| c)
139            .unwrap()
140    }
141}
142
143impl NetworkSocket for Socket {
144    fn send_reliable(&self, target: SocketTarget, message: &[u8]) {
145        let message = Bytes::copy_from_slice(message);
146
147        match target {
148            SocketTarget::Player(i) => {
149                let conn = self.get_connection(i).clone();
150
151                RUNTIME.spawn(async move {
152                    let result = async move {
153                        let mut stream = conn.open_uni().await?;
154                        stream.write_chunk(message).await?;
155                        stream.finish()?;
156                        stream.stopped().await?;
157                        anyhow::Ok(())
158                    };
159                    if let Err(err) = result.await {
160                        warn!("send reliable to {i} failed: {err:?}");
161                    }
162                });
163            }
164            SocketTarget::All => {
165                for (_, conn) in &self.connections {
166                    let message = message.clone();
167                    let conn = conn.clone();
168                    RUNTIME.spawn(async move {
169                        let result = async move {
170                            let mut stream = conn.open_uni().await?;
171                            stream.write_chunk(message).await?;
172                            stream.finish()?;
173                            stream.stopped().await?;
174                            anyhow::Ok(())
175                        };
176                        if let Err(err) = result.await {
177                            warn!("send reliable all failed: {err:?}");
178                        }
179                    });
180                }
181            }
182        }
183    }
184
185    fn recv_reliable(&self) -> Vec<(u32, Vec<u8>)> {
186        let mut messages = Vec::new();
187        while let Ok(message) = self.reliable_receiver.try_recv() {
188            messages.push(message);
189        }
190        messages
191    }
192
193    fn ggrs_socket(&self) -> Self {
194        self.clone()
195    }
196
197    fn close(&self) {
198        for (_, conn) in &self.connections {
199            conn.close(0u8.into(), &[]);
200        }
201    }
202
203    fn player_idx(&self) -> u32 {
204        self.player_idx
205    }
206
207    fn player_count(&self) -> u32 {
208        self.player_count
209    }
210
211    fn increment_match_id(&mut self) {
212        self.match_id = self.match_id.wrapping_add(1);
213    }
214}
215
216pub(super) async fn establish_peer_connections(
217    player_idx: u32,
218    player_count: u32,
219    peer_addrs: Vec<(u32, NodeAddr)>,
220    conn: Option<iroh::endpoint::Connection>,
221) -> anyhow::Result<Vec<(u32, iroh::endpoint::Connection)>> {
222    let mut peer_connections = Vec::new();
223    let had_og_conn = conn.is_some();
224    if let Some(conn) = conn {
225        // Set the connection to the matchmaker for player 0
226        peer_connections.push((0, conn));
227    }
228
229    let ep = get_network_endpoint().await;
230
231    // For every peer with a player index that is higher than ours, wait for
232    // them to connect to us.
233    let mut in_connections = Vec::new();
234    let range = (player_idx + 1)..player_count;
235    info!(players=?range, "Waiting for {} peer connections", range.len());
236    for i in range {
237        // Wait for connection
238        let conn = ep
239            .accept()
240            .await
241            .ok_or_else(|| anyhow::anyhow!("no connection for {}", i))?;
242        let mut connecting = conn.accept()?;
243        let alpn = connecting.alpn().await?;
244        anyhow::ensure!(
245            alpn == PLAY_ALPN,
246            "invalid ALPN: {:?}",
247            std::str::from_utf8(&alpn).unwrap_or("<bytes>")
248        );
249
250        let conn = connecting.await?;
251
252        // Receive the player index
253        let idx = {
254            let mut buf = [0; 4];
255            let mut channel = conn.accept_uni().await?;
256            channel.read_exact(&mut buf).await?;
257
258            u32::from_le_bytes(buf)
259        };
260
261        in_connections.push((idx, conn));
262    }
263
264    // For every peer with a player index lower than ours, connect to them.
265    let start_range = if had_og_conn { 1 } else { 0 };
266    let range = start_range..player_idx;
267    info!(players=?range, "Connecting to {} peers", range.len());
268
269    let mut out_connections = Vec::new();
270    for i in range {
271        let (_, addr) = peer_addrs.iter().find(|(idx, _)| *idx == i).unwrap();
272        let conn = ep.connect(addr.clone(), PLAY_ALPN).await?;
273
274        // Send player index
275        let mut channel = conn.open_uni().await?;
276        channel.write(&player_idx.to_le_bytes()).await?;
277        channel.finish()?;
278        channel.stopped().await?;
279
280        out_connections.push((i, conn));
281    }
282
283    peer_connections.extend(out_connections);
284    peer_connections.extend(in_connections);
285
286    Ok(peer_connections)
287}
288
289impl ggrs::NonBlockingSocket<usize> for Socket {
290    fn send_to(&mut self, msg: &ggrs::Message, addr: &usize) {
291        let msg = GameMessage {
292            // Consider a way we can send message by reference and avoid clone?
293            message: msg.clone(),
294            match_id: self.match_id,
295        };
296        let conn = self.get_connection((*addr).try_into().unwrap());
297
298        let msg_bytes = postcard::to_allocvec(&msg).unwrap();
299        conn.send_datagram(Bytes::copy_from_slice(&msg_bytes[..]))
300            .ok();
301    }
302
303    fn receive_all_messages(&mut self) -> Vec<(usize, ggrs::Message)> {
304        let mut messages = Vec::new();
305        while let Ok(message) = self.ggrs_receiver.try_recv() {
306            if message.1.match_id == self.match_id {
307                messages.push((message.0 as usize, message.1.message));
308            }
309        }
310        messages
311    }
312}