From c64b2895f29969d3d25746e9b376a177e8bcfdbf Mon Sep 17 00:00:00 2001
From: Kaitmazian Maksim <m.kaitmazian@picodata.io>
Date: Wed, 12 Jun 2024 17:57:40 +0300
Subject: [PATCH] feat: support postgres cast notation

---
 doc/sql/query.ebnf                      |  2 +-
 sbroad-core/src/executor/tests/cast.rs  | 72 ++++++++++++++++++
 sbroad-core/src/frontend/sql.rs         | 97 +++++++++++++++----------
 sbroad-core/src/frontend/sql/query.pest | 12 ++-
 4 files changed, 139 insertions(+), 44 deletions(-)

diff --git a/doc/sql/query.ebnf b/doc/sql/query.ebnf
index 773c502140..43f63ce708 100644
--- a/doc/sql/query.ebnf
+++ b/doc/sql/query.ebnf
@@ -47,7 +47,7 @@ aggregate   ::= ('AVG' | 'COUNT' | 'MAX' | 'MIN' | 'SUM' | 'TOTAL') '(' ('DISTIN
 case        ::= 'CASE' expression?
                 ('WHEN' expression 'THEN' expression)+
                 ('ELSE' expression)? 'END'
-cast        ::= 'CAST' '(' expression 'AS' type ')'
+cast        ::= ('CAST' '(' expression 'AS' type ')') | (experssion "::" type)
 to_char     ::= 'TO_CHAR' '(' expression ',' format ')'
 to_date     ::= 'TO_DATE' '(' expression ',' format ')'
 trim        ::= 'TRIM' '('
diff --git a/sbroad-core/src/executor/tests/cast.rs b/sbroad-core/src/executor/tests/cast.rs
index 6fb87f2792..2b4e3928ce 100644
--- a/sbroad-core/src/executor/tests/cast.rs
+++ b/sbroad-core/src/executor/tests/cast.rs
@@ -126,3 +126,75 @@ fn cast14_test() {
         vec![],
     );
 }
+
+#[test]
+fn pgcast1_test() {
+    broadcast_check(
+        r#"SELECT true::bool FROM "t1""#,
+        r#"SELECT CAST (? as bool) as "COL_1" FROM "t1""#,
+        vec![Value::from(true)],
+    );
+}
+
+#[test]
+fn pgcast2_test() {
+    broadcast_check(
+        r#"SELECT false::bool FROM "t1""#,
+        r#"SELECT CAST (? as bool) as "COL_1" FROM "t1""#,
+        vec![Value::from(false)],
+    );
+}
+
+#[test]
+fn pgcast3_test() {
+    broadcast_check(
+        r#"SELECT '1.0'::decimal FROM "t1""#,
+        r#"SELECT CAST (? as decimal) as "COL_1" FROM "t1""#,
+        vec![Value::from("1.0")],
+    );
+}
+
+#[test]
+fn pgcast4_test() {
+    broadcast_check(
+        r#"SELECT '1.0'::double FROM "t1""#,
+        r#"SELECT CAST (? as double) as "COL_1" FROM "t1""#,
+        vec![Value::from("1.0")],
+    );
+}
+
+#[test]
+fn pgcast5_test() {
+    broadcast_check(
+        r#"SELECT '1'::int FROM "t1""#,
+        r#"SELECT CAST (? as int) as "COL_1" FROM "t1""#,
+        vec![Value::from("1")],
+    );
+}
+
+#[test]
+fn pgcast6_test() {
+    broadcast_check(
+        r#"SELECT '1'::integer FROM "t1""#,
+        r#"SELECT CAST (? as int) as "COL_1" FROM "t1""#,
+        vec![Value::from("1")],
+    );
+}
+
+#[test]
+fn pgcast7_test() {
+    broadcast_check(
+        r#"SELECT 1::string FROM "t1""#,
+        r#"SELECT CAST (? as string) as "COL_1" FROM "t1""#,
+        vec![Value::from(1_u64)],
+    );
+}
+
+#[test]
+fn pgcast8_test() {
+    broadcast_check(
+        r#"SELECT 1::text FROM "t1""#,
+        r#"SELECT CAST (? as text) as "COL_1" FROM "t1""#,
+        vec![Value::from(1_u64)],
+    );
+}
diff --git a/sbroad-core/src/frontend/sql.rs b/sbroad-core/src/frontend/sql.rs
index 43557c9d62..f111cbb86f 100644
--- a/sbroad-core/src/frontend/sql.rs
+++ b/sbroad-core/src/frontend/sql.rs
@@ -1201,6 +1201,55 @@ fn parse_option<M: Metadata>(
     Ok(value)
 }
 
+fn parse_cast_expr<M: Metadata>(
+    pair: Pair<Rule>,
+    referred_relation_ids: &[usize],
+    worker: &mut ExpressionsWorker<M>,
+    plan: &mut Plan,
+) -> Result<ParseExpression, SbroadError> {
+    let mut inner_pairs = pair.into_inner();
+    let expr_pair = inner_pairs.next().expect("Cast has no expr child.");
+    let child_parse_expr =
+        parse_expr_pratt(expr_pair.into_inner(), referred_relation_ids, worker, plan)?;
+
+    let mut cast_types = Vec::with_capacity(inner_pairs.len());
+    for type_pairs in inner_pairs {
+        let cast_type = if type_pairs.as_rule() == Rule::ColumnDefType {
+            let mut column_def_type_pairs = type_pairs.into_inner();
+            let column_def_type = column_def_type_pairs
+                .next()
+                .expect("concrete type expected under ColumnDefType");
+            if column_def_type.as_rule() == Rule::TypeVarchar {
+                let mut type_pairs_inner = column_def_type.into_inner();
+                let varchar_length = type_pairs_inner
+                    .next()
+                    .expect("Length is missing under Varchar");
+                let len = varchar_length.as_str().parse::<usize>().map_err(|e| {
+                    SbroadError::ParsingError(
+                        Entity::Value,
+                        format_smolstr!("failed to parse varchar length: {e:?}"),
+                    )
+                })?;
+                Ok(CastType::Varchar(len))
+            } else {
+                CastType::try_from(&column_def_type.as_rule())
+            }
+        } else {
+            // TypeAny.
+            CastType::try_from(&type_pairs.as_rule())
+        }?;
+
+        cast_types.push(cast_type);
+    }
+
+    assert!(!cast_types.is_empty(), "cast expression has no cast types");
+
+    Ok(ParseExpression::Cast {
+        cast_types,
+        child: Box::new(child_parse_expr),
+    })
+}
+
 enum ParameterSource<'parameter> {
     AstNode {
         ast: &'parameter AbstractSyntaxTree,
@@ -1470,7 +1519,7 @@ enum ParseExpression {
         child: Box<ParseExpression>,
     },
     Cast {
-        cast_type: CastType,
+        cast_types: Vec<CastType>,
         child: Box<ParseExpression>,
     },
     Case {
@@ -1586,9 +1635,12 @@ impl ParseExpression {
                     plan.add_covered_with_parentheses(child_plan_id)
                 }
             }
-            ParseExpression::Cast { cast_type, child } => {
-                let child_plan_id = child.populate_plan(plan, worker)?;
-                plan.add_cast(child_plan_id, cast_type.clone())?
+            ParseExpression::Cast { cast_types, child } => {
+                let mut child_plan_id = child.populate_plan(plan, worker)?;
+                for cast_type in cast_types {
+                    child_plan_id = plan.add_cast(child_plan_id, cast_type.clone())?;
+                }
+                child_plan_id
             }
             ParseExpression::Case {
                 search_expr,
@@ -2142,41 +2194,8 @@ where
                     ParseExpression::Exists { is_not: first_is_not, child: Box::new(child_parse_expr)}
                 }
                 Rule::Trim => parse_trim(primary, referred_relation_ids, worker, plan)?,
-                Rule::Cast => {
-                    let mut inner_pairs = primary.into_inner();
-                    let expr_pair = inner_pairs.next().expect("Cast has no expr child.");
-                    let child_parse_expr = parse_expr_pratt(
-                        expr_pair.into_inner(),
-                        referred_relation_ids,
-                        worker,
-                        plan
-                    )?;
-                    let type_pairs = inner_pairs.next().expect("Cast has no type child");
-                    let cast_type = if type_pairs.as_rule() == Rule::ColumnDefType {
-                        let mut column_def_type_pairs = type_pairs.into_inner();
-                        let column_def_type = column_def_type_pairs.next()
-                            .expect("concrete type expected under ColumnDefType");
-                        if column_def_type.as_rule() == Rule::TypeVarchar {
-                            let mut type_pairs_inner = column_def_type.into_inner();
-                            let varchar_length = type_pairs_inner.next().expect("Length is missing under Varchar");
-                            let len = varchar_length
-                                .as_str()
-                                .parse::<usize>()
-                                .map_err(|e| {
-                                    SbroadError::ParsingError(
-                                        Entity::Value,
-                                        format_smolstr!("failed to parse varchar length: {e:?}"),
-                                    )
-                                })?;
-                            Ok(CastType::Varchar(len))
-                        } else {
-                            CastType::try_from(&column_def_type.as_rule())
-                        }
-                    } else {
-                        // TypeAny.
-                        CastType::try_from(&type_pairs.as_rule())
-                    }?;
-                    ParseExpression::Cast { cast_type, child: Box::new(child_parse_expr) }
+                Rule::CastOp | Rule::CastExpr => {
+                    parse_cast_expr(primary, referred_relation_ids, worker, plan)?
                 }
                 Rule::Case => {
                     let mut inner_pairs = primary.into_inner();
diff --git a/sbroad-core/src/frontend/sql/query.pest b/sbroad-core/src/frontend/sql/query.pest
index a7e3a07ac2..65a9e16a81 100644
--- a/sbroad-core/src/frontend/sql/query.pest
+++ b/sbroad-core/src/frontend/sql/query.pest
@@ -226,7 +226,7 @@ Identifier = @{ DelimitedIdentifier | RegularIdentifier  }
                             &IdentifierInapplicableSymbol }
         RegularIdentifierFirstApplicableSymbol = { !(IdentifierInapplicableSymbol | ASCII_DIGIT) ~ ANY }
         RegularIdentifierApplicableSymbol = { !IdentifierInapplicableSymbol ~ ANY }
-        IdentifierInapplicableSymbol = { WHITESPACE | "." | "," | "(" | EOF | ")" | "\""
+        IdentifierInapplicableSymbol = { WHITESPACE | "." | "," | "(" | EOF | ")" | "\"" | ":"
                                        | "'" | ArithInfixOp | ConcatInfixOp | NotEq | GtEq
                                        | Gt | LtEq | Lt | Eq }
         KeywordCoverage = { Keyword ~ IdentifierInapplicableSymbol }
@@ -267,10 +267,10 @@ Expr = { ExprAtomValue ~ (ExprInfixOp ~ ExprAtomValue)* }
             LtEq  = { "<=" }
             NotEq = { "<>" | "!=" }
             In    = { NotFlag? ~ ^"in" }
-    ExprAtomValue = _{ UnaryNot* ~ AtomicExpr ~ IsNullPostfix? }
+    ExprAtomValue = _{ CastExpr | (UnaryNot* ~ AtomicExpr ~ IsNullPostfix?) }
         UnaryNot   = @{ NotFlag }
         IsNullPostfix = { ^"is" ~ NotFlag? ~ ^"null" }
-        AtomicExpr = _{ Literal | Parameter | Cast | Trim | CurrentDate | IdentifierWithOptionalContinuation | ExpressionInParentheses | UnaryOperator | Case | SubQuery | Row }
+        AtomicExpr = _{ Literal | Parameter | CastOp | Trim | CurrentDate | IdentifierWithOptionalContinuation | ExpressionInParentheses | UnaryOperator | Case | SubQuery | Row }
             Literal = { True | False | Null | Double | Decimal | Unsigned | Integer | SingleQuotedString }
                 True     = { ^"true" }
                 False    = { ^"false" }
@@ -310,7 +310,7 @@ Expr = { ExprAtomValue ~ (ExprInfixOp ~ ExprAtomValue)* }
             }
                 CaseWhenBlock = { ^"when" ~ Expr ~ ^"then" ~ Expr }
                 CaseElseBlock = { ^"else" ~ Expr }
-            Cast = { ^"cast" ~ "(" ~ Expr ~ ^"as" ~ TypeCast ~ ")" }
+            CastOp = { ^"cast" ~ "(" ~ Expr ~ ^"as" ~ TypeCast ~ ")" }
                 TypeCast = _{ TypeAny | ColumnDefType }
                 ColumnDefType = { TypeBool | TypeDatetime | TypeDecimal | TypeDouble | TypeInt | TypeNumber
                                    | TypeScalar | TypeString | TypeText | TypeUnsigned | TypeVarchar | TypeUuid }
@@ -330,6 +330,10 @@ Expr = { ExprAtomValue ~ (ExprInfixOp ~ ExprAtomValue)* }
             UnaryOperator = _{ Exists }
                 Exists = { NotFlag? ~ ^"exists" ~ SubQuery }
             Row = { "(" ~ Expr ~ ("," ~ Expr)* ~ ")" }
+        CastExpr = { AtomicExprWrapped ~ ("::" ~ ColumnDefType)+ }
+            // In CastOp rule, AtomicExpr is stored as a child of Expr, so we try to imitate
+            // this in order parse CastExpr and CastOp rules in the same way.
+            AtomicExprWrapped = { AtomicExpr }
 
 Distinct = { ^"distinct" }
 NotFlag = { ^"not" }
-- 
GitLab