From 707f67148c4708cdc4604e93526bc04bb649c4ae Mon Sep 17 00:00:00 2001
From: Georgy Moshkin <gmoshkin@picodata.io>
Date: Wed, 19 Oct 2022 14:23:50 +0300
Subject: [PATCH] refactor(governor): extract call_all function

---
 src/traft/node.rs | 159 ++++++++++++++++++++++++----------------------
 1 file changed, 83 insertions(+), 76 deletions(-)

diff --git a/src/traft/node.rs b/src/traft/node.rs
index 19927e6d3f..ef333676f8 100644
--- a/src/traft/node.rs
+++ b/src/traft/node.rs
@@ -1118,49 +1118,28 @@ fn raft_conf_change_loop(status: Rc<Cell<Status>>, storage: Storage) {
                 }
             );
 
-            // TODO: this crap is only needed to wait until results of all
-            // the calls are ready. There are several ways to rafactor this:
-            // - we could use a std-style channel that unblocks the reading end
-            //   once all the writing ends have dropped
-            //   (fiber::Channel cannot do that for now)
-            // - using the std Futures we could use futures::join!
-            //
-            // Those things aren't implemented yet, so this is what we do
-            static mut SENT_COUNT: usize = 0;
-            unsafe { SENT_COUNT = 0 };
-            let (cond_rx, cond_tx) = Rc::new(fiber::Cond::new()).into_clones();
             let replicaset_size = replicaset_iids.len();
-            let (rx, tx) = fiber::Channel::new(replicaset_size as _).into_clones();
-            for peer_instance_id in &replicaset_iids {
-                let tx = tx.clone();
-                let cond_tx = cond_tx.clone();
-                let peer_iid = peer_instance_id.clone();
-                pool.call(
-                    peer_instance_id,
-                    replication::Request {
-                        replicaset_instances: replicaset_iids.clone(),
-                        replicaset_id: replicaset_id.clone(),
-                        // TODO: what if someone goes offline/expelled?
-                        promote: replicaset_size == 1,
-                    },
-                    move |res| {
-                        tx.send((peer_iid, res)).expect("mustn't fail");
-                        unsafe { SENT_COUNT += 1 };
-                        if unsafe { SENT_COUNT } == replicaset_size {
-                            cond_tx.signal()
-                        }
-                    },
-                )
-                .expect("shouldn't fail");
-            }
-            // TODO: don't hard code timeout
-            if !cond_rx.wait_timeout(Duration::from_secs(3)) {
-                tlog!(Warning, "failed to configure replication: timed out");
-                continue 'governor;
-            }
+            let res = call_all(
+                &mut pool,
+                replicaset_iids.clone(),
+                replication::Request {
+                    replicaset_instances: replicaset_iids,
+                    replicaset_id: replicaset_id.clone(),
+                    // TODO: what if someone goes offline/expelled?
+                    promote: replicaset_size == 1,
+                },
+                // TODO: don't hard code timeout
+                Duration::from_secs(3),
+            );
+            let res = unwrap_ok_or!(res,
+                Err(e) => {
+                    tlog!(Warning, "failed to configure replication: {e}");
+                    continue 'governor;
+                }
+            );
 
             let cluster_id = storage.raft.cluster_id().unwrap().unwrap();
-            for (peer_iid, resp) in rx.into_iter().take(replicaset_size) {
+            for (peer_iid, resp) in res {
                 let cluster_id = cluster_id.clone();
                 let peer_iid_2 = peer_iid.clone();
                 let res = resp.and_then(move |replication::Response { lsn }| {
@@ -1217,44 +1196,26 @@ fn raft_conf_change_loop(status: Rc<Cell<Status>>, storage: Storage) {
                     continue 'governor;
                 }
             );
-            let peer_ids = peer_ids.collect::<Vec<_>>();
-
-            // TODO: good api needed
-            static mut SENT_COUNT: usize = 0;
-            unsafe { SENT_COUNT = 0 };
-            let (cond_rx, cond_tx) = Rc::new(fiber::Cond::new()).into_clones();
-            let peer_count = peer_ids.len();
-            let (rx, tx) = fiber::Channel::new(peer_count as _).into_clones();
-            // Send rpc request to all peers in the cluster
-            for peer_instance_id in &peer_ids {
-                let tx = tx.clone();
-                let cond_tx = cond_tx.clone();
-                let peer_iid = peer_instance_id.clone();
-                pool.call(
-                    peer_instance_id,
-                    sharding::Request {
-                        leader_id: raft_id,
-                        term,
-                    },
-                    move |res| {
-                        tx.send((peer_iid, res)).expect("mustn't fail");
-                        unsafe { SENT_COUNT += 1 };
-                        if unsafe { SENT_COUNT } == peer_count {
-                            cond_tx.signal()
-                        }
-                    },
-                )
-                .expect("shouldn't fail");
-            }
-            // TODO: don't hard code timeout
-            if !cond_rx.wait_timeout(Duration::from_secs(3)) {
-                tlog!(Warning, "failed to configure sharding: timed out");
-                continue 'governor;
-            }
+
+            let res = call_all(
+                &mut pool,
+                peer_ids,
+                sharding::Request {
+                    leader_id: raft_id,
+                    term,
+                },
+                // TODO: don't hard code timeout
+                Duration::from_secs(3),
+            );
+            let res = unwrap_ok_or!(res,
+                Err(e) => {
+                    tlog!(Warning, "failed to configure sharding: {e}");
+                    continue 'governor;
+                }
+            );
 
             let cluster_id = storage.raft.cluster_id().unwrap().unwrap();
-            // Process all rpc responses
-            for (peer_iid, resp) in rx.into_iter().take(peer_count) {
+            for (peer_iid, resp) in res {
                 let cluster_id = cluster_id.clone();
                 let peer_iid_2 = peer_iid.clone();
                 let res = resp.and_then(move |sharding::Response {}| {
@@ -1310,6 +1271,52 @@ fn raft_conf_change_loop(status: Rc<Cell<Status>>, storage: Storage) {
 
         event::wait(Event::TopologyChanged).expect("Events system must be initialized");
     }
+
+    #[allow(clippy::type_complexity)]
+    fn call_all<R, I>(
+        pool: &mut ConnectionPool,
+        ids: impl IntoIterator<Item = I>,
+        req: R,
+        timeout: Duration,
+    ) -> Result<Vec<(I, Result<R::Response, Error>)>, Error>
+    where
+        R: traft::rpc::Request + Clone,
+        I: traft::network::PeerId + Clone + std::fmt::Debug + 'static,
+    {
+        // TODO: this crap is only needed to wait until results of all
+        // the calls are ready. There are several ways to rafactor this:
+        // - we could use a std-style channel that unblocks the reading end
+        //   once all the writing ends have dropped
+        //   (fiber::Channel cannot do that for now)
+        // - using the std Futures we could use futures::join!
+        //
+        // Those things aren't implemented yet, so this is what we do
+        let ids = ids.into_iter().collect::<Vec<_>>();
+        static mut SENT_COUNT: usize = 0;
+        unsafe { SENT_COUNT = 0 };
+        let (cond_rx, cond_tx) = Rc::new(fiber::Cond::new()).into_clones();
+        let peer_count = ids.len();
+        let (rx, tx) = fiber::Channel::new(peer_count as _).into_clones();
+        for id in &ids {
+            let tx = tx.clone();
+            let cond_tx = cond_tx.clone();
+            let id_copy = id.clone();
+            pool.call(id, req.clone(), move |res| {
+                tx.send((id_copy, res)).expect("mustn't fail");
+                unsafe { SENT_COUNT += 1 };
+                if unsafe { SENT_COUNT } == peer_count {
+                    cond_tx.signal()
+                }
+            })
+            .expect("shouldn't fail");
+        }
+        // TODO: don't hard code timeout
+        if !cond_rx.wait_timeout(timeout) {
+            return Err(Error::Timeout);
+        }
+
+        Ok(rx.into_iter().take(peer_count).collect())
+    }
 }
 
 fn conf_change_single(node_id: RaftId, is_voter: bool) -> raft::ConfChangeSingle {
-- 
GitLab