diff --git a/src/lua/bsdsocket.cc b/src/lua/bsdsocket.cc index 9e1e1693ad30c545036faeb8955df280d6bf9fe0..422571af9aef4323fe2188c44657682353a013c4 100644 --- a/src/lua/bsdsocket.cc +++ b/src/lua/bsdsocket.cc @@ -336,8 +336,8 @@ bsdsocket_local_resolve(const char *host, const char *port, } /* IPv6 */ - char ipv6[16]; - if (inet_pton(AF_INET6, host, ipv6) == 1) { + struct in6_addr ipv6; + if (inet_pton(AF_INET6, host, &ipv6) == 1) { struct sockaddr_in6 *inaddr6 = (struct sockaddr_in6 *) addr; if (*socklen < sizeof(*inaddr6)) { errno = ENOBUFS; @@ -345,8 +345,8 @@ bsdsocket_local_resolve(const char *host, const char *port, } memset(inaddr6, 0, sizeof(*inaddr6)); inaddr6->sin6_family = AF_INET6; - inaddr6->sin6_port = htonl(atol(port)); - memcpy(inaddr6->sin6_addr.s6_addr, ipv6, 16); + inaddr6->sin6_port = htons(atoi(port)); + memcpy(inaddr6->sin6_addr.s6_addr, &ipv6, sizeof(ipv6)); *socklen = sizeof(*inaddr6); return 0; } @@ -804,6 +804,24 @@ lbox_bsdsocket_peername(struct lua_State *L) return 1; } +static int +lbox_bsdsocket_accept(struct lua_State *L) +{ + int fh = lua_tointeger(L, 1); + + struct sockaddr_storage fa; + socklen_t len = sizeof(fa); + + int sc = accept(fh, (struct sockaddr*)&fa, &len); + if (sc < 0) { + lua_pushnil(L); + return 1; + } + lua_pushnumber(L, sc); + lbox_bsdsocket_push_addr(L, (struct sockaddr *)&fa, len); + return 2; +} + static int lbox_bsdsocket_recvfrom(struct lua_State *L) { @@ -860,6 +878,7 @@ tarantool_lua_bsdsocket_init(struct lua_State *L) { "peer", lbox_bsdsocket_peername }, { "recvfrom", lbox_bsdsocket_recvfrom }, { "abort", lbox_bsdsocket_abort }, + { "accept", lbox_bsdsocket_accept }, { NULL, NULL } }; diff --git a/src/lua/bsdsocket.lua b/src/lua/bsdsocket.lua index 871bdbd356a6fdf65f1d67199b3681c9b0a3dffc..e7aa926cda7ada49175a24c5132020c859babdab 100644 --- a/src/lua/bsdsocket.lua +++ b/src/lua/bsdsocket.lua @@ -504,12 +504,12 @@ socket_methods.accept = function(self) local fd = check_socket(self) self._errno = nil - local cfd = ffi.C.accept(fd, nil, nil) - if cfd < 1 then + local cfd, from = internal.accept(fd) + if cfd == nil then self._errno = box.errno() return nil end - return bless_socket(cfd) + return bless_socket(cfd), from end local function readchunk(self, size, timeout) @@ -964,96 +964,80 @@ local function tcp_connect(host, port, timeout) return nil end -local function tcp_server_remote(list, prepare, handler) - local slist = {} - - -- bind/create sockets - for _, addr in pairs(list) do - local s = create_socket(addr.family, addr.type, addr.protocol) - - local ok = false - if s ~= nil then - local backlog = prepare(s) - if s:bind(addr.host, addr.port) then - if s:listen(backlog) then - ok = true - end - end - end - - -- errors - if not ok then - if s ~= nil then - s:close() - end - local save_errno = boxerrno() - for _, s in pairs(slist) do - s:close() - end - boxerrno(save_errno) - return nil - end - - table.insert(slist, s) - end - - local server = { s = slist } +local function tcp_server_handler(server, sc, from) + fiber.name(sprintf("%s/client/%s:%s", server.name, from.host, from.port)) + server.handler(sc, from) + sc:close() +end - server.stop = function() - if #server.s == 0 then - return false +local function tcp_server_loop(server, s, addr) + fiber.name(sprintf("%s/listen/%s:%s", server.name, addr.host, addr.port)) + while s:readable() do + local sc, from = s:accept() + if sc == nil then + break end - for _, s in pairs(server.s) do - s:close() - end - server.s = {} - return true + fiber.create(tcp_server_handler, server, sc, from) end - - for _, s in pairs(server.s) do - fiber.create(function(s) - fiber.name(sprintf("listen_fd=%d",s:fd())) - - while s:readable() do - - local sc = s:accept() - - if sc == nil then - break - end - - fiber.create(function(sc) - pcall(handler, sc) - sc:close() - end, sc) - end - end, s) + if addr.family == 'AF_UNIX' and addr.port then + os.remove(addr.port) -- remove unix socket end - - return server end -local function tcp_server(host, port, prepare, handler, timeout) - if handler == nil then - handler = prepare - prepare = function() end - end +local function tcp_server_usage() + error('Usage: socket.tcp_server(host, port, handler | opts)') +end - if type(prepare) ~= 'function' or type(handler) ~= 'function' then - error("Usage: socket.tcp_server(host, port[, prepare], handler)") +local function tcp_server(host, port, opts, timeout) + local server = {} + if type(opts) == 'function' then + server.handler = opts + elseif type(opts) == 'table' then + if type(opts.handler) ~='function' or (opts.prepare ~= nil and + type(opts.prepare) ~= 'function') then + tcp_server_usage() + end + for k, v in pairs(opts) do + server[k] = v + end + else + tcp_server_usage() end - + server.name = server.name or 'server' + timeout = timeout and tonumber(timeout) or TIMEOUT_INFINITY + local dns if host == 'unix/' then - return tcp_server_remote({{host = host, port = port, protocol = 0, - family = 'PF_UNIX', type = 'SOCK_STREAM' }}, prepare, handler) + dns = {{host = host, port = port, family = 'AF_UNIX', protocol = 0, + type = 'SOCK_STREAM' }} + else + dns = getaddrinfo(host, port, timeout, { type = 'SOCK_STREAM' }) + if dns == nil then + return nil + end end - local dns = getaddrinfo(host, port, timeout, { type = 'SOCK_STREAM', - protocol = 'tcp' }) - if dns == nil then - return nil + for _, addr in ipairs(dns) do + local s = create_socket(addr.family, addr.type, addr.protocol) + if s ~= nil then + local backlog + if server.prepare then + backlog = server.prepare(s) + else + s:setsockopt('SOL_SOCKET', 'SO_REUSEADDR', 1) -- ignore error + end + if not s:bind(addr.host, addr.port) or not s:listen(backlog) then + local save_errno = boxerrno() + s:close() + boxerrno(save_errno) + return nil + end + fiber.create(tcp_server_loop, server, s, addr) + return s, addr + end end - return tcp_server_remote(dns, prepare, handler) + -- DNS resolved successfully, but addresss family is not supported + boxerrno(boxerrno.EAFNOSUPPORT) + return nil end socket_mt = { diff --git a/test/box/bsdsocket.result b/test/box/bsdsocket.result index ccc0973dea38030d9b48b65242cb2c55859ef2ea..8e8fc8adbb0a34515a9a3bd0d4b6d1b039a5a647 100644 --- a/test/box/bsdsocket.result +++ b/test/box/bsdsocket.result @@ -22,6 +22,9 @@ log = require 'log' errno = require 'errno' --- ... +fio = require 'fio' +--- +... type(socket) --- - table @@ -349,8 +352,19 @@ sc:write('Hello, world') --- - true ... -sa = s:accept() +sa, addr = s:accept() +--- +... +addr2 = sa:name() +--- +... +addr2.host == addr.host +--- +- true +... +addr2.family == addr.family --- +- true ... sa:nonblock(1) --- @@ -1209,9 +1223,13 @@ os.remove(path) --- - true ... -server = socket.tcp_server('unix/', path, function(s) s:write('Hello, world') end) +server, addr = socket.tcp_server('unix/', path, function(s) s:write('Hello, world') end) --- ... +type(addr) +--- +- table +... server ~= nil --- - true @@ -1230,11 +1248,83 @@ client:read(123) --- - Hello, world ... -server:stop() +server:close() --- - true ... -os.remove(path) +-- unix socket automatically removed +fio.stat(path) == nil +--- +- true +... +--# setopt delimiter ';' +server, addr = socket.tcp_server('localhost', 0, { handler = function(s) + s:read(2) + s:write('Hello, world') +end, name = 'testserv'}); +--- +... +--# setopt delimiter '' +type(addr) +--- +- table +... +server ~= nil +--- +- true +... +addr2 = server:name() +--- +... +addr.host == addr2.host +--- +- true +... +addr.family == addr2.family +--- +- true +... +fiber.sleep(.5) +--- +... +client = socket.tcp_connect(addr2.host, addr2.port) +--- +... +client ~= nil +--- +- true +... +-- Check that listen and client fibers have appropriate names +cnt = 0 +--- +... +--# setopt delimiter ';' +for i=100,200 do + local f = fiber.find(i) + if f and f:name():match('^testserv/') then + cnt = cnt + 1 + end +end; +--- +... +--# setopt delimiter '' +cnt +--- +- 2 +... +client:write('hi') +--- +- true +... +client:read(123) +--- +- Hello, world +... +client:close() +--- +- true +... +server:close() --- - true ... @@ -1275,11 +1365,7 @@ client:read{ line = { "\n\n", "\r\n\r\n" } } --- - "Hello\r\n\r\n" ... -server:stop() ---- -- true -... -os.remove(path) +server:close() --- - true ... diff --git a/test/box/bsdsocket.test.lua b/test/box/bsdsocket.test.lua index 7ead95599262bb1498d51354ba5cd88928b7ed1b..e18bdecc5d46513d63efd0b1b6998220a8c8338a 100644 --- a/test/box/bsdsocket.test.lua +++ b/test/box/bsdsocket.test.lua @@ -6,6 +6,7 @@ fiber = require 'fiber' msgpack = require 'msgpack' log = require 'log' errno = require 'errno' +fio = require 'fio' type(socket) socket('PF_INET', 'SOCK_STREAM', 'tcp121222'); @@ -114,7 +115,10 @@ sc:writable(10) sc:write('Hello, world') -sa = s:accept() +sa, addr = s:accept() +addr2 = sa:name() +addr2.host == addr.host +addr2.family == addr.family sa:nonblock(1) sa:read(8) sa:read(3) @@ -404,37 +408,59 @@ s:close() os.remove(path) -server = socket.tcp_server('unix/', path, function(s) s:write('Hello, world') end) +server, addr = socket.tcp_server('unix/', path, function(s) s:write('Hello, world') end) +type(addr) server ~= nil fiber.sleep(.5) client = socket.tcp_connect('unix/', path) client ~= nil client:read(123) -server:stop() -os.remove(path) +server:close() +-- unix socket automatically removed +fio.stat(path) == nil +--# setopt delimiter ';' +server, addr = socket.tcp_server('localhost', 0, { handler = function(s) + s:read(2) + s:write('Hello, world') +end, name = 'testserv'}); +--# setopt delimiter '' +type(addr) +server ~= nil +addr2 = server:name() +addr.host == addr2.host +addr.family == addr2.family +fiber.sleep(.5) +client = socket.tcp_connect(addr2.host, addr2.port) +client ~= nil +-- Check that listen and client fibers have appropriate names +cnt = 0 +--# setopt delimiter ';' +for i=100,200 do + local f = fiber.find(i) + if f and f:name():match('^testserv/') then + cnt = cnt + 1 + end +end; +--# setopt delimiter '' +cnt +client:write('hi') +client:read(123) +client:close() +server:close() longstring = string.rep("abc", 65535) server = socket.tcp_server('unix/', path, function(s) s:write(longstring) end) - client = socket.tcp_connect('unix/', path) client:read(#longstring) == longstring - client = socket.tcp_connect('unix/', path) client:read(#longstring + 1) == longstring - client = socket.tcp_connect('unix/', path) client:read(#longstring - 1) == string.sub(longstring, 1, #longstring - 1) - - longstring = "Hello\r\n\r\nworld\n\n" - client = socket.tcp_connect('unix/', path) client:read{ line = { "\n\n", "\r\n\r\n" } } - - -server:stop() -os.remove(path) +server:close() -- Test that socket is closed on GC