From 165967300ad861efd2f47b1841e3703eb14ea6f2 Mon Sep 17 00:00:00 2001
From: Denis Smirnov <sd@picodata.io>
Date: Thu, 20 Oct 2022 19:00:25 +0700
Subject: [PATCH] feat: implement SQL type cast operator

---
 doc/sql/feature_taxonomy.md               |   4 +-
 sbroad-core/src/backend/sql/ir.rs         |   2 +
 sbroad-core/src/backend/sql/tree.rs       |  24 ++++
 sbroad-core/src/executor/tests.rs         |   3 +
 sbroad-core/src/executor/tests/cast.rs    | 156 ++++++++++++++++++++++
 sbroad-core/src/frontend/sql.rs           |  62 ++++++++-
 sbroad-core/src/frontend/sql/ast.rs       |  26 ++++
 sbroad-core/src/frontend/sql/ast/tests.rs |   2 +-
 sbroad-core/src/frontend/sql/query.pest   |  19 ++-
 sbroad-core/src/ir/explain.rs             |   3 +
 sbroad-core/src/ir/expression.rs          |  11 ++
 sbroad-core/src/ir/expression/cast.rs     |  80 +++++++++++
 sbroad-core/src/ir/tree.rs                |   8 +-
 13 files changed, 393 insertions(+), 7 deletions(-)
 create mode 100644 sbroad-core/src/executor/tests/cast.rs
 create mode 100644 sbroad-core/src/ir/expression/cast.rs

diff --git a/doc/sql/feature_taxonomy.md b/doc/sql/feature_taxonomy.md
index a27747b9be..4861440cc6 100644
--- a/doc/sql/feature_taxonomy.md
+++ b/doc/sql/feature_taxonomy.md
@@ -441,8 +441,8 @@
 **no**
 
 ## F201. CAST function.
-1. Subclause 6.13, “cast specification”: For all supported data types: **no**
-1. Subclause 6.26, “value expression”: cast specification: **no**
+1. Subclause 6.13, “cast specification”: For all supported data types: **yes**
+1. Subclause 6.26, “value expression”: cast specification: **yes**
 
 ## F221. Explicit defaults.
 1. Subclause 6.5, “contextually typed value specification”: default specification: **no**
diff --git a/sbroad-core/src/backend/sql/ir.rs b/sbroad-core/src/backend/sql/ir.rs
index f4c6703a2f..8a010302ff 100644
--- a/sbroad-core/src/backend/sql/ir.rs
+++ b/sbroad-core/src/backend/sql/ir.rs
@@ -167,6 +167,7 @@ impl ExecutionPlan {
                         sql.push_str("as ");
                         sql.push_str(s);
                     }
+                    SyntaxData::Cast => sql.push_str("CAST"),
                     SyntaxData::CloseParenthesis => sql.push(')'),
                     SyntaxData::Comma => sql.push(','),
                     SyntaxData::Condition => sql.push_str("ON"),
@@ -202,6 +203,7 @@ impl ExecutionPlan {
                             Node::Expression(expr) => match expr {
                                 Expression::Alias { .. }
                                 | Expression::Bool { .. }
+                                | Expression::Cast { .. }
                                 | Expression::Row { .. }
                                 | Expression::Unary { .. } => {}
                                 Expression::Constant { value, .. } => {
diff --git a/sbroad-core/src/backend/sql/tree.rs b/sbroad-core/src/backend/sql/tree.rs
index 06bb45a10f..3dffae8763 100644
--- a/sbroad-core/src/backend/sql/tree.rs
+++ b/sbroad-core/src/backend/sql/tree.rs
@@ -20,6 +20,8 @@ use sbroad_proc::otm_child_span;
 pub enum SyntaxData {
     /// "as alias_name"
     Alias(String),
+    /// "cast"
+    Cast,
     /// ")"
     CloseParenthesis,
     /// ","
@@ -64,6 +66,14 @@ impl SyntaxNode {
         }
     }
 
+    fn new_cast() -> Self {
+        SyntaxNode {
+            data: SyntaxData::Cast,
+            left: None,
+            right: Vec::new(),
+        }
+    }
+
     fn new_close() -> Self {
         SyntaxNode {
             data: SyntaxData::CloseParenthesis,
@@ -626,6 +636,20 @@ impl<'p> SyntaxPlan<'p> {
                 }
             },
             Node::Expression(expr) => match expr {
+                Expression::Cast { child, to } => {
+                    let sn = SyntaxNode::new_pointer(
+                        id,
+                        Some(self.nodes.push_syntax_node(SyntaxNode::new_cast())),
+                        vec![
+                            self.nodes.push_syntax_node(SyntaxNode::new_open()),
+                            self.nodes.get_syntax_node_id(*child)?,
+                            self.nodes
+                                .push_syntax_node(SyntaxNode::new_alias(String::from(to))),
+                            self.nodes.push_syntax_node(SyntaxNode::new_close()),
+                        ],
+                    );
+                    Ok(self.nodes.push_syntax_node(sn))
+                }
                 Expression::Constant { .. } => {
                     let sn = SyntaxNode::new_parameter(id);
                     Ok(self.nodes.push_syntax_node(sn))
diff --git a/sbroad-core/src/executor/tests.rs b/sbroad-core/src/executor/tests.rs
index a6feabe7e0..141a32949f 100644
--- a/sbroad-core/src/executor/tests.rs
+++ b/sbroad-core/src/executor/tests.rs
@@ -1352,6 +1352,9 @@ mod between;
 #[cfg(test)]
 mod bucket_id;
 
+#[cfg(test)]
+mod cast;
+
 #[cfg(test)]
 mod empty_motion;
 
diff --git a/sbroad-core/src/executor/tests/cast.rs b/sbroad-core/src/executor/tests/cast.rs
new file mode 100644
index 0000000000..32c5e74785
--- /dev/null
+++ b/sbroad-core/src/executor/tests/cast.rs
@@ -0,0 +1,156 @@
+use pretty_assertions::assert_eq;
+
+use crate::backend::sql::ir::PatternWithParams;
+use crate::executor::engine::mock::RouterRuntimeMock;
+use crate::executor::result::ProducerResult;
+use crate::ir::value::Value;
+
+use super::*;
+
+#[test]
+fn cast1_test() {
+    cast_check(
+        r#"SELECT CAST('1' as any) FROM "t1""#,
+        r#"SELECT CAST (? as any) as "COLUMN_1" FROM "t1""#,
+        vec![Value::from("1")],
+    );
+}
+
+#[test]
+fn cast2_test() {
+    cast_check(
+        r#"SELECT CAST(true as bool) FROM "t1""#,
+        r#"SELECT CAST (? as bool) as "COLUMN_1" FROM "t1""#,
+        vec![Value::from(true)],
+    );
+}
+
+#[test]
+fn cast3_test() {
+    cast_check(
+        r#"SELECT CAST(false as boolean) FROM "t1""#,
+        r#"SELECT CAST (? as bool) as "COLUMN_1" FROM "t1""#,
+        vec![Value::from(false)],
+    );
+}
+
+#[test]
+fn cast4_test() {
+    cast_check(
+        r#"SELECT CAST('1.0' as decimal) FROM "t1""#,
+        r#"SELECT CAST (? as decimal) as "COLUMN_1" FROM "t1""#,
+        vec![Value::from("1.0")],
+    );
+}
+
+#[test]
+fn cast5_test() {
+    cast_check(
+        r#"SELECT CAST('1.0' as double) FROM "t1""#,
+        r#"SELECT CAST (? as double) as "COLUMN_1" FROM "t1""#,
+        vec![Value::from("1.0")],
+    );
+}
+
+#[test]
+fn cast6_test() {
+    cast_check(
+        r#"SELECT CAST('1' as int) FROM "t1""#,
+        r#"SELECT CAST (? as int) as "COLUMN_1" FROM "t1""#,
+        vec![Value::from("1")],
+    );
+}
+
+#[test]
+fn cast7_test() {
+    cast_check(
+        r#"SELECT CAST('1' as integer) FROM "t1""#,
+        r#"SELECT CAST (? as int) as "COLUMN_1" FROM "t1""#,
+        vec![Value::from("1")],
+    );
+}
+
+#[test]
+fn cast8_test() {
+    cast_check(
+        r#"SELECT CAST('1' as number) FROM "t1""#,
+        r#"SELECT CAST (? as number) as "COLUMN_1" FROM "t1""#,
+        vec![Value::from("1")],
+    );
+}
+
+#[test]
+fn cast9_test() {
+    cast_check(
+        r#"SELECT CAST('1' as scalar) FROM "t1""#,
+        r#"SELECT CAST (? as scalar) as "COLUMN_1" FROM "t1""#,
+        vec![Value::from("1")],
+    );
+}
+
+#[test]
+fn cast10_test() {
+    cast_check(
+        r#"SELECT CAST(1 as string) FROM "t1""#,
+        r#"SELECT CAST (? as string) as "COLUMN_1" FROM "t1""#,
+        vec![Value::from(1_u64)],
+    );
+}
+
+#[test]
+fn cast11_test() {
+    cast_check(
+        r#"SELECT CAST(1 as text) FROM "t1""#,
+        r#"SELECT CAST (? as text) as "COLUMN_1" FROM "t1""#,
+        vec![Value::from(1_u64)],
+    );
+}
+
+#[test]
+fn cast12_test() {
+    cast_check(
+        r#"SELECT CAST('1' as unsigned) FROM "t1""#,
+        r#"SELECT CAST (? as unsigned) as "COLUMN_1" FROM "t1""#,
+        vec![Value::from("1")],
+    );
+}
+
+#[test]
+fn cast13_test() {
+    cast_check(
+        r#"SELECT CAST(1 as varchar(10)) FROM "t1""#,
+        r#"SELECT CAST (? as varchar(10)) as "COLUMN_1" FROM "t1""#,
+        vec![Value::from(1_u64)],
+    );
+}
+
+#[test]
+fn cast14_test() {
+    cast_check(
+        r#"SELECT CAST(bucket_id("a") as varchar(100)) FROM "t1""#,
+        r#"SELECT CAST ("BUCKET_ID" ("t1"."a") as varchar(100)) as "COLUMN_1" FROM "t1""#,
+        vec![],
+    );
+}
+
+fn cast_check(sql: &str, pattern: &str, params: Vec<Value>) {
+    let coordinator = RouterRuntimeMock::new();
+
+    let mut query = Query::new(&coordinator, sql, vec![]).unwrap();
+    let result = *query
+        .dispatch()
+        .unwrap()
+        .downcast::<ProducerResult>()
+        .unwrap();
+
+    let mut expected = ProducerResult::new();
+
+    expected.rows.push(vec![
+        Value::String("Execute query on all buckets".to_string()),
+        Value::String(String::from(PatternWithParams::new(
+            pattern.to_string(),
+            params,
+        ))),
+    ]);
+    assert_eq!(expected, result);
+}
diff --git a/sbroad-core/src/frontend/sql.rs b/sbroad-core/src/frontend/sql.rs
index 1a55183f4d..98d15e76c8 100644
--- a/sbroad-core/src/frontend/sql.rs
+++ b/sbroad-core/src/frontend/sql.rs
@@ -15,6 +15,7 @@ use crate::frontend::sql::ast::{
 };
 use crate::frontend::sql::ir::Translation;
 use crate::frontend::Ast;
+use crate::ir::expression::cast::Type as CastType;
 use crate::ir::expression::Expression;
 use crate::ir::operator::{Bool, Relational, Unary};
 use crate::ir::value::Value;
@@ -530,6 +531,45 @@ impl Ast for AbstractSyntaxTree {
                     let and_id = plan.add_cond(greater_eq_id, Bool::And, less_eq_id)?;
                     map.add(*id, and_id);
                 }
+                Type::Cast => {
+                    let ast_child_id = node.children.first().ok_or_else(|| {
+                        QueryPlannerError::CustomError("Condition has no children.".into())
+                    })?;
+                    let plan_child_id = map.get(*ast_child_id)?;
+                    let ast_type_id = node.children.get(1).ok_or_else(|| {
+                        QueryPlannerError::CustomError(
+                            "Cast type node id is not found among cast children.".into(),
+                        )
+                    })?;
+                    let ast_type = self.nodes.get_node(*ast_type_id)?;
+                    let cast_type = if ast_type.rule == Type::TypeVarchar {
+                        // Get the length of the varchar.
+                        let ast_len_id = ast_type.children.first().ok_or_else(|| {
+                            QueryPlannerError::CustomError(
+                                "Cast type length node id is not found among cast children.".into(),
+                            )
+                        })?;
+                        let ast_len = self.nodes.get_node(*ast_len_id)?;
+                        let len = ast_len
+                            .value
+                            .as_ref()
+                            .ok_or_else(|| {
+                                QueryPlannerError::CustomError("Varchar length is empty".into())
+                            })?
+                            .parse::<usize>()
+                            .map_err(|e| {
+                                QueryPlannerError::CustomError(format!(
+                                    "Failed to parse varchar length: {:?}",
+                                    e
+                                ))
+                            })?;
+                        Ok(CastType::Varchar(len))
+                    } else {
+                        CastType::try_from(&ast_type.rule)
+                    }?;
+                    let cast_id = plan.add_cast(plan_child_id, cast_type)?;
+                    map.add(*id, cast_id);
+                }
                 Type::Condition => {
                     let ast_child_id = node.children.first().ok_or_else(|| {
                         QueryPlannerError::CustomError("Condition has no children.".into())
@@ -769,10 +809,22 @@ impl Ast for AbstractSyntaxTree {
                 Type::AliasName
                 | Type::ColumnName
                 | Type::FunctionName
+                | Type::Length
                 | Type::ScanName
                 | Type::Select
                 | Type::SubQueryName
-                | Type::TargetColumns => {}
+                | Type::TargetColumns
+                | Type::TypeAny
+                | Type::TypeBool
+                | Type::TypeDecimal
+                | Type::TypeDouble
+                | Type::TypeInt
+                | Type::TypeNumber
+                | Type::TypeScalar
+                | Type::TypeString
+                | Type::TypeText
+                | Type::TypeUnsigned
+                | Type::TypeVarchar => {}
                 rule => {
                     return Err(QueryPlannerError::CustomError(format!(
                         "Not implements type: {:?}",
@@ -870,6 +922,10 @@ impl Plan {
                         child: ref param_id,
                         ..
                     }
+                    | Expression::Cast {
+                        child: ref param_id,
+                        ..
+                    }
                     | Expression::Unary {
                         child: ref param_id,
                         ..
@@ -945,6 +1001,10 @@ impl Plan {
                         child: ref mut param_id,
                         ..
                     }
+                    | Expression::Cast {
+                        child: ref mut param_id,
+                        ..
+                    }
                     | Expression::Unary {
                         child: ref mut param_id,
                         ..
diff --git a/sbroad-core/src/frontend/sql/ast.rs b/sbroad-core/src/frontend/sql/ast.rs
index 83122b2e19..f71a269f2d 100644
--- a/sbroad-core/src/frontend/sql/ast.rs
+++ b/sbroad-core/src/frontend/sql/ast.rs
@@ -29,6 +29,7 @@ pub enum Type {
     And,
     Asterisk,
     Between,
+    Cast,
     Column,
     ColumnName,
     Condition,
@@ -48,6 +49,7 @@ pub enum Type {
     Integer,
     IsNull,
     IsNotNull,
+    Length,
     Lt,
     LtEq,
     Name,
@@ -71,6 +73,17 @@ pub enum Type {
     Table,
     TargetColumns,
     True,
+    TypeAny,
+    TypeBool,
+    TypeDecimal,
+    TypeDouble,
+    TypeInt,
+    TypeNumber,
+    TypeScalar,
+    TypeString,
+    TypeText,
+    TypeUnsigned,
+    TypeVarchar,
     UnionAll,
     Unsigned,
     Value,
@@ -87,6 +100,7 @@ impl Type {
             Rule::And => Ok(Type::And),
             Rule::Asterisk => Ok(Type::Asterisk),
             Rule::Between => Ok(Type::Between),
+            Rule::Cast => Ok(Type::Cast),
             Rule::Column => Ok(Type::Column),
             Rule::ColumnName => Ok(Type::ColumnName),
             Rule::Condition => Ok(Type::Condition),
@@ -106,6 +120,7 @@ impl Type {
             Rule::Insert => Ok(Type::Insert),
             Rule::IsNull => Ok(Type::IsNull),
             Rule::IsNotNull => Ok(Type::IsNotNull),
+            Rule::Length => Ok(Type::Length),
             Rule::Lt => Ok(Type::Lt),
             Rule::LtEq => Ok(Type::LtEq),
             Rule::Name => Ok(Type::Name),
@@ -128,6 +143,17 @@ impl Type {
             Rule::Table => Ok(Type::Table),
             Rule::TargetColumns => Ok(Type::TargetColumns),
             Rule::True => Ok(Type::True),
+            Rule::TypeAny => Ok(Type::TypeAny),
+            Rule::TypeBool => Ok(Type::TypeBool),
+            Rule::TypeDecimal => Ok(Type::TypeDecimal),
+            Rule::TypeDouble => Ok(Type::TypeDouble),
+            Rule::TypeInt => Ok(Type::TypeInt),
+            Rule::TypeNumber => Ok(Type::TypeNumber),
+            Rule::TypeScalar => Ok(Type::TypeScalar),
+            Rule::TypeString => Ok(Type::TypeString),
+            Rule::TypeText => Ok(Type::TypeText),
+            Rule::TypeUnsigned => Ok(Type::TypeUnsigned),
+            Rule::TypeVarchar => Ok(Type::TypeVarchar),
             Rule::UnionAll => Ok(Type::UnionAll),
             Rule::Unsigned => Ok(Type::Unsigned),
             Rule::Value => Ok(Type::Value),
diff --git a/sbroad-core/src/frontend/sql/ast/tests.rs b/sbroad-core/src/frontend/sql/ast/tests.rs
index d4b648cff3..ff14b409cd 100644
--- a/sbroad-core/src/frontend/sql/ast/tests.rs
+++ b/sbroad-core/src/frontend/sql/ast/tests.rs
@@ -152,7 +152,7 @@ fn invalid_query() {
 1 | select a frAm t
   |        ^---
   |
-  = expected Alias, Asterisk, Function, True, False, Null, Decimal, Double, Integer, Unsigned, Row, or Parameter"#,
+  = expected Alias, Asterisk, Function, Cast, True, False, Null, Decimal, Double, Integer, Unsigned, Row, or Parameter"#,
         ),
         format!("{}", ast),
     );
diff --git a/sbroad-core/src/frontend/sql/query.pest b/sbroad-core/src/frontend/sql/query.pest
index 44f41c8f6c..1d2bc59bca 100644
--- a/sbroad-core/src/frontend/sql/query.pest
+++ b/sbroad-core/src/frontend/sql/query.pest
@@ -77,6 +77,23 @@ Function = { FunctionName ~ ("(" ~ FunctionArgs ~ ")") }
     FunctionName = @{ Name }
     FunctionArgs = _{ (Expr ~ ("," ~ Expr)*)? }
 
+Cast = { ^"cast" ~ "(" ~ Expr ~ ^"as" ~ Type ~ ")" }
+    Type = _{ TypeAny | TypeBool | TypeDecimal | TypeDouble
+             | TypeInt | TypeNumber | TypeScalar | TypeString
+             | TypeText | TypeUnsigned | TypeVarchar }
+    TypeAny = { ^"any" }
+    TypeBool = { (^"boolean" | ^"bool") }
+    TypeDecimal = { ^"decimal" }
+    TypeDouble = { ^"double" }
+    TypeInt = { (^"integer" | ^"int") }
+    TypeNumber = { ^"number" }
+    TypeScalar = { ^"scalar" }
+    TypeString = { ^"string" }
+    TypeText = { ^"text" }
+    TypeUnsigned = { ^"unsigned" }
+    TypeVarchar = { ^"varchar" ~ "(" ~ Length ~ ")" }
+        Length = @{ Unsigned }
+
 
 NameString = @{ !(WHITESPACE* ~ Keyword ~ WHITESPACE) ~ ('А' .. 'Я' | 'а' .. 'я' | 'A' .. 'Z' | 'a'..'z' | "-" | "_" | ASCII_DIGIT)+ }
 String = @{ !(WHITESPACE* ~ Keyword ~ WHITESPACE) ~ (Character | ("'" ~ "'") | "\"")* }
@@ -92,7 +109,7 @@ Punctuation = _{
 }
 Other = _{ "\\" | "/" | "@" | "%" | "&" | "*" | "#" | WHITESPACE }
 
-Value = _{ Parameter | Row | True | False | Null | Decimal | Double | Unsigned | Integer | SingleQuotedString | Function | Reference }
+Value = _{ Parameter | Row | True | False | Null | Decimal | Double | Unsigned | Integer | SingleQuotedString | Cast | Function | Reference }
     True = @{ ^"true" }
     False = @{ ^"false" }
     Null = @{ ^"null" }
diff --git a/sbroad-core/src/ir/explain.rs b/sbroad-core/src/ir/explain.rs
index 63e1848d20..641b5b19d9 100644
--- a/sbroad-core/src/ir/explain.rs
+++ b/sbroad-core/src/ir/explain.rs
@@ -41,6 +41,7 @@ impl Col {
                     column.alias = Some(name.to_string());
                 }
                 Expression::Bool { .. }
+                | Expression::Cast { .. }
                 | Expression::StableFunction { .. }
                 | Expression::Row { .. }
                 | Expression::Constant { .. }
@@ -223,6 +224,7 @@ impl Row {
                     row.add_col(RowVal::Const(value.clone()));
                 }
                 Expression::Bool { .. }
+                | Expression::Cast { .. }
                 | Expression::StableFunction { .. }
                 | Expression::Row { .. }
                 | Expression::Alias { .. }
@@ -309,6 +311,7 @@ impl Selection {
             }
             Expression::Reference { .. }
             | Expression::StableFunction { .. }
+            | Expression::Cast { .. }
             | Expression::Constant { .. }
             | Expression::Alias { .. } => {
                 return Err(QueryPlannerError::CustomError(
diff --git a/sbroad-core/src/ir/expression.rs b/sbroad-core/src/ir/expression.rs
index 8c36d26cd9..22d31c798b 100644
--- a/sbroad-core/src/ir/expression.rs
+++ b/sbroad-core/src/ir/expression.rs
@@ -19,6 +19,8 @@ use super::distribution::Distribution;
 use super::value::Value;
 use super::{operator, Node, Nodes, Plan};
 
+pub mod cast;
+
 /// Tuple tree build blocks.
 ///
 /// A tuple describes a single portion of data moved among cluster nodes.
@@ -52,6 +54,15 @@ pub enum Expression {
         /// Right branch expression node index in the plan node arena.
         right: usize,
     },
+    /// Type cast expression.
+    ///
+    /// Example: `cast(a as text)`.
+    Cast {
+        /// Target expression that must be casted to another type.
+        child: usize,
+        /// Cast type.
+        to: cast::Type,
+    },
     /// Constant expressions.
     ///
     // Example: `42`.
diff --git a/sbroad-core/src/ir/expression/cast.rs b/sbroad-core/src/ir/expression/cast.rs
new file mode 100644
index 0000000000..7617fa7998
--- /dev/null
+++ b/sbroad-core/src/ir/expression/cast.rs
@@ -0,0 +1,80 @@
+use crate::errors::QueryPlannerError;
+use crate::frontend::sql::ast::Type as AstType;
+use crate::ir::expression::Expression;
+use crate::ir::{Node, Plan};
+use serde::{Deserialize, Serialize};
+
+#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
+pub enum Type {
+    Any,
+    Boolean,
+    Decimal,
+    Double,
+    Integer,
+    Number,
+    Scalar,
+    String,
+    Text,
+    Unsigned,
+    Varchar(usize),
+}
+
+impl TryFrom<&AstType> for Type {
+    type Error = QueryPlannerError;
+
+    /// Pay attention that we can't build `Type::Varchar(length)` from string
+    /// because it has an additional length parameter. It should be constructed
+    /// separately.
+    fn try_from(ast_type: &AstType) -> Result<Self, Self::Error> {
+        match ast_type {
+            AstType::TypeAny => Ok(Type::Any),
+            AstType::TypeBool => Ok(Type::Boolean),
+            AstType::TypeDecimal => Ok(Type::Decimal),
+            AstType::TypeDouble => Ok(Type::Double),
+            AstType::TypeInt => Ok(Type::Integer),
+            AstType::TypeNumber => Ok(Type::Number),
+            AstType::TypeScalar => Ok(Type::Scalar),
+            AstType::TypeString => Ok(Type::String),
+            AstType::TypeText => Ok(Type::Text),
+            AstType::TypeUnsigned => Ok(Type::Unsigned),
+            _ => Err(QueryPlannerError::CustomError(format!(
+                "Unsupported type: {:?}",
+                ast_type
+            ))),
+        }
+    }
+}
+
+impl From<&Type> for String {
+    fn from(t: &Type) -> Self {
+        match t {
+            Type::Any => "any".to_string(),
+            Type::Boolean => "bool".to_string(),
+            Type::Decimal => "decimal".to_string(),
+            Type::Double => "double".to_string(),
+            Type::Integer => "int".to_string(),
+            Type::Number => "number".to_string(),
+            Type::Scalar => "scalar".to_string(),
+            Type::String => "string".to_string(),
+            Type::Text => "text".to_string(),
+            Type::Unsigned => "unsigned".to_string(),
+            Type::Varchar(length) => format!("varchar({})", length),
+        }
+    }
+}
+
+impl Plan {
+    /// Adds a cast expression to the plan.
+    ///
+    /// # Errors
+    /// - Child node is not of the expression type.
+    pub fn add_cast(&mut self, expr_id: usize, to_type: Type) -> Result<usize, QueryPlannerError> {
+        self.get_expression_node(expr_id)?;
+        let cast_expr = Expression::Cast {
+            child: expr_id,
+            to: to_type,
+        };
+        let cast_id = self.nodes.push(Node::Expression(cast_expr));
+        Ok(cast_id)
+    }
+}
diff --git a/sbroad-core/src/ir/tree.rs b/sbroad-core/src/ir/tree.rs
index ac51f8f164..4a38194dbd 100644
--- a/sbroad-core/src/ir/tree.rs
+++ b/sbroad-core/src/ir/tree.rs
@@ -114,7 +114,9 @@ impl<'n> Iterator for ExpressionIterator<'n> {
     fn next(&mut self) -> Option<Self::Item> {
         match self.nodes.arena.get(*self.current) {
             Some(Node::Expression(
-                Expression::Alias { child, .. } | Expression::Unary { child, .. },
+                Expression::Alias { child, .. }
+                | Expression::Cast { child, .. }
+                | Expression::Unary { child, .. },
             )) => {
                 let child_step = *self.child.borrow();
                 if child_step == 0 {
@@ -284,7 +286,9 @@ impl<'p> Iterator for SubtreeIterator<'p> {
             return match child {
                 Node::Parameter => None,
                 Node::Expression(exp) => match exp {
-                    Expression::Alias { child, .. } | Expression::Unary { child, .. } => {
+                    Expression::Alias { child, .. }
+                    | Expression::Cast { child, .. }
+                    | Expression::Unary { child, .. } => {
                         let step = *self.child.borrow();
                         *self.child.borrow_mut() += 1;
                         if step == 0 {
-- 
GitLab