From ec0ce302efe21d51c599eef2962e2b7f6798ba8a Mon Sep 17 00:00:00 2001
From: Georgy Moshkin <gmoshkin@picodata.io>
Date: Thu, 20 Oct 2022 15:11:00 +0300
Subject: [PATCH] refactor(governor): support passing different requests in
 call_all

---
 src/traft/node.rs         | 77 +++++++++++++++++++--------------------
 src/traft/rpc/sharding.rs |  2 +-
 2 files changed, 38 insertions(+), 41 deletions(-)

diff --git a/src/traft/node.rs b/src/traft/node.rs
index 3647629422..a6be8aac09 100644
--- a/src/traft/node.rs
+++ b/src/traft/node.rs
@@ -20,6 +20,7 @@ use std::cell::Cell;
 use std::collections::HashMap;
 use std::collections::HashSet;
 use std::convert::TryFrom;
+use std::iter::repeat;
 use std::rc::Rc;
 use std::time::Duration;
 use std::time::Instant;
@@ -1004,7 +1005,7 @@ fn raft_conf_change_loop(status: Rc<Cell<Status>>, storage: Storage) {
             continue 'governor;
         }
 
-        let raft_id = status.get().id;
+        let leader_id = status.get().id;
         let peers = storage.peers.all_peers().unwrap();
         let term = storage.raft.term().unwrap().unwrap_or(0);
         let node = global().expect("must be initialized");
@@ -1128,18 +1129,17 @@ fn raft_conf_change_loop(status: Rc<Cell<Status>>, storage: Storage) {
             );
 
             let replicaset_size = replicaset_iids.len();
-            let res = call_all(
-                &mut pool,
-                replicaset_iids.clone(),
-                replication::Request {
-                    replicaset_instances: replicaset_iids,
+            let reqs = replicaset_iids
+                .iter()
+                .cloned()
+                .zip(repeat(replication::Request {
+                    replicaset_instances: replicaset_iids.clone(),
                     replicaset_id: replicaset_id.clone(),
                     // TODO: what if someone goes offline/expelled?
                     promote: replicaset_size == 1,
-                },
-                // TODO: don't hard code timeout
-                Duration::from_secs(3),
-            );
+                }));
+            // TODO: don't hard code timeout
+            let res = call_all(&mut pool, reqs, Duration::from_secs(3));
             let res = unwrap_ok_or!(res,
                 Err(e) => {
                     tlog!(Warning, "failed to configure replication: {e}");
@@ -1226,18 +1226,18 @@ fn raft_conf_change_loop(status: Rc<Cell<Status>>, storage: Storage) {
         if need_sharding {
             let res = (|| -> Result<(), Error> {
                 // TODO: filter out Offline & Expelled peers
-                let peer_ids = peers.iter().map(|peer| peer.instance_id.clone());
-                let res = call_all(
-                    &mut pool,
-                    peer_ids,
-                    sharding::Request {
-                        leader_id: raft_id,
-                        term,
-                        weights: None,
-                    },
-                    // TODO: don't hard code timeout
-                    Duration::from_secs(3),
-                )?;
+                let reqs = peers.iter().map(|peer| {
+                    (
+                        peer.instance_id.clone(),
+                        sharding::Request {
+                            leader_id,
+                            term,
+                            ..Default::default()
+                        },
+                    )
+                });
+                // TODO: don't hard code timeout
+                let res = call_all(&mut pool, reqs, Duration::from_secs(3))?;
 
                 let cluster_id = storage
                     .raft
@@ -1276,17 +1276,15 @@ fn raft_conf_change_loop(status: Rc<Cell<Status>>, storage: Storage) {
         if maybe_need_weights_update {
             let res = if let Some(new_weights) = get_new_weights(&peers, &storage.state) {
                 (|| -> Result<(), Error> {
-                    let res = call_all(
-                        &mut pool,
-                        peers.iter().map(|peer| peer.instance_id.clone()),
-                        sharding::Request {
-                            leader_id: raft_id,
-                            term,
-                            weights: Some(new_weights.clone()),
-                        },
-                        // TODO: don't hard code timeout
-                        Duration::from_secs(3),
-                    )?;
+                    let peer_ids = peers.iter().map(|peer| peer.instance_id.clone());
+                    let reqs = peer_ids.zip(repeat(sharding::Request {
+                        leader_id,
+                        term,
+                        weights: Some(new_weights.clone()),
+                        ..Default::default()
+                    }));
+                    // TODO: don't hard code timeout
+                    let res = call_all(&mut pool, reqs, Duration::from_secs(3))?;
 
                     let cluster_id = storage.raft.cluster_id()?.unwrap();
                     for (peer_iid, resp) in res {
@@ -1348,12 +1346,11 @@ fn raft_conf_change_loop(status: Rc<Cell<Status>>, storage: Storage) {
     #[allow(clippy::type_complexity)]
     fn call_all<R, I>(
         pool: &mut ConnectionPool,
-        ids: impl IntoIterator<Item = I>,
-        req: R,
+        reqs: impl IntoIterator<Item = (I, R)>,
         timeout: Duration,
     ) -> Result<Vec<(I, Result<R::Response, Error>)>, Error>
     where
-        R: traft::rpc::Request + Clone,
+        R: traft::rpc::Request,
         I: traft::network::PeerId + Clone + std::fmt::Debug + 'static,
     {
         // TODO: this crap is only needed to wait until results of all
@@ -1364,17 +1361,17 @@ fn raft_conf_change_loop(status: Rc<Cell<Status>>, storage: Storage) {
         // - using the std Futures we could use futures::join!
         //
         // Those things aren't implemented yet, so this is what we do
-        let ids = ids.into_iter().collect::<Vec<_>>();
+        let reqs = reqs.into_iter().collect::<Vec<_>>();
         static mut SENT_COUNT: usize = 0;
         unsafe { SENT_COUNT = 0 };
         let (cond_rx, cond_tx) = Rc::new(fiber::Cond::new()).into_clones();
-        let peer_count = ids.len();
+        let peer_count = reqs.len();
         let (rx, tx) = fiber::Channel::new(peer_count as _).into_clones();
-        for id in &ids {
+        for (id, req) in reqs {
             let tx = tx.clone();
             let cond_tx = cond_tx.clone();
             let id_copy = id.clone();
-            pool.call(id, req.clone(), move |res| {
+            pool.call(&id, req, move |res| {
                 tx.send((id_copy, res)).expect("mustn't fail");
                 unsafe { SENT_COUNT += 1 };
                 if unsafe { SENT_COUNT } == peer_count {
diff --git a/src/traft/rpc/sharding.rs b/src/traft/rpc/sharding.rs
index aabcd75991..133c303a3f 100644
--- a/src/traft/rpc/sharding.rs
+++ b/src/traft/rpc/sharding.rs
@@ -48,7 +48,7 @@ fn proc_sharding(req: Request) -> Result<Response, Error> {
 }
 
 /// Request to configure vshard.
-#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
+#[derive(Clone, Default, Debug, serde::Serialize, serde::Deserialize)]
 pub struct Request {
     pub leader_id: RaftId,
     pub term: RaftTerm,
-- 
GitLab