From 2812973b4c77642c22accb1f4d2d9840b4a60e32 Mon Sep 17 00:00:00 2001 From: Maksim Kaitmazian <m.kaitmazian@picodata.io> Date: Mon, 11 Dec 2023 21:13:00 +0300 Subject: [PATCH] feat/refactoring: sync simple query with new API --- pgproto/src/client.rs | 12 +- pgproto/src/client/simple_query.rs | 18 +-- pgproto/src/entrypoints.rs | 178 +++++++++++++++++++++++ pgproto/src/lib.rs | 3 +- pgproto/src/sql.rs | 6 - pgproto/src/sql/handle.rs | 127 ---------------- pgproto/src/sql/portal.rs | 134 ----------------- pgproto/src/sql/statement.rs | 41 ------ pgproto/src/storage.rs | 56 +++++++ pgproto/src/{sql => storage}/describe.rs | 28 ++-- pgproto/src/storage/result.rs | 69 +++++++++ pgproto/src/{sql => storage}/value.rs | 0 12 files changed, 341 insertions(+), 331 deletions(-) create mode 100644 pgproto/src/entrypoints.rs delete mode 100644 pgproto/src/sql.rs delete mode 100644 pgproto/src/sql/handle.rs delete mode 100644 pgproto/src/sql/portal.rs delete mode 100644 pgproto/src/sql/statement.rs create mode 100644 pgproto/src/storage.rs rename pgproto/src/{sql => storage}/describe.rs (80%) create mode 100644 pgproto/src/storage/result.rs rename pgproto/src/{sql => storage}/value.rs (100%) diff --git a/pgproto/src/client.rs b/pgproto/src/client.rs index 81983984a8..d2d9aaf3ce 100644 --- a/pgproto/src/client.rs +++ b/pgproto/src/client.rs @@ -1,6 +1,7 @@ use crate::client::simple_query::process_query_message; use crate::error::*; use crate::messages; +use crate::storage::StorageManager; use crate::stream::{BeMessage, FeMessage, PgStream}; use pgwire::messages::startup::*; use std::io; @@ -9,8 +10,12 @@ mod auth; mod simple_query; mod startup; +pub type ClientId = u32; + /// Postgres client representation. pub struct PgClient<S> { + // The portal and statement storage manager. + manager: StorageManager, /// Stream for network communication. stream: PgStream<S>, } @@ -36,7 +41,10 @@ impl<S: io::Read + io::Write> PgClient<S> { error })?; - Ok(PgClient { stream }) + Ok(PgClient { + manager: StorageManager::new(), + stream, + }) } /// Send paraneter to the frontend. @@ -77,7 +85,7 @@ impl<S: io::Read + io::Write> PgClient<S> { match message { FeMessage::Query(query) => { log::info!("executing query"); - process_query_message(&mut self.stream, query)?; + process_query_message(&mut self.stream, &self.manager, query)?; Ok(ConnectionState::ReadyForQuery) } FeMessage::Terminate(_) => { diff --git a/pgproto/src/client/simple_query.rs b/pgproto/src/client/simple_query.rs index b670e1a398..9985ea8b90 100644 --- a/pgproto/src/client/simple_query.rs +++ b/pgproto/src/client/simple_query.rs @@ -1,27 +1,27 @@ +use crate::storage::StorageManager; +use crate::{error::PgResult, messages, stream::PgStream}; use pgwire::messages::simplequery::Query; use std::io; -use crate::{error::PgResult, messages, sql::statement::Statement, stream::PgStream}; - pub fn process_query_message( stream: &mut PgStream<impl io::Read + io::Write>, + manager: &StorageManager, query: Query, ) -> PgResult<()> { - let statement = Statement::prepare(query.query())?; - let mut portal = statement.bind()?; + let mut query_result = manager.simple_query(query.query())?; - if portal.sends_rows() { - let row_description = messages::row_description(portal.row_description()?); + if let Some(row_description) = query_result.row_description()? { + let row_description = messages::row_description(row_description); stream.write_message_noflush(row_description)?; } - while let Some(data_row) = portal.execute_one()? { + for data_row in query_result.by_ref() { let data_row = messages::data_row(data_row); stream.write_message_noflush(data_row)?; } - let tag = portal.command_tag().as_str(); - let command_complete = messages::command_complete(tag, portal.row_count()); + let tag = query_result.command_tag().as_str(); + let command_complete = messages::command_complete(tag, query_result.row_count()); stream.write_message(command_complete)?; Ok(()) } diff --git a/pgproto/src/entrypoints.rs b/pgproto/src/entrypoints.rs new file mode 100644 index 0000000000..49fdf0669e --- /dev/null +++ b/pgproto/src/entrypoints.rs @@ -0,0 +1,178 @@ +use std::cell::RefCell; + +use crate::{ + client::ClientId, + error::{PgError, PgResult}, + storage::{ + describe::{Describe, QueryType}, + result::ExecuteResult, + value::PgValue, + }, +}; +use serde::Deserialize; +use serde_json::Value; +use tarantool::tlua::{LuaFunction, LuaThread, PushGuard}; + +type Row = Vec<PgValue>; + +#[derive(Deserialize)] +struct RawExecuteResult { + describe: Describe, + // tuple in the same format as tuples returned from pico.sql + result: Value, +} + +fn parse_dql(res: Value) -> PgResult<Vec<Row>> { + #[derive(Deserialize)] + struct DqlResult { + rows: Vec<Vec<Value>>, + #[serde(rename = "metadata")] + _metadata: Value, + } + + let res: DqlResult = serde_json::from_value(res)?; + let rows = res + .rows + .into_iter() + .map(|row| row.into_iter().map(PgValue::from).collect()) + .collect(); + Ok(rows) +} + +fn parse_dml(res: Value) -> PgResult<usize> { + #[derive(Deserialize)] + struct DmlResult { + row_count: usize, + } + + let res: DmlResult = serde_json::from_value(res)?; + Ok(res.row_count) +} + +fn parse_explain(res: Value) -> PgResult<Vec<Row>> { + let res: Vec<Value> = serde_json::from_value(res)?; + Ok(res + .into_iter() + // every row must be a vector + .map(|val| vec![PgValue::from(val)]) + .collect()) +} + +fn execute_result_from_json(json: &str) -> PgResult<ExecuteResult> { + let raw: RawExecuteResult = serde_json::from_str(json)?; + match raw.describe.query_type() { + QueryType::Dql => Ok(ExecuteResult::new(parse_dql(raw.result)?, raw.describe)), + QueryType::Explain => Ok(ExecuteResult::new(parse_explain(raw.result)?, raw.describe)), + QueryType::Acl | QueryType::Ddl => Ok(ExecuteResult::empty(0, raw.describe)), + QueryType::Dml => Ok(ExecuteResult::empty(parse_dml(raw.result)?, raw.describe)), + } +} + +type Entrypoint = LuaFunction<PushGuard<LuaThread>>; + +/// List of lua functions from sbroad that implement PG protcol API. +pub struct Entrypoints { + /// Handler for a Query message. + /// First, it closes an unnamed portal and statement, just like PG does when gets a Query messsage. + /// After that the extended pipeline is executed on unnamed portal and statement: parse + bind + describe + execute. + /// It returns the query result (in the same format as pico.sql) and description. + /// We need the description here for the command tag (CREATE TABLE, SELECT, etc) + /// that is required to make a CommandComplete message as the response to the Query message, + /// so pico.sql is not enough. + /// + /// No resources to be free after the call. + simple_query: Entrypoint, + + /// Close client statements and portals by the given client id. + close_client_statements: Entrypoint, +} + +impl Entrypoints { + fn new() -> PgResult<Self> { + let simple_query = LuaFunction::load( + tarantool::lua_state(), + " + local function close_unnamed(client_id) + pico.pg_close_stmt(client_id, '') + pico.pg_close_portal(client_id, '') + end + + local function parse_and_execute_unnamed(client_id, sql) + local res, err = pico.pg_parse(client_id, '', sql, {}) + if res == nil then + return nil, err + end + + local res, err = pico.pg_bind(client_id, '', '') + if res == nil then + return nil, err + end + + local desc, err = pico.pg_describe_portal(client_id, '') + if desc == nil then + return nil, err + end + + local res, err = pico.pg_execute(client_id, '') + if res == nil then + return nil, err + end + + return {['describe'] = desc, ['result'] = res} + end + + local client_id, sql = ... + + -- Strictly speaking, this is a part of an extended query protocol. + -- When a query message is received, PG closes an unnamed portal and statement + -- and runs the extended query pipeline on them (parse + bind + execute). + close_unnamed(client_id) + + local res, err = parse_and_execute_unnamed(client_id, sql) + + -- After the execution, the portal and statement must be closed. + close_unnamed(client_id) + + if res == nil then + error(err) + end + + return require('json').encode(res) + ", + )?; + + let close_client_statements = LuaFunction::load( + tarantool::lua_state(), + " + -- closing a statement closes its portals, + -- so then we close all the statements we close all the portals too. + pico.pg_close_client_stmts(...) + ", + )?; + + Ok(Self { + simple_query, + close_client_statements, + }) + } + + /// Handler for a Query message. See self.simple_query for the details. + pub fn simple_query(&self, client_id: ClientId, sql: &str) -> PgResult<ExecuteResult> { + let json: String = self + .simple_query + .call_with_args((client_id, sql)) + .map_err(|e| PgError::TarantoolError(e.into()))?; + execute_result_from_json(&json) + } + + /// Close all the client statements and portals. See self.close_client_statements for the details. + pub fn close_client_statements(&self, client_id: ClientId) -> PgResult<()> { + self.close_client_statements + .call_with_args(client_id) + .map_err(|e| PgError::TarantoolError(e.into())) + } +} + +thread_local! { + pub static PG_ENTRYPOINTS: RefCell<Entrypoints> = RefCell::new(Entrypoints::new().unwrap()) +} diff --git a/pgproto/src/lib.rs b/pgproto/src/lib.rs index 809475d887..e6b8fe00a7 100644 --- a/pgproto/src/lib.rs +++ b/pgproto/src/lib.rs @@ -1,9 +1,10 @@ mod client; +mod entrypoints; mod error; mod helpers; mod messages; mod server; -mod sql; +mod storage; mod stream; use crate::client::PgClient; diff --git a/pgproto/src/sql.rs b/pgproto/src/sql.rs deleted file mode 100644 index 16709f6aa0..0000000000 --- a/pgproto/src/sql.rs +++ /dev/null @@ -1,6 +0,0 @@ -pub mod describe; -pub mod portal; -pub mod statement; -pub mod value; - -mod handle; diff --git a/pgproto/src/sql/handle.rs b/pgproto/src/sql/handle.rs deleted file mode 100644 index dd495aa0a5..0000000000 --- a/pgproto/src/sql/handle.rs +++ /dev/null @@ -1,127 +0,0 @@ -use super::describe::Describe; -use serde::de::DeserializeOwned; - -use crate::error::PgResult; - -#[derive(Debug)] -pub struct Handle(usize); - -impl Handle { - /// Parse the statement and store its description in the engine; - /// then return an opaque handle (descriptor) to that description. - pub fn prepare(query: &str) -> PgResult<Handle> { - let code = format!( - " - local res, err = pico.pg_parse([[{query}]]) - - if res == nil then - error(err) - end - - return res - " - ); - let desc: usize = tarantool::lua_state().eval(&code)?; - Ok(desc.into()) - } - - /// Execute query and get a field from result. - pub fn execute_and_get<T: DeserializeOwned>(&self, name: &str) -> PgResult<T> { - let desc = self.0; - let code = format!( - " - local res, err = pico.pg_execute({desc}) - - if res == nil then - error(err) - end - - return require('json').encode(res['{name}']) - " - ); - let raw_result = tarantool::lua_state().eval::<String>(&code)?; - let result: T = serde_json::from_str(&raw_result)?; - Ok(result) - } - - /// Execute query and get result. - pub fn execute<T: DeserializeOwned>(&self) -> PgResult<T> { - let desc = self.0; - let code = format!( - " - local res, err = pico.pg_execute({desc}) - - if res == nil then - error(err) - end - - return require('json').encode(res) - " - ); - let raw_result = tarantool::lua_state().eval::<String>(&code)?; - let result: T = serde_json::from_str(&raw_result)?; - Ok(result) - } - - /// Bind parameters and optimize query. - pub fn bind(&self) -> PgResult<()> { - let desc = self.0; - let code = format!( - " - local res, err = pico.pg_bind({desc}, {{}}) - - if res == nil then - error(err) - end - " - ); - tarantool::lua_state().eval::<()>(&code)?; - Ok(()) - } - - /// Get query description. - pub fn describe(&self) -> PgResult<Describe> { - let desc = self.0; - let code = format!( - " - local res, err = pico.pg_describe({desc}) - - if res == nil then - error(err) - end - - return require('json').encode(res) - " - ); - let describe_str = tarantool::lua_state().eval::<String>(&code)?; - let describe = serde_json::from_str(&describe_str)?; - Ok(describe) - } - - fn close(&self) -> PgResult<()> { - let desc = self.0; - let code = format!( - " - local res, err = pico.pg_close({desc}) - - if res == nil then - error(err) - end - " - ); - tarantool::lua_state().eval(&code)?; - Ok(()) - } -} - -impl Drop for Handle { - fn drop(&mut self) { - self.close().unwrap() - } -} - -impl From<usize> for Handle { - fn from(value: usize) -> Self { - Handle(value) - } -} diff --git a/pgproto/src/sql/portal.rs b/pgproto/src/sql/portal.rs deleted file mode 100644 index 03f02357e6..0000000000 --- a/pgproto/src/sql/portal.rs +++ /dev/null @@ -1,134 +0,0 @@ -use super::describe::{CommandTag, QueryType}; -use super::handle::Handle; -use super::statement::Statement; -use super::value::PgValue; -use crate::error::PgResult; -use bytes::BytesMut; -use pgwire::messages::data::{DataRow, RowDescription}; -use serde_json::Value; -use std::vec::IntoIter; - -enum PortalState { - /// Freshly created, not executed yet. - New, - /// Query is executing, can be deleted. - Running, - /// Query execution is finished. - Finished, -} - -/// Portal is a result of statement binding. -/// It allows you to execute queries and acts as a cursor over the resulting tuples. -pub struct Portal { - // TODO: consider using statement name or reference - state: PortalState, - statement: Statement, - values_stream: IntoIter<Vec<PgValue>>, - row_count: usize, -} - -impl Portal { - pub fn new(statement: Statement) -> Portal { - Portal { - statement, - state: PortalState::New, - values_stream: IntoIter::default(), - row_count: 0, - } - } - - /// Take next tuple. - pub fn execute_one(&mut self) -> PgResult<Option<DataRow>> { - if let PortalState::New = self.state { - // NOTE: this changes the state. - self.execute_all_and_store()?; - } - - if let PortalState::Finished = self.state { - return Ok(None); - } - - let mut buf = BytesMut::new(); - if let Some(values) = self.values_stream.next() { - let row = encode_row(values, &mut buf); - self.row_count += 1; - buf.clear(); - Ok(Some(row)) - } else { - self.state = PortalState::Finished; - Ok(None) - } - } - - /// Get the format of the output tuples. - pub fn row_description(&self) -> PgResult<RowDescription> { - self.statement.describe().row_description() - } - - pub fn command_tag(&self) -> &CommandTag { - self.statement.describe().command_tag() - } - - pub fn sends_rows(&self) -> bool { - let query_type = self.statement.describe().query_type(); - matches!(query_type, QueryType::Dql | QueryType::Explain) - } - - /// Get the number of returned or modified tuples. - /// None is returned if portal doesn't return or modify tuples. - pub fn row_count(&self) -> Option<usize> { - let query_type = self.statement.describe().query_type(); - match query_type { - QueryType::Dml | QueryType::Dql | QueryType::Explain => Some(self.row_count), - _ => None, - } - } - - /// Execute statement and store result in portal. - fn execute_all_and_store(&mut self) -> PgResult<()> { - let query_type = self.statement.describe().query_type(); - let handle = self.statement.handle(); - match query_type { - QueryType::Dql => { - self.values_stream = execute_dql(handle)?.into_iter(); - self.state = PortalState::Running; - } - QueryType::Acl | QueryType::Ddl | QueryType::Dml => { - self.row_count = execute_acl_ddl_or_dml(handle)?; - self.state = PortalState::Finished; - } - QueryType::Explain => { - self.values_stream = execute_explain(handle)?.into_iter(); - self.state = PortalState::Running; - } - } - Ok(()) - } -} - -fn encode_row(values: Vec<PgValue>, buf: &mut BytesMut) -> DataRow { - let row = values.into_iter().map(|v| v.encode(buf).unwrap()).collect(); - DataRow::new(row) -} - -fn execute_dql(handle: &Handle) -> PgResult<Vec<Vec<PgValue>>> { - let raw: Vec<Vec<Value>> = handle.execute_and_get("rows")?; - Ok(raw - .into_iter() - .map(|row| row.into_iter().map(PgValue::from).collect()) - .collect()) -} - -fn execute_acl_ddl_or_dml(handle: &Handle) -> PgResult<usize> { - let row_count: usize = handle.execute_and_get("row_count")?; - Ok(row_count) -} - -fn execute_explain(handle: &Handle) -> PgResult<Vec<Vec<PgValue>>> { - let explain: Vec<Value> = handle.execute()?; - Ok(explain - .into_iter() - // every row must be a vector - .map(|val| vec![PgValue::from(val)]) - .collect()) -} diff --git a/pgproto/src/sql/statement.rs b/pgproto/src/sql/statement.rs deleted file mode 100644 index 1c164ecda7..0000000000 --- a/pgproto/src/sql/statement.rs +++ /dev/null @@ -1,41 +0,0 @@ -use super::describe::Describe; -use super::handle::Handle; -use super::portal::Portal; -use crate::error::PgResult; -use std::rc::Rc; - -#[derive(Debug)] -struct StatementImpl { - handle: Handle, - describe: Describe, -} - -#[derive(Debug, Clone)] -pub struct Statement(Rc<StatementImpl>); - -impl Statement { - /// Prepare statement from query. - pub fn prepare(query: &str) -> PgResult<Statement> { - let handle = Handle::prepare(query)?; - let describe = handle.describe()?; - let stmt = StatementImpl { handle, describe }; - - Ok(Statement(stmt.into())) - } - - /// Create a portal by binding parameters to statement. - pub fn bind(&self) -> PgResult<Portal> { - self.0.handle.bind()?; - Ok(Portal::new(self.clone())) - } - - /// Get the handle representing the statement. - pub fn handle(&self) -> &Handle { - &self.0.handle - } - - /// Get statement description. - pub fn describe(&self) -> &Describe { - &self.0.describe - } -} diff --git a/pgproto/src/storage.rs b/pgproto/src/storage.rs new file mode 100644 index 0000000000..544b8a6fda --- /dev/null +++ b/pgproto/src/storage.rs @@ -0,0 +1,56 @@ +use self::result::ExecuteResult; +use crate::client::ClientId; +use crate::entrypoints::PG_ENTRYPOINTS; +use crate::error::PgResult; +use log::warn; +use std::sync::atomic::{AtomicU32, Ordering}; + +pub mod describe; +pub mod result; +pub mod value; + +fn unique_id() -> ClientId { + static ID_COUNTER: AtomicU32 = AtomicU32::new(0); + ID_COUNTER.fetch_add(1, Ordering::Relaxed) +} + +/// It allows to interact with the real storage in sbroad +/// and cleanups all the created portals and statements when it's dropped. +pub struct StorageManager { + /// Unique id of a PG client. + /// Every portal and statement is tagged by this id in the sbroad storage, + /// so they all can be found and deleted using this id. + /// Since the id is unique the statements and portals are isolated between users. + client_id: ClientId, +} + +impl StorageManager { + pub fn new() -> Self { + Self { + client_id: unique_id(), + } + } + + /// Handler for a Query message. See Entrypoints::simple_query for the details. + pub fn simple_query(&self, sql: &str) -> PgResult<ExecuteResult> { + PG_ENTRYPOINTS.with(|entrypoints| entrypoints.borrow().simple_query(self.client_id, sql)) + } + + fn on_disconnect(&self) -> PgResult<()> { + // Close all the statements with its portals. + PG_ENTRYPOINTS + .with(|entrypoints| entrypoints.borrow().close_client_statements(self.client_id)) + } +} + +impl Drop for StorageManager { + fn drop(&mut self) { + match self.on_disconnect() { + Ok(_) => {} + Err(err) => warn!( + "failed to close user {} statements and portals: {:?}", + self.client_id, err + ), + } + } +} diff --git a/pgproto/src/sql/describe.rs b/pgproto/src/storage/describe.rs similarity index 80% rename from pgproto/src/sql/describe.rs rename to pgproto/src/storage/describe.rs index 9cf30a8449..2ad8877eb6 100644 --- a/pgproto/src/sql/describe.rs +++ b/pgproto/src/storage/describe.rs @@ -1,4 +1,4 @@ -use crate::{error::PgResult, sql::value}; +use crate::{error::PgResult, storage::value}; use pgwire::messages::data::{FieldDescription, RowDescription}; use postgres_types::Type; use serde::Deserialize; @@ -111,15 +111,21 @@ impl Describe { &self.command_tag } - pub fn row_description(&self) -> PgResult<RowDescription> { - let row_description = self - .metadata - .iter() - .map(|col| { - let type_str = col.r#type.as_str(); - value::type_from_name(type_str).map(|ty| field_description(col.name.clone(), ty)) - }) - .collect::<PgResult<_>>()?; - Ok(RowDescription::new(row_description)) + pub fn row_description(&self) -> PgResult<Option<RowDescription>> { + match self.query_type() { + QueryType::Acl | QueryType::Ddl | QueryType::Dml => Ok(None), + QueryType::Dql | QueryType::Explain => { + let row_description = self + .metadata + .iter() + .map(|col| { + let type_str = col.r#type.as_str(); + value::type_from_name(type_str) + .map(|ty| field_description(col.name.clone(), ty)) + }) + .collect::<PgResult<_>>()?; + Ok(Some(RowDescription::new(row_description))) + } + } } } diff --git a/pgproto/src/storage/result.rs b/pgproto/src/storage/result.rs new file mode 100644 index 0000000000..3ba298c09e --- /dev/null +++ b/pgproto/src/storage/result.rs @@ -0,0 +1,69 @@ +use super::{ + describe::{CommandTag, Describe, QueryType}, + value::PgValue, +}; +use crate::error::PgResult; +use bytes::BytesMut; +use pgwire::messages::data::{DataRow, RowDescription}; +use std::vec::IntoIter; + +fn encode_row(values: Vec<PgValue>, buf: &mut BytesMut) -> DataRow { + let row = values.into_iter().map(|v| v.encode(buf).unwrap()).collect(); + DataRow::new(row) +} + +pub struct ExecuteResult { + describe: Describe, + values_stream: IntoIter<Vec<PgValue>>, + row_count: usize, + buf: BytesMut, +} + +impl ExecuteResult { + pub fn new(rows: Vec<Vec<PgValue>>, describe: Describe) -> Self { + let values_stream = rows.into_iter(); + Self { + values_stream, + describe, + row_count: 0, + buf: BytesMut::default(), + } + } + + pub fn empty(row_count: usize, describe: Describe) -> Self { + Self { + values_stream: Default::default(), + describe, + row_count, + buf: BytesMut::default(), + } + } + + pub fn command_tag(&self) -> &CommandTag { + self.describe.command_tag() + } + + pub fn row_description(&self) -> PgResult<Option<RowDescription>> { + self.describe.row_description() + } + + pub fn row_count(&self) -> Option<usize> { + match self.describe.query_type() { + QueryType::Dml | QueryType::Dql | QueryType::Explain => Some(self.row_count), + _ => None, + } + } +} + +impl Iterator for ExecuteResult { + type Item = DataRow; + + fn next(&mut self) -> Option<DataRow> { + self.values_stream.next().map(|row| { + let row = encode_row(row, &mut self.buf); + self.buf.clear(); + self.row_count += 1; + row + }) + } +} diff --git a/pgproto/src/sql/value.rs b/pgproto/src/storage/value.rs similarity index 100% rename from pgproto/src/sql/value.rs rename to pgproto/src/storage/value.rs -- GitLab