diff --git a/src/box/func.c b/src/box/func.c index ad0cab3abac1bc4aeaf8cbfaeb0b82470d647087..63cadaea4a890f11b2fd37b0a3595535a2436ff8 100644 --- a/src/box/func.c +++ b/src/box/func.c @@ -36,6 +36,7 @@ #include "port.h" #include "schema.h" #include "session.h" +#include "sql/func.h" /** * Parsed symbol and package names. @@ -393,6 +394,9 @@ func_new(struct func_def *def) case FUNC_LANGUAGE_LUA: func = func_lua_new(def); break; + case FUNC_LANGUAGE_SQL_EXPR: + func = func_sql_expr_new(def); + break; default: unreachable(); } diff --git a/src/box/func_def.c b/src/box/func_def.c index 55591d383d5b0eb768914deac60698cb4968990a..41814350b244949175c3b24c8eac5d250fad8340 100644 --- a/src/box/func_def.c +++ b/src/box/func_def.c @@ -34,7 +34,9 @@ #include "diag.h" #include "error.h" -const char *func_language_strs[] = {"LUA", "C", "SQL", "SQL_BUILTIN"}; +const char *func_language_strs[] = { + "LUA", "C", "SQL", "SQL_BUILTIN", "SQL_EXPR" +}; const char *func_aggregate_strs[] = {"none", "group"}; diff --git a/src/box/func_def.h b/src/box/func_def.h index a884c35ab525055875516d63e5d77804f58d04ad..22905618a709b4dbabb6f84649080951a7bdfa70 100644 --- a/src/box/func_def.h +++ b/src/box/func_def.h @@ -48,6 +48,7 @@ enum func_language { FUNC_LANGUAGE_C, FUNC_LANGUAGE_SQL, FUNC_LANGUAGE_SQL_BUILTIN, + FUNC_LANGUAGE_SQL_EXPR, func_language_MAX, }; diff --git a/src/box/sql.c b/src/box/sql.c index 2a37eae13305f9349b43d91bbac93945f7eaeeb9..4213dfa2248a19eb13e72fb6b23959033c3283e1 100644 --- a/src/box/sql.c +++ b/src/box/sql.c @@ -1317,6 +1317,7 @@ vdbe_field_ref_create(struct vdbe_field_ref *field_ref, struct tuple *tuple, const char *field0 = data; field_ref->field_count = mp_decode_array((const char **) &field0); + field_ref->format = NULL; field_ref->slots[0] = (uint32_t)(field0 - data); memset(&field_ref->slots[1], 0, field_ref->field_count * sizeof(field_ref->slots[0])); diff --git a/src/box/sql.h b/src/box/sql.h index bf8d7b7c5b03f1d849e5aa14f5f24723a6ba4f03..dc40c32e365940e7ecb62a96e5ef1bcf2d55c04a 100644 --- a/src/box/sql.h +++ b/src/box/sql.h @@ -376,6 +376,8 @@ struct vdbe_field_ref { uint32_t data_sz; /** Count of fields in tuple. */ uint32_t field_count; + /** Format that match data in field data. */ + struct tuple_format *format; /** * Bitmask of initialized slots. The fieldno == 0 slot * must be initialized in vdbe_field_ref constructor. diff --git a/src/box/sql/expr.c b/src/box/sql/expr.c index c97cdd724ead1743ce7273be39464cb25585eea5..cf5220b0203fb10db5603ff236684c84f7cfb8bd 100644 --- a/src/box/sql/expr.c +++ b/src/box/sql/expr.c @@ -3816,6 +3816,19 @@ sqlExprCodeTarget(Parse * pParse, Expr * pExpr, int target) return sqlExprCodeGetColumn(pParse, col, iTab, target, pExpr->op2); } + case TK_ID: + assert(pParse->vdbe_field_ref_reg > 0); + int reg = pParse->vdbe_field_ref_reg; + sqlVdbeAddOp4(v, OP_Fetch, reg, 0, target, + sqlDbStrDup(pParse->db, pExpr->u.zToken), + P4_DYNAMIC); + return target; + case TK_DOT: + assert(pParse->vdbe_field_ref_reg > 0); + diag_set(ClientError, ER_UNSUPPORTED, "SQL expressions", + "reference to spaces"); + pParse->is_aborted = true; + return target; case TK_INTEGER:{ expr_code_int(pParse, pExpr, false, target); return target; @@ -4206,6 +4219,12 @@ sqlExprCodeTarget(Parse * pParse, Expr * pExpr, int target) } case TK_EXISTS: case TK_SELECT:{ + if (pParse->vdbe_field_ref_reg > 0) { + diag_set(ClientError, ER_UNSUPPORTED, + "SQL expressions", "subselects"); + pParse->is_aborted = true; + return target; + } int nCol; if (op == TK_SELECT && (nCol = pExpr->x.pSelect->pEList->nExpr) != 1) { diff --git a/src/box/sql/func.c b/src/box/sql/func.c index 163ae992cb8254a3587e8fd4a78dfcaa14c3bf29..42d825b0b1f453f1ea3a8b7bfb6a4b80122e3773 100644 --- a/src/box/sql/func.c +++ b/src/box/sql/func.c @@ -36,6 +36,8 @@ */ #include "sqlInt.h" #include "mem.h" +#include "port.h" +#include "func.h" #include "vdbeInt.h" #include "version.h" #include "coll/coll.h" @@ -47,6 +49,9 @@ #include <unicode/ucol.h> #include "box/coll_id_cache.h" #include "box/func_cache.h" +#include "box/execute.h" +#include "box/session.h" +#include "box/tuple_format.h" #include "box/user.h" #include "assoc.h" @@ -2097,7 +2102,7 @@ is_upcast(int op, enum field_type a, enum field_type b) static inline bool is_castable(int op, enum field_type a, enum field_type b) { - return is_upcast(op, a, b) || op == TK_VARIABLE || + return is_upcast(op, a, b) || op == TK_VARIABLE || op == TK_ID || (sql_type_is_numeric(a) && sql_type_is_numeric(b)) || b == FIELD_TYPE_ANY; } @@ -2336,3 +2341,134 @@ static struct func_vtab func_sql_builtin_vtab = { .call = func_sql_builtin_call_stub, .destroy = func_sql_builtin_destroy, }; + +/** Table of methods of SQL user-defined functions. */ +static struct func_vtab func_sql_expr_vtab; + +/** SQL user-defined function. */ +struct func_sql_expr { + /** Function object base class. */ + struct func base; + /** Prepared SQL statement. */ + struct Vdbe *stmt; +}; + +struct func * +func_sql_expr_new(struct func_def *def) +{ + struct sql *db = sql_get(); + const char *body = def->body; + uint32_t body_len = body == NULL ? 0 : strlen(body); + struct Expr *expr = sql_expr_compile(db, body, body_len); + if (expr == NULL) + return NULL; + + struct Parse parser; + sql_parser_create(&parser, db, default_flags); + struct Vdbe *v = sqlGetVdbe(&parser); + if (v == NULL) { + sql_parser_destroy(&parser); + sql_expr_delete(db, expr); + return NULL; + } + int ref_reg = ++parser.nMem; + sqlVdbeAddOp2(v, OP_Variable, ++parser.nVar, ref_reg); + parser.vdbe_field_ref_reg = ref_reg; + + sqlVdbeSetNumCols(v, 1); + vdbe_metadata_set_col_name(v, 0, def->name); + vdbe_metadata_set_col_type(v, 0, field_type_strs[def->returns]); + int res_reg = sqlExprCodeTarget(&parser, expr, ++parser.nMem); + sqlVdbeAddOp2(v, OP_ResultRow, res_reg, 1); + + bool is_error = parser.is_aborted; + sql_finish_coding(&parser); + sql_parser_destroy(&parser); + sql_expr_delete(db, expr); + + if (is_error) { + sql_stmt_finalize((struct sql_stmt *)v); + return NULL; + } + struct func_sql_expr *func = xmalloc(sizeof(*func)); + func->stmt = v; + func->base.vtab = &func_sql_expr_vtab; + return &func->base; +} + +int +func_sql_expr_call(struct func *func, struct port *args, struct port *ret) +{ + struct func_sql_expr *func_sql = (struct func_sql_expr *)func; + struct sql_stmt *stmt = (struct sql_stmt *)func_sql->stmt; + const char *data; + uint32_t mp_size; + struct tuple_format *format; + if (args->vtab == &port_c_vtab && ((struct port_c *)args)->size == 2 && + ((struct port_c *)args)->first_entry.mp_format != NULL) { + /* + * The only case where mp_format is not NULL is when the + * function is used as a CHECK constraint. + */ + struct port_c_entry *pe = ((struct port_c *)args)->first; + data = pe->mp; + mp_size = pe->mp_size; + format = pe->mp_format; + } else { + diag_set(ClientError, ER_UNSUPPORTED, "Tarantool", + "SQL functions"); + return -1; + } + + struct region *region = &fiber()->gc; + size_t svp = region_used(region); + port_sql_create(ret, stmt, DQL_EXECUTE, false); + /* + * In SQL, we can only retrieve fields that have names. There is no + * point to prepare slots for nameless fields. + */ + uint32_t count = format->min_field_count; + struct vdbe_field_ref *ref; + size_t size = sizeof(ref->slots[0]) * count + sizeof(*ref); + ref = region_aligned_alloc(region, size, alignof(*ref)); + vdbe_field_ref_prepare_data(ref, data, mp_size); + ref->format = format; + if (sql_bind_ptr(stmt, 1, ref) != 0) + goto error; + + if (sql_step(stmt) != SQL_ROW) + goto error; + + uint32_t res_size; + char *pos = sql_stmt_func_result_to_msgpack(stmt, &res_size, region); + if (pos == NULL) + goto error; + int rc = port_c_add_mp(ret, pos, pos + res_size); + if (rc != 0) + goto error; + + if (sql_step(stmt) != SQL_DONE) + goto error; + + sql_stmt_reset(stmt); + region_truncate(region, svp); + return 0; +error: + sql_stmt_reset(stmt); + region_truncate(region, svp); + port_destroy(ret); + return -1; +} + +void +func_sql_expr_destroy(struct func *base) +{ + struct func_sql_expr *func = (struct func_sql_expr *)base; + sql_stmt_finalize((struct sql_stmt *)func->stmt); + free(func); +} + +static struct func_vtab func_sql_expr_vtab = { + .call = func_sql_expr_call, + .destroy = func_sql_expr_destroy, +}; diff --git a/src/box/sql/func.h b/src/box/sql/func.h new file mode 100644 index 0000000000000000000000000000000000000000..977a584e5f978a6dab4843e0dee8ba942d6830ad --- /dev/null +++ b/src/box/sql/func.h @@ -0,0 +1,21 @@ +/* + * SPDX-License-Identifier: BSD-2-Clause + * + * Copyright 2010-2022, Tarantool AUTHORS, please see AUTHORS file. + */ +#pragma once + +#if defined(__cplusplus) +extern "C" { +#endif /* defined(__cplusplus) */ + +struct func; +struct func_def; + +/** Create new SQL user-defined function. */ +struct func * +func_sql_expr_new(struct func_def *def); + +#if defined(__cplusplus) +} /* extern "C" */ +#endif /* defined __cplusplus */ diff --git a/src/box/sql/mem.c b/src/box/sql/mem.c index 54b9205908e8babd31a5a2e10d9bac697601d59d..95672260d7b833441c96465f7f6f4b55de4c46cd 100644 --- a/src/box/sql/mem.c +++ b/src/box/sql/mem.c @@ -3255,6 +3255,32 @@ mem_to_mpstream(const struct Mem *var, struct mpstream *stream) } } +char * +mem_to_mp(const struct Mem *mem, uint32_t *size, struct region *region) +{ + size_t used = region_used(region); + bool is_error = false; + struct mpstream stream; + mpstream_init(&stream, region, region_reserve_cb, region_alloc_cb, + set_encode_error, &is_error); + mem_to_mpstream(mem, &stream); + mpstream_flush(&stream); + if (is_error) { + region_truncate(region, used); + diag_set(OutOfMemory, stream.pos - stream.buf, + "mpstream_flush", "stream"); + return NULL; + } + *size = region_used(region) - used; + char *data = region_join(region, *size); + if (data == NULL) { + region_truncate(region, used); + diag_set(OutOfMemory, *size, "region_join", "data"); + return NULL; + } + return data; +} + char * mem_encode_array(const struct Mem *mems, uint32_t count, uint32_t *size, struct region *region) diff --git a/src/box/sql/mem.h b/src/box/sql/mem.h index c48ec076be0f26642de621c32d89f51b462bb285..7bcb76a788152fcc631dd08386c8b49bda16a92a 100644 --- a/src/box/sql/mem.h +++ b/src/box/sql/mem.h @@ -918,6 +918,10 @@ mem_from_mp(struct Mem *mem, const char *buf, uint32_t *len); void mem_to_mpstream(const struct Mem *var, struct mpstream *stream); +/** Encode MEM as msgpack value on region. */ +char * +mem_to_mp(const struct Mem *mem, uint32_t *size, struct region *region); + /** * Encode array of MEMs as msgpack array on region. * diff --git a/src/box/sql/port.c b/src/box/sql/port.c index 38c744ea12c4145947d0f87a606483ca07827326..9e05a4ce8115f6e88d47bf804ef249ecc48a9de4 100644 --- a/src/box/sql/port.c +++ b/src/box/sql/port.c @@ -354,6 +354,12 @@ port_sql_dump_msgpack(struct port *port, struct obuf *out) return 0; } +static const char * +port_sql_get_msgpack(struct port *base, uint32_t *size) +{ + return port_c_vtab.get_msgpack(base, size); +} + static void port_sql_destroy(struct port *base) { @@ -368,7 +374,7 @@ const struct port_vtab port_sql_vtab = { .dump_msgpack_16 = NULL, .dump_lua = port_sql_dump_lua, .dump_plain = NULL, - .get_msgpack = NULL, + .get_msgpack = port_sql_get_msgpack, .get_vdbemem = NULL, .destroy = port_sql_destroy, }; diff --git a/src/box/sql/sqlInt.h b/src/box/sql/sqlInt.h index 4b95045d791ccfed079c4b90a7f37fc09522e313..ef849896f13d567460648ca1761098f93498bfb2 100644 --- a/src/box/sql/sqlInt.h +++ b/src/box/sql/sqlInt.h @@ -368,6 +368,14 @@ char * sql_stmt_result_to_msgpack(struct sql_stmt *stmt, uint32_t *tuple_size, struct region *region); +/** + * Encode SQL function result in msgpack. The result is not packed into + * MP_ARRAY. + */ +char * +sql_stmt_func_result_to_msgpack(struct sql_stmt *stmt, uint32_t *tuple_size, + struct region *region); + /* * Terminate the current execution of an SQL statement and reset * it back to its starting state so that it can be reused. diff --git a/src/box/sql/vdbe.c b/src/box/sql/vdbe.c index 01623a7e246e0f72e67a01c6c7d8f9d6cd918caf..6d3f95d2d684106184e3cd9042c00f9a1467c924 100644 --- a/src/box/sql/vdbe.c +++ b/src/box/sql/vdbe.c @@ -1971,18 +1971,27 @@ case OP_Column: { * Interpret data P1 points at as an initialized vdbe_field_ref * object. * - * Fetch the P2-th column from its tuple. The value extracted - * is stored in register P3. If the column contains fewer than - * P2 fields, then extract a NULL. + * If P4 is not a field name, extract the P2th field from the tuple. Otherwise, + * get the field with the given name. The retrieved value is stored in + * register P3. */ case OP_Fetch: { - struct vdbe_field_ref *field_ref = - (struct vdbe_field_ref *) p->aMem[pOp->p1].u.p; - uint32_t field_idx = pOp->p2; - struct Mem *dest_mem = vdbe_prepare_null_out(p, pOp->p3); - if (vdbe_field_ref_fetch(field_ref, field_idx, dest_mem) != 0) + struct vdbe_field_ref *ref = p->aMem[pOp->p1].u.p; + uint32_t id = pOp->p2; + if (pOp->p4type == P4_DYNAMIC) { + const char *name = pOp->p4.z; + uint32_t len = strlen(name); + uint32_t hash = field_name_hash(name, len); + struct tuple_dictionary *dict = ref->format->dict; + if (tuple_fieldno_by_name(dict, name, len, hash, &id) != 0) { + diag_set(ClientError, ER_SQL_CANT_RESOLVE_FIELD, name); + goto abort_due_to_error; + } + } + struct Mem *res = vdbe_prepare_null_out(p, pOp->p3); + if (vdbe_field_ref_fetch(ref, id, res) != 0) goto abort_due_to_error; - REGISTER_TRACE(p, pOp->p3, dest_mem); + REGISTER_TRACE(p, pOp->p3, res); break; } diff --git a/src/box/sql/vdbeapi.c b/src/box/sql/vdbeapi.c index f9fd2524f8fd3c712e65b6789793cb4b91fa96eb..48ea3a54cf92b0a4eb5babb0a0bb99f92ab9fe25 100644 --- a/src/box/sql/vdbeapi.c +++ b/src/box/sql/vdbeapi.c @@ -216,6 +216,15 @@ sql_stmt_result_to_msgpack(struct sql_stmt *stmt, uint32_t *tuple_size, region); } +char * +sql_stmt_func_result_to_msgpack(struct sql_stmt *stmt, uint32_t *size, + struct region *region) +{ + struct Vdbe *vdbe = (struct Vdbe *)stmt; + assert(vdbe->nResColumn == 1); + return mem_to_mp(vdbe->pResultSet, size, region); +} + /* * Return the name of the Nth column of the result set returned by SQL * statement pStmt. diff --git a/test/sql-luatest/sql_func_expr_test.lua b/test/sql-luatest/sql_func_expr_test.lua new file mode 100644 index 0000000000000000000000000000000000000000..c84498d23ed441c4f0e9d08362ac15af08b72dc2 --- /dev/null +++ b/test/sql-luatest/sql_func_expr_test.lua @@ -0,0 +1,97 @@ +local server = require('test.luatest_helpers.server') +local t = require('luatest') +local g = t.group() + +g.before_all(function() + g.server = server:new({alias = 'sql_func_expr'}) + g.server:start() +end) + +g.after_all(function() + g.server:stop() +end) + +-- Make sure CHECK constraint works as intended. +g.test_sql_func_expr_1 = function() + g.server:exec(function() + local t = require('luatest') + local def = {language = 'SQL_EXPR', is_deterministic = true, + body = 'a * b > 10'} + box.schema.func.create('abc', def) + local format = {{'A', 'integer'}, {'B', 'integer'}} + local s = box.schema.space.create('test', {format = format}) + s:create_index('i') + s:alter{constraint='abc'} + t.assert_equals(s:insert{3, 4}, {3, 4}) + t.assert_error_msg_content_equals( + "Check constraint 'abc' failed for tuple", + function() s:insert{1, 2} end + ) + t.assert_error_msg_content_equals( + "Check constraint 'abc' failed for tuple", + function() s:insert{true, 2} end + ) + box.space.test:drop() + box.schema.func.drop('abc') + end) +end + +-- Make sure SQL_EXPRESSION function parsed properly. +g.test_sql_func_expr_2 = function() + g.server:exec(function() + local t = require('luatest') + local def = {language = 'SQL_EXPR', is_deterministic = true, body = ''} + t.assert_error_msg_content_equals( + "Syntax error at line 1 near ' '", + function() box.schema.func.create('a1', def) end + ) + + def.body = ' ' + t.assert_error_msg_content_equals( + "Syntax error at line 1 near ' '", + function() box.schema.func.create('a1', def) end + ) + + def.body = '1, 1 ' + t.assert_error_msg_content_equals( + "Syntax error at line 1 near ','", + function() box.schema.func.create('a1', def) end + ) + + def.body = 'a + (SELECT "id" AS a FROM "_space" LIMIT 1);' + t.assert_error_msg_content_equals( + "SQL expressions does not support subselects", + function() box.schema.func.create('a1', def) end + ) + end) +end + +-- Make sure SQL EXPR recovers properly after restart. +g.test_sql_func_expr_3 = function() + g.server:exec(function() + local t = require('luatest') + local def = {language = 'SQL_EXPR', is_deterministic = true, + body = 'a * b > 10'} + box.schema.func.create('abc', def) + local format = {{'A', 'integer'}, {'B', 'integer'}} + local s = box.schema.create_space('test', {format = format}) + s:create_index('i') + s:alter{constraint='abc'} + t.assert_error_msg_content_equals( + "Check constraint 'abc' failed for tuple", + function() s:insert{1, 1} end + ) + end) + g.server:restart() + g.server:exec(function() + local t = require('luatest') + t.assert_equals(box.func.abc.language, 'SQL_EXPR') + t.assert_error_msg_content_equals( + "Check constraint 'abc' failed for tuple", + function() box.space.test:insert{2, 2} end + ) + t.assert_equals(box.space.test:insert{7, 7}, {7, 7}) + box.space.test:drop() + box.schema.func.drop('abc') + end) +end