From 8f5054fdd6a32e2ac16e1525c120d9eb749f5c4c Mon Sep 17 00:00:00 2001
From: Kaitmazian Maksim <m.kaitmazian@picodata.io>
Date: Thu, 14 Mar 2024 19:56:13 +0300
Subject: [PATCH] refactor: move rows encoding preparation from picodata to
 pgproto

---
 pgproto/src/client/extended_query.rs | 22 ++++++++++------------
 pgproto/src/entrypoints.rs           | 23 ++++++++++++++++++-----
 pgproto/src/storage.rs               |  4 ++--
 pgproto/src/storage/describe.rs      | 17 +++++++++++++++++
 pgproto/src/storage/value.rs         |  9 +++++++++
 5 files changed, 56 insertions(+), 19 deletions(-)

diff --git a/pgproto/src/client/extended_query.rs b/pgproto/src/client/extended_query.rs
index c6c113432b..616778306d 100644
--- a/pgproto/src/client/extended_query.rs
+++ b/pgproto/src/client/extended_query.rs
@@ -32,27 +32,24 @@ pub fn process_parse_message(
     Ok(())
 }
 
-/// Map any encoding format to per-parameter format just like pg does it in
+/// Map any encoding format to per-column or per-parameter format just like pg does it in
 /// [exec_bind_message](https://github.com/postgres/postgres/blob/5c7038d70bb9c4d28a80b0a2051f73fafab5af3f/src/backend/tcop/postgres.c#L1840-L1845)
 /// or [PortalSetResultFormat](https://github.com/postgres/postgres/blob/5c7038d70bb9c4d28a80b0a2051f73fafab5af3f/src/backend/tcop/pquery.c#L623).
-fn prepare_parameter_encoding_format(
-    formats: &[RawFormat],
-    nparams: usize,
-) -> PgResult<Vec<Format>> {
-    if formats.len() == nparams {
+fn prepare_encoding_format(formats: &[RawFormat], n: usize) -> PgResult<Vec<Format>> {
+    if formats.len() == n {
         // format specified for each column
         formats.iter().map(|i| Format::try_from(*i)).collect()
     } else if formats.len() == 1 {
         // single format specified, use it for each column
-        Ok(vec![Format::try_from(formats[0])?; nparams])
+        Ok(vec![Format::try_from(formats[0])?; n])
     } else if formats.is_empty() {
         // no format specified, use the default for each column
-        Ok(vec![Format::Text; nparams])
+        Ok(vec![Format::Text; n])
     } else {
         Err(PgError::ProtocolViolation(format!(
-            "got {} format codes for {} columns",
+            "got {} format codes for {} items",
             formats.len(),
-            nparams
+            n
         )))
     }
 }
@@ -62,7 +59,7 @@ fn decode_parameter_values(
     param_oids: &[Oid],
     formats: &[RawFormat],
 ) -> PgResult<Vec<PgValue>> {
-    let formats = prepare_parameter_encoding_format(formats, params.len())?;
+    let formats = prepare_encoding_format(formats, params.len())?;
     if params.len() != param_oids.len() {
         return Err(PgError::ProtocolViolation(format!(
             "got {} parameters, {} oids and {} formats",
@@ -86,7 +83,8 @@ pub fn process_bind_message(
     let params = mem::take(bind.parameters_mut());
     let formats = bind.parameter_format_codes();
     let params = decode_parameter_values(params, &describe.param_oids, formats)?;
-    let result_format = bind.result_column_format_codes();
+    let ncolumns = describe.ncolumns();
+    let result_format = prepare_encoding_format(bind.result_column_format_codes(), ncolumns)?;
 
     manager.bind(
         bind.statement_name().as_deref(),
diff --git a/pgproto/src/entrypoints.rs b/pgproto/src/entrypoints.rs
index 8fcc7f04a9..044da70fc8 100644
--- a/pgproto/src/entrypoints.rs
+++ b/pgproto/src/entrypoints.rs
@@ -4,7 +4,7 @@ use crate::{
     storage::{
         describe::{PortalDescribe, QueryType, StatementDescribe},
         result::ExecuteResult,
-        value::PgValue,
+        value::{Format, PgValue},
     },
 };
 use postgres_types::Oid;
@@ -78,8 +78,7 @@ fn parse_explain(res: Value) -> PgResult<DqlResult> {
     })
 }
 
-fn execute_result_from_json(json: &str) -> PgResult<ExecuteResult> {
-    let raw: RawExecuteResult = serde_json::from_str(json)?;
+fn execute_result_from_raw_result(raw: RawExecuteResult) -> PgResult<ExecuteResult> {
     match raw.describe.query_type() {
         QueryType::Dql => {
             let res = parse_dql(raw.result)?;
@@ -94,6 +93,20 @@ fn execute_result_from_json(json: &str) -> PgResult<ExecuteResult> {
     }
 }
 
+fn execute_result_from_json(json: &str) -> PgResult<ExecuteResult> {
+    let raw: RawExecuteResult = serde_json::from_str(json)?;
+    execute_result_from_raw_result(raw)
+}
+
+fn simple_execute_result_from_json(json: &str) -> PgResult<ExecuteResult> {
+    let mut raw: RawExecuteResult = serde_json::from_str(json)?;
+    // Simple query supports only the text format.
+    // We couldn't set the format when we were calling bind, because we didn't know the number of columns,
+    // but after executing the whole simple query pipeline we have a description containing this number.
+    raw.describe.set_text_output_format();
+    execute_result_from_raw_result(raw)
+}
+
 type Entrypoint = LuaFunction<PushGuard<LuaThread>>;
 
 /// List of lua functions from sbroad that implement PG protcol API.
@@ -327,7 +340,7 @@ impl Entrypoints {
             .simple_query
             .call_with_args((client_id, sql))
             .map_err(|e| PgError::TarantoolError(e.into()))?;
-        execute_result_from_json(&json)
+        simple_execute_result_from_json(&json)
     }
 
     /// Handler for a Parse message. See self.parse for the details.
@@ -350,7 +363,7 @@ impl Entrypoints {
         statement: &str,
         portal: &str,
         params: Vec<PgValue>,
-        result_format: &[i16],
+        result_format: Vec<Format>,
     ) -> PgResult<()> {
         self.bind
             .call_with_args((id, statement, portal, params, result_format))
diff --git a/pgproto/src/storage.rs b/pgproto/src/storage.rs
index bf62dd8f23..1858ded3a0 100644
--- a/pgproto/src/storage.rs
+++ b/pgproto/src/storage.rs
@@ -1,6 +1,6 @@
 use self::describe::{PortalDescribe, StatementDescribe};
 use self::result::ExecuteResult;
-use self::value::PgValue;
+use self::value::{Format, PgValue};
 use crate::client::ClientId;
 use crate::entrypoints::PG_ENTRYPOINTS;
 use crate::error::PgResult;
@@ -68,7 +68,7 @@ impl StorageManager {
         statement: Option<&str>,
         portal: Option<&str>,
         params: Vec<PgValue>,
-        result_format: &[i16],
+        result_format: Vec<Format>,
     ) -> PgResult<()> {
         PG_ENTRYPOINTS.with(|entrypoints| {
             entrypoints.borrow().bind(
diff --git a/pgproto/src/storage/describe.rs b/pgproto/src/storage/describe.rs
index 23f69a6659..9ffedcb2e7 100644
--- a/pgproto/src/storage/describe.rs
+++ b/pgproto/src/storage/describe.rs
@@ -22,6 +22,12 @@ pub struct StatementDescribe {
     pub param_oids: Vec<Oid>,
 }
 
+impl StatementDescribe {
+    pub fn ncolumns(&self) -> usize {
+        self.describe.metadata.len()
+    }
+}
+
 #[derive(Debug, Clone, Default, Deserialize)]
 pub struct PortalDescribe {
     #[serde(flatten)]
@@ -59,6 +65,17 @@ impl PortalDescribe {
     pub fn output_format(&self) -> &[Format] {
         &self.output_format
     }
+
+    // Enforce use of the text format for output rows. We use it for simple query, as it supports only the text format.
+    pub fn set_text_output_format(&mut self) {
+        let mut output_format = Vec::new();
+        output_format.resize(self.ncolumns(), Format::Text);
+        self.output_format = output_format;
+    }
+
+    pub fn ncolumns(&self) -> usize {
+        self.describe.metadata.len()
+    }
 }
 
 #[derive(Debug, Deserialize, PartialEq, Eq, Clone)]
diff --git a/pgproto/src/storage/value.rs b/pgproto/src/storage/value.rs
index a25039b8a9..eed314716d 100644
--- a/pgproto/src/storage/value.rs
+++ b/pgproto/src/storage/value.rs
@@ -34,6 +34,15 @@ pub enum Format {
     Binary = 1,
 }
 
+impl<L: AsLua> PushInto<L> for Format {
+    type Err = tarantool::tlua::Void;
+
+    fn push_into_lua(self, lua: L) -> Result<tarantool::tlua::PushGuard<L>, (Self::Err, L)> {
+        let value = self as RawFormat;
+        value.push_into_lua(lua)
+    }
+}
+
 impl TryFrom<RawFormat> for Format {
     type Error = PgError;
     fn try_from(value: RawFormat) -> Result<Self, Self::Error> {
-- 
GitLab