diff --git a/changelogs/unreleased/gh-6990-type-of-CASE-operation.md b/changelogs/unreleased/gh-6990-type-of-CASE-operation.md new file mode 100644 index 0000000000000000000000000000000000000000..ad811658685f713a7ecbb3fd79d2722890b2b341 --- /dev/null +++ b/changelogs/unreleased/gh-6990-type-of-CASE-operation.md @@ -0,0 +1,3 @@ +## feature/sql + +* New rules are applied to determine the type of CASE operation (gh-6990). diff --git a/src/box/sql/expr.c b/src/box/sql/expr.c index c7b2820feae0161f7f96aa1931944374be076864..e32da63b1b44b245c79c3de7bc6e6d52d0f34ce4 100644 --- a/src/box/sql/expr.c +++ b/src/box/sql/expr.c @@ -45,6 +45,37 @@ static void exprCodeBetween(Parse *, Expr *, int, void (*)(Parse *, Expr *, int, int), int); static int exprCodeVector(Parse * pParse, Expr * p, int *piToFree); +/** + * Determine the highest type between the given type and the type of the given + * expression. + */ +static enum field_type +sql_highest_type(enum field_type a, struct Expr *expr) +{ + if (a == FIELD_TYPE_ANY || expr->op == TK_VARIABLE) + return FIELD_TYPE_ANY; + if (expr->op == TK_NULL) + return a; + enum field_type b = sql_expr_type(expr); + if (a == b) + return a; + if (b == FIELD_TYPE_ANY || a == FIELD_TYPE_MAP || b == FIELD_TYPE_MAP || + a == FIELD_TYPE_ARRAY || b == FIELD_TYPE_ARRAY || + a == FIELD_TYPE_INTERVAL || b == FIELD_TYPE_INTERVAL) + return FIELD_TYPE_ANY; + if (!sql_type_is_numeric(a) || !sql_type_is_numeric(b)) + return FIELD_TYPE_SCALAR; + if (a == FIELD_TYPE_NUMBER || b == FIELD_TYPE_NUMBER) + return FIELD_TYPE_NUMBER; + if (a == FIELD_TYPE_DECIMAL || b == FIELD_TYPE_DECIMAL) + return FIELD_TYPE_DECIMAL; + if (a == FIELD_TYPE_DOUBLE || b == FIELD_TYPE_DOUBLE) + return FIELD_TYPE_DOUBLE; + assert((a == FIELD_TYPE_INTEGER || a == FIELD_TYPE_UNSIGNED)); + assert((b == FIELD_TYPE_INTEGER || b == FIELD_TYPE_UNSIGNED)); + return FIELD_TYPE_INTEGER; +} + enum field_type sql_expr_type(struct Expr *pExpr) { @@ -95,24 +126,25 @@ sql_expr_type(struct Expr *pExpr) * WHEN and one THEN clauses. So, first * expression always represents WHEN * argument, and the second one - THEN. - * In case at least one type of THEN argument - * is different from others then we can't - * determine type of returning value at compiling - * stage and set SCALAR (i.e. most general) type. */ - enum field_type ref_type = sql_expr_type(cs->a[1].pExpr); - for (int i = 3; i < cs->nExpr; i += 2) { - if (ref_type != sql_expr_type(cs->a[i].pExpr)) - return FIELD_TYPE_SCALAR; - } + uint32_t i = 1; + uint32_t count = cs->nExpr; + while (i < count && cs->a[i].pExpr->op == TK_NULL) + i += 2; + if (i >= count) + return FIELD_TYPE_ANY; + enum field_type res_type = sql_expr_type(cs->a[i].pExpr); + if (cs->a[i].pExpr->op == TK_VARIABLE) + res_type = FIELD_TYPE_ANY; + for (i += 2; i < count; i += 2) + res_type = sql_highest_type(res_type, cs->a[i].pExpr); /* * ELSE clause is optional but we should check * its type as well. */ - if (cs->nExpr % 2 == 1 && - ref_type != sql_expr_type(cs->a[cs->nExpr - 1].pExpr)) - return FIELD_TYPE_SCALAR; - return ref_type; + if (count % 2 == 0) + return res_type; + return sql_highest_type(res_type, cs->a[count - 1].pExpr); } case TK_LT: case TK_GT: @@ -4517,6 +4549,15 @@ sqlExprCodeTarget(Parse * pParse, Expr * pExpr, int target) assert(pParse->db->mallocFailed || pParse->is_aborted || pParse->iCacheLevel == iCacheLevel); sqlVdbeResolveLabel(v, endLabel); + enum field_type *type = + sqlDbMallocZero(pParse->db, sizeof(*type)); + if (type == NULL) { + pParse->is_aborted = true; + break; + } + type[0] = sql_expr_type(pExpr); + sqlVdbeAddOp4(v, OP_ApplyType, target, 1, 0, + (char *)type, P4_DYNAMIC); break; } case TK_RAISE: diff --git a/test/sql-luatest/gh_6990_case_operation_type_test.lua b/test/sql-luatest/gh_6990_case_operation_type_test.lua new file mode 100644 index 0000000000000000000000000000000000000000..b2c9c03a117c69d56fbbc6adcfbf666516a0ee92 --- /dev/null +++ b/test/sql-luatest/gh_6990_case_operation_type_test.lua @@ -0,0 +1,69 @@ +local server = require('test.luatest_helpers.server') +local t = require('luatest') +local g = t.group() + +g.before_all(function() + g.server = server:new({alias = 'test_case_operation_type'}) + g.server:start() +end) + +g.after_all(function() + g.server:stop() +end) + +g.test_case_operation_type = function() + g.server:exec(function() + local t = require('luatest') + local sql = [[SELECT CASE 1 WHEN 1 THEN NULL ELSE NULL END;]] + local res = "any" + t.assert_equals(box.execute(sql).metadata[1].type, res) + + sql = [[SELECT CASE 1 WHEN 1 THEN 1 ELSE ? END;]] + res = "any" + t.assert_equals(box.execute(sql, {1}).metadata[1].type, res) + + sql = [[SELECT CASE 1 WHEN 1 THEN [1] ELSE [2, 2] END;]] + res = "array" + t.assert_equals(box.execute(sql).metadata[1].type, res) + + sql = [[SELECT CASE 1 WHEN 1 THEN 1 ELSE {1 : 1} END;]] + res = "any" + t.assert_equals(box.execute(sql).metadata[1].type, res) + + sql = [[SELECT CASE 1 WHEN 1 THEN 1 ELSE {1 : 1} END;]] + res = "any" + t.assert_equals(box.execute(sql).metadata[1].type, res) + + sql = [[SELECT CASE 1 WHEN 1 THEN 1 ELSE 'asd' END;]] + res = "scalar" + t.assert_equals(box.execute(sql).metadata[1].type, res) + + sql = [[SELECT CASE 1 WHEN 1 THEN 1 ELSE CAST(1 AS NUMBER) END;]] + res = "number" + t.assert_equals(box.execute(sql).metadata[1].type, res) + + sql = [[SELECT CASE 1 WHEN 1 THEN -1 ELSE CAST(1 AS UNSIGNED) END;]] + res = "integer" + t.assert_equals(box.execute(sql).metadata[1].type, res) + + sql = [[SELECT CASE 1 WHEN 1 THEN -1 ELSE 1.5e0 END;]] + res = "double" + t.assert_equals(box.execute(sql).metadata[1].type, res) + + sql = [[SELECT CASE 1 WHEN 1 THEN -1 WHEN 2 THEN 1.5 ELSE 2e0 END;]] + res = "decimal" + t.assert_equals(box.execute(sql).metadata[1].type, res) + + sql = [[SELECT typeof(CASE 1 WHEN 1 THEN 1 ELSE {1 : 1} END);]] + res = "any" + t.assert_equals(box.execute(sql).rows[1][1], res) + + sql = [[SELECT typeof(CASE 1 WHEN 1 THEN 1 ELSE 'asd' END);]] + res = "scalar" + t.assert_equals(box.execute(sql).rows[1][1], res) + + sql = [[SELECT typeof(CASE 1 WHEN 1 THEN -1 ELSE 1.5e0 END);]] + res = "double" + t.assert_equals(box.execute(sql).rows[1][1], res) + end) +end