diff --git a/sbroad-core/src/ir/expression/tests.rs b/sbroad-core/src/ir/expression/tests.rs index 2ac4a2801163e4e66c4951d10c92245817a6d8ec..795c7a22b89d66a254563521789cd8a42ef9ac63 100644 --- a/sbroad-core/src/ir/expression/tests.rs +++ b/sbroad-core/src/ir/expression/tests.rs @@ -1,7 +1,8 @@ -use crate::ir::tests::column_integer_user_non_null; +use crate::ir::operator::Arithmetic; +use crate::ir::tests::{column_integer_user_non_null, sharding_column}; use pretty_assertions::assert_eq; -use crate::ir::relation::{SpaceEngine, Table}; +use crate::ir::relation::{Column, SpaceEngine, Table, Type}; use crate::ir::value::Value; use crate::ir::{Plan, SbroadError}; @@ -66,3 +67,83 @@ fn rel_nodes_from_reference_in_proj() { assert_eq!(1, rel_set.len()); assert_eq!(Some(&scan_id), rel_set.get(&scan_id)); } + +#[test] +fn derive_expr_type() { + fn column(name: String, ty: Type) -> Column { + Column { + name, + r#type: ty, + role: Default::default(), + is_nullable: false, + } + } + + let mut plan = Plan::default(); + let t = Table::new_sharded( + "t", + vec![ + column(String::from("a"), Type::Integer), + column(String::from("b"), Type::Integer), + column(String::from("c"), Type::Unsigned), + column(String::from("d"), Type::Decimal), + column(String::from("e"), Type::Decimal), + column(String::from("f"), Type::Double), + sharding_column(), + ], + &["a"], + &["a"], + SpaceEngine::Memtx, + ) + .unwrap(); + plan.add_rel(t); + let scan_id = plan.add_scan("t", None).unwrap(); + let a_id = plan.add_row_from_child(scan_id, &["a"]).unwrap(); + let b_id = plan.add_row_from_child(scan_id, &["b"]).unwrap(); + let c_id = plan.add_row_from_child(scan_id, &["c"]).unwrap(); + let d_id = plan.add_row_from_child(scan_id, &["d"]).unwrap(); + let e_id = plan.add_row_from_child(scan_id, &["e"]).unwrap(); + let f_id = plan.add_row_from_child(scan_id, &["f"]).unwrap(); + + // b/c + let arith_divide_id = plan + .add_arithmetic_to_plan(b_id, Arithmetic::Divide, c_id, false) + .unwrap(); + let expr = plan.get_expression_node(arith_divide_id).unwrap(); + assert_eq!(expr.calculate_type(&plan).unwrap(), Type::Integer); + + // d*e + let arith_multiply_id = plan + .add_arithmetic_to_plan(d_id, Arithmetic::Multiply, e_id, false) + .unwrap(); + let expr = plan.get_expression_node(arith_multiply_id).unwrap(); + assert_eq!(expr.calculate_type(&plan).unwrap(), Type::Decimal); + + // (b/c + d*e) + let arith_addition_id = plan + .add_arithmetic_to_plan(arith_divide_id, Arithmetic::Add, arith_multiply_id, true) + .unwrap(); + let expr = plan.get_expression_node(arith_addition_id).unwrap(); + assert_eq!(expr.calculate_type(&plan).unwrap(), Type::Decimal); + + // (b/c + d*e) * f + let arith_multiply_id2 = plan + .add_arithmetic_to_plan(arith_addition_id, Arithmetic::Multiply, f_id, false) + .unwrap(); + let expr = plan.get_expression_node(arith_multiply_id2).unwrap(); + assert_eq!(expr.calculate_type(&plan).unwrap(), Type::Double); + + // a + (b/c + d*e) * f + let arith_addition_id2 = plan + .add_arithmetic_to_plan(a_id, Arithmetic::Add, arith_multiply_id2, false) + .unwrap(); + let expr = plan.get_expression_node(arith_addition_id2).unwrap(); + assert_eq!(expr.calculate_type(&plan).unwrap(), Type::Double); + + // a + (b/c + d*e) * f - b + let arith_subract_id = plan + .add_arithmetic_to_plan(arith_addition_id2, Arithmetic::Subtract, b_id, false) + .unwrap(); + let expr = plan.get_expression_node(arith_subract_id).unwrap(); + assert_eq!(expr.calculate_type(&plan).unwrap(), Type::Double); +} diff --git a/sbroad-core/src/ir/expression/types.rs b/sbroad-core/src/ir/expression/types.rs index c850880af7120c262075b2dc77d65d1a0017c504..08e231eed248212b3a1de3cd33f8705fc3ca92f7 100644 --- a/sbroad-core/src/ir/expression/types.rs +++ b/sbroad-core/src/ir/expression/types.rs @@ -41,11 +41,11 @@ impl Expression { let left_type = plan.get_node_type(*left)?; let right_type = plan.get_node_type(*right)?; match (&left_type, &right_type) { - (Type::Double, Type::Unsigned | Type::Integer | Type::Decimal) + (Type::Double, Type::Double | Type::Unsigned | Type::Integer | Type::Decimal) | (Type::Unsigned | Type::Integer | Type::Decimal, Type::Double) => { Ok(Type::Double) } - (Type::Decimal, Type::Unsigned | Type::Integer) + (Type::Decimal, Type::Decimal | Type::Unsigned | Type::Integer) | (Type::Unsigned | Type::Integer, Type::Decimal) => Ok(Type::Decimal), (Type::Integer, Type::Unsigned | Type::Integer) | (Type::Unsigned, Type::Integer) => Ok(Type::Integer),