diff --git a/src/util.rs b/src/util.rs index 0054cb9e0175f1a5b92104eefa125540edb062a6..52f66df7e5b77da5c70c75f7ff43839ba035c808 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,13 +1,14 @@ -use nix::sys::termios::{tcgetattr, tcsetattr, LocalFlags, SetArg::TCSADRAIN}; -use tarantool::session::{self, UserId}; - use crate::traft::error::Error; +use nix::sys::termios::{tcgetattr, tcsetattr, LocalFlags, SetArg::TCSADRAIN}; use std::any::{Any, TypeId}; +use std::cell::Cell; use std::io::BufRead as _; use std::io::BufReader; use std::io::Write as _; use std::os::fd::AsRawFd; +use std::panic::Location; use std::time::Duration; +use tarantool::session::{self, UserId}; pub use Either::{Left, Right}; pub const INFINITY: Duration = Duration::from_secs(30 * 365 * 24 * 60 * 60); @@ -566,6 +567,7 @@ pub type CheckIsSameType<L, R> = <L as IsSameType<L, R>>::Void; /// A helper struct to enforce that a function must not yield. Will cause a /// panic if fiber yields are detected when drop is called for it. pub struct NoYieldsGuard { + message: &'static str, csw: u64, } @@ -574,6 +576,15 @@ impl NoYieldsGuard { #[inline(always)] pub fn new() -> Self { Self { + message: "fiber yielded when it wasn't supposed to", + csw: tarantool::fiber::csw(), + } + } + + #[inline(always)] + pub fn with_message(message: &'static str) -> Self { + Self { + message, csw: tarantool::fiber::csw(), } } @@ -588,11 +599,113 @@ impl Drop for NoYieldsGuard { #[inline(always)] fn drop(&mut self) { if self.has_yielded() { - panic!("NoYieldsGuard: fiber yielded when it wasn't supposed to"); + panic!("NoYieldsGuard: {}", self.message); + } + } +} + +//////////////////////////////////////////////////////////////////////////////// +// NoYieldsRefCell +//////////////////////////////////////////////////////////////////////////////// + +/// A `RefCell` wrapper which also enforces that the wrapped value is never +/// borrowed across fiber yields. +#[derive(Debug)] +pub struct NoYieldsRefCell<T> { + inner: std::cell::RefCell<T>, + loc: Cell<&'static Location<'static>>, +} + +impl<T> Default for NoYieldsRefCell<T> +where + T: Default, +{ + #[inline(always)] + #[track_caller] + fn default() -> Self { + Self { + inner: Default::default(), + loc: Cell::new(Location::caller()), + } + } +} + +impl<T> NoYieldsRefCell<T> { + #[inline(always)] + #[track_caller] + pub fn new(inner: T) -> Self { + Self { + inner: std::cell::RefCell::new(inner), + loc: Cell::new(Location::caller()), } } + + #[inline(always)] + #[track_caller] + pub fn borrow(&self) -> NoYieldsRef<'_, T> { + self.loc.set(Location::caller()); + let inner = self.inner.borrow(); + let guard = + NoYieldsGuard::with_message("yield detected while NoYieldsRefCell was borrowed"); + NoYieldsRef { inner, guard } + } + + #[inline(always)] + #[track_caller] + pub fn borrow_mut(&self) -> NoYieldsRefMut<'_, T> { + let Ok(inner) = self.inner.try_borrow_mut() else { + panic!("already borrowed at {}", self.loc.get()); + }; + self.loc.set(Location::caller()); + let guard = + NoYieldsGuard::with_message("yield detected while NoYieldsRefCell was borrowed"); + NoYieldsRefMut { inner, guard } + } +} + +pub struct NoYieldsRef<'a, T> { + inner: std::cell::Ref<'a, T>, + /// This is only needed for it's `Drop` implementation. + #[allow(unused)] + guard: NoYieldsGuard, +} + +impl<T> std::ops::Deref for NoYieldsRef<'_, T> { + type Target = T; + + #[inline(always)] + fn deref(&self) -> &Self::Target { + &self.inner + } } +pub struct NoYieldsRefMut<'a, T> { + inner: std::cell::RefMut<'a, T>, + /// This is only needed for it's `Drop` implementation. + #[allow(unused)] + guard: NoYieldsGuard, +} + +impl<T> std::ops::Deref for NoYieldsRefMut<'_, T> { + type Target = T; + + #[inline(always)] + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl<T> std::ops::DerefMut for NoYieldsRefMut<'_, T> { + #[inline(always)] + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +//////////////////////////////////////////////////////////////////////////////// +// ... +//////////////////////////////////////////////////////////////////////////////// + #[inline] pub(crate) fn effective_user_id() -> UserId { session::euid().expect("infallible in picodata")