From cabc6f205231b611e5c91b1684a1557066779216 Mon Sep 17 00:00:00 2001 From: Ulf Lilleengen Date: Mon, 22 Apr 2024 20:08:58 +0200 Subject: [PATCH] reduce size of channel manager --- host/src/channel_manager.rs | 425 ++++++++++++++---------------------- host/src/l2cap.rs | 11 +- 2 files changed, 174 insertions(+), 262 deletions(-) diff --git a/host/src/channel_manager.rs b/host/src/channel_manager.rs index eb1c339..0241cea 100644 --- a/host/src/channel_manager.rs +++ b/host/src/channel_manager.rs @@ -24,7 +24,7 @@ const BASE_ID: u16 = 0x40; struct State { next_req_id: u8, - channels: [ChannelState; CHANNELS], + channels: [ChannelStorage; CHANNELS], accept_waker: WakerRegistration, create_waker: WakerRegistration, credit_wakers: [WakerRegistration; CHANNELS], @@ -55,14 +55,13 @@ impl< { const TX_CHANNEL: Channel, L2CAP_TXQ> = Channel::new(); const RX_CHANNEL: Channel>, L2CAP_RXQ> = Channel::new(); - const DISCONNECTED: ChannelState = ChannelState::Disconnected; const CREDIT_WAKER: WakerRegistration = WakerRegistration::new(); pub fn new(pool: &'d dyn DynamicPacketPool<'d>) -> Self { Self { pool, state: Mutex::new(RefCell::new(State { next_req_id: 0, - channels: [Self::DISCONNECTED; CHANNELS], + channels: [ChannelStorage::DISCONNECTED; CHANNELS], accept_waker: WakerRegistration::new(), create_waker: WakerRegistration::new(), credit_wakers: [Self::CREDIT_WAKER; CHANNELS], @@ -84,55 +83,57 @@ impl< }) } - pub(crate) fn disconnect(&self, cid: u16) -> Result<(), Error> { - let idx = self.state.lock(|state| { + pub(crate) fn disconnect(&self, cid: u16) -> Result { + let handle = self.state.lock(|state| { let mut state = state.borrow_mut(); for (idx, storage) in state.channels.iter_mut().enumerate() { - match storage { - ChannelState::Disconnecting(state) if cid == state.cid => { - *storage = ChannelState::Disconnected; - return Ok(idx); + match storage.state { + ChannelState::Disconnecting if cid == storage.cid => { + storage.state = ChannelState::Disconnected; + let _ = self.inbound[idx].try_send(None); + return Ok(storage.conn); } - ChannelState::PeerConnecting(state) if cid == state.cid => { - *storage = ChannelState::Disconnecting(DisconnectingState { conn: state.conn, cid }); - return Ok(idx); + ChannelState::PeerConnecting(_) if cid == storage.cid => { + storage.state = ChannelState::Disconnecting; + let _ = self.inbound[idx].try_send(None); + return Ok(storage.conn); } - ChannelState::Connecting(state) if cid == state.cid => { - *storage = ChannelState::Disconnecting(DisconnectingState { conn: state.conn, cid }); - return Ok(idx); + ChannelState::Connecting(_) if cid == storage.cid => { + storage.state = ChannelState::Disconnecting; + let _ = self.inbound[idx].try_send(None); + return Ok(storage.conn); } - ChannelState::Connected(state) if cid == state.cid => { - *storage = ChannelState::Disconnecting(DisconnectingState { conn: state.conn, cid }); - return Ok(idx); + ChannelState::Connected if cid == storage.cid => { + let _ = self.inbound[idx].try_send(None); + return Ok(storage.conn); } _ => {} } } Err(Error::NotFound) })?; - let _ = self.inbound[idx].try_send(None); - Ok(()) + Ok(ConnHandle::new(handle)) } fn disconnected(&self, cid: u16) -> Result<(), Error> { self.state.lock(|state| { let mut state = state.borrow_mut(); for storage in state.channels.iter_mut() { - match storage { - ChannelState::Disconnecting(state) if cid == state.cid => { - *storage = ChannelState::Disconnected; + match storage.state { + ChannelState::Disconnecting if cid == storage.cid => { + storage.state = ChannelState::Disconnected; break; } - ChannelState::PeerConnecting(state) if cid == state.cid => { - *storage = ChannelState::Disconnecting(DisconnectingState { conn: state.conn, cid }); + ChannelState::PeerConnecting(_) if cid == storage.cid => { + storage.state = ChannelState::Disconnecting; break; } - ChannelState::Connecting(state) if cid == state.cid => { - *storage = ChannelState::Disconnecting(DisconnectingState { conn: state.conn, cid }); + ChannelState::Connecting(_) if cid == storage.cid => { + storage.state = ChannelState::Disconnecting; break; } - ChannelState::Connected(state) if cid == state.cid => { - *storage = ChannelState::Disconnecting(DisconnectingState { conn: state.conn, cid }); + ChannelState::Connected if cid == storage.cid => { + storage.state = ChannelState::Disconnecting; break; } _ => {} @@ -146,17 +147,17 @@ impl< self.state.lock(|state| { let mut state = state.borrow_mut(); for (idx, storage) in state.channels.iter_mut().enumerate() { - match storage { - ChannelState::PeerConnecting(state) if conn == state.conn => { - *storage = ChannelState::Disconnecting(DisconnectingState { conn, cid: state.cid }); + match storage.state { + ChannelState::PeerConnecting(_) if conn.raw() == storage.conn => { + storage.state = ChannelState::Disconnecting; let _ = self.inbound[idx].try_send(None); } - ChannelState::Connecting(state) if conn == state.conn => { - *storage = ChannelState::Disconnecting(DisconnectingState { conn, cid: state.cid }); + ChannelState::Connecting(_) if conn.raw() == storage.conn => { + storage.state = ChannelState::Disconnecting; let _ = self.inbound[idx].try_send(None); } - ChannelState::Connected(state) if conn == state.conn => { - *storage = ChannelState::Disconnecting(DisconnectingState { conn, cid: state.cid }); + ChannelState::Connected if conn.raw() == storage.conn => { + storage.state = ChannelState::Disconnecting; let _ = self.inbound[idx].try_send(None); } _ => {} @@ -171,32 +172,14 @@ impl< Ok(()) } - fn peer_connect PeerConnectingState>(&self, f: F) -> Result<(), Error> { - self.state.lock(|state| { - let mut state = state.borrow_mut(); - for (idx, storage) in state.channels.iter_mut().enumerate() { - if let ChannelState::Disconnected = storage { - let cid: u16 = BASE_ID + idx as u16; - let mut req = f(idx, cid); - req.cid = cid; - *storage = ChannelState::PeerConnecting(req); - state.accept_waker.wake(); - return Ok(()); - } - } - Err(Error::NoChannelAvailable) - }) - } - - fn connect ConnectingState>(&self, f: F) -> Result<(), Error> { + fn alloc(&self, f: F) -> Result<(), Error> { self.state.lock(|state| { let mut state = state.borrow_mut(); for (idx, storage) in state.channels.iter_mut().enumerate() { - if let ChannelState::Disconnected = storage { + if let ChannelState::Disconnected = storage.state { let cid: u16 = BASE_ID + idx as u16; - let mut req = f(idx, cid); - req.cid = cid; - *storage = ChannelState::Connecting(req); + storage.cid = cid; + f(storage); return Ok(()); } } @@ -204,19 +187,14 @@ impl< }) } - fn connected ConnectedState>( - &self, - request_id: u8, - f: F, - ) -> Result<(), Error> { + fn connected(&self, request_id: u8, f: F) -> Result<(), Error> { self.state.lock(|state| { let mut state = state.borrow_mut(); - for (idx, storage) in state.channels.iter_mut().enumerate() { - match storage { - ChannelState::Connecting(req) if request_id == req.request_id => { - let res = f(idx, req); - // info!("Connection created, properties: {:?}", res); - *storage = ChannelState::Connected(res); + for storage in state.channels.iter_mut() { + match storage.state { + ChannelState::Connecting(req_id) if request_id == req_id => { + f(storage); + storage.state = ChannelState::Connected; state.create_waker.wake(); return Ok(()); } @@ -231,9 +209,9 @@ impl< self.state.lock(|state| { let mut state = state.borrow_mut(); for (idx, storage) in state.channels.iter_mut().enumerate() { - match storage { - ChannelState::Connected(s) if s.peer_cid == cid => { - s.peer_credits += credits; + match storage.state { + ChannelState::Connected if storage.peer_cid == cid => { + storage.peer_credits += credits; state.credit_wakers[idx].wake(); return Ok(()); } @@ -244,59 +222,38 @@ impl< }) } - fn poll_accept ConnectedState>( - &self, - conn: ConnHandle, - psm: &[u16], - cx: &mut Context<'_>, - f: F, - ) -> Poll<(usize, ConnectedState)> { - self.state.lock(|state| { - let mut state = state.borrow_mut(); - for (idx, storage) in state.channels.iter_mut().enumerate() { - match storage { - ChannelState::PeerConnecting(req) if req.conn == conn && psm.contains(&req.psm) => { - let state = f(idx, req); - let cid = state.cid; - *storage = ChannelState::Connected(state.clone()); - return Poll::Ready((idx, state)); - } - _ => {} - } - } - state.accept_waker.register(cx.waker()); - Poll::Pending - }) - } - pub(crate) async fn accept( &self, conn: ConnHandle, psm: &[u16], - mut mtu: u16, + mtu: u16, credit_flow: CreditFlowPolicy, initial_credits: Option, controller: &HciController<'_, T>, ) -> Result> { - let mut req_id = 0; - let (idx, state) = poll_fn(|cx| { - self.poll_accept(conn, psm, cx, |idx, req| { - req_id = req.request_id; - let mps = req.mps.min(self.pool.mtu() as u16 - 4); - mtu = req.mtu.min(mtu); - let credits = initial_credits.unwrap_or(self.pool.min_available(AllocId::dynamic(idx)) as u16); - // info!("Accept L2CAP, initial credits: {}", credits); - ConnectedState { - conn: req.conn, - cid: req.cid, - psm: req.psm, - flow_control: CreditFlowControl::new(credit_flow, credits), - peer_credits: req.offered_credits, - peer_cid: req.peer_cid, - pool_id: AllocId::dynamic(idx), - mps, - mtu, + let (req_id, mps, mtu, cid, credits) = poll_fn(|cx| { + self.state.lock(|state| { + let mut state = state.borrow_mut(); + for chan in state.channels.iter_mut() { + match chan.state { + ChannelState::PeerConnecting(req_id) if chan.conn == conn.raw() && psm.contains(&chan.psm) => { + chan.mps = chan.mps.min(self.pool.mtu() as u16 - 4); + chan.mtu = chan.mtu.min(mtu); + chan.mtu = mtu; + chan.flow_control = CreditFlowControl::new( + credit_flow, + initial_credits + .unwrap_or(self.pool.min_available(AllocId::from_channel(chan.cid)) as u16), + ); + chan.state = ChannelState::Connected; + + return Poll::Ready((req_id, chan.mps, chan.mtu, chan.cid, chan.flow_control.available())); + } + _ => {} + } } + state.accept_waker.register(cx.waker()); + Poll::Pending }) }) .await; @@ -307,8 +264,8 @@ impl< conn, req_id, &LeCreditConnRes { - mps: state.mps, - dcid: state.cid, + mps, + dcid: cid, mtu, credits: 0, result: LeCreditConnResultCode::Success, @@ -319,20 +276,11 @@ impl< // Send initial credits let next_req_id = self.next_request_id(); - controller - .signal( - conn, - next_req_id, - &LeCreditFlowInd { - cid: state.cid, - credits: state.flow_control.available(), - }, - &mut tx[..], - ) + .signal(conn, next_req_id, &LeCreditFlowInd { cid, credits }, &mut tx[..]) .await?; - Ok(state.cid) + Ok(cid) } pub(crate) async fn create( @@ -347,27 +295,24 @@ impl< let req_id = self.next_request_id(); let mut credits = 0; let mut cid: u16 = 0; - self.connect(|i, c| { - cid = c; - credits = initial_credits.unwrap_or(self.pool.min_available(AllocId::dynamic(i)) as u16); - ConnectingState { - conn, - cid, - request_id: req_id, - psm, - initial_credits: credits, - flow_control_policy: credit_flow, - mps: self.pool.mtu() as u16 - 4, - mtu, - } + let mps = self.pool.mtu() as u16 - 4; + + self.alloc(|storage| { + cid = storage.cid; + credits = initial_credits.unwrap_or(self.pool.min_available(AllocId::from_channel(storage.cid)) as u16); + storage.mps = mps; + storage.mtu = mtu; + storage.flow_control = CreditFlowControl::new(credit_flow, credits); + + storage.state = ChannelState::Connecting(req_id); })?; + //info!("Created connect state with idx cid {}", cid); // let mut tx = [0; 18]; - let command = LeCreditConnReq { psm, - mps: self.pool.mtu() as u16 - 4, + mps, scid: cid, mtu, credits: 0, @@ -377,16 +322,16 @@ impl< controller.signal(conn, req_id, &command, &mut tx[..]).await?; // info!("Sent signal packet to remote, awaiting response"); - let (idx, cid) = poll_fn(|cx| { + poll_fn(|cx| { self.state.lock(|state| { let mut state = state.borrow_mut(); - for (idx, storage) in state.channels.iter_mut().enumerate() { - match storage { - ChannelState::Disconnecting(req) if req.conn == conn && req.cid == cid => { + for storage in state.channels.iter_mut() { + match storage.state { + ChannelState::Disconnecting if storage.conn == conn.raw() && storage.cid == cid => { return Poll::Ready(Err(Error::Disconnected)); } - ChannelState::Connected(req) if req.conn == conn && req.cid == cid => { - return Poll::Ready(Ok((idx, req.cid))); + ChannelState::Connected if storage.conn == conn.raw() && storage.cid == cid => { + return Poll::Ready(Ok(())); } _ => {} } @@ -422,13 +367,12 @@ impl< self.state.lock(|state| { let mut state = state.borrow_mut(); for (idx, storage) in state.channels.iter_mut().enumerate() { - match storage { - ChannelState::Connected(state) if header.channel == state.cid => { - if state.flow_control.available() == 0 { - // info!("No credits available on channel {}", state.cid); + match storage.state { + ChannelState::Connected if header.channel == storage.cid => { + if storage.flow_control.available() == 0 { return Err(Error::OutOfMemory); } - state.flow_control.received(1); + storage.flow_control.received(1); } _ => {} } @@ -448,16 +392,18 @@ impl< match header.code { L2capSignalCode::LeCreditConnReq => { let req = LeCreditConnReq::from_hci_bytes_complete(data)?; - self.peer_connect(|i, c| PeerConnectingState { - conn, - cid: c, - psm: req.psm, - request_id: header.identifier, - peer_cid: req.scid, - offered_credits: req.credits, - mps: req.mps, - mtu: req.mtu, + self.alloc(|storage| { + storage.conn = conn.raw(); + storage.psm = req.psm; + storage.peer_cid = req.scid; + storage.peer_credits = req.credits; + storage.mps = req.mps; + storage.mtu = req.mtu; + storage.state = ChannelState::PeerConnecting(header.identifier); })?; + self.state.lock(|state| { + state.borrow_mut().accept_waker.wake(); + }); Ok(()) } L2capSignalCode::LeCreditConnRes => { @@ -466,16 +412,12 @@ impl< match res.result { LeCreditConnResultCode::Success => { // Must be a response of a previous request which should already by allocated a channel for - self.connected(header.identifier, |idx, req| ConnectedState { - conn: req.conn, - cid: req.cid, - psm: req.psm, - flow_control: CreditFlowControl::new(req.flow_control_policy, req.initial_credits), - peer_credits: res.credits, - peer_cid: res.dcid, - pool_id: AllocId::dynamic(idx), - mps: req.mps.min(res.mps), - mtu: req.mtu.min(res.mtu), + self.connected(header.identifier, |state| { + assert_eq!(conn.raw(), state.conn); + state.peer_cid = res.dcid; + state.peer_credits = res.credits; + state.mps = state.mps.min(res.mps); + state.mtu = state.mtu.min(res.mtu); })?; Ok(()) } @@ -511,17 +453,13 @@ impl< } } - fn with_connected_channel R, R>( - &self, - cid: u16, - f: F, - ) -> Result { + fn connected_io_params(&self, cid: u16) -> Result<(usize, ConnHandle, u16, u16), Error> { self.state.lock(|state| { let mut state = state.borrow_mut(); for (idx, chan) in state.channels.iter_mut().enumerate() { - match chan { - ChannelState::Connected(state) if state.cid == cid => { - return Ok(f(idx, state)); + match chan.state { + ChannelState::Connected if chan.cid == cid => { + return Ok((idx, ConnHandle::new(chan.conn), chan.mps, chan.peer_cid)); } _ => {} } @@ -546,7 +484,7 @@ impl< buf: &mut [u8], hci: &HciController<'_, T>, ) -> Result> { - let idx = self.with_connected_channel(cid, |idx, _state| idx)?; + let (idx, _, _, _) = self.connected_io_params(cid)?; let mut n_received = 1; let packet = self.receive_pdu(cid, idx).await?; let len = packet.len; @@ -595,8 +533,7 @@ impl< hci: &HciController<'_, T>, ) -> Result<(), AdapterError> { let mut p_buf = [0u8; L2CAP_MTU]; - let (conn, mps, peer_cid) = - self.with_connected_channel(cid, |_, state| (state.conn, state.mps, state.peer_cid))?; + let (_, conn, mps, peer_cid) = self.connected_io_params(cid)?; // The number of packets we'll need to send for this payload let n_packets = 1 + ((buf.len() as u16).saturating_sub(mps - 2)).div_ceil(mps); // info!("Sending data of len {} into {} packets", buf.len(), n_packets); @@ -627,8 +564,7 @@ impl< hci: &HciController<'_, T>, ) -> Result<(), AdapterError> { let mut p_buf = [0u8; L2CAP_MTU]; - let (conn, mps, peer_cid) = - self.with_connected_channel(cid, |_, state| (state.conn, state.mps, state.peer_cid))?; + let (_, conn, mps, peer_cid) = self.connected_io_params(cid)?; // The number of packets we'll need to send for this payload let n_packets = 1 + ((buf.len() as u16).saturating_sub(mps - 2)).div_ceil(mps); // info!("Sending data of len {} into {} packets", buf.len(), n_packets); @@ -666,10 +602,10 @@ impl< ) -> Result<(), AdapterError> { let (conn, credits) = self.state.lock(|state| { let mut state = state.borrow_mut(); - for (idx, storage) in state.channels.iter_mut().enumerate() { - match storage { - ChannelState::Connected(state) if cid == state.cid => { - return Ok((state.conn, state.flow_control.process())); + for storage in state.channels.iter_mut() { + match storage.state { + ChannelState::Connected if cid == storage.cid => { + return Ok((storage.conn, storage.flow_control.process())); } _ => {} } @@ -682,7 +618,8 @@ impl< let signal = LeCreditFlowInd { cid, credits }; // Reuse packet buffer for signalling data to save the extra TX buffer - hci.signal(conn, identifier, &signal, packet.as_mut()).await?; + hci.signal(ConnHandle::new(conn), identifier, &signal, packet.as_mut()) + .await?; } Ok(()) } @@ -691,9 +628,9 @@ impl< self.state.lock(|state| { let mut state = state.borrow_mut(); for storage in state.channels.iter_mut() { - match storage { - ChannelState::Disconnecting(state) if cid == state.cid => { - *storage = ChannelState::Disconnected; + match storage.state { + ChannelState::Disconnecting if cid == storage.cid => { + storage.state = ChannelState::Disconnected; return Ok(()); } _ => {} @@ -707,10 +644,10 @@ impl< self.state.lock(|state| { let mut state = state.borrow_mut(); for (idx, storage) in state.channels.iter_mut().enumerate() { - match storage { - ChannelState::Connected(s) if cid == s.cid => { - if credits <= s.peer_credits { - s.peer_credits -= credits; + match storage.state { + ChannelState::Connected if cid == storage.cid => { + if credits <= storage.peer_credits { + storage.peer_credits -= credits; return Poll::Ready(Ok(())); } else { if let Some(cx) = cx { @@ -744,12 +681,40 @@ fn encode(data: &[u8], packet: &mut [u8], peer_cid: u16, header: Option) -> Ok(w.len()) } +pub struct ChannelStorage { + state: ChannelState, + conn: u16, + cid: u16, + psm: u16, + mps: u16, + mtu: u16, + flow_control: CreditFlowControl, + + peer_cid: u16, + peer_credits: u16, +} + +impl ChannelStorage { + const DISCONNECTED: ChannelStorage = ChannelStorage { + state: ChannelState::Disconnected, + conn: 0, + cid: 0, + mps: 0, + mtu: 0, + psm: 0, + + flow_control: CreditFlowControl::new(CreditFlowPolicy::Every(1), 0), + peer_cid: 0, + peer_credits: 0, + }; +} + pub enum ChannelState { Disconnected, - Connecting(ConnectingState), - PeerConnecting(PeerConnectingState), - Connected(ConnectedState), - Disconnecting(DisconnectingState), + Connecting(u8), + PeerConnecting(u8), + Connected, + Disconnecting, } /// Control how credits are issued by the receiving end. @@ -777,14 +742,13 @@ pub(crate) struct CreditFlowControl { } impl CreditFlowControl { - fn new(policy: CreditFlowPolicy, initial_credits: u16) -> Self { + const fn new(policy: CreditFlowPolicy, initial_credits: u16) -> Self { Self { policy, credits: initial_credits, received: 0, } } - fn available(&self) -> u16 { self.credits } @@ -819,54 +783,3 @@ impl CreditFlowControl { } } } - -#[derive(Debug)] -#[cfg_attr(feature = "defmt", derive(defmt::Format))] -pub struct ConnectingState { - pub(crate) conn: ConnHandle, - pub(crate) cid: u16, - pub(crate) request_id: u8, - pub(crate) flow_control_policy: CreditFlowPolicy, - - pub(crate) psm: u16, - pub(crate) initial_credits: u16, - pub(crate) mps: u16, - pub(crate) mtu: u16, -} - -#[derive(Debug)] -#[cfg_attr(feature = "defmt", derive(defmt::Format))] -pub struct PeerConnectingState { - pub(crate) conn: ConnHandle, - pub(crate) cid: u16, - pub(crate) request_id: u8, - - pub(crate) psm: u16, - pub(crate) peer_cid: u16, - pub(crate) offered_credits: u16, - pub(crate) mps: u16, - pub(crate) mtu: u16, -} - -#[derive(Debug, Clone)] -#[cfg_attr(feature = "defmt", derive(defmt::Format))] -pub struct ConnectedState { - pub(crate) conn: ConnHandle, - pub(crate) cid: u16, - pub(crate) psm: u16, - pub(crate) mps: u16, - pub(crate) mtu: u16, - pub(crate) flow_control: CreditFlowControl, - - pub(crate) peer_cid: u16, - pub(crate) peer_credits: u16, - - pub(crate) pool_id: AllocId, -} - -#[derive(Debug)] -#[cfg_attr(feature = "defmt", derive(defmt::Format))] -pub struct DisconnectingState { - pub(crate) conn: ConnHandle, - pub(crate) cid: u16, -} diff --git a/host/src/l2cap.rs b/host/src/l2cap.rs index 43cde87..2f131d3 100644 --- a/host/src/l2cap.rs +++ b/host/src/l2cap.rs @@ -1,6 +1,6 @@ use bt_hci::cmd::link_control::Disconnect; use bt_hci::controller::{Controller, ControllerCmdSync}; -use bt_hci::param::{ConnHandle, DisconnectReason}; +use bt_hci::param::DisconnectReason; use embassy_sync::blocking_mutex::raw::RawMutex; use crate::adapter::Adapter; @@ -13,7 +13,6 @@ pub(crate) mod sar; /// Handle representing an L2CAP channel. #[derive(Clone)] pub struct L2capChannel { - handle: ConnHandle, cid: u16, } @@ -114,7 +113,7 @@ impl L2capChannel { ) .await?; - Ok(Self { cid, handle }) + Ok(Self { cid }) } pub fn disconnect< @@ -130,9 +129,9 @@ impl L2capChannel { adapter: &Adapter<'_, M, T, CONNS, CHANNELS, L2CAP_MTU, L2CAP_TXQ, L2CAP_RXQ>, close_connection: bool, ) -> Result<(), AdapterError> { - adapter.channels.disconnect(self.cid)?; + let handle = adapter.channels.disconnect(self.cid)?; if close_connection { - adapter.try_command(Disconnect::new(self.handle, DisconnectReason::RemoteUserTerminatedConn))?; + adapter.try_command(Disconnect::new(handle, DisconnectReason::RemoteUserTerminatedConn))?; } Ok(()) } @@ -165,6 +164,6 @@ where { ) .await?; - Ok(Self { handle, cid }) + Ok(Self { cid }) } }