commit 3851430dec41c204d6219a84c47ca6885622a98e Author: kolapapa Date: Sun Dec 20 14:02:52 2020 +0100 Add Socket::(bind_)device Co-authored-by: Thomas de Zeeuw diff --git a/src/sys/unix.rs b/src/sys/unix.rs index 72097ae..1a0f24e 100644 --- a/src/sys/unix.rs +++ b/src/sys/unix.rs @@ -7,6 +7,8 @@ // except according to those terms. use std::cmp::min; +#[cfg(all(feature = "all", target_os = "linux"))] +use std::ffi::{CStr, CString}; #[cfg(not(target_os = "redox"))] use std::io::{IoSlice, IoSliceMut}; use std::mem::{self, size_of, MaybeUninit}; @@ -19,6 +21,8 @@ use std::os::unix::net::{UnixDatagram, UnixListener, UnixStream}; #[cfg(feature = "all")] use std::path::Path; +#[cfg(all(feature = "all", target_os = "linux"))] +use std::slice; use std::time::Duration; use std::{io, ptr}; @@ -867,6 +871,73 @@ pub fn set_mark(&self, mark: u32) -> io::Result<()> { unsafe { setsockopt::(self.inner, libc::SOL_SOCKET, libc::SO_MARK, mark as c_int) } } + /// Gets the value for the `SO_BINDTODEVICE` option on this socket. + /// + /// This value gets the socket binded device's interface name. + /// + /// This function is only available on Linux. + #[cfg(all(feature = "all", target_os = "linux"))] + pub fn device(&self) -> io::Result> { + // TODO: replace with `MaybeUninit::uninit_array` once stable. + let mut buf: [MaybeUninit; libc::IFNAMSIZ] = + unsafe { MaybeUninit::<[MaybeUninit; libc::IFNAMSIZ]>::uninit().assume_init() }; + let mut len = buf.len() as libc::socklen_t; + unsafe { + syscall!(getsockopt( + self.inner, + libc::SOL_SOCKET, + libc::SO_BINDTODEVICE, + buf.as_mut_ptr().cast(), + &mut len, + ))?; + } + if len == 0 { + Ok(None) + } else { + // Allocate a buffer for `CString` with the length including the + // null terminator. + let len = len as usize; + let mut name = Vec::with_capacity(len); + + // TODO: use `MaybeUninit::slice_assume_init_ref` once stable. + // Safety: `len` bytes are writen by the OS, this includes a null + // terminator. However we don't copy the null terminator because + // `CString::from_vec_unchecked` adds its own null terminator. + let buf = unsafe { slice::from_raw_parts(buf.as_ptr().cast(), len - 1) }; + name.extend_from_slice(buf); + + // Safety: the OS initialised the string for us, which shouldn't + // include any null bytes. + Ok(Some(unsafe { CString::from_vec_unchecked(name) })) + } + } + + /// Sets the value for the `SO_BINDTODEVICE` option on this socket. + /// + /// If a socket is bound to an interface, only packets received from that + /// particular interface are processed by the socket. Note that this only + /// works for some socket types, particularly `AF_INET` sockets. + /// + /// If `interface` is `None` or an empty string it removes the binding. + /// + /// This function is only available on Linux. + #[cfg(all(feature = "all", target_os = "linux"))] + pub fn bind_device(&self, interface: Option<&CStr>) -> io::Result<()> { + let (value, len) = if let Some(interface) = interface { + (interface.as_ptr(), interface.to_bytes_with_nul().len()) + } else { + (ptr::null(), 0) + }; + syscall!(setsockopt( + self.inner, + libc::SOL_SOCKET, + libc::SO_BINDTODEVICE, + value.cast(), + len as libc::socklen_t, + )) + .map(|_| ()) + } + /// Get the value of the `SO_REUSEPORT` option on this socket. /// /// For more information about this option, see [`set_reuse_port`]. diff --git a/tests/socket.rs b/tests/socket.rs index 11a3d9c..75890af 100644 --- a/tests/socket.rs +++ b/tests/socket.rs @@ -1,3 +1,5 @@ +#[cfg(all(feature = "all", target_os = "linux"))] +use std::ffi::CStr; #[cfg(any(windows, target_vendor = "apple"))] use std::io; #[cfg(unix)] @@ -271,3 +273,19 @@ fn keepalive() { ))] assert_eq!(socket.keepalive_retries().unwrap(), 10); } + +#[cfg(all(feature = "all", target_os = "linux"))] +#[test] +fn device() { + const INTERFACE: &str = "lo0\0"; + let interface = CStr::from_bytes_with_nul(INTERFACE.as_bytes()).unwrap(); + let socket = Socket::new(Domain::IPV4, Type::STREAM, None).unwrap(); + + assert_eq!(socket.device().unwrap(), None); + + socket.bind_device(Some(interface)).unwrap(); + assert_eq!(socket.device().unwrap().as_deref(), Some(interface)); + + socket.bind_device(None).unwrap(); + assert_eq!(socket.device().unwrap(), None); +}