From 41af865e8db52e3d4dcd6739d4d2373acd37861b Mon Sep 17 00:00:00 2001
From: Arseniy Volynets <vol0ncar@yandex.ru>
Date: Fri, 2 Aug 2024 14:10:24 +0300
Subject: [PATCH] refactor!: add session id to prepare funcs

BREAKING CHANGE!:
1. add session id argument to sql_prepare_ext
2. introduce sql_unprepare_ext function.
This function removes prepared stmt using
given session id.

In picodata SQL, we may prepare stmt in one
session and unprepare it in some
other session, which does not know in
what session the statement was prepared. Now
sql_prepare_ext returns not only statement id,
but also a session id. This way statement can
be unprepared from other session using
sql_unprepare_ext.

NO_DOC=picodata internal patch
NO_CHANGELOG=picodata internal patch
---
 extra/exports                     |  2 +-
 src/box/execute.c                 | 22 +++++++-
 src/box/execute.h                 | 13 ++++-
 test/sql-luatest/sql_api_test.lua | 87 +++++++++++++++++++++++++++++--
 4 files changed, 117 insertions(+), 7 deletions(-)

diff --git a/extra/exports b/extra/exports
index 10a7fc52b3..d679db1bfd 100644
--- a/extra/exports
+++ b/extra/exports
@@ -595,7 +595,7 @@ obuf_create
 obuf_destroy
 sql_prepare_ext
 sql_stmt_finalize
-sql_unprepare
+sql_unprepare_ext
 sql_stmt_query_str
 sql_stmt_calculate_id
 sql_execute_prepared_ext
diff --git a/src/box/execute.c b/src/box/execute.c
index 60f5a667d3..f711cde406 100644
--- a/src/box/execute.c
+++ b/src/box/execute.c
@@ -171,14 +171,32 @@ sql_prepare(const char *sql, int len, struct port *port)
 
 /**
  * Find or create prepared statement by its SQL query.
- * Returns compiled statement ID.
+ * Returns compiled statement ID and session ID.
  */
 int
-sql_prepare_ext(const char *sql, int len, uint32_t *stmt_id)
+sql_prepare_ext(const char *sql, int len, uint32_t *stmt_id, uint64_t *session_id)
 {
 	struct sql_stmt *stmt = NULL;
 	if (sql_stmt_find_or_create(sql, len, stmt_id, &stmt) != 0)
 		return -1;
+	*session_id = current_session()->id;
+	return 0;
+}
+
+int
+sql_unprepare_ext(uint32_t stmt_id, uint64_t session_id)
+{
+	struct session *session = session_find(session_id);
+	if (session == NULL) {
+		diag_set(ClientError, ER_NO_SUCH_SESSION, session_id);
+		return -1;
+	}
+	if (!session_check_stmt_id(session, stmt_id)) {
+		diag_set(ClientError, ER_WRONG_QUERY_ID, stmt_id);
+		return -1;
+	}
+	session_remove_stmt_id(session, stmt_id);
+	sql_stmt_unref(stmt_id);
 	return 0;
 }
 
diff --git a/src/box/execute.h b/src/box/execute.h
index e1089b987e..d9ba694bbc 100644
--- a/src/box/execute.h
+++ b/src/box/execute.h
@@ -136,15 +136,26 @@ sql_stmt_busy(const struct sql_stmt *stmt);
 int
 sql_prepare(const char *sql, int len, struct port *port);
 
+/**
+ * Unprepare statement from the session (exported version).
+ *
+ * @param stmt_id ID of prepared stmt.
+ * @param sid session ID.
+ */
+int
+sql_unprepare_ext(uint32_t stmt_id, uint64_t sid);
+
 /**
  * Prepare statement (exported version).
  *
  * @param sql UTF-8 encoded SQL query.
  * @param len Length of @param sql in bytes.
  * @param[out] stmt_id Prepared statement ID.
+ * @param[out] session_id session ID.
  */
 int
-sql_prepare_ext(const char *sql, int len, uint32_t *stmt_id);
+sql_prepare_ext(const char *sql, int len, uint32_t *stmt_id,
+		uint64_t *session_id);
 
 #if defined(__cplusplus)
 } /* extern "C" { */
diff --git a/test/sql-luatest/sql_api_test.lua b/test/sql-luatest/sql_api_test.lua
index f85fb2a444..7134325e5b 100644
--- a/test/sql-luatest/sql_api_test.lua
+++ b/test/sql-luatest/sql_api_test.lua
@@ -38,8 +38,10 @@ g.before_all(function()
             uint32_t stmt_id, const char *mp_params,
             uint64_t vdbe_max_steps, struct obuf *out_buf);
 
-        int sql_prepare_ext(const char *sql, int len, uint32_t *stmt_id);
-
+        int sql_prepare_ext(
+            const char *sql, int len,
+            uint32_t *stmt_id, uint64_t *sid);
+        int sql_unprepare_ext(uint32_t stmt_id, uint64_t sid);
     ]]
     g.server = server:new({alias = 'sql_api'})
     g.server:start()
@@ -159,13 +161,92 @@ g.test_stmt_prepare = function()
 
         local ffi = require('ffi')
         local stmt_id = ffi.new('uint32_t[1]')
+        local session_id = ffi.new('uint64_t[1]')
+
+        -- Prepare the statement.
+        res = ffi.C.sql_prepare_ext('VALUES (?)', 10, stmt_id, session_id)
+        t.assert_equals(res, 0)
+
+        -- Check the prepared statement.
+        res = box.execute(tonumber(stmt_id[0]), {'ABC'})
+        t.assert_equals(res.rows[1][1], 'ABC')
+
+        -- Unprepare the statement
+        res = ffi.C.sql_unprepare_ext(
+            tonumber(stmt_id[0]), tonumber(session_id[0]))
+        t.assert_equals(res, 0)
+
+        -- Calling unprepare again returns error
+        res = ffi.C.sql_unprepare_ext(
+            tonumber(stmt_id[0]), tonumber(session_id[0]))
+        t.assert_not_equals(res, 0)
+        local err = tostring(box.error.last())
+        local s = string.format(
+            "Prepared statement with id %u does not exist",
+            tonumber(stmt_id[0]))
+        t.assert_str_contains(err, s)
+
+        -- Check unprepare from invalid session
+        -- returns error
+        local wrong_sid = tonumber(session_id[0]) + 1
+        res = ffi.C.sql_unprepare_ext(tonumber(stmt_id[0]), wrong_sid)
+        t.assert_not_equals(res, 0)
+        err = tostring(box.error.last())
+        s = string.format("Session %u does not exist", wrong_sid)
+        t.assert_str_contains(err, s)
+
+        -- Check statement was deleted
+        res = box.execute(tonumber(stmt_id[0]), {'ABC'})
+        t.assert_not_equals(res, 0)
+    end)
+end
+
+g.test_unprepare_from_other_session = function()
+    g.server:exec(function()
+        local fiber = require('fiber')
+
+        local res, err, ok
+
+        local ffi = require('ffi')
+        local stmt_id = ffi.new('uint32_t[1]')
+        local session_id = ffi.new('uint64_t[1]')
 
         -- Prepare the statement.
-        res = ffi.C.sql_prepare_ext('VALUES (?)', 10, stmt_id)
+        res = ffi.C.sql_prepare_ext('VALUES (?)', 10, stmt_id, session_id)
         t.assert_equals(res, 0)
 
         -- Check the prepared statement.
         res = box.execute(tonumber(stmt_id[0]), {'ABC'})
         t.assert_equals(res.rows[1][1], 'ABC')
+
+        local new_session = function(stmt_id, session_id)
+            local ffi = require('ffi')
+            local id = tonumber(stmt_id[0])
+            local sid = tonumber(session_id[0])
+
+            -- Unprepare the statement
+            res = ffi.C.sql_unprepare_ext(id, sid)
+            t.assert_equals(res, 0)
+
+            -- Check we get an error when trying to unprepare
+            -- invalid statement from other session
+            res = ffi.C.sql_unprepare_ext(id, sid)
+            err = tostring(box.error.last())
+            local s = string.format(
+                "Prepared statement with id %u does not exist", id)
+            t.assert_str_contains(err, s)
+        end
+
+        local f = fiber.new(new_session, stmt_id, session_id)
+        f:set_joinable(true)
+        ok, res = f:join()
+
+        if not ok then
+            error(res:unpack())
+        end
+
+        -- Check statement was deleted
+        res = box.execute(tonumber(stmt_id[0]), {'ABC'})
+        t.assert_not_equals(res, 0)
     end)
 end
-- 
GitLab