From 09bb3e96be20fc53e8bfea48fe593d45ccb275b8 Mon Sep 17 00:00:00 2001
From: Alexander Turenko <alexander.turenko@tarantool.org>
Date: Mon, 4 Mar 2024 17:02:28 +0300
Subject: [PATCH] lyaml: fix alias serialization

The #8350 was introduced by the commit b42302f52ada ("lua-yaml: enable
aliasing for objects returned by __serialize") so the patch is
effectively reversed.

The idea is to call all object __serialize methods recursively before
finding references. The new serialization pass stores the mapping from
the original object to the serialized representation.

After this, the reference analysis pass and the encoding pass use this
mapping to replace original objects with the serialized representation.

As result, the reference analysis has a complete information about
objects and no references are missed.

Closes #8350
Closes #8310
Closes #8321

NO_DOC=bugfix

Co-authored-by: Nikolay Shirokovskiy <nshirokovskiy@tarantool.org>
(cherry picked from commit 610f5fb7e7e3d2759135f32c00d6db1583f2a496)
---
 .../fix-yaml-alias-serialization.md           |   3 +
 src/lua/serializer.c                          | 152 +++++++++++++++++-
 src/lua/serializer.h                          |  40 ++++-
 src/lua/utils.c                               |  18 +++
 src/lua/utils.h                               |   7 +
 test/app-luatest/serializer_test.lua          |  64 ++++++++
 test/box/tuple.result                         |  12 +-
 third_party/lua-yaml/lyaml.cc                 |  34 ++--
 8 files changed, 311 insertions(+), 19 deletions(-)
 create mode 100644 changelogs/unreleased/fix-yaml-alias-serialization.md
 create mode 100644 test/app-luatest/serializer_test.lua

diff --git a/changelogs/unreleased/fix-yaml-alias-serialization.md b/changelogs/unreleased/fix-yaml-alias-serialization.md
new file mode 100644
index 0000000000..057322ca5a
--- /dev/null
+++ b/changelogs/unreleased/fix-yaml-alias-serialization.md
@@ -0,0 +1,3 @@
+## bugfix/lua
+
+* Fixed alias detection in the YAML encoder (gh-8350, gh-8310, gh-8321).
diff --git a/src/lua/serializer.c b/src/lua/serializer.c
index 7c350a3ad4..5661950260 100644
--- a/src/lua/serializer.c
+++ b/src/lua/serializer.c
@@ -329,7 +329,6 @@ lua_field_try_serialize(struct lua_State *L, struct luaL_serializer *cfg,
 		if (luaL_tofield(L, cfg, -1, field) != 0)
 			return -1;
 		lua_replace(L, idx);
-		field->serialized = true;
 		return 0;
 	}
 	if (!lua_isstring(L, -1)) {
@@ -441,7 +440,6 @@ luaL_tofield(struct lua_State *L, struct luaL_serializer *cfg, int index,
 	field->type = MP_NIL;
 	field->ext_type = MP_UNKNOWN_EXTENSION;
 	field->compact = false;
-	field->serialized = false;
 
 	if (index < 0)
 		index = lua_gettop(L) + index + 1;
@@ -642,6 +640,156 @@ luaL_convertfield(struct lua_State *L, struct luaL_serializer *cfg, int idx,
 
 /* }}} Fill luaL_field */
 
+/* {{{ luaT_reftable */
+
+/**
+ * Traversal context for creating a table of references from
+ * original objects to serialized ones.
+ */
+struct reftable_new_ctx {
+	/** Serialization options. */
+	struct luaL_serializer *cfg;
+	/** Index of a references table on the Lua stack. */
+	int reftable_index;
+	/** Index of a visited objects table on the Lua stack. */
+	int visited_index;
+};
+
+/**
+ * Serialize the object given at top of the Lua stack and all the
+ * descendant ones recursively and fill a mapping from the
+ * original objects to the resulting ones.
+ *
+ * The serialization is performed using luaL_checkfield().
+ *
+ * The function leaves the Lua stack size unchanged.
+ */
+static void
+luaT_reftable_new_impl(struct lua_State *L, struct reftable_new_ctx *ctx)
+{
+	struct luaL_field field;
+
+	/*
+	 * We're not interested in values that can't have
+	 * __serialize or __tostring metamethods.
+	 */
+	if (!lua_istable(L, -1) &&
+	    !luaL_iscdata(L, -1) &&
+	    !lua_isuserdata(L, -1))
+		return;
+
+	/*
+	 * Check if the object is already visited.
+	 *
+	 * Just to don't go into the infinite recursion.
+	 */
+	if (luaT_hasfield(L, -1, ctx->visited_index))
+		return;
+
+	/* Mark the object as visited. */
+	lua_pushvalue(L, -1);
+	lua_pushboolean(L, true);
+	lua_settable(L, ctx->visited_index);
+
+	/*
+	 * Check if the object is already saved in the reference
+	 * table.
+	 */
+	if (luaT_hasfield(L, -1, ctx->reftable_index))
+		return;
+
+	/*
+	 * Copy the original object and serialize it. The
+	 * luaL_checkfield() function replaces the value on the
+	 * Lua stack with the serialized one (or left it as is).
+	 */
+	lua_pushvalue(L, -1);
+	luaL_checkfield(L, ctx->cfg, -1, &field);
+
+	/*
+	 * Save {original object -> serialized object} in the
+	 * reference table.
+	 */
+	if (!lua_rawequal(L, -1, -2)) {
+		lua_pushvalue(L, -2); /* original object */
+		lua_pushvalue(L, -2); /* serialized object */
+		lua_settable(L, ctx->reftable_index);
+	}
+
+	/*
+	 * Check if the serialized object is already saved in the
+	 * reference table.
+	 */
+	if (luaT_hasfield(L, -1, ctx->reftable_index)) {
+		lua_pop(L, 1);
+		return;
+	}
+
+	/*
+	 * Go down into the recursion to analyze the fields if the
+	 * serialized object is a table.
+	 */
+	if (lua_istable(L, -1)) {
+		lua_pushnil(L);
+		while (lua_next(L, -2)) {
+			luaT_reftable_new_impl(L, ctx);
+			lua_pop(L, 1);
+			luaT_reftable_new_impl(L, ctx);
+		}
+	}
+
+	/* Pop the serialized value, leave the original one. */
+	lua_pop(L, 1);
+}
+
+int
+luaT_reftable_new(struct lua_State *L, struct luaL_serializer *cfg, int idx)
+{
+	/*
+	 * Fill the traversal context.
+	 *
+	 * Create a reference table and a visited objects table.
+	 */
+	struct reftable_new_ctx ctx;
+	ctx.cfg = cfg;
+	lua_newtable(L);
+	ctx.reftable_index = lua_gettop(L);
+	lua_newtable(L);
+	ctx.visited_index = lua_gettop(L);
+
+	/*
+	 * Copy the given object on top of the Lua stack and
+	 * traverse all its descendants recursively.
+	 *
+	 * Fill the reference table for all the met objects that
+	 * are changed by the serialization.
+	 */
+	lua_pushvalue(L, idx);
+	luaT_reftable_new_impl(L, &ctx);
+
+	/*
+	 * Pop the copy of the given object and the visited
+	 * objects table. Leave the reference table on the top.
+	 */
+	lua_pop(L, 2);
+
+	return 1;
+}
+
+void
+luaT_reftable_serialize(struct lua_State *L, int reftable_index)
+{
+	lua_pushvalue(L, -1);
+	lua_gettable(L, reftable_index);
+	if (lua_isnil(L, -1)) {
+		lua_pop(L, 1);
+	} else {
+		lua_replace(L, -2);
+	}
+}
+
+/* }}} luaT_reftable */
+
 int
 tarantool_lua_serializer_init(struct lua_State *L)
 {
diff --git a/src/lua/serializer.h b/src/lua/serializer.h
index 7ba85a051e..6148ff907d 100644
--- a/src/lua/serializer.h
+++ b/src/lua/serializer.h
@@ -241,8 +241,6 @@ struct luaL_field {
 	/* subtypes of MP_EXT */
 	enum mp_extension_type ext_type;
 	bool compact;                /* a flag used by YAML serializer */
-	/** Set if __serialize metamethod was called for this field. */
-	bool serialized;
 };
 
 /**
@@ -382,6 +380,44 @@ luaL_checkfinite(struct lua_State *L, struct luaL_serializer *cfg,
 		luaL_error(L, "number must not be NaN or Inf");
 }
 
+/* {{{ luaT_reftable */
+
+/**
+ * Serialize the object at the given Lua stack index and all the
+ * descendant ones recursively and create a mapping from the
+ * original objects to the resulting ones.
+ *
+ * The mapping (a Lua table) is pushed on top of the Lua stack.
+ * The function returns amount of objects pushed to the stack: it
+ * is always 1.
+ *
+ * The serialization is performed using luaL_checkfield() with the
+ * provided configuration.
+ *
+ * A table that indirectly references itself is a valid input for
+ * this function: it tracks visited objects internally to break
+ * the cycle.
+ *
+ * If an error is raised by a __serialize or __tostring
+ * metamethod, it is raised by this function (not caught).
+ */
+int
+luaT_reftable_new(struct lua_State *L, struct luaL_serializer *cfg, int idx);
+
+/**
+ * Look for an object from top of the Lua stack in the reference
+ * table and, if found, replace it with the saved serialized
+ * object.
+ *
+ * If the object is not found, do nothing.
+ *
+ * The function leaves the stack size unchanged.
+ */
+void
+luaT_reftable_serialize(struct lua_State *L, int reftable_index);
+
+/* }}} luaT_reftable */
+
 int
 tarantool_lua_serializer_init(struct lua_State *L);
 
diff --git a/src/lua/utils.c b/src/lua/utils.c
index cbc5534a7f..cccde3a912 100644
--- a/src/lua/utils.c
+++ b/src/lua/utils.c
@@ -744,6 +744,24 @@ luaL_checkconstchar(struct lua_State *L, int idx, const char **res,
 	return 0;
 }
 
+bool
+luaT_hasfield(struct lua_State *L, int obj_index, int table_index)
+{
+	/*
+	 * lua_pushvalue() changes the size of the Lua stack, so
+	 * calling lua_gettable() with a relative index would pick
+	 * up a wrong object.
+	 */
+	if (table_index < 0)
+		table_index += lua_gettop(L) + 1;
+
+	lua_pushvalue(L, obj_index);
+	lua_gettable(L, table_index);
+	bool res = !lua_isnil(L, -1);
+	lua_pop(L, 1);
+	return res;
+}
+
 lua_State *
 luaT_state(void)
 {
diff --git a/src/lua/utils.h b/src/lua/utils.h
index 3e66c04017..c07ab10318 100644
--- a/src/lua/utils.h
+++ b/src/lua/utils.h
@@ -490,6 +490,13 @@ int
 luaL_checkconstchar(struct lua_State *L, int idx, const char **res,
 		    uint32_t *cdata_type_p);
 
+/**
+ * Whether the object at the given valid index is in the table at
+ * the given valid index.
+ */
+bool
+luaT_hasfield(struct lua_State *L, int obj_index, int table_index);
+
 /* {{{ Helper functions to interact with a Lua iterator from C */
 
 /**
diff --git a/test/app-luatest/serializer_test.lua b/test/app-luatest/serializer_test.lua
new file mode 100644
index 0000000000..d5ab94d012
--- /dev/null
+++ b/test/app-luatest/serializer_test.lua
@@ -0,0 +1,64 @@
+local yaml = require('yaml')
+local fiber = require('fiber')
+local t = require('luatest')
+
+local g = t.group()
+
+local strip = function(str)
+    return str:gsub('^%s*', ''):gsub('\n%s*', '\n')
+end
+
+local function serialize(o, s)
+    s = s or yaml
+    return s.decode(s.encode(o))
+end
+
+g.test_recursion = function()
+    local x = {}
+    x.x = x
+    local res = serialize(x)
+    t.assert(rawequal(res, res.x))
+end
+
+g.test_stress = function()
+    local s = yaml.new()
+    s.cfg({encode_use_tostring = true})
+
+    -- Shouldn't raise or cycle.
+    serialize(_G, s)
+end
+
+g.test_gh_8350_no_unnecessary_anchors = function()
+    local x = {{}}
+    setmetatable(x, {__serialize = function(_) return {x[1]} end})
+    local expected = [[
+        ---
+        - []
+        ...
+    ]]
+    t.assert_equals(yaml.encode(x), strip(expected))
+end
+
+g.test_gh_8310_alias_across_serialize_method = function()
+    local x = {}
+    local y = setmetatable({}, {__serialize = function() return x end})
+    local z = {x, y}
+    local expected = [[
+        ---
+        - &0 []
+        - *0
+        ...
+    ]]
+    t.assert_equals(yaml.encode(z), strip(expected))
+end
+
+g.test_gh_8321_alias_between_same_udata_objects = function()
+    local x = serialize({fiber.self(), fiber.self()})
+    t.assert(rawequal(x[1], x[2]))
+end
+
+g.test_gh_8321_alias_between_same_cdata_objects = function()
+    local tuple = box.tuple.new({})
+    local x = serialize({tuple, tuple})
+    t.assert(rawequal(x[1], x[2]))
+end
diff --git a/test/box/tuple.result b/test/box/tuple.result
index 0298b2c9f7..279b27e1c3 100644
--- a/test/box/tuple.result
+++ b/test/box/tuple.result
@@ -534,8 +534,8 @@ gen, init, state = t:pairs()
 gen, init, state
 ---
 - gen: <tuple iterator>
-  param: ['a', 'b', 'c']
-- ['a', 'b', 'c']
+  param: &0 ['a', 'b', 'c']
+- *0
 - null
 ...
 state, val = gen(init, state)
@@ -641,16 +641,16 @@ r
 t:pairs(nil)
 ---
 - gen: <tuple iterator>
-  param: ['a', 'b', 'c']
-- ['a', 'b', 'c']
+  param: &0 ['a', 'b', 'c']
+- *0
 - null
 ...
 t:pairs("fdsaf")
 ---
 - state: fdsaf
   gen: <tuple iterator>
-  param: ['a', 'b', 'c']
-- ['a', 'b', 'c']
+  param: &0 ['a', 'b', 'c']
+- *0
 - fdsaf
 ...
 --------------------------------------------------------------------------------
diff --git a/third_party/lua-yaml/lyaml.cc b/third_party/lua-yaml/lyaml.cc
index 4759bb7ec8..ed0fca40ae 100644
--- a/third_party/lua-yaml/lyaml.cc
+++ b/third_party/lua-yaml/lyaml.cc
@@ -99,6 +99,7 @@ struct lua_yaml_dumper {
 
    lua_State *outputL;
    luaL_Buffer yamlbuf;
+   int reftable_index;
 };
 
 /**
@@ -616,8 +617,6 @@ static int yaml_is_flow_mode(struct lua_yaml_dumper *dumper) {
    return 0;
 }
 
-static void find_references(struct lua_yaml_dumper *dumper);
-
 static int dump_node(struct lua_yaml_dumper *dumper)
 {
    size_t len = 0;
@@ -632,14 +631,13 @@ static int dump_node(struct lua_yaml_dumper *dumper)
    bool unused;
    (void) unused;
 
+   luaT_reftable_serialize(dumper->L, dumper->reftable_index);
    yaml_char_t *anchor = get_yaml_anchor(dumper);
    if (anchor && !*anchor)
       return 1;
 
    int top = lua_gettop(dumper->L);
    luaL_checkfield(dumper->L, dumper->cfg, top, &field);
-   if (field.serialized)
-      find_references(dumper);
    switch(field.type) {
    case MP_UINT:
       snprintf(buf, sizeof(buf) - 1, "%" PRIu64, field.ival);
@@ -773,9 +771,12 @@ static int append_output(void *arg, unsigned char *buf, size_t len) {
 }
 
 static void find_references(struct lua_yaml_dumper *dumper) {
-   int newval = -1, type = lua_type(dumper->L, -1);
-   if (type != LUA_TTABLE)
-      return;
+   int newval = -1;
+
+   lua_pushvalue(dumper->L, -1); /* push copy of table */
+   luaT_reftable_serialize(dumper->L, dumper->reftable_index);
+   if (lua_type(dumper->L, -1) != LUA_TTABLE)
+      goto done;
 
    lua_pushvalue(dumper->L, -1); /* push copy of table */
    lua_rawget(dumper->L, dumper->anchortable_index);
@@ -790,7 +791,7 @@ static void find_references(struct lua_yaml_dumper *dumper) {
       lua_rawset(dumper->L, dumper->anchortable_index);
    }
    if (newval)
-      return;
+      goto done;
 
    /* recursively process other table values */
    lua_pushnil(dumper->L);
@@ -799,6 +800,17 @@ static void find_references(struct lua_yaml_dumper *dumper) {
       lua_pop(dumper->L, 1);
       find_references(dumper); /* find references on key */
    }
+
+done:
+   /*
+    * Pop the serialized object, leave the original object on top
+    * of the Lua stack.
+    *
+    * NB: It is important for the cycle above: it assumes that
+    * table keys are not changed in the recursive call. Otherwise
+    * it would feed an incorrect key to lua_next().
+    */
+   lua_pop(dumper->L, 1);
 }
 
 int
@@ -845,12 +857,16 @@ lua_yaml_encode(lua_State *L, struct luaL_serializer *serializer,
    lua_newtable(L);
    dumper.anchortable_index = lua_gettop(L);
    dumper.anchor_number = 0;
+
+   luaT_reftable_new(L, dumper.cfg, 1);
+   dumper.reftable_index = lua_gettop(L);
+
    lua_pushvalue(L, 1); /* push copy of arg we're processing */
    find_references(&dumper);
    dump_document(&dumper);
    if (dumper.error)
       goto error;
-   lua_pop(L, 2); /* pop copied arg and anchor table */
+   lua_pop(L, 3); /* pop copied arg and anchor/ref tables */
 
    if (!yaml_stream_end_event_initialize(&ev) ||
        !yaml_emitter_emit(&dumper.emitter, &ev) ||
-- 
GitLab