From befcbd218e8453b35f9eeefdc1bebc7f5072ac9c Mon Sep 17 00:00:00 2001
From: Mergen Imeev <imeevma@gmail.com>
Date: Fri, 19 Nov 2021 18:26:10 +0300
Subject: [PATCH] sql: introduce binding for ARRAY

After this patch, ARRAY values can be used as bind variables. However,
due to the current syntax for binding in Lua, the only possible way is
to use ARRAY values as the named bind variable.

Part of #4762
---
 src/box/bind.c              |  7 +++----
 src/box/lua/execute.c       | 32 +++++++++++++++++++++++++++-----
 src/box/sql/sqlInt.h        |  3 +++
 src/box/sql/vdbeapi.c       |  8 ++++++++
 test/sql-tap/array.test.lua | 25 ++++++++++++++++++++++++-
 test/sql/bind.result        |  7 -------
 test/sql/bind.test.lua      |  2 --
 7 files changed, 65 insertions(+), 19 deletions(-)

diff --git a/src/box/bind.c b/src/box/bind.c
index e75e362835..af9f9eac54 100644
--- a/src/box/bind.c
+++ b/src/box/bind.c
@@ -99,15 +99,12 @@ sql_bind_decode(struct sql_bind *bind, int i, const char **packet)
 	case MP_BIN:
 		bind->s = mp_decode_bin(packet, &bind->bytes);
 		break;
+	case MP_ARRAY:
 	case MP_EXT:
 		bind->s = *packet;
 		mp_next(packet);
 		bind->bytes = *packet - bind->s;
 		break;
-	case MP_ARRAY:
-		diag_set(ClientError, ER_SQL_BIND_TYPE, "ARRAY",
-			 sql_bind_name(bind));
-		return -1;
 	case MP_MAP:
 		diag_set(ClientError, ER_SQL_BIND_TYPE, "MAP",
 			 sql_bind_name(bind));
@@ -190,6 +187,8 @@ sql_bind_column(struct sql_stmt *stmt, const struct sql_bind *p,
 		return sql_bind_null(stmt, pos);
 	case MP_BIN:
 		return sql_bind_bin_static(stmt, pos, p->s, p->bytes);
+	case MP_ARRAY:
+		return sql_bind_array_static(stmt, pos, p->s, p->bytes);
 	case MP_EXT:
 		assert(p->ext_type == MP_UUID || p->ext_type == MP_DECIMAL);
 		if (p->ext_type == MP_UUID)
diff --git a/src/box/lua/execute.c b/src/box/lua/execute.c
index 18a45a5d5b..71d4d7faed 100644
--- a/src/box/lua/execute.c
+++ b/src/box/lua/execute.c
@@ -8,6 +8,8 @@
 #include "box/bind.h"
 #include "box/sql_stmt_cache.h"
 #include "box/schema.h"
+#include "mpstream/mpstream.h"
+#include "box/sql/vdbeInt.h"
 
 /**
  * Serialize a description of the prepared statement.
@@ -331,6 +333,8 @@ lua_sql_bind_decode(struct lua_State *L, struct sql_bind *bind, int idx, int i)
 	}
 	if (luaL_tofield(L, luaL_msgpack_default, -1, &field) < 0)
 		return -1;
+	bind->type = field.type;
+	bind->ext_type = field.ext_type;
 	switch (field.type) {
 	case MP_UINT:
 		bind->u64 = field.ival;
@@ -382,10 +386,30 @@ lua_sql_bind_decode(struct lua_State *L, struct sql_bind *bind, int idx, int i)
 		diag_set(ClientError, ER_SQL_BIND_TYPE, "USERDATA",
 			 sql_bind_name(bind));
 		return -1;
-	case MP_ARRAY:
-		diag_set(ClientError, ER_SQL_BIND_TYPE, "ARRAY",
-			 sql_bind_name(bind));
+	case MP_ARRAY: {
+		size_t used = region_used(region);
+		struct mpstream stream;
+		bool is_error = false;
+		mpstream_init(&stream, region, region_reserve_cb,
+			      region_alloc_cb, set_encode_error, &is_error);
+		lua_pushvalue(L, -1);
+		luamp_encode_r(L, luaL_msgpack_default, &stream, &field, 0);
+		lua_pop(L, 1);
+		mpstream_flush(&stream);
+		if (is_error) {
+			region_truncate(region, used);
+			diag_set(OutOfMemory, stream.pos - stream.buf,
+				 "mpstream_flush", "stream");
+			return -1;
+		}
+		bind->bytes = region_used(region) - used;
+		bind->s = region_join(region, bind->bytes);
+		if (bind->s != NULL)
+			break;
+		region_truncate(region, used);
+		diag_set(OutOfMemory, bind->bytes, "region_join", "bind->s");
 		return -1;
+	}
 	case MP_MAP:
 		diag_set(ClientError, ER_SQL_BIND_TYPE, "MAP",
 			 sql_bind_name(bind));
@@ -393,8 +417,6 @@ lua_sql_bind_decode(struct lua_State *L, struct sql_bind *bind, int idx, int i)
 	default:
 		unreachable();
 	}
-	bind->type = field.type;
-	bind->ext_type = field.ext_type;
 	lua_pop(L, lua_gettop(L) - idx);
 	return 0;
 }
diff --git a/src/box/sql/sqlInt.h b/src/box/sql/sqlInt.h
index dcd71e5bd8..716110edc6 100644
--- a/src/box/sql/sqlInt.h
+++ b/src/box/sql/sqlInt.h
@@ -556,6 +556,9 @@ sql_bind_str_static(sql_stmt *stmt, int i, const char *str, uint32_t len);
 int
 sql_bind_bin_static(sql_stmt *stmt, int i, const char *str, uint32_t size);
 
+int
+sql_bind_array_static(sql_stmt *stmt, int i, const char *str, uint32_t size);
+
 int
 sql_bind_uuid(struct sql_stmt *stmt, int i, const struct tt_uuid *uuid);
 
diff --git a/src/box/sql/vdbeapi.c b/src/box/sql/vdbeapi.c
index 4ce5feeae9..3ea155d17b 100644
--- a/src/box/sql/vdbeapi.c
+++ b/src/box/sql/vdbeapi.c
@@ -524,6 +524,14 @@ sql_bind_bin_static(sql_stmt *stmt, int i, const char *str, uint32_t size)
 	return sql_bind_type(vdbe, i, "text");
 }
 
+int
+sql_bind_array_static(sql_stmt *stmt, int i, const char *str, uint32_t size)
+{
+	struct Vdbe *vdbe = (struct Vdbe *)stmt;
+	mem_set_array_static(&vdbe->aVar[i - 1], (char *)str, size);
+	return sql_bind_type(vdbe, i, "array");
+}
+
 int
 sql_bind_uuid(struct sql_stmt *stmt, int i, const struct tt_uuid *uuid)
 {
diff --git a/test/sql-tap/array.test.lua b/test/sql-tap/array.test.lua
index 79a1c831df..3387742bf9 100755
--- a/test/sql-tap/array.test.lua
+++ b/test/sql-tap/array.test.lua
@@ -1,6 +1,6 @@
 #!/usr/bin/env tarantool
 local test = require("sqltester")
-test:plan(115)
+test:plan(117)
 
 box.schema.func.create('A1', {
     language = 'Lua',
@@ -1024,6 +1024,29 @@ test:do_execsql_test(
         "array"
     })
 
+-- Make sure that ARRAY values can be used as bound variable.
+test:do_test(
+    "builtins-14.1",
+    function()
+        local res = box.execute([[SELECT #a;]], {{['#a'] = {1, 2, 3}}})
+        return {res.rows[1][1]}
+    end, {
+        {1, 2, 3}
+    })
+
+local remote = require('net.box')
+box.cfg{listen = os.getenv('LISTEN')}
+local cn = remote.connect(box.cfg.listen)
+test:do_test(
+    "builtins-14.2",
+    function()
+        local res = cn:execute([[SELECT #a;]], {{['#a'] = {1, 2, 3}}})
+        return {res.rows[1][1]}
+    end, {
+        {1, 2, 3}
+    })
+cn:close()
+
 box.execute([[DROP TABLE t1;]])
 box.execute([[DROP TABLE t;]])
 
diff --git a/test/sql/bind.result b/test/sql/bind.result
index cb03028854..f269056e22 100644
--- a/test/sql/bind.result
+++ b/test/sql/bind.result
@@ -249,13 +249,6 @@ execute('SELECT ? AS big_uint', {0xefffffffffffffff})
   - [17293822569102704640]
 ...
 -- Bind incorrect parameters.
-ok, err = pcall(execute, 'SELECT ?', { {1, 2, 3} })
----
-...
-ok
----
-- false
-...
 parameters = {}
 ---
 ...
diff --git a/test/sql/bind.test.lua b/test/sql/bind.test.lua
index 2ced7775a2..4ad227d95e 100644
--- a/test/sql/bind.test.lua
+++ b/test/sql/bind.test.lua
@@ -82,8 +82,6 @@ execute(sql, parameters)
 -- suitable method in its bind API.
 execute('SELECT ? AS big_uint', {0xefffffffffffffff})
 -- Bind incorrect parameters.
-ok, err = pcall(execute, 'SELECT ?', { {1, 2, 3} })
-ok
 parameters = {}
 parameters[1] = {}
 parameters[1][100] = 200
-- 
GitLab