diff --git a/src/lua/utils.c b/src/lua/utils.c index 870bc42869e20ac5c23f4c0c8a244a8b55232e97..2f579b6b643b219c1889732170b88900e4baf038 100644 --- a/src/lua/utils.c +++ b/src/lua/utils.c @@ -499,6 +499,22 @@ luaT_call(struct lua_State *L, int nargs, int nreturns) return 0; } +int +luaT_dostring(struct lua_State *L, const char *str) +{ + int top = lua_gettop(L); + if (luaL_loadstring(L, str) != 0) { + diag_set(LuajitError, lua_tostring(L, -1)); + lua_settop(L, top); + return -1; + } + if (luaT_call(L, 0, LUA_MULTRET) != 0) { + lua_settop(L, top); + return -1; + } + return 0; +} + int luaT_cpcall(lua_State *L, lua_CFunction func, void *ud) { diff --git a/src/lua/utils.h b/src/lua/utils.h index 3823ad79c7dfe9b46b98620a9f88184a76bae47f..ce90b977561a072cc5d0098f37b0604f4704a358 100644 --- a/src/lua/utils.h +++ b/src/lua/utils.h @@ -317,6 +317,13 @@ luaL_toint64(struct lua_State *L, int idx); LUA_API int luaT_call(lua_State *L, int nargs, int nreturns); +/* + * Like luaL_dostring(), but in case of error sets fiber diag instead + * of putting error on stack. + */ +int +luaT_dostring(struct lua_State *L, const char *str); + /** * Like lua_cpcall(), but with the proper support of Tarantool errors. * \sa lua_cpcall() diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 66c225ff3bf49afa862467b7a395e7949c5cf8af..e1aa9c99d2444778d3ce3acf1f35129316c434c6 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -568,3 +568,12 @@ create_unit_test(PREFIX getenv_safe SOURCES getenv_safe.c core_test_utils.c LIBRARIES unit core ) + +create_unit_test(PREFIX lua_utils + SOURCES lua_utils.c + LIBRARIES unit core server + ${LUAJIT_LIBRARIES} + ${CURL_LIBRARIES} + ${LIBYAML_LIBRARIES} + ${READLINE_LIBRARIES} +) diff --git a/test/unit/lua_utils.c b/test/unit/lua_utils.c new file mode 100644 index 0000000000000000000000000000000000000000..dcdf6f072456002bb36a3fd2ae4ea9d207bd55ea --- /dev/null +++ b/test/unit/lua_utils.c @@ -0,0 +1,140 @@ +#include <stdio.h> + +#include "diag.h" +#include "fiber.h" +#include "lualib.h" +#include "memory.h" +#include "reflection.h" + +#include "lua/utils.h" + +#define UNIT_TAP_COMPATIBLE 1 +#include "unit.h" + +static void +check_error(const char *type, const char *msg) +{ + struct error *err = diag_last_error(&fiber()->diag); + ok(strcmp(err->type->name, type) == 0, + "expected %s, got %s", type, err->type->name); + ok(strcmp(err->errmsg, msg) == 0, + "expected '%s', got '%s'", msg, err->errmsg); +} + +static void +test_toerror(lua_State *L) +{ + plan(4); + header(); + + /* test NON Tarantool error on stack */ + lua_pushstring(L, "test Lua error"); + luaT_toerror(L); + check_error("LuajitError", "test Lua error"); + /* + * luaT_toerror adds param on stack thru luaT_tolstring. + * Unfortunately the latter is public API. + */ + lua_pop(L, 2); + + /* test Tarantool error on stack */ + struct error *e = BuildIllegalParams(__FILE__, __LINE__, + "test non-Lua error"); + luaT_pusherror(L, e); + luaT_toerror(L); + check_error("IllegalParams", "test non-Lua error"); + lua_pop(L, 1); + + footer(); + check_plan(); +} + +static void +test_call(lua_State *L) +{ + plan(6); + header(); + + int v; + const char *expr; + + /* test no error on call */ + expr = "local a = {...} return a[1], a[2]"; + fail_unless(luaL_loadstring(L, expr) == 0); + lua_pushinteger(L, 3); + lua_pushinteger(L, 5); + ok(luaT_call(L, 2, 2) == 0, "call no error"); + fail_if(lua_gettop(L) != 2); + v = lua_tointeger(L, -2); + is(v, 3, "got %d", v); + v = lua_tointeger(L, -1); + is(v, 5, "got %d", v); + lua_pop(L, 2); + + /* test with error on call */ + expr = "return error('test error')"; + fail_unless(luaL_loadstring(L, expr) == 0); + ok(luaT_call(L, 0, 0) != 0, "call with error"); + check_error("LuajitError", "test error"); + /* See comment is test_toerror about stack size. */ + lua_pop(L, 2); + + footer(); + check_plan(); +} + +static void +test_dostring(lua_State *L) +{ + plan(11); + header(); + + int v; + /* test no error on call */ + ok(luaT_dostring(L, "return 3, 5") == 0, "call no error"); + fail_if(lua_gettop(L) != 2); + v = lua_tointeger(L, -2); + is(v, 3, "got %d", v); + v = lua_tointeger(L, -1); + is(v, 5, "got %d", v); + lua_pop(L, 2); + + /* test with error on call */ + const char *expr = "return error('test error')"; + ok(luaT_dostring(L, expr) != 0, "call with error"); + check_error("LuajitError", "test error"); + ok(lua_gettop(L) == 0, "got %d", lua_gettop(L)); + + /* test code loading error */ + ok(luaT_dostring(L, "*") != 0, "code loading error"); + check_error("LuajitError", + "[string \"*\"]:1: unexpected symbol near '*'"); + ok(lua_gettop(L) == 0, "got %d", lua_gettop(L)); + + footer(); + check_plan(); +} + +int +main(void) +{ + plan(3); + header(); + + struct lua_State *L = luaL_newstate(); + luaL_openlibs(L); + memory_init(); + fiber_init(fiber_c_invoke); + tarantool_lua_error_init(L); + + test_toerror(L); + test_call(L); + test_dostring(L); + + fiber_free(); + memory_free(); + lua_close(L); + + footer(); + return check_plan(); +}