diff --git a/changelogs/unreleased/gh-8928-net-box-from-fd.md b/changelogs/unreleased/gh-8928-net-box-from-fd.md new file mode 100644 index 0000000000000000000000000000000000000000..cac6624c9d6eed17b88dc2f30d6ff63167629146 --- /dev/null +++ b/changelogs/unreleased/gh-8928-net-box-from-fd.md @@ -0,0 +1,4 @@ +## feature/lua/netbox + +* Introduced the new `net.box` module function `from_fd` for creating a new + connection from a socket file descriptor number (gh-8984). diff --git a/src/box/lua/net_box.c b/src/box/lua/net_box.c index c9c1ae61794ecbfb4fa307c3c17347a69d80d180..7cba024bfd6c14cbefa0fa3fd94bb602ec29f6d9 100644 --- a/src/box/lua/net_box.c +++ b/src/box/lua/net_box.c @@ -59,6 +59,7 @@ #include "lua/fiber_cond.h" #include "lua/msgpack.h" #include "lua/uri.h" +#include "lua/utils.h" #include "msgpuck.h" #include "small/ibuf.h" #include "small/region.h" @@ -151,8 +152,10 @@ static const char *netbox_state_str[] = { }; struct netbox_options { - /** Remote server URI. */ + /** Remote server URI. Nil if this connection was created from fd. */ struct uri uri; + /** Connection fd. -1 if this connection was created from URI. */ + int fd; /** Authentication method. NULL if unspecified. */ const struct auth_method *auth_method; /** User credentials. */ @@ -487,6 +490,7 @@ netbox_options_create(struct netbox_options *opts) { memset(opts, 0, sizeof(*opts)); uri_create(&opts->uri, NULL); + opts->fd = -1; opts->auth_method = NULL; opts->callback_ref = LUA_NOREF; opts->connect_timeout = NETBOX_DEFAULT_CONNECT_TIMEOUT; @@ -1043,16 +1047,23 @@ netbox_transport_connect(struct netbox_transport *transport) assert(!iostream_is_initialized(io)); ev_tstamp start, delay; coio_timeout_init(&start, &delay, transport->opts.connect_timeout); - int fd = coio_connect_timeout(transport->opts.uri.host, - transport->opts.uri.service, - transport->opts.uri.host_hint, - /*addr=*/NULL, /*addr_len=*/NULL, delay); - coio_timeout_update(&start, &delay); - if (fd < 0) - goto io_error; - if (iostream_create(io, fd, &transport->io_ctx) != 0) { - close(fd); - goto error; + int fd = transport->opts.fd; + if (fd >= 0) { + plain_iostream_create(io, fd); + } else { + assert(!uri_is_nil(&transport->opts.uri)); + fd = coio_connect_timeout(transport->opts.uri.host, + transport->opts.uri.service, + transport->opts.uri.host_hint, + /*addr=*/NULL, /*addr_len=*/NULL, + delay); + coio_timeout_update(&start, &delay); + if (fd < 0) + goto io_error; + if (iostream_create(io, fd, &transport->io_ctx) != 0) { + close(fd); + goto error; + } } char greetingbuf[IPROTO_GREETING_SIZE]; if (coio_readn_timeout(io, greetingbuf, IPROTO_GREETING_SIZE, @@ -2205,7 +2216,7 @@ luaT_netbox_request_pairs(struct lua_State *L) /** * Creates a netbox transport object (userdata) and pushes it to Lua stack. - * Takes the following arguments: uri (string, number, or table), + * Takes the following arguments: uri (string or table) or fd (number), * user (string or nil), password (string or nil), callback (function), * connect_timeout (number or nil), reconnect_after (number or nil), * fetch_schema (boolean or nil), auth_type (string or nil). @@ -2222,8 +2233,20 @@ luaT_netbox_new_transport(struct lua_State *L) lua_setmetatable(L, -2); /* Initialize options from Lua arguments. */ struct netbox_options *opts = &transport->opts; - if (luaT_uri_create(L, 1, &opts->uri) != 0) - return luaT_error(L); + if (lua_type(L, 1) == LUA_TNUMBER) { + if (!luaL_tointeger_strict(L, 1, &opts->fd) || opts->fd < 0) { + diag_set(IllegalParams, + "Invalid fd: expected nonnegative integer"); + return luaT_error(L); + } + } else { + if (luaT_uri_create(L, 1, &opts->uri) != 0) + return luaT_error(L); + if (iostream_ctx_create(&transport->io_ctx, IOSTREAM_CLIENT, + &opts->uri) != 0) { + return luaT_error(L); + } + } if (!lua_isnil(L, 2)) opts->user = xstrdup(luaL_checkstring(L, 2)); if (!lua_isnil(L, 3)) @@ -2252,10 +2275,6 @@ luaT_netbox_new_transport(struct lua_State *L) "net.box: user is not defined"); return luaT_error(L); } - if (iostream_ctx_create(&transport->io_ctx, IOSTREAM_CLIENT, - &opts->uri) != 0) { - return luaT_error(L); - } return 1; } @@ -2927,7 +2946,8 @@ netbox_worker_f(va_list ap) */ assert(fiber()->storage.lua.stack == NULL); fiber()->storage.lua.stack = L; - const double reconnect_after = transport->opts.reconnect_after; + const double reconnect_after = !uri_is_nil(&transport->opts.uri) ? + transport->opts.reconnect_after : 0; while (!fiber_is_cancelled()) { if (netbox_transport_connect(transport) == 0) { int rc = luaT_cpcall(L, netbox_connection_handler_f, @@ -2973,9 +2993,15 @@ luaT_netbox_transport_start(struct lua_State *L) struct lua_State *fiber_L = lua_newthread(L); transport->coro_ref = luaL_ref(L, LUA_REGISTRYINDEX); transport->self_ref = luaL_ref(L, LUA_REGISTRYINDEX); - const char *name = tt_sprintf("%s:%s (net.box)", - transport->opts.uri.host ?: "", - transport->opts.uri.service ?: ""); + const char *name; + if (!uri_is_nil(&transport->opts.uri)) { + name = tt_sprintf("%s:%s (net.box)", + transport->opts.uri.host ?: "", + transport->opts.uri.service ?: ""); + } else { + assert(transport->opts.fd >= 0); + name = tt_sprintf("fd=%d (net.box)", transport->opts.fd); + } transport->worker = fiber_new_system(name, netbox_worker_f); if (transport->worker == NULL) { luaL_unref(L, LUA_REGISTRYINDEX, transport->coro_ref); diff --git a/src/box/lua/net_box.lua b/src/box/lua/net_box.lua index d3387ee35dc5f975dc48ab361affb458f899c411..acb22aa8bef2a9dab052ae263d844cceed08c251 100644 --- a/src/box/lua/net_box.lua +++ b/src/box/lua/net_box.lua @@ -17,6 +17,7 @@ local check_select_opts = box.internal.check_select_opts local check_index_arg = box.internal.check_index_arg local check_space_arg = box.internal.check_space_arg local check_primary_index = box.internal.check_primary_index +local check_param = utils.check_param local check_param_table = utils.check_param_table local ibuf_t = ffi.typeof('struct ibuf') @@ -117,7 +118,11 @@ local function parse_connect_params(host_or_uri, ...) -- self? host_or_uri port? if port == nil and (type(host_or_uri) == 'string' or type(host_or_uri) == 'number' or type(host_or_uri) == 'table') then - uri = host_or_uri + if type(host_or_uri) == 'number' then + uri = tostring(host_or_uri) + else + uri = host_or_uri + end elseif (type(host_or_uri) == 'string' or host_or_uri == nil) and (type(port) == 'string' or type(port) == 'number') then uri = urilib.format({host = host_or_uri, service = tostring(port)}) @@ -132,6 +137,7 @@ local function remote_serialize(self) return { host = self.host, port = self.port, + fd = self.fd, opts = next(self.opts) and self.opts, state = self.state, error = self.error, @@ -237,22 +243,34 @@ end local space_metatable, index_metatable -local function new_sm(uri, opts) - local parsed_uri, err = urilib.parse(uri) - if not parsed_uri then - error(err) - end - if opts.user == nil and opts.password == nil then - opts.user, opts.password = parsed_uri.login, parsed_uri.password - end - if opts.auth_type == nil and parsed_uri.params ~= nil and - parsed_uri.params.auth_type ~= nil then - opts.auth_type = parsed_uri.params.auth_type[1] +local function new_sm(uri_or_fd, opts) + local host, port, fd + if type(uri_or_fd) == 'string' or type(uri_or_fd) == 'table' then + local parsed_uri, err = urilib.parse(uri_or_fd) + if not parsed_uri then + error(err) + end + if opts.user == nil and opts.password == nil then + opts.user, opts.password = parsed_uri.login, parsed_uri.password + end + if opts.auth_type == nil and parsed_uri.params ~= nil and + parsed_uri.params.auth_type ~= nil then + opts.auth_type = parsed_uri.params.auth_type[1] + end + host, port = parsed_uri.host, parsed_uri.service + else + assert(type(uri_or_fd) == 'number') + fd = uri_or_fd end - local host, port = parsed_uri.host, parsed_uri.service local user, password = opts.user, opts.password; opts.password = nil local last_reconnect_error - local remote = {host = host, port = port, opts = opts, state = 'initial'} + local remote = { + host = host, + port = port, + fd = fd, + opts = opts, + state = 'initial', + } local function callback(what, ...) if remote._fiber == nil then remote._fiber = fiber.self() @@ -421,7 +439,7 @@ local function new_sm(uri, opts) end remote._callback = callback local transport = internal.new_transport( - uri, user, password, weak_callback, + uri_or_fd, user, password, weak_callback, opts.connect_timeout, opts.reconnect_after, opts.fetch_schema, opts.auth_type) weak_refs.transport = transport @@ -460,6 +478,18 @@ local function connect(...) return new_sm(uri, opts) end +-- +-- Create a connection from a file descriptor number. +-- +-- The file descriptor should point to a socket and be switched to +-- the non-blocking mode. +-- +local function from_fd(fd, opts) + check_param(fd, 'fd', 'number') + check_param_table(opts, CONNECT_OPTION_TYPES) + return new_sm(fd, opts or {}) +end + local function check_remote_arg(remote, method) if type(remote) ~= 'table' then local fmt = 'Use remote:%s(...) instead of remote.%s(...):' @@ -1233,6 +1263,7 @@ end this_module = { connect = connect, new = connect, -- Tarantool < 1.7.1 compatibility, + from_fd = from_fd, } function this_module.timeout(timeout, ...) diff --git a/test/box-luatest/gh_8928_net_box_from_fd_test.lua b/test/box-luatest/gh_8928_net_box_from_fd_test.lua new file mode 100644 index 0000000000000000000000000000000000000000..0995121d7041736791185c4d22be4912572dc72b --- /dev/null +++ b/test/box-luatest/gh_8928_net_box_from_fd_test.lua @@ -0,0 +1,144 @@ +local fiber = require('fiber') +local net = require('net.box') +local server = require('luatest.server') +local socket = require('socket') +local yaml = require('yaml') +local t = require('luatest') + +local g = t.group() + +local function serialize(x) + return yaml.decode(yaml.encode(x)) +end + +g.before_all(function(cg) + cg.server = server:new() + cg.server:start() + cg.server:exec(function() + box.session.su('admin', function() + box.schema.user.create('alice', {password = 'secret'}) + box.schema.user.grant('alice', 'super') + end) + end) + cg.connect = function() + return socket.tcp_connect('unix/', cg.server.net_box_uri) + end +end) + +g.after_all(function(cg) + cg.server:drop() +end) + +-- Checks errors raised on invalid arguments. +g.test_invalid_args = function() + t.assert_error_msg_equals( + "Illegal parameters, fd should be a number", + net.from_fd) + t.assert_error_msg_equals( + "Illegal parameters, options should be a table", + net.from_fd, 0, 0) + t.assert_error_msg_equals( + "Illegal parameters, unexpected option 'foo'", + net.from_fd, 0, {foo = 'bar'}) + t.assert_error_msg_equals( + "Illegal parameters, " .. + "options parameter 'user' should be of type string", + net.from_fd, 0, {user = 123}) + t.assert_error_msg_equals( + "Illegal parameters, " .. + "options parameter 'fetch_schema' should be of type boolean", + net.from_fd, 0, {fetch_schema = 'foo'}) + t.assert_error_msg_equals( + "Invalid fd: expected nonnegative integer", + net.from_fd, -1) + t.assert_error_msg_equals( + "Invalid fd: expected nonnegative integer", + net.from_fd, 2 ^ 31) + t.assert_error_msg_equals( + "Invalid fd: expected nonnegative integer", + net.from_fd, 1.5) +end + +-- Checks basic functionality. +g.test_basic = function(cg) + local s = cg.connect() + local fd = s:fd() + local c = net.from_fd(fd) + s:detach() + t.assert_equals(c.state, 'active') + + t.assert_equals(c.fd, fd) + t.assert_is(c.host, nil) + t.assert_is(c.port, nil) + + local v = serialize(c) + t.assert_equals(v.fd, fd) + t.assert_is(v.host, nil) + t.assert_is(v.port, nil) + + t.assert_equals(c:call('box.session.type'), 'binary') + t.assert_equals(c:call('box.session.user'), 'guest') + t.assert_equals(c:call('box.session.peer'), 'unix/:(socket)') + + local c2 = net.connect(cg.server.net_box_uri) + local v2 = serialize(c2) + + v.fd = nil + v2.host = nil + v2.port = nil + t.assert_equals(v, v2) + + c2:close() + c:close() +end + +-- Checks that fd is closed with connection. +g.test_fd_closed = function(cg) + local s = cg.connect() + t.assert_covers(s:name(), {type = 'SOCK_STREAM'}) + local c = net.from_fd(s:fd()) + t.assert_equals(c.state, 'active') + c:close() + t.helpers.retrying({}, function() + t.assert_is(s:name(), nil) + end) + t.assert_not(s:close()) +end + +-- Checks that authentication works. +g.test_auth = function(cg) + local s = cg.connect() + local c = net.from_fd(s:fd(), {user = 'alice', password = 'secret'}) + s:detach() + t.assert_equals(c.state, 'active') + t.assert_equals(c:call('box.session.user'), 'alice') + c:close() +end + +g.after_test('test_reconnect', function() + box.error.injection.set('ERRINJ_NETBOX_IO_ERROR', false) +end) + +-- Checks that reconnect is disabled. +g.test_reconnect = function(cg) + t.tarantool.skip_if_not_debug() + local s = cg.connect() + local c = net.from_fd(s:fd(), {reconnect_after = 0.1}) + s:detach() + t.assert(c:ping()) + box.error.injection.set('ERRINJ_NETBOX_IO_ERROR', true) + t.assert_not(c:ping()) + t.assert_equals(c.state, 'error') + box.error.injection.set('ERRINJ_NETBOX_IO_ERROR', false) + fiber.sleep(0.1) + t.assert_equals(c.state, 'error') + c:close() +end + +-- Passing an invalid fd. +g.test_invalid_fd = function() + local c = net.from_fd(9000) + t.assert_equals(c.state, 'error') + t.assert_str_contains(c.error, 'Bad file descriptor') + c:close() +end