From b36c87ce468565b998be8cca29e7ec838516dd7b Mon Sep 17 00:00:00 2001 From: Kaitmazian Maksim <m.kaitmazian@picodata.io> Date: Thu, 18 Jan 2024 19:18:53 +0300 Subject: [PATCH] feat: implement basic extended query protcol This commit adds supportion of extended query messages and error handling in extended query pipeline. Features that are not implemented: * binary encoding * parameterized queries --- pgproto/src/client.rs | 135 ++++++++++++++--- pgproto/src/client/extended_query.rs | 163 +++++++++++++++++++++ pgproto/src/entrypoints.rs | 209 ++++++++++++++++++++++++++- pgproto/src/messages.rs | 30 +++- pgproto/src/storage.rs | 63 ++++++++ pgproto/src/storage/describe.rs | 11 +- pgproto/src/storage/result.rs | 4 + pgproto/test/extended_query_test.py | 91 ++++++++++++ 8 files changed, 685 insertions(+), 21 deletions(-) create mode 100644 pgproto/src/client/extended_query.rs create mode 100644 pgproto/test/extended_query_test.py diff --git a/pgproto/src/client.rs b/pgproto/src/client.rs index d2d9aaf3ce..eb96bbf79b 100644 --- a/pgproto/src/client.rs +++ b/pgproto/src/client.rs @@ -7,6 +7,7 @@ use pgwire::messages::startup::*; use std::io; mod auth; +mod extended_query; mod simple_query; mod startup; @@ -18,6 +19,8 @@ pub struct PgClient<S> { manager: StorageManager, /// Stream for network communication. stream: PgStream<S>, + + loop_state: MessageLoopState, } impl<S: io::Read + io::Write> PgClient<S> { @@ -43,6 +46,7 @@ impl<S: io::Read + io::Write> PgClient<S> { Ok(PgClient { manager: StorageManager::new(), + loop_state: MessageLoopState::ReadyForQuery, stream, }) } @@ -61,8 +65,9 @@ impl<S: io::Read + io::Write> PgClient<S> { } #[derive(PartialEq)] -enum ConnectionState { +enum MessageLoopState { ReadyForQuery, + RunningExtendedQuery, Terminated, } @@ -79,37 +84,133 @@ impl PgError { impl<S: io::Read + io::Write> PgClient<S> { /// Receive a single message, process it, then send a proper response. - fn process_message(&mut self) -> PgResult<ConnectionState> { + fn process_message(&mut self) -> PgResult<()> { let message = self.stream.read_message()?; log::debug!("received {message:?}"); + + if self.is_running_extended_query() && !extended_query::is_extended_query_message(&message) + { + // According to the protocol, the extended query is expected to be finished by getting + // a Sync message, but the frontend can send a simple query message before Sync, + // which will finish the pipeline. In that case Postgres just changes the state without any + // errors or warnings. We can follow the Postgres way, but I think a warning might be helpful. + // + // See the discussion about getting a Query message while running extended query: + // https://postgrespro.com/list/thread-id/2416958. + log::warn!("got {message:?} message while running extended query"); + } + match message { FeMessage::Query(query) => { - log::info!("executing query"); + log::info!("executing simple query: {}", query.query()); process_query_message(&mut self.stream, &self.manager, query)?; - Ok(ConnectionState::ReadyForQuery) + self.loop_state = MessageLoopState::ReadyForQuery; + } + FeMessage::Parse(parse) => { + log::info!( + "parsing query \'{}\': {}", + parse.name().as_deref().unwrap_or_default(), + parse.query(), + ); + self.loop_state = MessageLoopState::RunningExtendedQuery; + extended_query::process_parse_message(&mut self.stream, &self.manager, parse)?; + } + FeMessage::Bind(bind) => { + log::info!( + "binding statement \'{}\' to portal \'{}\'", + bind.statement_name().as_deref().unwrap_or_default(), + bind.portal_name().as_deref().unwrap_or_default() + ); + self.loop_state = MessageLoopState::RunningExtendedQuery; + extended_query::process_bind_message(&mut self.stream, &self.manager, bind)?; + } + FeMessage::Execute(execute) => { + log::info!( + "executing portal \'{}\'", + execute.name().as_deref().unwrap_or_default() + ); + self.loop_state = MessageLoopState::RunningExtendedQuery; + extended_query::process_execute_message(&mut self.stream, &self.manager, execute)?; + } + FeMessage::Describe(describe) => { + log::info!( + "describing {} \'{}\'", + describe.target_type(), + describe.name().as_deref().unwrap_or_default() + ); + self.loop_state = MessageLoopState::RunningExtendedQuery; + extended_query::process_describe_message( + &mut self.stream, + &self.manager, + describe, + )?; + } + FeMessage::Close(close) => { + log::info!( + "closing {} \'{}\'", + close.target_type(), + close.name().as_deref().unwrap_or_default() + ); + self.loop_state = MessageLoopState::RunningExtendedQuery; + extended_query::process_close_message(&mut self.stream, &self.manager, close)?; + } + FeMessage::Flush(_) => { + log::info!("flushing"); + self.loop_state = MessageLoopState::RunningExtendedQuery; + self.stream.flush()?; + } + FeMessage::Sync(_) => { + log::info!("syncing"); + self.loop_state = MessageLoopState::ReadyForQuery; + extended_query::process_sync_mesage(&self.manager)?; } FeMessage::Terminate(_) => { log::info!("terminating the session"); - Ok(ConnectionState::Terminated) + self.loop_state = MessageLoopState::Terminated; } - message => Err(PgError::FeatureNotSupported(format!("{message:?}"))), - } + message => return Err(PgError::FeatureNotSupported(format!("{message:?}"))), + }; + Ok(()) + } + + fn process_error(&mut self, error: PgError) -> PgResult<()> { + log::info!("processing error: {error:?}"); + self.stream + .write_message(messages::error_response(error.info()))?; + error.check_fatality()?; + if let MessageLoopState::RunningExtendedQuery = self.loop_state { + loop { + if let FeMessage::Sync(_) = self.stream.read_message()? { + self.loop_state = MessageLoopState::ReadyForQuery; + extended_query::process_sync_mesage(&self.manager)?; + break; + } + } + }; + Ok(()) + } + + fn is_terminated(&self) -> bool { + matches!(self.loop_state, MessageLoopState::Terminated) + } + + fn is_running_extended_query(&self) -> bool { + matches!(self.loop_state, MessageLoopState::RunningExtendedQuery) } /// Process incoming client messages until we see an irrecoverable error. pub fn process_messages_loop(&mut self) -> PgResult<()> { log::info!("entering the message handling loop"); - loop { - self.stream.write_message(messages::ready_for_query())?; - match self.process_message() { - Ok(ConnectionState::ReadyForQuery) => continue, - Ok(ConnectionState::Terminated) => break Ok(()), - Err(error) => { - self.stream - .write_message(messages::error_response(error.info()))?; - error.check_fatality()?; - } + while !self.is_terminated() { + if let MessageLoopState::ReadyForQuery = self.loop_state { + self.stream.write_message(messages::ready_for_query())?; } + + match self.process_message() { + Ok(_) => continue, + Err(error) => self.process_error(error)?, + }; } + Ok(()) } } diff --git a/pgproto/src/client/extended_query.rs b/pgproto/src/client/extended_query.rs new file mode 100644 index 0000000000..a4ee18feef --- /dev/null +++ b/pgproto/src/client/extended_query.rs @@ -0,0 +1,163 @@ +use crate::stream::{BeMessage, FeMessage}; +use crate::{ + error::{PgError, PgResult}, + messages, + storage::StorageManager, + stream::PgStream, +}; +use pgwire::messages::extendedquery::{Bind, Close, Describe, Execute, Parse}; +use std::io; + +pub fn process_parse_message( + stream: &mut PgStream<impl io::Write>, + manager: &StorageManager, + parse: Parse, +) -> PgResult<()> { + if !parse.type_oids().is_empty() { + return Err(PgError::FeatureNotSupported("parameterized queries".into())); + } + manager.parse(parse.name().as_deref(), parse.query())?; + stream.write_message_noflush(messages::parse_complete())?; + Ok(()) +} + +pub fn process_bind_message( + stream: &mut PgStream<impl io::Write>, + manager: &StorageManager, + bind: Bind, +) -> PgResult<()> { + if !bind.parameters().is_empty() { + return Err(PgError::FeatureNotSupported("parameterized queries".into())); + } + + manager.bind( + bind.statement_name().as_deref(), + bind.portal_name().as_deref(), + )?; + stream.write_message_noflush(messages::bind_complete())?; + Ok(()) +} + +pub fn process_execute_message( + stream: &mut PgStream<impl io::Write>, + manager: &StorageManager, + execute: Execute, +) -> PgResult<()> { + let mut count = *execute.max_rows() as i64; + let mut portal = manager.execute(execute.name().as_deref())?; + if count <= 0 { + count = std::i64::MAX; + } + + for _ in 0..count { + if let Some(row) = portal.next() { + stream.write_message_noflush(messages::data_row(row))?; + } else { + break; + } + } + + if portal.is_empty() { + let tag = portal.command_tag().as_str(); + stream.write_message_noflush(messages::command_complete(tag, portal.row_count()))?; + } else { + stream.write_message_noflush(messages::portal_suspended())?; + } + + Ok(()) +} + +fn describe_statement( + manager: &StorageManager, + statement: Option<&str>, +) -> PgResult<(BeMessage, BeMessage)> { + let stmt_describe = manager.describe_statement(statement)?; + let param_oids = stmt_describe.param_oids; + let describe = stmt_describe.describe; + + let parameter_description = messages::parameter_description(param_oids); + if let Some(row_description) = describe.row_description()? { + Ok(( + parameter_description, + messages::row_description(row_description), + )) + } else { + Ok((parameter_description, messages::no_data())) + } +} + +fn describe_portal(manager: &StorageManager, portal: Option<&str>) -> PgResult<BeMessage> { + let describe = manager.describe_portal(portal)?; + if let Some(row_description) = describe.row_description()? { + Ok(messages::row_description(row_description)) + } else { + Ok(messages::no_data()) + } +} + +pub fn process_describe_message( + stream: &mut PgStream<impl io::Write>, + manager: &StorageManager, + describe: Describe, +) -> PgResult<()> { + let name = describe.name().as_deref(); + match describe.target_type() { + b'S' => { + let (params_desc, rows_desc) = describe_statement(manager, name)?; + stream.write_message_noflush(params_desc)?; + stream.write_message_noflush(rows_desc)?; + Ok(()) + } + b'P' => { + let rows_desc = describe_portal(manager, name)?; + stream.write_message_noflush(rows_desc)?; + Ok(()) + } + _ => Err(PgError::ProtocolViolation(format!( + "unknown describe type \'{}\'", + describe.target_type() + ))), + } +} + +pub fn process_close_message( + stream: &mut PgStream<impl io::Write>, + manager: &StorageManager, + close: Close, +) -> PgResult<()> { + let name = close.name().as_deref(); + match close.target_type() { + b'S' => manager.close_statement(name)?, + b'P' => manager.close_portal(name)?, + _ => { + return Err(PgError::ProtocolViolation(format!( + "unknown close type \'{}\'", + close.target_type() + ))); + } + } + stream.write_message_noflush(messages::close_complete())?; + Ok(()) +} + +pub fn process_sync_mesage(manager: &StorageManager) -> PgResult<()> { + // By default, PG runs in autocommit mode, which means that every statement is ran inside its own transaction. + // In simple query statement means the query inside a Query message. + // In extended query statement means everything before a Sync message. + // When PG gets a Sync mesage it finishes the current transaction by calling finish_xact_command, + // which drops all non-holdable portals. We close all portals here because we don't have the holdable portals. + manager.close_all_portals() +} + +pub fn is_extended_query_message(message: &FeMessage) -> bool { + matches!( + message, + FeMessage::Parse(_) + | FeMessage::Close(_) + | FeMessage::Bind(_) + | FeMessage::Describe(_) + | FeMessage::Execute(_) + | FeMessage::Flush(_) + | FeMessage::Sync(_) + ) +} diff --git a/pgproto/src/entrypoints.rs b/pgproto/src/entrypoints.rs index 49fdf0669e..ea85a0c2b3 100644 --- a/pgproto/src/entrypoints.rs +++ b/pgproto/src/entrypoints.rs @@ -4,7 +4,7 @@ use crate::{ client::ClientId, error::{PgError, PgResult}, storage::{ - describe::{Describe, QueryType}, + describe::{Describe, PortalDescribe, QueryType, StatementDescribe}, result::ExecuteResult, value::PgValue, }, @@ -83,7 +83,45 @@ pub struct Entrypoints { /// No resources to be free after the call. simple_query: Entrypoint, - /// Close client statements and portals by the given client id. + /// Handler for a Parse message. + /// Create a statement from a query query and store it in the sbroad storage using the given id and name as a key. + /// In case of conflicts the strategy is the same with PG. + /// + /// The statement lasts until it is explicitly closed. + parse: Entrypoint, + + /// Handler for a Bind message. + /// Copy the sources statement, create a portal by binding parameters to it and stores the portal in the sbroad storage. + /// In case of conflicts the strategy is the same with PG. + /// + /// The portal lasts until it is explicitly closed or executed. + bind: Entrypoint, + + /// Handler for an Execute message. + /// + /// Remove a portal from the sbroad storage, run it til the end and return the result. + execute: Entrypoint, + + /// Handler for a Describe message. + /// Get a statement description. + describe_statement: Entrypoint, + + /// Handler for a Describe message. + /// Get a portal description. + describe_portal: Entrypoint, + + /// Handler for a Close message. + /// Close a portal. It's not an error to close a nonexistent portal. + close_portal: Entrypoint, + + /// Handler for a Close message. + /// Close a statement with its portals. It's not an error to close a nonexistent statement. + close_statement: Entrypoint, + + /// Close client portals by the given client id. + close_client_portals: Entrypoint, + + /// Close client statements with its portals by the given client id. close_client_statements: Entrypoint, } @@ -150,9 +188,108 @@ impl Entrypoints { ", )?; + let parse = LuaFunction::load( + tarantool::lua_state(), + " + local client_id, name, sql = ... + local res, err = pico.pg_parse(client_id, name, sql, {}) + if res == nil then + error(err) + end + ", + )?; + + let bind = LuaFunction::load( + tarantool::lua_state(), + " + local res, err = pico.pg_bind(...) + if res == nil then + error(err) + end + ", + )?; + + let execute = LuaFunction::load( + tarantool::lua_state(), + " + local id, portal = ... + local desc, err = pico.pg_describe_portal(id, portal) + if desc == nil then + error(err) + end + + local res, err = pico.pg_execute(id, portal) + if res == nil then + error(err) + end + + return require('json').encode({['describe'] = desc, ['result'] = res}) + ", + )?; + + let describe_portal = LuaFunction::load( + tarantool::lua_state(), + " + local res, err = pico.pg_describe_portal(...) + if res == nil then + error(err) + end + return require('json').encode(res) + ", + )?; + + let describe_statement = LuaFunction::load( + tarantool::lua_state(), + " + local res, err = pico.pg_describe_stmt(...) + if res == nil then + error(err) + end + return require('json').encode(res) + ", + )?; + + let close_portal = LuaFunction::load( + tarantool::lua_state(), + " + local res, err = pico.pg_close_portal(...) + if res == nil then + error(err) + end + ", + )?; + + let close_statement = LuaFunction::load( + tarantool::lua_state(), + " + local res, err = pico.pg_close_stmt(...) + if res == nil then + error(err) + end + ", + )?; + + let close_client_portals = LuaFunction::load( + tarantool::lua_state(), + " + local res, err = pico.pg_close_client_portals(...) + if res == nil then + error(err) + end + ", + )?; + Ok(Self { simple_query, close_client_statements, + parse, + bind, + execute, + describe_portal, + describe_statement, + close_portal, + close_statement, + close_client_portals, }) } @@ -165,12 +302,80 @@ impl Entrypoints { execute_result_from_json(&json) } + /// Handler for a Parse message. See self.parse for the details. + pub fn parse(&self, client_id: ClientId, name: &str, sql: &str) -> PgResult<()> { + self.parse + .call_with_args((client_id, name, sql)) + .map_err(|e| PgError::TarantoolError(e.into())) + } + + /// Handler for a Bind message. See self.bind for the details. + pub fn bind(&self, id: ClientId, statement: &str, portal: &str) -> PgResult<()> { + self.bind + .call_with_args((id, statement, portal)) + .map_err(|e| PgError::TarantoolError(e.into())) + } + + /// Handler for an Execute message. See self.execute for the details. + pub fn execute(&self, id: ClientId, portal: &str) -> PgResult<ExecuteResult> { + let json: String = self + .execute + .call_with_args((id, portal)) + .map_err(|e| PgError::TarantoolError(e.into()))?; + execute_result_from_json(&json) + } + + /// Handler for a Describe message. See self.describe_portal for the details. + pub fn describe_portal(&self, client_id: ClientId, portal: &str) -> PgResult<PortalDescribe> { + let json: String = self + .describe_portal + .call_with_args((client_id, portal)) + .map_err(|e| PgError::TarantoolError(e.into()))?; + let describe = serde_json::from_str(&json)?; + Ok(describe) + } + + /// Handler for a Describe message. See self.describe_statement for the details. + pub fn describe_statement( + &self, + client_id: ClientId, + statement: &str, + ) -> PgResult<StatementDescribe> { + let json: String = self + .describe_statement + .call_with_args((client_id, statement)) + .map_err(|e| PgError::TarantoolError(e.into()))?; + let describe = serde_json::from_str(&json)?; + Ok(describe) + } + + /// Handler for a Close message. See self.close_portal for the details. + pub fn close_portal(&self, id: ClientId, portal: &str) -> PgResult<()> { + self.close_portal + .call_with_args((id, portal)) + .map_err(|e| PgError::TarantoolError(e.into())) + } + + /// Handler for a Close message. See self.close_statement for the details. + pub fn close_statement(&self, client_id: ClientId, statement: &str) -> PgResult<()> { + self.close_statement + .call_with_args((client_id, statement)) + .map_err(|e| PgError::TarantoolError(e.into())) + } + /// Close all the client statements and portals. See self.close_client_statements for the details. pub fn close_client_statements(&self, client_id: ClientId) -> PgResult<()> { self.close_client_statements .call_with_args(client_id) .map_err(|e| PgError::TarantoolError(e.into())) } + + /// Close client statements with its portals. + pub fn close_client_portals(&self, client_id: ClientId) -> PgResult<()> { + self.close_client_portals + .call_with_args(client_id) + .map_err(|e| PgError::TarantoolError(e.into())) + } } thread_local! { diff --git a/pgproto/src/messages.rs b/pgproto/src/messages.rs index 86e50d8bd2..065e01096e 100644 --- a/pgproto/src/messages.rs +++ b/pgproto/src/messages.rs @@ -1,8 +1,12 @@ use crate::stream::BeMessage; use pgwire::error::ErrorInfo; -use pgwire::messages::data::{DataRow, RowDescription}; +use pgwire::messages::data::{self, DataRow, ParameterDescription, RowDescription}; +use pgwire::messages::extendedquery::{ + BindComplete, CloseComplete, ParseComplete, PortalSuspended, +}; use pgwire::messages::response::SslResponse; use pgwire::messages::{response, startup::*}; +use postgres_types::Oid; /// MD5AuthRequest requests md5 password from the frontend. pub fn md5_auth_request(salt: &[u8; 4]) -> BeMessage { @@ -47,3 +51,27 @@ pub fn data_row(data_row: DataRow) -> BeMessage { pub fn ssl_refuse() -> BeMessage { BeMessage::SslResponse(SslResponse::Refuse) } + +pub fn parse_complete() -> BeMessage { + BeMessage::ParseComplete(ParseComplete::new()) +} + +pub fn no_data() -> BeMessage { + BeMessage::NoData(data::NoData::new()) +} + +pub fn bind_complete() -> BeMessage { + BeMessage::BindComplete(BindComplete::new()) +} + +pub fn portal_suspended() -> BeMessage { + BeMessage::PortalSuspended(PortalSuspended::new()) +} + +pub fn close_complete() -> BeMessage { + BeMessage::CloseComplete(CloseComplete::new()) +} + +pub fn parameter_description(type_ids: Vec<Oid>) -> BeMessage { + BeMessage::ParameterDescription(ParameterDescription::new(type_ids)) +} diff --git a/pgproto/src/storage.rs b/pgproto/src/storage.rs index 544b8a6fda..f06d4df66f 100644 --- a/pgproto/src/storage.rs +++ b/pgproto/src/storage.rs @@ -1,3 +1,4 @@ +use self::describe::{PortalDescribe, StatementDescribe}; use self::result::ExecuteResult; use crate::client::ClientId; use crate::entrypoints::PG_ENTRYPOINTS; @@ -36,6 +37,68 @@ impl StorageManager { PG_ENTRYPOINTS.with(|entrypoints| entrypoints.borrow().simple_query(self.client_id, sql)) } + pub fn describe_statement(&self, name: Option<&str>) -> PgResult<StatementDescribe> { + PG_ENTRYPOINTS.with(|entrypoints| { + entrypoints + .borrow() + .describe_statement(self.client_id, name.unwrap_or("")) + }) + } + + pub fn describe_portal(&self, name: Option<&str>) -> PgResult<PortalDescribe> { + PG_ENTRYPOINTS.with(|entrypoints| { + entrypoints + .borrow() + .describe_portal(self.client_id, name.unwrap_or("")) + }) + } + + pub fn parse(&self, name: Option<&str>, sql: &str) -> PgResult<()> { + PG_ENTRYPOINTS.with(|entrypoints| { + entrypoints + .borrow() + .parse(self.client_id, name.unwrap_or(""), sql) + }) + } + + pub fn bind(&self, statement: Option<&str>, portal: Option<&str>) -> PgResult<()> { + PG_ENTRYPOINTS.with(|entrypoints| { + entrypoints.borrow().bind( + self.client_id, + statement.unwrap_or(""), + portal.unwrap_or(""), + ) + }) + } + + pub fn execute(&self, portal: Option<&str>) -> PgResult<ExecuteResult> { + PG_ENTRYPOINTS.with(|entrypoints| { + entrypoints + .borrow() + .execute(self.client_id, portal.unwrap_or("")) + }) + } + + pub fn close_portal(&self, portal: Option<&str>) -> PgResult<()> { + PG_ENTRYPOINTS.with(|entrypoints| { + entrypoints + .borrow() + .close_portal(self.client_id, portal.unwrap_or("")) + }) + } + + pub fn close_statement(&self, statement: Option<&str>) -> PgResult<()> { + PG_ENTRYPOINTS.with(|entrypoints| { + entrypoints + .borrow() + .close_statement(self.client_id, statement.unwrap_or("")) + }) + } + + pub fn close_all_portals(&self) -> PgResult<()> { + PG_ENTRYPOINTS.with(|entrypoints| entrypoints.borrow().close_client_portals(self.client_id)) + } + fn on_disconnect(&self) -> PgResult<()> { // Close all the statements with its portals. PG_ENTRYPOINTS diff --git a/pgproto/src/storage/describe.rs b/pgproto/src/storage/describe.rs index 2ad8877eb6..a12258c209 100644 --- a/pgproto/src/storage/describe.rs +++ b/pgproto/src/storage/describe.rs @@ -1,6 +1,6 @@ use crate::{error::PgResult, storage::value}; use pgwire::messages::data::{FieldDescription, RowDescription}; -use postgres_types::Type; +use postgres_types::{Oid, Type}; use serde::Deserialize; use serde_repr::Deserialize_repr; @@ -13,6 +13,15 @@ pub struct Describe { metadata: Vec<MetadataColumn>, } +#[derive(Debug, Clone, Default, Deserialize)] +pub struct StatementDescribe { + #[serde(flatten)] + pub describe: Describe, + pub param_oids: Vec<Oid>, +} + +pub type PortalDescribe = Describe; + #[derive(Debug, Deserialize, PartialEq, Eq, Clone)] pub struct MetadataColumn { name: String, diff --git a/pgproto/src/storage/result.rs b/pgproto/src/storage/result.rs index 3ba298c09e..6378e4b6c3 100644 --- a/pgproto/src/storage/result.rs +++ b/pgproto/src/storage/result.rs @@ -43,6 +43,10 @@ impl ExecuteResult { self.describe.command_tag() } + pub fn is_empty(&self) -> bool { + self.values_stream.len() == 0 + } + pub fn row_description(&self) -> PgResult<Option<RowDescription>> { self.describe.row_description() } diff --git a/pgproto/test/extended_query_test.py b/pgproto/test/extended_query_test.py new file mode 100644 index 0000000000..ce047bab9e --- /dev/null +++ b/pgproto/test/extended_query_test.py @@ -0,0 +1,91 @@ +import pytest +import pg8000.native as pg # type: ignore +import os +from conftest import Postgres +from conftest import ReturnError +from pg8000.exceptions import DatabaseError # type: ignore + + +def test_extended_query(postgres: Postgres): + host = "127.0.0.1" + port = 5432 + + postgres.start(host, port) + i1 = postgres.instance + + user = "admin" + password = "password" + i1.eval("box.cfg{auth_type='md5'}") + i1.eval(f"box.schema.user.passwd('{user}', '{password}')") + + os.environ["PGSSLMODE"] = "disable" + conn = pg.Connection(user, password=password, host=host, port=port) + conn.autocommit = True + + ps = conn.prepare( + """ + create table "tall" ( + "id" integer not null, + "str" string, + "bool" boolean, + "real" double, + primary key ("id") + ) + using memtx distributed by ("id") + option (timeout = 3); + """ + ) + + # statement is prepared, but not executed yet + with pytest.raises(ReturnError, match="space TALL not found"): + i1.sql(""" select * from tall """) + + ps.run() + + ps = conn.prepare( + """ + INSERT INTO "tall" VALUES + (1, 'one', true, 0.1), + (2, 'to', false, 0.2), + (4, 'for', true, 0.4); + """ + ) + ps.run() + + ps = conn.prepare( + """ + SELECT * FROM "tall"; + """ + ) + + tuples = ps.run() + assert [1, "one", True, 0.1] in tuples + assert [2, "to", False, 0.2] in tuples + assert [4, "for", True, 0.4] in tuples + assert len(tuples) == 3 + + # rerun the same statement + tuples = ps.run() + assert [1, "one", True, 0.1] in tuples + assert [2, "to", False, 0.2] in tuples + assert [4, "for", True, 0.4] in tuples + assert len(tuples) == 3 + + ps.close() + + # run a closed statement + with pytest.raises(DatabaseError, match="Couldn't find statement"): + ps.run() + + # let's check that we are fine after some error handling + ps = conn.prepare( + """ + SELECT * FROM "tall"; + """ + ) + + tuples = ps.run() + assert [1, "one", True, 0.1] in tuples + assert [2, "to", False, 0.2] in tuples + assert [4, "for", True, 0.4] in tuples + assert len(tuples) == 3 -- GitLab