Skip to content
Snippets Groups Projects
Commit b7c65039 authored by Rimma Tolkacheva's avatar Rimma Tolkacheva Committed by Igor Munkin
Browse files

test/fuzz: introduce class Context

The context object is created to manage the context of Lua program.
It will be used in the next commit to check if `break` or `return` is
inside a breakable or returnable code block.

NO_CHANGELOG=internal
NO_DOC=fuzzer fix
parent 754af7a9
No related branches found
No related tags found
No related merge requests found
......@@ -8,6 +8,8 @@
#include <stack>
#include <string>
#include <trivia/util.h>
using namespace lua_grammar;
#define PROTO_TOSTRING(TYPE, VAR_NAME) \
......@@ -184,16 +186,101 @@ GetCondition(const std::string &counter_name, const std::string &then_block)
return retval;
}
/**
* Class that registers and provides context during code
* generation.
* Used to generate correct Lua code.
*/
class Context {
public:
enum class BlockType {
kReturnable,
kBreakable,
};
void step_in(BlockType type)
{
block_stack_.push(type);
if (type == BlockType::kReturnable) {
++returnable_counter_;
}
}
void step_out()
{
assert(!block_stack_.empty());
if (block_stack_.top() == BlockType::kReturnable) {
assert(returnable_counter_ > 0);
--returnable_counter_;
}
block_stack_.pop();
}
std::string get_next_block_setup()
{
std::size_t id = GetCounterIdProvider().next();
std::string counter_name = GetCounterName(id);
return GetCondition(counter_name, get_exit_statement_()) +
GetCounterIncrement(counter_name);
}
bool break_is_possible()
{
return !block_stack_.empty() &&
block_stack_.top() == BlockType::kBreakable;
}
bool return_is_possible()
{
return returnable_counter_ > 0;
}
private:
std::string get_exit_statement_()
{
assert(!block_stack_.empty());
switch (block_stack_.top()) {
case BlockType::kBreakable:
return "break";
case BlockType::kReturnable:
return "return";
}
unreachable();
}
std::stack<BlockType> block_stack_;
/*
* The returnable block can be exited with return from
* the breakable block within it, but the breakable block
* cannot be exited with break from the returnable block within
* it.
* Valid code:
* `function foo() while true do return end end`
* Erroneous code:
* `while true do function foo() break end end`
* This counter is used to check if `return` is possible.
*/
uint64_t returnable_counter_ = 0;
};
Context&
GetContext()
{
static Context context;
return context;
}
/**
* Block may be placed not only in a cycle, so specially for cycles
* there is a function that will add a break condition and a
* counter increment.
*/
std::string
BlockToStringCycleProtected(const Block &block, const std::string &counter_name)
BlockToStringCycleProtected(const Block &block)
{
std::string retval = GetCondition(counter_name, "break");
retval += GetCounterIncrement(counter_name);
std::string retval = GetContext().get_next_block_setup();
retval += ChunkToString(block.chunk());
return retval;
}
......@@ -204,11 +291,10 @@ BlockToStringCycleProtected(const Block &block, const std::string &counter_name)
* BlockToStringCycleProtected().
*/
std::string
DoBlockToStringCycleProtected(const DoBlock &block,
const std::string &counter_name)
DoBlockToStringCycleProtected(const DoBlock &block)
{
std::string retval = "do\n";
retval += BlockToStringCycleProtected(block.block(), counter_name);
retval += BlockToStringCycleProtected(block.block());
retval += "end\n";
return retval;
}
......@@ -219,8 +305,7 @@ DoBlockToStringCycleProtected(const DoBlock &block,
* increment.
*/
std::string
FuncBodyToStringReqProtected(const FuncBody &body,
const std::string &counter_name)
FuncBodyToStringReqProtected(const FuncBody &body)
{
std::string body_str = "( ";
if (body.has_parlist()) {
......@@ -228,8 +313,7 @@ FuncBodyToStringReqProtected(const FuncBody &body,
}
body_str += " )\n\t";
body_str += GetCondition(counter_name, "return");
body_str += GetCounterIncrement(counter_name);
body_str += GetContext().get_next_block_setup();
body_str += BlockToString(body.block());
body_str += "end\n";
......@@ -473,15 +557,14 @@ PROTO_TOSTRING(DoBlock, block)
*/
PROTO_TOSTRING(WhileCycle, whilecycle)
{
const auto id = GetCounterIdProvider().next();
auto counter_name = GetCounterName(id);
GetContext().step_in(Context::BlockType::kBreakable);
std::string whilecycle_str = "while ";
whilecycle_str += ExpressionToString(whilecycle.condition());
whilecycle_str += " ";
whilecycle_str += DoBlockToStringCycleProtected(whilecycle.doblock(),
counter_name);
whilecycle_str += DoBlockToStringCycleProtected(whilecycle.doblock());
GetContext().step_out();
return whilecycle_str;
}
......@@ -490,15 +573,14 @@ PROTO_TOSTRING(WhileCycle, whilecycle)
*/
PROTO_TOSTRING(RepeatCycle, repeatcycle)
{
const auto id = GetCounterIdProvider().next();
auto counter_name = GetCounterName(id);
GetContext().step_in(Context::BlockType::kBreakable);
std::string repeatcycle_str = "repeat\n";
repeatcycle_str += BlockToStringCycleProtected(repeatcycle.block(),
counter_name);
repeatcycle_str += BlockToStringCycleProtected(repeatcycle.block());
repeatcycle_str += "until ";
repeatcycle_str += ExpressionToString(repeatcycle.condition());
GetContext().step_out();
return repeatcycle_str;
}
......@@ -538,8 +620,7 @@ NESTED_PROTO_TOSTRING(ElseIfBlock, elseifblock, IfStatement)
*/
PROTO_TOSTRING(ForCycleName, forcyclename)
{
const auto id = GetCounterIdProvider().next();
auto counter_name = GetCounterName(id);
GetContext().step_in(Context::BlockType::kBreakable);
std::string forcyclename_str = "for ";
forcyclename_str += NameToString(forcyclename.name());
......@@ -554,7 +635,9 @@ PROTO_TOSTRING(ForCycleName, forcyclename)
forcyclename_str += " ";
forcyclename_str += DoBlockToStringCycleProtected(
forcyclename.doblock(), counter_name);
forcyclename.doblock());
GetContext().step_out();
return forcyclename_str;
}
......@@ -563,8 +646,7 @@ PROTO_TOSTRING(ForCycleName, forcyclename)
*/
PROTO_TOSTRING(ForCycleList, forcyclelist)
{
const auto id = GetCounterIdProvider().next();
auto counter_name = GetCounterName(id);
GetContext().step_in(Context::BlockType::kBreakable);
std::string forcyclelist_str = "for ";
forcyclelist_str += NameListToString(forcyclelist.names());
......@@ -572,7 +654,9 @@ PROTO_TOSTRING(ForCycleList, forcyclelist)
forcyclelist_str += ExpressionListToString(forcyclelist.expressions());
forcyclelist_str += " ";
forcyclelist_str += DoBlockToStringCycleProtected(
forcyclelist.doblock(), counter_name);
forcyclelist.doblock());
GetContext().step_out();
return forcyclelist_str;
}
......@@ -581,13 +665,13 @@ PROTO_TOSTRING(ForCycleList, forcyclelist)
*/
PROTO_TOSTRING(Function, func)
{
const auto id = GetCounterIdProvider().next();
auto counter_name = GetCounterName(id);
GetContext().step_in(Context::BlockType::kReturnable);
std::string func_str = "function ";
func_str += FuncNameToString(func.name());
func_str += FuncBodyToStringReqProtected(func.body(), counter_name);
func_str += FuncBodyToStringReqProtected(func.body());
GetContext().step_out();
return func_str;
}
......@@ -639,15 +723,14 @@ NESTED_PROTO_TOSTRING(ParList, parlist, FuncBody)
*/
PROTO_TOSTRING(LocalFunc, localfunc)
{
const auto id = GetCounterIdProvider().next();
auto counter_name = GetCounterName(id);
GetContext().step_in(Context::BlockType::kReturnable);
std::string localfunc_str = "local function ";
localfunc_str += NameToString(localfunc.name());
localfunc_str += " ";
localfunc_str += FuncBodyToStringReqProtected(localfunc.funcbody(),
counter_name);
localfunc_str += FuncBodyToStringReqProtected(localfunc.funcbody());
GetContext().step_out();
return localfunc_str;
}
......@@ -789,12 +872,12 @@ PROTO_TOSTRING(Expression, expr)
NESTED_PROTO_TOSTRING(AnonFunc, func, Expression)
{
const auto id = GetCounterIdProvider().next();
auto counter_name = GetCounterName(id);
GetContext().step_in(Context::BlockType::kReturnable);
std::string retval = "function ";
retval += FuncBodyToStringReqProtected(func.body(), counter_name);
retval += FuncBodyToStringReqProtected(func.body());
GetContext().step_out();
return retval;
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment