From 575a9f29bd5854ffffe864aa4909767f0611449d Mon Sep 17 00:00:00 2001
From: Georgy Moshkin <gmoshkin@picodata.io>
Date: Tue, 27 Aug 2024 23:30:13 +0300
Subject: [PATCH] fix: unify which term is used for cas predicates

CAS predicates should always contain the current raft term, because this
is what's explicitly checked in proc_cas.

Note that this will sometimes result in a EntryTermMismatch error
because the latest applied entry may have a different term in case the
election has started but not finished yet.

We could add this check to all the client cas precondition checks, but
it's not a big deal, because we handle the situation correctly anyways.
---
 src/cas.rs                 | 34 ++++++++++++++++++++++++++++++++--
 src/luamod.rs              |  9 +--------
 src/plugin/mod.rs          | 16 ++--------------
 src/rpc/join.rs            | 12 ++----------
 src/rpc/update_instance.rs | 12 ++----------
 src/schema.rs              | 20 ++++++++------------
 src/sql.rs                 | 16 ++--------------
 7 files changed, 49 insertions(+), 70 deletions(-)

diff --git a/src/cas.rs b/src/cas.rs
index c8953e9051..07152c3152 100644
--- a/src/cas.rs
+++ b/src/cas.rs
@@ -514,18 +514,48 @@ pub struct Predicate {
 }
 
 impl Predicate {
+    /// Constructs a predicate consisting of:
+    /// - current raft term
+    /// - provided `index`
+    /// - provided `ranges`
+    #[inline]
+    pub fn new(index: RaftIndex, ranges: impl Into<Vec<Range>>) -> Self {
+        let node = traft::node::global().expect("shouldn't be called before node is initialized");
+
+        Self {
+            index,
+            term: node.status().term,
+            ranges: ranges.into(),
+        }
+    }
+
+    /// Constructs a predicate consisting of:
+    /// - current raft term
+    /// - current applied index
+    /// - provided `ranges`
+    #[inline]
+    pub fn with_applied_index(ranges: impl Into<Vec<Range>>) -> Self {
+        let node = traft::node::global().expect("shouldn't be called before node is initialized");
+
+        Self {
+            index: node.get_index(),
+            term: node.status().term,
+            ranges: ranges.into(),
+        }
+    }
+
     pub fn from_lua_args(predicate: PredicateInLua) -> traft::Result<Self> {
         let node = traft::node::global()?;
         let (index, term) = if let Some(index) = predicate.index {
             if let Some(term) = predicate.term {
                 (index, term)
             } else {
-                let term = raft::Storage::term(&node.raft_storage, index)?;
+                let term = node.status().term;
                 (index, term)
             }
         } else {
             let index = node.get_index();
-            let term = raft::Storage::term(&node.raft_storage, index)?;
+            let term = node.status().term;
             (index, term)
         };
         Ok(Self {
diff --git a/src/luamod.rs b/src/luamod.rs
index d5c22604eb..60da8b8dcc 100644
--- a/src/luamod.rs
+++ b/src/luamod.rs
@@ -679,14 +679,7 @@ pub(crate) fn setup() {
                 let timeout = duration_from_secs_f64_clamped(timeout);
                 let deadline = fiber::clock().saturating_add(timeout);
 
-                let node = node::global()?;
-                let term = raft::Storage::term(&node.raft_storage, index)?;
-                let predicate = cas::Predicate {
-                    index,
-                    term,
-                    ranges: cas::schema_change_ranges().into(),
-                };
-
+                let predicate = cas::Predicate::new(index, cas::schema_change_ranges());
                 let req = crate::cas::Request::new(op, predicate, effective_user_id())?;
                 let res = cas::compare_and_swap(&req, false, deadline)?;
                 let res = res.no_retries()?;
diff --git a/src/plugin/mod.rs b/src/plugin/mod.rs
index b91ebcc862..6612394773 100644
--- a/src/plugin/mod.rs
+++ b/src/plugin/mod.rs
@@ -400,20 +400,13 @@ fn do_routing_table_cas(
     ranges: Vec<Range>,
     timeout: Duration,
 ) -> traft::Result<()> {
-    let node = node::global()?;
-    let raft_storage = &node.raft_storage;
-
     let deadline = fiber::clock().saturating_add(timeout);
     loop {
         let op = Op::BatchDml {
             ops: dml_ops.clone(),
         };
 
-        let predicate = cas::Predicate {
-            index: raft_storage.applied()?,
-            term: raft_storage.term()?,
-            ranges: ranges.clone(),
-        };
+        let predicate = cas::Predicate::with_applied_index(ranges.clone());
         let req = crate::cas::Request::new(op.clone(), predicate, ADMIN_ID)?;
         let res = cas::compare_and_swap(&req, true, deadline)?;
         if res.is_retriable_error() {
@@ -512,12 +505,7 @@ fn reenterable_plugin_cas_request(
             AlreadyApplied => return Ok(index),
         };
 
-        let term = raft::Storage::term(&node.raft_storage, index)?;
-        let predicate = cas::Predicate {
-            index,
-            term,
-            ranges: ranges.clone(),
-        };
+        let predicate = cas::Predicate::new(index, ranges.clone());
         // FIXME: access rules will be implemented in future release
         let current_user = effective_user_id();
         let req = crate::cas::Request::new(op.clone(), predicate, current_user)?;
diff --git a/src/rpc/join.rs b/src/rpc/join.rs
index d081f7dc9e..82f8809675 100644
--- a/src/rpc/join.rs
+++ b/src/rpc/join.rs
@@ -70,7 +70,6 @@ pub fn handle_join_request_and_wait(req: Request, timeout: Duration) -> Result<R
     let node = node::global()?;
     let cluster_id = node.raft_storage.cluster_id()?;
     let storage = &node.storage;
-    let raft_storage = &node.raft_storage;
     let guard = node.instances_update.lock();
 
     if req.cluster_id != cluster_id {
@@ -118,15 +117,8 @@ pub fn handle_join_request_and_wait(req: Request, timeout: Duration) -> Result<R
             cas::Range::new(ClusterwideTable::Tier),
             cas::Range::new(ClusterwideTable::Replicaset),
         ];
-        let cas_req = crate::cas::Request::new(
-            Op::BatchDml { ops },
-            cas::Predicate {
-                index: raft_storage.applied()?,
-                term: raft_storage.term()?,
-                ranges,
-            },
-            ADMIN_ID,
-        )?;
+        let predicate = cas::Predicate::with_applied_index(ranges);
+        let cas_req = crate::cas::Request::new(Op::BatchDml { ops }, predicate, ADMIN_ID)?;
         let res = cas::compare_and_swap(&cas_req, true, deadline)?;
         if let Some(e) = res.into_retriable_error() {
             crate::tlog!(Debug, "local CaS rejected: {e}");
diff --git a/src/rpc/update_instance.rs b/src/rpc/update_instance.rs
index 1ab8351f48..152b1cce87 100644
--- a/src/rpc/update_instance.rs
+++ b/src/rpc/update_instance.rs
@@ -121,7 +121,6 @@ pub fn handle_update_instance_request_in_governor_and_also_wait_too(
     let node = node::global()?;
     let cluster_id = node.raft_storage.cluster_id()?;
     let storage = &node.storage;
-    let raft_storage = &node.raft_storage;
     let guard = node.instances_update.lock();
 
     if req.cluster_id != cluster_id {
@@ -211,15 +210,8 @@ pub fn handle_update_instance_request_in_governor_and_also_wait_too(
             cas::Range::new(ClusterwideTable::Tier),
         ];
 
-        let cas_req = crate::cas::Request::new(
-            op,
-            cas::Predicate {
-                index: raft_storage.applied()?,
-                term: raft_storage.term()?,
-                ranges,
-            },
-            ADMIN_ID,
-        )?;
+        let predicate = cas::Predicate::with_applied_index(ranges);
+        let cas_req = crate::cas::Request::new(op, predicate, ADMIN_ID)?;
         let res = cas::compare_and_swap(&cas_req, true, deadline)?;
         if req.dont_retry {
             res.no_retries()?;
diff --git a/src/schema.rs b/src/schema.rs
index b9b35c5273..dfb2a44d55 100644
--- a/src/schema.rs
+++ b/src/schema.rs
@@ -2460,18 +2460,6 @@ pub fn abort_ddl(deadline: Instant) -> traft::Result<RaftIndex> {
         if node.storage.properties.pending_schema_change()?.is_none() {
             return Err(DdlError::NoPendingDdl.into());
         }
-        let index = node.get_index();
-        let term = raft::Storage::term(&node.raft_storage, index)?;
-        #[rustfmt::skip]
-        let predicate = cas::Predicate {
-            index,
-            term,
-            ranges: vec![
-                cas::Range::new(ClusterwideTable::Property).eq([PropertyName::PendingSchemaChange]),
-                cas::Range::new(ClusterwideTable::Property).eq([PropertyName::GlobalSchemaVersion]),
-                cas::Range::new(ClusterwideTable::Property).eq([PropertyName::NextSchemaVersion]),
-            ],
-        };
 
         let instance_id = node
             .raft_storage
@@ -2482,6 +2470,14 @@ pub fn abort_ddl(deadline: Instant) -> traft::Result<RaftIndex> {
             message: "explicit abort by user".into(),
             instance_id,
         };
+
+        #[rustfmt::skip]
+        let ranges = vec![
+            cas::Range::new(ClusterwideTable::Property).eq([PropertyName::PendingSchemaChange]),
+            cas::Range::new(ClusterwideTable::Property).eq([PropertyName::GlobalSchemaVersion]),
+            cas::Range::new(ClusterwideTable::Property).eq([PropertyName::NextSchemaVersion]),
+        ];
+        let predicate = cas::Predicate::with_applied_index(ranges);
         let req = cas::Request::new(Op::DdlAbort { cause }, predicate, effective_user_id())?;
         let res = cas::compare_and_swap(&req, true, deadline)?;
         match res {
diff --git a/src/sql.rs b/src/sql.rs
index 7e3f174268..03b9faa250 100644
--- a/src/sql.rs
+++ b/src/sql.rs
@@ -1374,13 +1374,8 @@ pub(crate) fn reenterable_schema_change_request(
         };
         let is_ddl_prepare = matches!(op, Op::DdlPrepare { .. });
 
-        let term = raft::Storage::term(&node.raft_storage, index)?;
         // TODO: Should look at https://git.picodata.io/picodata/picodata/picodata/-/issues/866.
-        let predicate = cas::Predicate {
-            index,
-            term,
-            ranges: cas::schema_change_ranges().into(),
-        };
+        let predicate = cas::Predicate::new(index, cas::schema_change_ranges());
         let req = crate::cas::Request::new(op, predicate, current_user)?;
         let res = cas::compare_and_swap(&req, true, deadline)?;
         let index = match res {
@@ -1472,9 +1467,6 @@ fn do_dml_on_global_tbl(mut query: Query<RouterRuntime>) -> traft::Result<Consum
 
     let raft_node = node::global()?;
     let raft_index = raft_node.get_index();
-    let raft_term = with_su(ADMIN_ID, || -> traft::Result<u64> {
-        Ok(raft::Storage::term(&raft_node.raft_storage, raft_index)?)
-    })??;
 
     // Materialize reading subtree and extract some needed data from Plan
     let (table_id, dml_kind, vtable) = {
@@ -1579,11 +1571,7 @@ fn do_dml_on_global_tbl(mut query: Query<RouterRuntime>) -> traft::Result<Consum
         let ops_count = ops.len();
         let op = crate::traft::op::Op::BatchDml { ops };
 
-        let predicate = Predicate {
-            index: raft_index,
-            term: raft_term,
-            ranges: vec![],
-        };
+        let predicate = Predicate::new(raft_index, []);
         let cas_req = crate::cas::Request::new(op, predicate, current_user)?;
         let res = crate::cas::compare_and_swap(&cas_req, true, deadline)?;
         res.no_retries()?;
-- 
GitLab