From 3d2fa3a400adbd8ab2d3a585fc5a5171148ac5fb Mon Sep 17 00:00:00 2001 From: Georgy Moshkin <gmoshkin@picodata.io> Date: Fri, 28 Jul 2023 18:23:54 +0300 Subject: [PATCH] refactor: hide ConnectionPool's mutability for ease of use --- src/sync.rs | 6 +-- src/traft/network.rs | 125 +++++++++++++++++++++---------------------- src/traft/node.rs | 3 +- 3 files changed, 66 insertions(+), 68 deletions(-) diff --git a/src/sync.rs b/src/sync.rs index fabb9d458e..e088ed5a01 100644 --- a/src/sync.rs +++ b/src/sync.rs @@ -31,7 +31,7 @@ fn proc_get_vclock() -> traft::Result<Vclock> { /// Calls [`proc_get_vclock`] on instance with `instance_id`. pub async fn call_get_vclock( - pool: &mut ConnectionPool, + pool: &ConnectionPool, instance_id: &impl IdOfInstance, ) -> traft::Result<Vclock> { let (vclock,): (Vclock,) = pool @@ -178,7 +178,7 @@ mod tests { async fn vclock_proc() { let storage = Clusterwide::new().unwrap(); // Connect to the current Tarantool instance - let mut pool = ConnectionPool::new(storage.clone(), Default::default()); + let pool = ConnectionPool::new(storage.clone(), Default::default()); let l = ::tarantool::lua_state(); let listen: String = l.eval("return box.info.listen").unwrap(); @@ -193,7 +193,7 @@ mod tests { .unwrap(); crate::init_handlers(); - let result = call_get_vclock(&mut pool, &instance.raft_id).await.unwrap(); + let result = call_get_vclock(&pool, &instance.raft_id).await.unwrap(); assert_eq!(result, Vclock::current()); pool.call( diff --git a/src/traft/network.rs b/src/traft/network.rs index 4904322846..311fd48242 100644 --- a/src/traft/network.rs +++ b/src/traft/network.rs @@ -11,8 +11,9 @@ use ::tarantool::util::IntoClones; use futures::future::poll_fn; use futures::Future; use futures::FutureExt as _; +use std::cell::UnsafeCell; +use std::collections::HashMap; use std::collections::VecDeque; -use std::collections::{hash_map::Entry, HashMap}; use std::pin::Pin; use std::task::Poll; use std::time::Duration; @@ -345,8 +346,8 @@ impl std::fmt::Debug for PoolWorker { #[derive(Debug)] pub struct ConnectionPool { worker_options: WorkerOptions, - workers: HashMap<RaftId, PoolWorker>, - raft_ids: HashMap<InstanceId, RaftId>, + workers: UnsafeCell<HashMap<RaftId, PoolWorker>>, + raft_ids: UnsafeCell<HashMap<InstanceId, RaftId>>, peer_addresses: PeerAddresses, instances: Instances, } @@ -356,8 +357,8 @@ impl ConnectionPool { pub fn new(storage: Clusterwide, worker_options: WorkerOptions) -> Self { Self { worker_options, - workers: HashMap::new(), - raft_ids: HashMap::new(), + workers: Default::default(), + raft_ids: Default::default(), peer_addresses: storage.peer_addresses, instances: storage.instances, } @@ -369,57 +370,55 @@ impl ConnectionPool { todo!(); } - fn get_or_create_by_raft_id(&mut self, raft_id: RaftId) -> Result<&mut PoolWorker> { - match self.workers.entry(raft_id) { - Entry::Occupied(entry) => Ok(entry.into_mut()), - Entry::Vacant(entry) => { - let instance_id = self - .instances - .field::<instance_field::InstanceId>(&raft_id) - .map_err(|_| Error::NoInstanceWithRaftId(raft_id)) - .ok(); - // Check if address of this peer is known. - // No need to store the result, - // because it will be updated in the loop - let _ = self.peer_addresses.try_get(raft_id)?; - let worker = PoolWorker::run( - raft_id, - instance_id.clone(), - self.peer_addresses.clone(), - self.worker_options.clone(), - )?; - if let Some(instance_id) = instance_id { - self.raft_ids.insert(instance_id, raft_id); - } - Ok(entry.insert(worker)) + fn get_or_create_by_raft_id(&self, raft_id: RaftId) -> Result<&PoolWorker> { + // SAFETY: everything here is safe, because we only use ConnectionPool + // from tx thread + let workers = unsafe { &*self.workers.get() }; + if let Some(worker) = workers.get(&raft_id) { + Ok(worker) + } else { + let instance_id = self + .instances + .field::<instance_field::InstanceId>(&raft_id) + .map_err(|_| Error::NoInstanceWithRaftId(raft_id)) + .ok(); + // Check if address of this peer is known. + // No need to store the result, + // because it will be updated in the loop + let _ = self.peer_addresses.try_get(raft_id)?; + let worker = PoolWorker::run( + raft_id, + instance_id.clone(), + self.peer_addresses.clone(), + self.worker_options.clone(), + )?; + if let Some(instance_id) = instance_id { + let raft_ids = unsafe { &mut *self.raft_ids.get() }; + raft_ids.insert(instance_id, raft_id); } + + let workers = unsafe { &mut *self.workers.get() }; + Ok(workers.entry(raft_id).or_insert(worker)) } } - fn get_or_create_by_instance_id(&mut self, instance_id: &str) -> Result<&mut PoolWorker> { - match self.raft_ids.entry(InstanceId(instance_id.into())) { - Entry::Occupied(entry) => { - let worker = self - .workers - .get_mut(entry.get()) - .expect("instance_id is present, but the worker isn't"); - Ok(worker) - } - Entry::Vacant(entry) => { - let instance_id = entry.key(); - let raft_id = self - .instances - .field::<instance_field::RaftId>(instance_id) - .map_err(|_| Error::NoInstanceWithInstanceId(instance_id.clone()))?; - let worker = PoolWorker::run( - raft_id, - instance_id.clone(), - self.peer_addresses.clone(), - self.worker_options.clone(), - )?; - entry.insert(raft_id); - Ok(self.workers.entry(raft_id).or_insert(worker)) - } + fn get_or_create_by_instance_id(&self, instance_id: &str) -> Result<&PoolWorker> { + // SAFETY: everything here is safe, because we only use ConnectionPool + // from tx thread + let raft_ids = unsafe { &*self.raft_ids.get() }; + if let Some(raft_id) = raft_ids.get(instance_id) { + let workers = unsafe { &*self.workers.get() }; + let worker = workers + .get(raft_id) + .expect("instance_id is present, but the worker isn't"); + Ok(worker) + } else { + let instance_id = InstanceId::from(instance_id); + let raft_id = self + .instances + .field::<instance_field::RaftId>(&instance_id) + .map_err(|_| Error::NoInstanceWithInstanceId(instance_id.clone()))?; + self.get_or_create_by_raft_id(raft_id) } } @@ -431,7 +430,7 @@ impl ConnectionPool { /// it's not appropriate for use inside a transaction. Anyway, /// sending a message inside a transaction is always a bad idea. #[inline] - pub fn send(&mut self, msg: raft::Message) -> Result<()> { + pub fn send(&self, msg: raft::Message) -> Result<()> { self.get_or_create_by_raft_id(msg.to)?.send(msg) } @@ -441,7 +440,7 @@ impl ConnectionPool { /// If the request failed, it's a responsibility of the caller /// to re-send it later. pub fn call<R>( - &mut self, + &self, id: &impl IdOfInstance, req: &R, ) -> Result<impl Future<Output = Result<R::Response>>> @@ -477,7 +476,7 @@ impl ConnectionPool { /// If the request failed, it's a responsibility of the caller /// to re-send it later. pub fn call_raw<Args, Response>( - &mut self, + &self, id: &impl IdOfInstance, proc: &'static str, args: &Args, @@ -509,7 +508,7 @@ impl ConnectionPool { impl Drop for ConnectionPool { fn drop(&mut self) { - for (_, worker) in self.workers.drain() { + for (_, worker) in self.workers.get_mut().drain() { worker.stop(); } } @@ -522,19 +521,19 @@ impl Drop for ConnectionPool { /// Types implementing this trait can be used to identify a `Instance` when /// accessing ConnectionPool. pub trait IdOfInstance: std::hash::Hash + Clone + std::fmt::Debug { - fn get_or_create_in<'p>(&self, pool: &'p mut ConnectionPool) -> Result<&'p mut PoolWorker>; + fn get_or_create_in<'p>(&self, pool: &'p ConnectionPool) -> Result<&'p PoolWorker>; } impl IdOfInstance for RaftId { #[inline(always)] - fn get_or_create_in<'p>(&self, pool: &'p mut ConnectionPool) -> Result<&'p mut PoolWorker> { + fn get_or_create_in<'p>(&self, pool: &'p ConnectionPool) -> Result<&'p PoolWorker> { pool.get_or_create_by_raft_id(*self) } } impl IdOfInstance for InstanceId { #[inline(always)] - fn get_or_create_in<'p>(&self, pool: &'p mut ConnectionPool) -> Result<&'p mut PoolWorker> { + fn get_or_create_in<'p>(&self, pool: &'p ConnectionPool) -> Result<&'p PoolWorker> { pool.get_or_create_by_instance_id(self) } } @@ -578,7 +577,7 @@ mod tests { let storage = Clusterwide::new().unwrap(); // Connect to the current Tarantool instance - let mut pool = ConnectionPool::new(storage.clone(), Default::default()); + let pool = ConnectionPool::new(storage.clone(), Default::default()); let listen: String = l.eval("return box.info.listen").unwrap(); let instance = traft::Instance { @@ -628,7 +627,7 @@ mod tests { call_timeout: Duration::from_millis(50), ..Default::default() }; - let mut pool = ConnectionPool::new(storage.clone(), opts); + let pool = ConnectionPool::new(storage.clone(), opts); let listen: String = l.eval("return box.info.listen").unwrap(); let instance = traft::Instance { @@ -714,7 +713,7 @@ mod tests { call_timeout: Duration::from_millis(50), ..Default::default() }; - let mut pool = ConnectionPool::new(storage.clone(), opts); + let pool = ConnectionPool::new(storage.clone(), opts); let listen: String = l.eval("return box.info.listen").unwrap(); let instance = traft::Instance { @@ -792,7 +791,7 @@ mod tests { call_timeout: Duration::from_secs(3), ..Default::default() }; - let mut pool = ConnectionPool::new(storage.clone(), opts); + let pool = ConnectionPool::new(storage.clone(), opts); let listen: String = l.eval("return box.info.listen").unwrap(); let instance = traft::Instance { diff --git a/src/traft/node.rs b/src/traft/node.rs index 0f5969c9c4..5ab2eabe30 100644 --- a/src/traft/node.rs +++ b/src/traft/node.rs @@ -1570,8 +1570,7 @@ impl NodeImpl { )); } let master = self.storage.instances.get(&replicaset.master_id)?; - let master_vclock = - fiber::block_on(sync::call_get_vclock(&mut self.pool, &master.raft_id))?; + let master_vclock = fiber::block_on(sync::call_get_vclock(&self.pool, &master.raft_id))?; let local_vclock = Vclock::current(); if matches!( local_vclock.partial_cmp(&master_vclock), -- GitLab