From 402947369c6b105012c1b78f0dbe5074c24f6898 Mon Sep 17 00:00:00 2001
From: Denis Smirnov <sd@picodata.io>
Date: Thu, 11 Jul 2024 22:15:08 +0700
Subject: [PATCH] feat: use rust allocated tuples for binary data

Co-authored-by: Georgy Moshkin <gmoshkin@picodata.io>
---
 Cargo.lock                                 |   1 +
 sbroad-core/Cargo.toml                     |   1 +
 sbroad-core/src/executor/engine/helpers.rs |  49 ++++++--
 sbroad-core/src/executor/protocol.rs       | 137 +++++++++++++++++++--
 4 files changed, 169 insertions(+), 19 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index 2e8c13c5e..a239d924e 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1145,6 +1145,7 @@ dependencies = [
  "rmpv",
  "sbroad-proc",
  "serde",
+ "serde_bytes",
  "serde_yaml",
  "smol_str",
  "tarantool",
diff --git a/sbroad-core/Cargo.toml b/sbroad-core/Cargo.toml
index 1e7847507..4ccad98a7 100644
--- a/sbroad-core/Cargo.toml
+++ b/sbroad-core/Cargo.toml
@@ -26,6 +26,7 @@ rmp-serde = "1.0"
 rmpv = "1.0"
 sbroad-proc = { path = "../sbroad-proc", version = "0.1" }
 serde = { version = "1.0", features = ["derive", "rc"] }
+serde_bytes = "0.11"
 serde_yaml = "0.8"
 uuid = { version = "1.1", features = ["v4", "fast-rng", "macro-diagnostics"] }
 smol_str = { version = "0.2", features = ["serde"] }
diff --git a/sbroad-core/src/executor/engine/helpers.rs b/sbroad-core/src/executor/engine/helpers.rs
index 09213ebc2..1d7cbcd5a 100644
--- a/sbroad-core/src/executor/engine/helpers.rs
+++ b/sbroad-core/src/executor/engine/helpers.rs
@@ -45,7 +45,7 @@ use crate::{
     executor::{
         bucket::Buckets,
         ir::{ExecutionPlan, QueryType},
-        protocol::{Binary, EncodedOptionalData, EncodedRequiredData, OptionalData, RequiredData},
+        protocol::{Binary, EncodedOptionalData, OptionalData, RequiredData},
         result::{ConsumerResult, MetadataColumn, ProducerResult},
         vtable::{VTableTuple, VirtualTable},
     },
@@ -172,9 +172,8 @@ pub fn build_required_binary(exec_plan: &mut ExecutionPlan) -> Result<Binary, Sb
         schema_info,
         tables,
     );
-    let encoded_required_data = EncodedRequiredData::try_from(required)?;
-    let raw_required_data: Vec<u8> = encoded_required_data.into();
-    Ok(raw_required_data.into())
+    let required_as_tuple = required.to_tuple()?;
+    Ok(required_as_tuple.into())
 }
 
 /// # Errors
@@ -248,10 +247,8 @@ pub fn build_optional_binary(mut exec_plan: ExecutionPlan) -> Result<Binary, Sbr
     };
     let vtables_meta = exec_plan.remove_vtables()?;
     let optional_data = OptionalData::new(exec_plan, ordered, vtables_meta);
-
-    let encoded_optional_data = EncodedOptionalData::try_from(optional_data)?;
-    let raw_optional_data: Vec<u8> = encoded_optional_data.into();
-    Ok(raw_optional_data.into())
+    let optional_as_tuple = optional_data.to_tuple()?;
+    Ok(optional_as_tuple.into())
 }
 
 /// Helper struct for storing optional data extracted
@@ -306,14 +303,31 @@ pub fn decode_msgpack(tuple_buf: &[u8]) -> Result<DecodeOutput, SbroadError> {
             )),
         ));
     }
-    let data_len = rmp::decode::read_str_len(&mut stream).map_err(|e| {
+
+    // Decode required data.
+    let req_array_len = rmp::decode::read_array_len(&mut stream).map_err(|e| {
+        SbroadError::FailedTo(
+            Action::Decode,
+            Some(Entity::MsgPack),
+            format_smolstr!("required array length: {e:?}"),
+        )
+    })? as usize;
+    if req_array_len != 1 {
+        return Err(SbroadError::Invalid(
+            Entity::Tuple,
+            Some(format_smolstr!(
+                "expected array of 1 element in required, got {req_array_len}"
+            )),
+        ));
+    }
+    let req_data_len = rmp::decode::read_str_len(&mut stream).map_err(|e| {
         SbroadError::FailedTo(
             Action::Decode,
             Some(Entity::MsgPack),
             format_smolstr!("read required data length: {e:?}"),
         )
     })? as usize;
-    let mut data: Vec<u8> = vec![0_u8; data_len];
+    let mut data: Vec<u8> = vec![0_u8; req_data_len];
     stream.read_exact_buf(&mut data).map_err(|e| {
         SbroadError::FailedTo(
             Action::Decode,
@@ -324,6 +338,21 @@ pub fn decode_msgpack(tuple_buf: &[u8]) -> Result<DecodeOutput, SbroadError> {
 
     let mut optional_data = None;
     if array_len == 3 {
+        let opt_array_len = rmp::decode::read_array_len(&mut stream).map_err(|e| {
+            SbroadError::FailedTo(
+                Action::Decode,
+                Some(Entity::MsgPack),
+                format_smolstr!("optional array length: {e:?}"),
+            )
+        })? as usize;
+        if opt_array_len != 1 {
+            return Err(SbroadError::Invalid(
+                Entity::Tuple,
+                Some(format_smolstr!(
+                    "expected array of 1 element in optional, got {opt_array_len}"
+                )),
+            ));
+        }
         let opt_len = rmp::decode::read_str_len(&mut stream).map_err(|e| {
             SbroadError::FailedTo(
                 Action::Decode,
diff --git a/sbroad-core/src/executor/protocol.rs b/sbroad-core/src/executor/protocol.rs
index 2e41f1c5a..42c57c697 100644
--- a/sbroad-core/src/executor/protocol.rs
+++ b/sbroad-core/src/executor/protocol.rs
@@ -1,5 +1,5 @@
 use opentelemetry::Context;
-use rmp::decode::{read_array_len, Bytes, RmpRead};
+use rmp::decode::{read_array_len, read_str_len, Bytes, RmpRead};
 use serde::{Deserialize, Serialize};
 use smol_str::{format_smolstr, SmolStr};
 use std::collections::HashMap;
@@ -24,12 +24,108 @@ use super::vtable::VirtualTableMeta;
 
 pub type VTablesMeta = HashMap<NodeId, VirtualTableMeta>;
 
-#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq)]
-pub struct Binary(Vec<u8>);
+pub fn rust_allocated_tuple_from_bincode<T>(value: &T) -> Result<Tuple, SmolStr>
+where
+    T: ?Sized + serde::Serialize,
+{
+    let type_name = std::any::type_name::<T>();
+
+    let res = bincode::serialized_size(value);
+    let bincode_size = match res {
+        Ok(v) => v,
+        Err(e) => {
+            let msg = format_smolstr!("failed getting serialized size for {type_name}: {e}");
+            tarantool::say_warn!("{msg}");
+            return Err(msg);
+        }
+    };
+    if bincode_size > u32::MAX as u64 {
+        let msg = format_smolstr!(
+            "serialized value of {type_name} is too big: {bincode_size} > {}",
+            u32::MAX
+        );
+        tarantool::say_warn!("{msg}");
+        return Err(msg);
+    }
+
+    let mut msgpack_header = [0_u8; 6];
+    // array of len 1
+    msgpack_header[0] = b'\x91';
+    // string with 32bit length
+    msgpack_header[1] = b'\xdb';
+    // 32bit length of string
+    msgpack_header[2..].copy_from_slice(&(bincode_size as u32).to_be_bytes());
+
+    let capacity = msgpack_header.len() + bincode_size as usize;
+    let mut builder = tarantool::tuple::TupleBuilder::rust_allocated();
+    builder.reserve(capacity);
+    builder.append(&msgpack_header);
+
+    let res = bincode::serialize_into(&mut builder, value);
+    match res {
+        Ok(()) => {}
+        Err(e) => {
+            let msg = format_smolstr!("failed serializing value of {type_name}: {e}");
+            tarantool::say_warn!("{msg}");
+            return Err(msg);
+        }
+    }
+
+    let tuple = builder.into_tuple();
+    let tuple = match tuple {
+        Ok(v) => v,
+        Err(e) => {
+            let msg = format_smolstr!(
+                "failed creating a tuple from serialized data for {type_name}: {e}"
+            );
+            tarantool::say_warn!("{msg}");
+            return Err(msg);
+        }
+    };
+
+    Ok(tuple)
+}
+
+pub fn rust_allocated_tuple_from_bytes(data: &[u8]) -> Tuple {
+    let mut msgpack_header = [0_u8; 6];
+    // array of len 1
+    msgpack_header[0] = b'\x91';
+    // string with 32bit length
+    msgpack_header[1] = b'\xdb';
+    // 32bit length of string
+    msgpack_header[2..].copy_from_slice(&(data.len() as u32).to_be_bytes());
+
+    let capacity = msgpack_header.len() + data.len();
+    let mut builder = tarantool::tuple::TupleBuilder::rust_allocated();
+    builder.reserve(capacity);
+    builder.append(&msgpack_header);
+
+    builder.append(data);
+
+    let tuple = builder.into_tuple();
+    match tuple {
+        Ok(v) => v,
+        Err(e) => {
+            unreachable!("can't fail msgpack validation, msgpack header is valid: {e}");
+        }
+    }
+}
+
+#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
+pub struct Binary(#[serde(with = "serde_bytes")] Tuple);
 
 impl From<Vec<u8>> for Binary {
+    #[inline(always)]
     fn from(value: Vec<u8>) -> Self {
-        Binary(value)
+        let tuple = rust_allocated_tuple_from_bytes(&value);
+        Binary(tuple)
+    }
+}
+
+impl From<Tuple> for Binary {
+    #[inline(always)]
+    fn from(tuple: tarantool::tuple::Tuple) -> Self {
+        Binary(tuple)
     }
 }
 
@@ -40,8 +136,7 @@ where
     type Err = Void;
 
     fn push_into_lua(self, lua: L) -> Result<PushGuard<L>, (Void, L)> {
-        let encoded = unsafe { String::from_utf8_unchecked(self.0) };
-        encoded.push_into_lua(lua)
+        self.0.push_into_lua(lua)
     }
 }
 
@@ -52,8 +147,7 @@ where
     type Err = Void;
 
     fn push_to_lua(&self, lua: L) -> Result<PushGuard<L>, (Self::Err, L)> {
-        let encoded = unsafe { std::str::from_utf8_unchecked(&self.0) };
-        encoded.push_to_lua(lua)
+        self.0.push_to_lua(lua)
     }
 }
 
@@ -148,7 +242,7 @@ impl<'e> IntoIterator for &'e EncodedRows {
     fn into_iter(self) -> Self::IntoIter {
         let capacity = *self.marking.iter().max().unwrap_or(&0);
         EncodedRowsIter {
-            stream: Bytes::from(self.encoded.0.as_slice()),
+            stream: Bytes::from(self.encoded.0.data()),
             marking: &self.marking,
             position: 0,
             // Allocate buffer for encoded row.
@@ -175,6 +269,11 @@ impl<'e> Iterator for EncodedRowsIter<'e> {
         let cur_pos = self.position;
         self.position += 1;
         if cur_pos == 0 {
+            // Array of one element wrapping a string (binary) with an array of encoded tuples.
+            let wrapper_array_len = read_array_len(&mut self.stream).expect("wrapping array");
+            assert_eq!(wrapper_array_len, 1);
+            let str_len = read_str_len(&mut self.stream).expect("string length");
+            assert_eq!(str_len as usize, self.stream.remaining_slice().len());
             let array_len = read_array_len(&mut self.stream).expect("encoded rows length");
             assert_eq!(array_len as usize, self.marking.len());
         }
@@ -263,6 +362,16 @@ impl TryFrom<&[u8]> for RequiredData {
 }
 
 impl RequiredData {
+    const ENTITY: Entity = Entity::RequiredData;
+
+    /// Construct a tuple, i.e. msgpack array of one binary string
+    /// containing the bincode encoding of `self`.
+    #[inline(always)]
+    pub fn to_tuple(&self) -> Result<tarantool::tuple::Tuple, SbroadError> {
+        rust_allocated_tuple_from_bincode(self)
+            .map_err(|msg| SbroadError::FailedTo(Action::Serialize, Some(Self::ENTITY), msg))
+    }
+
     #[must_use]
     pub fn new(
         plan_id: SmolStr,
@@ -366,6 +475,16 @@ impl TryFrom<&[u8]> for OptionalData {
 }
 
 impl OptionalData {
+    const ENTITY: Entity = Entity::OptionalData;
+
+    /// Construct a tuple, i.e. msgpack array of one binary string
+    /// containing the bincode encoding of `self`.
+    #[inline(always)]
+    pub fn to_tuple(&self) -> Result<tarantool::tuple::Tuple, SbroadError> {
+        rust_allocated_tuple_from_bincode(self)
+            .map_err(|msg| SbroadError::FailedTo(Action::Serialize, Some(Self::ENTITY), msg))
+    }
+
     #[must_use]
     pub fn new(
         exec_plan: ExecutionPlan,
-- 
GitLab