From fb30acd1b8cf474c4333554587899d4d479b8feb Mon Sep 17 00:00:00 2001
From: Georgy Moshkin <gmoshkin@picodata.io>
Date: Mon, 31 Jul 2023 12:04:46 +0300
Subject: [PATCH] refactor: reuse a single ConnectionPool instance between raft
 node & governor

---
 src/governor/mod.rs  |  29 +++++------
 src/sync.rs          |  17 +++++--
 src/traft/network.rs | 117 +++++++++++++++++++++++--------------------
 src/traft/node.rs    |  26 ++++++----
 4 files changed, 108 insertions(+), 81 deletions(-)

diff --git a/src/governor/mod.rs b/src/governor/mod.rs
index c79f85803d..0f61d4cb03 100644
--- a/src/governor/mod.rs
+++ b/src/governor/mod.rs
@@ -1,4 +1,5 @@
 use std::collections::HashMap;
+use std::rc::Rc;
 use std::time::Duration;
 
 use ::tarantool::fiber;
@@ -14,7 +15,7 @@ use crate::storage::Clusterwide;
 use crate::storage::ToEntryIter as _;
 use crate::tlog;
 use crate::traft::error::Error;
-use crate::traft::network::{ConnectionPool, WorkerOptions};
+use crate::traft::network::ConnectionPool;
 use crate::traft::node::global;
 use crate::traft::node::Status;
 use crate::traft::raft_storage::RaftSpaceAccess;
@@ -30,6 +31,7 @@ use plan::action_plan;
 use plan::stage::*;
 
 impl Loop {
+    const RPC_TIMEOUT: Duration = Duration::from_secs(1);
     const SYNC_TIMEOUT: Duration = Duration::from_secs(10);
     const RETRY_TIMEOUT: Duration = Duration::from_millis(250);
     const UPDATE_INSTANCE_TIMEOUT: Duration = Duration::from_secs(3);
@@ -166,7 +168,7 @@ impl Loop {
                         "replicaset_id" => %replicaset_id,
                     ]
                     async {
-                        pool.call(instance_id, &rpc)?
+                        pool.call(instance_id, &rpc, Self::RPC_TIMEOUT)?
                             // TODO: don't hard code timeout
                             .timeout(Duration::from_secs(3))
                             .await?
@@ -197,7 +199,7 @@ impl Loop {
                         let mut fs = vec![];
                         for instance_id in targets {
                             tlog!(Info, "calling rpc::sharding"; "instance_id" => %instance_id);
-                            let resp = pool.call(instance_id, &rpc)?;
+                            let resp = pool.call(instance_id, &rpc, Self::RPC_TIMEOUT)?;
                             fs.push(async move {
                                 resp.await.map_err(|e| {
                                     tlog!(Warning, "failed calling rpc::sharding: {e}";
@@ -237,7 +239,7 @@ impl Loop {
                         "replicaset_id" => %replicaset_id,
                     ]
                     async {
-                        pool.call(master_id, &rpc)?
+                        pool.call(master_id, &rpc, Self::RPC_TIMEOUT)?
                             .timeout(Duration::from_secs(3))
                             .await?
                     }
@@ -270,7 +272,7 @@ impl Loop {
                         for instance_id in targets {
                             tlog!(Info, "calling rpc::replication"; "instance_id" => %instance_id);
                             rpc.is_master = instance_id == master_id;
-                            let resp = pool.call(instance_id, &rpc)?;
+                            let resp = pool.call(instance_id, &rpc, Self::RPC_TIMEOUT)?;
                             fs.push(async move {
                                 match resp.await {
                                     Ok(resp) => {
@@ -314,7 +316,7 @@ impl Loop {
                         let mut fs = vec![];
                         for instance_id in targets {
                             tlog!(Info, "calling rpc::sharding"; "instance_id" => %instance_id);
-                            let resp = pool.call(instance_id, &rpc)?;
+                            let resp = pool.call(instance_id, &rpc, Self::RPC_TIMEOUT)?;
                             fs.push(async move {
                                 match resp.await {
                                     Ok(_) => {
@@ -357,7 +359,7 @@ impl Loop {
                     ]
                     async {
                         pool
-                            .call(target, &rpc)?
+                            .call(target, &rpc, Self::RPC_TIMEOUT)?
                             .timeout(Loop::SYNC_TIMEOUT)
                             .await?;
                         node.propose_and_wait(op, Duration::from_secs(3))??
@@ -395,7 +397,7 @@ impl Loop {
                         let mut fs = vec![];
                         for instance_id in targets {
                             tlog!(Info, "calling rpc::sharding"; "instance_id" => %instance_id);
-                            let resp = pool.call(instance_id, &rpc)?;
+                            let resp = pool.call(instance_id, &rpc, Self::RPC_TIMEOUT)?;
                             fs.push(async move {
                                 match resp.await {
                                     Ok(_) => {
@@ -450,7 +452,7 @@ impl Loop {
                         let mut fs = vec![];
                         for instance_id in targets {
                             tlog!(Info, "calling proc_apply_schema_change"; "instance_id" => %instance_id);
-                            let resp = pool.call(instance_id, &rpc)?;
+                            let resp = pool.call(instance_id, &rpc, Self::RPC_TIMEOUT)?;
                             fs.push(async move {
                                 match resp.await {
                                     Ok(rpc::ddl_apply::Response::Ok) => {
@@ -519,6 +521,7 @@ impl Loop {
     }
 
     pub fn start(
+        pool: Rc<ConnectionPool>,
         status: watch::Receiver<Status>,
         storage: Clusterwide,
         raft_storage: RaftSpaceAccess,
@@ -530,12 +533,6 @@ impl Loop {
 
         let (waker_tx, waker_rx) = watch::channel(());
 
-        let opts = WorkerOptions {
-            call_timeout: Duration::from_secs(1),
-            ..Default::default()
-        };
-        let pool = ConnectionPool::new(args.storage.clone(), opts);
-
         let state = State {
             status,
             waker: waker_rx,
@@ -566,5 +563,5 @@ struct Args {
 struct State {
     status: watch::Receiver<Status>,
     waker: watch::Receiver<()>,
-    pool: ConnectionPool,
+    pool: Rc<ConnectionPool>,
 }
diff --git a/src/sync.rs b/src/sync.rs
index e088ed5a01..9c2cf7753e 100644
--- a/src/sync.rs
+++ b/src/sync.rs
@@ -34,8 +34,13 @@ pub async fn call_get_vclock(
     pool: &ConnectionPool,
     instance_id: &impl IdOfInstance,
 ) -> traft::Result<Vclock> {
-    let (vclock,): (Vclock,) = pool
-        .call_raw(instance_id, crate::stringify_cfunc!(proc_get_vclock), &())?
+    let vclock: Vclock = pool
+        .call_raw(
+            instance_id,
+            crate::stringify_cfunc!(proc_get_vclock),
+            &(),
+            None,
+        )?
         .await?;
     Ok(vclock)
 }
@@ -109,7 +114,12 @@ pub async fn call_get_index(
     instance_id: &impl IdOfInstance,
 ) -> traft::Result<RaftIndex> {
     let (index,): (RaftIndex,) = pool
-        .call_raw(instance_id, crate::stringify_cfunc!(proc_get_index), &())?
+        .call_raw(
+            instance_id,
+            crate::stringify_cfunc!(proc_get_index),
+            &(),
+            None,
+        )?
         .await?;
     Ok(index)
 }
@@ -202,6 +212,7 @@ mod tests {
                 target: Vclock::current(),
                 timeout: 1.0,
             },
+            None,
         )
         .unwrap()
         .await
diff --git a/src/traft/network.rs b/src/traft/network.rs
index 311fd48242..f4c7a24c85 100644
--- a/src/traft/network.rs
+++ b/src/traft/network.rs
@@ -57,6 +57,7 @@ impl Default for WorkerOptions {
 struct Request {
     proc: &'static str,
     args: TupleBuffer,
+    timeout: Option<Duration>,
     on_result: Box<dyn FnOnce(Result<Tuple>)>,
 }
 
@@ -69,6 +70,7 @@ impl Request {
         Self {
             proc,
             args,
+            timeout: None,
             on_result: Box::new(on_result),
         }
     }
@@ -184,7 +186,10 @@ impl PoolWorker {
                     Box::pin(async move {
                         client
                             .call(request.proc, &request.args)
-                            .timeout(call_timeout)
+                            // TODO: it would be better to get a deadline from
+                            // the caller instead of the timeout, so we can more
+                            // accurately limit the time of the given rpc request.
+                            .timeout(request.timeout.unwrap_or(call_timeout))
                             .await
                     }),
                 ));
@@ -265,28 +270,16 @@ impl PoolWorker {
     /// - in case peer was disconnected
     /// - in case response failed to deserialize
     /// - in case peer responded with an error
-    pub fn rpc<R>(&self, request: &R, cb: impl FnOnce(Result<R::Response>) + 'static)
-    where
+    #[inline(always)]
+    pub fn rpc<R>(
+        &self,
+        request: &R,
+        timeout: Option<Duration>,
+        cb: impl FnOnce(Result<R::Response>) + 'static,
+    ) where
         R: rpc::RequestArgs,
     {
-        let args = unwrap_ok_or!(request.to_tuple_buffer(),
-            Err(e) => { return cb(Err(e.into())) }
-        );
-        let convert_result = |bytes: Result<Tuple>| {
-            let tuple: Tuple = bytes?;
-            let ((res,),) = tuple.decode()?;
-            Ok(res)
-        };
-        self.inbox
-            .send(Request::new(R::PROC_NAME, args, move |res| {
-                cb(convert_result(res))
-            }));
-        if self.inbox_ready.send(()).is_err() {
-            tlog!(
-                Warning,
-                "failed sending request to peer, worker loop receiver dropped"
-            );
-        }
+        self.rpc_raw(R::PROC_NAME, request, timeout, cb)
     }
 
     /// Send an RPC `request` and invoke `cb` whenever the result is ready.
@@ -300,6 +293,7 @@ impl PoolWorker {
         &self,
         proc: &'static str,
         args: &Args,
+        timeout: Option<Duration>,
         cb: impl FnOnce(Result<Response>) + 'static,
     ) where
         Args: ToTupleBuffer,
@@ -310,11 +304,31 @@ impl PoolWorker {
         );
         let convert_result = |bytes: Result<Tuple>| {
             let tuple: Tuple = bytes?;
-            let (res,) = tuple.decode()?;
+            // NOTE: this double layer of single element tuple here is
+            // intentional. The thing is that tarantool wraps all the returned
+            // values from stored procs into an msgpack array. This is true for
+            // both native and lua procs. However in case of lua this outermost
+            // array helps with multiple return values such that if a lua proc
+            // returns 2 values, the messagepack representation would be an
+            // array of 2 elements. But for some reason native procs get a
+            // different treatment: their return values are always wrapped in an
+            // outermost array of one element (!) which contains an array whose
+            // elements are actual values returned by proc. To be exact, each
+            // element of this inner array is a value passed to box_return_mp
+            // therefore there's exactly as many elements as there were calls to
+            // box_return_mp during execution of the proc.
+            // The way #[tarantool::proc] is implemented we always call
+            // box_return_mp exactly once therefore procs defined such way
+            // always have their return values wrapped in 2 layers of one
+            // element arrays.
+            // For that reason we unwrap those 2 layers here and pass to the
+            // user just the value they returned from their #[tarantool::proc].
+            let ((res,),) = tuple.decode()?;
             Ok(res)
         };
-        self.inbox
-            .send(Request::new(proc, args, move |res| cb(convert_result(res))));
+        let mut request = Request::new(proc, args, move |res| cb(convert_result(res)));
+        request.timeout = timeout;
+        self.inbox.send(request);
         if self.inbox_ready.send(()).is_err() {
             tlog!(
                 Warning,
@@ -437,34 +451,21 @@ impl ConnectionPool {
     /// Send a request to instance with `id` (see `IdOfInstance`) returning a
     /// future.
     ///
+    /// If `timeout` is None, the `WorkerOptions::call_timeout` is used.
+    ///
     /// If the request failed, it's a responsibility of the caller
     /// to re-send it later.
+    #[inline(always)]
     pub fn call<R>(
         &self,
         id: &impl IdOfInstance,
         req: &R,
+        timeout: impl Into<Option<Duration>>,
     ) -> Result<impl Future<Output = Result<R::Response>>>
     where
         R: rpc::RequestArgs,
     {
-        let (tx, mut rx) = oneshot::channel();
-        id.get_or_create_in(self)?.rpc(req, move |res| {
-            if tx.send(res).is_err() {
-                tlog!(
-                    Debug,
-                    "rpc response ignored because caller dropped the future"
-                )
-            }
-        });
-
-        // We use an explicit type implementing Future instead of defining an
-        // async fn, because we need to tell rust explicitly that the `id` &
-        // `req` arguments are not borrowed by the returned future.
-        let f = poll_fn(move |cx| {
-            let rx = Pin::new(&mut rx);
-            Future::poll(rx, cx).map(|r| r.unwrap_or_else(|_| Err(Error::other("disconnected"))))
-        });
-        Ok(f)
+        self.call_raw(id, R::PROC_NAME, req, timeout.into())
     }
 
     /// Call an rpc on instance with `id` (see `IdOfInstance`) returning a
@@ -473,6 +474,8 @@ impl ConnectionPool {
     /// This method is similar to [`Self::call`] but allows to call rpcs
     /// without using [`rpc::Request`] trait.
     ///
+    /// If `timeout` is None, the `WorkerOptions::call_timeout` is used.
+    ///
     /// If the request failed, it's a responsibility of the caller
     /// to re-send it later.
     pub fn call_raw<Args, Response>(
@@ -480,20 +483,22 @@ impl ConnectionPool {
         id: &impl IdOfInstance,
         proc: &'static str,
         args: &Args,
+        timeout: Option<Duration>,
     ) -> Result<impl Future<Output = Result<Response>>>
     where
         Response: serde::de::DeserializeOwned + 'static,
         Args: ToTupleBuffer,
     {
         let (tx, mut rx) = oneshot::channel();
-        id.get_or_create_in(self)?.rpc_raw(proc, args, move |res| {
-            if tx.send(res).is_err() {
-                tlog!(
-                    Debug,
-                    "rpc response ignored because caller dropped the future"
-                )
-            }
-        });
+        id.get_or_create_in(self)?
+            .rpc_raw(proc, args, timeout, move |res| {
+                if tx.send(res).is_err() {
+                    tlog!(
+                        Debug,
+                        "rpc response ignored because caller dropped the future"
+                    )
+                }
+            });
 
         // We use an explicit type implementing Future instead of defining an
         // async fn, because we need to tell rust explicitly that the `id` &
@@ -567,9 +572,15 @@ mod tests {
         l.exec(
             r#"
             function test_stored_proc(a, b)
-                return a + b
+                -- Tarantool always wraps return values from native stored procs
+                -- into an additional array, while it doesn't do that for lua
+                -- procs. Our network module is implemented to expect that
+                -- additional layer of array, so seeing how this test proc is
+                -- intended to emulate one of our rust procs, we explicitly
+                -- add a table layer.
+                return {a + b}
             end
-            
+
             box.schema.func.create('test_stored_proc')
             "#,
         )
@@ -591,7 +602,7 @@ mod tests {
             .unwrap();
 
         let result: u32 = fiber::block_on(
-            pool.call_raw(&instance.raft_id, "test_stored_proc", &(1u32, 2u32))
+            pool.call_raw(&instance.raft_id, "test_stored_proc", &(1u32, 2u32), None)
                 .unwrap(),
         )
         .unwrap();
diff --git a/src/traft/node.rs b/src/traft/node.rs
index 5ab2eabe30..38aa6bab71 100644
--- a/src/traft/node.rs
+++ b/src/traft/node.rs
@@ -161,7 +161,20 @@ impl Node {
     /// **This function yields**
     pub fn new(storage: Clusterwide, raft_storage: RaftSpaceAccess) -> Result<Self, RaftError> {
         let topology = Rc::new(RefCell::new(Topology::from(storage.clone())));
-        let node_impl = NodeImpl::new(storage.clone(), raft_storage.clone(), topology.clone())?;
+
+        let opts = WorkerOptions {
+            raft_msg_handler: stringify_cfunc!(proc_raft_interact),
+            call_timeout: MainLoop::TICK.saturating_mul(4),
+            ..Default::default()
+        };
+        let pool = Rc::new(ConnectionPool::new(storage.clone(), opts));
+
+        let node_impl = NodeImpl::new(
+            pool.clone(),
+            storage.clone(),
+            raft_storage.clone(),
+            topology.clone(),
+        )?;
 
         let raft_id = node_impl.raft_id();
         let status = node_impl.status.subscribe();
@@ -173,6 +186,7 @@ impl Node {
             raft_id,
             main_loop: MainLoop::start(node_impl.clone(), watchers.clone()), // yields
             governor_loop: governor::Loop::start(
+                pool,
                 status.clone(),
                 storage.clone(),
                 raft_storage.clone(),
@@ -535,13 +549,14 @@ pub(crate) struct NodeImpl {
     joint_state_latch: KVCell<RaftIndex, oneshot::Sender<Result<(), RaftError>>>,
     storage: Clusterwide,
     raft_storage: RaftSpaceAccess,
-    pool: ConnectionPool,
+    pool: Rc<ConnectionPool>,
     lc: LogicalClock,
     status: watch::Sender<Status>,
 }
 
 impl NodeImpl {
     fn new(
+        pool: Rc<ConnectionPool>,
         storage: Clusterwide,
         raft_storage: RaftSpaceAccess,
         topology: Rc<RefCell<Topology>>,
@@ -559,13 +574,6 @@ impl NodeImpl {
             LogicalClock::new(raft_id, gen)
         };
 
-        let opts = WorkerOptions {
-            raft_msg_handler: stringify_cfunc!(proc_raft_interact),
-            call_timeout: MainLoop::TICK.saturating_mul(4),
-            ..Default::default()
-        };
-        let pool = ConnectionPool::new(storage.clone(), opts);
-
         let cfg = raft::Config {
             id: raft_id,
             applied,
-- 
GitLab