diff --git a/src/lua/bsdsocket.lua b/src/lua/bsdsocket.lua index 7795d8d93355cc3fe47d6c8a7e00b3416d210834..d2bce9c5bac94939f70e60986bf24bf6c3f060c2 100644 --- a/src/lua/bsdsocket.lua +++ b/src/lua/bsdsocket.lua @@ -1,15 +1,17 @@ -- bsdsocket.lua (internal file) -do - local TIMEOUT_INFINITY = 500 * 365 * 86400 local ffi = require('ffi') local boxerrno = require('errno') local internal = require('socket.internal') local fiber = require('fiber') +package.loaded['socket.internal'] = nil ffi.cdef[[ + struct socket { + int fd; + }; typedef uint32_t socklen_t; typedef ptrdiff_t ssize_t; @@ -18,9 +20,9 @@ ffi.cdef[[ int bind(int sockfd, const struct sockaddr *addr, socklen_t addrlen); - ssize_t write(int fh, const char *octets, size_t len); + ssize_t write(int fd, const char *octets, size_t len); ssize_t read(int fd, void *buf, size_t count); - int listen(int fh, int backlog); + int listen(int fd, int backlog); int socket(int domain, int type, int protocol); int close(int s); int shutdown(int s, int how); @@ -33,7 +35,7 @@ ffi.cdef[[ int bsdsocket_local_resolve(const char *host, const char *port, struct sockaddr *addr, socklen_t *socklen); - int bsdsocket_nonblock(int fh, int mode); + int bsdsocket_nonblock(int fd, int mode); int setsockopt(int s, int level, int iname, const void *opt, size_t optlen); int getsockopt(int s, int level, int iname, void *ptr, size_t *optlen); @@ -48,12 +50,6 @@ ffi.cdef[[ struct protoent *getprotobyname(const char *name); ]] -ffi.cdef([[ -struct sockaddr { - char _data[256]; /* enough to fit any address */ -}; -]]); - local function sprintf(fmt, ...) return string.format(fmt, ...) end @@ -62,6 +58,35 @@ local function printf(fmt, ...) print(sprintf(fmt, ...)) end +local socket_t = ffi.typeof('struct socket'); + +local function socket_cdata_gc(socket) + ffi.C.close(socket.fd) +end + +local function check_socket(self) + local socket = type(self) == 'table' and self.socket + if not ffi.istype(socket_t, socket) then + error('Usage: socket:method()'); + end + return socket.fd +end + +local socket_mt +local function bless_socket(fd) + -- Make socket to be non-blocked by default + if ffi.C.bsdsocket_nonblock(fd, 1) < 0 then + local errno = box.errno() + ffi.C.close(fd) + box.errno(errno) + return nil + end + + local socket = ffi.new(socket_t, fd); + ffi.gc(socket, socket_cdata_gc) + return setmetatable({ socket = socket }, socket_mt) +end + local function get_ivalue(table, key) if type(key) == 'number' then return key @@ -86,9 +111,9 @@ local function get_iflags(table, flags) return res end -local socket_mt local socket_methods = {} socket_methods.errno = function(self) + check_socket(self) if self['_errno'] == nil then return 0 else @@ -97,6 +122,7 @@ socket_methods.errno = function(self) end socket_methods.error = function(self) + check_socket(self) if self['_errno'] == nil then return nil else @@ -104,18 +130,21 @@ socket_methods.error = function(self) end end -local addr = ffi.new('struct sockaddr') +-- addrbuf is equivalent to struct sockaddr_storage +local addrbuf = ffi.new('char[128]') -- enough to fit any address +local addr = ffi.cast('struct sockaddr *', addrbuf) local addr_len = ffi.new('socklen_t[1]') socket_methods.sysconnect = function(self, host, port) + local fd = check_socket(self) self._errno = nil host = tostring(host) port = tostring(port) - addr_len[0] = ffi.sizeof(addr) + addr_len[0] = ffi.sizeof(addrbuf) local res = ffi.C.bsdsocket_local_resolve(host, port, addr, addr_len) if res == 0 then - res = ffi.C.connect(self.fh, addr, addr_len[0]); + res = ffi.C.connect(fd, addr, addr_len[0]); if res == 0 then return true end @@ -125,8 +154,9 @@ socket_methods.sysconnect = function(self, host, port) end socket_methods.syswrite = function(self, octets) + local fd = check_socket(self) self._errno = nil - local done = ffi.C.write(self.fh, octets, string.len(octets)) + local done = ffi.C.write(fd, octets, string.len(octets)) if done < 0 then self._errno = boxerrno() return nil @@ -135,9 +165,10 @@ socket_methods.syswrite = function(self, octets) end socket_methods.sysread = function(self, len) + local fd = check_socket(self) self._errno = nil local buf = ffi.new('char[?]', len) - local res = ffi.C.read(self.fh, buf, len) + local res = ffi.C.read(fd, buf, len) if res < 0 then self._errno = boxerrno() @@ -149,16 +180,17 @@ socket_methods.sysread = function(self, len) end socket_methods.nonblock = function(self, nb) + local fd = check_socket(self) self._errno = nil local res if nb == nil then - res = ffi.C.bsdsocket_nonblock(self.fh, 0x80) + res = ffi.C.bsdsocket_nonblock(fd, 0x80) elseif nb then - res = ffi.C.bsdsocket_nonblock(self.fh, 1) + res = ffi.C.bsdsocket_nonblock(fd, 1) else - res = ffi.C.bsdsocket_nonblock(self.fh, 0) + res = ffi.C.bsdsocket_nonblock(fd, 0) end if res < 0 then @@ -174,48 +206,34 @@ socket_methods.nonblock = function(self, nb) end local function wait_safely(self, what, timeout) + local fd = check_socket(self) local f = fiber.self() local fid = f:id() + self._errno = nil + timeout = timeout or TIMEOUT_INFINITY + if self.waiters == nil then self.waiters = {} end self.waiters[fid] = true - local res = internal.iowait(self.fh, what, timeout) + local res = internal.iowait(fd, what, timeout) self.waiters[fid] = nil fiber.testcancel() + if res == 0 then + self._errno = boxerrno.ETIMEDOUT + return 0 + end return res end socket_methods.readable = function(self, timeout) - self._errno = nil - if timeout == nil then - timeout = TIMEOUT_INFINITY - end - - local wres = wait_safely(self, 0, timeout) - - if wres == 0 then - self._errno = boxerrno.ETIMEDOUT - return false - end - return true + return wait_safely(self, 0, timeout) ~= 0 end socket_methods.wait = function(self, timeout) - self._errno = nil - if timeout == nil then - timeout = TIMEOUT_INFINITY - end - local wres = wait_safely(self, 2, timeout) - - if wres == 0 then - self._errno = boxerrno.ETIMEDOUT - return - end - local res = '' if bit.band(wres, 1) ~= 0 then res = res .. 'R' @@ -227,26 +245,16 @@ socket_methods.wait = function(self, timeout) end socket_methods.writable = function(self, timeout) - self._errno = nil - if timeout == nil then - timeout = TIMEOUT_INFINITY - end - - local wres = wait_safely(self, 1, timeout) - - if wres == 0 then - self._errno = boxerrno.ETIMEDOUT - return false - end - return true + return wait_safely(self, 1, timeout) ~= 0 end socket_methods.listen = function(self, backlog) + local fd = check_socket(self) self._errno = nil if backlog == nil then backlog = 256 end - local res = ffi.C.listen(self.fh, backlog) + local res = ffi.C.listen(fd, backlog) if res < 0 then self._errno = boxerrno() return false @@ -255,15 +263,16 @@ socket_methods.listen = function(self, backlog) end socket_methods.bind = function(self, host, port) + local fd = check_socket(self) self._errno = nil host = tostring(host) port = tostring(port) - addr_len[0] = ffi.sizeof(addr) + addr_len[0] = ffi.sizeof(addrbuf) local res = ffi.C.bsdsocket_local_resolve(host, port, addr, addr_len) if res == 0 then - res = ffi.C.bind(self.fh, addr, addr_len[0]); + res = ffi.C.bind(fd, addr, addr_len[0]); end if res == 0 then return true @@ -274,6 +283,7 @@ socket_methods.bind = function(self, host, port) end socket_methods.close = function(self) + local fd = check_socket(self) if self.waiters ~= nil then for fid in pairs(self.waiters) do internal.abort(fid) @@ -282,14 +292,16 @@ socket_methods.close = function(self) end self._errno = nil - if ffi.C.close(self.fh) < 0 then + if ffi.C.close(fd) < 0 then self._errno = boxerrno() return false end + ffi.gc(self.socket, nil) return true end socket_methods.shutdown = function(self, how) + local fd = check_socket(self) local hvariants = { ['R'] = 0, ['READ'] = 0, @@ -308,7 +320,7 @@ socket_methods.shutdown = function(self, how) ihow = 2 end self._errno = nil - if ffi.C.shutdown(self.fh, ihow) < 0 then + if ffi.C.shutdown(fd, ihow) < 0 then self._errno = boxerrno() return false end @@ -316,6 +328,8 @@ socket_methods.shutdown = function(self, how) end socket_methods.setsockopt = function(self, level, name, value) + local fd = check_socket(self) + local info = get_ivalue(internal.SO_OPT, name) if info == nil then @@ -353,7 +367,7 @@ socket_methods.setsockopt = function(self, level, name, value) if info.type == 1 then local value = ffi.new("int[1]", value) - local res = ffi.C.setsockopt(self.fh, + local res = ffi.C.setsockopt(fd, level, info.iname, value, ffi.sizeof('int')) if res < 0 then @@ -364,7 +378,7 @@ socket_methods.setsockopt = function(self, level, name, value) end if info.type == 2 then - local res = ffi.C.setsockopt(self.fh, + local res = ffi.C.setsockopt(fd, level, info.iname, value, ffi.sizeof('size_t')) if res < 0 then self._errno = boxerrno() @@ -380,6 +394,8 @@ socket_methods.setsockopt = function(self, level, name, value) end socket_methods.getsockopt = function(self, level, name) + local fd = check_socket(self) + local info = get_ivalue(internal.SO_OPT, name) if info == nil then @@ -407,7 +423,7 @@ socket_methods.getsockopt = function(self, level, name) if info.type == 1 then local value = ffi.new("int[1]", 0) local len = ffi.new("size_t[1]", ffi.sizeof('int')) - local res = ffi.C.getsockopt(self.fh, level, info.iname, value, len) + local res = ffi.C.getsockopt(fd, level, info.iname, value, len) if res < 0 then self._errno = boxerrno() @@ -423,7 +439,7 @@ socket_methods.getsockopt = function(self, level, name) if info.type == 2 then local value = ffi.new("char[256]", { 0 }) local len = ffi.new("size_t[1]", 256) - local res = ffi.C.getsockopt(self.fh, level, info.iname, value, len) + local res = ffi.C.getsockopt(fd, level, info.iname, value, len) if res < 0 then self._errno = boxerrno() return nil @@ -438,13 +454,14 @@ socket_methods.getsockopt = function(self, level, name) end socket_methods.linger = function(self, active, timeout) + local fd = check_socket(self) local info = internal.SO_OPT.SO_LINGER self._errno = nil if active == nil then local value = ffi.new("linger_t[1]") local len = ffi.new("size_t[1]", 2 * ffi.sizeof('int')) - local res = ffi.C.getsockopt(self.fh, + local res = ffi.C.getsockopt(fd, internal.SOL_SOCKET, info.iname, value, len) if res < 0 then self._errno = boxerrno() @@ -472,7 +489,7 @@ socket_methods.linger = function(self, active, timeout) local value = ffi.new("linger_t[1]", { { active = iactive, timeout = timeout } }) local len = 2 * ffi.sizeof('int') - local res = ffi.C.setsockopt(self.fh, + local res = ffi.C.setsockopt(fd, internal.SOL_SOCKET, info.iname, value, len) if res < 0 then self._errno = boxerrno() @@ -483,25 +500,15 @@ socket_methods.linger = function(self, active, timeout) end socket_methods.accept = function(self) - + local fd = check_socket(self) self._errno = nil - local fh = ffi.C.accept(self.fh, nil, nil) - - if fh < 1 then - self._errno = boxerrno() + local cfd = ffi.C.accept(fd, nil, nil) + if cfd < 1 then + self._errno = box.errno() return nil end - - fh = tonumber(fh) - - -- Make socket to be non-blocked by default - -- ignore result - ffi.C.bsdsocket_nonblock(fh, 1) - - local socket = { fh = fh } - setmetatable(socket, socket_mt) - return socket + return bless_socket(cfd) end local function readchunk(self, size, timeout) @@ -631,6 +638,7 @@ local function readline(self, limit, eol, timeout) end socket_methods.read = function(self, opts, timeout) + check_socket(self) timeout = timeout and tonumber(timeout) or TIMEOUT_INFINITY if type(opts) == 'number' then return readchunk(self, opts, timeout) @@ -651,6 +659,7 @@ socket_methods.read = function(self, opts, timeout) end socket_methods.write = function(self, octets, timeout) + check_socket(self) if timeout == nil then timeout = TIMEOUT_INFINITY end @@ -676,10 +685,11 @@ socket_methods.write = function(self, octets, timeout) end socket_methods.send = function(self, octets, flags) + local fd = check_socket(self) local iflags = get_iflags(internal.SEND_FLAGS, flags) self._errno = nil - local res = ffi.C.send(self.fh, octets, string.len(octets), iflags) + local res = ffi.C.send(fd, octets, string.len(octets), iflags) if res == -1 then self._errno = boxerrno() return false @@ -688,6 +698,7 @@ socket_methods.send = function(self, octets, flags) end socket_methods.recv = function(self, size, flags) + local fd = check_socket(self) local iflags = get_iflags(internal.SEND_FLAGS, flags) if iflags == nil then self._errno = boxerrno.EINVAL @@ -700,7 +711,7 @@ socket_methods.recv = function(self, size, flags) self._errno = nil local buf = ffi.new("char[?]", size) - local res = ffi.C.recv(self.fh, buf, size, iflags) + local res = ffi.C.recv(fd, buf, size, iflags) if res == -1 then self._errno = boxerrno() @@ -710,13 +721,14 @@ socket_methods.recv = function(self, size, flags) end socket_methods.recvfrom = function(self, size, flags) + local fd = check_socket(self) local iflags = get_iflags(internal.SEND_FLAGS, flags) if iflags == nil then self._errno = boxerrno.EINVAL return nil end self._errno = nil - local res, from = internal.recvfrom(self.fh, size, iflags) + local res, from = internal.recvfrom(fd, size, iflags) if res == nil then self._errno = boxerrno() return nil @@ -725,6 +737,7 @@ socket_methods.recvfrom = function(self, size, flags) end socket_methods.sendto = function(self, host, port, octets, flags) + local fd = check_socket(self) local iflags = get_iflags(internal.SEND_FLAGS, flags) if iflags == nil then @@ -741,10 +754,10 @@ socket_methods.sendto = function(self, host, port, octets, flags) port = tostring(port) octets = tostring(octets) - addr_len[0] = ffi.sizeof(addr) + addr_len[0] = ffi.sizeof(addrbuf) local res = ffi.C.bsdsocket_local_resolve(host, port, addr, addr_len) if res == 0 then - res = ffi.C.sendto(self.fh, octets, string.len(octets), iflags, + res = ffi.C.sendto(fd, octets, string.len(octets), iflags, addr, addr_len[0]) end if res < 0 then @@ -779,24 +792,11 @@ local function create_socket(domain, stype, proto) iproto = tonumber(proto) end - local fh = ffi.C.socket(idomain, itype, iproto) - if fh < 0 then - return nil - end - - fh = tonumber(fh) - - -- Make socket to be non-blocked by default - if ffi.C.bsdsocket_nonblock(fh, 1) < 0 then - local errno = boxerrno() - ffi.C.close(fh) - boxerrno(errno) + local fd = ffi.C.socket(idomain, itype, iproto) + if fd < 1 then return nil end - - local socket = { fh = fh } - setmetatable(socket, socket_mt) - return socket + return bless_socket(fd) end local function getaddrinfo(host, port, timeout, opts) @@ -869,7 +869,8 @@ local soname_mt = { } socket_methods.name = function(self) - local aka = internal.name(self.fh) + local fd = check_socket(self) + local aka = internal.name(fd) if aka == nil then self._errno = boxerrno() return nil @@ -880,7 +881,8 @@ socket_methods.name = function(self) end socket_methods.peer = function(self) - local peer = internal.peer(self.fh) + local fd = check_socket(self) + local peer = internal.peer(fd) if peer == nil then self._errno = boxerrno() return nil @@ -890,6 +892,10 @@ socket_methods.peer = function(self) return peer end +socket_methods.fd = function(self) + return check_socket(self) +end + -- tcp connector local function tcp_connect_remote(remote, timeout) local s = create_socket(remote.family, remote.type, remote.protocol) @@ -1005,7 +1011,7 @@ local function tcp_server_remote(list, prepare, handler) for _, s in pairs(server.s) do fiber.create(function(s) - fiber.name(sprintf("listen_fd=%d",s.fh)) + fiber.name(sprintf("listen_fd=%d",s:fd())) while s:readable() do @@ -1052,8 +1058,10 @@ end socket_mt = { __index = socket_methods, __tostring = function(self) + local fd = check_socket(self) + local save_errno = self._errno - local name = sprintf("fd %d", self.fh) + local name = sprintf("fd %d", fd) local aka = self:name() if aka ~= nil then name = sprintf("%s, aka %s:%s", name, aka.host, aka.port) @@ -1067,18 +1075,10 @@ socket_mt = { end } -if package.loaded.socket == nil then - package.loaded.socket = {} -end - - -setmetatable(package.loaded.socket, { - __call = function(self, ...) return create_socket(...) end, - __index = { - getaddrinfo = getaddrinfo, - tcp_connect = tcp_connect, - tcp_server = tcp_server - } +package.loaded.socket = setmetatable({ + getaddrinfo = getaddrinfo, + tcp_connect = tcp_connect, + tcp_server = tcp_server +}, { + __call = function(self, ...) return create_socket(...) end; }) - -end diff --git a/test/box/bsdsocket.result b/test/box/bsdsocket.result index 2bd743400646fa1f37113ac80c07234452c8be02..9e3316a4c2ead2b27f80f330a4e61f6879517d27 100644 --- a/test/box/bsdsocket.result +++ b/test/box/bsdsocket.result @@ -46,6 +46,17 @@ type(s:error()) --- - nil ... +-- Invalid arguments +--# setopt delimiter ';' +for k in pairs(getmetatable(s).__index) do + local r, msg = pcall(s[k]) + if not msg:match('Usage:') then + error("Arguments is not checked for "..k) + end +end; +--- +... +--# setopt delimiter '' port = string.gsub(box.cfg.listen, '^.*:', '') --- ... @@ -255,7 +266,7 @@ s:getsockopt('SOL_SOCKET', 'SO_DEBUG') ... s:setsockopt('SOL_SOCKET', 'SO_ACCEPTCONN', 1) --- -- error: '[string "-- bsdsocket.lua (internal file)..."]:326: Socket option SO_ACCEPTCONN +- error: '[string "-- bsdsocket.lua (internal file)..."]:340: Socket option SO_ACCEPTCONN is read only' ... s:getsockopt('SOL_SOCKET', 'SO_RCVBUF') > 32 @@ -914,6 +925,36 @@ socket.tcp_connect('unix/', path), errno() == errno.ENOENT - null - true ... +-- invalid fd +s = socket('AF_INET', 'SOCK_STREAM', 'tcp') +--- +... +s:read(9) +--- +- null +... +s:close() +--- +- true +... +s.socket.fd = 512 +--- +... +tostring(s) +--- +- fd 512 +... +s:readable(0) +--- +- true +... +s:writable(0) +--- +- true +... +s = nil +--- +... -- close port = 65454 --- @@ -948,25 +989,6 @@ end, serv); s = socket.tcp_connect('127.0.0.1', port) --- ... -s:read(9) ---- -- Tarantool -... -sa = setmetatable({ fh = 512 }, getmetatable(s)) ---- -... -tostring(sa) ---- -- fd 512 -... -sa:readable(0) ---- -- true -... -sa:writable(0) ---- -- true -... ch = fiber.channel() --- ... @@ -1260,3 +1282,35 @@ os.remove(path) --- - true ... +-- Test that socket is closed on GC +s = socket('AF_UNIX', 'SOCK_STREAM', 'ip') +--- +... +s:bind('unix/', path) +--- +- true +... +s:listen() +--- +- true +... +s = nil +--- +... +collectgarbage('collect') +--- +- 0 +... +collectgarbage('collect') +--- +- 0 +... +socket.tcp_connect('unix/', path), errno() == errno.ECONNREFUSED +--- +- null +- true +... +os.remove(path) +--- +- true +... diff --git a/test/box/bsdsocket.test.lua b/test/box/bsdsocket.test.lua index 8a78cea3edce78e81f6639e0c3bcc80daeefab1e..be6b5ccda69a36950acfcb937be4c66be4ca34bc 100644 --- a/test/box/bsdsocket.test.lua +++ b/test/box/bsdsocket.test.lua @@ -14,6 +14,15 @@ s:wait(.01) type(s) s:errno() type(s:error()) +-- Invalid arguments +--# setopt delimiter ';' +for k in pairs(getmetatable(s).__index) do + local r, msg = pcall(s[k]) + if not msg:match('Usage:') then + error("Arguments is not checked for "..k) + end +end; +--# setopt delimiter '' port = string.gsub(box.cfg.listen, '^.*:', '') @@ -286,6 +295,16 @@ socket.tcp_connect('unix/', path), errno() == errno.ECONNREFUSED os.remove(path) socket.tcp_connect('unix/', path), errno() == errno.ENOENT +-- invalid fd +s = socket('AF_INET', 'SOCK_STREAM', 'tcp') +s:read(9) +s:close() +s.socket.fd = 512 +tostring(s) +s:readable(0) +s:writable(0) +s = nil + -- close port = 65454 serv = socket('AF_INET', 'SOCK_STREAM', 'tcp') @@ -304,12 +323,6 @@ end, serv); --# setopt delimiter '' s = socket.tcp_connect('127.0.0.1', port) -s:read(9) -sa = setmetatable({ fh = 512 }, getmetatable(s)) -tostring(sa) -sa:readable(0) -sa:writable(0) - ch = fiber.channel() f = fiber.create(function() s:read(12) ch:put(true) end) s:close() @@ -423,4 +436,12 @@ server:stop() os.remove(path) - +-- Test that socket is closed on GC +s = socket('AF_UNIX', 'SOCK_STREAM', 'ip') +s:bind('unix/', path) +s:listen() +s = nil +collectgarbage('collect') +collectgarbage('collect') +socket.tcp_connect('unix/', path), errno() == errno.ECONNREFUSED +os.remove(path)