From 8f5054fdd6a32e2ac16e1525c120d9eb749f5c4c Mon Sep 17 00:00:00 2001 From: Kaitmazian Maksim <m.kaitmazian@picodata.io> Date: Thu, 14 Mar 2024 19:56:13 +0300 Subject: [PATCH] refactor: move rows encoding preparation from picodata to pgproto --- pgproto/src/client/extended_query.rs | 22 ++++++++++------------ pgproto/src/entrypoints.rs | 23 ++++++++++++++++++----- pgproto/src/storage.rs | 4 ++-- pgproto/src/storage/describe.rs | 17 +++++++++++++++++ pgproto/src/storage/value.rs | 9 +++++++++ 5 files changed, 56 insertions(+), 19 deletions(-) diff --git a/pgproto/src/client/extended_query.rs b/pgproto/src/client/extended_query.rs index c6c113432b..616778306d 100644 --- a/pgproto/src/client/extended_query.rs +++ b/pgproto/src/client/extended_query.rs @@ -32,27 +32,24 @@ pub fn process_parse_message( Ok(()) } -/// Map any encoding format to per-parameter format just like pg does it in +/// Map any encoding format to per-column or per-parameter format just like pg does it in /// [exec_bind_message](https://github.com/postgres/postgres/blob/5c7038d70bb9c4d28a80b0a2051f73fafab5af3f/src/backend/tcop/postgres.c#L1840-L1845) /// or [PortalSetResultFormat](https://github.com/postgres/postgres/blob/5c7038d70bb9c4d28a80b0a2051f73fafab5af3f/src/backend/tcop/pquery.c#L623). -fn prepare_parameter_encoding_format( - formats: &[RawFormat], - nparams: usize, -) -> PgResult<Vec<Format>> { - if formats.len() == nparams { +fn prepare_encoding_format(formats: &[RawFormat], n: usize) -> PgResult<Vec<Format>> { + if formats.len() == n { // format specified for each column formats.iter().map(|i| Format::try_from(*i)).collect() } else if formats.len() == 1 { // single format specified, use it for each column - Ok(vec![Format::try_from(formats[0])?; nparams]) + Ok(vec![Format::try_from(formats[0])?; n]) } else if formats.is_empty() { // no format specified, use the default for each column - Ok(vec![Format::Text; nparams]) + Ok(vec![Format::Text; n]) } else { Err(PgError::ProtocolViolation(format!( - "got {} format codes for {} columns", + "got {} format codes for {} items", formats.len(), - nparams + n ))) } } @@ -62,7 +59,7 @@ fn decode_parameter_values( param_oids: &[Oid], formats: &[RawFormat], ) -> PgResult<Vec<PgValue>> { - let formats = prepare_parameter_encoding_format(formats, params.len())?; + let formats = prepare_encoding_format(formats, params.len())?; if params.len() != param_oids.len() { return Err(PgError::ProtocolViolation(format!( "got {} parameters, {} oids and {} formats", @@ -86,7 +83,8 @@ pub fn process_bind_message( let params = mem::take(bind.parameters_mut()); let formats = bind.parameter_format_codes(); let params = decode_parameter_values(params, &describe.param_oids, formats)?; - let result_format = bind.result_column_format_codes(); + let ncolumns = describe.ncolumns(); + let result_format = prepare_encoding_format(bind.result_column_format_codes(), ncolumns)?; manager.bind( bind.statement_name().as_deref(), diff --git a/pgproto/src/entrypoints.rs b/pgproto/src/entrypoints.rs index 8fcc7f04a9..044da70fc8 100644 --- a/pgproto/src/entrypoints.rs +++ b/pgproto/src/entrypoints.rs @@ -4,7 +4,7 @@ use crate::{ storage::{ describe::{PortalDescribe, QueryType, StatementDescribe}, result::ExecuteResult, - value::PgValue, + value::{Format, PgValue}, }, }; use postgres_types::Oid; @@ -78,8 +78,7 @@ fn parse_explain(res: Value) -> PgResult<DqlResult> { }) } -fn execute_result_from_json(json: &str) -> PgResult<ExecuteResult> { - let raw: RawExecuteResult = serde_json::from_str(json)?; +fn execute_result_from_raw_result(raw: RawExecuteResult) -> PgResult<ExecuteResult> { match raw.describe.query_type() { QueryType::Dql => { let res = parse_dql(raw.result)?; @@ -94,6 +93,20 @@ fn execute_result_from_json(json: &str) -> PgResult<ExecuteResult> { } } +fn execute_result_from_json(json: &str) -> PgResult<ExecuteResult> { + let raw: RawExecuteResult = serde_json::from_str(json)?; + execute_result_from_raw_result(raw) +} + +fn simple_execute_result_from_json(json: &str) -> PgResult<ExecuteResult> { + let mut raw: RawExecuteResult = serde_json::from_str(json)?; + // Simple query supports only the text format. + // We couldn't set the format when we were calling bind, because we didn't know the number of columns, + // but after executing the whole simple query pipeline we have a description containing this number. + raw.describe.set_text_output_format(); + execute_result_from_raw_result(raw) +} + type Entrypoint = LuaFunction<PushGuard<LuaThread>>; /// List of lua functions from sbroad that implement PG protcol API. @@ -327,7 +340,7 @@ impl Entrypoints { .simple_query .call_with_args((client_id, sql)) .map_err(|e| PgError::TarantoolError(e.into()))?; - execute_result_from_json(&json) + simple_execute_result_from_json(&json) } /// Handler for a Parse message. See self.parse for the details. @@ -350,7 +363,7 @@ impl Entrypoints { statement: &str, portal: &str, params: Vec<PgValue>, - result_format: &[i16], + result_format: Vec<Format>, ) -> PgResult<()> { self.bind .call_with_args((id, statement, portal, params, result_format)) diff --git a/pgproto/src/storage.rs b/pgproto/src/storage.rs index bf62dd8f23..1858ded3a0 100644 --- a/pgproto/src/storage.rs +++ b/pgproto/src/storage.rs @@ -1,6 +1,6 @@ use self::describe::{PortalDescribe, StatementDescribe}; use self::result::ExecuteResult; -use self::value::PgValue; +use self::value::{Format, PgValue}; use crate::client::ClientId; use crate::entrypoints::PG_ENTRYPOINTS; use crate::error::PgResult; @@ -68,7 +68,7 @@ impl StorageManager { statement: Option<&str>, portal: Option<&str>, params: Vec<PgValue>, - result_format: &[i16], + result_format: Vec<Format>, ) -> PgResult<()> { PG_ENTRYPOINTS.with(|entrypoints| { entrypoints.borrow().bind( diff --git a/pgproto/src/storage/describe.rs b/pgproto/src/storage/describe.rs index 23f69a6659..9ffedcb2e7 100644 --- a/pgproto/src/storage/describe.rs +++ b/pgproto/src/storage/describe.rs @@ -22,6 +22,12 @@ pub struct StatementDescribe { pub param_oids: Vec<Oid>, } +impl StatementDescribe { + pub fn ncolumns(&self) -> usize { + self.describe.metadata.len() + } +} + #[derive(Debug, Clone, Default, Deserialize)] pub struct PortalDescribe { #[serde(flatten)] @@ -59,6 +65,17 @@ impl PortalDescribe { pub fn output_format(&self) -> &[Format] { &self.output_format } + + // Enforce use of the text format for output rows. We use it for simple query, as it supports only the text format. + pub fn set_text_output_format(&mut self) { + let mut output_format = Vec::new(); + output_format.resize(self.ncolumns(), Format::Text); + self.output_format = output_format; + } + + pub fn ncolumns(&self) -> usize { + self.describe.metadata.len() + } } #[derive(Debug, Deserialize, PartialEq, Eq, Clone)] diff --git a/pgproto/src/storage/value.rs b/pgproto/src/storage/value.rs index a25039b8a9..eed314716d 100644 --- a/pgproto/src/storage/value.rs +++ b/pgproto/src/storage/value.rs @@ -34,6 +34,15 @@ pub enum Format { Binary = 1, } +impl<L: AsLua> PushInto<L> for Format { + type Err = tarantool::tlua::Void; + + fn push_into_lua(self, lua: L) -> Result<tarantool::tlua::PushGuard<L>, (Self::Err, L)> { + let value = self as RawFormat; + value.push_into_lua(lua) + } +} + impl TryFrom<RawFormat> for Format { type Error = PgError; fn try_from(value: RawFormat) -> Result<Self, Self::Error> { -- GitLab