From b40ce279d9d52cf426481ae34af5cf1fe1c7de5e Mon Sep 17 00:00:00 2001
From: Kaitmazian Maksim <m.kaitmazian@picodata.io>
Date: Thu, 19 Dec 2024 12:45:26 +0300
Subject: [PATCH] refactor(sbroad): Value::cast

---
 .../src/executor/engine/helpers.rs            |   8 +-
 sbroad/sbroad-core/src/executor/result.rs     |   2 +-
 sbroad/sbroad-core/src/executor/vtable.rs     |   2 +-
 sbroad/sbroad-core/src/ir/value.rs            | 112 ++++++++++--------
 sbroad/sbroad-core/src/ir/value/tests.rs      |   6 +-
 src/sql.rs                                    |   2 +-
 6 files changed, 75 insertions(+), 57 deletions(-)

diff --git a/sbroad/sbroad-core/src/executor/engine/helpers.rs b/sbroad/sbroad-core/src/executor/engine/helpers.rs
index d1d880d4fc..3321911e4a 100644
--- a/sbroad/sbroad-core/src/executor/engine/helpers.rs
+++ b/sbroad/sbroad-core/src/executor/engine/helpers.rs
@@ -649,7 +649,7 @@ pub fn build_insert_args<'t>(
                         )),
                     )
                 })?;
-                insert_tuple.push(value.cast(table_type)?);
+                insert_tuple.push(value.cast_and_encode(table_type)?);
             }
             TupleBuilderCommand::SetValue(value) => {
                 insert_tuple.push(EncodedValue::Ref(MsgPackValue::from(value)));
@@ -1693,7 +1693,7 @@ fn execute_sharded_update(
                                     )),
                                 )
                             })?;
-                            insert_tuple.push(value.cast(table_type)?);
+                            insert_tuple.push(value.cast_and_encode(table_type)?);
                         }
                         TupleBuilderCommand::CalculateBucketId(_) => {
                             insert_tuple.push(EncodedValue::Ref(MsgPackValue::Unsigned(bucket_id)));
@@ -1784,7 +1784,7 @@ pub fn build_update_args<'t>(
                         )),
                     )
                 })?;
-                key_tuple.push(value.cast(table_type)?);
+                key_tuple.push(value.cast_and_encode(table_type)?);
             }
             TupleBuilderCommand::UpdateColToCastedPos(table_col, pos, table_type) => {
                 let value = vt_tuple.get(*pos).ok_or_else(|| {
@@ -1798,7 +1798,7 @@ pub fn build_update_args<'t>(
                 let op = [
                     EncodedValue::Ref(MsgPackValue::from(eq_op())),
                     EncodedValue::Owned(LuaValue::Unsigned(*table_col as u64)),
-                    value.cast(table_type)?,
+                    value.cast_and_encode(table_type)?,
                 ];
                 ops.push(op);
             }
diff --git a/sbroad/sbroad-core/src/executor/result.rs b/sbroad/sbroad-core/src/executor/result.rs
index fe274ce97c..132f08004c 100644
--- a/sbroad/sbroad-core/src/executor/result.rs
+++ b/sbroad/sbroad-core/src/executor/result.rs
@@ -167,7 +167,7 @@ impl ProducerResult {
                 if value.get_type() == column.r#type {
                     tuple.push(value);
                 } else {
-                    tuple.push(Value::from(value.cast(&column.r#type)?));
+                    tuple.push(value.cast(column.r#type)?);
                 }
             }
             data.push(tuple);
diff --git a/sbroad/sbroad-core/src/executor/vtable.rs b/sbroad/sbroad-core/src/executor/vtable.rs
index 817afb0033..2152b2d28c 100644
--- a/sbroad/sbroad-core/src/executor/vtable.rs
+++ b/sbroad/sbroad-core/src/executor/vtable.rs
@@ -194,7 +194,7 @@ impl VirtualTable {
         for tuple in self.get_mut_tuples() {
             for (i, v) in tuple.iter_mut().enumerate() {
                 let (_, ty) = fixed_types.get(i).expect("Type expected.");
-                let cast_value = v.cast(ty)?;
+                let cast_value = v.cast_and_encode(ty)?;
                 match cast_value {
                     EncodedValue::Ref(_) => {
                         // Value type is already ok.
diff --git a/sbroad/sbroad-core/src/ir/value.rs b/sbroad/sbroad-core/src/ir/value.rs
index 000a7bba0c..7e90cb1680 100644
--- a/sbroad/sbroad-core/src/ir/value.rs
+++ b/sbroad/sbroad-core/src/ir/value.rs
@@ -780,111 +780,127 @@ impl Value {
         }
     }
 
-    /// Cast a value to a different type and wrap into encoded value.
-    /// If the target type is the same as the current type, the value
-    /// is returned by reference. Otherwise, the value is cloned.
-    ///
-    /// # Errors
-    /// - the value cannot be cast to the given type.
+    /// Cast a value to a different type.
     #[allow(clippy::too_many_lines)]
-    pub fn cast(&self, column_type: &Type) -> Result<EncodedValue, SbroadError> {
+    pub fn cast(self, column_type: Type) -> Result<Self, SbroadError> {
         let cast_error = SbroadError::Invalid(
             Entity::Value,
             Some(format_smolstr!("Failed to cast {self} to {column_type}.")),
         );
 
         match column_type {
-            Type::Any => Ok(self.into()),
+            Type::Any => Ok(self),
             Type::Array | Type::Map => match self {
-                Value::Null => Ok(Value::Null.into()),
+                Value::Null => Ok(Value::Null),
                 _ => Err(cast_error),
             },
             Type::Boolean => match self {
-                Value::Boolean(_) => Ok(self.into()),
-                Value::Null => Ok(Value::Null.into()),
+                Value::Boolean(_) => Ok(self),
+                Value::Null => Ok(Value::Null),
                 _ => Err(cast_error),
             },
             Type::Datetime => match self {
-                Value::Null => Ok(Value::Null.into()),
-                Value::Datetime(_) => Ok(self.into()),
+                Value::Null => Ok(Value::Null),
+                Value::Datetime(_) => Ok(self),
                 _ => Err(cast_error),
             },
             Type::Decimal => match self {
-                Value::Decimal(_) => Ok(self.into()),
+                Value::Decimal(_) => Ok(self),
                 Value::Double(v) => Ok(Value::Decimal(
                     Decimal::from_str(&format!("{v}")).map_err(|_| cast_error)?,
-                )
-                .into()),
-                Value::Integer(v) => Ok(Value::Decimal(Decimal::from(*v)).into()),
-                Value::Unsigned(v) => Ok(Value::Decimal(Decimal::from(*v)).into()),
-                Value::Null => Ok(Value::Null.into()),
+                )),
+                Value::Integer(v) => Ok(Value::Decimal(Decimal::from(v))),
+                Value::Unsigned(v) => Ok(Value::Decimal(Decimal::from(v))),
+                Value::Null => Ok(Value::Null),
                 _ => Err(cast_error),
             },
             Type::Double => match self {
-                Value::Double(_) => Ok(self.into()),
-                Value::Decimal(v) => Ok(Value::Double(Double::from_str(&format!("{v}"))?).into()),
-                Value::Integer(v) => Ok(Value::Double(Double::from(*v)).into()),
-                Value::Unsigned(v) => Ok(Value::Double(Double::from(*v)).into()),
-                Value::Null => Ok(Value::Null.into()),
+                Value::Double(_) => Ok(self),
+                Value::Decimal(v) => Ok(Value::Double(Double::from_str(&format!("{v}"))?)),
+                Value::Integer(v) => Ok(Value::Double(Double::from(v))),
+                Value::Unsigned(v) => Ok(Value::Double(Double::from(v))),
+                Value::Null => Ok(Value::Null),
                 _ => Err(cast_error),
             },
             Type::Integer => match self {
-                Value::Integer(_) => Ok(self.into()),
-                Value::Decimal(v) => Ok(Value::Integer(v.to_i64().ok_or(cast_error)?).into()),
+                Value::Integer(_) => Ok(self),
+                Value::Decimal(v) => Ok(Value::Integer(v.to_i64().ok_or(cast_error)?)),
                 Value::Double(v) => v
                     .to_string()
                     .parse::<i64>()
                     .map(Value::Integer)
-                    .map(EncodedValue::from)
                     .map_err(|_| cast_error),
-                Value::Unsigned(v) => {
-                    Ok(Value::Integer(i64::try_from(*v).map_err(|_| cast_error)?).into())
-                }
-                Value::Null => Ok(Value::Null.into()),
+                Value::Unsigned(v) => Ok(Value::Integer(i64::try_from(v).map_err(|_| cast_error)?)),
+                Value::Null => Ok(Value::Null),
                 _ => Err(cast_error),
             },
             Type::Scalar => match self {
                 Value::Tuple(_) => Err(cast_error),
-                _ => Ok(self.into()),
+                _ => Ok(self),
             },
             Type::String => match self {
-                Value::String(_) => Ok(self.into()),
-                Value::Null => Ok(Value::Null.into()),
+                Value::String(_) => Ok(self),
+                Value::Null => Ok(Value::Null),
                 _ => Err(cast_error),
             },
             Type::Uuid => match self {
-                Value::Uuid(_) => Ok(self.into()),
-                Value::String(v) => {
-                    Ok(Value::Uuid(Uuid::parse_str(v).map_err(|_| cast_error)?).into())
-                }
-                Value::Null => Ok(Value::Null.into()),
+                Value::Uuid(_) => Ok(self),
+                Value::String(v) => Ok(Value::Uuid(Uuid::parse_str(&v).map_err(|_| cast_error)?)),
+                Value::Null => Ok(Value::Null),
                 _ => Err(cast_error),
             },
             Type::Number => match self {
                 Value::Integer(_) | Value::Decimal(_) | Value::Double(_) | Value::Unsigned(_) => {
-                    Ok(self.into())
+                    Ok(self)
                 }
-                Value::Null => Ok(Value::Null.into()),
+                Value::Null => Ok(Value::Null),
                 _ => Err(cast_error),
             },
             Type::Unsigned => match self {
-                Value::Unsigned(_) => Ok(self.into()),
-                Value::Integer(v) => {
-                    Ok(Value::Unsigned(u64::try_from(*v).map_err(|_| cast_error)?).into())
-                }
-                Value::Decimal(v) => Ok(Value::Unsigned(v.to_u64().ok_or(cast_error)?).into()),
+                Value::Unsigned(_) => Ok(self),
+                Value::Integer(v) => Ok(Value::Unsigned(u64::try_from(v).map_err(|_| cast_error)?)),
+                Value::Decimal(v) => Ok(Value::Unsigned(v.to_u64().ok_or(cast_error)?)),
                 Value::Double(v) => v
                     .to_string()
                     .parse::<u64>()
                     .map(Value::Unsigned)
-                    .map(EncodedValue::from)
                     .map_err(|_| cast_error),
-                Value::Null => Ok(Value::Null.into()),
+                Value::Null => Ok(Value::Null),
                 _ => Err(cast_error),
             },
         }
     }
 
+    /// Cast a value to a different type and wrap into encoded value.
+    /// If the target type is the same as the current type, the value
+    /// is returned by reference. Otherwise, the value is cloned.
+    ///
+    /// # Errors
+    /// - the value cannot be cast to the given type.
+    #[allow(clippy::too_many_lines)]
+    pub fn cast_and_encode(&self, column_type: &Type) -> Result<EncodedValue, SbroadError> {
+        // First, try variants returning EncodedValue::Ref to avoid cloning.
+        match (column_type, self) {
+            (Type::Any | Type::Scalar, value) => return Ok(value.into()),
+            (Type::Boolean, Value::Boolean(_)) => return Ok(self.into()),
+            (Type::Datetime, Value::Datetime(_)) => return Ok(self.into()),
+            (Type::Decimal, Value::Decimal(_)) => return Ok(self.into()),
+            (Type::Double, Value::Double(_)) => return Ok(self.into()),
+            (Type::Integer, Value::Integer(_)) => return Ok(self.into()),
+            (Type::String, Value::String(_)) => return Ok(self.into()),
+            (Type::Uuid, Value::Uuid(_)) => return Ok(self.into()),
+            (Type::Unsigned, Value::Unsigned(_)) => return Ok(self.into()),
+            (
+                Type::Number,
+                Value::Integer(_) | Value::Decimal(_) | Value::Double(_) | Value::Unsigned(_),
+            ) => return Ok(self.into()),
+            _ => (),
+        }
+
+        // Then, apply cast with clone.
+        self.clone().cast(*column_type).map(Into::into)
+    }
+
     #[must_use]
     pub fn get_type(&self) -> Type {
         match self {
diff --git a/sbroad/sbroad-core/src/ir/value/tests.rs b/sbroad/sbroad-core/src/ir/value/tests.rs
index a2d019fbdb..4067eaf45c 100644
--- a/sbroad/sbroad-core/src/ir/value/tests.rs
+++ b/sbroad/sbroad-core/src/ir/value/tests.rs
@@ -34,7 +34,9 @@ fn uuid() {
         Some(TrivalentOrdering::Equal)
     );
     assert_eq!(
-        Value::String(uid.to_string()).cast(&Type::Uuid).is_ok(),
+        Value::String(uid.to_string())
+            .cast_and_encode(&Type::Uuid)
+            .is_ok(),
         true
     );
     assert_eq!(v_uid.partial_cmp(&Value::String(t_uid_2.to_string())), None);
@@ -44,7 +46,7 @@ fn uuid() {
 fn uuid_negative() {
     assert_eq!(
         Value::String("hello".to_string())
-            .cast(&Type::Uuid)
+            .cast_and_encode(&Type::Uuid)
             .unwrap_err(),
         SbroadError::Invalid(
             Entity::Value,
diff --git a/src/sql.rs b/src/sql.rs
index 11ca2985af..7618f5e84d 100644
--- a/src/sql.rs
+++ b/src/sql.rs
@@ -1135,7 +1135,7 @@ fn alter_system_ir_node_to_op_or_result(
             else {
                 return Err(Error::other(format!("unknown parameter: '{param_name}'")));
             };
-            let Ok(casted_value) = param_value.cast(&expected_type) else {
+            let Ok(casted_value) = param_value.cast_and_encode(&expected_type) else {
                 let actual_type = value_type_str(param_value);
                 return Err(Error::other(format!(
                     "invalid value for '{param_name}' expected {expected_type}, got {actual_type}",
-- 
GitLab