From 442d0dfd97385ec129b76a3923a490e3de36cc4f Mon Sep 17 00:00:00 2001
From: Egor Ivkov <e.o.ivkov@gmail.com>
Date: Thu, 13 Jul 2023 13:38:50 +0300
Subject: [PATCH] feat: add async cas and allow local execution

---
 src/cas.rs    | 311 +++++++++++++++++++++++++++++---------------------
 src/schema.rs |  15 +--
 2 files changed, 191 insertions(+), 135 deletions(-)

diff --git a/src/cas.rs b/src/cas.rs
index 5a2ab9c38c..4ffe908bb2 100644
--- a/src/cas.rs
+++ b/src/cas.rs
@@ -12,6 +12,7 @@ use crate::traft::Result;
 use crate::traft::{EntryContext, EntryContextNormal};
 use crate::traft::{RaftIndex, RaftTerm};
 use crate::unwrap_ok_or;
+use crate::util;
 
 use ::raft::prelude as raft;
 use ::raft::Error as RaftError;
@@ -37,6 +38,7 @@ const PROHIBITED_SPACES: &[ClusterwideSpaceId] = &[
     ClusterwideSpaceId::Privilege,
 ];
 
+// FIXME: cas::Error will be returned as a string when rpc is called
 /// Performs a clusterwide compare and swap operation.
 ///
 /// E.g. it checks the `predicate` on leader and if no conflicting
@@ -44,9 +46,12 @@ const PROHIBITED_SPACES: &[ClusterwideSpaceId] = &[
 /// index and term.
 ///
 /// # Errors
-/// See [`cas::Error`][Error] for CaS-specific errors. It can also return
-/// general picodata errors in cases of faulty network or storage.
-pub fn compare_and_swap(op: Op, predicate: Predicate) -> traft::Result<(RaftIndex, RaftTerm)> {
+/// See [`cas::Error`][Error] for CaS-specific errors.
+/// It can also return general picodata errors in cases of faulty network or storage.
+pub async fn compare_and_swap_async(
+    op: Op,
+    predicate: Predicate,
+) -> traft::Result<(RaftIndex, RaftTerm)> {
     let node = node::global()?;
     let request = Request {
         cluster_id: node.raft_storage.cluster_id()?,
@@ -63,157 +68,207 @@ pub fn compare_and_swap(op: Op, predicate: Predicate) -> traft::Result<(RaftInde
             node.storage.peer_addresses.try_get(leader_id),
             Err(e) => {
                 tlog!(Warning, "failed getting leader address: {e}");
-                tlog!(Info, "going to retry in while...");
-                fiber::sleep(Duration::from_millis(250));
+                tlog!(Info, "going to retry in a while...");
+                util::sleep_async(Duration::from_millis(250)).await;
                 continue;
             }
         );
-        let resp = rpc::network_call(&leader_address, &request);
-        let resp = fiber::block_on(resp.timeout(Duration::from_secs(3)));
+        let resp = if leader_id == node.raft_id {
+            // cas has to be called locally in cases when listen ports are closed,
+            // for example on shutdown
+            proc_cas_local(request.clone())
+        } else {
+            rpc::network_call(&leader_address, &request)
+                .await
+                .map_err(TraftError::from)
+        };
         match resp {
             Ok(Response { index, term }) => return Ok((index, term)),
             Err(e) => {
                 tlog!(Warning, "{e}");
-                return Err(e.into());
+                if e.is_not_leader_err() {
+                    tlog!(Info, "going to retry in a while...");
+                    node.wait_status();
+                    continue;
+                } else {
+                    return Err(e);
+                }
             }
         }
     }
 }
 
-crate::define_rpc_request! {
-    fn proc_cas(req: Request) -> Result<Response> {
-        let node = node::global()?;
-        let raft_storage = &node.raft_storage;
-        let storage = &node.storage;
-        let cluster_id = raft_storage.cluster_id()?;
-
-        if req.cluster_id != cluster_id {
-            return Err(TraftError::ClusterIdMismatch {
-                instance_cluster_id: req.cluster_id,
-                cluster_cluster_id: cluster_id,
-            });
-        }
+/// Performs a clusterwide compare and swap operation.
+///
+/// E.g. it checks the `predicate` on leader and if no conflicting entries were found
+/// appends the `op` to the raft log and returns its index and term.
+///
+/// # Errors
+/// See [`rpc::cas::Error`] for CaS-specific errors.
+/// It can also return general picodata errors in cases of faulty network or storage.
+pub fn compare_and_swap(
+    op: Op,
+    predicate: Predicate,
+    timeout: Duration,
+) -> traft::Result<(RaftIndex, RaftTerm)> {
+    fiber::block_on(compare_and_swap_async(op, predicate).timeout(timeout)).map_err(Into::into)
+}
 
-        let Predicate { index: requested, term: requested_term, .. } = req.predicate;
-        // N.B. lock the mutex before getting status
-        let mut node_impl = node.node_impl();
-        let status = node.status();
+fn proc_cas_local(req: Request) -> Result<Response> {
+    let node = node::global()?;
+    let raft_storage = &node.raft_storage;
+    let storage = &node.storage;
+    let cluster_id = raft_storage.cluster_id()?;
+
+    if req.cluster_id != cluster_id {
+        return Err(TraftError::ClusterIdMismatch {
+            instance_cluster_id: req.cluster_id,
+            cluster_cluster_id: cluster_id,
+        });
+    }
 
-        if requested_term != status.term {
-            return Err(TraftError::TermMismatch {
-                requested: requested_term,
-                current: status.term,
-            });
-        }
-        if status.leader_id != Some(node.raft_id()) {
-            // Nearly impossible error indicating invalid request.
-            return Err(TraftError::NotALeader);
-        }
+    let Predicate {
+        index: requested,
+        term: requested_term,
+        ..
+    } = req.predicate;
+    // N.B. lock the mutex before getting status
+    let mut node_impl = node.node_impl();
+    let status = node.status();
+
+    if requested_term != status.term {
+        return Err(TraftError::TermMismatch {
+            requested: requested_term,
+            current: status.term,
+        });
+    }
+    if status.leader_id != Some(node.raft_id()) {
+        // Nearly impossible error indicating invalid request.
+        return Err(TraftError::NotALeader);
+    }
 
-        let raft_log = &node_impl.raw_node.raft.raft_log;
+    let raft_log = &node_impl.raw_node.raft.raft_log;
 
-        let first = raft_log.first_index();
-        let last = raft_log.last_index();
-        assert!(first >= 1);
-        if requested > last {
-            return Err(Error::NoSuchIndex { requested, last_index: last }.into());
-        } else if (requested + 1) < first {
-            return Err(Error::Compacted{ requested, compacted_index: first - 1 }.into());
+    let first = raft_log.first_index();
+    let last = raft_log.last_index();
+    assert!(first >= 1);
+    if requested > last {
+        return Err(Error::NoSuchIndex {
+            requested,
+            last_index: last,
+        }
+        .into());
+    } else if (requested + 1) < first {
+        return Err(Error::Compacted {
+            requested,
+            compacted_index: first - 1,
         }
+        .into());
+    }
 
-        assert!(requested >= first - 1);
-        assert!(requested <= last);
-        assert_eq!(requested_term, status.term);
-
-        // Also check that requested index actually belongs to the
-        // requested term.
-        let entry_term = raft_log.term(requested).unwrap_or(0);
-        if entry_term != status.term {
-            return Err(Error::EntryTermMismatch {
-                index: requested,
-                expected_term: status.term,
-                actual_term: entry_term,
-            }.into());
+    assert!(requested >= first - 1);
+    assert!(requested <= last);
+    assert_eq!(requested_term, status.term);
+
+    // Also check that requested index actually belongs to the
+    // requested term.
+    let entry_term = raft_log.term(requested).unwrap_or(0);
+    if entry_term != status.term {
+        return Err(Error::EntryTermMismatch {
+            index: requested,
+            expected_term: status.term,
+            actual_term: entry_term,
         }
+        .into());
+    }
 
-        let last_persisted = raft::Storage::last_index(raft_storage)?;
-        assert!(last_persisted <= last);
+    let last_persisted = raft::Storage::last_index(raft_storage)?;
+    assert!(last_persisted <= last);
 
-        // Check if ranges in predicate contain prohibited spaces.
-        for range in &req.predicate.ranges {
-            let Ok(space) = ClusterwideSpaceId::try_from(range.space) else { continue; };
-            if PROHIBITED_SPACES.contains(&space)
-            {
-                return Err(Error::SpaceNotAllowed { space: space.name().into() }.into())
+    // Check if ranges in predicate contain prohibited spaces.
+    for range in &req.predicate.ranges {
+        let Ok(space) = ClusterwideSpaceId::try_from(range.space) else { continue; };
+        if PROHIBITED_SPACES.contains(&space) {
+            return Err(Error::SpaceNotAllowed {
+                space: space.name().into(),
             }
+            .into());
         }
+    }
 
-        // It's tempting to just use `raft_log.entries()` here and only
-        // write the body of the loop once, but this would mean
-        // converting entries from our storage representation to raft-rs
-        // and than back again for all the persisted entries, which we
-        // obviously don't want to do, so we instead write the body of
-        // the loop twice
-
-        // General case:
-        //                              ,- last_persisted
-        //                  ,- first    |             ,- last
-        // entries: - - - - x x x x x x x x x x x x x x
-        //                | [ persisted ] [ unstable  ]
-        //                | [ checked                 ]
-        //      requested ^
-        //
-
-        // Corner case 1:
-        //                ,- last_persisted
-        //                | ,- first    ,- last
-        // entries: - - - - x x x x x x x
-        //                  [ unstable  ]
-        //
-
-        if requested < last_persisted { // there's at least one persisted entry to check
-            let persisted = raft_storage.entries(requested + 1, last_persisted + 1)?;
-            if persisted.len() < (last_persisted - requested) as usize {
-                return Err(RaftError::Store(StorageError::Unavailable).into());
-            }
-
-            for entry in persisted {
-                assert_eq!(entry.term, status.term);
-                let entry_index = entry.index;
-                let Some(op) = entry.into_op() else { continue };
-                req.predicate.check_entry(entry_index, &op, storage)?;
-            }
+    // It's tempting to just use `raft_log.entries()` here and only
+    // write the body of the loop once, but this would mean
+    // converting entries from our storage representation to raft-rs
+    // and than back again for all the persisted entries, which we
+    // obviously don't want to do, so we instead write the body of
+    // the loop twice
+
+    // General case:
+    //                              ,- last_persisted
+    //                  ,- first    |             ,- last
+    // entries: - - - - x x x x x x x x x x x x x x
+    //                | [ persisted ] [ unstable  ]
+    //                | [ checked                 ]
+    //      requested ^
+    //
+
+    // Corner case 1:
+    //                ,- last_persisted
+    //                | ,- first    ,- last
+    // entries: - - - - x x x x x x x
+    //                  [ unstable  ]
+    //
+
+    if requested < last_persisted {
+        // there's at least one persisted entry to check
+        let persisted = raft_storage.entries(requested + 1, last_persisted + 1)?;
+        if persisted.len() < (last_persisted - requested) as usize {
+            return Err(RaftError::Store(StorageError::Unavailable).into());
         }
 
-        // Check remaining unstable entries.
-        let unstable = raft_log.entries(last_persisted + 1, u64::MAX)?;
-        for entry in unstable {
+        for entry in persisted {
             assert_eq!(entry.term, status.term);
-            let Ok(cx) = EntryContext::from_raft_entry(&entry) else {
+            let entry_index = entry.index;
+            let Some(op) = entry.into_op() else { continue };
+            req.predicate.check_entry(entry_index, &op, storage)?;
+        }
+    }
+
+    // Check remaining unstable entries.
+    let unstable = raft_log.entries(last_persisted + 1, u64::MAX)?;
+    for entry in unstable {
+        assert_eq!(entry.term, status.term);
+        let Ok(cx) = EntryContext::from_raft_entry(&entry) else {
                 tlog!(Warning, "raft entry has invalid context"; "entry" => ?entry);
                 continue;
             };
-            let Some(EntryContext::Normal(EntryContextNormal { op, .. })) = cx else {
+        let Some(EntryContext::Normal(EntryContextNormal { op, .. })) = cx else {
                 continue;
             };
-            req.predicate.check_entry(entry.index, &op, storage)?;
-        }
+        req.predicate.check_entry(entry.index, &op, storage)?;
+    }
 
-        // TODO: apply to limbo first
+    // TODO: apply to limbo first
 
-        // Don't wait for the proposal to be accepted, instead return the index
-        // to the requestor, so that they can wait for it.
+    // Don't wait for the proposal to be accepted, instead return the index
+    // to the requestor, so that they can wait for it.
 
-        let notify = node_impl.propose_async(req.op)?;
-        let raft_log = &node_impl.raw_node.raft.raft_log;
-        let index = raft_log.last_index();
-        let term = raft_log.term(index).unwrap();
-        assert_eq!(index, last + 1);
-        assert_eq!(term, requested_term);
-        drop(node_impl); // unlock the mutex
-        drop(notify); // don't wait for commit
+    let notify = node_impl.propose_async(req.op)?;
+    let raft_log = &node_impl.raw_node.raft.raft_log;
+    let index = raft_log.last_index();
+    let term = raft_log.term(index).unwrap();
+    assert_eq!(index, last + 1);
+    assert_eq!(term, requested_term);
+    drop(node_impl); // unlock the mutex
+    drop(notify); // don't wait for commit
 
-        Ok(Response { index, term })
+    Ok(Response { index, term })
+}
+
+crate::define_rpc_request! {
+    // TODO Result<Either<Response, Error>>
+    fn proc_cas(req: Request) -> Result<Response> {
+        proc_cas_local(req)
     }
 
     pub struct Request {
@@ -439,13 +494,13 @@ impl Range {
     /// use picodata::cas::Range;
     ///
     /// // Creates a range for tuples with keys from 1 (excluding) to 10 (excluding)
-    /// let my_space_id = 2222;
+    /// let my_space_id: u32 = 2222;
     /// let range = Range::new(my_space_id).gt((1,)).lt((10,));
     /// ```
     #[inline(always)]
-    pub fn new(space: SpaceId) -> Self {
+    pub fn new(space: impl Into<SpaceId>) -> Self {
         Self {
-            space,
+            space: space.into(),
             key_min: Bound::unbounded(),
             key_max: Bound::unbounded(),
         }
@@ -651,7 +706,7 @@ mod tests {
         assert!(t(&create_space, Range::new(props).eq(&pending_schema_version)).is_err());
         assert!(t(&create_space, Range::new(props).eq(("another_key",))).is_ok());
 
-        assert!(t(&create_space, Range::new(69105).eq(("any_key",))).is_ok());
+        assert!(t(&create_space, Range::new(69105u32).eq(("any_key",))).is_ok());
         assert!(t(&create_space, Range::new(space_id).eq(("any_key",))).is_err());
 
         // drop_space
@@ -671,7 +726,7 @@ mod tests {
         assert!(t(&drop_space, Range::new(props).eq(&pending_schema_version)).is_err());
         assert!(t(&drop_space, Range::new(props).eq(("another_key",))).is_ok());
 
-        assert!(t(&drop_space, Range::new(69105).eq(("any_key",))).is_ok());
+        assert!(t(&drop_space, Range::new(69105u32).eq(("any_key",))).is_ok());
         assert!(t(&drop_space, Range::new(space_id).eq(("any_key",))).is_err());
 
         // create_index
@@ -679,7 +734,7 @@ mod tests {
         assert!(t(&create_index, Range::new(props).eq(&pending_schema_version)).is_err());
         assert!(t(&create_index, Range::new(props).eq(("another_key",))).is_ok());
 
-        assert!(t(&create_index, Range::new(69105).eq(("any_key",))).is_ok());
+        assert!(t(&create_index, Range::new(69105u32).eq(("any_key",))).is_ok());
         assert!(t(&create_index, Range::new(space_id).eq(("any_key",))).is_ok());
 
         // drop_index
@@ -687,7 +742,7 @@ mod tests {
         assert!(t(&drop_index, Range::new(props).eq(&pending_schema_version)).is_err());
         assert!(t(&drop_index, Range::new(props).eq(("another_key",))).is_ok());
 
-        assert!(t(&drop_index, Range::new(69105).eq(("any_key",))).is_ok());
+        assert!(t(&drop_index, Range::new(69105u32).eq(("any_key",))).is_ok());
         assert!(t(&drop_index, Range::new(space_id).eq(("any_key",))).is_ok());
 
         // Abort and Commit need a pending schema change to get space name
@@ -704,7 +759,7 @@ mod tests {
         assert!(t(&commit, Range::new(props).eq(&global_schema_version)).is_err());
         assert!(t(&commit, Range::new(props).eq(("another_key",))).is_ok());
 
-        assert!(t(&commit, Range::new(69105).eq(("any_key",))).is_ok());
+        assert!(t(&commit, Range::new(69105u32).eq(("any_key",))).is_ok());
         assert!(t(&commit, Range::new(space_id).eq(("any_key",))).is_err());
 
         // abort
@@ -714,7 +769,7 @@ mod tests {
         assert!(t(&abort, Range::new(props).eq(&next_schema_version)).is_err());
         assert!(t(&abort, Range::new(props).eq(("another_key",))).is_ok());
 
-        assert!(t(&abort, Range::new(69105).eq(("any_key",))).is_ok());
+        assert!(t(&abort, Range::new(69105u32).eq(("any_key",))).is_ok());
         assert!(t(&abort, Range::new(space_id).eq(("any_key",))).is_err());
     }
 
@@ -767,7 +822,7 @@ mod tests {
             assert!(test(op, Range::new(space).gt((12,))).is_ok());
             assert!(test(op, Range::new(space).ge((13,))).is_ok());
 
-            assert!(test(op, Range::new(69105)).is_ok());
+            assert!(test(op, Range::new(69105u32)).is_ok());
         }
     }
 }
diff --git a/src/schema.rs b/src/schema.rs
index cf40b4ea3d..11c07fda52 100644
--- a/src/schema.rs
+++ b/src/schema.rs
@@ -605,6 +605,7 @@ fn wait_for_no_pending_schema_change(
 /// Waits for any pending schema change to finalize.
 ///
 /// If `timeout` is reached earlier returns an error.
+// TODO: Use deadline instead of timeout
 pub fn prepare_schema_change(op: Op, timeout: Duration) -> traft::Result<RaftIndex> {
     debug_assert!(op.is_schema_change());
 
@@ -623,17 +624,17 @@ pub fn prepare_schema_change(op: Op, timeout: Duration) -> traft::Result<RaftInd
             index,
             term,
             ranges: vec![
-                cas::Range::new(ClusterwideSpaceId::Property as _)
+                cas::Range::new(ClusterwideSpaceId::Property)
                     .eq((PropertyName::PendingSchemaChange,)),
                 cas::Range::new(ClusterwideSpaceId::Property as _)
                     .eq((PropertyName::PendingSchemaVersion,)),
                 cas::Range::new(ClusterwideSpaceId::Property as _)
                     .eq((PropertyName::GlobalSchemaVersion,)),
-                cas::Range::new(ClusterwideSpaceId::Property as _)
+                cas::Range::new(ClusterwideSpaceId::Property)
                     .eq((PropertyName::NextSchemaVersion,)),
             ],
         };
-        let (index, term) = compare_and_swap(op, predicate)?;
+        let (index, term) = compare_and_swap(op, predicate, timeout)?;
         node.wait_index(index, timeout)?;
         if raft::Storage::term(raft_storage, index)? != term {
             // leader switched - retry
@@ -660,15 +661,15 @@ pub fn abort_ddl(timeout: Duration) -> traft::Result<RaftIndex> {
             index,
             term,
             ranges: vec![
-                cas::Range::new(ClusterwideSpaceId::Property as _)
+                cas::Range::new(ClusterwideSpaceId::Property)
                     .eq((PropertyName::PendingSchemaChange,)),
-                cas::Range::new(ClusterwideSpaceId::Property as _)
+                cas::Range::new(ClusterwideSpaceId::Property)
                     .eq((PropertyName::GlobalSchemaVersion,)),
-                cas::Range::new(ClusterwideSpaceId::Property as _)
+                cas::Range::new(ClusterwideSpaceId::Property)
                     .eq((PropertyName::NextSchemaVersion,)),
             ],
         };
-        let (index, term) = compare_and_swap(Op::DdlAbort, predicate)?;
+        let (index, term) = compare_and_swap(Op::DdlAbort, predicate, timeout)?;
         node.wait_index(index, timeout)?;
         if raft::Storage::term(&node.raft_storage, index)? != term {
             // leader switched - retry
-- 
GitLab