From 15327a9ff2edb21858122f9d1d274a5ff94fe7c1 Mon Sep 17 00:00:00 2001
From: Georgy Moshkin <gmoshkin@picodata.io>
Date: Tue, 27 Aug 2024 20:11:39 +0300
Subject: [PATCH] refactor: move wait_index into compare_and_swap

---
 src/cas.rs                 | 60 ++++++++++++++++++++++++++++++++++++--
 src/luamod.rs              | 11 ++++---
 src/plugin/ffi.rs          | 23 +++++++--------
 src/plugin/mod.rs          | 35 +++-------------------
 src/rpc/join.rs            | 21 +++----------
 src/rpc/update_instance.rs | 26 ++++-------------
 src/schema.rs              | 25 ++++------------
 src/sql.rs                 | 39 +++++--------------------
 8 files changed, 102 insertions(+), 138 deletions(-)

diff --git a/src/cas.rs b/src/cas.rs
index ae6d52669f..c8953e9051 100644
--- a/src/cas.rs
+++ b/src/cas.rs
@@ -66,8 +66,9 @@ pub fn check_admin_dml_prohibited(dml: &Dml, as_user: UserId) -> traft::Result<(
 /// It can also return general picodata errors in cases of faulty network or storage.
 pub fn compare_and_swap(
     request: &Request,
+    wait_index: bool,
     deadline: Instant,
-) -> traft::Result<(RaftIndex, RaftTerm)> {
+) -> traft::Result<CasResult> {
     let node = node::global()?;
 
     if let Op::BatchDml { ops } = &request.op {
@@ -100,8 +101,61 @@ pub fn compare_and_swap(
         res = fiber::block_on(future);
     }
 
-    let response = res?;
-    return Ok((response.index, response.term));
+    let response = crate::unwrap_ok_or!(res,
+        Err(e) => {
+            if e.is_retriable() {
+                return Ok(CasResult::RetriableError(e));
+            } else {
+                return Err(e);
+            }
+        }
+    );
+
+    if wait_index {
+        node.wait_index(response.index, deadline.duration_since(fiber::clock()))?;
+
+        let actual_term = raft::Storage::term(&node.raft_storage, response.index)?;
+        if response.term != actual_term {
+            // Leader has changed and the entry got rolled back, ok to retry.
+            return Ok(CasResult::RetriableError(TraftError::TermMismatch {
+                requested: response.term,
+                current: actual_term,
+            }));
+        }
+    }
+
+    return Ok(CasResult::Ok((response.index, response.term)));
+}
+
+#[must_use = "You must decide if you're retrying the error or returning it to user"]
+pub enum CasResult {
+    Ok((RaftIndex, RaftTerm)),
+    RetriableError(TraftError),
+}
+
+impl CasResult {
+    #[inline(always)]
+    pub fn into_retriable_error(self) -> Option<TraftError> {
+        match self {
+            Self::RetriableError(e) => Some(e),
+            Self::Ok(_) => None,
+        }
+    }
+
+    #[inline(always)]
+    pub fn is_retriable_error(&self) -> bool {
+        matches!(self, Self::RetriableError { .. })
+    }
+
+    /// Converts the result into `std::result::Result` for your convenience if
+    /// you want to return the retriable error to the user.
+    #[inline(always)]
+    pub fn no_retries(self) -> traft::Result<(RaftIndex, RaftTerm)> {
+        match self {
+            Self::Ok(v) => Ok(v),
+            Self::RetriableError(e) => Err(e),
+        }
+    }
 }
 
 fn proc_cas_local(req: &Request) -> Result<Response> {
diff --git a/src/luamod.rs b/src/luamod.rs
index 904933f069..d5c22604eb 100644
--- a/src/luamod.rs
+++ b/src/luamod.rs
@@ -1,7 +1,7 @@
 //! Lua API exported as `_G.pico`
 //!
 
-use crate::cas::{self, compare_and_swap};
+use crate::cas;
 use crate::config::PicodataConfig;
 use crate::instance::InstanceId;
 use crate::plugin::PluginIdentifier;
@@ -688,7 +688,8 @@ pub(crate) fn setup() {
                 };
 
                 let req = crate::cas::Request::new(op, predicate, effective_user_id())?;
-                let res = compare_and_swap(&req, deadline)?;
+                let res = cas::compare_and_swap(&req, false, deadline)?;
+                let res = res.no_retries()?;
                 Ok(res)
             },
         ),
@@ -1183,7 +1184,8 @@ pub(crate) fn setup() {
                 let predicate = cas::Predicate::from_lua_args(predicate.unwrap_or_default())?;
                 let req = crate::cas::Request::new(op.into(), predicate, su.original_user_id)?;
                 let deadline = fiber::clock().saturating_add(Duration::from_secs(3));
-                let (index, _) = compare_and_swap(&req, deadline)?;
+                let res = cas::compare_and_swap(&req, false, deadline)?;
+                let (index, _) = res.no_retries()?;
                 Ok(index)
             },
         ),
@@ -1219,7 +1221,8 @@ pub(crate) fn setup() {
                     su.original_user_id,
                 )?;
                 let deadline = fiber::clock().saturating_add(Duration::from_secs(3));
-                let (index, _) = compare_and_swap(&req, deadline)?;
+                let res = cas::compare_and_swap(&req, false, deadline)?;
+                let (index, _) = res.no_retries()?;
                 Ok(index)
             },
         ),
diff --git a/src/plugin/ffi.rs b/src/plugin/ffi.rs
index 0029aebda4..d20e5231b3 100644
--- a/src/plugin/ffi.rs
+++ b/src/plugin/ffi.rs
@@ -1,4 +1,4 @@
-use crate::cas::{compare_and_swap, Bound, Range, Request};
+use crate::cas::{Bound, Range};
 use crate::info::{InstanceInfo, RaftInfo, VersionInfo};
 use crate::instance::StateVariant;
 use crate::plugin::{rpc, PluginIdentifier};
@@ -17,7 +17,6 @@ use picoplugin::transport::rpc::server::FfiRpcHandler;
 use picoplugin::util::FfiSafeBytes;
 use sbroad::ir::value::double::Double;
 use sbroad::ir::value::{LuaValue, Tuple, Value};
-use std::time::Duration;
 use std::{mem, slice};
 use tarantool::datetime::Datetime;
 use tarantool::error::IntoBoxError;
@@ -212,20 +211,18 @@ extern "C" fn pico_ffi_cas(
     predicate: types::Predicate,
     timeout: RDuration,
 ) -> RResult<ROption<RTuple!(u64, u64)>, ()> {
+    let deadline = fiber::clock().saturating_add(timeout.into());
     let op = Op::from(op);
     let pred = cas::Predicate::from(predicate);
-    let timeout = Duration::from(timeout);
     let user_id = effective_user_id();
-    let request = match Request::new(op, pred, user_id) {
-        Ok(req) => req,
-        Err(e) => {
-            return error_into_tt_error(e);
-        }
-    };
-
-    let deadline = fiber::clock().saturating_add(timeout);
-    match compare_and_swap(&request, deadline) {
-        Ok((index, term)) => ROk(RSome(Tuple2(index, term))),
+    let res = (|| -> Result<_, _> {
+        let request = cas::Request::new(op, pred, user_id)?;
+        cas::compare_and_swap(&request, false, deadline)
+    })();
+    match res {
+        Ok(cas::CasResult::Ok((index, term))) => ROk(RSome(Tuple2(index, term))),
+        Ok(cas::CasResult::RetriableError(e)) => error_into_tt_error(e),
+        // FIXME: this is wrong, just return an error instead
         Err(traft::error::Error::Timeout) => ROk(RNone),
         Err(e) => error_into_tt_error(e),
     }
diff --git a/src/plugin/mod.rs b/src/plugin/mod.rs
index 4f536fa18a..b91ebcc862 100644
--- a/src/plugin/mod.rs
+++ b/src/plugin/mod.rs
@@ -31,7 +31,6 @@ use crate::traft::node::Node;
 use crate::traft::op::PluginRaftOp;
 use crate::traft::op::{Dml, Op};
 use crate::traft::{node, RaftIndex};
-use crate::unwrap_ok_or;
 use crate::util::effective_user_id;
 use crate::{cas, tlog, traft};
 
@@ -416,21 +415,8 @@ fn do_routing_table_cas(
             ranges: ranges.clone(),
         };
         let req = crate::cas::Request::new(op.clone(), predicate, ADMIN_ID)?;
-        let res = cas::compare_and_swap(&req, deadline);
-        let (index, term) = unwrap_ok_or!(res,
-            Err(e) => {
-                if e.is_retriable() {
-                    continue;
-                } else {
-                    return Err(e);
-                }
-            }
-        );
-
-        node.wait_index(index, deadline.duration_since(Instant::now_fiber()))?;
-
-        if term != raft::Storage::term(&node.raft_storage, index)? {
-            // Leader has changed and the entry got rolled back, retry.
+        let res = cas::compare_and_swap(&req, true, deadline)?;
+        if res.is_retriable_error() {
             continue;
         }
 
@@ -535,21 +521,8 @@ fn reenterable_plugin_cas_request(
         // FIXME: access rules will be implemented in future release
         let current_user = effective_user_id();
         let req = crate::cas::Request::new(op.clone(), predicate, current_user)?;
-        let res = cas::compare_and_swap(&req, deadline);
-        let (index, term) = unwrap_ok_or!(res,
-            Err(e) => {
-                if e.is_retriable() {
-                    continue;
-                } else {
-                    return Err(e);
-                }
-            }
-        );
-
-        node.wait_index(index, deadline.duration_since(Instant::now_fiber()))?;
-
-        if term != raft::Storage::term(&node.raft_storage, index)? {
-            // Leader has changed and the entry got rolled back, retry.
+        let res = cas::compare_and_swap(&req, true, deadline)?;
+        if res.is_retriable_error() {
             continue;
         }
 
diff --git a/src/rpc/join.rs b/src/rpc/join.rs
index 1d9c78c96d..d081f7dc9e 100644
--- a/src/rpc/join.rs
+++ b/src/rpc/join.rs
@@ -127,23 +127,10 @@ pub fn handle_join_request_and_wait(req: Request, timeout: Duration) -> Result<R
             },
             ADMIN_ID,
         )?;
-        let res = cas::compare_and_swap(&cas_req, deadline);
-        let (index, term) = crate::unwrap_ok_or!(res,
-            Err(e) => {
-                if e.is_retriable() {
-                    crate::tlog!(Debug, "local CaS rejected: {e}");
-                    fiber::sleep(Duration::from_millis(250));
-                    continue;
-                } else {
-                    return Err(e);
-                }
-            }
-        );
-
-        node.wait_index(index, deadline.duration_since(fiber::clock()))?;
-
-        if term != raft::Storage::term(&node.raft_storage, index)? {
-            // Leader has changed and the entry got rolled back, retry.
+        let res = cas::compare_and_swap(&cas_req, true, deadline)?;
+        if let Some(e) = res.into_retriable_error() {
+            crate::tlog!(Debug, "local CaS rejected: {e}");
+            fiber::sleep(Duration::from_millis(250));
             continue;
         }
 
diff --git a/src/rpc/update_instance.rs b/src/rpc/update_instance.rs
index bedae9763a..1ab8351f48 100644
--- a/src/rpc/update_instance.rs
+++ b/src/rpc/update_instance.rs
@@ -220,26 +220,12 @@ pub fn handle_update_instance_request_in_governor_and_also_wait_too(
             },
             ADMIN_ID,
         )?;
-        let res = cas::compare_and_swap(&cas_req, deadline);
-        let (index, term) = crate::unwrap_ok_or!(res,
-            Err(e) => {
-                if req.dont_retry {
-                    return Err(e);
-                }
-                if e.is_retriable() {
-                    crate::tlog!(Debug, "local CaS rejected: {e}");
-                    fiber::sleep(Duration::from_millis(250));
-                    continue;
-                } else {
-                    return Err(e);
-                }
-            }
-        );
-
-        node.wait_index(index, deadline.duration_since(fiber::clock()))?;
-
-        if term != raft::Storage::term(raft_storage, index)? {
-            // Leader has changed and the entry got rolled back, retry.
+        let res = cas::compare_and_swap(&cas_req, true, deadline)?;
+        if req.dont_retry {
+            res.no_retries()?;
+        } else if let Some(e) = res.into_retriable_error() {
+            crate::tlog!(Debug, "local CaS rejected: {e}");
+            fiber::sleep(Duration::from_millis(250));
             continue;
         }
 
diff --git a/src/schema.rs b/src/schema.rs
index aa54fc0c8a..b9b35c5273 100644
--- a/src/schema.rs
+++ b/src/schema.rs
@@ -1,5 +1,5 @@
 use crate::access_control::UserMetadataKind;
-use crate::cas::{self, compare_and_swap, Request};
+use crate::cas;
 use crate::config::DEFAULT_USERNAME;
 use crate::instance::InstanceId;
 use crate::pico_service::pico_service_password;
@@ -2482,25 +2482,12 @@ pub fn abort_ddl(deadline: Instant) -> traft::Result<RaftIndex> {
             message: "explicit abort by user".into(),
             instance_id,
         };
-        let req = Request::new(Op::DdlAbort { cause }, predicate, effective_user_id())?;
-        let res = compare_and_swap(&req, deadline);
-        let (index, term) = crate::unwrap_ok_or!(res,
-            Err(e) => {
-                if e.is_retriable() {
-                    continue;
-                } else {
-                    return Err(e);
-                }
-            }
-        );
-
-        node.wait_index(index, deadline.duration_since(fiber::clock()))?;
-
-        if raft::Storage::term(&node.raft_storage, index)? != term {
-            // leader switched - retry
-            continue;
+        let req = cas::Request::new(Op::DdlAbort { cause }, predicate, effective_user_id())?;
+        let res = cas::compare_and_swap(&req, true, deadline)?;
+        match res {
+            cas::CasResult::RetriableError(_) => continue,
+            cas::CasResult::Ok((index, _)) => return Ok(index),
         }
-        return Ok(index);
     }
 }
 
diff --git a/src/sql.rs b/src/sql.rs
index 970a3140a5..7e3f174268 100644
--- a/src/sql.rs
+++ b/src/sql.rs
@@ -16,7 +16,7 @@ use crate::traft::node::Node as TraftNode;
 use crate::traft::op::{Acl as OpAcl, Ddl as OpDdl, Dml, DmlKind, Op};
 use crate::traft::{self, node};
 use crate::util::{duration_from_secs_f64_clamped, effective_user_id};
-use crate::{cas, tlog, unwrap_ok_or};
+use crate::{cas, tlog};
 
 use opentelemetry::{baggage::BaggageExt, Context, KeyValue};
 use sbroad::debug;
@@ -1382,23 +1382,11 @@ pub(crate) fn reenterable_schema_change_request(
             ranges: cas::schema_change_ranges().into(),
         };
         let req = crate::cas::Request::new(op, predicate, current_user)?;
-        let res = cas::compare_and_swap(&req, deadline);
-        let (index, term) = unwrap_ok_or!(res,
-            Err(e) => {
-                if e.is_retriable() {
-                    continue 'retry;
-                } else {
-                    return Err(e);
-                }
-            }
-        );
-
-        node.wait_index(index, deadline.duration_since(Instant::now_fiber()))?;
-
-        if term != raft::Storage::term(&node.raft_storage, index)? {
-            // Leader has changed and the entry got rolled back, retry.
-            continue 'retry;
-        }
+        let res = cas::compare_and_swap(&req, true, deadline)?;
+        let index = match res {
+            cas::CasResult::Ok((index, _)) => index,
+            cas::CasResult::RetriableError(_) => continue,
+        };
 
         if is_ddl_prepare {
             wait_for_ddl_commit(index, deadline.duration_since(Instant::now_fiber()))?;
@@ -1586,7 +1574,6 @@ fn do_dml_on_global_tbl(mut query: Query<RouterRuntime>) -> traft::Result<Consum
     // there.
     with_su(ADMIN_ID, || -> traft::Result<ConsumerResult> {
         let timeout = Duration::from_secs(DEFAULT_QUERY_TIMEOUT);
-        let node = node::global()?;
         let deadline = Instant::now_fiber().saturating_add(timeout);
 
         let ops_count = ops.len();
@@ -1598,18 +1585,8 @@ fn do_dml_on_global_tbl(mut query: Query<RouterRuntime>) -> traft::Result<Consum
             ranges: vec![],
         };
         let cas_req = crate::cas::Request::new(op, predicate, current_user)?;
-        let (index, term) = crate::cas::compare_and_swap(&cas_req, deadline)?;
-
-        node.wait_index(index, deadline.duration_since(Instant::now_fiber()))?;
-
-        let actual_term = raft::Storage::term(&raft_node.raft_storage, index)?;
-        if term != actual_term {
-            // Leader has changed and the entry got rolled back.
-            return Err(Error::TermMismatch {
-                requested: term,
-                current: actual_term,
-            });
-        }
+        let res = crate::cas::compare_and_swap(&cas_req, true, deadline)?;
+        res.no_retries()?;
 
         Ok(ConsumerResult {
             row_count: ops_count as u64,
-- 
GitLab