Skip to content
Snippets Groups Projects

fix: don't ignore all but the first resolved socket addresses

Merged Georgy Moshkin requested to merge gmoshkin/reduce-tcp-connect-logic-dubiousness into master
Files
3
@@ -34,6 +34,7 @@ use futures::{AsyncRead, AsyncWrite};
use crate::ffi::tarantool as ffi;
use crate::fiber;
use crate::fiber::r#async::context::ContextExt;
use crate::time::Instant;
#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
@@ -101,66 +102,80 @@ impl TcpStream {
// SAFETY: it is just simple sys call
let (v4_addrs, v6_addrs) = unsafe { resolve_addr(url, port, timeout.as_secs_f64())? };
let timeout = deadline.duration_since(fiber::clock());
// Take the first address, prefer ipv4
let (addr, addr_len) = if let Some(v4) = v4_addrs.first() {
(
v4 as *const _ as *const libc::sockaddr,
mem::size_of::<libc::sockaddr_in>() as libc::socklen_t,
)
} else if let Some(v6) = v6_addrs.first() {
(
v6 as *const _ as *const libc::sockaddr,
mem::size_of::<libc::sockaddr_in6>() as libc::socklen_t,
)
} else {
return Err(Error::ResolveAddress(url.into()));
};
// SAFETY: it is just sequential sys calls, so it's safe
let fd = unsafe { nonblocking_socket() }.map_err(|e| Error::Connect {
error: e,
address: format!("{url}:{port}"),
})?;
// Put the fd into a `TcpStream` immediately, so it's closed in case of error.
let result = Self {
fd: Rc::new(Cell::new(Some(fd))),
};
let mut last_error = None;
// SAFETY: it is just simple sys call
let mut io_error = match cvt(unsafe { libc::connect(fd, addr, addr_len) }) {
Ok(_) => {
return Ok(result);
for v4_addr in v4_addrs {
match Self::connect_single(LibcSocketAddr::V4(v4_addr), deadline) {
Ok(stream) => {
return Ok(stream);
}
Err(e) => last_error = Some(e),
}
Err(e) => e,
};
if io_error.raw_os_error() != Some(libc::EINPROGRESS) {
return Err(Error::Connect {
error: io_error,
address: format!("{url}:{port}"),
});
}
io_error = match crate::coio::coio_wait(fd, ffi::CoIOFlags::WRITE, timeout.as_secs_f64()) {
Ok(_) => match unsafe { check_socket_error(fd) } {
Ok(_) => {
return Ok(result);
for v6_addr in v6_addrs {
match Self::connect_single(LibcSocketAddr::V6(v6_addr), deadline) {
Ok(stream) => {
return Ok(stream);
}
Err(e) => e,
},
Err(e) => e,
};
Err(e) => last_error = Some(e),
}
}
if let Some(error) = last_error {
if let io::ErrorKind::TimedOut = error.kind() {
Err(Error::Timeout)
} else {
Err(Error::Connect {
error,
address: format!("{url}:{port}"),
})
}
} else {
Err(Error::ResolveAddress(url.into()))
}
}
if let io::ErrorKind::TimedOut = io_error.kind() {
return Err(Error::Timeout);
fn connect_single(socket_addr: LibcSocketAddr, deadline: Instant) -> io::Result<Self> {
let (kind, addr, addr_len);
match &socket_addr {
LibcSocketAddr::V4(v4) => {
kind = libc::AF_INET;
addr = v4 as *const libc::sockaddr_in as *const libc::sockaddr;
addr_len = mem::size_of::<libc::sockaddr_in>();
}
LibcSocketAddr::V6(v6) => {
kind = libc::AF_INET6;
addr = v6 as *const libc::sockaddr_in6 as *const libc::sockaddr;
addr_len = mem::size_of::<libc::sockaddr_in6>();
}
}
let fd = nonblocking_socket(kind)?;
let res = cvt(unsafe { libc::connect(fd.0, addr, addr_len as _) });
if let Err(io_error) = res {
if io_error.raw_os_error() != Some(libc::EINPROGRESS) {
return Err(io_error);
} else {
// Need to block the fiber until the connection result is known
}
let timeout = deadline.duration_since(fiber::clock());
crate::coio::coio_wait(fd.0, ffi::CoIOFlags::WRITE, timeout.as_secs_f64())?;
Err(Error::Connect {
error: io_error,
address: format!("{url}:{port}"),
})
// This is safe, because fd is still open.
unsafe { check_socket_error(fd.0)? };
// If no error, then connection is established
};
// If this allocation panics the fd will still be closed
let result = Self {
fd: Rc::new(Cell::new(None)),
};
// Now TcpStream owns the fd and takes responsibility of closing it.
result.fd.set(Some(fd.into_inner()));
return Ok(result);
}
#[inline(always)]
@@ -195,30 +210,82 @@ fn cvt(t: libc::c_int) -> io::Result<libc::c_int> {
}
}
unsafe fn nonblocking_socket() -> io::Result<RawFd> {
#[cfg(target_os = "linux")]
let fd: RawFd = cvt(libc::socket(
libc::AF_INET,
libc::SOCK_STREAM | libc::SOCK_CLOEXEC | libc::SOCK_NONBLOCK,
0,
))?;
#[cfg(target_os = "macos")]
let fd: RawFd = {
let fd = cvt(libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0))?;
cvt(libc::ioctl(fd, libc::FIOCLEX))?;
#[cfg(target_os = "linux")]
#[inline(always)]
fn nonblocking_socket(kind: libc::c_int) -> io::Result<AutoCloseFd> {
let fd = unsafe {
cvt(libc::socket(
kind,
libc::SOCK_STREAM | libc::SOCK_CLOEXEC | libc::SOCK_NONBLOCK,
0,
))?
};
let fd = AutoCloseFd(fd);
Ok(fd)
}
#[cfg(target_os = "macos")]
fn nonblocking_socket(kind: libc::c_int) -> io::Result<AutoCloseFd> {
// This is safe because `libc::socket` doesn't do undefined behavior
let fd = unsafe { cvt(libc::socket(kind, libc::SOCK_STREAM, 0))? };
let fd = AutoCloseFd(fd);
// This is safe because fd is open.
unsafe {
cvt(libc::ioctl(fd.0, libc::FIOCLEX))?;
}
// This is safe because fd is open and the opt_value buffer specification is valid.
unsafe {
let opt_value: libc::c_int = 1;
cvt(libc::setsockopt(
fd,
fd.0,
libc::SOL_SOCKET,
libc::SO_NOSIGPIPE,
&1 as *const libc::c_int as *const _,
mem::size_of::<libc::c_int>() as libc::socklen_t,
&opt_value as *const _ as *const libc::c_void,
mem::size_of_val(&opt_value) as _,
))?;
cvt(libc::ioctl(fd, libc::FIONBIO, &mut 1))?;
fd
};
}
// This is safe because fd is open.
unsafe {
cvt(libc::ioctl(fd.0, libc::FIONBIO, &mut 1))?;
}
Ok(fd)
}
/// A wrapper around a raw file descriptor, which automatically closes the
/// descriptor if dropped.
/// Use [`Self::into_inner`] to disable the automatic close on drop.
///
/// TODO: consider using [`std::os::fd::OwnedFd`] instead
struct AutoCloseFd(RawFd);
impl AutoCloseFd {
#[inline(always)]
fn into_inner(self) -> RawFd {
let fd = self.0;
std::mem::forget(self);
fd
}
}
impl Drop for AutoCloseFd {
fn drop(&mut self) {
// Safe as long as we only store open file descriptors
let rc = unsafe { libc::close(self.0) };
if rc != 0 {
crate::say_error!(
"failed closing socket descriptor: {}",
io::Error::last_os_error()
);
}
}
}
unsafe fn check_socket_error(fd: RawFd) -> io::Result<()> {
let mut val: libc::c_int = mem::zeroed();
let mut val_len = mem::size_of::<libc::c_int>() as libc::socklen_t;
@@ -299,6 +366,11 @@ unsafe fn resolve_addr(
Ok((ipv4_addresses, ipv6_addresses))
}
enum LibcSocketAddr {
V4(libc::sockaddr_in),
V6(libc::sockaddr_in6),
}
impl AsyncWrite for TcpStream {
fn poll_write(
self: Pin<&mut Self>,
@@ -591,7 +663,7 @@ mod tests {
async fn read_timeout() {
let mut stream = TcpStream::connect_timeout("localhost", listen_port(), _10_SEC).unwrap();
// Read greeting
let mut buf = vec![0; 128];
let mut buf = vec![0; 4096];
assert_eq!(
stream
.read_exact(&mut buf)
@@ -775,10 +847,10 @@ mod tests {
unsafe { libc::close(fd) };
}
fn get_socket_fds() -> Vec<u32> {
fn get_socket_fds() -> HashSet<u32> {
use std::os::unix::fs::FileTypeExt;
let mut res = vec![];
let mut res = HashSet::new();
for entry in std::fs::read_dir("/dev/fd/").unwrap() {
let Ok(entry) = entry else {
continue;
@@ -794,9 +866,8 @@ mod tests {
// Yay rust!
let fd_str = fd_path.file_name().unwrap();
let fd: u32 = fd_str.to_str().unwrap().parse().unwrap();
res.push(fd);
res.insert(fd);
}
res.sort_unstable();
res
}
@@ -813,6 +884,7 @@ mod tests {
// XXX: this is a bit unreliable, because tarantool is spawning a bunch
// of other threads which may or may not be creating and closing fds,
// so we may want to remove this test at some point
assert_eq!(fds_before, fds_after)
let new_fds: Vec<_> = fds_after.difference(&fds_before).copied().collect();
assert!(dbg!(new_fds.is_empty()));
}
}
Loading