diff --git a/src/lua/bsdsocket.lua b/src/lua/bsdsocket.lua index 6273f442bff0ff1f854b20da25364214ac236249..d0b0d1f79de73aec127f15a7b93b06f90c5e9d54 100644 --- a/src/lua/bsdsocket.lua +++ b/src/lua/bsdsocket.lua @@ -7,6 +7,9 @@ local TIMEOUT_INFINITY = 500 * 365 * 86400 local ffi = require 'ffi' ffi.cdef[[ + struct socket { + int fd; + }; typedef uint32_t socklen_t; typedef ptrdiff_t ssize_t; @@ -15,9 +18,9 @@ ffi.cdef[[ int bind(int sockfd, const struct sockaddr *addr, socklen_t addrlen); - ssize_t write(int fh, const char *octets, size_t len); + ssize_t write(int fd, const char *octets, size_t len); ssize_t read(int fd, void *buf, size_t count); - int listen(int fh, int backlog); + int listen(int fd, int backlog); int socket(int domain, int type, int protocol); int close(int s); int shutdown(int s, int how); @@ -30,7 +33,7 @@ ffi.cdef[[ int bsdsocket_local_resolve(const char *host, const char *port, struct sockaddr *addr, socklen_t *socklen); - int bsdsocket_nonblock(int fh, int mode); + int bsdsocket_nonblock(int fd, int mode); int setsockopt(int s, int level, int iname, const void *opt, size_t optlen); int getsockopt(int s, int level, int iname, void *ptr, size_t *optlen); @@ -59,6 +62,34 @@ local function printf(fmt, ...) print(sprintf(fmt, ...)) end +local socket_t = ffi.typeof('struct socket'); + +local function socket_cdata_gc(socket) + ffi.C.close(socket.fd) +end + +local function check_socket(self) + local socket = type(self) == 'table' and self.socket + if not ffi.istype(socket_t, socket) then + error('Usage: socket:method()'); + end + return socket.fd +end + +local function bless_socket(fd) + -- Make socket to be non-blocked by default + if ffi.C.bsdsocket_nonblock(fd, 1) < 0 then + local errno = box.errno() + ffi.C.close(fd) + box.errno(errno) + return nil + end + + local socket = ffi.new(socket_t, fd); + ffi.gc(socket, socket_cdata_gc) + return setmetatable({ socket = socket }, box.socket.internal.socket_mt) +end + local function get_ivalue(table, key) if type(key) == 'number' then return key @@ -85,6 +116,7 @@ end local socket_methods = {} socket_methods.errno = function(self) + check_socket(self) if self['_errno'] == nil then return 0 else @@ -93,6 +125,7 @@ socket_methods.errno = function(self) end socket_methods.error = function(self) + check_socket(self) if self['_errno'] == nil then return nil else @@ -103,6 +136,7 @@ end local addr = ffi.new('struct sockaddr') local addr_len = ffi.new('socklen_t[1]') socket_methods.sysconnect = function(self, host, port) + local fd = check_socket(self) self._errno = nil host = tostring(host) @@ -111,7 +145,7 @@ socket_methods.sysconnect = function(self, host, port) addr_len[0] = ffi.sizeof(addr) local res = ffi.C.bsdsocket_local_resolve(host, port, addr, addr_len) if res == 0 then - res = ffi.C.connect(self.fh, addr, addr_len[0]); + res = ffi.C.connect(fd, addr, addr_len[0]); if res == 0 then return true end @@ -121,8 +155,9 @@ socket_methods.sysconnect = function(self, host, port) end socket_methods.syswrite = function(self, octets) + local fd = check_socket(self) self._errno = nil - local done = ffi.C.write(self.fh, octets, string.len(octets)) + local done = ffi.C.write(fd, octets, string.len(octets)) if done < 0 then self._errno = box.errno() return nil @@ -131,9 +166,10 @@ socket_methods.syswrite = function(self, octets) end socket_methods.sysread = function(self, len) + local fd = check_socket(self) self._errno = nil local buf = ffi.new('char[?]', len) - local res = ffi.C.read(self.fh, buf, len) + local res = ffi.C.read(fd, buf, len) if res < 0 then self._errno = box.errno() @@ -145,16 +181,17 @@ socket_methods.sysread = function(self, len) end socket_methods.nonblock = function(self, nb) + local fd = check_socket(self) self._errno = nil local res if nb == nil then - res = ffi.C.bsdsocket_nonblock(self.fh, 0x80) + res = ffi.C.bsdsocket_nonblock(fd, 0x80) elseif nb then - res = ffi.C.bsdsocket_nonblock(self.fh, 1) + res = ffi.C.bsdsocket_nonblock(fd, 1) else - res = ffi.C.bsdsocket_nonblock(self.fh, 0) + res = ffi.C.bsdsocket_nonblock(fd, 0) end if res < 0 then @@ -170,48 +207,38 @@ socket_methods.nonblock = function(self, nb) end local function wait_safely(self, what, timeout) + local fd = check_socket(self) local f = box.fiber.self() local fid = f:id() + self._errno = nil + timeout = timeout or TIMEOUT_INFINITY + if self.waiters == nil then self.waiters = {} end self.waiters[fid] = f - local res = box.socket.internal.iowait(self.fh, what, timeout) + local wres = box.socket.internal.iowait(fd, what, timeout) self.waiters[fid] = nil box.fiber.testcancel() - return res -end - -socket_methods.readable = function(self, timeout) - self._errno = nil - if timeout == nil then - timeout = TIMEOUT_INFINITY - end - - local wres = wait_safely(self, 0, timeout) if wres == 0 then self._errno = box.errno.ETIMEDOUT - return false + return 0 end - return true + return wres end -socket_methods.wait = function(self, timeout) - self._errno = nil - if timeout == nil then - timeout = TIMEOUT_INFINITY - end +socket_methods.readable = function(self, timeout) + return wait_safely(self, 0, timeout) ~= 0 +end +socket_methods.wait = function(self, timeout) local wres = wait_safely(self, 2, timeout) - if wres == 0 then - self._errno = box.errno.ETIMEDOUT - return + return nil end - local res = '' if bit.band(wres, 1) ~= 0 then res = res .. 'R' @@ -223,26 +250,16 @@ socket_methods.wait = function(self, timeout) end socket_methods.writable = function(self, timeout) - self._errno = nil - if timeout == nil then - timeout = TIMEOUT_INFINITY - end - - local wres = wait_safely(self, 1, timeout) - - if wres == 0 then - self._errno = box.errno.ETIMEDOUT - return false - end - return true + return wait_safely(self, 1, timeout) ~= 0 end socket_methods.listen = function(self, backlog) + local fd = check_socket(self) self._errno = nil if backlog == nil then backlog = 256 end - local res = ffi.C.listen(self.fh, backlog) + local res = ffi.C.listen(fd, backlog) if res < 0 then self._errno = box.errno() return false @@ -251,6 +268,7 @@ socket_methods.listen = function(self, backlog) end socket_methods.bind = function(self, host, port) + local fd = check_socket(self) self._errno = nil host = tostring(host) @@ -259,7 +277,7 @@ socket_methods.bind = function(self, host, port) addr_len[0] = ffi.sizeof(addr) local res = ffi.C.bsdsocket_local_resolve(host, port, addr, addr_len) if res == 0 then - res = ffi.C.bind(self.fh, addr, addr_len[0]); + res = ffi.C.bind(fd, addr, addr_len[0]); end if res == 0 then return true @@ -270,6 +288,7 @@ socket_methods.bind = function(self, host, port) end socket_methods.close = function(self) + local fd = check_socket(self) if self.waiters ~= nil then for fid, fiber in pairs(self.waiters) do fiber:wakeup() @@ -278,14 +297,16 @@ socket_methods.close = function(self) end self._errno = nil - if ffi.C.close(self.fh) < 0 then + if ffi.C.close(fd) < 0 then self._errno = box.errno() return false end + ffi.gc(self.socket, nil) return true end socket_methods.shutdown = function(self, how) + local fd = check_socket(self) local hvariants = { ['R'] = 0, ['READ'] = 0, @@ -304,7 +325,7 @@ socket_methods.shutdown = function(self, how) ihow = 2 end self._errno = nil - if ffi.C.shutdown(self.fh, ihow) < 0 then + if ffi.C.shutdown(fd, ihow) < 0 then self._errno = box.errno() return false end @@ -312,6 +333,8 @@ socket_methods.shutdown = function(self, how) end socket_methods.setsockopt = function(self, level, name, value) + local fd = check_socket(self) + local info = get_ivalue(box.socket.internal.SO_OPT, name) if info == nil then @@ -349,7 +372,7 @@ socket_methods.setsockopt = function(self, level, name, value) if info.type == 1 then local value = ffi.new("int[1]", value) - local res = ffi.C.setsockopt(self.fh, + local res = ffi.C.setsockopt(fd, level, info.iname, value, ffi.sizeof('int')) if res < 0 then @@ -360,7 +383,7 @@ socket_methods.setsockopt = function(self, level, name, value) end if info.type == 2 then - local res = ffi.C.setsockopt(self.fh, + local res = ffi.C.setsockopt(fd, level, info.iname, value, ffi.sizeof('size_t')) if res < 0 then self._errno = box.errno() @@ -376,6 +399,8 @@ socket_methods.setsockopt = function(self, level, name, value) end socket_methods.getsockopt = function(self, level, name) + local fd = check_socket(self) + local info = get_ivalue(box.socket.internal.SO_OPT, name) if info == nil then @@ -403,7 +428,7 @@ socket_methods.getsockopt = function(self, level, name) if info.type == 1 then local value = ffi.new("int[1]", 0) local len = ffi.new("size_t[1]", ffi.sizeof('int')) - local res = ffi.C.getsockopt(self.fh, level, info.iname, value, len) + local res = ffi.C.getsockopt(fd, level, info.iname, value, len) if res < 0 then self._errno = box.errno() @@ -419,7 +444,7 @@ socket_methods.getsockopt = function(self, level, name) if info.type == 2 then local value = ffi.new("char[256]", { 0 }) local len = ffi.new("size_t[1]", 256) - local res = ffi.C.getsockopt(self.fh, level, info.iname, value, len) + local res = ffi.C.getsockopt(fd, level, info.iname, value, len) if res < 0 then self._errno = box.errno() return nil @@ -434,13 +459,14 @@ socket_methods.getsockopt = function(self, level, name) end socket_methods.linger = function(self, active, timeout) + local fd = check_socket(self) local info = box.socket.internal.SO_OPT.SO_LINGER self._errno = nil if active == nil then local value = ffi.new("linger_t[1]") local len = ffi.new("size_t[1]", 2 * ffi.sizeof('int')) - local res = ffi.C.getsockopt(self.fh, + local res = ffi.C.getsockopt(fd, box.socket.internal.SOL_SOCKET, info.iname, value, len) if res < 0 then self._errno = box.errno() @@ -468,7 +494,7 @@ socket_methods.linger = function(self, active, timeout) local value = ffi.new("linger_t[1]", { { active = iactive, timeout = timeout } }) local len = 2 * ffi.sizeof('int') - local res = ffi.C.setsockopt(self.fh, + local res = ffi.C.setsockopt(fd, box.socket.internal.SOL_SOCKET, info.iname, value, len) if res < 0 then self._errno = box.errno() @@ -479,25 +505,16 @@ socket_methods.linger = function(self, active, timeout) end socket_methods.accept = function(self) - + local fd = check_socket(self) self._errno = nil - local fh = ffi.C.accept(self.fh, nil, nil) + local cfd = ffi.C.accept(fd, nil, nil) - if fh < 1 then + if cfd < 1 then self._errno = box.errno() return nil end - - fh = tonumber(fh) - - -- Make socket to be non-blocked by default - -- ignore result - ffi.C.bsdsocket_nonblock(fh, 1) - - local socket = { fh = fh } - setmetatable(socket, box.socket.internal.socket_mt) - return socket + return bless_socket(cfd) end local function readchunk(self, size, timeout) @@ -627,6 +644,7 @@ local function readline(self, limit, eol, timeout) end socket_methods.read = function(self, opts, timeout) + check_socket(self) timeout = timeout and tonumber(timeout) or TIMEOUT_INFINITY if type(opts) == 'number' then return readchunk(self, opts, timeout) @@ -647,6 +665,7 @@ socket_methods.read = function(self, opts, timeout) end socket_methods.write = function(self, octets, timeout) + check_socket(self) if timeout == nil then timeout = TIMEOUT_INFINITY end @@ -672,10 +691,11 @@ socket_methods.write = function(self, octets, timeout) end socket_methods.send = function(self, octets, flags) + local fd = check_socket(self) local iflags = get_iflags(box.socket.internal.SEND_FLAGS, flags) self._errno = nil - local res = ffi.C.send(self.fh, octets, string.len(octets), iflags) + local res = ffi.C.send(fd, octets, string.len(octets), iflags) if res == -1 then self._errno = box.errno() return false @@ -684,6 +704,7 @@ socket_methods.send = function(self, octets, flags) end socket_methods.recv = function(self, size, flags) + local fd = check_socket(self) local iflags = get_iflags(box.socket.internal.SEND_FLAGS, flags) if iflags == nil then self._errno = box.errno.EINVAL @@ -696,7 +717,7 @@ socket_methods.recv = function(self, size, flags) self._errno = nil local buf = ffi.new("char[?]", size) - local res = ffi.C.recv(self.fh, buf, size, iflags) + local res = ffi.C.recv(fd, buf, size, iflags) if res == -1 then self._errno = box.errno() @@ -706,13 +727,14 @@ socket_methods.recv = function(self, size, flags) end socket_methods.recvfrom = function(self, size, flags) + local fd = check_socket(self) local iflags = get_iflags(box.socket.internal.SEND_FLAGS, flags) if iflags == nil then self._errno = box.errno.EINVAL return nil end self._errno = nil - local res, from = box.socket.internal.recvfrom(self.fh, size, iflags) + local res, from = box.socket.internal.recvfrom(fd, size, iflags) if res == nil then self._errno = box.errno() return nil @@ -721,6 +743,7 @@ socket_methods.recvfrom = function(self, size, flags) end socket_methods.sendto = function(self, host, port, octets, flags) + local fd = check_socket(self) local iflags = get_iflags(box.socket.internal.SEND_FLAGS, flags) if iflags == nil then @@ -740,7 +763,7 @@ socket_methods.sendto = function(self, host, port, octets, flags) addr_len[0] = ffi.sizeof(addr) local res = ffi.C.bsdsocket_local_resolve(host, port, addr, addr_len) if res == 0 then - res = ffi.C.sendto(self.fh, octets, string.len(octets), iflags, + res = ffi.C.sendto(fd, octets, string.len(octets), iflags, addr, addr_len[0]) end if res < 0 then @@ -777,24 +800,11 @@ local function create_socket(domain, stype, proto) iproto = p.p_proto end - local fh = ffi.C.socket(idomain, itype, iproto) - if fh < 0 then + local fd = ffi.C.socket(idomain, itype, iproto) + if fd < 1 then return nil end - - fh = tonumber(fh) - - -- Make socket to be non-blocked by default - if ffi.C.bsdsocket_nonblock(fh, 1) < 0 then - local errno = box.errno() - ffi.C.close(fh) - box.errno(errno) - return nil - end - - local socket = { fh = fh } - setmetatable(socket, box.socket.internal.socket_mt) - return socket + return bless_socket(fd) end local function getaddrinfo(host, port, timeout, opts) @@ -875,7 +885,8 @@ local soname_mt = { } socket_methods.name = function(self) - local aka = box.socket.internal.name(self.fh) + local fd = check_socket(self) + local aka = box.socket.internal.name(fd) if aka == nil then self._errno = box.errno() return nil @@ -886,7 +897,8 @@ socket_methods.name = function(self) end socket_methods.peer = function(self) - local peer = box.socket.internal.peer(self.fh) + local fd = check_socket(self) + local peer = box.socket.internal.peer(fd) if peer == nil then self._errno = box.errno() return nil @@ -896,6 +908,10 @@ socket_methods.peer = function(self) return peer end +socket_methods.fd = function(self) + return check_socket(self) +end + -- tcp connector local function tcp_connect_remote(remote, timeout) local s = create_socket(remote.family, remote.type, remote.protocol) @@ -1010,7 +1026,7 @@ local function tcp_server_remote(list, prepare, handler) for _, s in pairs(server.s) do box.fiber.wrap(function(s) - box.fiber.name(sprintf("listen_fd=%d",s.fh)) + box.fiber.name(sprintf("listen_fd=%d", s:fd())) while s:readable() do @@ -1059,8 +1075,10 @@ box.socket.internal = { socket_mt = { __index = socket_methods, __tostring = function(self) + local fd = check_socket(self) + local save_errno = self._errno - local name = sprintf("fd %d", self.fh) + local name = sprintf("fd %d", fd) local aka = self:name() if aka ~= nil then name = sprintf("%s, aka %s:%s", name, aka.host, aka.port) diff --git a/test/box/bsdsocket.result b/test/box/bsdsocket.result index adb29993ece401f6a5996853d5eb658f4ccdf7b9..8242bbbf89e0006c3bad4ff61182f69408946481 100644 --- a/test/box/bsdsocket.result +++ b/test/box/bsdsocket.result @@ -30,6 +30,9 @@ lua type(s:error()) --- - nil ... +lua for k in pairs(getmetatable(s).__index) do local r, msg = pcall(s[k]); if not msg:match('Usage:') then print(k) end end +--- +... lua s:nonblock(false) --- - false @@ -219,7 +222,7 @@ lua s:getsockopt('SOL_SOCKET', 'SO_DEBUG') ... lua s:setsockopt('SOL_SOCKET', 'SO_ACCEPTCONN', 1) --- -error: '[string "-- bsdsocket.lua (internal file)..."]:322: Socket option SO_ACCEPTCONN is read only' +error: '[string "-- bsdsocket.lua (internal file)..."]:345: Socket option SO_ACCEPTCONN is read only' ... lua s:getsockopt('SOL_SOCKET', 'SO_RCVBUF') > 32 --- @@ -458,7 +461,7 @@ lua s:bind('unix/', '/tmp/tarantool-test-socket') --- - true ... -lua string.match(tostring(sc), 'fd %d+, aka unix/:/tmp/tarantool%-test%-socket') ~= nil +lua string.match(tostring(s), 'fd %d+, aka unix/:/tmp/tarantool%-test%-socket') ~= nil --- - true ... @@ -816,12 +819,20 @@ lua box.socket.tcp_connect('unix/', '/tmp/tarantool-test-socket'), box.errno() = - nil - true ... -lua s = box.socket.tcp_connect('127.0.0.1', box.cfg.primary_port) +lua sa = box.socket.tcp_connect('127.0.0.1', box.cfg.primary_port) +--- +... +lua sa:close() --- + - true ... -lua sa = { fh = 512 } setmetatable(sa, getmetatable(s)) +lua sa.socket.fd = 512 --- ... +lua sa:fd() +--- + - 512 +... lua tostring(sa) --- - fd 512 @@ -834,6 +845,12 @@ lua sa:writable(0) --- - true ... +lua sa = nil +--- +... +lua s = box.socket.tcp_connect('127.0.0.1', box.cfg.primary_port) +--- +... lua ch = box.ipc.channel() --- ... @@ -1022,3 +1039,32 @@ lua server:stop() --- - true ... +lua s = box.socket('AF_UNIX', 'SOCK_STREAM', 'ip') +--- +... +lua s:bind('unix/', '/tmp/tarantool-test-socket') +--- + - true +... +lua s:listen() +--- + - true +... +lua s = nil +--- +... +lua collectgarbage('collect') +--- + - 0 +... +lua collectgarbage('collect') +--- + - 0 +... +lua client, errno = box.socket.tcp_connect('unix/', '/tmp/tarantool-test-socket'), box.errno() +--- +... +lua errno == box.errno.ECONNREFUSED +--- + - true +... diff --git a/test/box/bsdsocket.test b/test/box/bsdsocket.test index 30754928cb2afcceeadb1711103dad12684c495e..92e89b2a9d634cbc687f13d41b2d820459f842a0 100644 --- a/test/box/bsdsocket.test +++ b/test/box/bsdsocket.test @@ -13,6 +13,8 @@ exec admin "lua type(s)" exec admin "lua string.match(tostring(s), 'fd %d+, aka 0%.0%.0%.0:0') ~= nil" exec admin "lua s:errno()" exec admin "lua type(s:error())" +# Invalid arguments +exec admin "lua for k in pairs(getmetatable(s).__index) do local r, msg = pcall(s[k]); if not msg:match('Usage:') then print(k) end end" exec admin "lua s:nonblock(false)" exec admin "lua s:sysconnect('127.0.0.1', box.cfg.primary_port)" @@ -139,7 +141,7 @@ exec admin "lua s:nonblock()" if os.path.exists('/tmp/tarantool-test-socket'): os.unlink('/tmp/tarantool-test-socket') exec admin "lua s:bind('unix/', '/tmp/tarantool-test-socket')" -exec admin "lua string.match(tostring(sc), 'fd %d+, aka unix/:/tmp/tarantool%-test%-socket') ~= nil" +exec admin "lua string.match(tostring(s), 'fd %d+, aka unix/:/tmp/tarantool%-test%-socket') ~= nil" exec admin "lua s:listen(1234)" exec admin "lua sc = box.socket('PF_UNIX', 'SOCK_STREAM', 'ip')" @@ -262,12 +264,16 @@ if os.path.exists(path): os.unlink(path) exec admin "lua box.socket.tcp_connect('unix/', '{}'), box.errno() == box.errno.ENOENT".format(path) -exec admin "lua s = box.socket.tcp_connect('127.0.0.1', box.cfg.primary_port)" -exec admin "lua sa = { fh = 512 } setmetatable(sa, getmetatable(s))" +exec admin "lua sa = box.socket.tcp_connect('127.0.0.1', box.cfg.primary_port)" +exec admin "lua sa:close()" +exec admin "lua sa.socket.fd = 512" +exec admin "lua sa:fd()" exec admin "lua tostring(sa)" exec admin "lua sa:readable(0)" exec admin "lua sa:writable(0)" +exec admin "lua sa = nil" +exec admin "lua s = box.socket.tcp_connect('127.0.0.1', box.cfg.primary_port)" exec admin "lua ch = box.ipc.channel()" exec admin "lua f = box.fiber.wrap(function() s:read(12) ch:put(true) end)" exec admin "lua box.fiber.sleep(.1)" @@ -328,3 +334,14 @@ exec admin "lua client:read(123)" exec admin "lua server:stop()" os.unlink(path) + +# Test that socket is closed on GC +exec admin "lua s = box.socket('AF_UNIX', 'SOCK_STREAM', 'ip')" +exec admin "lua s:bind('unix/', '{}')".format(path) +exec admin "lua s:listen()" +exec admin "lua s = nil" +exec admin "lua collectgarbage('collect')" +exec admin "lua collectgarbage('collect')" +exec admin "lua client, errno = box.socket.tcp_connect('unix/', '{}'), box.errno()".format(path) +exec admin "lua errno == box.errno.ECONNREFUSED" +os.unlink(path)