From 61130805b15c506657818606f3a85a3b2b32311d Mon Sep 17 00:00:00 2001
From: Vladimir Davydov <vdavydov@tarantool.org>
Date: Thu, 3 Aug 2023 12:00:28 +0300
Subject: [PATCH] lua/socket: introduce socket.socketpair

The new function is a wrapper around the socketpair system call.
It takes the same arguments as the socket constructor and returns
two socket objects representing the two ends of the newly created
socket pair on success.

It may be useful for establishing a communication channel between
related processes.

Closes #8927

@TarantoolBot document
Title: Document new socket functions

Two socket module functions and one socket object method were added to
Tarantool 3.0:

- `socket.from_fd(fd)`: constructs a socket object from the file
  descriptor number. Returns the new socket object on success. Never
  fails. Note, the function doesn't perform any checks on the given
  file descriptor so technically it's possible to pass a closed file
  descriptor or a file descriptor that refers to a file, in which case
  the new socket object methods may not work as expected.

- `socket:detach()`: like `socket:close()` but doesn't close the socket
  file descriptor, only switches the socket object to the closed state.
  Returns nothing. If the socket was already detached or closed, raises
  an exception.

  Along with `socket.from_fd`, this method may be used for transferring
  file descriptor ownership from one socket to another:

  ```Lua
  local socket = require('socket')
  local s1 = socket('AF_INET', 'SOCK_STREAM', 'tcp')
  local s2 = socket.from_fd(s1:fd())
  s1:detach()
  ```

- `socket.socketpair(domain, type, proto)`: a wrapper around the
  [`socketpair`][1] system call. Returns two socket objects representing
  the two ends of the new socket pair on success. On failure, returns
  nil and sets [`errno`][2].

  Example:

  ```Lua
  local errno = require('errno')
  local socket = require('socket')
  local s1, s1 = socket.socketpair('AF_UNIX', 'SOCK_STREAM', 0)
  if not s1 then
      error('socketpair: ' .. errno.strerror())
  end
  s1:send('foo')
  assert(s2:recv() == 'foo')
  s1:close()
  s2:close()
  ```

[1]: https://man7.org/linux/man-pages/man2/socketpair.2.html
[2]: https://www.tarantool.io/en/doc/latest/reference/reference_lua/errno/
---
 changelogs/unreleased/gh-8927-socketpair.md |  4 +++
 src/lua/socket.lua                          | 31 ++++++++++++++++++++-
 test/app-luatest/socket_test.lua            | 24 ++++++++++++++++
 3 files changed, 58 insertions(+), 1 deletion(-)
 create mode 100644 changelogs/unreleased/gh-8927-socketpair.md

diff --git a/changelogs/unreleased/gh-8927-socketpair.md b/changelogs/unreleased/gh-8927-socketpair.md
new file mode 100644
index 0000000000..66f09d53b8
--- /dev/null
+++ b/changelogs/unreleased/gh-8927-socketpair.md
@@ -0,0 +1,4 @@
+## feature/lua/socket
+
+* Introduced new socket functions `socket.socketpair`, `socket.from_fd`, and
+  `socket:detach` (gh-8927).
diff --git a/src/lua/socket.lua b/src/lua/socket.lua
index 5aac3f662f..2858148049 100644
--- a/src/lua/socket.lua
+++ b/src/lua/socket.lua
@@ -32,6 +32,7 @@ ffi.cdef[[
     ssize_t read(int fd, void *buf, size_t count);
     int listen(int fd, int backlog);
     int socket(int domain, int type, int protocol);
+    int socketpair(int domain, int type, int protocol, int sv[2]);
     int coio_close(int s);
     int shutdown(int s, int how);
     ssize_t send(int sockfd, const void *buf, size_t len, int flags);
@@ -933,7 +934,7 @@ local function socket_sendto(self, host, port, octets, flags)
     return tonumber(res)
 end
 
-local function socket_new(domain, stype, proto)
+local function check_socket_args(domain, stype, proto)
     local idomain = get_ivalue(internal.DOMAIN, domain)
     if idomain == nil then
         boxerrno(boxerrno.EINVAL)
@@ -951,6 +952,14 @@ local function socket_new(domain, stype, proto)
         return nil
     end
 
+    return idomain, itype, iproto
+end
+
+local function socket_new(domain, stype, proto)
+    local idomain, itype, iproto = check_socket_args(domain, stype, proto)
+    if idomain == nil then
+        return nil
+    end
     local fd = ffi.C.socket(idomain, itype, iproto)
     if fd >= 0 then
         local socket = make_socket(fd, itype)
@@ -962,6 +971,25 @@ local function socket_new(domain, stype, proto)
     end
 end
 
+local function socket_socketpair(domain, stype, proto)
+    local idomain, itype, iproto = check_socket_args(domain, stype, proto)
+    if idomain == nil then
+        return nil
+    end
+    local sv = ffi.new('int[2]')
+    if ffi.C.socketpair(idomain, itype, iproto, sv) ~= 0 then
+        return nil
+    end
+    local s1 = make_socket(sv[0], itype)
+    local s2 = make_socket(sv[1], itype)
+    if not s1:nonblock(true) or not s2:nonblock(true) then
+        s1:close()
+        s2:close()
+        return nil
+    end
+    return s1, s2
+end
+
 local function socket_from_fd(fd)
     if type(fd) ~= 'number' then
         error('fd must be a number')
@@ -1622,6 +1650,7 @@ end
 
 return setmetatable({
     from_fd = socket_from_fd;
+    socketpair = socket_socketpair;
     getaddrinfo = getaddrinfo,
     tcp_connect = tcp_connect,
     tcp_server = tcp_server,
diff --git a/test/app-luatest/socket_test.lua b/test/app-luatest/socket_test.lua
index 92374ff8d4..db099b415b 100644
--- a/test/app-luatest/socket_test.lua
+++ b/test/app-luatest/socket_test.lua
@@ -1,3 +1,4 @@
+local errno = require('errno')
 local socket = require('socket')
 local t = require('luatest')
 
@@ -58,3 +59,26 @@ g.test_detach = function()
     t.assert_is_not(s2:name(), nil)
     t.assert(s2:close())
 end
+
+g.test_socketpair = function()
+    t.assert_is(socket.socketpair(), nil)
+    t.assert_equals(errno(), errno.EINVAL)
+    t.assert_is(socket.socketpair('foo'), nil)
+    t.assert_equals(errno(), errno.EINVAL)
+    t.assert_is(socket.socketpair('AF_UNIX', 'bar'), nil)
+    t.assert_equals(errno(), errno.EINVAL)
+    t.assert_is(socket.socketpair('AF_UNIX', 'SOCK_STREAM', 'baz'), nil)
+    t.assert_equals(errno(), errno.EPROTOTYPE)
+    t.assert_is(socket.socketpair('AF_INET', 'SOCK_STREAM', 0), nil)
+    t.assert_equals(errno(), errno.EOPNOTSUPP)
+
+    local s1, s2 = socket.socketpair('AF_UNIX', 'SOCK_STREAM', 0)
+    t.assert(s1)
+    t.assert(s2)
+    t.assert(s1:nonblock())
+    t.assert(s2:nonblock())
+    t.assert_equals(s1:send('foo'), 3)
+    t.assert_equals(s2:recv(), 'foo')
+    t.assert(s1:close())
+    t.assert(s2:close())
+end
-- 
GitLab