From f84401ec477380f98f7917ea959b04b495a9f856 Mon Sep 17 00:00:00 2001
From: EmirVildanov <reddog201030@gmail.com>
Date: Tue, 22 Oct 2024 23:18:14 +0300
Subject: [PATCH] feat: support IS [NOT] expression

---
 doc/sql/query.ebnf                       |  2 +-
 sbroad-core/src/frontend/sql.rs          | 46 +++++++++---
 sbroad-core/src/frontend/sql/ir/tests.rs | 94 ++++++++++++++++++++++++
 sbroad-core/src/frontend/sql/query.pest  |  5 +-
 4 files changed, 132 insertions(+), 15 deletions(-)

diff --git a/doc/sql/query.ebnf b/doc/sql/query.ebnf
index addedb7ac..c38100a64 100644
--- a/doc/sql/query.ebnf
+++ b/doc/sql/query.ebnf
@@ -38,7 +38,7 @@ expression  ::= ('NOT'* (
                     | 'NOT'? 'EXISTS' '(' dql ')'
                     | '(' dql ')'
                     | '(' expression (',' expression)* ')'
-                ) ('IS' 'NOT'? 'NULL')?)
+                ) ('IS' 'NOT'? ('NULL' | 'TRUE' | 'FALSE' | 'UNKNOWN'))*)
                 | expression
                 (
                     'NOT'? 'BETWEEN' expression 'AND'
diff --git a/sbroad-core/src/frontend/sql.rs b/sbroad-core/src/frontend/sql.rs
index 524486d5d..42972581d 100644
--- a/sbroad-core/src/frontend/sql.rs
+++ b/sbroad-core/src/frontend/sql.rs
@@ -1550,7 +1550,7 @@ fn parse_param<M: Metadata>(
 lazy_static::lazy_static! {
     static ref PRATT_PARSER: PrattParser<Rule> = {
         use pest::pratt_parser::{Assoc::{Left, Right}, Op};
-        use Rule::{Add, And, Between, ConcatInfixOp, Divide, Eq, Escape, Gt, GtEq, In, IsNullPostfix, CastPostfix, Like, Lt, LtEq, Multiply, NotEq, Or, Subtract, UnaryNot};
+        use Rule::{Add, And, Between, ConcatInfixOp, Divide, Eq, Escape, Gt, GtEq, In, IsPostfix, CastPostfix, Like, Lt, LtEq, Multiply, NotEq, Or, Subtract, UnaryNot};
 
         // Precedence is defined lowest to highest.
         PrattParser::new()
@@ -1568,7 +1568,7 @@ lazy_static::lazy_static! {
             )
             .op(Op::infix(Add, Left) | Op::infix(Subtract, Left))
             .op(Op::infix(Multiply, Left) | Op::infix(Divide, Left) | Op::infix(ConcatInfixOp, Left))
-            .op(Op::postfix(IsNullPostfix))
+            .op(Op::postfix(IsPostfix))
             .op(Op::postfix(CastPostfix))
     };
 }
@@ -1793,9 +1793,10 @@ enum ParseExpression {
         is_not: bool,
         child: Box<ParseExpression>,
     },
-    IsNull {
+    Is {
         is_not: bool,
         child: Box<ParseExpression>,
+        value: Option<bool>,
     },
     Cast {
         cast_type: CastType,
@@ -2270,10 +2271,20 @@ impl ParseExpression {
                     op_id
                 }
             }
-            ParseExpression::IsNull { is_not, child } => {
+            ParseExpression::Is {
+                is_not,
+                child,
+                value,
+            } => {
                 let child_plan_id = child.populate_plan(plan, worker)?;
                 let child_covered_with_row = plan.row(child_plan_id)?;
-                let op_id = plan.add_unary(Unary::IsNull, child_covered_with_row)?;
+                let op_id = match value {
+                    None => plan.add_unary(Unary::IsNull, child_covered_with_row)?,
+                    Some(b) => {
+                        let right_operand = plan.add_const(Value::Boolean(*b));
+                        plan.add_bool(child_covered_with_row, Bool::Eq, right_operand)?
+                    }
+                };
                 if *is_not {
                     plan.add_unary(Unary::Not, op_id)?
                 } else {
@@ -2826,14 +2837,25 @@ where
                     let cast_type = cast_type_from_pair(ty_pair)?;
                     Ok(ParseExpression::Cast { child: Box::new(child), cast_type })
                 }
-                Rule::IsNullPostfix => {
-                    let is_not = match op.into_inner().len() {
-                        1 => true,
-                        0 => false,
-                        _ => unreachable!("IsNull must have 0 or 1 children")
+                Rule::IsPostfix => {
+                    let mut inner = op.into_inner();
+                    let (is_not, value_index) = match inner.len() {
+                        2 => (true, 1),
+                        1 => (false, 0),
+                        _ => unreachable!("Is must have 1 or 2 children")
                     };
-                    Ok(ParseExpression::IsNull { is_not, child: Box::new(child)})
-                },
+                    let value_rule = inner
+                        .nth(value_index)
+                        .expect("Value must be present under Is")
+                        .as_rule();
+                    let value = match value_rule {
+                        Rule::True => Some(true),
+                        Rule::False => Some(false),
+                        Rule::Unknown | Rule::Null => None,
+                        _ => unreachable!("Is value must be TRUE, FALSE, NULL or UNKNOWN")
+                    };
+                    Ok(ParseExpression::Is { is_not, child: Box::new(child), value })
+                }
                 rule => unreachable!("Expr::parse expected postfix operator, found {:?}", rule),
             }
         })
diff --git a/sbroad-core/src/frontend/sql/ir/tests.rs b/sbroad-core/src/frontend/sql/ir/tests.rs
index 0daa3b1dc..0fe1f6c25 100644
--- a/sbroad-core/src/frontend/sql/ir/tests.rs
+++ b/sbroad-core/src/frontend/sql/ir/tests.rs
@@ -24,6 +24,11 @@ fn sql_to_optimized_ir_add_motions_err(query: &str) -> SbroadError {
     plan.add_motions().unwrap_err()
 }
 
+fn check_output(input: &str, params: Vec<Value>, expected_explain: &str) {
+    let plan = sql_to_optimized_ir(input, params);
+    assert_eq!(expected_explain, plan.as_explain().unwrap());
+}
+
 #[test]
 fn front_sql1() {
     let input = r#"SELECT "identification_number", "product_code" FROM "hash_testing"
@@ -409,6 +414,95 @@ execution options:
     assert_eq!(expected_explain, plan.as_explain().unwrap());
 }
 
+#[test]
+fn front_sql_is_true() {
+    check_output(
+        "select true is true",
+        vec![],
+        r#"projection (ROW(true::boolean) = true::boolean -> "col_1")
+execution options:
+    vdbe_max_steps = 45000
+    vtable_max_rows = 5000
+"#,
+    );
+
+    check_output(
+        "select true is not true",
+        vec![],
+        r#"projection (not ROW(true::boolean) = true::boolean -> "col_1")
+execution options:
+    vdbe_max_steps = 45000
+    vtable_max_rows = 5000
+"#,
+    );
+}
+
+#[test]
+fn front_sql_is_false() {
+    check_output(
+        "select true is false",
+        vec![],
+        r#"projection (ROW(true::boolean) = false::boolean -> "col_1")
+execution options:
+    vdbe_max_steps = 45000
+    vtable_max_rows = 5000
+"#,
+    );
+
+    check_output(
+        "select true is not false",
+        vec![],
+        r#"projection (not ROW(true::boolean) = false::boolean -> "col_1")
+execution options:
+    vdbe_max_steps = 45000
+    vtable_max_rows = 5000
+"#,
+    );
+}
+
+#[test]
+fn front_sql_is_null_unknown() {
+    check_output(
+        "select true is null",
+        vec![],
+        r#"projection (ROW(true::boolean) is null -> "col_1")
+execution options:
+    vdbe_max_steps = 45000
+    vtable_max_rows = 5000
+"#,
+    );
+
+    check_output(
+        "select true is unknown",
+        vec![],
+        r#"projection (ROW(true::boolean) is null -> "col_1")
+execution options:
+    vdbe_max_steps = 45000
+    vtable_max_rows = 5000
+"#,
+    );
+
+    check_output(
+        "select true is not null",
+        vec![],
+        r#"projection (not ROW(true::boolean) is null -> "col_1")
+execution options:
+    vdbe_max_steps = 45000
+    vtable_max_rows = 5000
+"#,
+    );
+
+    check_output(
+        "select true is not unknown",
+        vec![],
+        r#"projection (not ROW(true::boolean) is null -> "col_1")
+execution options:
+    vdbe_max_steps = 45000
+    vtable_max_rows = 5000
+"#,
+    );
+}
+
 #[test]
 fn front_sql_between_with_additional_non_bool_value_from_left() {
     let input = r#"SELECT * FROM "test_space" WHERE 42 and 1 between 2 and 3"#;
diff --git a/sbroad-core/src/frontend/sql/query.pest b/sbroad-core/src/frontend/sql/query.pest
index 78ce15ee3..43ee7474c 100644
--- a/sbroad-core/src/frontend/sql/query.pest
+++ b/sbroad-core/src/frontend/sql/query.pest
@@ -320,10 +320,11 @@ Expr = ${ ExprAtomValue ~ (ExprInfixOpo ~ ExprAtomValue)* }
             Lt    = { "<" }
             LtEq  = { "<=" }
             NotEq = { "<>" | "!=" }
-    ExprAtomValue = _{ (UnaryNot ~ W)* ~ AtomicExpr ~ CastPostfix* ~ (W ~ IsNullPostfix)? }
+    ExprAtomValue = _{ (UnaryNot ~ W)* ~ AtomicExpr ~ CastPostfix* ~ (W ~ IsPostfix)* }
         UnaryNot   = { NotFlag }
         CastPostfix = { "::" ~ ColumnDefType }
-        IsNullPostfix = ${ ^"is" ~ W ~ (NotFlag ~ W )? ~ ^"null" }
+        IsPostfix = ${ ^"is" ~ W ~ (NotFlag ~ W)? ~ (True | False | Unknown | Null) }
+            Unknown = { ^"unknown" }
         AtomicExpr = _{ Literal | Parameter | CastOp | Trim | CurrentDate | IdentifierWithOptionalContinuation | ExpressionInParentheses | UnaryOperator | Case | SubQuery | Row }
             Literal = { True | False | Null | Double | Decimal | Unsigned | Integer | SingleQuotedString }
                 True     = { ^"true" }
-- 
GitLab