diff --git a/src/lua/decimal.c b/src/lua/decimal.c index de6586c8b86e5e92561da1ff6b1b2f9bbb550a49..7f9358787269264bf4cb436f550c09ea7827937e 100644 --- a/src/lua/decimal.c +++ b/src/lua/decimal.c @@ -69,6 +69,10 @@ ldecimal_##name(struct lua_State *L) { \ static int \ ldecimal_##name(struct lua_State *L) { \ assert(lua_gettop(L) == 2); \ + if (lua_isnil(L, 1) || lua_isnil(L, 2)) { \ + luaL_error(L, "attempt to compare decimal with nil"); \ + return 1; \ + } \ decimal_t *lhs = lua_todecimal(L, 1); \ decimal_t *rhs = lua_todecimal(L, 2); \ lua_pushboolean(L, decimal_compare(lhs, rhs) cmp 0); \ @@ -226,10 +230,23 @@ LDECIMAL_FUNC(exp, exp) LDECIMAL_FUNC(sqrt, sqrt) LDECIMAL_FUNC(abs, abs) -LDECIMAL_CMPOP(eq, ==) LDECIMAL_CMPOP(lt, <) LDECIMAL_CMPOP(le, <=) +static int +ldecimal_eq(struct lua_State *L) +{ + assert(lua_gettop(L) == 2); + if (lua_isnil(L, 1) || lua_isnil(L, 2)) { + lua_pushboolean(L, false); + return 1; + } + decimal_t *lhs = lua_todecimal(L, 1); + decimal_t *rhs = lua_todecimal(L, 2); + lua_pushboolean(L, decimal_compare(lhs, rhs) == 0); + return 1; +} + static int ldecimal_minus(struct lua_State *L) { diff --git a/test/app/decimal.result b/test/app/decimal.result index c632f57a7ebbfde89fc8888d320e8d7eb816a7f1..2e44928bb78df46ce0ebbac9bf88c9ccc4c4d461 100644 --- a/test/app/decimal.result +++ b/test/app/decimal.result @@ -223,6 +223,32 @@ b | - '0.1' | ... +-- check comparsion with nil +a == nil + | --- + | - false + | ... +a ~= nil + | --- + | - true + | ... +a > nil + | --- + | - error: '[string "return a > nil "]:1: attempt to compare decimal with nil' + | ... +a < nil + | --- + | - error: '[string "return a < nil "]:1: attempt to compare decimal with nil' + | ... +a >= nil + | --- + | - error: '[string "return a >= nil "]:1: attempt to compare decimal with nil' + | ... +a <= nil + | --- + | - error: '[string "return a <= nil "]:1: attempt to compare decimal with nil' + | ... + decimal.sqrt(a) | --- | - '3.1622776601683793319988935444327185337' diff --git a/test/app/decimal.test.lua b/test/app/decimal.test.lua index 40f1f29deb907b63c0501f6a96bad68fc0724a12..d83522b4592931f6eb9bd2294c318a86ebf2c65e 100644 --- a/test/app/decimal.test.lua +++ b/test/app/decimal.test.lua @@ -62,6 +62,14 @@ a ~= b a b +-- check comparsion with nil +a == nil +a ~= nil +a > nil +a < nil +a >= nil +a <= nil + decimal.sqrt(a) decimal.ln(a) decimal.log10(a)