From 8c4a931f2c96c36af92cc48b8e68332fd9231671 Mon Sep 17 00:00:00 2001 From: Kaitmazian Maksim <m.kaitmazian@picodata.io> Date: Thu, 19 Dec 2024 12:45:26 +0300 Subject: [PATCH 1/3] refactor(sbroad): Value::cast --- .../src/executor/engine/helpers.rs | 8 +- sbroad/sbroad-core/src/executor/result.rs | 2 +- sbroad/sbroad-core/src/executor/vtable.rs | 2 +- sbroad/sbroad-core/src/ir/value.rs | 112 ++++++++++-------- sbroad/sbroad-core/src/ir/value/tests.rs | 6 +- src/sql.rs | 2 +- 6 files changed, 75 insertions(+), 57 deletions(-) diff --git a/sbroad/sbroad-core/src/executor/engine/helpers.rs b/sbroad/sbroad-core/src/executor/engine/helpers.rs index d1d880d4fc..3321911e4a 100644 --- a/sbroad/sbroad-core/src/executor/engine/helpers.rs +++ b/sbroad/sbroad-core/src/executor/engine/helpers.rs @@ -649,7 +649,7 @@ pub fn build_insert_args<'t>( )), ) })?; - insert_tuple.push(value.cast(table_type)?); + insert_tuple.push(value.cast_and_encode(table_type)?); } TupleBuilderCommand::SetValue(value) => { insert_tuple.push(EncodedValue::Ref(MsgPackValue::from(value))); @@ -1693,7 +1693,7 @@ fn execute_sharded_update( )), ) })?; - insert_tuple.push(value.cast(table_type)?); + insert_tuple.push(value.cast_and_encode(table_type)?); } TupleBuilderCommand::CalculateBucketId(_) => { insert_tuple.push(EncodedValue::Ref(MsgPackValue::Unsigned(bucket_id))); @@ -1784,7 +1784,7 @@ pub fn build_update_args<'t>( )), ) })?; - key_tuple.push(value.cast(table_type)?); + key_tuple.push(value.cast_and_encode(table_type)?); } TupleBuilderCommand::UpdateColToCastedPos(table_col, pos, table_type) => { let value = vt_tuple.get(*pos).ok_or_else(|| { @@ -1798,7 +1798,7 @@ pub fn build_update_args<'t>( let op = [ EncodedValue::Ref(MsgPackValue::from(eq_op())), EncodedValue::Owned(LuaValue::Unsigned(*table_col as u64)), - value.cast(table_type)?, + value.cast_and_encode(table_type)?, ]; ops.push(op); } diff --git a/sbroad/sbroad-core/src/executor/result.rs b/sbroad/sbroad-core/src/executor/result.rs index fe274ce97c..132f08004c 100644 --- a/sbroad/sbroad-core/src/executor/result.rs +++ b/sbroad/sbroad-core/src/executor/result.rs @@ -167,7 +167,7 @@ impl ProducerResult { if value.get_type() == column.r#type { tuple.push(value); } else { - tuple.push(Value::from(value.cast(&column.r#type)?)); + tuple.push(value.cast(column.r#type)?); } } data.push(tuple); diff --git a/sbroad/sbroad-core/src/executor/vtable.rs b/sbroad/sbroad-core/src/executor/vtable.rs index 817afb0033..2152b2d28c 100644 --- a/sbroad/sbroad-core/src/executor/vtable.rs +++ b/sbroad/sbroad-core/src/executor/vtable.rs @@ -194,7 +194,7 @@ impl VirtualTable { for tuple in self.get_mut_tuples() { for (i, v) in tuple.iter_mut().enumerate() { let (_, ty) = fixed_types.get(i).expect("Type expected."); - let cast_value = v.cast(ty)?; + let cast_value = v.cast_and_encode(ty)?; match cast_value { EncodedValue::Ref(_) => { // Value type is already ok. diff --git a/sbroad/sbroad-core/src/ir/value.rs b/sbroad/sbroad-core/src/ir/value.rs index 000a7bba0c..7e90cb1680 100644 --- a/sbroad/sbroad-core/src/ir/value.rs +++ b/sbroad/sbroad-core/src/ir/value.rs @@ -780,111 +780,127 @@ impl Value { } } - /// Cast a value to a different type and wrap into encoded value. - /// If the target type is the same as the current type, the value - /// is returned by reference. Otherwise, the value is cloned. - /// - /// # Errors - /// - the value cannot be cast to the given type. + /// Cast a value to a different type. #[allow(clippy::too_many_lines)] - pub fn cast(&self, column_type: &Type) -> Result<EncodedValue, SbroadError> { + pub fn cast(self, column_type: Type) -> Result<Self, SbroadError> { let cast_error = SbroadError::Invalid( Entity::Value, Some(format_smolstr!("Failed to cast {self} to {column_type}.")), ); match column_type { - Type::Any => Ok(self.into()), + Type::Any => Ok(self), Type::Array | Type::Map => match self { - Value::Null => Ok(Value::Null.into()), + Value::Null => Ok(Value::Null), _ => Err(cast_error), }, Type::Boolean => match self { - Value::Boolean(_) => Ok(self.into()), - Value::Null => Ok(Value::Null.into()), + Value::Boolean(_) => Ok(self), + Value::Null => Ok(Value::Null), _ => Err(cast_error), }, Type::Datetime => match self { - Value::Null => Ok(Value::Null.into()), - Value::Datetime(_) => Ok(self.into()), + Value::Null => Ok(Value::Null), + Value::Datetime(_) => Ok(self), _ => Err(cast_error), }, Type::Decimal => match self { - Value::Decimal(_) => Ok(self.into()), + Value::Decimal(_) => Ok(self), Value::Double(v) => Ok(Value::Decimal( Decimal::from_str(&format!("{v}")).map_err(|_| cast_error)?, - ) - .into()), - Value::Integer(v) => Ok(Value::Decimal(Decimal::from(*v)).into()), - Value::Unsigned(v) => Ok(Value::Decimal(Decimal::from(*v)).into()), - Value::Null => Ok(Value::Null.into()), + )), + Value::Integer(v) => Ok(Value::Decimal(Decimal::from(v))), + Value::Unsigned(v) => Ok(Value::Decimal(Decimal::from(v))), + Value::Null => Ok(Value::Null), _ => Err(cast_error), }, Type::Double => match self { - Value::Double(_) => Ok(self.into()), - Value::Decimal(v) => Ok(Value::Double(Double::from_str(&format!("{v}"))?).into()), - Value::Integer(v) => Ok(Value::Double(Double::from(*v)).into()), - Value::Unsigned(v) => Ok(Value::Double(Double::from(*v)).into()), - Value::Null => Ok(Value::Null.into()), + Value::Double(_) => Ok(self), + Value::Decimal(v) => Ok(Value::Double(Double::from_str(&format!("{v}"))?)), + Value::Integer(v) => Ok(Value::Double(Double::from(v))), + Value::Unsigned(v) => Ok(Value::Double(Double::from(v))), + Value::Null => Ok(Value::Null), _ => Err(cast_error), }, Type::Integer => match self { - Value::Integer(_) => Ok(self.into()), - Value::Decimal(v) => Ok(Value::Integer(v.to_i64().ok_or(cast_error)?).into()), + Value::Integer(_) => Ok(self), + Value::Decimal(v) => Ok(Value::Integer(v.to_i64().ok_or(cast_error)?)), Value::Double(v) => v .to_string() .parse::<i64>() .map(Value::Integer) - .map(EncodedValue::from) .map_err(|_| cast_error), - Value::Unsigned(v) => { - Ok(Value::Integer(i64::try_from(*v).map_err(|_| cast_error)?).into()) - } - Value::Null => Ok(Value::Null.into()), + Value::Unsigned(v) => Ok(Value::Integer(i64::try_from(v).map_err(|_| cast_error)?)), + Value::Null => Ok(Value::Null), _ => Err(cast_error), }, Type::Scalar => match self { Value::Tuple(_) => Err(cast_error), - _ => Ok(self.into()), + _ => Ok(self), }, Type::String => match self { - Value::String(_) => Ok(self.into()), - Value::Null => Ok(Value::Null.into()), + Value::String(_) => Ok(self), + Value::Null => Ok(Value::Null), _ => Err(cast_error), }, Type::Uuid => match self { - Value::Uuid(_) => Ok(self.into()), - Value::String(v) => { - Ok(Value::Uuid(Uuid::parse_str(v).map_err(|_| cast_error)?).into()) - } - Value::Null => Ok(Value::Null.into()), + Value::Uuid(_) => Ok(self), + Value::String(v) => Ok(Value::Uuid(Uuid::parse_str(&v).map_err(|_| cast_error)?)), + Value::Null => Ok(Value::Null), _ => Err(cast_error), }, Type::Number => match self { Value::Integer(_) | Value::Decimal(_) | Value::Double(_) | Value::Unsigned(_) => { - Ok(self.into()) + Ok(self) } - Value::Null => Ok(Value::Null.into()), + Value::Null => Ok(Value::Null), _ => Err(cast_error), }, Type::Unsigned => match self { - Value::Unsigned(_) => Ok(self.into()), - Value::Integer(v) => { - Ok(Value::Unsigned(u64::try_from(*v).map_err(|_| cast_error)?).into()) - } - Value::Decimal(v) => Ok(Value::Unsigned(v.to_u64().ok_or(cast_error)?).into()), + Value::Unsigned(_) => Ok(self), + Value::Integer(v) => Ok(Value::Unsigned(u64::try_from(v).map_err(|_| cast_error)?)), + Value::Decimal(v) => Ok(Value::Unsigned(v.to_u64().ok_or(cast_error)?)), Value::Double(v) => v .to_string() .parse::<u64>() .map(Value::Unsigned) - .map(EncodedValue::from) .map_err(|_| cast_error), - Value::Null => Ok(Value::Null.into()), + Value::Null => Ok(Value::Null), _ => Err(cast_error), }, } } + /// Cast a value to a different type and wrap into encoded value. + /// If the target type is the same as the current type, the value + /// is returned by reference. Otherwise, the value is cloned. + /// + /// # Errors + /// - the value cannot be cast to the given type. + #[allow(clippy::too_many_lines)] + pub fn cast_and_encode(&self, column_type: &Type) -> Result<EncodedValue, SbroadError> { + // First, try variants returning EncodedValue::Ref to avoid cloning. + match (column_type, self) { + (Type::Any | Type::Scalar, value) => return Ok(value.into()), + (Type::Boolean, Value::Boolean(_)) => return Ok(self.into()), + (Type::Datetime, Value::Datetime(_)) => return Ok(self.into()), + (Type::Decimal, Value::Decimal(_)) => return Ok(self.into()), + (Type::Double, Value::Double(_)) => return Ok(self.into()), + (Type::Integer, Value::Integer(_)) => return Ok(self.into()), + (Type::String, Value::String(_)) => return Ok(self.into()), + (Type::Uuid, Value::Uuid(_)) => return Ok(self.into()), + (Type::Unsigned, Value::Unsigned(_)) => return Ok(self.into()), + ( + Type::Number, + Value::Integer(_) | Value::Decimal(_) | Value::Double(_) | Value::Unsigned(_), + ) => return Ok(self.into()), + _ => (), + } + + // Then, apply cast with clone. + self.clone().cast(*column_type).map(Into::into) + } + #[must_use] pub fn get_type(&self) -> Type { match self { diff --git a/sbroad/sbroad-core/src/ir/value/tests.rs b/sbroad/sbroad-core/src/ir/value/tests.rs index a2d019fbdb..4067eaf45c 100644 --- a/sbroad/sbroad-core/src/ir/value/tests.rs +++ b/sbroad/sbroad-core/src/ir/value/tests.rs @@ -34,7 +34,9 @@ fn uuid() { Some(TrivalentOrdering::Equal) ); assert_eq!( - Value::String(uid.to_string()).cast(&Type::Uuid).is_ok(), + Value::String(uid.to_string()) + .cast_and_encode(&Type::Uuid) + .is_ok(), true ); assert_eq!(v_uid.partial_cmp(&Value::String(t_uid_2.to_string())), None); @@ -44,7 +46,7 @@ fn uuid() { fn uuid_negative() { assert_eq!( Value::String("hello".to_string()) - .cast(&Type::Uuid) + .cast_and_encode(&Type::Uuid) .unwrap_err(), SbroadError::Invalid( Entity::Value, diff --git a/src/sql.rs b/src/sql.rs index 11ca2985af..7618f5e84d 100644 --- a/src/sql.rs +++ b/src/sql.rs @@ -1135,7 +1135,7 @@ fn alter_system_ir_node_to_op_or_result( else { return Err(Error::other(format!("unknown parameter: '{param_name}'"))); }; - let Ok(casted_value) = param_value.cast(&expected_type) else { + let Ok(casted_value) = param_value.cast_and_encode(&expected_type) else { let actual_type = value_type_str(param_value); return Err(Error::other(format!( "invalid value for '{param_name}' expected {expected_type}, got {actual_type}", -- GitLab From 818b3337bcdb319ede0c222579323b162c3b4b56 Mon Sep 17 00:00:00 2001 From: Kaitmazian Maksim <m.kaitmazian@picodata.io> Date: Thu, 19 Dec 2024 14:03:12 +0300 Subject: [PATCH 2/3] opt(sbroad): fold constants under casts Close https://git.picodata.io/core/picodata/-/issues/1228 --- sbroad/sbroad-core/src/executor.rs | 1 + .../sbroad-core/src/executor/tests/concat.rs | 2 +- sbroad/sbroad-core/src/ir/explain/tests.rs | 3 + .../src/ir/explain/tests/cast_constants.rs | 91 +++++++++++++++++ .../src/ir/explain/tests/concat.rs | 4 +- sbroad/sbroad-core/src/ir/transformation.rs | 1 + .../src/ir/transformation/cast_constants.rs | 97 +++++++++++++++++++ 7 files changed, 196 insertions(+), 3 deletions(-) create mode 100644 sbroad/sbroad-core/src/ir/explain/tests/cast_constants.rs create mode 100644 sbroad/sbroad-core/src/ir/transformation/cast_constants.rs diff --git a/sbroad/sbroad-core/src/executor.rs b/sbroad/sbroad-core/src/executor.rs index 7d219ec286..039a197479 100644 --- a/sbroad/sbroad-core/src/executor.rs +++ b/sbroad/sbroad-core/src/executor.rs @@ -57,6 +57,7 @@ impl Plan { /// # Errors /// - Failed to optimize the plan. pub fn optimize(&mut self) -> Result<(), SbroadError> { + self.cast_constants()?; self.replace_in_operator()?; self.push_down_not()?; self.split_columns()?; diff --git a/sbroad/sbroad-core/src/executor/tests/concat.rs b/sbroad/sbroad-core/src/executor/tests/concat.rs index c0a6aa8221..09df269f7d 100644 --- a/sbroad/sbroad-core/src/executor/tests/concat.rs +++ b/sbroad/sbroad-core/src/executor/tests/concat.rs @@ -5,7 +5,7 @@ use crate::ir::value::Value; fn concat1_test() { broadcast_check( r#"SELECT CAST('1' as string) || 'hello' FROM "t1""#, - r#"SELECT (CAST (? as string)) || (?) as "col_1" FROM "t1""#, + r#"SELECT (?) || (?) as "col_1" FROM "t1""#, vec![Value::from("1"), Value::from("hello")], ); } diff --git a/sbroad/sbroad-core/src/ir/explain/tests.rs b/sbroad/sbroad-core/src/ir/explain/tests.rs index 0767a6c4b9..0c931a896e 100644 --- a/sbroad/sbroad-core/src/ir/explain/tests.rs +++ b/sbroad/sbroad-core/src/ir/explain/tests.rs @@ -585,3 +585,6 @@ mod delete; #[cfg(test)] mod query_explain; + +#[cfg(test)] +mod cast_constants; diff --git a/sbroad/sbroad-core/src/ir/explain/tests/cast_constants.rs b/sbroad/sbroad-core/src/ir/explain/tests/cast_constants.rs new file mode 100644 index 0000000000..924d1011b3 --- /dev/null +++ b/sbroad/sbroad-core/src/ir/explain/tests/cast_constants.rs @@ -0,0 +1,91 @@ +use crate::executor::{engine::mock::RouterRuntimeMock, Query}; +use pretty_assertions::assert_eq; + +fn assert_expain_matches(sql: &str, expected_explain: &str) { + let metadata = &RouterRuntimeMock::new(); + let mut query = Query::new(metadata, sql, vec![]).unwrap(); + let actual_explain = query.to_explain().unwrap(); + assert_eq!(actual_explain.as_str(), expected_explain); +} + +#[test] +fn select_values_rows() { + assert_expain_matches( + "SELECT * FROM (VALUES (1::int, 2::decimal::unsigned, 'txt'::text::text::text))", + r#"projection ("COLUMN_1"::integer -> "COLUMN_1", "COLUMN_2"::unsigned -> "COLUMN_2", "COLUMN_3"::string -> "COLUMN_3") + scan + values + value row (data=ROW(1::integer, 2::unsigned, 'txt'::string)) +execution options: + vdbe_max_steps = 45000 + vtable_max_rows = 5000 +buckets = any +"#, + ); +} + +#[test] +fn insert_values_rows() { + assert_expain_matches( + "INSERT INTO t1 VALUES ('txt'::text::text::text, 2::decimal::unsigned::double::integer)", + r#"insert "t1" on conflict: fail + motion [policy: segment([ref("COLUMN_1"), ref("COLUMN_2")])] + values + value row (data=ROW('txt'::string, 2::integer)) +execution options: + vdbe_max_steps = 45000 + vtable_max_rows = 5000 +buckets = unknown +"#, + ); +} + +#[test] +fn select_selection() { + assert_expain_matches( + "SELECT * FROM t3 WHERE a = 'kek'::text::text::text", + r#"projection ("t3"."a"::string -> "a", "t3"."b"::integer -> "b") + selection ROW("t3"."a"::string) = ROW('kek'::string) + scan "t3" +execution options: + vdbe_max_steps = 45000 + vtable_max_rows = 5000 +buckets = [1610] +"#, + ); +} + +#[test] +fn update_selection() { + assert_expain_matches( + "UPDATE t SET c = 2 WHERE a = 1::int::int and b = 2::unsigned::decimal", + r#"update "t" +"c" = "col_0" + motion [policy: local] + projection (2::unsigned -> "col_0", "t"."b"::unsigned -> "col_1") + selection ROW("t"."a"::unsigned) = ROW(1::integer) and ROW("t"."b"::unsigned) = ROW(2::decimal) + scan "t" +execution options: + vdbe_max_steps = 45000 + vtable_max_rows = 5000 +buckets = [550] +"#, + ); +} + +#[test] +fn delete_selection() { + assert_expain_matches( + r#"DELETE FROM "t2" where "e" = 3::unsigned and "f" = 2::decimal"#, + r#"delete "t2" + motion [policy: local] + projection ("t2"."g"::unsigned -> "pk_col_0", "t2"."h"::unsigned -> "pk_col_1") + selection ROW("t2"."e"::unsigned) = ROW(3::unsigned) and ROW("t2"."f"::unsigned) = ROW(2::decimal) + scan "t2" +execution options: + vdbe_max_steps = 45000 + vtable_max_rows = 5000 +buckets = [9374] +"#, + ); +} diff --git a/sbroad/sbroad-core/src/ir/explain/tests/concat.rs b/sbroad/sbroad-core/src/ir/explain/tests/concat.rs index 1e65e6c1ef..cc3026cbbf 100644 --- a/sbroad/sbroad-core/src/ir/explain/tests/concat.rs +++ b/sbroad/sbroad-core/src/ir/explain/tests/concat.rs @@ -6,7 +6,7 @@ fn concat1_test() { r#"SELECT CAST('1' as string) || 'hello' FROM "t1""#, &format!( "{}\n{}\n{}\n{}\n{}\n", - r#"projection (ROW('1'::string::string) || ROW('hello'::string) -> "col_1")"#, + r#"projection (ROW('1'::string) || ROW('hello'::string) -> "col_1")"#, r#" scan "t1""#, r#"execution options:"#, r#" vdbe_max_steps = 45000"#, @@ -22,7 +22,7 @@ fn concat2_test() { &format!( "{}\n{}\n{}\n{}\n{}\n{}\n", r#"projection ("t1"."a"::string -> "a")"#, - r#" selection ROW(ROW(ROW('1'::string::string) || ROW("func"(('hello'::string))::integer)) || ROW('2'::string)) = ROW(42::unsigned)"#, + r#" selection ROW(ROW(ROW('1'::string) || ROW("func"(('hello'::string))::integer)) || ROW('2'::string)) = ROW(42::unsigned)"#, r#" scan "t1""#, r#"execution options:"#, r#" vdbe_max_steps = 45000"#, diff --git a/sbroad/sbroad-core/src/ir/transformation.rs b/sbroad/sbroad-core/src/ir/transformation.rs index 333c762bff..49adc95ca4 100644 --- a/sbroad/sbroad-core/src/ir/transformation.rs +++ b/sbroad/sbroad-core/src/ir/transformation.rs @@ -3,6 +3,7 @@ //! Contains rule-based transformations. pub mod bool_in; +pub mod cast_constants; pub mod dnf; pub mod equality_propagation; pub mod merge_tuples; diff --git a/sbroad/sbroad-core/src/ir/transformation/cast_constants.rs b/sbroad/sbroad-core/src/ir/transformation/cast_constants.rs new file mode 100644 index 0000000000..6e284214cb --- /dev/null +++ b/sbroad/sbroad-core/src/ir/transformation/cast_constants.rs @@ -0,0 +1,97 @@ +use crate::{ + errors::SbroadError, + ir::{ + node::{ + expression::{Expression, MutExpression}, + Cast, Constant, Node, NodeId, Row, + }, + relation::Type, + tree::traversal::{LevelNode, PostOrderWithFilter}, + value::Value, + Plan, + }, +}; + +fn apply_cast(plan: &Plan, child_id: NodeId, target_type: Type) -> Option<Value> { + match plan.get_expression_node(child_id).ok()? { + Expression::Constant(Constant { value }) => { + let value = value.clone(); + value.cast(target_type).ok() + } + Expression::Cast(Cast { + child: cast_child, + to: cast_type, + }) => { + let cast_type = cast_type.as_relation_type(); + let value = apply_cast(plan, *cast_child, cast_type); + // Note: We don't throw errors if casting fails. + // It's possible that some type and value combinations are missing, + // but in such cases, we simply skip this evaluation and continue with other casts. + // An optimization failure should not prevent the execution of the plan. + value.and_then(|x| x.cast(target_type).ok()) + } + _ => None, + } +} + +impl Plan { + /// Evaluates cast constant expressions and replaces them with actual values in the plan. + /// + /// This function focuses on simplifying the plan by eliminating unnecessary casts in selection + /// expressions, enabling bucket filtering and in value rows, enabling local materialization. + pub fn cast_constants(&mut self) -> Result<(), SbroadError> { + // For simplicity, we only evaluate constants wrapped in Row, + // e.g., Row(Cast(Constant), Cast(Cast(Constant))). + // This approach includes the target cases from the function comment + // (selection expressions and values rows). + let rows_filter = |node_id| { + matches!( + self.get_node(node_id), + Ok(Node::Expression(Expression::Row(_))) + ) + }; + let mut subtree = PostOrderWithFilter::with_capacity( + |node| self.subtree_iter(node, false), + self.nodes.len(), + Box::new(rows_filter), + ); + + let top_id = self.get_top()?; + subtree.populate_nodes(top_id); + let row_ids = subtree.take_nodes(); + drop(subtree); + + let mut new_list = Vec::new(); + for LevelNode(_, row_id) in row_ids { + // Clone row children list to overcome borrow checker. + new_list.clear(); + if let Expression::Row(Row { list, .. }) = self.get_expression_node(row_id)? { + new_list.clone_from(list); + } + + // Try to apply cast to constants, push new values in the plan and remember ids in a + // copy if row children list. + for row_child in new_list.iter_mut() { + if let Expression::Cast(Cast { + child: cast_child, + to, + }) = self.get_expression_node(*row_child)? + { + let to = to.as_relation_type(); + if let Some(value) = apply_cast(self, *cast_child, to) { + *row_child = self.add_const(value); + } + } + } + + // Change row children to the new ones with casts applied. + if let MutExpression::Row(Row { ref mut list, .. }) = + self.get_mut_expression_node(row_id)? + { + new_list.clone_into(list); + } + } + + Ok(()) + } +} -- GitLab From dded2b2b91aa70f0fadfbfc876e2104ffa60c1f5 Mon Sep 17 00:00:00 2001 From: Kaitmazian Maksim <m.kaitmazian@picodata.io> Date: Thu, 19 Dec 2024 16:49:16 +0300 Subject: [PATCH 3/3] refactor(sbroad): avoid redundant error allocation in Value::cast --- sbroad/sbroad-core/src/ir/value.rs | 65 ++++++++++++++++++------------ 1 file changed, 39 insertions(+), 26 deletions(-) diff --git a/sbroad/sbroad-core/src/ir/value.rs b/sbroad/sbroad-core/src/ir/value.rs index 7e90cb1680..18d85399fb 100644 --- a/sbroad/sbroad-core/src/ir/value.rs +++ b/sbroad/sbroad-core/src/ir/value.rs @@ -783,36 +783,39 @@ impl Value { /// Cast a value to a different type. #[allow(clippy::too_many_lines)] pub fn cast(self, column_type: Type) -> Result<Self, SbroadError> { - let cast_error = SbroadError::Invalid( - Entity::Value, - Some(format_smolstr!("Failed to cast {self} to {column_type}.")), - ); + fn cast_error(value: &Value, column_type: Type) -> SbroadError { + SbroadError::Invalid( + Entity::Value, + Some(format_smolstr!("Failed to cast {value} to {column_type}.")), + ) + } match column_type { Type::Any => Ok(self), Type::Array | Type::Map => match self { Value::Null => Ok(Value::Null), - _ => Err(cast_error), + _ => Err(cast_error(&self, column_type)), }, Type::Boolean => match self { Value::Boolean(_) => Ok(self), Value::Null => Ok(Value::Null), - _ => Err(cast_error), + _ => Err(cast_error(&self, column_type)), }, Type::Datetime => match self { Value::Null => Ok(Value::Null), Value::Datetime(_) => Ok(self), - _ => Err(cast_error), + _ => Err(cast_error(&self, column_type)), }, Type::Decimal => match self { Value::Decimal(_) => Ok(self), - Value::Double(v) => Ok(Value::Decimal( - Decimal::from_str(&format!("{v}")).map_err(|_| cast_error)?, + Value::Double(ref v) => Ok(Value::Decimal( + Decimal::from_str(&format!("{v}")) + .map_err(|_| cast_error(&self, column_type))?, )), Value::Integer(v) => Ok(Value::Decimal(Decimal::from(v))), Value::Unsigned(v) => Ok(Value::Decimal(Decimal::from(v))), Value::Null => Ok(Value::Null), - _ => Err(cast_error), + _ => Err(cast_error(&self, column_type)), }, Type::Double => match self { Value::Double(_) => Ok(self), @@ -820,53 +823,63 @@ impl Value { Value::Integer(v) => Ok(Value::Double(Double::from(v))), Value::Unsigned(v) => Ok(Value::Double(Double::from(v))), Value::Null => Ok(Value::Null), - _ => Err(cast_error), + _ => Err(cast_error(&self, column_type)), }, Type::Integer => match self { Value::Integer(_) => Ok(self), - Value::Decimal(v) => Ok(Value::Integer(v.to_i64().ok_or(cast_error)?)), - Value::Double(v) => v + Value::Decimal(v) => Ok(Value::Integer( + v.to_i64().ok_or_else(|| cast_error(&self, column_type))?, + )), + Value::Double(ref v) => v .to_string() .parse::<i64>() .map(Value::Integer) - .map_err(|_| cast_error), - Value::Unsigned(v) => Ok(Value::Integer(i64::try_from(v).map_err(|_| cast_error)?)), + .map_err(|_| cast_error(&self, column_type)), + Value::Unsigned(v) => Ok(Value::Integer( + i64::try_from(v).map_err(|_| cast_error(&self, column_type))?, + )), Value::Null => Ok(Value::Null), - _ => Err(cast_error), + _ => Err(cast_error(&self, column_type)), }, Type::Scalar => match self { - Value::Tuple(_) => Err(cast_error), + Value::Tuple(_) => Err(cast_error(&self, column_type)), _ => Ok(self), }, Type::String => match self { Value::String(_) => Ok(self), Value::Null => Ok(Value::Null), - _ => Err(cast_error), + _ => Err(cast_error(&self, column_type)), }, Type::Uuid => match self { Value::Uuid(_) => Ok(self), - Value::String(v) => Ok(Value::Uuid(Uuid::parse_str(&v).map_err(|_| cast_error)?)), + Value::String(ref v) => Ok(Value::Uuid( + Uuid::parse_str(v).map_err(|_| cast_error(&self, column_type))?, + )), Value::Null => Ok(Value::Null), - _ => Err(cast_error), + _ => Err(cast_error(&self, column_type)), }, Type::Number => match self { Value::Integer(_) | Value::Decimal(_) | Value::Double(_) | Value::Unsigned(_) => { Ok(self) } Value::Null => Ok(Value::Null), - _ => Err(cast_error), + _ => Err(cast_error(&self, column_type)), }, Type::Unsigned => match self { Value::Unsigned(_) => Ok(self), - Value::Integer(v) => Ok(Value::Unsigned(u64::try_from(v).map_err(|_| cast_error)?)), - Value::Decimal(v) => Ok(Value::Unsigned(v.to_u64().ok_or(cast_error)?)), - Value::Double(v) => v + Value::Integer(v) => Ok(Value::Unsigned( + u64::try_from(v).map_err(|_| cast_error(&self, column_type))?, + )), + Value::Decimal(v) => Ok(Value::Unsigned( + v.to_u64().ok_or_else(|| cast_error(&self, column_type))?, + )), + Value::Double(ref v) => v .to_string() .parse::<u64>() .map(Value::Unsigned) - .map_err(|_| cast_error), + .map_err(|_| cast_error(&self, column_type)), Value::Null => Ok(Value::Null), - _ => Err(cast_error), + _ => Err(cast_error(&self, column_type)), }, } } -- GitLab