From e1039ff2aa1a57c3f9c99d6db5cb43fcf2e06423 Mon Sep 17 00:00:00 2001 From: Kaitmazian Maksim <m.kaitmazian@picodata.io> Date: Thu, 19 Dec 2024 12:45:26 +0300 Subject: [PATCH] refactor(sbroad): Value::cast --- .../src/executor/engine/helpers.rs | 8 +- sbroad/sbroad-core/src/executor/result.rs | 2 +- sbroad/sbroad-core/src/executor/vtable.rs | 2 +- sbroad/sbroad-core/src/ir/value.rs | 112 ++++++++++-------- sbroad/sbroad-core/src/ir/value/tests.rs | 6 +- src/sql.rs | 2 +- 6 files changed, 75 insertions(+), 57 deletions(-) diff --git a/sbroad/sbroad-core/src/executor/engine/helpers.rs b/sbroad/sbroad-core/src/executor/engine/helpers.rs index d1d880d4fc..3321911e4a 100644 --- a/sbroad/sbroad-core/src/executor/engine/helpers.rs +++ b/sbroad/sbroad-core/src/executor/engine/helpers.rs @@ -649,7 +649,7 @@ pub fn build_insert_args<'t>( )), ) })?; - insert_tuple.push(value.cast(table_type)?); + insert_tuple.push(value.cast_and_encode(table_type)?); } TupleBuilderCommand::SetValue(value) => { insert_tuple.push(EncodedValue::Ref(MsgPackValue::from(value))); @@ -1693,7 +1693,7 @@ fn execute_sharded_update( )), ) })?; - insert_tuple.push(value.cast(table_type)?); + insert_tuple.push(value.cast_and_encode(table_type)?); } TupleBuilderCommand::CalculateBucketId(_) => { insert_tuple.push(EncodedValue::Ref(MsgPackValue::Unsigned(bucket_id))); @@ -1784,7 +1784,7 @@ pub fn build_update_args<'t>( )), ) })?; - key_tuple.push(value.cast(table_type)?); + key_tuple.push(value.cast_and_encode(table_type)?); } TupleBuilderCommand::UpdateColToCastedPos(table_col, pos, table_type) => { let value = vt_tuple.get(*pos).ok_or_else(|| { @@ -1798,7 +1798,7 @@ pub fn build_update_args<'t>( let op = [ EncodedValue::Ref(MsgPackValue::from(eq_op())), EncodedValue::Owned(LuaValue::Unsigned(*table_col as u64)), - value.cast(table_type)?, + value.cast_and_encode(table_type)?, ]; ops.push(op); } diff --git a/sbroad/sbroad-core/src/executor/result.rs b/sbroad/sbroad-core/src/executor/result.rs index fe274ce97c..132f08004c 100644 --- a/sbroad/sbroad-core/src/executor/result.rs +++ b/sbroad/sbroad-core/src/executor/result.rs @@ -167,7 +167,7 @@ impl ProducerResult { if value.get_type() == column.r#type { tuple.push(value); } else { - tuple.push(Value::from(value.cast(&column.r#type)?)); + tuple.push(value.cast(column.r#type)?); } } data.push(tuple); diff --git a/sbroad/sbroad-core/src/executor/vtable.rs b/sbroad/sbroad-core/src/executor/vtable.rs index 817afb0033..2152b2d28c 100644 --- a/sbroad/sbroad-core/src/executor/vtable.rs +++ b/sbroad/sbroad-core/src/executor/vtable.rs @@ -194,7 +194,7 @@ impl VirtualTable { for tuple in self.get_mut_tuples() { for (i, v) in tuple.iter_mut().enumerate() { let (_, ty) = fixed_types.get(i).expect("Type expected."); - let cast_value = v.cast(ty)?; + let cast_value = v.cast_and_encode(ty)?; match cast_value { EncodedValue::Ref(_) => { // Value type is already ok. diff --git a/sbroad/sbroad-core/src/ir/value.rs b/sbroad/sbroad-core/src/ir/value.rs index 000a7bba0c..7e90cb1680 100644 --- a/sbroad/sbroad-core/src/ir/value.rs +++ b/sbroad/sbroad-core/src/ir/value.rs @@ -780,111 +780,127 @@ impl Value { } } - /// Cast a value to a different type and wrap into encoded value. - /// If the target type is the same as the current type, the value - /// is returned by reference. Otherwise, the value is cloned. - /// - /// # Errors - /// - the value cannot be cast to the given type. + /// Cast a value to a different type. #[allow(clippy::too_many_lines)] - pub fn cast(&self, column_type: &Type) -> Result<EncodedValue, SbroadError> { + pub fn cast(self, column_type: Type) -> Result<Self, SbroadError> { let cast_error = SbroadError::Invalid( Entity::Value, Some(format_smolstr!("Failed to cast {self} to {column_type}.")), ); match column_type { - Type::Any => Ok(self.into()), + Type::Any => Ok(self), Type::Array | Type::Map => match self { - Value::Null => Ok(Value::Null.into()), + Value::Null => Ok(Value::Null), _ => Err(cast_error), }, Type::Boolean => match self { - Value::Boolean(_) => Ok(self.into()), - Value::Null => Ok(Value::Null.into()), + Value::Boolean(_) => Ok(self), + Value::Null => Ok(Value::Null), _ => Err(cast_error), }, Type::Datetime => match self { - Value::Null => Ok(Value::Null.into()), - Value::Datetime(_) => Ok(self.into()), + Value::Null => Ok(Value::Null), + Value::Datetime(_) => Ok(self), _ => Err(cast_error), }, Type::Decimal => match self { - Value::Decimal(_) => Ok(self.into()), + Value::Decimal(_) => Ok(self), Value::Double(v) => Ok(Value::Decimal( Decimal::from_str(&format!("{v}")).map_err(|_| cast_error)?, - ) - .into()), - Value::Integer(v) => Ok(Value::Decimal(Decimal::from(*v)).into()), - Value::Unsigned(v) => Ok(Value::Decimal(Decimal::from(*v)).into()), - Value::Null => Ok(Value::Null.into()), + )), + Value::Integer(v) => Ok(Value::Decimal(Decimal::from(v))), + Value::Unsigned(v) => Ok(Value::Decimal(Decimal::from(v))), + Value::Null => Ok(Value::Null), _ => Err(cast_error), }, Type::Double => match self { - Value::Double(_) => Ok(self.into()), - Value::Decimal(v) => Ok(Value::Double(Double::from_str(&format!("{v}"))?).into()), - Value::Integer(v) => Ok(Value::Double(Double::from(*v)).into()), - Value::Unsigned(v) => Ok(Value::Double(Double::from(*v)).into()), - Value::Null => Ok(Value::Null.into()), + Value::Double(_) => Ok(self), + Value::Decimal(v) => Ok(Value::Double(Double::from_str(&format!("{v}"))?)), + Value::Integer(v) => Ok(Value::Double(Double::from(v))), + Value::Unsigned(v) => Ok(Value::Double(Double::from(v))), + Value::Null => Ok(Value::Null), _ => Err(cast_error), }, Type::Integer => match self { - Value::Integer(_) => Ok(self.into()), - Value::Decimal(v) => Ok(Value::Integer(v.to_i64().ok_or(cast_error)?).into()), + Value::Integer(_) => Ok(self), + Value::Decimal(v) => Ok(Value::Integer(v.to_i64().ok_or(cast_error)?)), Value::Double(v) => v .to_string() .parse::<i64>() .map(Value::Integer) - .map(EncodedValue::from) .map_err(|_| cast_error), - Value::Unsigned(v) => { - Ok(Value::Integer(i64::try_from(*v).map_err(|_| cast_error)?).into()) - } - Value::Null => Ok(Value::Null.into()), + Value::Unsigned(v) => Ok(Value::Integer(i64::try_from(v).map_err(|_| cast_error)?)), + Value::Null => Ok(Value::Null), _ => Err(cast_error), }, Type::Scalar => match self { Value::Tuple(_) => Err(cast_error), - _ => Ok(self.into()), + _ => Ok(self), }, Type::String => match self { - Value::String(_) => Ok(self.into()), - Value::Null => Ok(Value::Null.into()), + Value::String(_) => Ok(self), + Value::Null => Ok(Value::Null), _ => Err(cast_error), }, Type::Uuid => match self { - Value::Uuid(_) => Ok(self.into()), - Value::String(v) => { - Ok(Value::Uuid(Uuid::parse_str(v).map_err(|_| cast_error)?).into()) - } - Value::Null => Ok(Value::Null.into()), + Value::Uuid(_) => Ok(self), + Value::String(v) => Ok(Value::Uuid(Uuid::parse_str(&v).map_err(|_| cast_error)?)), + Value::Null => Ok(Value::Null), _ => Err(cast_error), }, Type::Number => match self { Value::Integer(_) | Value::Decimal(_) | Value::Double(_) | Value::Unsigned(_) => { - Ok(self.into()) + Ok(self) } - Value::Null => Ok(Value::Null.into()), + Value::Null => Ok(Value::Null), _ => Err(cast_error), }, Type::Unsigned => match self { - Value::Unsigned(_) => Ok(self.into()), - Value::Integer(v) => { - Ok(Value::Unsigned(u64::try_from(*v).map_err(|_| cast_error)?).into()) - } - Value::Decimal(v) => Ok(Value::Unsigned(v.to_u64().ok_or(cast_error)?).into()), + Value::Unsigned(_) => Ok(self), + Value::Integer(v) => Ok(Value::Unsigned(u64::try_from(v).map_err(|_| cast_error)?)), + Value::Decimal(v) => Ok(Value::Unsigned(v.to_u64().ok_or(cast_error)?)), Value::Double(v) => v .to_string() .parse::<u64>() .map(Value::Unsigned) - .map(EncodedValue::from) .map_err(|_| cast_error), - Value::Null => Ok(Value::Null.into()), + Value::Null => Ok(Value::Null), _ => Err(cast_error), }, } } + /// Cast a value to a different type and wrap into encoded value. + /// If the target type is the same as the current type, the value + /// is returned by reference. Otherwise, the value is cloned. + /// + /// # Errors + /// - the value cannot be cast to the given type. + #[allow(clippy::too_many_lines)] + pub fn cast_and_encode(&self, column_type: &Type) -> Result<EncodedValue, SbroadError> { + // First, try variants returning EncodedValue::Ref to avoid cloning. + match (column_type, self) { + (Type::Any | Type::Scalar, value) => return Ok(value.into()), + (Type::Boolean, Value::Boolean(_)) => return Ok(self.into()), + (Type::Datetime, Value::Datetime(_)) => return Ok(self.into()), + (Type::Decimal, Value::Decimal(_)) => return Ok(self.into()), + (Type::Double, Value::Double(_)) => return Ok(self.into()), + (Type::Integer, Value::Integer(_)) => return Ok(self.into()), + (Type::String, Value::String(_)) => return Ok(self.into()), + (Type::Uuid, Value::Uuid(_)) => return Ok(self.into()), + (Type::Unsigned, Value::Unsigned(_)) => return Ok(self.into()), + ( + Type::Number, + Value::Integer(_) | Value::Decimal(_) | Value::Double(_) | Value::Unsigned(_), + ) => return Ok(self.into()), + _ => (), + } + + // Then, apply cast with clone. + self.clone().cast(*column_type).map(Into::into) + } + #[must_use] pub fn get_type(&self) -> Type { match self { diff --git a/sbroad/sbroad-core/src/ir/value/tests.rs b/sbroad/sbroad-core/src/ir/value/tests.rs index a2d019fbdb..4067eaf45c 100644 --- a/sbroad/sbroad-core/src/ir/value/tests.rs +++ b/sbroad/sbroad-core/src/ir/value/tests.rs @@ -34,7 +34,9 @@ fn uuid() { Some(TrivalentOrdering::Equal) ); assert_eq!( - Value::String(uid.to_string()).cast(&Type::Uuid).is_ok(), + Value::String(uid.to_string()) + .cast_and_encode(&Type::Uuid) + .is_ok(), true ); assert_eq!(v_uid.partial_cmp(&Value::String(t_uid_2.to_string())), None); @@ -44,7 +46,7 @@ fn uuid() { fn uuid_negative() { assert_eq!( Value::String("hello".to_string()) - .cast(&Type::Uuid) + .cast_and_encode(&Type::Uuid) .unwrap_err(), SbroadError::Invalid( Entity::Value, diff --git a/src/sql.rs b/src/sql.rs index 11ca2985af..7618f5e84d 100644 --- a/src/sql.rs +++ b/src/sql.rs @@ -1135,7 +1135,7 @@ fn alter_system_ir_node_to_op_or_result( else { return Err(Error::other(format!("unknown parameter: '{param_name}'"))); }; - let Ok(casted_value) = param_value.cast(&expected_type) else { + let Ok(casted_value) = param_value.cast_and_encode(&expected_type) else { let actual_type = value_type_str(param_value); return Err(Error::other(format!( "invalid value for '{param_name}' expected {expected_type}, got {actual_type}", -- GitLab