From 95b9c87b7a4f059c7a842fe82f8f23f1b020191c Mon Sep 17 00:00:00 2001 From: Mergen Imeev <imeevma@gmail.com> Date: Tue, 19 Apr 2022 15:31:59 +0300 Subject: [PATCH] sql: allow to bind INTERVAL values Part of #6773 NO_DOC=Will be added later. NO_CHANGELOG=Will be added later. --- src/box/bind.c | 22 ++++++++++++++--- src/box/bind.h | 2 ++ src/box/lua/execute.c | 4 +++ src/box/sql/sqlInt.h | 4 +++ src/box/sql/vdbeapi.c | 10 ++++++++ test/sql-luatest/interval_test.lua | 39 ++++++++++++++++++++++++++++++ 6 files changed, 77 insertions(+), 4 deletions(-) diff --git a/src/box/bind.c b/src/box/bind.c index 6b7f9365dc..bfb8f73903 100644 --- a/src/box/bind.c +++ b/src/box/bind.c @@ -34,6 +34,7 @@ #include "sql/sqlInt.h" #include "sql/sqlLimit.h" #include "sql/vdbe.h" +#include "mp_interval.h" #include "mp_datetime.h" #include "mp_decimal.h" #include "mp_uuid.h" @@ -127,6 +128,13 @@ sql_bind_decode(struct sql_bind *bind, int i, const char **packet) return -1; } break; + case MP_INTERVAL: + if (interval_unpack(packet, &bind->itv) == NULL) { + diag_set(ClientError, ER_INVALID_MSGPACK, + "Invalid MP_INTERVAL MsgPack format"); + return -1; + } + break; default: diag_set(ClientError, ER_SQL_BIND_TYPE, "USERDATA", sql_bind_name(bind)); @@ -224,12 +232,18 @@ sql_bind_column(struct sql_stmt *stmt, const struct sql_bind *p, case MP_MAP: return sql_bind_map_static(stmt, pos, p->s, p->bytes); case MP_EXT: - if (p->ext_type == MP_UUID) + switch (p->ext_type) { + case MP_UUID: return sql_bind_uuid(stmt, pos, &p->uuid); - else if (p->ext_type == MP_DECIMAL) + case MP_DECIMAL: return sql_bind_dec(stmt, pos, &p->dec); - assert(p->ext_type == MP_DATETIME); - return sql_bind_datetime(stmt, pos, &p->dt); + case MP_DATETIME: + return sql_bind_datetime(stmt, pos, &p->dt); + case MP_INTERVAL: + return sql_bind_interval(stmt, pos, &p->itv); + default: + unreachable(); + } default: unreachable(); } diff --git a/src/box/bind.h b/src/box/bind.h index be02b7f6ef..414d327dfe 100644 --- a/src/box/bind.h +++ b/src/box/bind.h @@ -77,6 +77,8 @@ struct sql_bind { decimal_t dec; /** DATETIME value. */ struct datetime dt; + /** INTERVAL value. */ + struct interval itv; }; }; diff --git a/src/box/lua/execute.c b/src/box/lua/execute.c index ad6d606d78..25dc7a0f90 100644 --- a/src/box/lua/execute.c +++ b/src/box/lua/execute.c @@ -387,6 +387,10 @@ lua_sql_bind_decode(struct lua_State *L, struct sql_bind *bind, int idx, int i) bind->dt = *field.dateval; break; } + if (field.ext_type == MP_INTERVAL) { + bind->itv = *field.interval; + break; + } diag_set(ClientError, ER_SQL_BIND_TYPE, "USERDATA", sql_bind_name(bind)); return -1; diff --git a/src/box/sql/sqlInt.h b/src/box/sql/sqlInt.h index 3b9f09ba3a..393b5e9a79 100644 --- a/src/box/sql/sqlInt.h +++ b/src/box/sql/sqlInt.h @@ -573,6 +573,10 @@ sql_bind_dec(struct sql_stmt *stmt, int i, const decimal_t *dec); int sql_bind_datetime(struct sql_stmt *stmt, int i, const struct datetime *dt); +/** Perform INTERVAL parameter binding for the SQL statement. */ +int +sql_bind_interval(struct sql_stmt *stmt, int i, const struct interval *itv); + /** * Return the number of wildcards that should be bound to. */ diff --git a/src/box/sql/vdbeapi.c b/src/box/sql/vdbeapi.c index 2316769e19..f9fd2524f8 100644 --- a/src/box/sql/vdbeapi.c +++ b/src/box/sql/vdbeapi.c @@ -570,6 +570,16 @@ sql_bind_datetime(struct sql_stmt *stmt, int i, const struct datetime *dt) return 0; } +int +sql_bind_interval(struct sql_stmt *stmt, int i, const struct interval *itv) +{ + struct Vdbe *p = (struct Vdbe *)stmt; + if (vdbeUnbind(p, i) != 0 || sql_bind_type(p, i, "interval") != 0) + return -1; + mem_set_interval(&p->aVar[i - 1], itv); + return 0; +} + int sql_bind_parameter_count(const struct sql_stmt *stmt) { diff --git a/test/sql-luatest/interval_test.lua b/test/sql-luatest/interval_test.lua index 1b6cc6608b..0891e3614e 100644 --- a/test/sql-luatest/interval_test.lua +++ b/test/sql-luatest/interval_test.lua @@ -1506,3 +1506,42 @@ g.test_interval_26_9 = function() t.assert_equals(err.message, res) end) end + +-- Make sure that DATETIME value can be bound. +g.test_datetime_27_1 = function() + g.server:exec(function() + local t = require('luatest') + local itv = require('datetime').interval + local itv1 = itv.new({year = 1, month = 2, day = 3, hour = 4}) + local rows = box.execute([[SELECT ?;]], {itv1}).rows + t.assert_equals(rows, {{itv1}}) + end) +end + +g.test_datetime_27_2 = function() + g.server:exec(function() + local t = require('luatest') + local itv = require('datetime').interval + local itv1 = itv.new({year = 1, month = 2, day = 3, hour = 4}) + local rows = box.execute([[SELECT $1;]], {itv1}).rows + t.assert_equals(rows, {{itv1}}) + end) +end + +g.test_datetime_27_3 = function() + g.server:exec(function() + local t = require('luatest') + local itv = require('datetime').interval + local itv1 = itv.new({year = 1, month = 2, day = 3, hour = 4}) + local rows = box.execute([[SELECT #a;]], {{['#a'] = itv1}}).rows + t.assert_equals(rows, {{itv1}}) + end) +end + +g.test_datetime_27_4 = function() + local conn = g.server.net_box + local itv = require('datetime').interval + local itv1 = itv.new({year = 1, month = 2, day = 3, hour = 4}) + local rows = conn:execute([[SELECT ?;]], {itv1}).rows + t.assert_equals(rows, {{itv1}}) +end -- GitLab