diff --git a/changelogs/unreleased/gh-8054-lua-iproto-encoder.md b/changelogs/unreleased/gh-8054-lua-iproto-encoder.md new file mode 100644 index 0000000000000000000000000000000000000000..dba4900d8822740bb9460b28b2ff63ff7cdbe1c7 --- /dev/null +++ b/changelogs/unreleased/gh-8054-lua-iproto-encoder.md @@ -0,0 +1,3 @@ +## feature/lua + +* Introduced helpers for encoding and decoding IPROTO packets in Lua (gh-8054). diff --git a/src/box/iproto.cc b/src/box/iproto.cc index b5dc2ffbe2246b35b8e720bb7370bfae90d663b8..d1a8944e1c3e47107de24c4c404252a96b3ac33e 100644 --- a/src/box/iproto.cc +++ b/src/box/iproto.cc @@ -79,7 +79,6 @@ #include "mpstream/mpstream.h" enum { - IPROTO_SALT_SIZE = 32, IPROTO_PACKET_SIZE_MAX = 2UL * 1024 * 1024 * 1024, }; diff --git a/src/box/iproto_constants.h b/src/box/iproto_constants.h index af7f432c77cfde2722d5d7ab24c3f258ac7c7e15..cf9792a48305641e8bdd6604f40be17da96b82b6 100644 --- a/src/box/iproto_constants.h +++ b/src/box/iproto_constants.h @@ -43,8 +43,10 @@ extern "C" { enum { /** Maximal iproto package body length (2GiB) */ IPROTO_BODY_LEN_MAX = 2147483648UL, - /* Maximal length of text handshake (greeting) */ + /** Size of iproto greeting message. */ IPROTO_GREETING_SIZE = 128, + /** Size of salt sent in iproto greeting message. */ + IPROTO_SALT_SIZE = 32, /** marker + len + prev crc32 + cur crc32 + (padding) */ XLOG_FIXHEADER_SIZE = 19 }; diff --git a/src/box/lua/iproto.c b/src/box/lua/iproto.c index b7ad78252b356973d43d9ab6e99562186ef355df..c6fd800790e88e255dac10d049ad659b23893b84 100644 --- a/src/box/lua/iproto.c +++ b/src/box/lua/iproto.c @@ -12,12 +12,18 @@ #include "box/iproto_constants.h" #include "box/iproto_features.h" #include "box/user.h" +#include "box/xrow.h" + +#include "version.h" #include "core/assoc.h" #include "core/iostream.h" #include "core/fiber.h" #include "core/mp_ctx.h" +#include "core/random.h" + #include "core/tt_static.h" +#include "core/tt_uuid.h" #include "lua/msgpack.h" #include "lua/utils.h" @@ -150,6 +156,12 @@ push_iproto_raft_keys_enum(struct lua_State *L) static void push_iproto_constants(struct lua_State *L) { + lua_pushinteger(L, IPROTO_GREETING_SIZE); + lua_setfield(L, -2, "GREETING_SIZE"); + lua_pushinteger(L, GREETING_PROTOCOL_LEN_MAX); + lua_setfield(L, -2, "GREETING_PROTOCOL_LEN_MAX"); + lua_pushinteger(L, GREETING_SALT_LEN_MAX); + lua_setfield(L, -2, "GREETING_SALT_LEN_MAX"); push_iproto_flag_enum(L); push_iproto_key_enum(L); push_iproto_metadata_key_enum(L); @@ -380,6 +392,283 @@ lbox_iproto_override(struct lua_State *L) return 0; } +/** + * Encodes Tarantool greeting message. + * + * Takes a table with the following fields that will be used in + * the greeting (all fields are optional): + * - version: Tarantool version string in the form 'X.Y.Z'. + * Default: current Tarantool version. + * - uuid: Instance UUID string. Default: Some random UUID. + * (We don't use INSTANCE_UUID because it may be uninitialized.) + * - salt: Salt string (used for authentication). + * Default: Some random salt string. + * + * Returns the encoded greeting message string on success. + * Raises an error on invalid arguments. + */ +static int +lbox_iproto_encode_greeting(struct lua_State *L) +{ + int n_args = lua_gettop(L); + if (n_args == 0) { + lua_newtable(L); + } else if (n_args != 1 || lua_type(L, 1) != LUA_TTABLE) { + return luaL_error(L, "Usage: box.iproto.encode_greeting({" + "version = x, uuid = x, salt = x})"); + } + + uint32_t version; + lua_getfield(L, 1, "version"); + if (lua_isnil(L, -1)) { + version = tarantool_version_id(); + } else if (lua_type(L, -1) == LUA_TSTRING) { + const char *str = lua_tostring(L, -1); + unsigned major, minor, patch; + if (sscanf(str, "%u.%u.%u", &major, &minor, &patch) != 3) + return luaL_error(L, "cannot parse version string"); + version = version_id(major, minor, patch); + } else { + return luaL_error(L, "version must be a string"); + } + lua_pop(L, 1); /* version */ + + struct tt_uuid uuid; + lua_getfield(L, 1, "uuid"); + if (lua_isnil(L, -1)) { + tt_uuid_create(&uuid); + } else if (lua_type(L, -1) == LUA_TSTRING) { + const char *uuid_str = lua_tostring(L, -1); + if (tt_uuid_from_string(uuid_str, &uuid) != 0) + return luaL_error(L, "cannot parse uuid string"); + } else { + return luaL_error(L, "uuid must be a string"); + } + lua_pop(L, 1); /* uuid */ + + uint32_t salt_len; + char salt[GREETING_SALT_LEN_MAX]; + lua_getfield(L, 1, "salt"); + if (lua_isnil(L, -1)) { + salt_len = IPROTO_SALT_SIZE; + random_bytes(salt, IPROTO_SALT_SIZE); + } else if (lua_type(L, -1) == LUA_TSTRING) { + size_t len; + const char *str = lua_tolstring(L, -1, &len); + if (len > GREETING_SALT_LEN_MAX) + return luaL_error(L, "salt string length " + "cannot be greater than %d", + GREETING_SALT_LEN_MAX); + salt_len = len; + memcpy(salt, str, len); + } else { + return luaL_error(L, "salt must be a string"); + } + lua_pop(L, 1); /* salt */ + + char greeting_str[IPROTO_GREETING_SIZE]; + greeting_encode(greeting_str, version, &uuid, salt, salt_len); + + lua_pushlstring(L, greeting_str, sizeof(greeting_str)); + return 1; +} + +/** + * Decodes Tarantool greeting message. + * + * Takes a greeting message string and returns a table with the following + * fields on success: + * - version: Tarantool version string in the form 'X.Y.Z'. + * - protocol: Tarantool protocol string ("Binary" for IPROTO). + * - uuid: Instance UUID string. + * - salt: Salt string (used for authentication). + * + * Raises an error on invalid input. + */ +static int +lbox_iproto_decode_greeting(struct lua_State *L) +{ + int n_args = lua_gettop(L); + if (n_args != 1 || lua_type(L, 1) != LUA_TSTRING) { + return luaL_error( + L, "Usage: box.iproto.decode_greeting(string)"); + } + + size_t len; + const char *greeting_str = lua_tolstring(L, 1, &len); + if (len != IPROTO_GREETING_SIZE) { + return luaL_error(L, "greeting length must equal %d", + IPROTO_GREETING_SIZE); + } + struct greeting greeting; + if (greeting_decode(greeting_str, &greeting) != 0) + return luaL_error(L, "cannot parse greeting string"); + + lua_newtable(L); + lua_pushfstring(L, "%u.%u.%u", + version_id_major(greeting.version_id), + version_id_minor(greeting.version_id), + version_id_patch(greeting.version_id)); + lua_setfield(L, -2, "version"); + lua_pushstring(L, greeting.protocol); + lua_setfield(L, -2, "protocol"); + luaT_pushuuidstr(L, &greeting.uuid); + lua_setfield(L, -2, "uuid"); + lua_pushlstring(L, greeting.salt, greeting.salt_len); + lua_setfield(L, -2, "salt"); + return 1; +} + +/** + * Encodes IPROTO packet. + * + * Takes a packet header and optionally a body given as a string or a table. + * If an argument is a table, it will be encoded in MsgPack using the IPROTO + * key translation table. If an argument is a string, it's supposed to store + * valid MsgPack data and will be copied as is. + * + * On success, returns a string storing the encoded IPROTO packet. + * On failure, raises a Lua error. + */ +static int +lbox_iproto_encode_packet(struct lua_State *L) +{ + int n_args = lua_gettop(L); + if (n_args != 1 && n_args != 2) + return luaL_error( + L, "Usage: box.iproto.encode_packet(header[, body])"); + int header_type = lua_type(L, 1); + if (header_type != LUA_TSTRING && header_type != LUA_TTABLE) + return luaL_error(L, "header must be a string or a table"); + int body_type = lua_type(L, 2); + if (body_type != LUA_TSTRING && body_type != LUA_TTABLE && + body_type != LUA_TNONE && body_type != LUA_TNIL) + return luaL_error(L, "body must be a string or a table"); + struct region *region = &fiber()->gc; + size_t region_svp = region_used(region); + struct mpstream stream; + mpstream_init(&stream, region, region_reserve_cb, region_alloc_cb, + mpstream_panic_cb, NULL); + size_t fixheader_size = mp_sizeof_uint(UINT32_MAX); + char *fixheader = mpstream_reserve(&stream, fixheader_size); + mpstream_advance(&stream, fixheader_size); + struct mp_ctx ctx; + mp_ctx_create_default(&ctx, iproto_key_translation); + if (header_type == LUA_TTABLE) { + int rc = luamp_encode_with_ctx(L, luaL_msgpack_default, + &stream, 1, &ctx, NULL); + if (rc != 0) + goto error; + } else if (header_type == LUA_TSTRING) { + size_t size; + const char *data = lua_tolstring(L, 1, &size); + mpstream_memcpy(&stream, data, size); + } + if (body_type == LUA_TTABLE) { + int rc = luamp_encode_with_ctx(L, luaL_msgpack_default, + &stream, 2, &ctx, NULL); + if (rc != 0) + goto error; + } else if (body_type == LUA_TSTRING) { + size_t size; + const char *data = lua_tolstring(L, 2, &size); + mpstream_memcpy(&stream, data, size); + } + mpstream_flush(&stream); + size_t data_size = region_used(region) - region_svp; + *fixheader = 0xce; + mp_store_u32(fixheader + 1, data_size - fixheader_size); + char *data = xregion_join(region, data_size); + lua_pushlstring(L, data, data_size); + region_truncate(region, region_svp); + return 1; +error: + region_truncate(region, region_svp); + return luaT_error(L); +} + +/** + * Decodes IPROTO packet. + * + * Takes a string that contains an encoded IPROTO packet and optionally + * the position in the string to start decoding from (if the position is + * omitted, the function will start decoding from the beginning of the + * input string, i.e. assume that the position equals 1). + * + * On success returns three values: the decoded packet header (never nil), + * the decoded packet body (may be nil), and the position of the following + * packet in the string. The header and body are returned as MsgPack objects. + * + * If the packet is truncated, returns nil and the minimal number of bytes + * necessary to decode the packet. + * + * On failure, raises a Lua error. + */ +static int +lbox_iproto_decode_packet(struct lua_State *L) +{ + int n_args = lua_gettop(L); + if (n_args == 0 || n_args > 2 || + lua_type(L, 1) != LUA_TSTRING || + (n_args == 2 && lua_type(L, 2) != LUA_TNUMBER)) + return luaL_error( + L, "Usage: box.iproto.decode_packet(string[, pos])"); + + size_t data_size; + const char *data = lua_tolstring(L, 1, &data_size); + const char *data_end = data + data_size; + const char *p = data; + if (n_args == 2) { + int pos = lua_tointeger(L, 2); + if (pos <= 0) + return luaL_error(L, "position must be greater than 0"); + p += pos - 1; + } + ptrdiff_t n = p - data_end + 1; + if (n > 0) + goto truncated_input; + if (mp_typeof(*p) != MP_UINT) { + diag_set(ClientError, ER_PROTOCOL, "invalid fixheader"); + return luaT_error(L); + } + n = mp_check_uint(p, data_end); + if (n > 0) + goto truncated_input; + size_t packet_size = mp_decode_uint(&p); + if (packet_size == 0) { + diag_set(ClientError, ER_PROTOCOL, "invalid fixheader"); + return luaT_error(L); + } + const char *packet_end = p + packet_size; + n = packet_end - data_end; + if (n > 0) + goto truncated_input; + const char *header = p; + if (mp_check(&p, packet_end) != 0) + return luaT_error(L); + const char *header_end = p; + const char *body = p; + if (p != packet_end && mp_check_exact(&p, packet_end) != 0) + return luaT_error(L); + const char *body_end = p; + struct mp_ctx ctx; + mp_ctx_create_default(&ctx, iproto_key_translation); + luamp_push_with_ctx(L, header, header_end, &ctx); + if (body != body_end) { + mp_ctx_create_default(&ctx, iproto_key_translation); + luamp_push_with_ctx(L, body, body_end, &ctx); + } else { + lua_pushnil(L); + } + lua_pushnumber(L, packet_end - data + 1); + return 3; +truncated_input: + assert(n > 0); + lua_pushnil(L); + lua_pushnumber(L, n); + return 2; +} + /** * Initializes module for working with Tarantool's network subsystem. */ @@ -393,6 +682,10 @@ box_lua_iproto_init(struct lua_State *L) static const struct luaL_Reg funcs[] = { {"send", lbox_iproto_send}, {"override", lbox_iproto_override}, + {"encode_greeting", lbox_iproto_encode_greeting}, + {"decode_greeting", lbox_iproto_decode_greeting}, + {"encode_packet", lbox_iproto_encode_packet}, + {"decode_packet", lbox_iproto_decode_packet}, {NULL, NULL} }; luaL_setfuncs(L, funcs, 0); diff --git a/test/app-luatest/iproto_encoder_test.lua b/test/app-luatest/iproto_encoder_test.lua new file mode 100644 index 0000000000000000000000000000000000000000..0bd93331decc936bdf676724200b7d481984de88 --- /dev/null +++ b/test/app-luatest/iproto_encoder_test.lua @@ -0,0 +1,295 @@ +local msgpack = require('msgpack') +local t = require('luatest') +local tarantool = require('tarantool') +local uuid = require('uuid') + +local g = t.group() + +-- +-- Checks exported constant values. +-- +g.test_constants = function() + t.assert_equals(box.iproto.GREETING_SIZE, 128) + t.assert_equals(box.iproto.GREETING_PROTOCOL_LEN_MAX, 32) + t.assert_equals(box.iproto.GREETING_SALT_LEN_MAX, 44) +end + +-- +-- Checks errors raised on invalid arguments passed to +-- box.iproto.encode_greeting() and box.iproto.decode_greeting(). +-- +g.test_encode_decode_greeting_invalid_args = function() + local encode = box.iproto.encode_greeting + local decode = box.iproto.decode_greeting + + local errmsg = 'Usage: box.iproto.encode_greeting({' .. + 'version = x, uuid = x, salt = x})' + t.assert_error_msg_equals(errmsg, encode, 123) + t.assert_error_msg_equals(errmsg, encode, 'foo') + t.assert_error_msg_equals(errmsg, encode, {}, 123) + + t.assert_error_msg_equals('version must be a string', + encode, {version = 123}) + t.assert_error_msg_equals('cannot parse version string', + encode, {version = 'foo'}) + t.assert_error_msg_equals('uuid must be a string', + encode, {uuid = 123}) + t.assert_error_msg_equals('cannot parse uuid string', + encode, {uuid = 'foo'}) + t.assert_error_msg_equals('salt must be a string', + encode, {salt = 123}) + t.assert_error_msg_equals('salt string length cannot be greater than 44', + encode, {salt = string.rep('x', 45)}) + + errmsg = 'Usage: box.iproto.decode_greeting(string)' + t.assert_error_msg_equals(errmsg, decode, 123) + t.assert_error_msg_equals(errmsg, decode, {}) + t.assert_error_msg_equals(errmsg, decode, 'foo', 123) + + t.assert_error_msg_equals('greeting length must equal 128', decode, 'foo') +end + +-- +-- Checks box.iproto.encode_greeting() and box.iproto.decode_greeting() output. +-- +g.test_encode_decode_greeting = function() + local encode = box.iproto.encode_greeting + local decode = box.iproto.decode_greeting + + local pattern = + 'Tarantool%s+%d+%.%d+%.%d+%s+%(Binary%)%s+' .. + string.rep('%x', 8) .. '%-' .. string.rep('%x', 4) .. '%-' .. + string.rep('%x', 4) .. '%-' .. string.rep('%x', 4) .. '%-' .. + string.rep('%x', 12) .. '%s*\n[%w%p]+%s*$' + + local str = encode() + t.assert_equals(#str, box.iproto.GREETING_SIZE) + t.assert_str_matches(str, pattern) + + str = encode({}) + t.assert_equals(#str, box.iproto.GREETING_SIZE) + t.assert_str_matches(str, pattern) + + local greeting = decode(str) + t.assert_type(greeting, 'table') + t.assert_equals(greeting.version, tarantool.version:match('%d%.%d%.%d')) + t.assert_equals(greeting.protocol, 'Binary') + t.assert(uuid.fromstr(greeting.uuid)) + t.assert_not_equals(uuid.fromstr(greeting.uuid), uuid.NULL) + t.assert_type(greeting.salt, 'string') + t.assert_equals(#greeting.salt, 32) + t.assert_equals(encode(greeting), str) + + greeting = { + version = '2.3.4', + protocol = 'Binary', + uuid = uuid.str(), + salt = string.rep('x', 40), + } + str = encode(greeting) + t.assert_equals(#str, box.iproto.GREETING_SIZE) + t.assert_str_matches(str, pattern) + t.assert_equals(decode(str), greeting) +end + +-- +-- Checks errors raised on invalid arguments passed to +-- box.iproto.encode_packet() and box.iproto.decode_packet(). +-- +g.test_encode_decode_packet_invalid_args = function() + local encode = box.iproto.encode_packet + local decode = box.iproto.decode_packet + + local errmsg = 'Usage: box.iproto.encode_packet(header[, body])' + t.assert_error_msg_equals(errmsg, encode) + t.assert_error_msg_equals(errmsg, encode, {}, {}, {}) + + t.assert_error_msg_equals('header must be a string or a table', + encode, 123) + t.assert_error_msg_equals('body must be a string or a table', + encode, {}, 123) + t.assert_error_msg_equals("unsupported Lua type 'function'", + encode, {function() end}) + t.assert_error_msg_equals("unsupported Lua type 'function'", + encode, {}, {function() end}) + + errmsg = 'Usage: box.iproto.decode_packet(string[, pos])' + t.assert_error_msg_equals(errmsg, decode) + t.assert_error_msg_equals(errmsg, decode, {}) + t.assert_error_msg_equals(errmsg, decode, 123) + t.assert_error_msg_equals(errmsg, decode, '', '1') + t.assert_error_msg_equals(errmsg, decode, '', 1, '') + + errmsg = 'position must be greater than 0' + t.assert_error_msg_equals(errmsg, decode, '', 0) + t.assert_error_msg_equals(errmsg, decode, '', -1) +end + +-- +-- Checks errors raised by box.iproto.decode_packet() on bad input. +-- +g.test_decode_packet_bad_input = function() + local decode = box.iproto.decode_packet + + local errmsg = 'invalid fixheader' + t.assert_error_msg_equals(errmsg, decode, string.fromhex('00')) + t.assert_error_msg_equals(errmsg, decode, string.fromhex('ff')) + t.assert_error_msg_equals(errmsg, decode, string.fromhex('80')) + + t.assert_error_msg_equals('Invalid MsgPack - illegal code', + decode, string.fromhex('01c1')) + t.assert_error_msg_equals('Invalid MsgPack - truncated input', + decode, string.fromhex('0281c0')) + t.assert_error_msg_equals('Invalid MsgPack - truncated input', + decode, string.fromhex('0281c0c0')) + t.assert_error_msg_equals('Invalid MsgPack - truncated input', + decode, string.fromhex('0581c0c081c0')) + t.assert_error_msg_equals('Invalid MsgPack - truncated input', + decode, string.fromhex('0581c0c081c0c0')) + t.assert_error_msg_equals('Invalid MsgPack - junk after input', + decode, string.fromhex('0781c0c081c0c0c0')) +end + +-- +-- Checks output of box.iproto.decode_packet() on truncated input. +-- +g.test_decode_packet_truncated_input = function() + local decode = box.iproto.decode_packet + + t.assert_equals({decode('')}, {nil, 1}) + t.assert_equals({decode(string.fromhex('ce'))}, {nil, 4}) + t.assert_equals({decode(string.fromhex('ce0000'))}, {nil, 2}) + t.assert_equals({decode(string.fromhex('05'))}, {nil, 5}) + t.assert_equals({decode(string.fromhex('ce00000005'))}, {nil, 5}) + t.assert_equals({decode(string.fromhex('ce0000000581'))}, {nil, 4}) +end + +-- +-- Checks box.iproto.encode_packet() and box.iproto.decode_packet() output +-- on input containing a single packet. +-- +g.test_encode_decode_packet_one = function() + local encode = box.iproto.encode_packet + local decode = box.iproto.decode_packet + + local data = encode({ + sync = 123, + request_type = box.iproto.type.INSERT, + }, { + space_id = 512, + tuple = {1, 2, 3}, + }) + t.assert_equals(string.hex(data), + 'ce0000000f820002017b8210cd02002193010203') + local header, body, pos = decode(data) + t.assert(msgpack.is_object(header)) + t.assert_equals(header:decode(), { + [box.iproto.key.SYNC] = 123, + [box.iproto.key.REQUEST_TYPE] = box.iproto.type.INSERT, + }) + t.assert_equals(header.sync, 123) + t.assert_equals(header.request_type, box.iproto.type.INSERT) + t.assert(msgpack.is_object(body)) + t.assert_equals(body:decode(), { + [box.iproto.key.SPACE_ID] = 512, + [box.iproto.key.TUPLE] = {1, 2, 3}, + }) + t.assert_equals(body.space_id, 512) + t.assert_equals(body.tuple, {1, 2, 3}) + t.assert_equals(pos, #data + 1) + + data = encode({ + sync = 123, + request_type = box.iproto.type.NOP, + }) + t.assert_equals(string.hex(data), 'ce0000000582000c017b') + header, body, pos = decode(data) + t.assert(msgpack.is_object(header)) + t.assert_equals(header:decode(), { + [box.iproto.key.SYNC] = 123, + [box.iproto.key.REQUEST_TYPE] = box.iproto.type.NOP, + }) + t.assert_equals(header.sync, 123) + t.assert_equals(header.request_type, box.iproto.type.NOP) + t.assert_is(body, nil) + t.assert_equals(pos, #data + 1) +end + +-- +-- Checks box.iproto.encode_packet() and box.iproto.decode_packet() output +-- on input containing multiple packets. +-- +g.test_encode_decode_packet_many = function() + local encode = box.iproto.encode_packet + local decode = box.iproto.decode_packet + + local data = encode({ + sync = 1, + request_type = box.iproto.type.INSERT, + }, { + space_id = 512, + tuple = {'a', 'b', 'c'}, + }) .. encode({ + sync = 2, + request_type = box.iproto.type.NOP, + }) .. encode({ + sync = 3, + request_type = box.iproto.type.REPLACE, + }, { + space_name = 'test', + tuple = {1, 2, 3}, + }) + + local header, body, pos = decode(data) + t.assert(msgpack.is_object(header)) + t.assert_equals(header:decode(), { + [box.iproto.key.SYNC] = 1, + [box.iproto.key.REQUEST_TYPE] = box.iproto.type.INSERT, + }) + t.assert(msgpack.is_object(body)) + t.assert_equals(body:decode(), { + [box.iproto.key.SPACE_ID] = 512, + [box.iproto.key.TUPLE] = {'a', 'b', 'c'}, + }) + + header, body, pos = decode(data, pos) + t.assert(msgpack.is_object(header)) + t.assert_equals(header:decode(), { + [box.iproto.key.SYNC] = 2, + [box.iproto.key.REQUEST_TYPE] = box.iproto.type.NOP, + }) + t.assert_is(body, nil) + + header, body, pos = decode(data, pos) + t.assert(msgpack.is_object(header)) + t.assert_equals(header:decode(), { + [box.iproto.key.SYNC] = 3, + [box.iproto.key.REQUEST_TYPE] = box.iproto.type.REPLACE, + }) + t.assert(msgpack.is_object(body)) + t.assert_equals(body:decode(), { + [box.iproto.key.SPACE_NAME] = 'test', + [box.iproto.key.TUPLE] = {1, 2, 3}, + }) + + t.assert_equals(pos, #data + 1) + t.assert_equals({decode(data, pos)}, {nil, 1}) +end + +-- +-- Checks box.iproto.encode_packet() output on binary input. +-- +g.test_encode_packet_bin = function() + local encode = box.iproto.encode_packet + + local data = encode(string.fromhex('820002017b'), + string.fromhex('8210cd02002193010203')) + t.assert_equals(string.hex(data), + 'ce0000000f820002017b8210cd02002193010203') + data = encode(string.fromhex('82000c017b')) + t.assert_equals(string.hex(data), 'ce0000000582000c017b') + + -- box.iproto.encode_packet() doesn't check binary input. + data = encode(string.fromhex('c1'), string.fromhex('82')) + t.assert_equals(string.hex(data), 'ce00000002c182') +end diff --git a/test/box-luatest/ghs_16_user_enumeration_test.lua b/test/box-luatest/ghs_16_user_enumeration_test.lua index c32f5c15fb4f1f1f0fc21b639a985d96ed016cfa..6876e1f5fcc6485c76713fa1fc3fb8f5f4e32f5a 100644 --- a/test/box-luatest/ghs_16_user_enumeration_test.lua +++ b/test/box-luatest/ghs_16_user_enumeration_test.lua @@ -1,4 +1,3 @@ -local msgpack = require('msgpack') local net = require('net.box') local server = require('luatest.server') local socket = require('socket') @@ -7,39 +6,37 @@ local t = require('luatest') local g = t.group() -local IPROTO_REQUEST_TYPE = 0 -local IPROTO_TYPE_ERROR = bit.lshift(1, 15) -local IPROTO_AUTH = 7 -local IPROTO_TUPLE = 33 -local IPROTO_USER = 35 -local IPROTO_ERROR = 49 - -- Opens a new connection and sends IPROTO_AUTH request. -- Returns {code, error} local function auth(sock_path, user, tuple) - local hdr = msgpack.encode({[IPROTO_REQUEST_TYPE] = IPROTO_AUTH}) - local body = msgpack.encode({ - [IPROTO_USER] = user, - [IPROTO_TUPLE] = tuple, - }) - local len = hdr:len() + body:len() - t.assert_lt(len, 256) local s = socket.tcp_connect('unix/', sock_path) - local data = s:read(128) -- greeting - t.assert_equals(#data, 128) - data = '\xce\00\00\00' .. string.char(len) .. hdr .. body - t.assert_equals(s:write(data), #data) -- request - data = s:read(5) -- fixheader - t.assert_equals(#data, 5) - len = msgpack.decode(data) - data = s:read(len) -- response - t.assert_equals(#data, len) + local greeting = s:read(box.iproto.GREETING_SIZE) + greeting = box.iproto.decode_greeting(greeting) + t.assert_covers(greeting, {protocol = 'Binary'}) + local request = box.iproto.encode_packet({ + sync = 123, + request_type = box.iproto.type.AUTH, + }, { + user_name = user, + tuple = tuple, + }) + t.assert_equals(s:write(request), #request) + local response = '' + local header, body + repeat + header, body = box.iproto.decode_packet(response) + if header == nil then + local size = body + local data = s:read(size) + t.assert_is_not(data) + response = response .. data + end + until header ~= nil s:close() - hdr, len = msgpack.decode(data) - body = msgpack.decode(data, len) + t.assert_equals(header.sync, 123) return { - bit.band(hdr[IPROTO_REQUEST_TYPE], bit.bnot(IPROTO_TYPE_ERROR)), - body[IPROTO_ERROR], + bit.band(header.request_type, bit.bnot(box.iproto.type.TYPE_ERROR)), + body.error_24, } end