From 3d2fa3a400adbd8ab2d3a585fc5a5171148ac5fb Mon Sep 17 00:00:00 2001
From: Georgy Moshkin <gmoshkin@picodata.io>
Date: Fri, 28 Jul 2023 18:23:54 +0300
Subject: [PATCH] refactor: hide ConnectionPool's mutability for ease of use

---
 src/sync.rs          |   6 +--
 src/traft/network.rs | 125 +++++++++++++++++++++----------------------
 src/traft/node.rs    |   3 +-
 3 files changed, 66 insertions(+), 68 deletions(-)

diff --git a/src/sync.rs b/src/sync.rs
index fabb9d458e..e088ed5a01 100644
--- a/src/sync.rs
+++ b/src/sync.rs
@@ -31,7 +31,7 @@ fn proc_get_vclock() -> traft::Result<Vclock> {
 
 /// Calls [`proc_get_vclock`] on instance with `instance_id`.
 pub async fn call_get_vclock(
-    pool: &mut ConnectionPool,
+    pool: &ConnectionPool,
     instance_id: &impl IdOfInstance,
 ) -> traft::Result<Vclock> {
     let (vclock,): (Vclock,) = pool
@@ -178,7 +178,7 @@ mod tests {
     async fn vclock_proc() {
         let storage = Clusterwide::new().unwrap();
         // Connect to the current Tarantool instance
-        let mut pool = ConnectionPool::new(storage.clone(), Default::default());
+        let pool = ConnectionPool::new(storage.clone(), Default::default());
         let l = ::tarantool::lua_state();
         let listen: String = l.eval("return box.info.listen").unwrap();
 
@@ -193,7 +193,7 @@ mod tests {
             .unwrap();
         crate::init_handlers();
 
-        let result = call_get_vclock(&mut pool, &instance.raft_id).await.unwrap();
+        let result = call_get_vclock(&pool, &instance.raft_id).await.unwrap();
         assert_eq!(result, Vclock::current());
 
         pool.call(
diff --git a/src/traft/network.rs b/src/traft/network.rs
index 4904322846..311fd48242 100644
--- a/src/traft/network.rs
+++ b/src/traft/network.rs
@@ -11,8 +11,9 @@ use ::tarantool::util::IntoClones;
 use futures::future::poll_fn;
 use futures::Future;
 use futures::FutureExt as _;
+use std::cell::UnsafeCell;
+use std::collections::HashMap;
 use std::collections::VecDeque;
-use std::collections::{hash_map::Entry, HashMap};
 use std::pin::Pin;
 use std::task::Poll;
 use std::time::Duration;
@@ -345,8 +346,8 @@ impl std::fmt::Debug for PoolWorker {
 #[derive(Debug)]
 pub struct ConnectionPool {
     worker_options: WorkerOptions,
-    workers: HashMap<RaftId, PoolWorker>,
-    raft_ids: HashMap<InstanceId, RaftId>,
+    workers: UnsafeCell<HashMap<RaftId, PoolWorker>>,
+    raft_ids: UnsafeCell<HashMap<InstanceId, RaftId>>,
     peer_addresses: PeerAddresses,
     instances: Instances,
 }
@@ -356,8 +357,8 @@ impl ConnectionPool {
     pub fn new(storage: Clusterwide, worker_options: WorkerOptions) -> Self {
         Self {
             worker_options,
-            workers: HashMap::new(),
-            raft_ids: HashMap::new(),
+            workers: Default::default(),
+            raft_ids: Default::default(),
             peer_addresses: storage.peer_addresses,
             instances: storage.instances,
         }
@@ -369,57 +370,55 @@ impl ConnectionPool {
         todo!();
     }
 
-    fn get_or_create_by_raft_id(&mut self, raft_id: RaftId) -> Result<&mut PoolWorker> {
-        match self.workers.entry(raft_id) {
-            Entry::Occupied(entry) => Ok(entry.into_mut()),
-            Entry::Vacant(entry) => {
-                let instance_id = self
-                    .instances
-                    .field::<instance_field::InstanceId>(&raft_id)
-                    .map_err(|_| Error::NoInstanceWithRaftId(raft_id))
-                    .ok();
-                // Check if address of this peer is known.
-                // No need to store the result,
-                // because it will be updated in the loop
-                let _ = self.peer_addresses.try_get(raft_id)?;
-                let worker = PoolWorker::run(
-                    raft_id,
-                    instance_id.clone(),
-                    self.peer_addresses.clone(),
-                    self.worker_options.clone(),
-                )?;
-                if let Some(instance_id) = instance_id {
-                    self.raft_ids.insert(instance_id, raft_id);
-                }
-                Ok(entry.insert(worker))
+    fn get_or_create_by_raft_id(&self, raft_id: RaftId) -> Result<&PoolWorker> {
+        // SAFETY: everything here is safe, because we only use ConnectionPool
+        // from tx thread
+        let workers = unsafe { &*self.workers.get() };
+        if let Some(worker) = workers.get(&raft_id) {
+            Ok(worker)
+        } else {
+            let instance_id = self
+                .instances
+                .field::<instance_field::InstanceId>(&raft_id)
+                .map_err(|_| Error::NoInstanceWithRaftId(raft_id))
+                .ok();
+            // Check if address of this peer is known.
+            // No need to store the result,
+            // because it will be updated in the loop
+            let _ = self.peer_addresses.try_get(raft_id)?;
+            let worker = PoolWorker::run(
+                raft_id,
+                instance_id.clone(),
+                self.peer_addresses.clone(),
+                self.worker_options.clone(),
+            )?;
+            if let Some(instance_id) = instance_id {
+                let raft_ids = unsafe { &mut *self.raft_ids.get() };
+                raft_ids.insert(instance_id, raft_id);
             }
+
+            let workers = unsafe { &mut *self.workers.get() };
+            Ok(workers.entry(raft_id).or_insert(worker))
         }
     }
 
-    fn get_or_create_by_instance_id(&mut self, instance_id: &str) -> Result<&mut PoolWorker> {
-        match self.raft_ids.entry(InstanceId(instance_id.into())) {
-            Entry::Occupied(entry) => {
-                let worker = self
-                    .workers
-                    .get_mut(entry.get())
-                    .expect("instance_id is present, but the worker isn't");
-                Ok(worker)
-            }
-            Entry::Vacant(entry) => {
-                let instance_id = entry.key();
-                let raft_id = self
-                    .instances
-                    .field::<instance_field::RaftId>(instance_id)
-                    .map_err(|_| Error::NoInstanceWithInstanceId(instance_id.clone()))?;
-                let worker = PoolWorker::run(
-                    raft_id,
-                    instance_id.clone(),
-                    self.peer_addresses.clone(),
-                    self.worker_options.clone(),
-                )?;
-                entry.insert(raft_id);
-                Ok(self.workers.entry(raft_id).or_insert(worker))
-            }
+    fn get_or_create_by_instance_id(&self, instance_id: &str) -> Result<&PoolWorker> {
+        // SAFETY: everything here is safe, because we only use ConnectionPool
+        // from tx thread
+        let raft_ids = unsafe { &*self.raft_ids.get() };
+        if let Some(raft_id) = raft_ids.get(instance_id) {
+            let workers = unsafe { &*self.workers.get() };
+            let worker = workers
+                .get(raft_id)
+                .expect("instance_id is present, but the worker isn't");
+            Ok(worker)
+        } else {
+            let instance_id = InstanceId::from(instance_id);
+            let raft_id = self
+                .instances
+                .field::<instance_field::RaftId>(&instance_id)
+                .map_err(|_| Error::NoInstanceWithInstanceId(instance_id.clone()))?;
+            self.get_or_create_by_raft_id(raft_id)
         }
     }
 
@@ -431,7 +430,7 @@ impl ConnectionPool {
     /// it's not appropriate for use inside a transaction. Anyway,
     /// sending a message inside a transaction is always a bad idea.
     #[inline]
-    pub fn send(&mut self, msg: raft::Message) -> Result<()> {
+    pub fn send(&self, msg: raft::Message) -> Result<()> {
         self.get_or_create_by_raft_id(msg.to)?.send(msg)
     }
 
@@ -441,7 +440,7 @@ impl ConnectionPool {
     /// If the request failed, it's a responsibility of the caller
     /// to re-send it later.
     pub fn call<R>(
-        &mut self,
+        &self,
         id: &impl IdOfInstance,
         req: &R,
     ) -> Result<impl Future<Output = Result<R::Response>>>
@@ -477,7 +476,7 @@ impl ConnectionPool {
     /// If the request failed, it's a responsibility of the caller
     /// to re-send it later.
     pub fn call_raw<Args, Response>(
-        &mut self,
+        &self,
         id: &impl IdOfInstance,
         proc: &'static str,
         args: &Args,
@@ -509,7 +508,7 @@ impl ConnectionPool {
 
 impl Drop for ConnectionPool {
     fn drop(&mut self) {
-        for (_, worker) in self.workers.drain() {
+        for (_, worker) in self.workers.get_mut().drain() {
             worker.stop();
         }
     }
@@ -522,19 +521,19 @@ impl Drop for ConnectionPool {
 /// Types implementing this trait can be used to identify a `Instance` when
 /// accessing ConnectionPool.
 pub trait IdOfInstance: std::hash::Hash + Clone + std::fmt::Debug {
-    fn get_or_create_in<'p>(&self, pool: &'p mut ConnectionPool) -> Result<&'p mut PoolWorker>;
+    fn get_or_create_in<'p>(&self, pool: &'p ConnectionPool) -> Result<&'p PoolWorker>;
 }
 
 impl IdOfInstance for RaftId {
     #[inline(always)]
-    fn get_or_create_in<'p>(&self, pool: &'p mut ConnectionPool) -> Result<&'p mut PoolWorker> {
+    fn get_or_create_in<'p>(&self, pool: &'p ConnectionPool) -> Result<&'p PoolWorker> {
         pool.get_or_create_by_raft_id(*self)
     }
 }
 
 impl IdOfInstance for InstanceId {
     #[inline(always)]
-    fn get_or_create_in<'p>(&self, pool: &'p mut ConnectionPool) -> Result<&'p mut PoolWorker> {
+    fn get_or_create_in<'p>(&self, pool: &'p ConnectionPool) -> Result<&'p PoolWorker> {
         pool.get_or_create_by_instance_id(self)
     }
 }
@@ -578,7 +577,7 @@ mod tests {
 
         let storage = Clusterwide::new().unwrap();
         // Connect to the current Tarantool instance
-        let mut pool = ConnectionPool::new(storage.clone(), Default::default());
+        let pool = ConnectionPool::new(storage.clone(), Default::default());
         let listen: String = l.eval("return box.info.listen").unwrap();
 
         let instance = traft::Instance {
@@ -628,7 +627,7 @@ mod tests {
             call_timeout: Duration::from_millis(50),
             ..Default::default()
         };
-        let mut pool = ConnectionPool::new(storage.clone(), opts);
+        let pool = ConnectionPool::new(storage.clone(), opts);
         let listen: String = l.eval("return box.info.listen").unwrap();
 
         let instance = traft::Instance {
@@ -714,7 +713,7 @@ mod tests {
             call_timeout: Duration::from_millis(50),
             ..Default::default()
         };
-        let mut pool = ConnectionPool::new(storage.clone(), opts);
+        let pool = ConnectionPool::new(storage.clone(), opts);
         let listen: String = l.eval("return box.info.listen").unwrap();
 
         let instance = traft::Instance {
@@ -792,7 +791,7 @@ mod tests {
             call_timeout: Duration::from_secs(3),
             ..Default::default()
         };
-        let mut pool = ConnectionPool::new(storage.clone(), opts);
+        let pool = ConnectionPool::new(storage.clone(), opts);
         let listen: String = l.eval("return box.info.listen").unwrap();
 
         let instance = traft::Instance {
diff --git a/src/traft/node.rs b/src/traft/node.rs
index 0f5969c9c4..5ab2eabe30 100644
--- a/src/traft/node.rs
+++ b/src/traft/node.rs
@@ -1570,8 +1570,7 @@ impl NodeImpl {
             ));
         }
         let master = self.storage.instances.get(&replicaset.master_id)?;
-        let master_vclock =
-            fiber::block_on(sync::call_get_vclock(&mut self.pool, &master.raft_id))?;
+        let master_vclock = fiber::block_on(sync::call_get_vclock(&self.pool, &master.raft_id))?;
         let local_vclock = Vclock::current();
         if matches!(
             local_vclock.partial_cmp(&master_vclock),
-- 
GitLab