diff --git a/src/pgproto.rs b/src/pgproto.rs index 44f20f4d739807eb9650ee0d8242bc4f27163663..e2a24117991849171867fd594bdfa75473c25e22 100644 --- a/src/pgproto.rs +++ b/src/pgproto.rs @@ -10,13 +10,12 @@ use tarantool::coio::{CoIOListener, CoIOStream}; mod backend; mod client; -mod entrypoints; mod error; mod messages; mod server; -mod storage; mod stream; mod tls; +mod value; pub const DEFAULT_MAX_PG_STATEMENTS: usize = 50; pub const DEFAULT_MAX_PG_PORTALS: usize = 50; diff --git a/src/pgproto/backend.rs b/src/pgproto/backend.rs index 50965b06cd438f8d6340a7da90b92e597cce9385..f99ff5352873ead5f1e1913a2584b3dd507118bd 100644 --- a/src/pgproto/backend.rs +++ b/src/pgproto/backend.rs @@ -1,13 +1,15 @@ use self::describe::{PortalDescribe, StatementDescribe}; +use self::result::ExecuteResult; use self::storage::{with_portals_mut, Portal, Statement, PG_PORTALS, PG_STATEMENTS}; use super::client::ClientId; use super::error::{PgError, PgResult}; -use crate::pgproto::storage::value::Format; +use crate::pgproto::value::{Format, PgValue, RawFormat}; use crate::schema::ADMIN_ID; use crate::sql::otm::TracerKind; use crate::sql::router::RouterRuntime; use crate::sql::with_tracer; use crate::traft::error::Error; +use bytes::Bytes; use opentelemetry::sdk::trace::Tracer; use opentelemetry::Context; use postgres_types::Oid; @@ -20,15 +22,61 @@ use sbroad::ir::Plan as IrPlan; use sbroad::otm::{query_id, query_span, OTM_CHAR_LIMIT}; use sbroad::utils::MutexLike; use smol_str::ToSmolStr; +use std::iter::zip; use std::rc::Rc; +use std::sync::atomic::{AtomicU32, Ordering}; use tarantool::session::with_su; -use tarantool::tuple::Tuple; mod pgproc; +mod result; mod storage; pub mod describe; +fn decode_parameter_values( + params: Vec<Option<Bytes>>, + param_oids: &[Oid], + formats: &[Format], +) -> PgResult<Vec<Value>> { + if params.len() != param_oids.len() { + return Err(PgError::ProtocolViolation(format!( + "got {} parameters, {} oids and {} formats", + params.len(), + param_oids.len(), + formats.len() + ))); + } + + zip(zip(params, param_oids), formats) + .map(|((bytes, oid), format)| { + let value = PgValue::decode(bytes.as_ref(), *oid, *format)?; + Ok(value.into_inner()) + }) + .collect() +} + +/// Map any encoding format to per-column or per-parameter format just like pg does it in +/// [exec_bind_message](https://github.com/postgres/postgres/blob/5c7038d70bb9c4d28a80b0a2051f73fafab5af3f/src/backend/tcop/postgres.c#L1840-L1845) +/// or [PortalSetResultFormat](https://github.com/postgres/postgres/blob/5c7038d70bb9c4d28a80b0a2051f73fafab5af3f/src/backend/tcop/pquery.c#L623). +fn prepare_encoding_format(formats: &[RawFormat], n: usize) -> PgResult<Vec<Format>> { + if formats.len() == n { + // format specified for each column + formats.iter().map(|i| Format::try_from(*i)).collect() + } else if formats.len() == 1 { + // single format specified, use it for each column + Ok(vec![Format::try_from(formats[0])?; n]) + } else if formats.is_empty() { + // no format specified, use the default for each column + Ok(vec![Format::Text; n]) + } else { + Err(PgError::ProtocolViolation(format!( + "got {} format codes for {} items", + formats.len(), + n + ))) + } +} + // helper function to get `TracerRef` fn get_tracer_param(traceable: bool) -> &'static Tracer { let kind = TracerKind::from_traceable(traceable); @@ -40,7 +88,7 @@ pub fn bind( stmt_name: String, portal_name: String, params: Vec<Value>, - output_format: Vec<u8>, + result_format: Vec<Format>, traceable: bool, ) -> PgResult<()> { let key = (client_id, stmt_name.into()); @@ -63,11 +111,7 @@ pub fn bind( plan.apply_options()?; plan.optimize()?; } - let format = output_format - .into_iter() - .map(|raw| Format::try_from(raw as i16).unwrap()) - .collect(); - Portal::new(plan, statement.clone(), format) + Portal::new(plan, statement.clone(), result_format) }, )?; @@ -79,7 +123,12 @@ pub fn bind( Ok(()) } -pub fn execute(id: ClientId, name: String, max_rows: i64, traceable: bool) -> PgResult<Tuple> { +pub fn execute( + id: ClientId, + name: String, + max_rows: i64, + traceable: bool, +) -> PgResult<ExecuteResult> { let max_rows = if max_rows <= 0 { i64::MAX } else { max_rows }; let name = Rc::from(name); @@ -89,7 +138,7 @@ pub fn execute(id: ClientId, name: String, max_rows: i64, traceable: bool) -> Pg })?; with_portals_mut((id, name), |portal| { let ctx = with_tracer(Context::new(), TracerKind::from_traceable(traceable)); - query_span::<PgResult<Tuple>, _>( + query_span::<PgResult<_>, _>( "\"api.router.execute\"", statement.id(), get_tracer_param(traceable), @@ -159,7 +208,7 @@ pub fn parse( ) } -pub fn describe_stmt(id: ClientId, name: String) -> PgResult<StatementDescribe> { +pub fn describe_statement(id: ClientId, name: &str) -> PgResult<StatementDescribe> { let key = (id, name.into()); let Some(statement) = PG_STATEMENTS.with(|storage| storage.borrow().get(&key)) else { return Err(PgError::Other( @@ -169,24 +218,169 @@ pub fn describe_stmt(id: ClientId, name: String) -> PgResult<StatementDescribe> Ok(statement.describe().clone()) } -pub fn describe_portal(id: ClientId, name: String) -> PgResult<PortalDescribe> { +pub fn describe_portal(id: ClientId, name: &str) -> PgResult<PortalDescribe> { with_portals_mut((id, name.into()), |portal| Ok(portal.describe().clone())) } -pub fn close_stmt(id: ClientId, name: String) { +pub fn close_statement(id: ClientId, name: &str) { // Close can't cause an error in PG. PG_STATEMENTS.with(|storage| storage.borrow_mut().remove(&(id, name.into()))); } -pub fn close_portal(id: ClientId, name: String) { +pub fn close_portal(id: ClientId, name: &str) { // Close can't cause an error in PG. PG_PORTALS.with(|storage| storage.borrow_mut().remove(&(id, name.into()))); } -pub fn close_client_stmts(id: ClientId) { +pub fn close_client_statements(id: ClientId) { PG_STATEMENTS.with(|storage| storage.borrow_mut().remove_by_client_id(id)) } pub fn close_client_portals(id: ClientId) { PG_PORTALS.with(|storage| storage.borrow_mut().remove_by_client_id(id)) } + +/// Each postgres client uses its own backend to handle incoming messages. +pub struct Backend { + /// A unique identificator of a postgres client. It is used as a part of a key in the portal + /// storage, allowing to store in a single storage portals from many clients. + client_id: ClientId, +} + +impl Backend { + pub fn new() -> Self { + /// Generate a unique client id. + fn unique_id() -> ClientId { + static ID_COUNTER: AtomicU32 = AtomicU32::new(0); + ID_COUNTER.fetch_add(1, Ordering::Relaxed) + } + + Self { + client_id: unique_id(), + } + } + + /// Execute a simple query. Handler for a Query message. + /// + /// First, it closes an unnamed portal and statement, just like PG does when gets a Query + /// messsage. After that the extended pipeline is executed on unnamed portal and statement: + /// parse + bind + describe + execute and result is returned. + /// + /// Note that it closes the uunamed portal and statement even in case of a failure. + pub fn simple_query(&self, sql: String) -> PgResult<ExecuteResult> { + let close_unnamed = || { + self.close_statement(None); + self.close_portal(None); + }; + + let simple_query = || { + close_unnamed(); + self.parse(None, sql, vec![])?; + self.bind(None, None, vec![], &[], &[Format::Text as RawFormat])?; + self.execute(None, -1) + }; + + simple_query().map_err(|err| { + close_unnamed(); + err + }) + } + + /// Handler for a Describe message. + /// + /// Describe a statement. + pub fn describe_statement(&self, name: Option<&str>) -> PgResult<StatementDescribe> { + let name = name.unwrap_or_default(); + describe_statement(self.client_id, name) + } + + /// Handler for a Describe message. + /// + /// Describe a portal. + pub fn describe_portal(&self, name: Option<&str>) -> PgResult<PortalDescribe> { + let name = name.unwrap_or_default(); + describe_portal(self.client_id, name) + } + + /// Handler for a Parse message. + /// + /// Create a statement from a query and store it in the statement storage with the + /// given name. In case of a conflict the strategy is the same with PG. + /// The statement lasts until it is explicitly closed. + pub fn parse(&self, name: Option<String>, sql: String, param_oids: Vec<Oid>) -> PgResult<()> { + let name = name.unwrap_or_default(); + parse(self.client_id, name, sql, param_oids, false) + } + + /// Handler for a Bind message. + /// + /// Copy the sources statement, create a portal by binding the given parameters and store it + /// in the portal storage. In case of a conflict the strategy is the same with PG. + /// The portal lasts until it is explicitly closed. + pub fn bind( + &self, + statement: Option<String>, + portal: Option<String>, + params: Vec<Option<Bytes>>, + params_format: &[RawFormat], + result_format: &[RawFormat], + ) -> PgResult<()> { + let statement = statement.unwrap_or_default(); + let portal = portal.unwrap_or_default(); + + let describe = describe_statement(self.client_id, &statement)?; + let params_format = prepare_encoding_format(params_format, params.len())?; + let result_format = prepare_encoding_format(result_format, describe.ncolumns())?; + let params = decode_parameter_values(params, &describe.param_oids, ¶ms_format)?; + + bind( + self.client_id, + statement, + portal, + params, + result_format, + true, + ) + } + + /// Handler for an Execute message. + /// + /// Take a portal from the storage and retrive at most max_rows rows from it. In case of + /// non-dql queries max_rows is ignored and result with no rows is returned. + pub fn execute(&self, portal: Option<String>, max_rows: i64) -> PgResult<ExecuteResult> { + let name = portal.unwrap_or_default(); + execute(self.client_id, name, max_rows, true) + } + + /// Handler for a Close message. + /// + /// Close a portal. It's not an error to close a non-existent portal. + pub fn close_portal(&self, name: Option<&str>) { + let name = name.unwrap_or_default(); + close_portal(self.client_id, name) + } + + /// Handler for a Close message. + /// + /// Close a statement. It's not an error to close a non-existent statement. + pub fn close_statement(&self, name: Option<&str>) { + let name = name.unwrap_or_default(); + close_statement(self.client_id, name) + } + + /// Close all the client's portals. It should be called at the end of the transaction. + pub fn close_all_portals(&self) { + close_client_portals(self.client_id) + } + + fn on_disconnect(&self) { + close_client_statements(self.client_id); + close_client_portals(self.client_id); + } +} + +impl Drop for Backend { + fn drop(&mut self) { + self.on_disconnect() + } +} diff --git a/src/pgproto/backend/describe.rs b/src/pgproto/backend/describe.rs index fdb2a109c34c11b0a5ea8a8cbc829cbddaa64cbd..65a19b8336c4e958885109af9fffc26dbb012f91 100644 --- a/src/pgproto/backend/describe.rs +++ b/src/pgproto/backend/describe.rs @@ -1,5 +1,5 @@ use crate::pgproto::error::PgResult; -use crate::pgproto::storage::value::{self, Format}; +use crate::pgproto::value::{self, Format}; use pgwire::messages::data::{FieldDescription, RowDescription}; use postgres_types::{Oid, Type}; use sbroad::errors::{Entity, SbroadError}; @@ -177,12 +177,13 @@ impl TryFrom<&Node> for CommandTag { #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] pub struct MetadataColumn { name: String, - r#type: String, + #[serde(rename = "type")] + ty: String, } impl MetadataColumn { fn new(name: String, ty: String) -> Self { - Self { name, r#type: ty } + Self { name, ty } } } @@ -305,7 +306,7 @@ impl Describe { .metadata .iter() .map(|col| { - let type_str = col.r#type.as_str(); + let type_str = col.ty.as_str(); value::type_from_name(type_str) .map(|ty| field_description(col.name.clone(), ty, Format::Text)) }) @@ -369,7 +370,7 @@ impl PortalDescribe { let output_format = &self.output_format; let row_description = zip(metadata, output_format) .map(|(col, format)| { - let type_str = col.r#type.as_str(); + let type_str = col.ty.as_str(); value::type_from_name(type_str) .map(|ty| field_description(col.name.clone(), ty, *format)) }) @@ -390,15 +391,4 @@ impl PortalDescribe { pub fn output_format(&self) -> &[Format] { &self.output_format } - - // Enforce use of the text format for output rows. We use it for simple query, as it supports only the text format. - pub fn set_text_output_format(&mut self) { - let mut output_format = Vec::new(); - output_format.resize(self.ncolumns(), Format::Text); - self.output_format = output_format; - } - - pub fn ncolumns(&self) -> usize { - self.describe.metadata.len() - } } diff --git a/src/pgproto/backend/pgproc.rs b/src/pgproto/backend/pgproc.rs index 5f98cf925ca38a7b002a5dedeeefb73f1954c388..763fdc892998b3485f02af574db5ed8d6ffc7e62 100644 --- a/src/pgproto/backend/pgproc.rs +++ b/src/pgproto/backend/pgproc.rs @@ -1,19 +1,21 @@ -use super::describe::{PortalDescribe, StatementDescribe}; +use super::describe::{PortalDescribe, QueryType, StatementDescribe}; use super::storage::{UserPortalNames, UserStatementNames}; use crate::pgproto::backend; +use crate::pgproto::error::DecodingError; +use crate::pgproto::value::Format; use crate::pgproto::{client::ClientId, error::PgResult}; use ::tarantool::proc; use postgres_types::Oid; use sbroad::ir::value::{LuaValue, Value}; -use serde::Deserialize; -use tarantool::tuple::Tuple; +use serde::{Deserialize, Serialize}; +use tarantool::tuple::{Encode, Tuple}; struct BindArgs { id: ClientId, stmt_name: String, portal_name: String, params: Vec<Value>, - encoding_format: Vec<u8>, + encoding_format: Vec<Format>, traceable: bool, } @@ -28,7 +30,7 @@ impl<'de> Deserialize<'de> for BindArgs { String, String, Option<Vec<LuaValue>>, - Vec<u8>, + Vec<i16>, Option<bool>, ); @@ -41,12 +43,17 @@ impl<'de> Deserialize<'de> for BindArgs { .map(Value::from) .collect::<Vec<Value>>(); + let format: Vec<_> = encoding_format + .into_iter() + .map(|raw| Format::try_from(raw).unwrap_or_default()) + .collect(); + Ok(Self { id, stmt_name, portal_name, params, - encoding_format, + encoding_format: format, traceable: traceable.unwrap_or(false), }) } @@ -68,12 +75,12 @@ pub fn proc_pg_bind(args: BindArgs) -> PgResult<()> { #[proc] pub fn proc_pg_describe_stmt(id: ClientId, name: String) -> PgResult<StatementDescribe> { - backend::describe_stmt(id, name) + backend::describe_statement(id, &name) } #[proc] pub fn proc_pg_describe_portal(id: ClientId, name: String) -> PgResult<PortalDescribe> { - backend::describe_portal(id, name) + backend::describe_portal(id, &name) } #[proc] @@ -83,7 +90,46 @@ pub fn proc_pg_execute( max_rows: i64, traceable: bool, ) -> PgResult<Tuple> { - backend::execute(id, name, max_rows, traceable) + let result = backend::execute(id, name, max_rows, traceable)?; + let bytes = match result.query_type() { + QueryType::Explain | QueryType::Dql => { + #[derive(Serialize)] + struct ProcResult { + rows: Vec<Vec<LuaValue>>, + is_finished: bool, + } + impl Encode for ProcResult {} + + let is_finished = result.is_portal_finished(); + let rows = result + .into_values_stream() + .map(|values| { + values + .into_iter() + .map(|v| LuaValue::from(v.into_inner())) + .collect() + }) + .collect(); + let result = ProcResult { rows, is_finished }; + rmp_serde::to_vec_named(&vec![result]) + } + QueryType::Acl | QueryType::Ddl | QueryType::Dml => { + #[derive(Serialize)] + struct ProcResult { + row_count: Option<usize>, + } + impl Encode for ProcResult {} + + let result = ProcResult { + row_count: result.row_count(), + }; + rmp_serde::to_vec_named(&vec![result]) + } + }; + + let bytes = bytes.map_err(|e| DecodingError::Other(e.into()))?; + let tuple = Tuple::try_from_slice(&bytes).map_err(|e| DecodingError::Other(e.into()))?; + Ok(tuple) } #[proc] @@ -99,17 +145,17 @@ pub fn proc_pg_parse( #[proc] pub fn proc_pg_close_stmt(id: ClientId, name: String) { - backend::close_stmt(id, name) + backend::close_statement(id, &name) } #[proc] pub fn proc_pg_close_portal(id: ClientId, name: String) { - backend::close_portal(id, name) + backend::close_portal(id, &name) } #[proc] pub fn proc_pg_close_client_stmts(id: ClientId) { - backend::close_client_stmts(id) + backend::close_client_statements(id) } #[proc] diff --git a/src/pgproto/storage/result.rs b/src/pgproto/backend/result.rs similarity index 75% rename from src/pgproto/storage/result.rs rename to src/pgproto/backend/result.rs index 1c80f9334928c247328cc2d969ac2fa1d4201860..e6e5a40d5b176eaaedfd8c8c0252219af77eec7e 100644 --- a/src/pgproto/storage/result.rs +++ b/src/pgproto/backend/result.rs @@ -1,6 +1,6 @@ -use super::value::{Format, PgValue}; use crate::pgproto::backend::describe::{CommandTag, PortalDescribe, QueryType}; use crate::pgproto::error::PgResult; +use crate::pgproto::value::{Format, PgValue}; use bytes::BytesMut; use pgwire::messages::data::{DataRow, RowDescription}; use std::iter::zip; @@ -13,6 +13,7 @@ fn encode_row(values: Vec<PgValue>, formats: &[Format], buf: &mut BytesMut) -> D DataRow::new(row) } +#[derive(Debug)] pub struct ExecuteResult { describe: PortalDescribe, values_stream: IntoIter<Vec<PgValue>>, @@ -22,28 +23,25 @@ pub struct ExecuteResult { } impl ExecuteResult { - pub fn new( - rows: Vec<Vec<PgValue>>, - describe: PortalDescribe, - is_portal_finished: bool, - ) -> Self { - let values_stream = rows.into_iter(); + /// Create a new finished result. It is used for non-dql queries. + pub fn new(row_count: usize, describe: PortalDescribe) -> Self { Self { - values_stream, + values_stream: Default::default(), describe, - row_count: 0, - is_portal_finished, + row_count, + is_portal_finished: true, buf: BytesMut::default(), } } - pub fn empty(row_count: usize, describe: PortalDescribe) -> Self { + /// Create a query result with rows. It is used for dql-like queries. + pub fn with_rows(self, rows: Vec<Vec<PgValue>>, is_portal_finished: bool) -> Self { + let values_stream = rows.into_iter(); Self { - values_stream: Default::default(), - describe, - row_count, - is_portal_finished: true, - buf: BytesMut::default(), + values_stream, + row_count: 0, + is_portal_finished, + ..self } } @@ -67,6 +65,16 @@ impl ExecuteResult { } } +impl ExecuteResult { + pub fn query_type(&self) -> &QueryType { + self.describe.describe.query_type() + } + + pub fn into_values_stream(self) -> IntoIter<Vec<PgValue>> { + self.values_stream + } +} + impl Iterator for ExecuteResult { type Item = DataRow; diff --git a/src/pgproto/backend/storage.rs b/src/pgproto/backend/storage.rs index 8479ef0f16fada998a149cac9b79f3d14a3149be..b02934526b144a53df476d9d02ae191e6615c4c0 100644 --- a/src/pgproto/backend/storage.rs +++ b/src/pgproto/backend/storage.rs @@ -1,6 +1,8 @@ +use super::describe::QueryType; use super::describe::{Describe, PortalDescribe, StatementDescribe}; +use super::result::ExecuteResult; use crate::pgproto::error::{PgError, PgResult}; -use crate::pgproto::storage::value::Format; +use crate::pgproto::value::{Format, PgValue}; use crate::pgproto::{DEFAULT_MAX_PG_PORTALS, DEFAULT_MAX_PG_STATEMENTS}; use crate::traft::node; use ::tarantool::tuple::Tuple; @@ -370,41 +372,50 @@ pub struct Portal { enum PortalState { #[default] NotStarted, - Running(IntoIter<Value>), - Finished(Option<Tuple>), + Running(IntoIter<Vec<PgValue>>), + Finished(Option<ExecuteResult>), } -/// Try to get rows from the query execution result. -/// Some is returned for dql-like queires (dql or explain), otherwise None is returned. -fn tuple_as_rows(tuple: &Tuple) -> Option<Vec<Value>> { +fn mp_row_into_pg_row(mp: Vec<Value>) -> PgResult<Vec<PgValue>> { + mp.into_iter().map(PgValue::try_from).collect() +} + +fn mp_rows_into_pg_rows(mp: Vec<Vec<Value>>) -> PgResult<Vec<Vec<PgValue>>> { + mp.into_iter().map(mp_row_into_pg_row).collect() +} + +/// Get rows from dql-like(dql or explain) query execution result. +fn get_rows_from_tuple(tuple: &Tuple) -> PgResult<Vec<Vec<PgValue>>> { #[derive(Deserialize, Default, Debug)] struct DqlResult { - rows: Vec<Value>, + rows: Vec<Vec<Value>>, } if let Ok(Some(res)) = tuple.field::<DqlResult>(0) { - return Some(res.rows); + return mp_rows_into_pg_rows(res.rows); } // Try to parse explain result. if let Ok(Some(res)) = tuple.field::<Vec<Value>>(0) { - return Some(res); + let rows = res.into_iter().map(|row| vec![row]).collect(); + return mp_rows_into_pg_rows(rows); } - None + Err(PgError::InternalError( + "couldn't get rows from the result tuple".into(), + )) } -fn take_rows(rows: &mut IntoIter<Value>, max_rows: usize) -> PgResult<Tuple> { - let is_finished = rows.len() <= max_rows; - let rows = rows.take(max_rows).collect(); - #[derive(Serialize)] - struct RunningResult { - rows: Vec<Value>, - is_finished: bool, - } - let result = RunningResult { rows, is_finished }; - let mp = rmp_serde::to_vec_named(&vec![result])?; - let ret = Tuple::try_from_slice(&mp)?; - Ok(ret) +/// Get row_count from result tuple. +fn get_row_count_from_tuple(tuple: &Tuple) -> PgResult<usize> { + #[derive(Deserialize)] + struct RowCount { + row_count: usize, + } + let res: RowCount = tuple.field(0)?.ok_or(PgError::InternalError( + "couldn't get row count from the result tuple".into(), + ))?; + + Ok(res.row_count) } impl Portal { @@ -419,27 +430,30 @@ impl Portal { }) } - pub fn execute(&mut self, max_rows: usize) -> PgResult<Tuple> { + pub fn execute(&mut self, max_rows: usize) -> PgResult<ExecuteResult> { loop { match &mut self.state { PortalState::NotStarted => self.start()?, - PortalState::Finished(Some(res)) => { - // clone only increments tuple's refcounter - let res = res.clone(); - self.state = PortalState::Finished(None); - return Ok(res); + PortalState::Finished(Some(_)) => { + let state = std::mem::replace(&mut self.state, PortalState::Finished(None)); + match state { + PortalState::Finished(Some(result)) => return Ok(result), + _ => unreachable!(), + } } PortalState::Running(ref mut rows) => { - let res = take_rows(rows, max_rows)?; + let taken = rows.take(max_rows).collect(); if rows.len() == 0 { self.state = PortalState::Finished(None); } - return Ok(res); + let is_finished = matches!(self.state, PortalState::Finished(_)); + return Ok(ExecuteResult::new(0, self.describe().clone()) + .with_rows(taken, is_finished)); } _ => { return Err(PgError::Other( format!("Can't execute portal in state {:?}", self.state).into(), - )) + )); } } } @@ -458,12 +472,20 @@ impl Portal { &runtime, HashMap::new(), ); - let res = dispatch(query)?; - if let Some(rows) = tuple_as_rows(&res) { - self.state = PortalState::Running(rows.into_iter()); - } else { - self.state = PortalState::Finished(Some(res)); - } + let tuple = dispatch(query)?; + self.state = match self.describe().query_type() { + QueryType::Dml => { + let row_count = get_row_count_from_tuple(&tuple)?; + PortalState::Finished(Some(ExecuteResult::new(row_count, self.describe().clone()))) + } + QueryType::Acl | QueryType::Ddl => { + PortalState::Finished(Some(ExecuteResult::new(0, self.describe().clone()))) + } + QueryType::Dql | QueryType::Explain => { + let rows = get_rows_from_tuple(&tuple)?.into_iter(); + PortalState::Running(rows) + } + }; Ok(()) } diff --git a/src/pgproto/client.rs b/src/pgproto/client.rs index 308375c0568e3a4f9a6170c13162cbd90a2a473a..c7eb472a36feadc887d9145ab088d58805d1e2cb 100644 --- a/src/pgproto/client.rs +++ b/src/pgproto/client.rs @@ -1,7 +1,7 @@ +use super::backend::Backend; use super::client::simple_query::process_query_message; use super::error::*; use super::messages; -use super::storage::StorageManager; use super::stream::{BeMessage, FeMessage, PgStream}; use super::tls::TlsAcceptor; use crate::tlog; @@ -17,8 +17,8 @@ pub type ClientId = u32; /// Postgres client representation. pub struct PgClient<S> { - // The portal and statement storage manager. - manager: StorageManager, + /// Postgres backend that handles queries. + backend: Backend, /// Stream for network communication. stream: PgStream<S>, @@ -40,7 +40,7 @@ impl<S: io::Read + io::Write> PgClient<S> { tlog!(Info, "client authenticated"); Ok(PgClient { - manager: StorageManager::new(), + backend: Backend::new(), loop_state: MessageLoopState::ReadyForQuery, stream, }) @@ -101,7 +101,7 @@ impl<S: io::Read + io::Write> PgClient<S> { match message { FeMessage::Query(query) => { tlog!(Info, "executing simple query: {}", query.query); - process_query_message(&mut self.stream, &self.manager, query)?; + process_query_message(&mut self.stream, &self.backend, query)?; self.loop_state = MessageLoopState::ReadyForQuery; } FeMessage::Parse(parse) => { @@ -112,7 +112,7 @@ impl<S: io::Read + io::Write> PgClient<S> { parse.query, ); self.loop_state = MessageLoopState::RunningExtendedQuery; - extended_query::process_parse_message(&mut self.stream, &self.manager, parse)?; + extended_query::process_parse_message(&mut self.stream, &self.backend, parse)?; } FeMessage::Bind(bind) => { tlog!( @@ -122,7 +122,7 @@ impl<S: io::Read + io::Write> PgClient<S> { bind.portal_name.as_deref().unwrap_or_default() ); self.loop_state = MessageLoopState::RunningExtendedQuery; - extended_query::process_bind_message(&mut self.stream, &self.manager, bind)?; + extended_query::process_bind_message(&mut self.stream, &self.backend, bind)?; } FeMessage::Execute(execute) => { tlog!( @@ -131,7 +131,7 @@ impl<S: io::Read + io::Write> PgClient<S> { execute.name.as_deref().unwrap_or_default() ); self.loop_state = MessageLoopState::RunningExtendedQuery; - extended_query::process_execute_message(&mut self.stream, &self.manager, execute)?; + extended_query::process_execute_message(&mut self.stream, &self.backend, execute)?; } FeMessage::Describe(describe) => { tlog!( @@ -143,7 +143,7 @@ impl<S: io::Read + io::Write> PgClient<S> { self.loop_state = MessageLoopState::RunningExtendedQuery; extended_query::process_describe_message( &mut self.stream, - &self.manager, + &self.backend, describe, )?; } @@ -155,7 +155,7 @@ impl<S: io::Read + io::Write> PgClient<S> { close.name.as_deref().unwrap_or_default() ); self.loop_state = MessageLoopState::RunningExtendedQuery; - extended_query::process_close_message(&mut self.stream, &self.manager, close)?; + extended_query::process_close_message(&mut self.stream, &self.backend, close)?; } FeMessage::Flush(_) => { tlog!(Info, "flushing"); @@ -165,7 +165,7 @@ impl<S: io::Read + io::Write> PgClient<S> { FeMessage::Sync(_) => { tlog!(Info, "syncing"); self.loop_state = MessageLoopState::ReadyForQuery; - extended_query::process_sync_mesage(&self.manager)?; + extended_query::process_sync_mesage(&self.backend); } FeMessage::Terminate(_) => { tlog!(Info, "terminating the session"); @@ -185,7 +185,7 @@ impl<S: io::Read + io::Write> PgClient<S> { loop { if let FeMessage::Sync(_) = self.stream.read_message()? { self.loop_state = MessageLoopState::ReadyForQuery; - extended_query::process_sync_mesage(&self.manager)?; + extended_query::process_sync_mesage(&self.backend); break; } } diff --git a/src/pgproto/client/extended_query.rs b/src/pgproto/client/extended_query.rs index d1c27f84e8858db6248d58e3eba96cad70961f32..cb6d140bf5fb33ba99f9d40982836bbea79be79d 100644 --- a/src/pgproto/client/extended_query.rs +++ b/src/pgproto/client/extended_query.rs @@ -1,16 +1,12 @@ -use crate::pgproto::storage::value::{Format, PgValue, RawFormat}; +use crate::pgproto::backend::Backend; use crate::pgproto::stream::{BeMessage, FeMessage}; use crate::pgproto::{ error::{PgError, PgResult}, messages, - storage::StorageManager, stream::PgStream, }; -use bytes::Bytes; use pgwire::messages::extendedquery::{Bind, Close, Describe, Execute, Parse}; -use postgres_types::Oid; use std::io::{Read, Write}; -use std::iter::zip; fn use_tarantool_parameter_placeholders(sql: &str) -> String { // TODO: delete it after the pg parameters are supported, @@ -23,98 +19,42 @@ fn use_tarantool_parameter_placeholders(sql: &str) -> String { pub fn process_parse_message( stream: &mut PgStream<impl Read + Write>, - manager: &StorageManager, + backend: &Backend, parse: Parse, ) -> PgResult<()> { let query = use_tarantool_parameter_placeholders(&parse.query); - manager.parse(parse.name.as_deref(), &query, &parse.type_oids)?; + backend.parse(parse.name, query, parse.type_oids)?; stream.write_message_noflush(messages::parse_complete())?; Ok(()) } -/// Map any encoding format to per-column or per-parameter format just like pg does it in -/// [exec_bind_message](https://github.com/postgres/postgres/blob/5c7038d70bb9c4d28a80b0a2051f73fafab5af3f/src/backend/tcop/postgres.c#L1840-L1845) -/// or [PortalSetResultFormat](https://github.com/postgres/postgres/blob/5c7038d70bb9c4d28a80b0a2051f73fafab5af3f/src/backend/tcop/pquery.c#L623). -fn prepare_encoding_format(formats: &[RawFormat], n: usize) -> PgResult<Vec<Format>> { - if formats.len() == n { - // format specified for each column - formats.iter().map(|i| Format::try_from(*i)).collect() - } else if formats.len() == 1 { - // single format specified, use it for each column - Ok(vec![Format::try_from(formats[0])?; n]) - } else if formats.is_empty() { - // no format specified, use the default for each column - Ok(vec![Format::Text; n]) - } else { - Err(PgError::ProtocolViolation(format!( - "got {} format codes for {} items", - formats.len(), - n - ))) - } -} - -fn decode_parameter_values( - params: &[Option<Bytes>], - param_oids: &[Oid], - formats: &[RawFormat], -) -> PgResult<Vec<PgValue>> { - let formats = prepare_encoding_format(formats, params.len())?; - if params.len() != param_oids.len() { - return Err(PgError::ProtocolViolation(format!( - "got {} parameters, {} oids and {} formats", - params.len(), - param_oids.len(), - formats.len() - ))); - } - - zip(zip(params, param_oids), formats) - .map(|((bytes, oid), format)| PgValue::decode(bytes.as_ref(), *oid, format)) - .collect() -} - pub fn process_bind_message( stream: &mut PgStream<impl Read + Write>, - manager: &StorageManager, + backend: &Backend, bind: Bind, ) -> PgResult<()> { - let describe = manager.describe_statement(bind.statement_name.as_deref())?; - let params = decode_parameter_values( - &bind.parameters, - &describe.param_oids, + backend.bind( + bind.statement_name, + bind.portal_name, + bind.parameters, &bind.parameter_format_codes, + &bind.result_column_format_codes, )?; - let result_format = - prepare_encoding_format(&bind.result_column_format_codes, describe.ncolumns())?; - manager.bind( - bind.statement_name.as_deref(), - bind.portal_name.as_deref(), - params, - result_format, - )?; stream.write_message_noflush(messages::bind_complete())?; Ok(()) } pub fn process_execute_message( stream: &mut PgStream<impl Read + Write>, - manager: &StorageManager, + backend: &Backend, execute: Execute, ) -> PgResult<()> { - let mut count = execute.max_rows as i64; - let mut execute_result = manager.execute(execute.name.as_deref())?; - if count <= 0 { - count = std::i64::MAX; - } + let max_rows = execute.max_rows as i64; + let mut execute_result = backend.execute(execute.name, max_rows)?; - for _ in 0..count { - if let Some(row) = execute_result.next() { - stream.write_message_noflush(messages::data_row(row))?; - } else { - break; - } + for row in execute_result.by_ref() { + stream.write_message_noflush(messages::data_row(row))?; } if execute_result.is_portal_finished() { @@ -129,10 +69,10 @@ pub fn process_execute_message( } fn describe_statement( - manager: &StorageManager, + backend: &Backend, statement: Option<&str>, ) -> PgResult<(BeMessage, BeMessage)> { - let stmt_describe = manager.describe_statement(statement)?; + let stmt_describe = backend.describe_statement(statement)?; let param_oids = stmt_describe.param_oids; let describe = stmt_describe.describe; @@ -147,8 +87,8 @@ fn describe_statement( } } -fn describe_portal(manager: &StorageManager, portal: Option<&str>) -> PgResult<BeMessage> { - let describe = manager.describe_portal(portal)?; +fn describe_portal(backend: &Backend, portal: Option<&str>) -> PgResult<BeMessage> { + let describe = backend.describe_portal(portal)?; if let Some(row_description) = describe.row_description()? { Ok(messages::row_description(row_description)) } else { @@ -158,19 +98,19 @@ fn describe_portal(manager: &StorageManager, portal: Option<&str>) -> PgResult<B pub fn process_describe_message( stream: &mut PgStream<impl Read + Write>, - manager: &StorageManager, + backend: &Backend, describe: Describe, ) -> PgResult<()> { let name = describe.name.as_deref(); match describe.target_type { b'S' => { - let (params_desc, rows_desc) = describe_statement(manager, name)?; + let (params_desc, rows_desc) = describe_statement(backend, name)?; stream.write_message_noflush(params_desc)?; stream.write_message_noflush(rows_desc)?; Ok(()) } b'P' => { - let rows_desc = describe_portal(manager, name)?; + let rows_desc = describe_portal(backend, name)?; stream.write_message_noflush(rows_desc)?; Ok(()) } @@ -183,13 +123,13 @@ pub fn process_describe_message( pub fn process_close_message( stream: &mut PgStream<impl Read + Write>, - manager: &StorageManager, + backend: &Backend, close: Close, ) -> PgResult<()> { let name = close.name.as_deref(); match close.target_type { - b'S' => manager.close_statement(name)?, - b'P' => manager.close_portal(name)?, + b'S' => backend.close_statement(name), + b'P' => backend.close_portal(name), _ => { return Err(PgError::ProtocolViolation(format!( "unknown close type \'{}\'", @@ -201,13 +141,13 @@ pub fn process_close_message( Ok(()) } -pub fn process_sync_mesage(manager: &StorageManager) -> PgResult<()> { +pub fn process_sync_mesage(backend: &Backend) { // By default, PG runs in autocommit mode, which means that every statement is ran inside its own transaction. // In simple query statement means the query inside a Query message. // In extended query statement means everything before a Sync message. // When PG gets a Sync mesage it finishes the current transaction by calling finish_xact_command, // which drops all non-holdable portals. We close all portals here because we don't have the holdable portals. - manager.close_all_portals() + backend.close_all_portals() } pub fn is_extended_query_message(message: &FeMessage) -> bool { diff --git a/src/pgproto/client/simple_query.rs b/src/pgproto/client/simple_query.rs index 6084527081fc5e74f1297104b0291d69d0c0e44d..7ab7c17904096a3e1501989f1748bcd0d9f86389 100644 --- a/src/pgproto/client/simple_query.rs +++ b/src/pgproto/client/simple_query.rs @@ -1,14 +1,14 @@ -use crate::pgproto::storage::StorageManager; +use crate::pgproto::backend::Backend; use crate::pgproto::{error::PgResult, messages, stream::PgStream}; use pgwire::messages::simplequery::Query; use std::io::{Read, Write}; pub fn process_query_message( stream: &mut PgStream<impl Read + Write>, - manager: &StorageManager, + backend: &Backend, query: Query, ) -> PgResult<()> { - let mut query_result = manager.simple_query(&query.query)?; + let mut query_result = backend.simple_query(query.query)?; if let Some(row_description) = query_result.row_description()? { let row_description = messages::row_description(row_description); diff --git a/src/pgproto/entrypoints.rs b/src/pgproto/entrypoints.rs deleted file mode 100644 index c4b6b65c36906d2489d48e5eb6cf238386878076..0000000000000000000000000000000000000000 --- a/src/pgproto/entrypoints.rs +++ /dev/null @@ -1,437 +0,0 @@ -use super::{ - backend::describe::{PortalDescribe, QueryType, StatementDescribe}, - client::ClientId, - error::{PgError, PgResult}, - storage::{ - result::ExecuteResult, - value::{Format, PgValue}, - }, -}; -use postgres_types::Oid; -use serde::Deserialize; -use serde_json::Value; -use std::cell::RefCell; -use tarantool::tlua::{LuaFunction, LuaThread, PushGuard}; - -type Row = Vec<PgValue>; - -#[derive(Deserialize)] -struct RawExecuteResult { - describe: PortalDescribe, - // tuple in the same format as tuples returned from pico.sql - result: Value, -} - -struct DqlResult { - rows: Vec<Row>, - is_finished: bool, -} - -fn parse_dql(res: Value) -> PgResult<DqlResult> { - #[derive(Deserialize)] - struct RawDqlResult { - rows: Vec<Vec<Value>>, - is_finished: bool, - } - - let res: RawDqlResult = serde_json::from_value(res)?; - let rows: PgResult<Vec<Row>> = res - .rows - .into_iter() - .map(|row| row.into_iter().map(PgValue::try_from).collect()) - .collect(); - - rows.map(|rows| DqlResult { - rows, - is_finished: res.is_finished, - }) -} - -fn parse_dml(res: Value) -> PgResult<usize> { - #[derive(Deserialize)] - struct RawDmlResult { - row_count: usize, - } - - let res: RawDmlResult = serde_json::from_value(res)?; - Ok(res.row_count) -} - -fn parse_explain(res: Value) -> PgResult<DqlResult> { - #[derive(Deserialize)] - struct RawExplainResult { - rows: Vec<Value>, - is_finished: bool, - } - - let res: RawExplainResult = serde_json::from_value(res)?; - let rows: PgResult<Vec<Row>> = res - .rows - .into_iter() - // every row must be a vector - .map(|val| Ok(vec![PgValue::try_from(val)?])) - .collect(); - - rows.map(|rows| DqlResult { - rows, - is_finished: res.is_finished, - }) -} - -fn execute_result_from_raw_result(raw: RawExecuteResult) -> PgResult<ExecuteResult> { - match raw.describe.query_type() { - QueryType::Dql => { - let res = parse_dql(raw.result)?; - Ok(ExecuteResult::new(res.rows, raw.describe, res.is_finished)) - } - QueryType::Explain => { - let res = parse_explain(raw.result)?; - Ok(ExecuteResult::new(res.rows, raw.describe, res.is_finished)) - } - QueryType::Acl | QueryType::Ddl => Ok(ExecuteResult::empty(0, raw.describe)), - QueryType::Dml => Ok(ExecuteResult::empty(parse_dml(raw.result)?, raw.describe)), - } -} - -fn execute_result_from_json(json: &str) -> PgResult<ExecuteResult> { - let raw: RawExecuteResult = serde_json::from_str(json)?; - execute_result_from_raw_result(raw) -} - -fn simple_execute_result_from_json(json: &str) -> PgResult<ExecuteResult> { - let mut raw: RawExecuteResult = serde_json::from_str(json)?; - // Simple query supports only the text format. - // We couldn't set the format when we were calling bind, because we didn't know the number of columns, - // but after executing the whole simple query pipeline we have a description containing this number. - raw.describe.set_text_output_format(); - execute_result_from_raw_result(raw) -} - -type Entrypoint = LuaFunction<PushGuard<LuaThread>>; - -/// List of lua functions from sbroad that implement PG protcol API. -pub struct Entrypoints { - /// Handler for a Query message. - /// First, it closes an unnamed portal and statement, just like PG does when gets a Query messsage. - /// After that the extended pipeline is executed on unnamed portal and statement: parse + bind + describe + execute. - /// It returns the query result (in the same format as pico.sql) and description. - /// We need the description here for the command tag (CREATE TABLE, SELECT, etc) - /// that is required to make a CommandComplete message as the response to the Query message, - /// so pico.sql is not enough. - /// - /// No resources to be free after the call. - simple_query: Entrypoint, - - /// Handler for a Parse message. - /// Create a statement from a query query and store it in the sbroad storage using the given id and name as a key. - /// In case of conflicts the strategy is the same with PG. - /// - /// The statement lasts until it is explicitly closed. - parse: Entrypoint, - - /// Handler for a Bind message. - /// Copy the sources statement, create a portal by binding parameters to it and stores the portal in the sbroad storage. - /// In case of conflicts the strategy is the same with PG. - /// - /// The portal lasts until it is explicitly closed or executed. - bind: Entrypoint, - - /// Handler for an Execute message. - /// - /// Remove a portal from the sbroad storage, run it til the end and return the result. - execute: Entrypoint, - - /// Handler for a Describe message. - /// Get a statement description. - describe_statement: Entrypoint, - - /// Handler for a Describe message. - /// Get a portal description. - describe_portal: Entrypoint, - - /// Handler for a Close message. - /// Close a portal. It's not an error to close a nonexistent portal. - close_portal: Entrypoint, - - /// Handler for a Close message. - /// Close a statement with its portals. It's not an error to close a nonexistent statement. - close_statement: Entrypoint, - - /// Close client portals by the given client id. - close_client_portals: Entrypoint, - - /// Close client statements with its portals by the given client id. - close_client_statements: Entrypoint, -} - -impl Entrypoints { - fn new() -> PgResult<Self> { - let simple_query = LuaFunction::load( - tarantool::lua_state(), - " - local function close_unnamed(client_id) - pico.pg_close_stmt(client_id, '') - pico.pg_close_portal(client_id, '') - end - - local function parse_and_execute_unnamed(client_id, sql) - local res, err = pico.pg_parse(client_id, '', sql, {}) - if res == nil then - return nil, err - end - - -- {}, {} => no parameters, default result encoding (text) - local res, err = pico.pg_bind(client_id, '', '', {}, {}) - if res == nil then - return nil, err - end - - local desc, err = pico.pg_describe_portal(client_id, '') - if desc == nil then - return nil, err - end - - -- -1 == fetch all - local res, err = pico.pg_execute(client_id, '', -1) - if res == nil then - return nil, err - end - - return {['describe'] = desc, ['result'] = res} - end - - local client_id, sql = ... - - -- Strictly speaking, this is a part of an extended query protocol. - -- When a query message is received, PG closes an unnamed portal and statement - -- and runs the extended query pipeline on them (parse + bind + execute). - close_unnamed(client_id) - - local res, err = parse_and_execute_unnamed(client_id, sql) - - -- After the execution, the portal and statement must be closed. - close_unnamed(client_id) - - if res == nil then - error(err) - end - - return require('json').encode(res) - ", - )?; - - let close_client_statements = LuaFunction::load( - tarantool::lua_state(), - " - -- closing a statement closes its portals, - -- so then we close all the statements we close all the portals too. - pico.pg_close_client_stmts(...) - ", - )?; - - let parse = LuaFunction::load( - tarantool::lua_state(), - " - local res, err = pico.pg_parse(...) - if res == nil then - error(err) - end - ", - )?; - - let bind = LuaFunction::load( - tarantool::lua_state(), - " - local res, err = pico.pg_bind(...) - if res == nil then - error(err) - end - ", - )?; - - let execute = LuaFunction::load( - tarantool::lua_state(), - " - local id, portal = ... - local desc, err = pico.pg_describe_portal(id, portal) - if desc == nil then - error(err) - end - - -- -1 == fetch all - local res, err = pico.pg_execute(id, portal, -1) - if res == nil then - error(err) - end - - return require('json').encode({['describe'] = desc, ['result'] = res}) - ", - )?; - - let describe_portal = LuaFunction::load( - tarantool::lua_state(), - " - local res, err = pico.pg_describe_portal(...) - if res == nil then - error(err) - end - return require('json').encode(res) - ", - )?; - - let describe_statement = LuaFunction::load( - tarantool::lua_state(), - " - local res, err = pico.pg_describe_stmt(...) - if res == nil then - error(err) - end - return require('json').encode(res) - ", - )?; - - let close_portal = LuaFunction::load( - tarantool::lua_state(), - " - local res, err = pico.pg_close_portal(...) - if res == nil then - error(err) - end - ", - )?; - - let close_statement = LuaFunction::load( - tarantool::lua_state(), - " - local res, err = pico.pg_close_stmt(...) - if res == nil then - error(err) - end - ", - )?; - - let close_client_portals = LuaFunction::load( - tarantool::lua_state(), - " - local res, err = pico.pg_close_client_portals(...) - if res == nil then - error(err) - end - ", - )?; - - Ok(Self { - simple_query, - close_client_statements, - parse, - bind, - execute, - describe_portal, - describe_statement, - close_portal, - close_statement, - close_client_portals, - }) - } - - /// Handler for a Query message. See self.simple_query for the details. - pub fn simple_query(&self, client_id: ClientId, sql: &str) -> PgResult<ExecuteResult> { - let json: String = self - .simple_query - .call_with_args((client_id, sql)) - .map_err(|e| PgError::LuaError(e.into()))?; - simple_execute_result_from_json(&json) - } - - /// Handler for a Parse message. See self.parse for the details. - pub fn parse( - &self, - client_id: ClientId, - name: &str, - sql: &str, - param_oids: &[Oid], - ) -> PgResult<()> { - self.parse - .call_with_args((client_id, name, sql, param_oids)) - .map_err(|e| PgError::LuaError(e.into())) - } - - /// Handler for a Bind message. See self.bind for the details. - pub fn bind( - &self, - id: ClientId, - statement: &str, - portal: &str, - params: Vec<PgValue>, - result_format: Vec<Format>, - ) -> PgResult<()> { - self.bind - .call_with_args((id, statement, portal, params, result_format)) - .map_err(|e| PgError::LuaError(e.into())) - } - - /// Handler for an Execute message. See self.execute for the details. - pub fn execute(&self, id: ClientId, portal: &str) -> PgResult<ExecuteResult> { - let json: String = self - .execute - .call_with_args((id, portal)) - .map_err(|e| PgError::LuaError(e.into()))?; - execute_result_from_json(&json) - } - - /// Handler for a Describe message. See self.describe_portal for the details. - pub fn describe_portal(&self, client_id: ClientId, portal: &str) -> PgResult<PortalDescribe> { - let json: String = self - .describe_portal - .call_with_args((client_id, portal)) - .map_err(|e| PgError::LuaError(e.into()))?; - let describe = serde_json::from_str(&json)?; - Ok(describe) - } - - /// Handler for a Describe message. See self.describe_statement for the details. - pub fn describe_statement( - &self, - client_id: ClientId, - statement: &str, - ) -> PgResult<StatementDescribe> { - let json: String = self - .describe_statement - .call_with_args((client_id, statement)) - .map_err(|e| PgError::LuaError(e.into()))?; - let describe = serde_json::from_str(&json)?; - Ok(describe) - } - - /// Handler for a Close message. See self.close_portal for the details. - pub fn close_portal(&self, id: ClientId, portal: &str) -> PgResult<()> { - self.close_portal - .call_with_args((id, portal)) - .map_err(|e| PgError::LuaError(e.into())) - } - - /// Handler for a Close message. See self.close_statement for the details. - pub fn close_statement(&self, client_id: ClientId, statement: &str) -> PgResult<()> { - self.close_statement - .call_with_args((client_id, statement)) - .map_err(|e| PgError::LuaError(e.into())) - } - - /// Close all the client statements and portals. See self.close_client_statements for the details. - pub fn close_client_statements(&self, client_id: ClientId) -> PgResult<()> { - self.close_client_statements - .call_with_args(client_id) - .map_err(|e| PgError::LuaError(e.into())) - } - - /// Close client statements with its portals. - pub fn close_client_portals(&self, client_id: ClientId) -> PgResult<()> { - self.close_client_portals - .call_with_args(client_id) - .map_err(|e| PgError::LuaError(e.into())) - } -} - -thread_local! { - pub static PG_ENTRYPOINTS: RefCell<Entrypoints> = RefCell::new(Entrypoints::new().unwrap()) -} diff --git a/src/pgproto/error.rs b/src/pgproto/error.rs index bc25864d6a1ec548085df969af5f0417fba30973..a3a9e1f5166073707fac6f5230419024a9f3874d 100644 --- a/src/pgproto/error.rs +++ b/src/pgproto/error.rs @@ -14,6 +14,9 @@ pub type PgResult<T> = Result<T, PgError>; /// See <https://www.postgresql.org/docs/current/errcodes-appendix.html>. #[derive(Error, Debug)] pub enum PgError { + #[error("internal error: {0}")] + InternalError(String), + #[error("protocol violation: {0}")] ProtocolViolation(String), @@ -26,6 +29,7 @@ pub enum PgError { #[error("IO error: {0}")] IoError(#[from] io::Error), + // Server could not encode value into client's format. #[error("encoding error: {0}")] EncodingError(Box<dyn error::Error>), @@ -38,6 +42,7 @@ pub enum PgError { #[error("json error: {0}")] JsonError(#[from] serde_json::Error), + // Server could not decode value recieved from client. #[error("{0}")] DecodingError(#[from] DecodingError), @@ -99,6 +104,7 @@ impl PgError { fn code(&self) -> &str { use PgError::*; match self { + InternalError(_) => "XX000", ProtocolViolation(_) => "08P01", FeatureNotSupported(_) => "0A000", InvalidPassword(_) => "28P01", diff --git a/src/pgproto/storage.rs b/src/pgproto/storage.rs deleted file mode 100644 index ce2abcf2696d0dc8195c6173ac9d1b633fbab048..0000000000000000000000000000000000000000 --- a/src/pgproto/storage.rs +++ /dev/null @@ -1,131 +0,0 @@ -use crate::tlog; - -use self::result::ExecuteResult; -use self::value::{Format, PgValue}; -use super::backend::describe::{PortalDescribe, StatementDescribe}; -use super::client::ClientId; -use super::entrypoints::PG_ENTRYPOINTS; -use super::error::PgResult; -use postgres_types::Oid; -use std::sync::atomic::{AtomicU32, Ordering}; - -pub mod result; -pub mod value; - -fn unique_id() -> ClientId { - static ID_COUNTER: AtomicU32 = AtomicU32::new(0); - ID_COUNTER.fetch_add(1, Ordering::Relaxed) -} - -/// It allows to interact with the real storage in sbroad -/// and cleanups all the created portals and statements when it's dropped. -pub struct StorageManager { - /// Unique id of a PG client. - /// Every portal and statement is tagged by this id in the sbroad storage, - /// so they all can be found and deleted using this id. - /// Since the id is unique the statements and portals are isolated between users. - client_id: ClientId, -} - -impl StorageManager { - pub fn new() -> Self { - Self { - client_id: unique_id(), - } - } - - /// Handler for a Query message. See Entrypoints::simple_query for the details. - pub fn simple_query(&self, sql: &str) -> PgResult<ExecuteResult> { - PG_ENTRYPOINTS.with(|entrypoints| entrypoints.borrow().simple_query(self.client_id, sql)) - } - - pub fn describe_statement(&self, name: Option<&str>) -> PgResult<StatementDescribe> { - PG_ENTRYPOINTS.with(|entrypoints| { - entrypoints - .borrow() - .describe_statement(self.client_id, name.unwrap_or("")) - }) - } - - pub fn describe_portal(&self, name: Option<&str>) -> PgResult<PortalDescribe> { - PG_ENTRYPOINTS.with(|entrypoints| { - entrypoints - .borrow() - .describe_portal(self.client_id, name.unwrap_or("")) - }) - } - - pub fn parse(&self, name: Option<&str>, sql: &str, param_oids: &[Oid]) -> PgResult<()> { - PG_ENTRYPOINTS.with(|entrypoints| { - entrypoints - .borrow() - .parse(self.client_id, name.unwrap_or(""), sql, param_oids) - }) - } - - pub fn bind( - &self, - statement: Option<&str>, - portal: Option<&str>, - params: Vec<PgValue>, - result_format: Vec<Format>, - ) -> PgResult<()> { - PG_ENTRYPOINTS.with(|entrypoints| { - entrypoints.borrow().bind( - self.client_id, - statement.unwrap_or(""), - portal.unwrap_or(""), - params, - result_format, - ) - }) - } - - pub fn execute(&self, portal: Option<&str>) -> PgResult<ExecuteResult> { - PG_ENTRYPOINTS.with(|entrypoints| { - entrypoints - .borrow() - .execute(self.client_id, portal.unwrap_or("")) - }) - } - - pub fn close_portal(&self, portal: Option<&str>) -> PgResult<()> { - PG_ENTRYPOINTS.with(|entrypoints| { - entrypoints - .borrow() - .close_portal(self.client_id, portal.unwrap_or("")) - }) - } - - pub fn close_statement(&self, statement: Option<&str>) -> PgResult<()> { - PG_ENTRYPOINTS.with(|entrypoints| { - entrypoints - .borrow() - .close_statement(self.client_id, statement.unwrap_or("")) - }) - } - - pub fn close_all_portals(&self) -> PgResult<()> { - PG_ENTRYPOINTS.with(|entrypoints| entrypoints.borrow().close_client_portals(self.client_id)) - } - - fn on_disconnect(&self) -> PgResult<()> { - // Close all the statements with its portals. - PG_ENTRYPOINTS - .with(|entrypoints| entrypoints.borrow().close_client_statements(self.client_id)) - } -} - -impl Drop for StorageManager { - fn drop(&mut self) { - match self.on_disconnect() { - Ok(_) => {} - Err(err) => tlog!( - Warning, - "failed to close user {} statements and portals: {:?}", - self.client_id, - err - ), - } - } -} diff --git a/src/pgproto/storage/value.rs b/src/pgproto/value.rs similarity index 76% rename from src/pgproto/storage/value.rs rename to src/pgproto/value.rs index 8b798ad4e65651053c2c84adc231a07463be6498..c4258374f468b9c2604679e706c7dd8886614b8e 100644 --- a/src/pgproto/storage/value.rs +++ b/src/pgproto/value.rs @@ -7,7 +7,6 @@ use sbroad::ir::value::Value; use serde_repr::{Deserialize_repr, Serialize_repr}; use std::error::Error; use std::str; -use tarantool::tlua::{AsLua, Nil, PushInto}; use crate::pgproto::error::{DecodingError, PgError, PgResult}; @@ -27,22 +26,14 @@ pub fn type_from_name(name: &str) -> PgResult<Type> { /// This type is used to send Format over the wire. pub type RawFormat = i16; -#[derive(Debug, Clone, Copy, Serialize_repr, Deserialize_repr)] +#[derive(Debug, Clone, Copy, Default, Serialize_repr, Deserialize_repr)] #[repr(i16)] pub enum Format { + #[default] Text = 0, Binary = 1, } -impl<L: AsLua> PushInto<L> for Format { - type Err = tarantool::tlua::Void; - - fn push_into_lua(self, lua: L) -> Result<tarantool::tlua::PushGuard<L>, (Self::Err, L)> { - let value = self as RawFormat; - value.push_into_lua(lua) - } -} - impl TryFrom<RawFormat> for Format { type Error = PgError; fn try_from(value: RawFormat) -> Result<Self, Self::Error> { @@ -60,6 +51,10 @@ impl TryFrom<RawFormat> for Format { pub struct PgValue(sbroad::ir::value::Value); impl PgValue { + pub fn into_inner(self) -> Value { + self.0 + } + fn integer(value: i64) -> Self { Self(Value::Integer(value)) } @@ -81,30 +76,39 @@ impl PgValue { } } -impl TryFrom<serde_json::Value> for PgValue { +impl TryFrom<rmpv::Value> for PgValue { type Error = PgError; - fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> { - let ret = match value { - serde_json::Value::Number(number) => { - if number.is_f64() { - PgValue::float(number.as_f64().unwrap()) - } else if number.is_i64() { - PgValue::integer(number.as_i64().unwrap()) + fn try_from(value: rmpv::Value) -> Result<Self, Self::Error> { + match value { + rmpv::Value::Nil => Ok(PgValue::null()), + rmpv::Value::Boolean(v) => Ok(PgValue::boolean(v)), + rmpv::Value::F32(v) => Ok(PgValue::float(v.into())), + rmpv::Value::F64(v) => Ok(PgValue::float(v)), + rmpv::Value::Integer(v) => { + let i = if v.is_i64() { + v.as_i64().unwrap() + } else if v.is_u64() { + // NOTE: u64::MAX can't be converted into i64 + i64::try_from(v.as_u64().unwrap()) + .map_err(|e| PgError::EncodingError(e.into()))? } else { - Err(PgError::FeatureNotSupported(format!( - "unsupported type {number}" - )))? - } + Err(PgError::EncodingError( + format!("couldn't encode integer: {v:?}").into(), + ))? + }; + Ok(PgValue::integer(i)) } - serde_json::Value::String(string) => PgValue::text(string), - serde_json::Value::Bool(bool) => PgValue::boolean(bool), - serde_json::Value::Null => PgValue::null(), - _ => Err(PgError::FeatureNotSupported(format!( - "unsupported type {value}" - )))?, - }; - Ok(ret) + rmpv::Value::String(v) => { + let Some(s) = v.as_str() else { + Err(PgError::EncodingError( + format!("couldn't encode string: {v:?}").into(), + ))? + }; + Ok(PgValue::text(s.to_owned())) + } + value => Err(PgError::FeatureNotSupported(format!("value: {value:?}"))), + } } } @@ -224,20 +228,3 @@ impl PgValue { } } } - -impl<L: AsLua> PushInto<L> for PgValue { - type Err = tarantool::tlua::Void; - - fn push_into_lua(self, lua: L) -> Result<tarantool::tlua::PushGuard<L>, (Self::Err, L)> { - match self.0 { - Value::Boolean(value) => value.push_into_lua(lua), - Value::String(value) => value.push_into_lua(lua), - Value::Integer(value) => value.push_into_lua(lua), - Value::Double(double) => double.value.push_into_lua(lua), - Value::Null => PushInto::push_into_lua(Nil, lua), - // Let's just panic for now. Anyway, we will get rid of this PushInto impl after - // we get rid of the lua entrypoints in the next merge request. - _ => panic!("unsupported value"), - } - } -} diff --git a/test/int/test_pgproto.py b/test/int/test_pgproto.py index 6e4a6db264cbb003ba87611ff65cb5ef7441c372..5b37e222eb99d238c7ed6ef29d11185cdaf959de 100644 --- a/test/int/test_pgproto.py +++ b/test/int/test_pgproto.py @@ -24,7 +24,7 @@ def test_extended_ddl(pg_client: PgClient): pg_client.bind("", "portal", [], []) assert len(pg_client.portals["available"]) == 1 data = pg_client.execute("portal") - assert data["row_count"] == 1 + assert data["row_count"] is None assert len(pg_client.statements["available"]) == 1 pg_client.close_stmt("") @@ -421,12 +421,12 @@ def test_interactive_portals(pg_client: PgClient): assert len(data["rows"]) == 1 assert [ """projection ("t"."key"::integer -> "key", "t"."value"::string -> "value")""" - ] == data["rows"] + ] == data["rows"][0] assert data["is_finished"] is False data = pg_client.execute("", -1) assert len(data["rows"]) == 4 - assert """ scan "t\"""" in data["rows"] - assert """execution options:""" in data["rows"] - assert """sql_vdbe_max_steps = 45000""" in data["rows"] - assert """vtable_max_rows = 5000""" in data["rows"] + assert [""" scan "t\""""] == data["rows"][0] + assert ["""execution options:"""] == data["rows"][1] + assert ["""sql_vdbe_max_steps = 45000"""] == data["rows"][2] + assert ["""vtable_max_rows = 5000"""] == data["rows"][3] assert data["is_finished"] is True