From 13dc8c86b93a2fdb5f8e1f9f8be5693688ce2333 Mon Sep 17 00:00:00 2001
From: Georgy Moshkin <gmoshkin@picodata.io>
Date: Fri, 30 Aug 2024 16:19:06 +0300
Subject: [PATCH] fix: automatic cas ranges for global sql dml

---
 src/luamod.rs        | 91 ++++++++++++++++++++++++++++++++++++++++++++
 src/sql.rs           |  3 +-
 test/int/test_sql.py | 89 +++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 182 insertions(+), 1 deletion(-)

diff --git a/src/luamod.rs b/src/luamod.rs
index 40575e3f09..b7ef96c767 100644
--- a/src/luamod.rs
+++ b/src/luamod.rs
@@ -337,6 +337,97 @@ pub(crate) fn setup() {
         }),
     );
 
+    ///////////////////////////////////////////////////////////////////////////
+    // raft index
+    ///////////////////////////////////////////////////////////////////////////
+
+    luamod_set(
+        &l,
+        "raft_get_index",
+        indoc! {"
+        pico.raft_get_index()
+        =====================
+
+        Returns the current applied raft index.
+
+        Returns:
+
+            (number)
+            or
+            (nil, string) in case of an error, if the raft node is
+                not initialized yet
+        "},
+        tlua::function0(|| traft::node::global().map(|n| n.get_index())),
+    );
+    luamod_set(
+        &l,
+        "raft_read_index",
+        indoc! {"
+        pico.raft_read_index(timeout)
+        =============================
+
+        Performs the quorum read operation.
+
+        It works the following way:
+
+        1. The instance forwards a request (`MsgReadIndex`) to a raft
+           leader. In case there's no leader at the moment, the function
+           returns the error 'raft: proposal dropped'.
+        2. Raft leader tracks its `commit_index` and broadcasts a
+           heartbeat to followers to make certain that it's still a
+           leader.
+        3. As soon as the heartbeat is acknowledged by the quorum, the
+           leader returns that index to the instance.
+        4. The instance awaits when the index is applied. If timeout
+           expires beforehand, the function returns the error 'timeout'.
+
+        Params:
+
+            1. timeout (number), in seconds
+
+        Returns:
+
+            (number)
+            or
+            (nil, string) in case of an error
+        "},
+        tlua::function1(|timeout: f64| -> traft::Result<RaftIndex> {
+            traft::node::global()?.read_index(duration_from_secs_f64_clamped(timeout))
+        }),
+    );
+    luamod_set(
+        &l,
+        "raft_wait_index",
+        indoc! {"
+        pico.raft_wait_index(target, timeout)
+        =====================================
+
+        Waits for the `target` index to be applied to the storage locally.
+
+        Returns current applied raft index. It can be equal to or
+        greater than the requested one. If timeout expires beforehand,
+        the function returns an error.
+
+        Params:
+
+            1. target (number)
+            2. timeout (number), in seconds
+
+        Returns:
+
+            (number)
+            or
+            (nil, string) in case of an error
+        "},
+        tlua::function2(
+            |target: RaftIndex, timeout: f64| -> traft::Result<RaftIndex> {
+                let node = traft::node::global()?;
+                node.wait_index(target, duration_from_secs_f64_clamped(timeout))
+            },
+        ),
+    );
+
+    ///////////////////////////////////////////////////////////////////////////
     // sql
     ///////////////////////////////////////////////////////////////////////////
     luamod_set_help_only(
diff --git a/src/sql.rs b/src/sql.rs
index 03b9faa250..a6c181094a 100644
--- a/src/sql.rs
+++ b/src/sql.rs
@@ -1570,8 +1570,9 @@ fn do_dml_on_global_tbl(mut query: Query<RouterRuntime>) -> traft::Result<Consum
 
         let ops_count = ops.len();
         let op = crate::traft::op::Op::BatchDml { ops };
+        let ranges = cas::Range::for_op(&op)?;
 
-        let predicate = Predicate::new(raft_index, []);
+        let predicate = Predicate::new(raft_index, ranges);
         let cas_req = crate::cas::Request::new(op, predicate, current_user)?;
         let res = crate::cas::compare_and_swap(&cas_req, true, deadline)?;
         res.no_retries()?;
diff --git a/test/int/test_sql.py b/test/int/test_sql.py
index 2b5016759d..f630d26eeb 100644
--- a/test/int/test_sql.py
+++ b/test/int/test_sql.py
@@ -5186,3 +5186,92 @@ def test_alter_system_property_errors(cluster: Cluster):
             alter system set "auto_offline_timeout" to true for tier foo
             """
         )
+
+
+def test_global_dml_cas_conflict(cluster: Cluster):
+    # Number of update operations per worker
+    N = 100
+    # Number of parallel workers running update operations
+    K = 4
+    # Add one for raft leader (not going to be a worker)
+    instance_count = K + 1
+    [i1, *_] = cluster.deploy(instance_count=instance_count)
+    workers = [i for i in cluster.instances if i != i1]
+
+    i1.sql(
+        """
+        CREATE TABLE test_table (id UNSIGNED PRIMARY KEY, counter UNSIGNED) DISTRIBUTED GLOBALLY
+        """
+    )
+    i1.sql(""" INSERT INTO test_table VALUES (0, 0) """)
+
+    test_sql = """ UPDATE test_table SET counter = counter + 1 WHERE id = 0 """
+    prepare = """
+        local N, test_sql = ...
+        local fiber = require 'fiber'
+        local log = require 'log'
+        function test_body()
+            while not box.space.test_table do
+                fiber.sleep(.1)
+            end
+
+            local i = 0
+            local stats = { n_retries = 0 }
+            while i < N do
+                log.info("UPDATE #%d running...", i)
+                while true do
+                    local ok, err = pico.sql(test_sql)
+                    if err == nil then break end
+                    log.error("UPDATE #%d failed: %s, retry", i, err)
+                    stats.n_retries = stats.n_retries + 1
+                    pico.raft_wait_index(pico.raft_get_index() + 1, 3)
+                end
+                log.info("UPDATE #%d OK", i)
+                i = i + 1
+            end
+
+            log.info("DONE: n_retries = %d", stats.n_retries)
+            return stats
+        end
+
+        function wait_result()
+            while true do
+                local result = rawget(_G, 'result')
+                if result ~= nil then
+                    if not result[1] then
+                        error(result[2])
+                    end
+                    return result[2]
+                end
+                fiber.sleep(.1)
+            end
+        end
+
+        function start_test()
+            fiber.create(function()
+                rawset(_G, 'result', { pcall(test_body) })
+            end)
+        end
+    """  # noqa: E501
+    for i in workers:
+        i.eval(prepare, N, test_sql)
+
+    #
+    # Run parallel updates to same table row from several instances simultaniously
+    #
+    for i in workers:
+        i.call("start_test")
+
+    #
+    # Wait for the test results
+    #
+    for i in workers:
+        stats = i.call("wait_result", timeout=20)
+        # There were conflicts
+        assert stats["n_retries"] > 0
+
+    #
+    # All operations were successfull
+    #
+    rows = i1.sql(""" SELECT * FROM test_table """)
+    assert rows == [[0, N * K]]
-- 
GitLab