diff --git a/changelogs/unreleased/protobuf_encoder_module.md b/changelogs/unreleased/protobuf_encoder_module.md new file mode 100644 index 0000000000000000000000000000000000000000..767dc342d50827f78d1f24fe52fbd9241fcfe19d --- /dev/null +++ b/changelogs/unreleased/protobuf_encoder_module.md @@ -0,0 +1,3 @@ +## feature/lua + +* Introduced Lua implementation of protobuf encoder (gh-9844). diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b40b48ee1e3a82e622924010efcb00901f78040f..9beffc1e397a93c1a6c030777c391116fc56bbae 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -71,6 +71,8 @@ lua_source(lua_sources lua/print.lua print_lua) lua_source(lua_sources lua/pairs.lua pairs_lua) lua_source(lua_sources lua/compat.lua compat_lua) lua_source(lua_sources lua/varbinary.lua varbinary_lua) +lua_source(lua_sources lua/protobuf_wireformat.lua protobuf_wireformat_lua) +lua_source(lua_sources lua/protobuf.lua protobuf_lua) if (ENABLE_COMPRESS_MODULE) lua_source(lua_sources ${COMPRESS_MODULE_LUA_SOURCE} compress_lua) endif() diff --git a/src/lua/init.c b/src/lua/init.c index c6787d39d369c102aed5de92f68dd6c42d6b8df3..8df359a88898ea3defc5401d1479ec6577a9e340 100644 --- a/src/lua/init.c +++ b/src/lua/init.c @@ -136,6 +136,8 @@ extern char minifio_lua[], utils_lua[], argparse_lua[], iconv_lua[], + protobuf_wireformat_lua[], + protobuf_lua[], /* jit.* library */ jit_vmdef_lua[], jit_bc_lua[], @@ -211,6 +213,8 @@ static const char *lua_modules[] = { "http.client", httpc_lua, "iconv", iconv_lua, "swim", swim_lua, + "internal.protobuf.wireformat", protobuf_wireformat_lua, + "protobuf", protobuf_lua, COMPRESS_LUA_MODULES /* jit.* library */ "jit.vmdef", jit_vmdef_lua, diff --git a/src/lua/protobuf.lua b/src/lua/protobuf.lua new file mode 100644 index 0000000000000000000000000000000000000000..7f29ef81e25d1eb9d9fdfaf106e0a06003ba3d57 --- /dev/null +++ b/src/lua/protobuf.lua @@ -0,0 +1,670 @@ +local ffi = require('ffi') +local wireformat = require('internal.protobuf.wireformat') +local protocol_mt +-- These constants are used to define the boundaries of valid field ids. +-- Described in more detail here: +-- https://protobuf.dev/programming-guides/proto3/#assigning +local MIN_FIELD_ID = 1 +local RESERVED_FIELD_ID_MIN = 19000 +local RESERVED_FIELD_ID_MAX = 19999 +local MAX_FIELD_ID = 2^29 - 1 + +-- Number limits for int32 and int64 +local MAX_FLOAT = 0x1.fffffep+127 +local MAX_UINT32 = 2^32 - 1 +-- Actual uint64 limit is 2^64 - 1. Because of lua number limited precision +-- numbers from [2^64 - 1024, 2^64 + 2048] represent as 2^64. So the correct +-- number limit for uint64 is 2^64 - 1025. +local MAX_UINT64 = 0xfffffffffffffbff +local MIN_SINT32 = -2^31 +local MAX_SINT32 = 2^31 - 1 +-- Same problem with lua number limited precision. +-- Numbers from [2^63 - 512, 2^63 + 1024] represent as 2^63. So the correct +-- number limit for int64 is 2^63 - 513. +local MAX_INT64 = 0x7ffffffffffffdff +local MIN_INT64 = -0x8000000000000000 -- 2^63 + +-- Cdata limits for int32_t and int64_t +local MAX_UINT32_LL = 2LL^32 - 1 +local MAX_UINT32_ULL = 2ULL^32 - 1 +local MIN_SINT32_LL = -2LL^31 +local MAX_SINT32_LL = 2LL^31 - 1 +local MAX_SINT32_ULL = 2ULL^31 - 1 +local MAX_SINT64_ULL = 2ULL^63 - 1 + + +local int64_t = ffi.typeof('int64_t') +local uint64_t = ffi.typeof('uint64_t') + +-- Forward declarations +local encode +local encode_field +local validate_scalar + +local scalars = {} + +-- {{{ Scalar type definitions + +scalars.float = { + accept_type = 'number', + encode_as_packed = true, + integral_only = false, + limits = { + number = {-MAX_FLOAT, MAX_FLOAT}, + }, + encode = wireformat.encode_float, +} + +scalars.fixed32 = { + accept_type = {'number', 'cdata'}, + encode_as_packed = true, + integral_only = true, + limits = { + number = {0, MAX_UINT32}, + int64 = {0LL, MAX_UINT32_LL}, + uint64 = {nil, MAX_UINT32_ULL}, + }, + encode = wireformat.encode_fixed32, +} + +scalars.sfixed32 = { + accept_type = {'number', 'cdata'}, + encode_as_packed = true, + integral_only = true, + limits = { + number = {MIN_SINT32, MAX_SINT32}, + int64 = {MIN_SINT32_LL, MAX_SINT32_LL}, + uint64 = {nil, MAX_SINT32_ULL}, + }, + encode = wireformat.encode_fixed32, +} + +scalars.double = { + accept_type = 'number', + encode_as_packed = true, + integral_only = false, + encode = wireformat.encode_double, +} + +scalars.fixed64 = { + accept_type = {'number', 'cdata'}, + encode_as_packed = true, + integral_only = true, + limits = { + number = {0, MAX_UINT64}, + int64 = {0LL, nil}, + }, + encode = wireformat.encode_fixed64, +} + +scalars.sfixed64 = { + accept_type = {'number', 'cdata'}, + encode_as_packed = true, + integral_only = true, + limits = { + number = {MIN_INT64, MAX_INT64}, + uint64 = {nil, MAX_SINT64_ULL}, + }, + encode = wireformat.encode_fixed64, +} + +scalars.string = { + accept_type = 'string', + encode_as_packed = false, + encode = wireformat.encode_len, +} + +scalars.bytes = scalars.string + +scalars.int32 = { + accept_type = {'number', 'cdata'}, + encode_as_packed = true, + integral_only = true, + limits = { + number = {MIN_SINT32, MAX_SINT32}, + int64 = {MIN_SINT32_LL, MAX_SINT32_LL}, + uint64 = {nil, MAX_SINT32_ULL}, + }, + encode = wireformat.encode_int, +} + +scalars.sint32 = { + accept_type = {'number', 'cdata'}, + encode_as_packed = true, + integral_only = true, + limits = { + number = {MIN_SINT32, MAX_SINT32}, + int64 = {MIN_SINT32_LL, MAX_SINT32_LL}, + uint64 = {nil, MAX_SINT32_ULL}, + }, + encode = wireformat.encode_sint, +} + +scalars.uint32 = { + accept_type = {'number', 'cdata'}, + encode_as_packed = true, + integral_only = true, + limits = { + number = {0, MAX_UINT32}, + int64 = {0LL, MAX_UINT32_LL}, + uint64 = {nil, MAX_UINT32_ULL}, + }, + encode = wireformat.encode_int, +} + +scalars.int64 = { + accept_type = {'number', 'cdata'}, + encode_as_packed = true, + integral_only = true, + limits = { + number = {MIN_INT64, MAX_INT64}, + uint64 = {nil, MAX_SINT64_ULL}, + }, + encode = wireformat.encode_int, +} + +scalars.sint64 = { + accept_type = {'number', 'cdata'}, + encode_as_packed = true, + integral_only = true, + limits = { + number = {MIN_INT64, MAX_INT64}, + uint64 = {nil, MAX_SINT64_ULL}, + }, + encode = wireformat.encode_sint, +} + +scalars.uint64 = { + accept_type = {'number', 'cdata'}, + encode_as_packed = true, + integral_only = true, + limits = { + number = {0, MAX_UINT64}, + int64 = {0LL, nil}, + }, + encode = wireformat.encode_int, +} + +scalars.bool = { + accept_type = 'boolean', + encode_as_packed = true, + encode = wireformat.encode_int, +} + +-- }}} Scalar type definitions + + +-- {{{ Constructors: message, enum, protocol + +-- Create a message object suitable to pass +-- into the protobuf.protocol function. +-- +-- Accepts a name of the message and a message +-- definition in the following format. +-- +-- message_def = { +-- <field_name> = {<field_type>, <field_id>}, +-- <...> +-- } +local function message(message_name, message_def) + local field_by_name = {} + local field_by_id = {} + for field_name, def in pairs(message_def) do + local field_type = def[1] + local field_id = def[2] + local field_type, rep = string.gsub(field_type, 'repeated%s', '') + if field_by_id[field_id] ~= nil then + error(('Id %d in field %q was already used'):format(field_id, + field_name)) + end + if field_id < MIN_FIELD_ID or field_id > MAX_FIELD_ID then + error(('Id %d in field %q is out of range [%d; %d]'):format( + field_id, field_name, MIN_FIELD_ID, MAX_FIELD_ID)) + end + if field_id >= RESERVED_FIELD_ID_MIN and + field_id <= RESERVED_FIELD_ID_MAX then + error(('Id %d in field %q is in reserved ' .. + 'id range [%d, %d]'):format(field_id, field_name, + RESERVED_FIELD_ID_MIN, RESERVED_FIELD_ID_MAX)) + end + local field_def = { + type = field_type, + name = field_name, + id = field_id, + } + if rep ~= 0 then + field_def['repeated'] = true + end + field_by_name[field_name] = field_def + field_by_id[field_id] = field_def + end + return { + type = 'message', + name = message_name, + field_by_name = field_by_name, + field_by_id = field_by_id + } +end + +-- Create a enum object suitable to pass into +-- the protobuf.protocol function. +-- +-- Accepts a name of an enum and an enum definition +-- in the following format. +-- +-- enum_def = { +-- <value_name> = <value_id>, +-- <...> +-- } +local function enum(enum_name, enum_def) + local id_by_value = {} + local value_by_id = {} + for value_name, value_id in pairs(enum_def) do + if value_by_id[value_id] ~= nil then + error(('Double definition of enum field %q by %d'):format( + value_name, value_id)) + end + local field_def = {type = 'int32', name = value_name} + validate_scalar(field_def, value_id) + id_by_value[value_name] = value_id + value_by_id[value_id] = value_name + end + if value_by_id[0] == nil then + error(('%q definition does not contain a field with id = 0'): + format(enum_name)) + end + return { + type = 'enum', + name = enum_name, + id_by_value = id_by_value, + value_by_id = value_by_id, + } +end + +-- Create a protocol object that stores message +-- data needed for encoding. +-- +-- Accepts protocol definition using protobuf.message +-- and protobuf.enum functions as in example. +-- +-- protocol_def = { +-- protocol.message(<message_name>, <message_def>), +-- protocol.enum(<enum_name>, <enum_def>), +-- <...> +-- } +-- +-- Returns a table of the following structure: +-- +-- protocol = { +-- ['MessageName_1'] = { +-- type = 'message' +-- name = 'MessageName_1' +-- field_by_name = { +-- ['FieldName_1'] = <..field_def..>, +-- ['FieldName_2'] = <..field_def..>, +-- <...> +-- }, +-- field_by_id = { +-- [1] = <..field_def..>, +-- [2] = <..field_def..>, +-- <...> +-- }, +-- }, +-- ['EnumName_1'] = { +-- type = 'enum' +-- name = 'EnumName_1' +-- id_by_value = { +-- [<string>] = <number>, +-- [<string>] = <number>, +-- <...> +-- }, +-- value_by_id = { +-- [<number>] = <string>, +-- [<number>] = <string>, +-- <...> +-- }, +-- }, +-- <...> +-- } +-- +-- where <..field_def..> is a table of following structure: +-- +-- field_def = { +-- type = 'MessageName' or 'EnumName' or 'int64' or <...>, +-- name = <string>, +-- id = <number>, +-- repeated = nil or true, +-- } +local function protocol(protocol_def) + local res = {} + -- Declaration table is used to check forward declarations + -- false -- this type used as the field type in the message was not defined + -- true -- this field type was defined + local declarations = {} + for _, def in pairs(protocol_def) do + if declarations[def.name] then + error(('Double definition of name %q'):format(def.name)) + end + if def.type == 'message' then + for _, field_def in pairs(def.field_by_id) do + local standard = scalars[field_def.type] ~= nil + local declared = declarations[field_def.type] + if not standard and not declared then + declarations[field_def.type] = false + end + end + end + declarations[def.name] = true + res[def.name] = def + end + -- Detects a message or a enum that is used as a field type in message + -- but not defined in protocol. Allows a type be defined after usage + for def_type, declared in pairs(declarations) do + if not declared then + error(('Type %q is not declared'):format(def_type)) + end + end + return setmetatable(res, protocol_mt) +end + +-- }}} Constructors: message, enum, protocol + + +-- {{{ Global helpers + +local function is_number64(value) + return type(value) == 'cdata' and (ffi.istype(int64_t, value) or + ffi.istype(uint64_t, value)) +end + +local function is_nan(value) + assert(type(value) == 'number') + return value ~= value +end + +local function is_inf(value) + assert(type(value) == 'number') + return not is_nan(value) and is_nan(value - value) +end + +-- Checks input value assumed to be integer. +-- +-- Checks 'number' type value to be integral and 'cdata' type value to be +-- number64 (look function above). Assumes a value of 'number' or 'cdata' +-- type as an input. +local function check_integer(field_def, value) + if type(value) == 'number' and math.ceil(value) ~= value then + error(('Input number value %f for %q is not integer'):format( + value, field_def.name)) + elseif type(value) == 'cdata' and not is_number64(value) then + error(('Input cdata value %q for %q field is not integer'):format( + ffi.typeof(value), field_def.name)) + end +end + +local function remove_tag(value) + local tag_length = 0 + repeat + tag_length = tag_length + 1 + local msb = bit.band(string.byte(value, tag_length), 0x80) + until msb == 0 + return string.sub(value, tag_length + 1) +end + +-- }}} Global helpers + + +-- {{{ is_scalar, is_enum, is_message + +local function is_scalar(field_def) + return scalars[field_def.type] +end + +local function is_enum(protocol, field_def) + return protocol[field_def.type].type == 'enum' +end + +local function is_message(protocol, field_def) + return protocol[field_def.type].type == 'message' +end + +-- }}} is_scalar, is_enum, is_message + + +-- {{{ Validation + +local function validate_length(value) + local MAX_LEN = 2^32 + if string.len(value) > MAX_LEN then + error("Too long string to be encoded") + end +end + +local function validate_table_is_array(field_def, value) + assert(type(value) == 'table') + local key_count = 0 + local min_key = math.huge + local max_key = -math.huge + for k, data in pairs(value) do + if data == box.NULL then + error(('Input array for %q repeated field contains box.NULL ' .. + 'value which leads to ambiguous behaviour'):format( + field_def.name)) + end + if type(k) ~= 'number' then + error(('Input array for %q repeated field ' .. + 'contains non-numeric key: %q'):format(field_def.name, k)) + end + if k - math.floor(k) ~= 0 then + error(('Input array for %q repeated field contains ' .. + 'non-integer numeric key: %q'):format(field_def.name, k)) + end + key_count = key_count + 1 + min_key = math.min(min_key, k) + max_key = math.max(max_key, k) + end + if key_count == 0 then + return + end + if min_key ~= 1 then + error(('Input array for %q repeated field got min index %d. ' .. + 'Must be 1'):format(field_def.name, min_key)) + end + if max_key ~= key_count then + error(('Input array for %q repeated field has inconsistent keys. ' .. + 'Got table with %d fields and max index of %d'):format( + field_def.name, key_count, max_key)) + end +end + +local function validate_type(field_def, value, exp_type) + if type(exp_type) == 'table' then + local found = false + for _, exp_t in pairs(exp_type) do + if type(value) == exp_t then + found = true + break + end + end + if not found then + error(('Field %q of %q type gets %q type value.'):format( + field_def.name, field_def.type, type(value))) + end + return + end + assert(type(exp_type) == 'string') + if type(value) ~= exp_type then + error(('Field %q of %q type gets %q type value.'):format( + field_def.name, field_def.type, type(value))) + end + return +end + +local function validate_range(field_def, value, range) + local min = range ~= nil and range[1] or nil + local max = range ~= nil and range[2] or nil + -- If one of the limits is 'nil' this function skips + -- the comparison with this limit + -- + -- NB: We can't use -math.huge instead of nil for the lower limit, + -- because, for example, 10ULL < -math.huge returns true. + if min ~= nil and value < min or max ~= nil and value > max then + error(('Input data for %q field is %q and do not fit in %q') + :format(field_def.name, value, field_def.type)) + end +end + +validate_scalar = function(field_def, value) + local scalar_def = scalars[field_def.type] + local value_type = type(value) + assert(scalar_def.accept_type ~= nil) + -- Checks type of input according to the allowed types for this field. + validate_type(field_def, value, scalar_def.accept_type) + -- Checks length of the string if input type assumed to be string. + if scalar_def.accept_type == 'string' then + validate_length(value) + end + -- Checks number values for being NaN or inf. + if value_type == 'number' and is_nan(value) then + error(('Input data for %q field is NaN'):format(field_def.name)) + end + if value_type == 'number' and is_inf(value) then + error(('Input data for %q field is inf'):format(field_def.name)) + end + -- Checks values assumed to be integer. + if scalar_def.integral_only then + check_integer(field_def, value) + end + -- Checks numeric values to see if they belong to the limits. + if scalar_def.limits ~= nil then + if value_type == 'cdata' then + value_type = ffi.istype(int64_t, value) and 'int64' + or 'uint64' + end + validate_range(field_def, value, scalar_def.limits[value_type]) + end +end + +-- }}} Validation + + +-- {{{ Encoders + +local function encode_repeated(protocol, field_def, value) + local buf = {} + local encode_as_packed = false + if type(value) ~= 'table' then + error('For repeated fields table data are needed') + end + validate_table_is_array(field_def, value) + if is_scalar(field_def) then + local scalar_def = scalars[field_def.type] + encode_as_packed = scalar_def.encode_as_packed + end + for _, item in ipairs(value) do + local encoded_item = encode_field(protocol, field_def, item, true) + if encoded_item == '' then + error(('Input for %q repeated field contains default value ' .. + 'can`t be encoded correctly'):format(field_def.name)) + end + if encode_as_packed then + encoded_item = remove_tag(encoded_item) + end + table.insert(buf, encoded_item) + end + if encode_as_packed then + return wireformat.encode_len(field_def.id, table.concat(buf)) + else + return table.concat(buf) + end +end + +local function encode_enum(protocol, field_def, value) + local id = protocol[field_def.type]['id_by_value'][value] + if type(value) ~= 'number' and id == nil then + error(('%q is not defined in %q enum'):format(value, field_def.type)) + end + -- According to open enums semantics unknown enum values are encoded as + -- numeric identifier. https://protobuf.dev/programming-guides/enum/ + if type(value) == 'number' then + local subs_field_def = {type = 'int32', id = field_def.id} + validate_scalar(subs_field_def, value) + return scalars['int32'].encode(field_def.id, value) + else + return scalars['int32'].encode(field_def.id, id) + end +end + +encode_field = function(protocol, field_def, value, ignore_repeated) + if field_def.repeated and not ignore_repeated then + return encode_repeated(protocol, field_def, value) + elseif is_scalar(field_def) then + validate_scalar(field_def, value) + local scalar_def = scalars[field_def.type] + return scalar_def.encode(field_def.id, value) + elseif is_enum(protocol, field_def) then + return encode_enum(protocol, field_def, value) + elseif is_message(protocol, field_def) then + local encoded_msg = encode(protocol, field_def.type, value) + validate_length(encoded_msg) + return wireformat.encode_len(field_def.id, encoded_msg, true) + else + assert(false) + end +end + +-- Encodes the entered data in accordance with the +-- selected protocol into binary format. +-- +-- Accepts a protocol created by protobuf.protocol function, +-- a name of a message selected for encoding and +-- the data that needs to be encoded in the following format. +-- +-- data = { +-- <field_name> = <value>, +-- <...> +-- } +encode = function(protocol, message_name, data) + local buf = {} + local message_def = protocol[message_name] + if message_def == nil then + error(('There is no message or enum named %q in the given protocol') + :format(message_name)) + end + if message_def.type ~= 'message' then + assert(message_def.type == 'enum') + error(('Attempt to encode enum %q as a top level message'):format( + message_name)) + end + local field_by_name = message_def.field_by_name + for field_name, value in pairs(data) do + if value == box.NULL then goto continue end + if field_by_name[field_name] == nil and + field_name ~= '_unknown_fields' then + error(('Wrong field name %q for %q message'): + format(field_name, message_name)) + end + if field_name == '_unknown_fields' then + table.insert(buf, table.concat(value)) + else + table.insert(buf, encode_field(protocol, + field_by_name[field_name], value, false)) + end + ::continue:: + end + return table.concat(buf) +end + +-- }}} Encoders + +protocol_mt = { + __index = { + encode = encode, + } +} + +return { + message = message, + enum = enum, + protocol = protocol, +} diff --git a/src/lua/protobuf_wireformat.lua b/src/lua/protobuf_wireformat.lua new file mode 100644 index 0000000000000000000000000000000000000000..616d6721cbb2e82189c0af9d18382bf5fe6f0b27 --- /dev/null +++ b/src/lua/protobuf_wireformat.lua @@ -0,0 +1,217 @@ +local ffi = require('ffi') +local int64_t = ffi.typeof('int64_t') +local uint64_t = ffi.typeof('uint64_t') + +local WIRE_TYPE_VARINT = 0 +local WIRE_TYPE_I64 = 1 +local WIRE_TYPE_LEN = 2 +-- SGROUP (3) and EGROUP (4) are deprecated in proto3. +local WIRE_TYPE_I32 = 5 + +local NUMERIC_DEFAULT = 0 +local STRING_DEFAULT = '' +local BOOL_DEFAULT = false + +-- {{{ Helpers + +-- 32-bit IEEE 754 representation of the given number. +local function as_float(value) + local p = ffi.new('float[1]') + p[0] = value + return ffi.string(ffi.cast('char *', p), 4) +end + +-- 64-bit IEEE 754 representation of the given number. +local function as_double(value) + local p = ffi.new('double[1]') + p[0] = value + return ffi.string(ffi.cast('char *', p), 8) +end + +-- 32-bit two's complement representation of the given integral number. +local function as_int32(value) + -- Set the type of storage for the given value: signed or unsigned 32-bit. + local ctype + if type(value) == 'number' then + ctype = value < 0 and 'int32_t[1]' or 'uint32_t[1]' + elseif type(value) == 'cdata' and ffi.istype(int64_t, value) then + ctype = 'int32_t[1]' + elseif type(value) == 'cdata' and ffi.istype(uint64_t, value) then + ctype = 'uint32_t[1]' + else + assert(false) + end + local p = ffi.new(ctype) + p[0] = value + return ffi.string(ffi.cast('char *', p), 4) +end + +-- 64-bit two's complement representation of the given integral number. +local function as_int64(value) + local ctype + -- Set the type of storage for the given value: signed or unsigned 64-bit. + if type(value) == 'number' then + ctype = value < 0 and 'int64_t[1]' or 'uint64_t[1]' + elseif type(value) == 'cdata' and ffi.istype(int64_t, value) then + ctype = 'int64_t[1]' + elseif type(value) == 'cdata' and ffi.istype(uint64_t, value) then + ctype = 'uint64_t[1]' + else + assert(false) + end + local p = ffi.new(ctype) + p[0] = value + return ffi.string(ffi.cast('char *', p), 8) +end + +-- Encode an integral value as VARINT without a tag. +-- +-- Input value types: number (integral), cdata<int64_t>, cdata<uint64_t>. +-- +-- This is a helper function to encode tag and data values. +-- +-- https://protobuf.dev/programming-guides/encoding/#varints +local function encode_varint(value) + local buf = ffi.new('char[?]', 11) + local size = 0 + -- The bit module defines bit arithmectic on the number type as 32 bit. + -- We need to handle numbers beyond 2^53 so we use cast to cdata. + -- + -- Note: casting a negative value to an unsigned type is an undefined + -- behavior thus we cast it to a signed type. + if type(value) == 'number' then + local ctype = value < 0 and int64_t or uint64_t + value = ffi.cast(ctype, value) + end + repeat + -- Extract next 7 bit payload and add a continuation bit + -- (set the most significant bit to 1). + local payload = bit.bor(bit.band(value, 0x7f), 0x80) + value = bit.rshift(value, 7) + -- Write the payload and continuation bit to the buffer. + buf[size] = payload + size = size + 1 + until value == 0 + -- Set the continuation bit to zero for the last byte. + buf[size-1] = bit.band(buf[size-1], 0x7f) + return ffi.string(buf, size) +end + +-- Encode a tag byte. +-- +-- Tag byte consists of the given field_id and the given Protocol Buffers +-- wire type. This is the first byte of Tag-Length-Value encoding. +local function encode_tag(field_id, wire_type) + assert(wire_type >= 0 and wire_type <= 5) + return encode_varint(bit.bor(bit.lshift(field_id, 3), wire_type)) +end + +-- }}} Helpers + +-- {{{ API functions + +-- Encode an integral value as VARINT using two complement encoding. +-- +-- Input value types: number (integral), cdata<int64_t>, cdata<uint64_t>, +-- boolean. +-- +-- Used for Protocol Buffers types: int32, int64, uint32, uint64, bool, enum. +local function encode_int(field_id, value) + if value == NUMERIC_DEFAULT or value == BOOL_DEFAULT then + return '' + end + if type(value) == 'boolean' then + value = value and 1 or 0 + end + return encode_tag(field_id, WIRE_TYPE_VARINT) .. encode_varint(value) +end + +-- Encode an integral value as VARINT using the "ZigZag" encoding. +-- +-- Input value types: number (integral), cdata<int64_t>, cdata<uint64_t>. +-- +-- Used for Protocol Buffers types: sint32, sint64. +local function encode_sint(field_id, value) + if value >= 0 then + return encode_int(field_id, 2 * value) + else + value = ffi.cast('uint64_t', -value) + return encode_int(field_id, 2 * value - 1) + end +end + +-- Encode an integral value as I32. +-- +-- Input value types: number (integral), cdata<int64_t>, cdata<uint64_t>. +-- +-- Used for Protocol Buffers types: fixed32, sfixed32. +local function encode_fixed32(field_id, value) + if value == NUMERIC_DEFAULT then + return '' + end + return encode_tag(field_id, WIRE_TYPE_I32) .. as_int32(value) +end + +-- Encode a floating point value as I32. +-- +-- Input value type: number. +-- +-- Used for Protocol Buffers type: float. +local function encode_float(field_id, value) + if value == NUMERIC_DEFAULT then + return '' + end + return encode_tag(field_id, WIRE_TYPE_I32) .. as_float(value) +end + +-- Encode an integral value as I64. +-- +-- Input value types: number (integral), cdata<int64_t>, cdata<uint64_t>. +-- +-- Used for Protocol Buffers types: fixed64, sfixed64. +local function encode_fixed64(field_id, value) + if value == NUMERIC_DEFAULT then + return '' + end + return encode_tag(field_id, WIRE_TYPE_I64) .. as_int64(value) +end + +-- Encode a floating point value as I64. +-- +-- Input value type: number. +-- +-- Used for Protocol Buffers type: double. +local function encode_double(field_id, value) + if value == NUMERIC_DEFAULT then + return '' + end + return encode_tag(field_id, WIRE_TYPE_I64) .. as_double(value) +end + +-- Encode a string value as LEN. +-- +-- Input value type: string. The string contains raw bytes to encode. +-- +-- Used for Protocol Buffers types: string, bytes, embedded message, packed +-- repeated fields. +local function encode_len(field_id, value, ex_presence) + if value == STRING_DEFAULT and not ex_presence then + return '' + end + return string.format('%s%s%s', + encode_tag(field_id, WIRE_TYPE_LEN), + encode_varint(string.len(value)), + value) +end + +-- }}} API functions + +return{ + encode_int = encode_int, + encode_sint = encode_sint, + encode_float = encode_float, + encode_fixed32 = encode_fixed32, + encode_double = encode_double, + encode_fixed64 = encode_fixed64, + encode_len = encode_len, +} diff --git a/test/app-luatest/protobuf_len_test.lua b/test/app-luatest/protobuf_len_test.lua new file mode 100644 index 0000000000000000000000000000000000000000..f2286608fc6fc416190fc7853911317ae4d66b74 --- /dev/null +++ b/test/app-luatest/protobuf_len_test.lua @@ -0,0 +1,38 @@ +local t = require('luatest') +local protobuf = require('protobuf') +local g = t.group() + +g.test_ordinary_string = function() + local protocol = protobuf.protocol({ + protobuf.message('test', {val = {'string', 1}}) + }) + local result = protocol:encode('test', {val = 'protobuf'}) + t.assert_equals(string.hex(result), '0a0870726f746f627566') +end + +g.test_nested_messages = function() + local protocol = protobuf.protocol({ + protobuf.message('test', {exv = {'nest', 1}}), + protobuf.message('nest', {inv = {'string', 1}}) + }) + local data = {exv = {inv = 'protobuf'}} + local result = protocol:encode('test', data) + t.assert_equals(string.hex(result), '0a0a0a0870726f746f627566') +end +g.test_byte = function() + local protocol = protobuf.protocol({ + protobuf.message('test', {val = {'bytes', 1}}) + }) + local data = {val = '0a0a0a0870726f746f627566'} + local proto_res = '0a18306130613061303837303732366637343666363237353636' + local result = protocol:encode('test', data) + t.assert_equals(string.hex(result), proto_res) +end + +g.test_very_long_string = function() + local protocol = protobuf.protocol({ + protobuf.message('test', {val = {'string', 1}}) + }) + local result = protocol:encode('test', {val = ('a'):rep(2^15)}) + t.assert_equals(string.hex(result), '0a808002' .. ('61'):rep(2^15)) +end diff --git a/test/app-luatest/protobuf_module_test.lua b/test/app-luatest/protobuf_module_test.lua new file mode 100644 index 0000000000000000000000000000000000000000..176db0ba6dc1ea1506233d8db623ecad562a9ce4 --- /dev/null +++ b/test/app-luatest/protobuf_module_test.lua @@ -0,0 +1,427 @@ +local t = require('luatest') +local protobuf = require('protobuf') +local g = t.group() + +g.test_module_multiple_fields = function() + local protocol = protobuf.protocol({ + protobuf.message('KeyValue', { + key = {'bytes', 1}, + index = {'int64', 2}, + number = {'int32', 3}, + version = {'fixed32', 4}, + available = {'bool', 5}, + }) + }) + local result = protocol:encode('KeyValue', { + key = 'abc', + number = 25, + version = 1, + index = 15, + available = true, + }) + t.assert_str_contains(string.hex(result), '2801') + t.assert_str_contains(string.hex(result), '0a03616263') + t.assert_str_contains(string.hex(result), '100f') + t.assert_str_contains(string.hex(result), '2501000000') + t.assert_str_contains(string.hex(result), '1819') +end + +g.test_module_selective_coding = function() + local protocol = protobuf.protocol({ + protobuf.message('KeyValue', { + key = {'bytes', 1}, + index = {'int64', 2}, + number = {'int32', 3}, + version = {'fixed32', 4}, + available = {'bool', 5}, + }) + }) + local result = protocol:encode('KeyValue', { + version = 1, + key = 'abc', + }) + t.assert_str_contains(string.hex(result), '0a03616263') + t.assert_str_contains(string.hex(result), '2501000000') +end + +g.test_module_selective_coding_with_box_NULL = function() + local protocol = protobuf.protocol({ + protobuf.message('KeyValue', { + key = {'bytes', 1}, + index = {'int64', 2}, + number = {'int32', 3}, + version = {'fixed32', 4}, + available = {'bool', 5}, + }) + }) + local result = protocol:encode('KeyValue', { + version = 1, + key = 'abc', + number = box.NULL, + }) + t.assert_str_contains(string.hex(result), '0a03616263') + t.assert_str_contains(string.hex(result), '2501000000') +end + +g.test_module_multiple_messages = function() + local protocol = protobuf.protocol({ + protobuf.message('KeyValue', { + key = {'bytes', 1}, + index = {'int64', 2}, + }), + protobuf.message('Storage', { + number = {'int32', 3}, + version = {'fixed32', 4}, + available = {'bool', 5}, + }) + }) + local result = protocol:encode('KeyValue', { + index = 25, + key = 'abc', + }) + t.assert_str_contains(string.hex(result), '0a03616263') + t.assert_str_contains(string.hex(result), '1019') + local result = protocol:encode('Storage', { + number = 15, + available = true, + }) + t.assert_str_contains(string.hex(result), '2801') + t.assert_str_contains(string.hex(result), '180f') +end + +g.test_module_nested_messages = function() + local protocol = protobuf.protocol({ + protobuf.message('KeyValue', { + key = {'bytes', 1}, + index = {'int64', 2}, + info = {'Spec', 3}, + }), + protobuf.message('Spec', { + number = {'int32', 3}, + version = {'fixed32', 4}, + available = {'bool', 5}, + }) + }) + local result = protocol:encode('KeyValue', { + index = 25, + key = 'abc', + info = { + number = 15, + available = true, + version = 1, + } + }) + t.assert_str_contains(string.hex(result), '1019') + t.assert_str_contains(string.hex(result), '0a03616263') + t.assert_str_contains(string.hex(result), '1a09') + t.assert_str_contains(string.hex(result), '180f') + t.assert_str_contains(string.hex(result), '2501000000') + t.assert_str_contains(string.hex(result), '2801') +end + +g.test_module_message_default_value_encoding = function() + local protocol = protobuf.protocol({ + protobuf.message('KeyValue', { + key = {'bytes', 1}, + index = {'int64', 2}, + info = {'Spec', 3}, + }), + protobuf.message('Spec', { + number = {'int32', 3}, + version = {'fixed32', 4}, + available = {'bool', 5}, + }) + }) + local result = protocol:encode('KeyValue', { + index = 25, + key = 'abc', + info = {}, + }) + t.assert_str_contains(string.hex(result), '1a00') + local result = protocol:encode('KeyValue', { + index = 25, + key = 'abc', + }) + t.assert_str_matches(string.hex(result), '0a036162631019') +end + +g.test_module_enum_usage = function() + local protocol = protobuf.protocol({ + protobuf.message('KeyValue', { + key = {'bytes', 1}, + index = {'int64', 2}, + ret_val = {'ReturnValue', 3}, + }), + protobuf.enum('ReturnValue', { + ok = 0, + error1 = 1, + error2 = 2, + error3 = 3, + }) + }) + local result = protocol:encode('KeyValue', { + index = 25, + key = 'abc', + ret_val = 'error2', + }) + t.assert_str_contains(string.hex(result), '1019') + t.assert_str_contains(string.hex(result), '0a03616263') + t.assert_str_contains(string.hex(result), '1802') +end + +g.test_module_enum_default_value_encoding = function() + local protocol = protobuf.protocol({ + protobuf.message('test', { + str_ret_val = {'ReturnValue', 1}, + num_ret_val = {'ReturnValue', 2}, + }), + protobuf.enum('ReturnValue', { + ok = 0, + error1 = 1, + }) + }) + local result = protocol:encode('test', { + str_ret_val = 'ok', + num_ret_val = 0, + }) + t.assert_str_matches(string.hex(result), '') +end + +g.test_module_exception_enum_on_top_level = function() + local protocol = protobuf.protocol({ + protobuf.message('MyMessage', { + myfield = {'MyEnum', 3}, + }), + protobuf.enum('MyEnum', { + ok = 0, + }) + }) + local msg = 'Attempt to encode enum "MyEnum" as a top level message' + local data = {myfield = 'ok'} + t.assert_error_msg_contains(msg, protocol.encode, protocol, 'MyEnum', data) +end + +g.test_module_exception_value_in_enum_out_of_int32_range = function() + local msg = 'Input data for "ok" field is "4294967297" and ' .. + 'do not fit in "int32"' + local enum_def = {def = 0, ok = 2^32 + 1} + t.assert_error_msg_contains(msg, protobuf.enum, 'MyEnum', enum_def) +end + +g.test_module_exception_enum_value_out_of_int32_range = function() + local protocol = protobuf.protocol({ + protobuf.message('KeyValue', { + key = {'bytes', 1}, + index = {'int64', 2}, + ret_val = {'ReturnValue', 3}, + }), + protobuf.enum('ReturnValue', { + ok = 0, + error1 = 1, + error2 = 2, + error3 = 3, + }) + }) + local msg = 'Input data for "nil" field is "4294967297" ' .. + 'and do not fit in "int32"' + local data = {ret_val = 2^32 + 1} + t.assert_error_msg_contains(msg, protocol.encode, protocol, + 'KeyValue', data) +end + +g.test_module_exception_enum_wrong_value = function() + local protocol = protobuf.protocol({ + protobuf.message('test', { + key = {'bytes', 1}, + index = {'int64', 2}, + ret_val = {'ReturnValue', 3}, + }), + protobuf.enum('ReturnValue', { + ok = 0, + error1 = 1, + error2 = 2, + error3 = 3, + }) + }) + local msg = '"error4" is not defined in "ReturnValue" enum' + local data = {ret_val = 'error4'} + t.assert_error_msg_contains(msg, protocol.encode, protocol, 'test', data) +end + +g.test_module_enum_number_encoding = function() + local protocol = protobuf.protocol({ + protobuf.message('KeyValue', { + key = {'bytes', 1}, + index = {'int64', 2}, + ret_val = {'ReturnValue', 3}, + }), + protobuf.enum('ReturnValue', { + ok = 0, + error1 = 1, + error2 = 2, + error3 = 3, + }) + }) + local result = protocol:encode('KeyValue', {ret_val = 15}) + t.assert_str_matches(string.hex(result), '180f') +end + +g.test_module_exception_empty_protocol = function() + local protocol = protobuf.protocol({}) + local msg = 'There is no message or enum named "test" ' .. + 'in the given protocol' + local data = {val = 1.5} + t.assert_error_msg_contains(msg, protocol.encode, protocol, 'test', data) +end + +g.test_module_exception_undefined_field = function() + local protocol = protobuf.protocol({ + protobuf.message('test', { + val = {'int32', 1} + }), + }) + local msg = 'Wrong field name "res" for "test" message' + local data = {res = 1} + t.assert_error_msg_contains(msg, protocol.encode, protocol, 'test', data) +end + +g.test_module_unknown_fields = function() + local protocol = protobuf.protocol({ + protobuf.message('test', { + key = {'bytes', 1}, + index = {'int64', 2}, + }) + }) + local result = protocol:encode('test', { + index = 10, + _unknown_fields = {'\x1a\x03\x61\x62\x63'}, + }) + t.assert_str_contains(string.hex(result), '100a') + t.assert_str_contains(string.hex(result), '1a03616263') +end + +g.test_module_unknown_fields_multiple = function() + local protocol = protobuf.protocol({ + protobuf.message('test', { + key = {'bytes', 1}, + index = {'int64', 2}, + }) + }) + local result = protocol:encode('test', { + index = 10, + _unknown_fields = {'\x1a\x01\x61', '\x22\x01\x62'}, + }) + t.assert_str_contains(string.hex(result), '100a') + t.assert_str_contains(string.hex(result), '1a0161220162') +end + +g.test_module_exception_name_reusage = function() + local protocol_def = { + protobuf.message('test', { + val = {'int32', 1} + }), + protobuf.message('test', { + res = {'int32', 1} + }) + } + local msg = 'Double definition of name "test"' + t.assert_error_msg_contains(msg, protobuf.protocol, protocol_def) +end + +g.test_module_recursive = function() + local protocol = protobuf.protocol({ + protobuf.message('test', { + val = {'int32', 1}, + recursion = {'test', 2} + }) + }) + local result = protocol:encode('test', { + val = 15, + recursion = {val = 15}, + }) + t.assert_str_contains(string.hex(result), '1202080f080f') +end + +g.test_module_exception_not_declared_type = function() + local protocol_def = { + protobuf.message('test', { + val = {'int32', 1}, + recursion = {'value', 2} + }) + } + local msg = 'Type "value" is not declared' + t.assert_error_msg_contains(msg, protobuf.protocol, protocol_def) +end + +g.test_module_exception_repeated_id = function() + local message_name = 'test' + local message_def = { + val1 = {'int32', 1}, + val2 = {'int32', 1} + } + local msg = 'Id 1 in field "val1" was already used' + t.assert_error_msg_contains(msg, protobuf.message, + message_name, message_def) +end + +g.test_module_exception_id_out_of_range = function() + local message_name = 'test' + local message_def = { + val = {'int32', 2^32} + } + local msg = 'Id 4294967296 in field "val" is out of range [1; 536870911]' + t.assert_error_msg_contains(msg, protobuf.message, + message_name, message_def) +end + +g.test_module_exception_id_in_prohibited_range = function() + local message_name = 'test' + local message_def = { + val = {'int32', 19000} + } + local msg = 'Id 19000 in field "val" is in reserved ' .. + 'id range [19000, 19999]' + t.assert_error_msg_contains(msg, protobuf.message, + message_name, message_def) +end + +g.test_module_exception_enum_double_definition = function() + local enum_name = 'test' + local enum_def = { + ok = 0, + error1 = 2, + error2 = 2, + error3 = 3, + } + local msg = 'Double definition of enum field "error2" by 2' + t.assert_error_msg_contains(msg, protobuf.enum, enum_name, enum_def) +end + +g.test_module_exception_enum_missing_zero = function() + local enum_name = 'test' + local enum_def = { + error1 = 1, + error2 = 2, + error3 = 3, + } + local msg = '"test" definition does not contain a field with id = 0' + t.assert_error_msg_contains(msg, protobuf.enum, enum_name, enum_def) +end + +-- Previous encoding implementation had a bug of signed integer overflow +-- which led to different behaviour of interpreted and JITted code. +-- This test repeates encoding process enough times to JIT the code and +-- checks result at each iteration. +g.test_repetitive_int64_encoding = function() + local protocol = protobuf.protocol({ + protobuf.message('test', { + index = {'int64', 1}, + }) + }) + for _ = 1, 300 do + local result = protocol:encode('test', { + index = -770, + }) + t.assert_str_contains(string.hex(result), '08fef9ffffffffffffff01') + end +end diff --git a/test/app-luatest/protobuf_numeric_test.lua b/test/app-luatest/protobuf_numeric_test.lua new file mode 100644 index 0000000000000000000000000000000000000000..628d4a986e7eb73c2f44074f40d039d9a91c136d --- /dev/null +++ b/test/app-luatest/protobuf_numeric_test.lua @@ -0,0 +1,310 @@ +local ffi = require('ffi') +local t = require('luatest') +local protobuf = require('protobuf') + +local p = t.group('upper_limit', { + {type = 'int32', value = 2^31 - 1, res = '08ffffffff07'}, + {type = 'int32', value = 2LL^31 - 1, res = '08ffffffff07'}, + {type = 'int32', value = 2ULL^31 - 1, res = '08ffffffff07'}, + {type = 'sint32', value = 2^31 - 1, res = '08feffffff0f'}, + {type = 'sint32', value = 2LL^31 - 1, res = '08feffffff0f'}, + {type = 'sint32', value = 2ULL^31 - 1, res = '08feffffff0f'}, + {type = 'uint32', value = 2^32 - 1, res = '08ffffffff0f'}, + {type = 'uint32', value = 2LL^32 - 1, res = '08ffffffff0f'}, + {type = 'uint32', value = 2ULL^32 - 1, res = '08ffffffff0f'}, + {type = 'int64', value = 2^63 - 513, res = '0880f8ffffffffffff7f'}, + {type = 'int64', value = 2LL^63 - 1, res = '08ffffffffffffffff7f'}, + {type = 'int64', value = 2ULL^63 - 1, res = '08ffffffffffffffff7f'}, + {type = 'sint64', value = 2^63 - 513, res = '0880f0ffffffffffffff01'}, + {type = 'sint64', value = 2LL^63 - 1, res = '08feffffffffffffffff01'}, + {type = 'sint64', value = 2ULL^63 - 1, res = '08feffffffffffffffff01'}, + {type = 'uint64', value = 2^64 - 1025, res = '0880f0ffffffffffffff01'}, + {type = 'uint64', value = 2ULL^64 - 1, res = '08ffffffffffffffffff01'}, + {type = 'bool', value = true, res = '0801'}, + {type = 'float', value = 0x1.fffffep+127, res = '0dffff7f7f'}, + {type = 'double', value = 0x1.fffffffffffffp+1023, + res = '09ffffffffffffef7f'}, + {type = 'fixed32', value = 2^32 - 1, res = '0dffffffff'}, + {type = 'fixed32', value = 2LL^32 - 1, res = '0dffffffff'}, + {type = 'fixed32', value = 2ULL^32 - 1, res = '0dffffffff'}, + {type = 'sfixed32', value = 2^31 - 1, res = '0dffffff7f'}, + {type = 'sfixed32', value = 2LL^31 - 1, res = '0dffffff7f'}, + {type = 'sfixed32', value = 2ULL^31 - 1, res = '0dffffff7f'}, + {type = 'fixed64', value = 2^64 - 1025, res = '0900f8ffffffffffff'}, + {type = 'fixed64', value = 2ULL^64 - 1, res = '09ffffffffffffffff'}, + {type = 'sfixed64', value = 2^63 - 513, res = '0900fcffffffffff7f'}, + {type = 'sfixed64', value = 2LL^63 - 1, res = '09ffffffffffffff7f'}, + {type = 'sfixed64', value = 2ULL^63 - 1, res = '09ffffffffffffff7f'}, + +}) + +p.test_upper_limit = function(cg) + local protocol = protobuf.protocol({ + protobuf.message('test', {val = {cg.params.type, 1}}) + }) + local result = protocol:encode('test', {val = cg.params.value}) + t.assert_equals(string.hex(result), cg.params.res) +end + +p = t.group('lower_limit', { + {type = 'int32', value = -2^31, res = '0880808080f8ffffffff01'}, + {type = 'int32', value = -2LL^31, res = '0880808080f8ffffffff01'}, + {type = 'sint32', value = -2^31, res = '08ffffffff0f'}, + {type = 'sint32', value = -2LL^31, res = '08ffffffff0f'}, + {type = 'uint32', value = 0, res = ''}, + {type = 'uint32', value = 0LL, res = ''}, + {type = 'uint32', value = 0ULL, res = ''}, + {type = 'int64', value = -2^63, res = '0880808080808080808001'}, + {type = 'int64', value = -2LL^63, res = '0880808080808080808001'}, + {type = 'sint64', value = -2^63, res = '08ffffffffffffffffff01'}, + {type = 'sint64', value = -2LL^63, res = '08ffffffffffffffffff01'}, + {type = 'uint64', value = 0, res = ''}, + {type = 'uint64', value = 0LL, res = ''}, + {type = 'uint64', value = 0ULL, res = ''}, + {type = 'bool', value = false, res = ''}, + {type = 'float', value = -0x1.fffffep+127, res = '0dffff7fff'}, + {type = 'double', value = -0x1.fffffffffffffp+1023, + res = '09ffffffffffffefff'}, + {type = 'fixed32', value = 0, res = ''}, + {type = 'fixed32', value = 0LL, res = ''}, + {type = 'fixed32', value = 0ULL, res = ''}, + {type = 'sfixed32', value = -2^31, res = '0d00000080'}, + {type = 'sfixed32', value = -2LL^31, res = '0d00000080'}, + {type = 'fixed64', value = 0, res = ''}, + {type = 'fixed64', value = 0LL, res = ''}, + {type = 'fixed64', value = 0ULL, res = ''}, + {type = 'sfixed64', value = -2^63, res = '090000000000000080'}, + {type = 'sfixed64', value = -2LL^63, res = '090000000000000080'}, +}) + +p.test_lower_limit = function(cg) + local protocol = protobuf.protocol({ + protobuf.message('test', {val = {cg.params.type, 1}}) + }) + local result = protocol:encode('test', {val = cg.params.value}) + t.assert_equals(string.hex(result), cg.params.res) +end + +p = t.group('exception_input_data_float', t.helpers.matrix({ + type = {'int32', 'sint32', 'uint32', 'int64', 'sint64', 'uint64', + 'fixed32', 'sfixed32', 'fixed64', 'sfixed64'}, + arg = { + {value = 1.5, msg = 'Input number value 1.500000 for ' .. + '"val" is not integer'}, + {value = ffi.cast('float', 1.5), msg = 'Input cdata value ' .. + '"ctype<float>" for "val" field is not integer'}, + } +})) + +p.test_exception_input_data_float = function(cg) + local protocol = protobuf.protocol({ + protobuf.message('test', {val = {cg.params.type, 1}}) + }) + local data = {val = cg.params.arg.value} + t.assert_error_msg_contains(cg.params.arg.msg, protocol.encode, + protocol, 'test', data) +end + +p = t.group('exception_input_data_wrong_type', t.helpers.matrix({ + type = {'int32', 'sint32', 'uint32', 'int64', 'sint64', 'uint64', + 'fixed32', 'sfixed32', 'fixed64', 'sfixed64', 'bool', 'float', + 'double'}, +})) + +p.test_exception_input_data_wrong_type = function(cg) + local protocol = protobuf.protocol({ + protobuf.message('test', {val = {cg.params.type, 1}}) + }) + local msg = 'Field "val" of "' .. cg.params.type .. '" type gets ' .. + '"string" type value.' + local data = {val = 'str'} + t.assert_error_msg_contains(msg, protocol.encode, protocol, 'test', data) +end + +p = t.group('exception_cdata_input_for_float_field', t.helpers.matrix({ + type = {'float', 'double'}, +})) + +p.test_exception_cdata_input_for_float_field = function(cg) + local protocol = protobuf.protocol({ + protobuf.message('test', {val = {cg.params.type, 1}}) + }) + local msg = 'Field "val" of "' .. cg.params.type .. '" type gets ' .. + '"cdata" type value.' + local data = {val = ffi.cast('float', 1.5)} + t.assert_error_msg_contains(msg, protocol.encode, protocol, 'test', data) +end + +p = t.group('exception_input_data_Nan', t.helpers.matrix({ + type = {'int32', 'sint32', 'uint32', 'int64', 'sint64', 'uint64', + 'fixed32', 'sfixed32', 'fixed64', 'sfixed64', 'float', 'double'}, + arg = { + {value = 0/0, msg = 'Input data for "val" field is NaN'}, + {value = 1/0, msg = 'Input data for "val" field is inf'}, + {value = -1/0, msg = 'Input data for "val" field is inf'}, + } +})) + +p.test_exception_input_data_Nan = function(cg) + local protocol = protobuf.protocol({ + protobuf.message('test', {val = {cg.params.type, 1}}) + }) + local data = {val = cg.params.arg.value} + t.assert_error_msg_contains(cg.params.arg.msg, protocol.encode, + protocol, 'test', data) +end + +local p = t.group('exception_input_data_out_of_range', { + {type = 'int32', value = 2^31, res = '2147483648'}, + {type = 'int32', value = 2LL^31, res = '2147483648LL'}, + {type = 'int32', value = 2ULL^31, res = '2147483648ULL'}, + {type = 'int32', value = -2^31 - 1, res = '-2147483649'}, + {type = 'int32', value = -2LL^31 - 1, res = '-2147483649LL'}, + {type = 'sint32', value = 2^31, res = '2147483648'}, + {type = 'sint32', value = 2LL^31, res = '2147483648LL'}, + {type = 'sint32', value = 2ULL^31, res = '2147483648ULL'}, + {type = 'sint32', value = -2^31 - 1, res = '-2147483649'}, + {type = 'sint32', value = -2LL^31 - 1, res = '-2147483649LL'}, + {type = 'uint32', value = 2^32, res = '4294967296'}, + {type = 'uint32', value = 2LL^32, res = '4294967296LL'}, + {type = 'uint32', value = 2ULL^32, res = '4294967296ULL'}, + {type = 'uint32', value = -1, res = '-1'}, + {type = 'uint32', value = -1LL, res = '-1LL'}, + {type = 'int64', value = 2^63 - 512, res = '9.2233720368548e+18'}, + {type = 'int64', value = 2ULL^63, res = '9223372036854775808ULL'}, + {type = 'int64', value = -2^63 - 1025, res = '-9.2233720368548e+18'}, + {type = 'sint64', value = 2^63 - 512, res = '9.2233720368548e+18'}, + {type = 'sint64', value = 2ULL^63, res = '9223372036854775808ULL'}, + {type = 'sint64', value = -2^63 - 1025, res = '-9.2233720368548e+18'}, + {type = 'uint64', value = -1, res = '-1'}, + {type = 'uint64', value = -1LL, res = '-1LL'}, + {type = 'float', value = 0x1.fffffe018d3f8p+127, res = '3.402823467e+38'}, + {type = 'fixed32', value = 2^32, res = '4294967296'}, + {type = 'fixed32', value = 2LL^32, res = '4294967296LL'}, + {type = 'fixed32', value = 2ULL^32, res = '4294967296ULL'}, + {type = 'fixed32', value = -1, res = '-1'}, + {type = 'fixed32', value = -1LL, res = '-1LL'}, + {type = 'sfixed32', value = 2^31, res = '2147483648'}, + {type = 'sfixed32', value = 2LL^31, res = '2147483648LL'}, + {type = 'sfixed32', value = 2ULL^31, res = '2147483648ULL'}, + {type = 'sfixed32', value = -2^31 - 1, res = '-2147483649'}, + {type = 'sfixed32', value = -2LL^31 - 1, res = '-2147483649LL'}, + {type = 'fixed64', value = 2^64, res = '1.844674407371e+19'}, + {type = 'fixed64', value = -1, res = '-1'}, + {type = 'fixed64', value = -1LL, res = '-1LL'}, + {type = 'sfixed64', value = 2^63 - 512, res = '9.2233720368548e+18'}, + {type = 'sfixed64', value = 2ULL^63, res = '9223372036854775808ULL'}, + {type = 'sfixed64', value = -2^63 - 1025, res = '-9.2233720368548e+18'}, +}) + +p.test_exception_input_data_out_of_range = function(cg) + local protocol = protobuf.protocol({ + protobuf.message('test', {val = {cg.params.type, 1}}) + }) + local msg = 'Input data for "val" field is "' .. cg.params.res.. + '" and do not fit in "' .. cg.params.type .. '"' + local data = {val = cg.params.value} + t.assert_error_msg_contains(msg, protocol.encode, protocol, 'test', data) +end + +p = t.group('regular_signed_values', t.helpers.matrix({ + value = {1540, -770, -10LL, 10ULL}, + res = { + {type = 'int32', code = {['1540'] = '08840c', + ['-770'] = '08fef9ffffffffffffff01', ['10ULL'] = '080a', + ['-10LL'] = '08f6ffffffffffffffff01'}}, + {type = 'sint32', code = {['1540'] = '088818', ['-770'] = '08830c', + ['10ULL'] = '0814', ['-10LL'] = '0813'}}, + {type = 'int64', code = {['1540'] = '08840c', + ['-770'] = '08fef9ffffffffffffff01', ['10ULL'] = '080a', + ['-10LL'] = '08f6ffffffffffffffff01'}}, + {type = 'sint64', code = {['1540'] = '088818', ['-770'] = '08830c', + ['10ULL'] = '0814', ['-10LL'] = '0813'}}, + {type = 'sfixed32', code = {['1540'] = '0d04060000', + ['-770'] = '0dfefcffff', ['10ULL'] = '0d0a000000', + ['-10LL'] = '0df6ffffff'}}, + {type = 'sfixed64', code = {['1540'] = '090406000000000000', + ['-770'] = '09fefcffffffffffff', ['10ULL'] = '090a00000000000000', + ['-10LL'] = '09f6ffffffffffffff'}}, + }, +})) + +p.test_regular_signed_values = function(cg) + local protocol = protobuf.protocol({ + protobuf.message('test', {val = {cg.params.res.type, 1}}) + }) + local result = protocol:encode('test', {val = cg.params.value}) + t.assert_equals(string.hex(result), + cg.params.res.code[tostring(cg.params.value)]) +end + +p = t.group('regular_usigned_values', t.helpers.matrix({ + value = {1540, 10LL, 15ULL}, + res = { + {type = 'uint32', code = {['1540'] = '08840c', ['10LL'] = '080a', + ['15ULL'] = '080f'}}, + {type = 'uint64', code = {['1540'] = '08840c', ['10LL'] = '080a', + ['15ULL'] = '080f'}}, + {type = 'fixed32', code = {['1540'] = '0d04060000', + ['10LL'] = '0d0a000000', ['15ULL'] = '0d0f000000'}}, + {type = 'fixed64', code = {['1540'] = '090406000000000000', + ['10LL'] = '090a00000000000000', + ['15ULL'] = '090f00000000000000'}}, + }, +})) + +p.test_regular_unsigned_values = function(cg) + local protocol = protobuf.protocol({ + protobuf.message('test', {val = {cg.params.res.type, 1}}) + }) + local result = protocol:encode('test', {val = cg.params.value}) + t.assert_equals(string.hex(result), + cg.params.res.code[tostring(cg.params.value)]) +end + +p = t.group('regular_floating_point_values', t.helpers.matrix({ + value = {1.5, -1.5}, + res = { + {type = 'float', code = {['1.5'] = '0d0000c03f', + ['-1.5'] = '0d0000c0bf'}}, + {type = 'double', code = {['1.5'] = '09000000000000f83f', + ['-1.5'] = '09000000000000f8bf'}}, + }, +})) + +p.test_regular_floating_point_values = function(cg) + local protocol = protobuf.protocol({ + protobuf.message('test', {val = {cg.params.res.type, 1}}) + }) + local result = protocol:encode('test', {val = cg.params.value}) + t.assert_equals(string.hex(result), + cg.params.res.code[tostring(cg.params.value)]) +end + +p = t.group('numeric_types_default_value_encoding', t.helpers.matrix({ + type = {'int32', 'sint32', 'uint32', 'int64', 'sint64', 'uint64', + 'fixed32', 'sfixed32', 'fixed64', 'sfixed64', 'float', 'double'}, + value = {0} +})) + +p.test_numeric_types_default_value_encoding = function(cg) + local protocol = protobuf.protocol({ + protobuf.message('test', {val = {cg.params.type, 1}}) + }) + local result = protocol:encode('test', {val = cg.params.value}) + t.assert_equals(string.hex(result), '') +end + +p = t.group('other_types_default_value_encoding', { + {type = 'bool', value = false}, + {type = 'string', value = ''}, + {type = 'bytes', value = ''}, +}) + +p.test_numeric_types_default_value_encoding = function(cg) + local protocol = protobuf.protocol({ + protobuf.message('test', {val = {cg.params.type, 1}}) + }) + local result = protocol:encode('test', {val = cg.params.value}) + t.assert_equals(string.hex(result), '') +end diff --git a/test/app-luatest/protobuf_repeated_test.lua b/test/app-luatest/protobuf_repeated_test.lua new file mode 100644 index 0000000000000000000000000000000000000000..c5572f4b98219139ac2843c60fb5947a8d1d1b07 --- /dev/null +++ b/test/app-luatest/protobuf_repeated_test.lua @@ -0,0 +1,115 @@ +local t = require('luatest') +local protobuf = require('protobuf') +local g = t.group() + +local p = t.group('packed_repeated_encoding', { + {type = 'repeated float', value = {0.5, 1.5}, + res = '0a080000003f0000c03f'}, + {type = 'repeated double', value = {0.5, 1.5}, + res = '0a10000000000000e03f000000000000f83f'}, + {type = 'repeated int32', value = {1, 2, 3, 4}, res = '0a0401020304'}, + {type = 'repeated sint32', value = {1, 2, 3, 4}, res = '0a0402040608'}, + {type = 'repeated uint32', value = {1, 2, 3, 4}, res = '0a0401020304'}, + {type = 'repeated int64', value = {1, 2, 3, 4}, res = '0a0401020304'}, + {type = 'repeated sint64', value = {1, 2, 3, 4}, res = '0a0402040608'}, + {type = 'repeated uint64', value = {1, 2, 3, 4}, res = '0a0401020304'}, + {type = 'repeated fixed32', value = {1, 2}, res = '0a080100000002000000'}, + {type = 'repeated sfixed32', value = {1, 2}, res = '0a080100000002000000'}, + {type = 'repeated fixed64', value = {1, 2}, + res = '0a1001000000000000000200000000000000'}, + {type = 'repeated sfixed64', value = {1, 2}, + res = '0a1001000000000000000200000000000000'}, + {type = 'repeated bool', value = {true, true}, res = '0a020101'}, +}) + +p.test_packed_repeated_encoding = function(cg) + local protocol = protobuf.protocol({ + protobuf.message('test', {val = {cg.params.type, 1}}) + }) + local result = protocol:encode('test', {val = cg.params.value}) + t.assert_equals(string.hex(result), cg.params.res) +end + +g.test_packed_repeated_int32_long_tag = function() + local protocol = protobuf.protocol({ + protobuf.message('test', {val = {'repeated int32', 1000}}) + }) + local result = protocol:encode('test', {val = {1, 2, 3, 4}}) + t.assert_equals(string.hex(result), 'c23e0401020304') +end + +g.test_exception_repeated_int32_with_default_value = function() + local protocol = protobuf.protocol({ + protobuf.message('test', {val = {'repeated int32', 1}}) + }) + local msg = 'Input for "val" repeated field contains default value ' .. + 'can`t be encoded correctly' + local data = {val = {1, 0, 0, 4}} + t.assert_error_msg_contains(msg, protocol.encode, + protocol, 'test', data) +end + +local p = t.group('non_packed_repeated_encoding', { + {type = 'repeated bytes', value = {'fuz', 'buz'}, + res = '0a0366757a0a0362757a'}, + {type = 'repeated string', value = {'fuz', 'buz'}, + res = '0a0366757a0a0362757a'}, +}) + +p.test_non_packed_repeated_encoding = function(cg) + local protocol = protobuf.protocol({ + protobuf.message('test', {val = {cg.params.type, 1}}) + }) + local result = protocol:encode('test', {val = cg.params.value}) + t.assert_equals(string.hex(result), cg.params.res) +end + +g.test_repeated_message = function() + local protocol = protobuf.protocol({ + protobuf.message('test', {val = {'repeated field', 1}}), + protobuf.message('field', {id = {'int32', 1}, name = {'string', 2}}) + }) + local data = {val = {{id = 1, name = 'fuz'}, {id = 2, name = 'buz'}}} + local proto_res = '0a07120366757a08010a07120362757a0802' + local result = protocol:encode('test', data) + t.assert_equals(string.hex(result), proto_res) +end + +g.test_repeated_enum = function() + local protocol = protobuf.protocol({ + protobuf.message('test', {val = {'repeated field', 1}}), + protobuf.enum('field', {Default = 0, True = 1, False = 2}) + }) + local data = {val = {'True', 'True', 'False'}} + local proto_res = '080108010802' + local result = protocol:encode('test', data) + t.assert_equals(string.hex(result), proto_res) +end + +local p = t.group('exceptions_repeated_encoding', { + {value = {1, 'fuz'}, msg = 'Field "val" of "int32" type ' .. + 'gets "string" type value.'}, + {value = 12, msg = 'For repeated fields table data are needed'}, + {value = {1, fuz = 2, 3}, msg = 'Input array for "val" repeated ' .. + 'field contains non-numeric key: "fuz"'}, + {value = {1, [0.5] = 2, 3}, msg = 'Input array for "val" repeated ' .. + 'field contains non-integer numeric key: "0.5"'}, + {value = {[2] = 2, [3] = 3, [4] = 4}, msg = 'Input array for "val" ' .. + 'repeated field got min index 2. Must be 1'}, + {value = {[1] = 1, [3] = 3, [4] = 4}, msg = 'Input array for "val" ' .. + 'repeated field has inconsistent keys. Got table with 3 fields ' .. + 'and max index of 4'}, + {value = {1, nil, 2}, msg = 'Input array for "val" repeated field ' .. + 'has inconsistent keys. Got table with 2 fields and max index of 3'}, + {value = {1, box.NULL, 2}, msg = 'Input array for "val" repeated ' .. + 'field contains box.NULL value which leads to ambiguous behaviour'}, +}) + +p.test_exceptions_repeated_encoding = function(cg) + local protocol = protobuf.protocol({ + protobuf.message('test', {val = {'repeated int32', 1}}) + }) + local data = {val = cg.params.value} + t.assert_error_msg_contains(cg.params.msg, protocol.encode, + protocol, 'test', data) +end