feat: Feature match udp sockets

fix: fixed compile proto-ipv4/v6 edge cases in the ping module
This commit is contained in:
skkeye 2025-02-10 00:51:59 -05:00 committed by Ulf Lilleengen
parent 7d2ffa76e5
commit 7b35265465

View File

@ -1,8 +1,8 @@
//! ICMP sockets. //! ICMP sockets.
use core::future::poll_fn; use core::future::{poll_fn, Future};
use core::mem; use core::mem;
use core::task::Poll; use core::task::{Context, Poll};
use smoltcp::iface::{Interface, SocketHandle}; use smoltcp::iface::{Interface, SocketHandle};
pub use smoltcp::phy::ChecksumCapabilities; pub use smoltcp::phy::ChecksumCapabilities;
@ -36,6 +36,8 @@ pub enum SendError {
NoRoute, NoRoute,
/// Socket not bound to an outgoing port. /// Socket not bound to an outgoing port.
SocketNotBound, SocketNotBound,
/// There is not enough transmit buffer capacity to ever send this packet.
PacketTooLarge,
} }
/// Error returned by [`IcmpSocket::recv_from`]. /// Error returned by [`IcmpSocket::recv_from`].
@ -109,25 +111,61 @@ impl<'a> IcmpSocket<'a> {
}) })
} }
/// Dequeue a packet received from a remote endpoint, copy the payload into the given slice, /// Wait until the socket becomes readable.
/// and return the amount of octets copied as well as the `IpAddress`
/// ///
/// **Note**: when the size of the provided buffer is smaller than the size of the payload, /// A socket is readable when a packet has been received, or when there are queued packets in
/// the packet is dropped and a `RecvError::Truncated` error is returned. /// the buffer.
pub async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, IpAddress), RecvError> { pub fn wait_recv_ready(&self) -> impl Future<Output = ()> + '_ {
poll_fn(move |cx| { poll_fn(move |cx| self.poll_recv_ready(cx))
self.with_mut(|s, _| match s.recv_slice(buf) { }
Ok(x) => Poll::Ready(Ok(x)),
// No data ready /// Wait until a datagram can be read.
Err(icmp::RecvError::Exhausted) => { ///
//s.register_recv_waker(cx.waker()); /// When no datagram is readable, this method will return `Poll::Pending` and
cx.waker().wake_by_ref(); /// register the current task to be notified when a datagram is received.
///
/// When a datagram is received, this method will return `Poll::Ready`.
pub fn poll_recv_ready(&self, cx: &mut Context<'_>) -> Poll<()> {
self.with_mut(|s, _| {
if s.can_recv() {
Poll::Ready(())
} else {
// socket buffer is empty wait until at least one byte has arrived
s.register_recv_waker(cx.waker());
Poll::Pending Poll::Pending
} }
})
}
/// Receive a datagram.
///
/// This method will wait until a datagram is received.
///
/// Returns the number of bytes received and the remote endpoint.
pub fn recv_from<'s>(
&'s self,
buf: &'s mut [u8],
) -> impl Future<Output = Result<(usize, IpAddress), RecvError>> + 's {
poll_fn(|cx| self.poll_recv_from(buf, cx))
}
/// Receive a datagram.
///
/// When no datagram is available, this method will return `Poll::Pending` and
/// register the current task to be notified when a datagram is received.
///
/// When a datagram is received, this method will return `Poll::Ready` with the
/// number of bytes received and the remote endpoint.
pub fn poll_recv_from(&self, buf: &mut [u8], cx: &mut Context<'_>) -> Poll<Result<(usize, IpAddress), RecvError>> {
self.with_mut(|s, _| match s.recv_slice(buf) {
Ok((n, meta)) => Poll::Ready(Ok((n, meta))),
// No data ready
Err(icmp::RecvError::Truncated) => Poll::Ready(Err(RecvError::Truncated)), Err(icmp::RecvError::Truncated) => Poll::Ready(Err(RecvError::Truncated)),
Err(icmp::RecvError::Exhausted) => {
s.register_recv_waker(cx.waker());
Poll::Pending
}
}) })
})
.await
} }
/// Dequeue a packet received from a remote endpoint and calls the provided function with the /// Dequeue a packet received from a remote endpoint and calls the provided function with the
@ -136,7 +174,7 @@ impl<'a> IcmpSocket<'a> {
/// ///
/// **Note**: when the size of the provided buffer is smaller than the size of the payload, /// **Note**: when the size of the provided buffer is smaller than the size of the payload,
/// the packet is dropped and a `RecvError::Truncated` error is returned. /// the packet is dropped and a `RecvError::Truncated` error is returned.
pub async fn recv_with<F, R>(&self, f: F) -> Result<R, RecvError> pub async fn recv_from_with<F, R>(&self, f: F) -> Result<R, RecvError>
where where
F: FnOnce((&[u8], IpAddress)) -> R, F: FnOnce((&[u8], IpAddress)) -> R,
{ {
@ -154,48 +192,130 @@ impl<'a> IcmpSocket<'a> {
.await .await
} }
/// Enqueue a packet to be sent to a given remote address, and fill it from a slice. /// Wait until the socket becomes writable.
///
/// A socket becomes writable when there is space in the buffer, from initial memory or after
/// dispatching datagrams on a full buffer.
pub fn wait_send_ready(&self) -> impl Future<Output = ()> + '_ {
poll_fn(|cx| self.poll_send_ready(cx))
}
/// Wait until a datagram can be sent.
///
/// When no datagram can be sent (i.e. the buffer is full), this method will return
/// `Poll::Pending` and register the current task to be notified when
/// space is freed in the buffer after a datagram has been dispatched.
///
/// When a datagram can be sent, this method will return `Poll::Ready`.
pub fn poll_send_ready(&self, cx: &mut Context<'_>) -> Poll<()> {
self.with_mut(|s, _| {
if s.can_send() {
Poll::Ready(())
} else {
// socket buffer is full wait until a datagram has been dispatched
s.register_send_waker(cx.waker());
Poll::Pending
}
})
}
/// Send a datagram to the specified remote endpoint.
///
/// This method will wait until the datagram has been sent.
///
/// If the socket's send buffer is too small to fit `buf`, this method will return `Err(SendError::PacketTooLarge)`
///
/// When the remote endpoint is not reachable, this method will return `Err(SendError::NoRoute)`
pub async fn send_to<T>(&self, buf: &[u8], remote_endpoint: T) -> Result<(), SendError> pub async fn send_to<T>(&self, buf: &[u8], remote_endpoint: T) -> Result<(), SendError>
where where
T: Into<IpAddress>, T: Into<IpAddress>,
{ {
let remote_endpoint = remote_endpoint.into(); let remote_endpoint: IpAddress = remote_endpoint.into();
poll_fn(move |cx| { poll_fn(move |cx| self.poll_send_to(buf, remote_endpoint, cx)).await
self.with_mut(|s, _| match s.send_slice(buf, remote_endpoint) { }
/// Send a datagram to the specified remote endpoint.
///
/// When the datagram has been sent, this method will return `Poll::Ready(Ok())`.
///
/// When the socket's send buffer is full, this method will return `Poll::Pending`
/// and register the current task to be notified when the buffer has space available.
///
/// If the socket's send buffer is too small to fit `buf`, this method will return `Poll::Ready(Err(SendError::PacketTooLarge))`
///
/// When the remote endpoint is not reachable, this method will return `Poll::Ready(Err(Error::NoRoute))`.
pub fn poll_send_to<T>(&self, buf: &[u8], remote_endpoint: T, cx: &mut Context<'_>) -> Poll<Result<(), SendError>>
where
T: Into<IpAddress>,
{
// Don't need to wake waker in `with_mut` if the buffer will never fit the icmp tx_buffer.
let send_capacity_too_small = self.with(|s, _| s.payload_send_capacity() < buf.len());
if send_capacity_too_small {
return Poll::Ready(Err(SendError::PacketTooLarge));
}
self.with_mut(|s, _| match s.send_slice(buf, remote_endpoint.into()) {
// Entire datagram has been sent // Entire datagram has been sent
Ok(()) => Poll::Ready(Ok(())), Ok(()) => Poll::Ready(Ok(())),
Err(icmp::SendError::BufferFull) => { Err(icmp::SendError::BufferFull) => {
s.register_send_waker(cx.waker()); s.register_send_waker(cx.waker());
Poll::Pending Poll::Pending
} }
Err(icmp::SendError::Unaddressable) => {
// If no sender/outgoing port is specified, there is not really "no route"
if s.is_open() {
Poll::Ready(Err(SendError::NoRoute))
} else {
Poll::Ready(Err(SendError::SocketNotBound))
}
}
})
}
/// Enqueue a packet to be sent to a given remote address with a zero-copy function.
///
/// This method will wait until the buffer can fit the requested size before
/// calling the function to fill its contents.
pub async fn send_to_with<T, F, R>(&mut self, size: usize, remote_endpoint: T, f: F) -> Result<R, SendError>
where
T: Into<IpAddress>,
F: FnOnce(&mut [u8]) -> R,
{
// Don't need to wake waker in `with_mut` if the buffer will never fit the icmp tx_buffer.
let send_capacity_too_small = self.with(|s, _| s.payload_send_capacity() < size);
if send_capacity_too_small {
return Err(SendError::PacketTooLarge);
}
let mut f = Some(f);
let remote_endpoint = remote_endpoint.into();
poll_fn(move |cx| {
self.with_mut(|s, _| match s.send(size, remote_endpoint) {
Ok(buf) => Poll::Ready(Ok({ unwrap!(f.take())(buf) })),
Err(icmp::SendError::BufferFull) => {
s.register_send_waker(cx.waker());
Poll::Pending
}
Err(icmp::SendError::Unaddressable) => Poll::Ready(Err(SendError::NoRoute)), Err(icmp::SendError::Unaddressable) => Poll::Ready(Err(SendError::NoRoute)),
}) })
}) })
.await .await
} }
/// Enqueue a packet to be sent to a given remote address with a zero-copy function. /// Flush the socket.
/// ///
/// This method will wait until the buffer can fit the requested size before /// This method will wait until the socket is flushed.
/// calling the function to fill its contents. pub fn flush(&mut self) -> impl Future<Output = ()> + '_ {
pub async fn send_to_with<T, F, R>(&self, size: usize, remote_endpoint: T, f: F) -> Result<R, SendError> poll_fn(|cx| {
where self.with_mut(|s, _| {
T: Into<IpAddress>, if s.send_queue() == 0 {
F: FnOnce(&mut [u8]) -> R, Poll::Ready(())
{ } else {
let mut f = Some(f);
let remote_endpoint = remote_endpoint.into();
poll_fn(move |cx| {
self.with_mut(|s, _| match s.send(size, remote_endpoint) {
Ok(buf) => Poll::Ready(Ok(unwrap!(f.take())(buf))),
Err(icmp::SendError::BufferFull) => {
s.register_send_waker(cx.waker()); s.register_send_waker(cx.waker());
Poll::Pending Poll::Pending
} }
Err(icmp::SendError::Unaddressable) => Poll::Ready(Err(SendError::NoRoute)),
}) })
}) })
.await
} }
/// Check whether the socket is open. /// Check whether the socket is open.
@ -280,9 +400,15 @@ pub mod ping {
//! }; //! };
//! ``` //! ```
use core::net::{IpAddr, Ipv6Addr}; use core::net::IpAddr;
#[cfg(feature = "proto-ipv6")]
use core::net::Ipv6Addr;
use embassy_time::{Duration, Instant, Timer, WithTimeout}; use embassy_time::{Duration, Instant, Timer, WithTimeout};
#[cfg(feature = "proto-ipv6")]
use smoltcp::wire::IpAddress;
#[cfg(feature = "proto-ipv6")]
use smoltcp::wire::Ipv6Address;
use super::*; use super::*;
@ -392,11 +518,11 @@ pub mod ping {
// make a single ping // make a single ping
// - shorts out errors // - shorts out errors
// - select the ip version // - select the ip version
let ping_duration = match params.target().unwrap() { let ping_duration = match params.target.unwrap() {
#[cfg(feature = "proto-ipv4")] #[cfg(feature = "proto-ipv4")]
IpAddr::V4(_) => self.single_ping_v4(params, seq_no).await?, IpAddress::Ipv4(_) => self.single_ping_v4(params, seq_no).await?,
#[cfg(feature = "proto-ipv6")] #[cfg(feature = "proto-ipv6")]
IpAddr::V6(_) => self.single_ping_v6(params, seq_no).await?, IpAddress::Ipv6(_) => self.single_ping_v6(params, seq_no).await?,
}; };
// safely add up the durations of each ping // safely add up the durations of each ping
@ -478,7 +604,7 @@ pub mod ping {
// Helper function to recieve and return the correct echo reply when it finds it // Helper function to recieve and return the correct echo reply when it finds it
async fn recv_pong(socket: &IcmpSocket<'_>, seq_no: u16) -> Result<(), PingError> { async fn recv_pong(socket: &IcmpSocket<'_>, seq_no: u16) -> Result<(), PingError> {
while match socket.recv_with(|(buf, _)| filter_pong(buf, seq_no)).await { while match socket.recv_from_with(|(buf, _)| filter_pong(buf, seq_no)).await {
Ok(b) => !b, Ok(b) => !b,
Err(e) => return Err(PingError::SocketRecvError(e)), Err(e) => return Err(PingError::SocketRecvError(e)),
} {} } {}
@ -548,7 +674,7 @@ pub mod ping {
// Helper function to recieve and return the correct echo reply when it finds it // Helper function to recieve and return the correct echo reply when it finds it
async fn recv_pong(socket: &IcmpSocket<'_>, seq_no: u16) -> Result<(), PingError> { async fn recv_pong(socket: &IcmpSocket<'_>, seq_no: u16) -> Result<(), PingError> {
while match socket.recv_with(|(buf, _)| filter_pong(buf, seq_no)).await { while match socket.recv_from_with(|(buf, _)| filter_pong(buf, seq_no)).await {
Ok(b) => !b, Ok(b) => !b,
Err(e) => return Err(PingError::SocketRecvError(e)), Err(e) => return Err(PingError::SocketRecvError(e)),
} {} } {}
@ -581,9 +707,9 @@ pub mod ping {
/// * `timeout` - The timeout duration before returning a [`PingError::DestinationHostUnreachable`] error. /// * `timeout` - The timeout duration before returning a [`PingError::DestinationHostUnreachable`] error.
/// * `rate_limit` - The minimum time per echo request. /// * `rate_limit` - The minimum time per echo request.
pub struct PingParams<'a> { pub struct PingParams<'a> {
target: Option<IpAddr>, target: Option<IpAddress>,
#[cfg(feature = "proto-ipv6")] #[cfg(feature = "proto-ipv6")]
source: Option<Ipv6Addr>, source: Option<Ipv6Address>,
payload: &'a [u8], payload: &'a [u8],
hop_limit: Option<u8>, hop_limit: Option<u8>,
count: u16, count: u16,
@ -610,7 +736,7 @@ pub mod ping {
/// Creates a new instance of [`PingParams`] with the specified target IP address. /// Creates a new instance of [`PingParams`] with the specified target IP address.
pub fn new<T: Into<IpAddr>>(target: T) -> Self { pub fn new<T: Into<IpAddr>>(target: T) -> Self {
Self { Self {
target: Some(target.into()), target: Some(PingParams::ip_addr_to_smoltcp(target)),
#[cfg(feature = "proto-ipv6")] #[cfg(feature = "proto-ipv6")]
source: None, source: None,
payload: b"embassy-net", payload: b"embassy-net",
@ -621,21 +747,34 @@ pub mod ping {
} }
} }
fn ip_addr_to_smoltcp<T: Into<IpAddr>>(ip_addr: T) -> IpAddress {
match ip_addr.into() {
#[cfg(feature = "proto-ipv4")]
IpAddr::V4(v4) => IpAddress::Ipv4(v4),
#[cfg(not(feature = "proto-ipv4"))]
IpAddr::V4(_) => unreachable!(),
#[cfg(feature = "proto-ipv6")]
IpAddr::V6(v6) => IpAddress::Ipv6(v6),
#[cfg(not(feature = "proto-ipv6"))]
IpAddr::V6(_) => unreachable!(),
}
}
/// Sets the target IP address for the ping. /// Sets the target IP address for the ping.
pub fn set_target<T: Into<IpAddr>>(&mut self, target: T) -> &mut Self { pub fn set_target<T: Into<IpAddr>>(&mut self, target: T) -> &mut Self {
self.target = Some(target.into()); self.target = Some(PingParams::ip_addr_to_smoltcp(target));
self self
} }
/// Retrieves the target IP address for the ping. /// Retrieves the target IP address for the ping.
pub fn target(&self) -> Option<IpAddr> { pub fn target(&self) -> Option<IpAddr> {
self.target self.target.map(|t| t.into())
} }
/// Sets the source IP address for the ping (IPv6 only). /// Sets the source IP address for the ping (IPv6 only).
#[cfg(feature = "proto-ipv6")] #[cfg(feature = "proto-ipv6")]
pub fn set_source(&mut self, source: Ipv6Addr) -> &mut Self { pub fn set_source<T: Into<Ipv6Address>>(&mut self, source: T) -> &mut Self {
self.source = Some(source); self.source = Some(source.into());
self self
} }