From ec0ce302efe21d51c599eef2962e2b7f6798ba8a Mon Sep 17 00:00:00 2001 From: Georgy Moshkin <gmoshkin@picodata.io> Date: Thu, 20 Oct 2022 15:11:00 +0300 Subject: [PATCH] refactor(governor): support passing different requests in call_all --- src/traft/node.rs | 77 +++++++++++++++++++-------------------- src/traft/rpc/sharding.rs | 2 +- 2 files changed, 38 insertions(+), 41 deletions(-) diff --git a/src/traft/node.rs b/src/traft/node.rs index 3647629422..a6be8aac09 100644 --- a/src/traft/node.rs +++ b/src/traft/node.rs @@ -20,6 +20,7 @@ use std::cell::Cell; use std::collections::HashMap; use std::collections::HashSet; use std::convert::TryFrom; +use std::iter::repeat; use std::rc::Rc; use std::time::Duration; use std::time::Instant; @@ -1004,7 +1005,7 @@ fn raft_conf_change_loop(status: Rc<Cell<Status>>, storage: Storage) { continue 'governor; } - let raft_id = status.get().id; + let leader_id = status.get().id; let peers = storage.peers.all_peers().unwrap(); let term = storage.raft.term().unwrap().unwrap_or(0); let node = global().expect("must be initialized"); @@ -1128,18 +1129,17 @@ fn raft_conf_change_loop(status: Rc<Cell<Status>>, storage: Storage) { ); let replicaset_size = replicaset_iids.len(); - let res = call_all( - &mut pool, - replicaset_iids.clone(), - replication::Request { - replicaset_instances: replicaset_iids, + let reqs = replicaset_iids + .iter() + .cloned() + .zip(repeat(replication::Request { + replicaset_instances: replicaset_iids.clone(), replicaset_id: replicaset_id.clone(), // TODO: what if someone goes offline/expelled? promote: replicaset_size == 1, - }, - // TODO: don't hard code timeout - Duration::from_secs(3), - ); + })); + // TODO: don't hard code timeout + let res = call_all(&mut pool, reqs, Duration::from_secs(3)); let res = unwrap_ok_or!(res, Err(e) => { tlog!(Warning, "failed to configure replication: {e}"); @@ -1226,18 +1226,18 @@ fn raft_conf_change_loop(status: Rc<Cell<Status>>, storage: Storage) { if need_sharding { let res = (|| -> Result<(), Error> { // TODO: filter out Offline & Expelled peers - let peer_ids = peers.iter().map(|peer| peer.instance_id.clone()); - let res = call_all( - &mut pool, - peer_ids, - sharding::Request { - leader_id: raft_id, - term, - weights: None, - }, - // TODO: don't hard code timeout - Duration::from_secs(3), - )?; + let reqs = peers.iter().map(|peer| { + ( + peer.instance_id.clone(), + sharding::Request { + leader_id, + term, + ..Default::default() + }, + ) + }); + // TODO: don't hard code timeout + let res = call_all(&mut pool, reqs, Duration::from_secs(3))?; let cluster_id = storage .raft @@ -1276,17 +1276,15 @@ fn raft_conf_change_loop(status: Rc<Cell<Status>>, storage: Storage) { if maybe_need_weights_update { let res = if let Some(new_weights) = get_new_weights(&peers, &storage.state) { (|| -> Result<(), Error> { - let res = call_all( - &mut pool, - peers.iter().map(|peer| peer.instance_id.clone()), - sharding::Request { - leader_id: raft_id, - term, - weights: Some(new_weights.clone()), - }, - // TODO: don't hard code timeout - Duration::from_secs(3), - )?; + let peer_ids = peers.iter().map(|peer| peer.instance_id.clone()); + let reqs = peer_ids.zip(repeat(sharding::Request { + leader_id, + term, + weights: Some(new_weights.clone()), + ..Default::default() + })); + // TODO: don't hard code timeout + let res = call_all(&mut pool, reqs, Duration::from_secs(3))?; let cluster_id = storage.raft.cluster_id()?.unwrap(); for (peer_iid, resp) in res { @@ -1348,12 +1346,11 @@ fn raft_conf_change_loop(status: Rc<Cell<Status>>, storage: Storage) { #[allow(clippy::type_complexity)] fn call_all<R, I>( pool: &mut ConnectionPool, - ids: impl IntoIterator<Item = I>, - req: R, + reqs: impl IntoIterator<Item = (I, R)>, timeout: Duration, ) -> Result<Vec<(I, Result<R::Response, Error>)>, Error> where - R: traft::rpc::Request + Clone, + R: traft::rpc::Request, I: traft::network::PeerId + Clone + std::fmt::Debug + 'static, { // TODO: this crap is only needed to wait until results of all @@ -1364,17 +1361,17 @@ fn raft_conf_change_loop(status: Rc<Cell<Status>>, storage: Storage) { // - using the std Futures we could use futures::join! // // Those things aren't implemented yet, so this is what we do - let ids = ids.into_iter().collect::<Vec<_>>(); + let reqs = reqs.into_iter().collect::<Vec<_>>(); static mut SENT_COUNT: usize = 0; unsafe { SENT_COUNT = 0 }; let (cond_rx, cond_tx) = Rc::new(fiber::Cond::new()).into_clones(); - let peer_count = ids.len(); + let peer_count = reqs.len(); let (rx, tx) = fiber::Channel::new(peer_count as _).into_clones(); - for id in &ids { + for (id, req) in reqs { let tx = tx.clone(); let cond_tx = cond_tx.clone(); let id_copy = id.clone(); - pool.call(id, req.clone(), move |res| { + pool.call(&id, req, move |res| { tx.send((id_copy, res)).expect("mustn't fail"); unsafe { SENT_COUNT += 1 }; if unsafe { SENT_COUNT } == peer_count { diff --git a/src/traft/rpc/sharding.rs b/src/traft/rpc/sharding.rs index aabcd75991..133c303a3f 100644 --- a/src/traft/rpc/sharding.rs +++ b/src/traft/rpc/sharding.rs @@ -48,7 +48,7 @@ fn proc_sharding(req: Request) -> Result<Response, Error> { } /// Request to configure vshard. -#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +#[derive(Clone, Default, Debug, serde::Serialize, serde::Deserialize)] pub struct Request { pub leader_id: RaftId, pub term: RaftTerm, -- GitLab