From e213f073cdeaa8d2f9a4b577fd923d2e495bdbd4 Mon Sep 17 00:00:00 2001 From: Ulf Lilleengen Date: Wed, 3 Apr 2024 18:48:37 +0200 Subject: [PATCH] various improvements * Use random static address by default * Use device address as source on nRF * Add support for filter_accept_list when scanning * Expose connection info on connect * Add non-blocking variant of send --- .../nrf-sdc/src/bin/ble_bas_peripheral.rs | 14 +- examples/nrf-sdc/src/bin/ble_l2cap_central.rs | 13 +- .../nrf-sdc/src/bin/ble_l2cap_peripheral.rs | 14 +- examples/serial-hci/src/main.rs | 4 +- host/src/adapter.rs | 93 ++++++++--- host/src/advertise.rs | 3 +- host/src/channel_manager.rs | 14 +- host/src/connection.rs | 36 +++-- host/src/connection_manager.rs | 30 +++- host/src/l2cap.rs | 144 ++++++++++++++---- host/src/lib.rs | 2 + host/src/scan.rs | 9 +- host/tests/l2cap.rs | 8 +- 13 files changed, 284 insertions(+), 100 deletions(-) diff --git a/examples/nrf-sdc/src/bin/ble_bas_peripheral.rs b/examples/nrf-sdc/src/bin/ble_bas_peripheral.rs index 9599323..c728257 100644 --- a/examples/nrf-sdc/src/bin/ble_bas_peripheral.rs +++ b/examples/nrf-sdc/src/bin/ble_bas_peripheral.rs @@ -2,6 +2,7 @@ #![no_main] #![feature(impl_trait_in_assoc_type)] +use bt_hci::cmd::le::LeSetRandomAddr; use bt_hci::cmd::SyncCmd; use bt_hci::param::BdAddr; use defmt::{error, info, unwrap}; @@ -12,7 +13,6 @@ use embassy_sync::blocking_mutex::raw::NoopRawMutex; use embassy_time::{Duration, Timer}; use nrf_sdc::{self as sdc, mpsl, mpsl::MultiprotocolServiceLayer}; use sdc::rng_pool::RngPool; -use sdc::vendor::ZephyrWriteBdAddr; use static_cell::StaticCell; use trouble_host::{ adapter::{Adapter, HostResources}, @@ -40,8 +40,8 @@ async fn mpsl_task(mpsl: &'static MultiprotocolServiceLayer<'static>) -> ! { fn bd_addr() -> BdAddr { unsafe { let ficr = &*pac::FICR::ptr(); - let high = u64::from((ficr.deviceid[1].read().bits() & 0x0000ffff) | 0x0000c000); - let addr = high << 32 | u64::from(ficr.deviceid[0].read().bits()); + let high = u64::from((ficr.deviceaddr[1].read().bits() & 0x0000ffff) | 0x0000c000); + let addr = high << 32 | u64::from(ficr.deviceaddr[0].read().bits()); BdAddr::new(unwrap!(addr.to_le_bytes()[..6].try_into())) } } @@ -108,8 +108,8 @@ async fn main(spawner: Spawner) { let mut sdc_mem = sdc::Mem::<3312>::new(); let sdc = unwrap!(build_sdc(sdc_p, &rng, mpsl, &mut sdc_mem)); - info!("Advertising as {:02x}", bd_addr()); - unwrap!(ZephyrWriteBdAddr::new(bd_addr()).exec(&sdc).await); + info!("Our address = {:02x}", bd_addr()); + unwrap!(LeSetRandomAddr::new(bd_addr()).exec(&sdc).await); Timer::after(Duration::from_millis(200)).await; static HOST_RESOURCES: StaticCell> = @@ -119,11 +119,11 @@ async fn main(spawner: Spawner) { let adapter: Adapter<'_, NoopRawMutex, _, CONNECTIONS_MAX, L2CAP_CHANNELS_MAX> = Adapter::new(sdc, host_resources); let config = AdvertiseConfig { params: None, - data: &[ + adv_data: &[ AdStructure::Flags(LE_GENERAL_DISCOVERABLE | BR_EDR_NOT_SUPPORTED), AdStructure::ServiceUuids16(&[Uuid::Uuid16([0x0f, 0x18])]), - AdStructure::CompleteLocalName(b"Trouble"), ], + scan_data: &[AdStructure::CompleteLocalName(b"Trouble")], }; let mut table: AttributeTable<'_, NoopRawMutex, 10> = AttributeTable::new(); diff --git a/examples/nrf-sdc/src/bin/ble_l2cap_central.rs b/examples/nrf-sdc/src/bin/ble_l2cap_central.rs index 03e93a4..53a4536 100644 --- a/examples/nrf-sdc/src/bin/ble_l2cap_central.rs +++ b/examples/nrf-sdc/src/bin/ble_l2cap_central.rs @@ -2,6 +2,7 @@ #![no_main] #![feature(impl_trait_in_assoc_type)] +use bt_hci::cmd::le::LeSetRandomAddr; use bt_hci::cmd::SyncCmd; use bt_hci::param::BdAddr; use defmt::{info, unwrap}; @@ -12,7 +13,6 @@ use embassy_sync::blocking_mutex::raw::NoopRawMutex; use embassy_time::{Duration, Timer}; use nrf_sdc::{self as sdc, mpsl, mpsl::MultiprotocolServiceLayer}; use sdc::rng_pool::RngPool; -use sdc::vendor::ZephyrWriteBdAddr; use static_cell::StaticCell; use trouble_host::{ adapter::{Adapter, HostResources}, @@ -41,8 +41,8 @@ async fn mpsl_task(mpsl: &'static MultiprotocolServiceLayer<'static>) -> ! { fn bd_addr() -> BdAddr { unsafe { let ficr = &*pac::FICR::ptr(); - let high = u64::from((ficr.deviceid[1].read().bits() & 0x0000ffff) | 0x0000c000); - let addr = high << 32 | u64::from(ficr.deviceid[0].read().bits()); + let high = u64::from((ficr.deviceaddr[1].read().bits() & 0x0000ffff) | 0x0000c000); + let addr = high << 32 | u64::from(ficr.deviceaddr[0].read().bits()); BdAddr::new(unwrap!(addr.to_le_bytes()[..6].try_into())) } } @@ -117,7 +117,7 @@ async fn main(spawner: Spawner) { let sdc = unwrap!(build_sdc(sdc_p, &rng, mpsl, &mut sdc_mem)); info!("Our address = {:02x}", bd_addr()); - unwrap!(ZephyrWriteBdAddr::new(bd_addr()).exec(&sdc).await); + unwrap!(LeSetRandomAddr::new(bd_addr()).exec(&sdc).await); Timer::after(Duration::from_millis(200)).await; static HOST_RESOURCES: StaticCell> = @@ -126,7 +126,10 @@ async fn main(spawner: Spawner) { let adapter: Adapter<'_, NoopRawMutex, _, CONNECTIONS_MAX, L2CAP_CHANNELS_MAX> = Adapter::new(sdc, host_resources); - let config = ScanConfig { params: None }; + let config = ScanConfig { + params: None, + filter_accept_list: &[], + }; // NOTE: Modify this to match the address of the peripheral you want to connect to let target: BdAddr = BdAddr::new([0xf5, 0x9f, 0x1a, 0x05, 0xe4, 0xee]); diff --git a/examples/nrf-sdc/src/bin/ble_l2cap_peripheral.rs b/examples/nrf-sdc/src/bin/ble_l2cap_peripheral.rs index 37ab81b..c0a1d09 100644 --- a/examples/nrf-sdc/src/bin/ble_l2cap_peripheral.rs +++ b/examples/nrf-sdc/src/bin/ble_l2cap_peripheral.rs @@ -2,6 +2,7 @@ #![no_main] #![feature(impl_trait_in_assoc_type)] +use bt_hci::cmd::le::LeSetRandomAddr; use bt_hci::cmd::SyncCmd; use bt_hci::param::BdAddr; use defmt::{info, unwrap}; @@ -12,7 +13,6 @@ use embassy_sync::blocking_mutex::raw::NoopRawMutex; use embassy_time::{Duration, Timer}; use nrf_sdc::{self as sdc, mpsl, mpsl::MultiprotocolServiceLayer}; use sdc::rng_pool::RngPool; -use sdc::vendor::ZephyrWriteBdAddr; use static_cell::StaticCell; use trouble_host::{ adapter::{Adapter, HostResources}, @@ -40,8 +40,8 @@ async fn mpsl_task(mpsl: &'static MultiprotocolServiceLayer<'static>) -> ! { fn bd_addr() -> BdAddr { unsafe { let ficr = &*pac::FICR::ptr(); - let high = u64::from((ficr.deviceid[1].read().bits() & 0x0000ffff) | 0x0000c000); - let addr = high << 32 | u64::from(ficr.deviceid[0].read().bits()); + let high = u64::from((ficr.deviceaddr[1].read().bits() & 0x0000ffff) | 0x0000c000); + let addr = high << 32 | u64::from(ficr.deviceaddr[0].read().bits()); BdAddr::new(unwrap!(addr.to_le_bytes()[..6].try_into())) } } @@ -116,7 +116,7 @@ async fn main(spawner: Spawner) { let sdc = unwrap!(build_sdc(sdc_p, &rng, mpsl, &mut sdc_mem)); info!("Our address = {:02x}", bd_addr()); - unwrap!(ZephyrWriteBdAddr::new(bd_addr()).exec(&sdc).await); + unwrap!(LeSetRandomAddr::new(bd_addr()).exec(&sdc).await); Timer::after(Duration::from_millis(200)).await; static HOST_RESOURCES: StaticCell> = @@ -127,10 +127,8 @@ async fn main(spawner: Spawner) { let config = AdvertiseConfig { params: None, - data: &[ - AdStructure::Flags(LE_GENERAL_DISCOVERABLE | BR_EDR_NOT_SUPPORTED), - AdStructure::CompleteLocalName(b"Trouble"), - ], + adv_data: &[AdStructure::Flags(LE_GENERAL_DISCOVERABLE | BR_EDR_NOT_SUPPORTED)], + scan_data: &[AdStructure::CompleteLocalName(b"Trouble")], }; let _ = join(adapter.run(), async { diff --git a/examples/serial-hci/src/main.rs b/examples/serial-hci/src/main.rs index 34efe79..0a70584 100644 --- a/examples/serial-hci/src/main.rs +++ b/examples/serial-hci/src/main.rs @@ -64,11 +64,11 @@ async fn main() { let adapter: Adapter<'_, NoopRawMutex, _, 2, 4, 1, 1> = Adapter::new(controller, host_resources); let config = AdvertiseConfig { params: None, - data: &[ + adv_data: &[ AdStructure::Flags(LE_GENERAL_DISCOVERABLE | BR_EDR_NOT_SUPPORTED), AdStructure::ServiceUuids16(&[Uuid::Uuid16([0x0f, 0x18])]), - AdStructure::CompleteLocalName(b"Trouble HCI"), ], + scan_data: &[AdStructure::CompleteLocalName(b"Trouble HCI")], }; let mut table: AttributeTable<'_, NoopRawMutex, 10> = AttributeTable::new(); diff --git a/host/src/adapter.rs b/host/src/adapter.rs index 191a715..5ad0427 100644 --- a/host/src/adapter.rs +++ b/host/src/adapter.rs @@ -14,8 +14,8 @@ use crate::types::l2cap::L2capLeSignal; use crate::{AdapterError, Error}; use bt_hci::cmd::controller_baseband::{Reset, SetEventMask}; use bt_hci::cmd::le::{ - LeCreateConn, LeCreateConnParams, LeReadBufferSize, LeSetAdvData, LeSetAdvEnable, LeSetAdvParams, LeSetScanEnable, - LeSetScanParams, + LeAddDeviceToFilterAcceptList, LeClearFilterAcceptList, LeCreateConn, LeCreateConnParams, LeReadBufferSize, + LeSetAdvData, LeSetAdvEnable, LeSetAdvParams, LeSetScanEnable, LeSetScanParams, LeSetScanResponseData, }; use bt_hci::cmd::link_control::{Disconnect, DisconnectParams}; use bt_hci::cmd::{AsyncCmd, SyncCmd}; @@ -26,6 +26,7 @@ use bt_hci::event::le::LeEvent; use bt_hci::event::Event; use bt_hci::param::{BdAddr, ConnHandle, DisconnectReason, EventMask}; use bt_hci::ControllerToHostPacket; +use core::task::Poll; use embassy_futures::select::{select, Either}; use embassy_sync::blocking_mutex::raw::RawMutex; use embassy_sync::channel::Channel; @@ -104,16 +105,33 @@ where /// Performs a BLE scan, return a report for discovering peripherals. /// /// Scan is stopped when a report is received. Call this method repeatedly to continue scanning. - pub async fn scan(&self, config: &ScanConfig) -> Result> + pub async fn scan(&self, config: &ScanConfig<'_>) -> Result> where - T: ControllerCmdSync + ControllerCmdSync, + T: ControllerCmdSync + + ControllerCmdSync + + ControllerCmdSync + + ControllerCmdSync, { + LeClearFilterAcceptList::new().exec(&self.controller).await?; + + if !config.filter_accept_list.is_empty() { + for entry in config.filter_accept_list { + LeAddDeviceToFilterAcceptList::new(entry.0, *entry.1) + .exec(&self.controller) + .await?; + } + } + let params = config.params.unwrap_or(LeSetScanParams::new( bt_hci::param::LeScanKind::Active, bt_hci::param::Duration::from_millis(1_000), bt_hci::param::Duration::from_millis(1_000), - bt_hci::param::AddrKind::PUBLIC, - bt_hci::param::ScanningFilterPolicy::BasicUnfiltered, + bt_hci::param::AddrKind::RANDOM, + if config.filter_accept_list.is_empty() { + bt_hci::param::ScanningFilterPolicy::BasicUnfiltered + } else { + bt_hci::param::ScanningFilterPolicy::BasicFiltered + }, )); params.exec(&self.controller).await?; @@ -130,14 +148,17 @@ where /// in which case a handle for the connection is returned. pub async fn advertise<'m>(&'m self, config: &AdvertiseConfig<'_>) -> Result, AdapterError> where - T: ControllerCmdSync + ControllerCmdSync + ControllerCmdSync, + T: ControllerCmdSync + + ControllerCmdSync + + ControllerCmdSync + + ControllerCmdSync, { let params = &config.params.unwrap_or(LeSetAdvParams::new( bt_hci::param::Duration::from_millis(400), bt_hci::param::Duration::from_millis(400), bt_hci::param::AdvKind::AdvInd, - bt_hci::param::AddrKind::PUBLIC, - bt_hci::param::AddrKind::PUBLIC, + bt_hci::param::AddrKind::RANDOM, + bt_hci::param::AddrKind::RANDOM, BdAddr::default(), bt_hci::param::AdvChannelMap::ALL, bt_hci::param::AdvFilterPolicy::default(), @@ -145,13 +166,28 @@ where params.exec(&self.controller).await?; - let mut data = [0; 31]; - let mut w = WriteCursor::new(&mut data[..]); - for item in config.data.iter() { - item.encode(&mut w)?; + if !config.adv_data.is_empty() { + let mut data = [0; 31]; + let mut w = WriteCursor::new(&mut data[..]); + for item in config.adv_data.iter() { + item.encode(&mut w)?; + } + let len = w.len(); + LeSetAdvData::new(len as u8, data).exec(&self.controller).await?; } - let len = w.len(); - LeSetAdvData::new(len as u8, data).exec(&self.controller).await?; + + if !config.scan_data.is_empty() { + let mut data = [0; 31]; + let mut w = WriteCursor::new(&mut data[..]); + for item in config.scan_data.iter() { + item.encode(&mut w)?; + } + let len = w.len(); + LeSetScanResponseData::new(len as u8, data) + .exec(&self.controller) + .await?; + } + LeSetAdvEnable::new(true).exec(&self.controller).await?; let conn = Connection::accept(self).await; LeSetAdvEnable::new(false).exec(&self.controller).await?; @@ -197,7 +233,9 @@ where } other if other >= L2CAP_CID_DYN_START => match self.channels.dispatch(packet).await { - Ok(_) => {} + Ok(_) => { + info!("L2CAP packet dispatched!"); + } Err(e) => { warn!("Error dispatching l2cap packet to channel: {:?}", e); } @@ -235,6 +273,7 @@ where Ok(ControllerToHostPacket::Event(event)) => match event { Event::Le(event) => match event { LeEvent::LeConnectionComplete(e) => { + warn!("CONNECTION COMPLET!"); if let Err(err) = self.connections.connect( e.handle, ConnectionInfo { @@ -359,11 +398,29 @@ where } pub struct HciController<'d, T: Controller> { - controller: &'d T, - permits: &'d LocalSemaphore, + pub(crate) controller: &'d T, + pub(crate) permits: &'d LocalSemaphore, } impl<'d, T: Controller> HciController<'d, T> { + pub(crate) fn try_send(&self, handle: ConnHandle, pdu: &[u8]) -> Result<(), AdapterError> { + let permit = self + .permits + .try_acquire(1) + .ok_or::>(Error::Busy.into())?; + let acl = AclPacket::new( + handle, + AclPacketBoundary::FirstNonFlushable, + AclBroadcastFlag::PointToPoint, + pdu, + ); + let fut = self.controller.write_acl_data(&acl); + match embassy_futures::poll_once(fut) { + Poll::Ready(result) => result.map_err(AdapterError::Controller), + Poll::Pending => Err(Error::Busy.into()), + } + } + pub(crate) async fn send(&self, handle: ConnHandle, pdu: &[u8]) -> Result<(), AdapterError> { self.permits.acquire(1).await.disarm(); let acl = AclPacket::new( diff --git a/host/src/advertise.rs b/host/src/advertise.rs index ac823c4..0ec157f 100644 --- a/host/src/advertise.rs +++ b/host/src/advertise.rs @@ -7,7 +7,8 @@ use bt_hci::cmd::le::LeSetAdvParams; pub struct AdvertiseConfig<'d> { pub params: Option, - pub data: &'d [AdStructure<'d>], + pub adv_data: &'d [AdStructure<'d>], + pub scan_data: &'d [AdStructure<'d>], } pub const AD_FLAG_LE_LIMITED_DISCOVERABLE: u8 = 0b00000001; diff --git a/host/src/channel_manager.rs b/host/src/channel_manager.rs index d5abb2b..119f0ce 100644 --- a/host/src/channel_manager.rs +++ b/host/src/channel_manager.rs @@ -40,7 +40,7 @@ pub struct ChannelManager<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TX } pub trait DynamicChannelManager<'d> { - fn poll_request_to_send(&self, cid: u16, credits: usize, cx: &mut Context<'_>) -> Poll>; + fn poll_request_to_send(&self, cid: u16, credits: usize, cx: Option<&mut Context<'_>>) -> Poll>; fn confirm_received(&self, cid: u16, credits: usize) -> Result<(ConnHandle, L2capLeSignal), Error>; fn confirm_disconnected(&self, cid: u16) -> Result<(), Error>; } @@ -227,6 +227,8 @@ impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP }), ); + info!("Responding to create: {:?}", response); + controller.signal(conn, response).await?; Ok((state, self.inbound[idx].receiver().into())) } @@ -302,13 +304,13 @@ impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP self.inbound[chan].send(Some(Pdu::new(p, len))).await; Ok(()) } else { - warn!("No memory for channel {}", packet.channel); + warn!("No memory for channel {} (id {})", packet.channel, chan); Err(Error::OutOfMemory) } } pub async fn control(&self, conn: ConnHandle, signal: L2capLeSignal) -> Result<(), Error> { - // info!("Inbound signal: {:?}", signal); + info!("Inbound signal: {:?}", signal); match signal.data { L2capLeSignalData::LeCreditConnReq(req) => { self.connect(ConnectingState { @@ -414,7 +416,7 @@ impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP }) } - fn poll_request_to_send(&self, cid: u16, credits: usize, cx: &mut Context<'_>) -> Poll> { + fn poll_request_to_send(&self, cid: u16, credits: usize, cx: Option<&mut Context<'_>>) -> Poll> { self.state.lock(|state| { let mut state = state.borrow_mut(); for (idx, storage) in state.channels.iter_mut().enumerate() { @@ -424,7 +426,9 @@ impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP s.peer_credits -= credits as u16; return Poll::Ready(Ok(())); } else { - state.credit_wakers[idx].register(cx.waker()); + if let Some(cx) = cx { + state.credit_wakers[idx].register(cx.waker()); + } return Poll::Pending; } } diff --git a/host/src/connection.rs b/host/src/connection.rs index a6ac069..39294d4 100644 --- a/host/src/connection.rs +++ b/host/src/connection.rs @@ -1,20 +1,22 @@ use bt_hci::{ cmd::{le::LeCreateConnParams, link_control::DisconnectParams}, - param::{AddrKind, BdAddr, ConnHandle, DisconnectReason, Duration}, + param::{AddrKind, BdAddr, ConnHandle, DisconnectReason, Duration, LeConnRole}, }; use embassy_sync::{blocking_mutex::raw::RawMutex, channel::DynamicSender}; use crate::adapter::{Adapter, ControlCommand}; +pub use crate::connection_manager::ConnectionInfo; + #[derive(Clone)] pub struct Connection<'d> { - handle: ConnHandle, + info: ConnectionInfo, control: DynamicSender<'d, ControlCommand>, } impl<'d> Connection<'d> { pub fn handle(&self) -> ConnHandle { - self.handle + self.info.handle } pub async fn accept< @@ -27,9 +29,9 @@ impl<'d> Connection<'d> { >( adapter: &'d Adapter<'_, M, T, CONNS, CHANNELS, L2CAP_TXQ, L2CAP_RXQ>, ) -> Self { - let handle = adapter.connections.accept(None).await; + let info = adapter.connections.accept(None).await; Connection { - handle, + info, control: adapter.control.sender().into(), } } @@ -37,12 +39,20 @@ impl<'d> Connection<'d> { pub async fn disconnect(&mut self) { self.control .send(ControlCommand::Disconnect(DisconnectParams { - handle: self.handle, + handle: self.info.handle, reason: DisconnectReason::RemoteUserTerminatedConn, })) .await; } + pub fn role(&self) -> LeConnRole { + self.info.role + } + + pub fn peer_address(&self) -> BdAddr { + self.info.peer_address + } + pub async fn connect< M: RawMutex, T, @@ -58,21 +68,21 @@ impl<'d> Connection<'d> { let params = LeCreateConnParams { le_scan_interval: Duration::from_micros(1707500), le_scan_window: Duration::from_micros(312500), - use_filter_accept_list: false, - peer_addr_kind: AddrKind::PUBLIC, + use_filter_accept_list: true, + peer_addr_kind: AddrKind::RANDOM, peer_addr, own_addr_kind: AddrKind::PUBLIC, - conn_interval_min: Duration::from_millis(25), - conn_interval_max: Duration::from_millis(50), + conn_interval_min: Duration::from_millis(80), + conn_interval_max: Duration::from_millis(80), max_latency: 0, - supervision_timeout: Duration::from_millis(250), + supervision_timeout: Duration::from_millis(8000), min_ce_length: Duration::from_millis(0), max_ce_length: Duration::from_millis(0), }; adapter.control.send(ControlCommand::Connect(params)).await; - let handle = adapter.connections.accept(Some(params.peer_addr)).await; + let info = adapter.connections.accept(Some(params.peer_addr)).await; Connection { - handle, + info, control: adapter.control.sender().into(), } } diff --git a/host/src/connection_manager.rs b/host/src/connection_manager.rs index 05bbc04..aaf8193 100644 --- a/host/src/connection_manager.rs +++ b/host/src/connection_manager.rs @@ -64,21 +64,21 @@ impl ConnectionManager { }) } - pub fn poll_accept(&self, peer: Option, cx: &mut Context<'_>) -> Poll { + pub fn poll_accept(&self, peer: Option, cx: &mut Context<'_>) -> Poll { self.state.lock(|state| { let mut state = state.borrow_mut(); for storage in state.connections.iter_mut() { if let ConnectionState::Connecting(handle, info) = storage { if let Some(peer) = peer { if info.peer_address == peer { - let handle = *handle; - *storage = ConnectionState::Connected(handle, *info); - return Poll::Ready(handle); + let i = *info; + *storage = ConnectionState::Connected(*handle, *info); + return Poll::Ready(i); } } else { - let handle = *handle; - *storage = ConnectionState::Connected(handle, *info); - return Poll::Ready(handle); + let i = *info; + *storage = ConnectionState::Connected(*handle, *info); + return Poll::Ready(i); } } } @@ -87,9 +87,23 @@ impl ConnectionManager { }) } - pub async fn accept(&self, peer: Option) -> ConnHandle { + pub async fn accept(&self, peer: Option) -> ConnectionInfo { poll_fn(move |cx| self.poll_accept(peer, cx)).await } + + pub fn info(&self, handle: ConnHandle) -> Result { + self.state.lock(|state| { + let mut state = state.borrow_mut(); + for storage in state.connections.iter_mut() { + if let ConnectionState::Connected(h, info) = storage { + if *h == handle { + return Ok(*info); + } + } + } + Err(Error::NotFound) + }) + } } pub enum ConnectionState { diff --git a/host/src/l2cap.rs b/host/src/l2cap.rs index bc9976b..47ff41c 100644 --- a/host/src/l2cap.rs +++ b/host/src/l2cap.rs @@ -1,6 +1,6 @@ use core::future::poll_fn; -use crate::adapter::{Adapter, HciController}; +use crate::adapter::{Adapter, ControlCommand, HciController}; use crate::channel_manager::DynamicChannelManager; use crate::codec; use crate::connection::Connection; @@ -8,11 +8,15 @@ use crate::cursor::{ReadCursor, WriteCursor}; use crate::packet_pool::{AllocId, DynamicPacketPool}; use crate::pdu::Pdu; use crate::{AdapterError, Error}; +use bt_hci::cmd::link_control::DisconnectParams; use bt_hci::controller::Controller; use bt_hci::data::AclPacket; use bt_hci::param::ConnHandle; +use bt_hci::param::DisconnectReason; +use core::task::Poll; use embassy_sync::blocking_mutex::raw::RawMutex; use embassy_sync::channel::DynamicReceiver; +use embassy_sync::channel::DynamicSender; pub(crate) const L2CAP_CID_ATT: u16 = 0x0004; pub(crate) const L2CAP_CID_LE_U_SIGNAL: u16 = 0x0005; @@ -29,6 +33,7 @@ impl<'d> L2capPacket<'d> { pub fn decode(packet: AclPacket<'_>) -> Result<(bt_hci::param::ConnHandle, L2capPacket), codec::Error> { let handle = packet.handle(); let data = packet.data(); + let mut r = ReadCursor::new(data); let length: u16 = r.read()?; let channel: u16 = r.read()?; @@ -55,10 +60,54 @@ pub struct L2capChannel<'a, 'd, T: Controller, const MTU: usize> { pool: &'d dyn DynamicPacketPool<'d>, manager: &'a dyn DynamicChannelManager<'d>, rx: DynamicReceiver<'a, Option>>, + control: DynamicSender<'a, ControlCommand>, tx: HciController<'a, T>, } +impl<'a, 'd, T: Controller, const MTU: usize> Clone for L2capChannel<'a, 'd, T, MTU> { + fn clone(&self) -> Self { + Self { + conn: self.conn, + pool_id: self.pool_id, + cid: self.cid, + peer_cid: self.peer_cid, + mps: self.mps, + pool: self.pool, + manager: self.manager, + rx: self.rx, + tx: HciController { + controller: self.tx.controller, + permits: self.tx.permits, + }, + control: self.control, + } + } +} + impl<'a, 'd, T: Controller, const MTU: usize> L2capChannel<'a, 'd, T, MTU> { + fn encode(&self, data: &[u8], header: Option) -> Result, Error> { + if let Some(mut packet) = self.pool.alloc(self.pool_id) { + let mut w = WriteCursor::new(packet.as_mut()); + if header.is_some() { + w.write(2 + data.len() as u16)?; + } else { + w.write(data.len() as u16)?; + } + w.write(self.peer_cid)?; + + if let Some(len) = header { + w.write(len)?; + } + + w.append(data)?; + let len = w.len(); + let pdu = Pdu::new(packet, len); + Ok(pdu) + } else { + Err(Error::OutOfMemory) + } + } + pub async fn send(&mut self, buf: &[u8]) -> Result<(), AdapterError> { // The number of packets we'll need to send for this payload let n_packets = 1 + (buf.len().saturating_sub(self.mps - 2)).div_ceil(self.mps); @@ -70,42 +119,53 @@ impl<'a, 'd, T: Controller, const MTU: usize> L2capChannel<'a, 'd, T, MTU> { return Err(Error::OutOfMemory.into()); } - poll_fn(|cx| self.manager.poll_request_to_send(self.cid, n_packets, cx)).await?; + poll_fn(|cx| self.manager.poll_request_to_send(self.cid, n_packets, Some(cx))).await?; // Segment using mps - let (first, remaining) = buf.split_at(self.mps - 2); - if let Some(mut packet) = self.pool.alloc(self.pool_id) { - let len = { - let mut w = WriteCursor::new(packet.as_mut()); - w.write(2 + first.len() as u16)?; - w.write(self.peer_cid)?; - let len = buf.len() as u16; - w.write(len)?; - w.append(first)?; - w.len() - }; - let pdu = Pdu::new(packet, len); + let (first, remaining) = buf.split_at(buf.len().min(self.mps - 2)); + + let pdu = self.encode(first, Some(buf.len() as u16))?; + self.tx.send(self.conn, pdu.as_ref()).await?; + + let chunks = remaining.chunks(self.mps); + let num_chunks = chunks.len(); + + for (i, chunk) in chunks.enumerate() { + let pdu = self.encode(chunk, None)?; self.tx.send(self.conn, pdu.as_ref()).await?; - } else { + } + + Ok(()) + } + + pub fn try_send(&mut self, buf: &[u8]) -> Result<(), AdapterError> { + // The number of packets we'll need to send for this payload + let n_packets = 1 + (buf.len().saturating_sub(self.mps - 2)).div_ceil(self.mps); + + // TODO: We could potentially make this more graceful by sending as much as we can, and wait + // for pool to get the available packets back, which would require some poll/async behavior + // support for the pool. + if self.pool.available(self.pool_id) < n_packets { return Err(Error::OutOfMemory.into()); } + match self.manager.poll_request_to_send(self.cid, n_packets, None) { + Poll::Ready(res) => res?, + Poll::Pending => return Err(Error::Busy.into()), + } + + // Segment using mps + let (first, remaining) = buf.split_at(buf.len().min(self.mps - 2)); + + let pdu = self.encode(first, Some(buf.len() as u16))?; + self.tx.try_send(self.conn, pdu.as_ref())?; + let chunks = remaining.chunks(self.mps); let num_chunks = chunks.len(); + for (i, chunk) in chunks.enumerate() { - if let Some(mut packet) = self.pool.alloc(self.pool_id) { - let len = { - let mut w = WriteCursor::new(packet.as_mut()); - w.write(chunk.len() as u16)?; - w.write(self.peer_cid)?; - w.append(chunk)?; - w.len() - }; - let pdu = Pdu::new(packet, len); - self.tx.send(self.conn, pdu.as_ref()).await?; - } else { - return Err(Error::OutOfMemory.into()); - } + let pdu = self.encode(chunk, None)?; + self.tx.try_send(self.conn, pdu.as_ref())?; } Ok(()) @@ -124,15 +184,25 @@ impl<'a, 'd, T: Controller, const MTU: usize> L2capChannel<'a, 'd, T, MTU> { pub async fn receive(&mut self, buf: &mut [u8]) -> Result> { let mut n_received = 1; let packet = self.receive_pdu().await?; + let len = packet.len; + let mut r = ReadCursor::new(packet.as_ref()); let remaining: u16 = r.read()?; - let data = r.remaining(); + info!("Total expected: {}", remaining); + let data = r.remaining(); let to_copy = data.len().min(buf.len()); buf[..to_copy].copy_from_slice(&data[..to_copy]); let mut pos = to_copy; + info!("Received {} bytes so far", pos); let mut remaining = remaining as usize - data.len(); + info!( + "Total size of PDU is {}, read buffer size is {} remaining; {}", + len, + buf.len(), + remaining + ); // We have some k-frames to reassemble while remaining > 0 { let packet = self.receive_pdu().await?; @@ -148,6 +218,7 @@ impl<'a, 'd, T: Controller, const MTU: usize> L2capChannel<'a, 'd, T, MTU> { let (handle, response) = self.manager.confirm_received(self.cid, n_received)?; self.tx.signal(handle, response).await?; + info!("Total reserved {} bytes", pos); Ok(pos) } @@ -178,10 +249,24 @@ impl<'a, 'd, T: Controller, const MTU: usize> L2capChannel<'a, 'd, T, MTU> { pool_id: state.pool_id, manager: &adapter.channels, tx: adapter.hci(), + control: adapter.control.sender().into(), rx, }) } + pub fn disconnect(&self, close_connection: bool) -> Result<(), AdapterError> { + self.manager.confirm_disconnected(self.cid)?; + if close_connection { + self.control + .try_send(ControlCommand::Disconnect(DisconnectParams { + handle: self.conn, + reason: DisconnectReason::RemoteUserTerminatedConn, + })) + .map_err(|_| Error::Busy)?; + } + Ok(()) + } + pub async fn create< M: RawMutex, const CONNS: usize, @@ -210,6 +295,7 @@ where { pool: adapter.pool, manager: &adapter.channels, tx: adapter.hci(), + control: adapter.control.sender().into(), rx, }) } diff --git a/host/src/lib.rs b/host/src/lib.rs index ddb75bd..656163b 100644 --- a/host/src/lib.rs +++ b/host/src/lib.rs @@ -50,6 +50,8 @@ pub enum Error { OutOfMemory, NotSupported, ChannelClosed, + Busy, + Disconnected, Other, } diff --git a/host/src/scan.rs b/host/src/scan.rs index e88f746..fe5d333 100644 --- a/host/src/scan.rs +++ b/host/src/scan.rs @@ -1,10 +1,15 @@ use core::iter::FusedIterator; -use bt_hci::{cmd::le::LeSetScanParams, param::LeAdvReport, FromHciBytes, FromHciBytesError}; +use bt_hci::{ + cmd::le::LeSetScanParams, + param::{AddrKind, BdAddr, LeAdvReport}, + FromHciBytes, FromHciBytesError, +}; use heapless::Vec; -pub struct ScanConfig { +pub struct ScanConfig<'d> { pub params: Option, + pub filter_accept_list: &'d [(AddrKind, &'d BdAddr)], } pub struct ScanReport { diff --git a/host/tests/l2cap.rs b/host/tests/l2cap.rs index e668724..ea6382d 100644 --- a/host/tests/l2cap.rs +++ b/host/tests/l2cap.rs @@ -77,10 +77,11 @@ async fn l2cap_connection_oriented_channels() { let config = AdvertiseConfig { params: None, - data: &[ + adv_data: &[ AdStructure::Flags(LE_GENERAL_DISCOVERABLE | BR_EDR_NOT_SUPPORTED), AdStructure::CompleteLocalName(b"trouble-l2cap-int"), ], + scan_data: &[], }; select! { @@ -132,7 +133,10 @@ async fn l2cap_connection_oriented_channels() { let adapter: Adapter<'_, NoopRawMutex, _, CONNECTIONS_MAX, L2CAP_CHANNELS_MAX> = Adapter::new(controller_central, &mut host_resources); - let config = ScanConfig { params: None }; + let config = ScanConfig { + params: None, + filter_accept_list: &[], + }; select! { r = adapter.run() => {