From 09a8a5f0fb8f8f45cd684473aacaff55eb427f92 Mon Sep 17 00:00:00 2001 From: Kaitmazian Maksim <m.kaitmazian@picodata.io> Date: Fri, 29 Mar 2024 15:09:47 +0300 Subject: [PATCH] refactor(pgproto): move proc_pg_* exports in a separate file --- src/pgproto/backend.rs | 172 ++++++++++------------------------ src/pgproto/backend/pgproc.rs | 128 +++++++++++++++++++++++++ 2 files changed, 180 insertions(+), 120 deletions(-) create mode 100644 src/pgproto/backend/pgproc.rs diff --git a/src/pgproto/backend.rs b/src/pgproto/backend.rs index fd5126ef12..50965b06cd 100644 --- a/src/pgproto/backend.rs +++ b/src/pgproto/backend.rs @@ -1,8 +1,5 @@ use self::describe::{PortalDescribe, StatementDescribe}; -use self::storage::{ - with_portals_mut, Portal, Statement, UserPortalNames, UserStatementNames, PG_PORTALS, - PG_STATEMENTS, -}; +use self::storage::{with_portals_mut, Portal, Statement, PG_PORTALS, PG_STATEMENTS}; use super::client::ClientId; use super::error::{PgError, PgResult}; use crate::pgproto::storage::value::Format; @@ -11,7 +8,6 @@ use crate::sql::otm::TracerKind; use crate::sql::router::RouterRuntime; use crate::sql::with_tracer; use crate::traft::error::Error; -use ::tarantool::proc; use opentelemetry::sdk::trace::Tracer; use opentelemetry::Context; use postgres_types::Oid; @@ -19,62 +15,19 @@ use sbroad::executor::engine::helpers::normalize_name_for_space_api; use sbroad::executor::engine::{QueryCache, Router, TableVersionMap}; use sbroad::executor::lru::Cache; use sbroad::frontend::Ast; -use sbroad::ir::value::{LuaValue, Value}; +use sbroad::ir::value::Value; use sbroad::ir::Plan as IrPlan; use sbroad::otm::{query_id, query_span, OTM_CHAR_LIMIT}; use sbroad::utils::MutexLike; -use serde::Deserialize; use smol_str::ToSmolStr; use std::rc::Rc; use tarantool::session::with_su; use tarantool::tuple::Tuple; -pub mod describe; +mod pgproc; mod storage; -struct BindArgs { - id: ClientId, - stmt_name: String, - portal_name: String, - params: Vec<Value>, - encoding_format: Vec<u8>, - traceable: bool, -} - -impl<'de> Deserialize<'de> for BindArgs { - fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> - where - D: serde::Deserializer<'de>, - { - #[derive(Deserialize)] - struct EncodedBindArgs( - ClientId, - String, - String, - Option<Vec<LuaValue>>, - Vec<u8>, - Option<bool>, - ); - - let EncodedBindArgs(id, stmt_name, portal_name, params, encoding_format, traceable) = - EncodedBindArgs::deserialize(deserializer)?; - - let params = params - .unwrap_or_default() - .into_iter() - .map(Value::from) - .collect::<Vec<Value>>(); - - Ok(Self { - id, - stmt_name, - portal_name, - params, - encoding_format, - traceable: traceable.unwrap_or(false), - }) - } -} +pub mod describe; // helper function to get `TracerRef` fn get_tracer_param(traceable: bool) -> &'static Tracer { @@ -82,17 +35,15 @@ fn get_tracer_param(traceable: bool) -> &'static Tracer { kind.get_tracer() } -#[proc(packed_args)] -pub fn proc_pg_bind(args: BindArgs) -> PgResult<()> { - let BindArgs { - id, - stmt_name, - portal_name, - params, - encoding_format: output_format, - traceable, - } = args; - let key = (id, stmt_name.into()); +pub fn bind( + client_id: ClientId, + stmt_name: String, + portal_name: String, + params: Vec<Value>, + output_format: Vec<u8>, + traceable: bool, +) -> PgResult<()> { + let key = (client_id, stmt_name.into()); let Some(statement) = PG_STATEMENTS.with(|storage| storage.borrow().get(&key)) else { return Err(PgError::Other( format!("Couldn't find statement \'{}\'.", key.1).into(), @@ -120,65 +71,15 @@ pub fn proc_pg_bind(args: BindArgs) -> PgResult<()> { }, )?; - PG_PORTALS.with(|storage| storage.borrow_mut().put((id, portal_name.into()), portal))?; + PG_PORTALS.with(|storage| { + storage + .borrow_mut() + .put((client_id, portal_name.into()), portal) + })?; Ok(()) } -#[proc] -pub fn proc_pg_statements(id: ClientId) -> UserStatementNames { - UserStatementNames::new(id) -} - -#[proc] -pub fn proc_pg_portals(id: ClientId) -> UserPortalNames { - UserPortalNames::new(id) -} - -#[proc] -pub fn proc_pg_close_stmt(id: ClientId, name: String) { - // Close can't cause an error in PG. - PG_STATEMENTS.with(|storage| storage.borrow_mut().remove(&(id, name.into()))); -} - -#[proc] -pub fn proc_pg_close_portal(id: ClientId, name: String) { - // Close can't cause an error in PG. - PG_PORTALS.with(|storage| storage.borrow_mut().remove(&(id, name.into()))); -} - -#[proc] -pub fn proc_pg_close_client_stmts(id: ClientId) { - PG_STATEMENTS.with(|storage| storage.borrow_mut().remove_by_client_id(id)) -} - -#[proc] -pub fn proc_pg_close_client_portals(id: ClientId) { - PG_PORTALS.with(|storage| storage.borrow_mut().remove_by_client_id(id)) -} - -#[proc] -pub fn proc_pg_describe_stmt(id: ClientId, name: String) -> PgResult<StatementDescribe> { - let key = (id, name.into()); - let Some(statement) = PG_STATEMENTS.with(|storage| storage.borrow().get(&key)) else { - return Err(PgError::Other( - format!("Couldn't find statement \'{}\'.", key.1).into(), - )); - }; - Ok(statement.describe().clone()) -} - -#[proc] -pub fn proc_pg_describe_portal(id: ClientId, name: String) -> PgResult<PortalDescribe> { - with_portals_mut((id, name.into()), |portal| Ok(portal.describe().clone())) -} - -#[proc] -pub fn proc_pg_execute( - id: ClientId, - name: String, - max_rows: i64, - traceable: bool, -) -> PgResult<Tuple> { +pub fn execute(id: ClientId, name: String, max_rows: i64, traceable: bool) -> PgResult<Tuple> { let max_rows = if max_rows <= 0 { i64::MAX } else { max_rows }; let name = Rc::from(name); @@ -199,8 +100,7 @@ pub fn proc_pg_execute( }) } -#[proc] -pub fn proc_pg_parse( +pub fn parse( cid: ClientId, name: String, query: String, @@ -258,3 +158,35 @@ pub fn proc_pg_parse( }, ) } + +pub fn describe_stmt(id: ClientId, name: String) -> PgResult<StatementDescribe> { + let key = (id, name.into()); + let Some(statement) = PG_STATEMENTS.with(|storage| storage.borrow().get(&key)) else { + return Err(PgError::Other( + format!("Couldn't find statement \'{}\'.", key.1).into(), + )); + }; + Ok(statement.describe().clone()) +} + +pub fn describe_portal(id: ClientId, name: String) -> PgResult<PortalDescribe> { + with_portals_mut((id, name.into()), |portal| Ok(portal.describe().clone())) +} + +pub fn close_stmt(id: ClientId, name: String) { + // Close can't cause an error in PG. + PG_STATEMENTS.with(|storage| storage.borrow_mut().remove(&(id, name.into()))); +} + +pub fn close_portal(id: ClientId, name: String) { + // Close can't cause an error in PG. + PG_PORTALS.with(|storage| storage.borrow_mut().remove(&(id, name.into()))); +} + +pub fn close_client_stmts(id: ClientId) { + PG_STATEMENTS.with(|storage| storage.borrow_mut().remove_by_client_id(id)) +} + +pub fn close_client_portals(id: ClientId) { + PG_PORTALS.with(|storage| storage.borrow_mut().remove_by_client_id(id)) +} diff --git a/src/pgproto/backend/pgproc.rs b/src/pgproto/backend/pgproc.rs new file mode 100644 index 0000000000..5f98cf925c --- /dev/null +++ b/src/pgproto/backend/pgproc.rs @@ -0,0 +1,128 @@ +use super::describe::{PortalDescribe, StatementDescribe}; +use super::storage::{UserPortalNames, UserStatementNames}; +use crate::pgproto::backend; +use crate::pgproto::{client::ClientId, error::PgResult}; +use ::tarantool::proc; +use postgres_types::Oid; +use sbroad::ir::value::{LuaValue, Value}; +use serde::Deserialize; +use tarantool::tuple::Tuple; + +struct BindArgs { + id: ClientId, + stmt_name: String, + portal_name: String, + params: Vec<Value>, + encoding_format: Vec<u8>, + traceable: bool, +} + +impl<'de> Deserialize<'de> for BindArgs { + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> + where + D: serde::Deserializer<'de>, + { + #[derive(Deserialize)] + struct EncodedBindArgs( + ClientId, + String, + String, + Option<Vec<LuaValue>>, + Vec<u8>, + Option<bool>, + ); + + let EncodedBindArgs(id, stmt_name, portal_name, params, encoding_format, traceable) = + EncodedBindArgs::deserialize(deserializer)?; + + let params = params + .unwrap_or_default() + .into_iter() + .map(Value::from) + .collect::<Vec<Value>>(); + + Ok(Self { + id, + stmt_name, + portal_name, + params, + encoding_format, + traceable: traceable.unwrap_or(false), + }) + } +} + +#[proc(packed_args)] +pub fn proc_pg_bind(args: BindArgs) -> PgResult<()> { + let BindArgs { + id, + stmt_name, + portal_name, + params, + encoding_format: output_format, + traceable, + } = args; + + backend::bind(id, stmt_name, portal_name, params, output_format, traceable) +} + +#[proc] +pub fn proc_pg_describe_stmt(id: ClientId, name: String) -> PgResult<StatementDescribe> { + backend::describe_stmt(id, name) +} + +#[proc] +pub fn proc_pg_describe_portal(id: ClientId, name: String) -> PgResult<PortalDescribe> { + backend::describe_portal(id, name) +} + +#[proc] +pub fn proc_pg_execute( + id: ClientId, + name: String, + max_rows: i64, + traceable: bool, +) -> PgResult<Tuple> { + backend::execute(id, name, max_rows, traceable) +} + +#[proc] +pub fn proc_pg_parse( + id: ClientId, + name: String, + query: String, + param_oids: Vec<Oid>, + traceable: bool, +) -> PgResult<()> { + backend::parse(id, name, query, param_oids, traceable) +} + +#[proc] +pub fn proc_pg_close_stmt(id: ClientId, name: String) { + backend::close_stmt(id, name) +} + +#[proc] +pub fn proc_pg_close_portal(id: ClientId, name: String) { + backend::close_portal(id, name) +} + +#[proc] +pub fn proc_pg_close_client_stmts(id: ClientId) { + backend::close_client_stmts(id) +} + +#[proc] +pub fn proc_pg_close_client_portals(id: ClientId) { + backend::close_client_portals(id) +} + +#[proc] +pub fn proc_pg_statements(id: ClientId) -> UserStatementNames { + UserStatementNames::new(id) +} + +#[proc] +pub fn proc_pg_portals(id: ClientId) -> UserPortalNames { + UserPortalNames::new(id) +} -- GitLab