diff --git a/src/bootstrap_entries.rs b/src/bootstrap_entries.rs index bc9d2f6959a201326a21686db31194bf22c085ac..ed6fb2f0c2a72ed5ae34ba036502ec8bd4f58dd8 100644 --- a/src/bootstrap_entries.rs +++ b/src/bootstrap_entries.rs @@ -5,10 +5,10 @@ use ::tarantool::msgpack; use crate::config::PicodataConfig; use crate::instance::Instance; +use crate::pgproto; use crate::replicaset::Replicaset; use crate::schema; use crate::schema::ADMIN_ID; -use crate::sql::pgproto; use crate::storage; use crate::storage::ClusterwideTable; use crate::storage::PropertyName; diff --git a/src/pgproto.rs b/src/pgproto.rs index 9698a6089903e6adcf7216c086ceb949e7525d49..44f20f4d739807eb9650ee0d8242bc4f27163663 100644 --- a/src/pgproto.rs +++ b/src/pgproto.rs @@ -8,6 +8,7 @@ use std::path::{Path, PathBuf}; use stream::PgStream; use tarantool::coio::{CoIOListener, CoIOStream}; +mod backend; mod client; mod entrypoints; mod error; @@ -17,6 +18,9 @@ mod storage; mod stream; mod tls; +pub const DEFAULT_MAX_PG_STATEMENTS: usize = 50; +pub const DEFAULT_MAX_PG_PORTALS: usize = 50; + /// Main postgres server configuration. #[derive(PartialEq, Default, Debug, Clone, serde::Deserialize, serde::Serialize, Introspection)] #[serde(deny_unknown_fields)] diff --git a/src/pgproto/backend.rs b/src/pgproto/backend.rs new file mode 100644 index 0000000000000000000000000000000000000000..56d397d78c1a3747a17262e8f704e2ba293998f9 --- /dev/null +++ b/src/pgproto/backend.rs @@ -0,0 +1,253 @@ +use self::storage::{ + with_portals_mut, Portal, PortalDescribe, Statement, StatementDescribe, UserPortalNames, + UserStatementNames, PG_PORTALS, PG_STATEMENTS, +}; +use super::client::ClientId; +use super::error::{PgError, PgResult}; +use crate::schema::ADMIN_ID; +use crate::sql::otm::TracerKind; +use crate::sql::router::RouterRuntime; +use crate::sql::with_tracer; +use crate::traft::error::Error; +use ::tarantool::proc; +use opentelemetry::sdk::trace::Tracer; +use opentelemetry::Context; +use postgres_types::Oid; +use sbroad::executor::engine::helpers::normalize_name_for_space_api; +use sbroad::executor::engine::{QueryCache, Router, TableVersionMap}; +use sbroad::executor::lru::Cache; +use sbroad::frontend::Ast; +use sbroad::ir::value::{LuaValue, Value}; +use sbroad::ir::Plan as IrPlan; +use sbroad::otm::{query_id, query_span, OTM_CHAR_LIMIT}; +use sbroad::utils::MutexLike; +use serde::Deserialize; +use smol_str::ToSmolStr; +use std::rc::Rc; +use tarantool::session::with_su; +use tarantool::tuple::Tuple; + +mod storage; + +struct BindArgs { + id: ClientId, + stmt_name: String, + portal_name: String, + params: Vec<Value>, + encoding_format: Vec<u8>, + traceable: bool, +} + +impl<'de> Deserialize<'de> for BindArgs { + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> + where + D: serde::Deserializer<'de>, + { + #[derive(Deserialize)] + struct EncodedBindArgs( + ClientId, + String, + String, + Option<Vec<LuaValue>>, + Vec<u8>, + Option<bool>, + ); + + let EncodedBindArgs(id, stmt_name, portal_name, params, encoding_format, traceable) = + EncodedBindArgs::deserialize(deserializer)?; + + let params = params + .unwrap_or_default() + .into_iter() + .map(Value::from) + .collect::<Vec<Value>>(); + + Ok(Self { + id, + stmt_name, + portal_name, + params, + encoding_format, + traceable: traceable.unwrap_or(false), + }) + } +} + +// helper function to get `TracerRef` +fn get_tracer_param(traceable: bool) -> &'static Tracer { + let kind = TracerKind::from_traceable(traceable); + kind.get_tracer() +} + +#[proc(packed_args)] +pub fn proc_pg_bind(args: BindArgs) -> PgResult<()> { + let BindArgs { + id, + stmt_name, + portal_name, + params, + encoding_format: output_format, + traceable, + } = args; + let key = (id, stmt_name.into()); + let Some(statement) = PG_STATEMENTS.with(|storage| storage.borrow().get(&key)) else { + return Err(PgError::Other( + format!("Couldn't find statement \'{}\'.", key.1).into(), + )); + }; + let mut plan = statement.plan().clone(); + let ctx = with_tracer(Context::new(), TracerKind::from_traceable(traceable)); + let portal = query_span::<PgResult<_>, _>( + "\"api.router.bind\"", + statement.id(), + get_tracer_param(traceable), + &ctx, + statement.query_pattern(), + || { + if !plan.is_ddl()? && !plan.is_acl()? { + plan.bind_params(params)?; + plan.apply_options()?; + plan.optimize()?; + } + Portal::new(plan, statement.clone(), output_format) + }, + )?; + + PG_PORTALS.with(|storage| storage.borrow_mut().put((id, portal_name.into()), portal))?; + Ok(()) +} + +#[proc] +pub fn proc_pg_statements(id: ClientId) -> UserStatementNames { + UserStatementNames::new(id) +} + +#[proc] +pub fn proc_pg_portals(id: ClientId) -> UserPortalNames { + UserPortalNames::new(id) +} + +#[proc] +pub fn proc_pg_close_stmt(id: ClientId, name: String) { + // Close can't cause an error in PG. + PG_STATEMENTS.with(|storage| storage.borrow_mut().remove(&(id, name.into()))); +} + +#[proc] +pub fn proc_pg_close_portal(id: ClientId, name: String) { + // Close can't cause an error in PG. + PG_PORTALS.with(|storage| storage.borrow_mut().remove(&(id, name.into()))); +} + +#[proc] +pub fn proc_pg_close_client_stmts(id: ClientId) { + PG_STATEMENTS.with(|storage| storage.borrow_mut().remove_by_client_id(id)) +} + +#[proc] +pub fn proc_pg_close_client_portals(id: ClientId) { + PG_PORTALS.with(|storage| storage.borrow_mut().remove_by_client_id(id)) +} + +#[proc] +pub fn proc_pg_describe_stmt(id: ClientId, name: String) -> PgResult<StatementDescribe> { + let key = (id, name.into()); + let Some(statement) = PG_STATEMENTS.with(|storage| storage.borrow().get(&key)) else { + return Err(PgError::Other( + format!("Couldn't find statement \'{}\'.", key.1).into(), + )); + }; + Ok(statement.describe().clone()) +} + +#[proc] +pub fn proc_pg_describe_portal(id: ClientId, name: String) -> PgResult<PortalDescribe> { + with_portals_mut((id, name.into()), |portal| Ok(portal.describe().clone())) +} + +#[proc] +pub fn proc_pg_execute( + id: ClientId, + name: String, + max_rows: i64, + traceable: bool, +) -> PgResult<Tuple> { + let max_rows = if max_rows <= 0 { i64::MAX } else { max_rows }; + let name = Rc::from(name); + + let statement = with_portals_mut((id, Rc::clone(&name)), |portal| { + // We are cloning Rc here. + Ok(portal.statement().clone()) + })?; + with_portals_mut((id, name), |portal| { + let ctx = with_tracer(Context::new(), TracerKind::from_traceable(traceable)); + query_span::<PgResult<Tuple>, _>( + "\"api.router.execute\"", + statement.id(), + get_tracer_param(traceable), + &ctx, + statement.query_pattern(), + || portal.execute(max_rows as usize), + ) + }) +} + +#[proc] +pub fn proc_pg_parse( + cid: ClientId, + name: String, + query: String, + param_oids: Vec<Oid>, + traceable: bool, +) -> PgResult<()> { + let id = query_id(&query); + // Keep the query patterns for opentelemetry spans short enough. + let sql = query + .char_indices() + .filter_map(|(i, c)| if i <= OTM_CHAR_LIMIT { Some(c) } else { None }) + .collect::<String>(); + let ctx = with_tracer(Context::new(), TracerKind::from_traceable(traceable)); + query_span::<PgResult<()>, _>( + "\"api.router.parse\"", + &id.clone(), + get_tracer_param(traceable), + &ctx, + &sql.clone(), + || { + let runtime = RouterRuntime::new().map_err(Error::from)?; + let mut cache = runtime.cache().lock(); + let cache_entry = with_su(ADMIN_ID, || cache.get(&query.to_smolstr()))??; + if let Some(plan) = cache_entry { + let statement = + Statement::new(id.to_string(), sql.clone(), plan.clone(), param_oids)?; + PG_STATEMENTS + .with(|cache| cache.borrow_mut().put((cid, name.into()), statement))?; + return Ok(()); + } + let metadata = &*runtime.metadata().lock(); + let plan = with_su(ADMIN_ID, || -> PgResult<IrPlan> { + let mut plan = + <RouterRuntime as Router>::ParseTree::transform_into_plan(&query, metadata)?; + if runtime.provides_versions() { + let mut table_version_map = + TableVersionMap::with_capacity(plan.relations.tables.len()); + for table in plan.relations.tables.keys() { + let normalized = normalize_name_for_space_api(table); + let version = runtime.get_table_version(normalized.as_str())?; + table_version_map.insert(normalized, version); + } + plan.version_map = table_version_map; + } + Ok(plan) + }) + .map_err(|e| PgError::Other(e.into()))??; + if !plan.is_ddl()? && !plan.is_acl()? { + cache.put(query.into(), plan.clone())?; + } + let statement = Statement::new(id.to_string(), sql, plan, param_oids)?; + PG_STATEMENTS + .with(|storage| storage.borrow_mut().put((cid, name.into()), statement))?; + Ok(()) + }, + ) +} diff --git a/src/sql/pgproto.rs b/src/pgproto/backend/storage.rs similarity index 92% rename from src/sql/pgproto.rs rename to src/pgproto/backend/storage.rs index ce6b3552360c38ce14bf942ef88c7354df0eecc6..5daf95302192321269010d1848f6407bdd525865 100644 --- a/src/sql/pgproto.rs +++ b/src/pgproto/backend/storage.rs @@ -1,7 +1,6 @@ -//! PostgreSQL protocol. - -use crate::traft::error::Error; -use crate::traft::{self, node}; +use crate::pgproto::error::{PgError, PgResult}; +use crate::pgproto::{DEFAULT_MAX_PG_PORTALS, DEFAULT_MAX_PG_STATEMENTS}; +use crate::traft::node; use ::tarantool::tuple::Tuple; use rmpv::Value; use sbroad::errors::{Entity, SbroadError}; @@ -27,11 +26,8 @@ use std::vec::IntoIter; use tarantool::proc::{Return, ReturnMsgpack}; use tarantool::tuple::FunctionCtx; -use super::dispatch; -use super::router::RouterRuntime; - -pub const DEFAULT_MAX_PG_STATEMENTS: usize = 50; -pub const DEFAULT_MAX_PG_PORTALS: usize = 50; +use crate::sql::dispatch; +use crate::sql::router::RouterRuntime; pub type ClientId = u32; @@ -54,9 +50,10 @@ impl<S> PgStorage<S> { } } - pub fn put(&mut self, key: (ClientId, Rc<str>), value: S) -> traft::Result<()> { + pub fn put(&mut self, key: (ClientId, Rc<str>), value: S) -> PgResult<()> { if self.len() >= self.capacity() { - return Err(Error::Other("Statement storage is full".into())); + // TODO: it should be configuration_limit_exceeded error + return Err(PgError::Other("Statement storage is full".into())); } if key.1.is_empty() { @@ -68,7 +65,8 @@ impl<S> PgStorage<S> { match self.map.entry(key) { Entry::Occupied(entry) => { let (id, name) = entry.key(); - Err(Error::Other( + // TODO: it should be duplicate_cursor or duplicate_prepared_statement error + Err(PgError::Other( format!("Duplicated name \'{name}\' for client {id}").into(), )) } @@ -102,10 +100,6 @@ impl<S> PgStorage<S> { self.map.len() } - pub fn is_empty(&self) -> bool { - self.map.is_empty() - } - pub fn capacity(&self) -> usize { self.capacity } @@ -129,7 +123,7 @@ impl StatementStorage { } } - pub fn put(&mut self, key: (ClientId, Rc<str>), statement: Statement) -> traft::Result<()> { + pub fn put(&mut self, key: (ClientId, Rc<str>), statement: Statement) -> PgResult<()> { self.0.put(key, StatementHolder(statement)) } @@ -152,14 +146,6 @@ impl StatementStorage { pub fn len(&self) -> usize { self.0.len() } - - pub fn is_empty(&self) -> bool { - self.0.is_empty() - } - - pub fn capacity(&self) -> usize { - self.0.capacity() - } } impl Default for StatementStorage { @@ -186,7 +172,7 @@ impl PortalStorage { } } - pub fn put(&mut self, key: (ClientId, Rc<str>), portal: Portal) -> traft::Result<()> { + pub fn put(&mut self, key: (ClientId, Rc<str>), portal: Portal) -> PgResult<()> { self.0.put(key, portal)?; Ok(()) } @@ -212,14 +198,6 @@ impl PortalStorage { pub fn len(&self) -> usize { self.0.len() } - - pub fn is_empty(&self) -> bool { - self.0.is_empty() - } - - pub fn capacity(&self) -> usize { - self.0.capacity() - } } impl Default for PortalStorage { @@ -233,12 +211,12 @@ thread_local! { pub static PG_PORTALS: Rc<RefCell<PortalStorage>> = Rc::new(RefCell::new(PortalStorage::new())); } -pub fn with_portals_mut<T, F>(key: (ClientId, Rc<str>), f: F) -> traft::Result<T> +pub fn with_portals_mut<T, F>(key: (ClientId, Rc<str>), f: F) -> PgResult<T> where - F: FnOnce(&mut Portal) -> traft::Result<T>, + F: FnOnce(&mut Portal) -> PgResult<T>, { let mut portal: Portal = PG_PORTALS.with(|storage| { - storage.borrow_mut().remove(&key).ok_or(Error::Other( + storage.borrow_mut().remove(&key).ok_or(PgError::Other( format!("Couldn't find portal \'{}\'.", key.1).into(), )) })?; @@ -275,7 +253,7 @@ impl StatementInner { query_pattern: String, plan: Plan, specified_param_oids: Vec<u32>, - ) -> Result<Self, Error> { + ) -> PgResult<Self> { let param_oids = derive_param_oids(&plan, specified_param_oids)?; let describe = StatementDescribe::new(Describe::new(&plan)?, param_oids); Ok(Self { @@ -317,7 +295,7 @@ impl Statement { sql: String, plan: Plan, specified_param_oids: Vec<u32>, - ) -> Result<Self, Error> { + ) -> PgResult<Self> { Ok(Self(Rc::new(StatementInner::new( id, sql, @@ -380,17 +358,14 @@ impl StatementDescribe { // TODO: use const from pgwire once pgproto is merged to picodata const TEXT_OID: u32 = 25; -fn derive_param_oids(plan: &Plan, mut param_oids: Vec<Oid>) -> Result<Vec<Oid>, Error> { +fn derive_param_oids(plan: &Plan, mut param_oids: Vec<Oid>) -> PgResult<Vec<Oid>> { let params_count = plan.get_param_set().len(); if params_count < param_oids.len() { - return Err(Error::Other( - format!( - "query has {} parameters, but {} were given", - params_count, - param_oids.len() - ) - .into(), - )); + return Err(PgError::ProtocolViolation(format!( + "query has {} parameters, but {} were given", + params_count, + param_oids.len() + ))); } // Postgres derives oids of unspecified parameters depending on the context. @@ -441,7 +416,7 @@ fn tuple_as_rows(tuple: &Tuple) -> Option<Vec<Value>> { None } -fn take_rows(rows: &mut IntoIter<Value>, max_rows: usize) -> traft::Result<Tuple> { +fn take_rows(rows: &mut IntoIter<Value>, max_rows: usize) -> PgResult<Tuple> { let is_finished = rows.len() <= max_rows; let rows = rows.take(max_rows).collect(); #[derive(Serialize)] @@ -450,13 +425,13 @@ fn take_rows(rows: &mut IntoIter<Value>, max_rows: usize) -> traft::Result<Tuple is_finished: bool, } let result = RunningResult { rows, is_finished }; - let mp = rmp_serde::to_vec_named(&vec![result]).map_err(|e| Error::Other(e.into()))?; - let ret = Tuple::try_from_slice(&mp).map_err(|e| Error::Other(e.into()))?; + let mp = rmp_serde::to_vec_named(&vec![result])?; + let ret = Tuple::try_from_slice(&mp)?; Ok(ret) } impl Portal { - pub fn new(plan: Plan, statement: Statement, output_format: Vec<u8>) -> Result<Self, Error> { + pub fn new(plan: Plan, statement: Statement, output_format: Vec<u8>) -> PgResult<Self> { let stmt_describe = statement.describe(); let describe = PortalDescribe::new(stmt_describe.describe.clone(), output_format); Ok(Self { @@ -467,7 +442,7 @@ impl Portal { }) } - pub fn execute(&mut self, max_rows: usize) -> traft::Result<Tuple> { + pub fn execute(&mut self, max_rows: usize) -> PgResult<Tuple> { loop { match &mut self.state { PortalState::NotStarted => self.start()?, @@ -485,7 +460,7 @@ impl Portal { return Ok(res); } _ => { - return Err(Error::Other( + return Err(PgError::Other( format!("Can't execute portal in state {:?}", self.state).into(), )) } @@ -493,8 +468,8 @@ impl Portal { } } - fn start(&mut self) -> traft::Result<()> { - let runtime = RouterRuntime::new().map_err(Error::from)?; + fn start(&mut self) -> PgResult<()> { + let runtime = RouterRuntime::new()?; let query = Query::from_parts( self.plan.is_explain(), // XXX: the router runtime cache contains only unbinded IR plans to @@ -714,7 +689,7 @@ impl Describe { self } - pub fn new(plan: &Plan) -> Result<Self, SbroadError> { + pub fn new(plan: &Plan) -> PgResult<Self> { let command_tag = if plan.is_explain() { CommandTag::Explain } else { diff --git a/src/pgproto/error.rs b/src/pgproto/error.rs index 8d78bcd4d298a3e4bb3031f61c9b2ece87ce06b2..bc25864d6a1ec548085df969af5f0417fba30973 100644 --- a/src/pgproto/error.rs +++ b/src/pgproto/error.rs @@ -5,6 +5,8 @@ use std::io; use std::num::{ParseFloatError, ParseIntError}; use std::str::ParseBoolError; use std::string::FromUtf8Error; +use tarantool::error::BoxError; +use tarantool::error::IntoBoxError; use thiserror::Error; pub type PgResult<T> = Result<T, PgError>; @@ -41,6 +43,21 @@ pub enum PgError { #[error("tls error: {0}")] TlsError(#[from] TlsError), + + #[error("sbroad error: {0}")] + SbroadError(#[from] sbroad::errors::SbroadError), + + #[error("traft error: {0}")] + TraftError(Box<crate::traft::error::Error>), + + #[error("tarantool error: {0}")] + TarantoolError(#[from] tarantool::error::Error), + + #[error("encoding error: {0}")] + RmpSerdeEncode(#[from] rmp_serde::encode::Error), + + #[error("{0}")] + Other(Box<dyn error::Error>), } #[derive(Error, Debug)] @@ -61,6 +78,12 @@ pub enum DecodingError { Other(Box<dyn error::Error>), } +impl From<crate::traft::error::Error> for PgError { + fn from(value: crate::traft::error::Error) -> Self { + PgError::TraftError(value.into()) + } +} + /// Build error info from PgError. impl PgError { pub fn info(&self) -> ErrorInfo { @@ -85,3 +108,9 @@ impl PgError { } } } + +impl IntoBoxError for PgError { + fn into_box_error(self) -> BoxError { + self.to_string().into_box_error() + } +} diff --git a/src/sql.rs b/src/sql.rs index 9671b5de523bd9db8aef4780a569da2a4a0b12b9..db3baf994ed878855df8b1d2ee5b9b3315ed02cd 100644 --- a/src/sql.rs +++ b/src/sql.rs @@ -7,10 +7,6 @@ use crate::schema::{ RoutineLanguage, RoutineParamDef, RoutineParams, RoutineSecurity, SchemaObjectType, ShardingFn, UserDef, ADMIN_ID, }; -use crate::sql::pgproto::{ - with_portals_mut, Portal, PortalDescribe, Statement, StatementDescribe, UserPortalNames, - UserStatementNames, PG_PORTALS, PG_STATEMENTS, -}; use crate::sql::router::RouterRuntime; use crate::sql::storage::StorageRuntime; use crate::storage::space_by_name; @@ -21,18 +17,14 @@ use crate::traft::{self, node}; use crate::util::{duration_from_secs_f64_clamped, effective_user_id}; use crate::{cas, unwrap_ok_or}; -use opentelemetry::sdk::trace::Tracer; use opentelemetry::{baggage::BaggageExt, Context, KeyValue}; use sbroad::backend::sql::ir::{EncodedPatternWithParams, PatternWithParams}; use sbroad::debug; use sbroad::errors::{Action, Entity, SbroadError}; use sbroad::executor::engine::helpers::{decode_msgpack, normalize_name_for_space_api}; -use sbroad::executor::engine::{QueryCache, Router, TableVersionMap}; -use sbroad::executor::lru::Cache; use sbroad::executor::protocol::{EncodedRequiredData, RequiredData}; use sbroad::executor::result::ConsumerResult; use sbroad::executor::Query; -use sbroad::frontend::Ast; use sbroad::ir::acl::{Acl, AlterOption, GrantRevokeType, Privilege as SqlPrivilege}; use sbroad::ir::block::Block; use sbroad::ir::ddl::{Ddl, ParamDef}; @@ -40,15 +32,13 @@ use sbroad::ir::expression::Expression; use sbroad::ir::operator::Relational; use sbroad::ir::relation::Type; use sbroad::ir::tree::traversal::{PostOrderWithFilter, REL_CAPACITY}; -use sbroad::ir::value::{LuaValue, Value}; +use sbroad::ir::value::Value; use sbroad::ir::{Node as IrNode, Plan as IrPlan}; -use sbroad::otm::{query_id, query_span, OTM_CHAR_LIMIT}; -use serde::Deserialize; -use smol_str::{format_smolstr, SmolStr, ToSmolStr}; +use sbroad::otm::query_span; +use smol_str::{format_smolstr, SmolStr}; use tarantool::access_control::{box_access_check_ddl, SchemaObjectType as TntSchemaObjectType}; use tarantool::schema::function::func_next_reserved_id; -use self::pgproto::{ClientId, Oid}; use crate::storage::Clusterwide; use ::tarantool::access_control::{box_access_check_space, PrivType}; use ::tarantool::auth::{AuthData, AuthDef, AuthMethod}; @@ -57,14 +47,11 @@ use ::tarantool::session::{with_su, UserId}; use ::tarantool::space::{FieldType, Space, SpaceId, SystemSpace}; use ::tarantool::time::Instant; use ::tarantool::tuple::{RawBytes, Tuple}; -use sbroad::utils::MutexLike; -use std::rc::Rc; use std::str::FromStr; use tarantool::session; pub mod otm; -pub mod pgproto; pub mod router; pub mod storage; use otm::TracerKind; @@ -190,7 +177,7 @@ fn check_routine_privileges(plan: &IrPlan) -> traft::Result<()> { Ok(()) } -fn dispatch(mut query: Query<RouterRuntime>) -> traft::Result<Tuple> { +pub fn dispatch(mut query: Query<RouterRuntime>) -> traft::Result<Tuple> { if query.is_ddl().map_err(Error::from)? || query.is_acl().map_err(Error::from)? { let ir_plan = query.get_exec_plan().get_ir_plan(); let top_id = ir_plan.get_top().map_err(Error::from)?; @@ -333,231 +320,6 @@ pub fn dispatch_query(encoded_params: EncodedPatternWithParams) -> traft::Result ) } -struct BindArgs { - id: ClientId, - stmt_name: String, - portal_name: String, - params: Vec<Value>, - encoding_format: Vec<u8>, - traceable: bool, -} - -impl<'de> Deserialize<'de> for BindArgs { - fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> - where - D: serde::Deserializer<'de>, - { - #[derive(Deserialize)] - struct EncodedBindArgs( - ClientId, - String, - String, - Option<Vec<LuaValue>>, - Vec<u8>, - Option<bool>, - ); - - let EncodedBindArgs(id, stmt_name, portal_name, params, encoding_format, traceable) = - EncodedBindArgs::deserialize(deserializer)?; - - let params = params - .unwrap_or_default() - .into_iter() - .map(Value::from) - .collect::<Vec<Value>>(); - - Ok(Self { - id, - stmt_name, - portal_name, - params, - encoding_format, - traceable: traceable.unwrap_or(false), - }) - } -} - -// helper function to get `TracerRef` -fn get_tracer_param(traceable: bool) -> &'static Tracer { - let kind = TracerKind::from_traceable(traceable); - kind.get_tracer() -} - -#[proc(packed_args)] -pub fn proc_pg_bind(args: BindArgs) -> traft::Result<()> { - let BindArgs { - id, - stmt_name, - portal_name, - params, - encoding_format: output_format, - traceable, - } = args; - let key = (id, stmt_name.into()); - let Some(statement) = PG_STATEMENTS.with(|storage| storage.borrow().get(&key)) else { - return Err(Error::Other( - format!("Couldn't find statement \'{}\'.", key.1).into(), - )); - }; - let mut plan = statement.plan().clone(); - let ctx = with_tracer(Context::new(), TracerKind::from_traceable(traceable)); - let portal = query_span::<traft::Result<_>, _>( - "\"api.router.bind\"", - statement.id(), - get_tracer_param(traceable), - &ctx, - statement.query_pattern(), - || { - if !plan.is_ddl()? && !plan.is_acl()? { - plan.bind_params(params)?; - plan.apply_options()?; - plan.optimize()?; - } - Portal::new(plan, statement.clone(), output_format) - }, - )?; - - PG_PORTALS.with(|storage| storage.borrow_mut().put((id, portal_name.into()), portal))?; - Ok(()) -} - -#[proc] -pub fn proc_pg_statements(id: ClientId) -> UserStatementNames { - UserStatementNames::new(id) -} - -#[proc] -pub fn proc_pg_portals(id: ClientId) -> UserPortalNames { - UserPortalNames::new(id) -} - -#[proc] -pub fn proc_pg_close_stmt(id: ClientId, name: String) { - // Close can't cause an error in PG. - PG_STATEMENTS.with(|storage| storage.borrow_mut().remove(&(id, name.into()))); -} - -#[proc] -pub fn proc_pg_close_portal(id: ClientId, name: String) { - // Close can't cause an error in PG. - PG_PORTALS.with(|storage| storage.borrow_mut().remove(&(id, name.into()))); -} - -#[proc] -pub fn proc_pg_close_client_stmts(id: ClientId) { - PG_STATEMENTS.with(|storage| storage.borrow_mut().remove_by_client_id(id)) -} - -#[proc] -pub fn proc_pg_close_client_portals(id: ClientId) { - PG_PORTALS.with(|storage| storage.borrow_mut().remove_by_client_id(id)) -} - -#[proc] -pub fn proc_pg_describe_stmt(id: ClientId, name: String) -> Result<StatementDescribe, Error> { - let key = (id, name.into()); - let Some(statement) = PG_STATEMENTS.with(|storage| storage.borrow().get(&key)) else { - return Err(Error::Other( - format!("Couldn't find statement \'{}\'.", key.1).into(), - )); - }; - Ok(statement.describe().clone()) -} - -#[proc] -pub fn proc_pg_describe_portal(id: ClientId, name: String) -> traft::Result<PortalDescribe> { - with_portals_mut((id, name.into()), |portal| Ok(portal.describe().clone())) -} - -#[proc] -pub fn proc_pg_execute( - id: ClientId, - name: String, - max_rows: i64, - traceable: bool, -) -> traft::Result<Tuple> { - let max_rows = if max_rows <= 0 { i64::MAX } else { max_rows }; - let name = Rc::from(name); - - let statement = with_portals_mut((id, Rc::clone(&name)), |portal| { - // We are cloning Rc here. - Ok(portal.statement().clone()) - })?; - with_portals_mut((id, name), |portal| { - let ctx = with_tracer(Context::new(), TracerKind::from_traceable(traceable)); - query_span::<traft::Result<Tuple>, _>( - "\"api.router.execute\"", - statement.id(), - get_tracer_param(traceable), - &ctx, - statement.query_pattern(), - || portal.execute(max_rows as usize), - ) - }) -} - -#[proc] -pub fn proc_pg_parse( - cid: ClientId, - name: String, - query: String, - param_oids: Vec<Oid>, - traceable: bool, -) -> traft::Result<()> { - let id = query_id(&query); - // Keep the query patterns for opentelemetry spans short enough. - let sql = query - .char_indices() - .filter_map(|(i, c)| if i <= OTM_CHAR_LIMIT { Some(c) } else { None }) - .collect::<String>(); - let ctx = with_tracer(Context::new(), TracerKind::from_traceable(traceable)); - query_span::<traft::Result<()>, _>( - "\"api.router.parse\"", - &id.clone(), - get_tracer_param(traceable), - &ctx, - &sql.clone(), - || { - let runtime = RouterRuntime::new().map_err(Error::from)?; - let mut cache = runtime.cache().lock(); - let cache_entry = with_su(ADMIN_ID, || cache.get(&query.to_smolstr()))??; - if let Some(plan) = cache_entry { - let statement = - Statement::new(id.to_string(), sql.clone(), plan.clone(), param_oids)?; - PG_STATEMENTS - .with(|cache| cache.borrow_mut().put((cid, name.into()), statement))?; - return Ok(()); - } - let metadata = &*runtime.metadata().lock(); - let plan = with_su(ADMIN_ID, || -> traft::Result<IrPlan> { - let mut plan = - <RouterRuntime as Router>::ParseTree::transform_into_plan(&query, metadata) - .map_err(Error::from)?; - if runtime.provides_versions() { - let mut table_version_map = - TableVersionMap::with_capacity(plan.relations.tables.len()); - for table in plan.relations.tables.keys() { - let normalized = normalize_name_for_space_api(table); - let version = runtime - .get_table_version(normalized.as_str()) - .map_err(Error::from)?; - table_version_map.insert(normalized, version); - } - plan.version_map = table_version_map; - } - Ok(plan) - })??; - if !plan.is_ddl()? && !plan.is_acl()? { - cache.put(query.to_smolstr(), plan.clone())?; - } - let statement = Statement::new(id.to_string(), sql, plan, param_oids)?; - PG_STATEMENTS - .with(|storage| storage.borrow_mut().put((cid, name.into()), statement))?; - Ok(()) - }, - ) -} - impl TryFrom<&SqlPrivilege> for PrivilegeType { type Error = SbroadError; diff --git a/src/storage.rs b/src/storage.rs index bb09d6387a2c3c31e4fd9df86c23e06bc1d7e495..fa67fdaa8e5083b85a9759a00227af61150c13c5 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -21,6 +21,7 @@ use tarantool::util::NumOrStr; use crate::access_control::{user_by_id, UserMetadataKind}; use crate::failure_domain::FailureDomain; use crate::instance::{self, Instance}; +use crate::pgproto::{DEFAULT_MAX_PG_PORTALS, DEFAULT_MAX_PG_STATEMENTS}; use crate::replicaset::Replicaset; use crate::schema::{ Distribution, PrivilegeType, SchemaObjectType, ServiceDef, ServiceRouteItem, ServiceRouteKey, @@ -29,7 +30,6 @@ use crate::schema::{IndexDef, IndexOption, TableDef}; use crate::schema::{PluginDef, INITIAL_SCHEMA_VERSION}; use crate::schema::{PrivilegeDef, RoutineDef, UserDef}; use crate::schema::{ADMIN_ID, PUBLIC_ID, UNIVERSE_ID}; -use crate::sql::pgproto::{DEFAULT_MAX_PG_PORTALS, DEFAULT_MAX_PG_STATEMENTS}; use crate::tier::Tier; use crate::traft; use crate::traft::error::Error;