Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • core/picodata
1 result
Show changes
Commits on Source (2)
......@@ -137,6 +137,15 @@ pub(super) fn prepare(
)
);
#[rustfmt::skip]
init_entries_push_op(
op::Dml::insert(
ClusterwideTable::Property,
&(PropertyName::MaxPgStatements, pgproto::DEFAULT_MAX_PG_STATEMENTS),
ADMIN_ID,
)
);
#[rustfmt::skip]
init_entries_push_op(
op::Dml::insert(
......
......@@ -136,11 +136,16 @@ fn init_sbroad() {
for (module, func) in &[
("sbroad", "sql"),
("pgproto", "pg_bind"),
("pgproto", "pg_close"),
("pgproto", "pg_describe"),
("pgproto", "pg_close_stmt"),
("pgproto", "pg_close_portal"),
("pgproto", "pg_describe_stmt"),
("pgproto", "pg_describe_portal"),
("pgproto", "pg_execute"),
("pgproto", "pg_parse"),
("pgproto", "pg_statements"),
("pgproto", "pg_portals"),
("pgproto", "pg_close_client_stmts"),
("pgproto", "pg_close_client_portals"),
] {
let program = format!(
r#"
......
......@@ -7,7 +7,8 @@ use crate::schema::{
ADMIN_ID,
};
use crate::sql::pgproto::{
with_portals, BoxedPortal, Describe, Descriptor, UserDescriptors, PG_PORTALS,
with_portals_mut, Portal, PortalDescribe, Statement, StatementDescribe, UserPortalNames,
UserStatementNames, PG_PORTALS, PG_STATEMENTS,
};
use crate::sql::router::RouterRuntime;
use crate::sql::storage::StorageRuntime;
......@@ -26,7 +27,6 @@ use sbroad::debug;
use sbroad::errors::{Action, Entity, SbroadError};
use sbroad::executor::engine::helpers::{decode_msgpack, normalize_name_for_space_api};
use sbroad::executor::engine::{QueryCache, Router, TableVersionMap};
use sbroad::executor::ir::ExecutionPlan;
use sbroad::executor::lru::Cache;
use sbroad::executor::protocol::{EncodedRequiredData, RequiredData};
use sbroad::executor::result::ConsumerResult;
......@@ -46,6 +46,7 @@ use serde::Deserialize;
use tarantool::access_control::{box_access_check_ddl, SchemaObjectType as TntSchemaObjectType};
use tarantool::schema::function::func_next_reserved_id;
use self::pgproto::{ClientId, Oid};
use crate::storage::Clusterwide;
use ::tarantool::access_control::{box_access_check_space, PrivType};
use ::tarantool::auth::{AuthData, AuthDef, AuthMethod};
......@@ -54,11 +55,12 @@ use ::tarantool::session::{with_su, UserId};
use ::tarantool::space::{FieldType, Space, SpaceId, SystemSpace};
use ::tarantool::time::Instant;
use ::tarantool::tuple::{RawBytes, Tuple};
use std::collections::HashMap;
use std::rc::Rc;
use std::str::FromStr;
use tarantool::session;
pub mod otm;
pub mod pgproto;
pub mod router;
pub mod storage;
......@@ -329,26 +331,30 @@ pub fn dispatch_query(encoded_params: EncodedPatternWithParams) -> traft::Result
}
struct BindArgs {
descriptor: Descriptor,
id: ClientId,
stmt_name: String,
portal_name: String,
params: Vec<Value>,
encoding_format: Vec<u8>,
traceable: bool,
}
impl BindArgs {
fn take(self) -> (Descriptor, Vec<Value>, bool) {
(self.descriptor, self.params, self.traceable)
}
}
impl<'de> Deserialize<'de> for BindArgs {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
struct EncodedBindArgs(Descriptor, Option<Vec<LuaValue>>, Option<bool>);
struct EncodedBindArgs(
ClientId,
String,
String,
Option<Vec<LuaValue>>,
Vec<u8>,
Option<bool>,
);
let EncodedBindArgs(descriptor, params, traceable) =
let EncodedBindArgs(id, stmt_name, portal_name, params, encoding_format, traceable) =
EncodedBindArgs::deserialize(deserializer)?;
let params = params
......@@ -358,8 +364,11 @@ impl<'de> Deserialize<'de> for BindArgs {
.collect::<Vec<Value>>();
Ok(Self {
descriptor,
id,
stmt_name,
portal_name,
params,
encoding_format,
traceable: traceable.unwrap_or(false),
})
}
......@@ -373,90 +382,125 @@ fn get_tracer_param(traceable: bool) -> &'static Tracer {
#[proc(packed_args)]
pub fn proc_pg_bind(args: BindArgs) -> traft::Result<()> {
let (key, params, traceable) = args.take();
with_portals(key, |portal| {
let mut plan = std::mem::take(portal.plan_mut());
let ctx = with_tracer(Context::new(), TracerKind::from_traceable(traceable));
query_span::<traft::Result<()>, _>(
"\"api.router.bind\"",
portal.id(),
get_tracer_param(traceable),
&ctx,
portal.sql(),
|| {
if !plan.is_ddl()? && !plan.is_acl()? {
plan.bind_params(params)?;
plan.apply_options()?;
plan.optimize()?;
}
Ok(())
},
)?;
*portal.plan_mut() = plan;
Ok(())
})
let BindArgs {
id,
stmt_name,
portal_name,
params,
encoding_format: output_format,
traceable,
} = args;
let key = (id, stmt_name.into());
let Some(statement) = PG_STATEMENTS.with(|storage| storage.borrow().get(&key)) else {
return Err(Error::Other(
format!("Couldn't find statement \'{}\'.", key.1).into(),
));
};
let mut plan = statement.plan().clone();
let ctx = with_tracer(Context::new(), TracerKind::from_traceable(traceable));
let portal = query_span::<traft::Result<_>, _>(
"\"api.router.bind\"",
statement.id(),
get_tracer_param(traceable),
&ctx,
statement.query_pattern(),
|| {
if !plan.is_ddl()? && !plan.is_acl()? {
plan.bind_params(params)?;
plan.apply_options()?;
plan.optimize()?;
}
Portal::new(plan, statement.clone(), output_format)
},
)?;
PG_PORTALS.with(|storage| storage.borrow_mut().put((id, portal_name.into()), portal))?;
Ok(())
}
#[proc]
pub fn proc_pg_portals() -> UserDescriptors {
UserDescriptors::new()
pub fn proc_pg_statements(id: ClientId) -> UserStatementNames {
UserStatementNames::new(id)
}
#[proc]
pub fn proc_pg_close(key: Descriptor) -> traft::Result<()> {
let portal: BoxedPortal = PG_PORTALS.with(|storage| storage.borrow_mut().remove(key))?;
drop(portal);
Ok(())
pub fn proc_pg_portals(id: ClientId) -> UserPortalNames {
UserPortalNames::new(id)
}
#[proc]
pub fn proc_pg_describe(key: Descriptor, traceable: bool) -> traft::Result<Describe> {
with_portals(key, |portal| {
let ctx = with_tracer(Context::new(), TracerKind::from_traceable(traceable));
let description = query_span::<traft::Result<Describe>, _>(
"\"api.router.describe\"",
portal.id(),
get_tracer_param(traceable),
&ctx,
portal.sql(),
|| Describe::new(portal.plan()).map_err(Error::from),
)?;
Ok(description)
})
pub fn proc_pg_close_stmt(id: ClientId, name: String) {
// Close can't cause an error in PG.
PG_STATEMENTS.with(|storage| storage.borrow_mut().remove(&(id, name.into())));
}
#[proc]
pub fn proc_pg_close_portal(id: ClientId, name: String) {
// Close can't cause an error in PG.
PG_PORTALS.with(|storage| storage.borrow_mut().remove(&(id, name.into())));
}
#[proc]
pub fn proc_pg_close_client_stmts(id: ClientId) {
PG_STATEMENTS.with(|storage| storage.borrow_mut().remove_by_client_id(id))
}
#[proc]
pub fn proc_pg_close_client_portals(id: ClientId) {
PG_PORTALS.with(|storage| storage.borrow_mut().remove_by_client_id(id))
}
#[proc]
pub fn proc_pg_describe_stmt(id: ClientId, name: String) -> Result<StatementDescribe, Error> {
let key = (id, name.into());
let Some(statement) = PG_STATEMENTS.with(|storage| storage.borrow().get(&key)) else {
return Err(Error::Other(
format!("Couldn't find statement \'{}\'.", key.1).into(),
));
};
Ok(statement.describe().clone())
}
#[proc]
pub fn proc_pg_execute(key: Descriptor, traceable: bool) -> traft::Result<Tuple> {
with_portals(key, |portal| {
pub fn proc_pg_describe_portal(id: ClientId, name: String) -> traft::Result<PortalDescribe> {
with_portals_mut((id, name.into()), |portal| Ok(portal.describe().clone()))
}
#[proc]
pub fn proc_pg_execute(
id: ClientId,
name: String,
max_rows: i64,
traceable: bool,
) -> traft::Result<Tuple> {
let max_rows = if max_rows <= 0 { i64::MAX } else { max_rows };
let name = Rc::from(name);
let statement = with_portals_mut((id, Rc::clone(&name)), |portal| {
// We are cloning Rc here.
Ok(portal.statement().clone())
})?;
with_portals_mut((id, name), |portal| {
let ctx = with_tracer(Context::new(), TracerKind::from_traceable(traceable));
let res = query_span::<traft::Result<Tuple>, _>(
query_span::<traft::Result<Tuple>, _>(
"\"api.router.execute\"",
portal.id(),
statement.id(),
get_tracer_param(traceable),
&ctx,
portal.sql(),
|| {
let runtime = RouterRuntime::new().map_err(Error::from)?;
let query = Query::from_parts(
portal.plan().is_explain(),
// XXX: the router runtime cache contains only unbinded IR plans to
// speed up SQL parsing and metadata resolution. We need to clone the
// plan here as its IR would be mutate during query execution (bind,
// optimization, dispatch steps). Otherwise we'll polute the parsing
// cache entry.
ExecutionPlan::from(portal.plan().clone()),
&runtime,
HashMap::new(),
);
dispatch(query)
},
)?;
Ok(res)
statement.query_pattern(),
|| portal.execute(max_rows as usize),
)
})
}
#[proc]
pub fn proc_pg_parse(query: String, traceable: bool) -> traft::Result<Descriptor> {
pub fn proc_pg_parse(
cid: ClientId,
name: String,
query: String,
param_oids: Vec<Oid>,
traceable: bool,
) -> traft::Result<()> {
let id = query_id(&query);
// Keep the query patterns for opentelemetry spans short enough.
let sql = query
......@@ -464,7 +508,7 @@ pub fn proc_pg_parse(query: String, traceable: bool) -> traft::Result<Descriptor
.filter_map(|(i, c)| if i <= OTM_CHAR_LIMIT { Some(c) } else { None })
.collect::<String>();
let ctx = with_tracer(Context::new(), TracerKind::from_traceable(traceable));
query_span::<traft::Result<Descriptor>, _>(
query_span::<traft::Result<()>, _>(
"\"api.router.parse\"",
&id.clone(),
get_tracer_param(traceable),
......@@ -477,10 +521,10 @@ pub fn proc_pg_parse(query: String, traceable: bool) -> traft::Result<Descriptor
.try_borrow_mut()
.map_err(|e| Error::Other(format!("runtime query cache: {e:?}").into()))?;
if let Some(plan) = cache.get(&query)? {
let portal = BoxedPortal::new(id, sql.clone(), plan.clone());
let descriptor = portal.descriptor();
PG_PORTALS.with(|cache| cache.borrow_mut().put(descriptor, portal))?;
return Ok(descriptor);
let statement = Statement::new(id, sql.clone(), plan.clone(), param_oids)?;
PG_STATEMENTS
.with(|cache| cache.borrow_mut().put((cid, name.into()), statement))?;
return Ok(());
}
let ast = <RouterRuntime as Router>::ParseTree::new(&query).map_err(Error::from)?;
let metadata = &*runtime.metadata().map_err(Error::from)?;
......@@ -503,10 +547,10 @@ pub fn proc_pg_parse(query: String, traceable: bool) -> traft::Result<Descriptor
if !plan.is_ddl()? && !plan.is_acl()? {
cache.put(query, plan.clone())?;
}
let portal = BoxedPortal::new(id, sql, plan);
let descriptor = portal.descriptor();
PG_PORTALS.with(|storage| storage.borrow_mut().put(descriptor, portal))?;
Ok(descriptor)
let statement = Statement::new(id, sql, plan, param_oids)?;
PG_STATEMENTS
.with(|storage| storage.borrow_mut().put((cid, name.into()), statement))?;
Ok(())
},
)
}
......
dispatch_query
execute
proc_pg_bind
proc_pg_close
proc_pg_describe
proc_pg_close_stmt
proc_pg_close_portal
proc_pg_describe_stmt
proc_pg_describe_portal
proc_pg_execute
proc_pg_parse
proc_pg_portals
proc_pg_statements
proc_pg_close_client_stmts
proc_pg_close_client_portals
require('sbroad.core-router')
require('sbroad.core-storage')
local helper = require('sbroad.helper')
local ffi = require('ffi')
local function check_descriptor(descriptor)
local desc_type = type(descriptor)
if desc_type == "number" or desc_type == "cdata" and ffi.sizeof(descriptor) == 8 then
return true
local function verify_out_formats(formats)
if type(formats) ~= "table" then
return false
end
return false
for _, format in ipairs(formats) do
if format ~= 0 and format ~= 1 then
return false
end
end
return true
end
local function pg_bind(...)
local n_args = select("#", ...)
if n_args == 0 or n_args > 3 then
return nil, "Usage: pg_bind(descriptor, params[, traceable])"
if n_args == 0 or n_args > 6 then
return nil, "Usage: pg_bind(id, stmt_name, portal_name, params, out_formats[, traceable])"
end
local descriptor, params, traceable = ...
if not check_descriptor(descriptor) then
return nil, "descriptor must be a number"
local id, stmt_name, portal_name, params, out_formats, traceable = ...
if type(id) ~= "number" or (type(id) == "number" and id < 0) then
return nil, "id must be a non-negative number"
end
if type(stmt_name) ~= "string" then
return nil, "name must be a string"
end
if type(portal_name) ~= "string" then
return nil, "name must be a string"
end
if params ~= nil and type(params) ~= "table" then
return nil, "parameters must be a table"
end
if not verify_out_formats(out_formats) then
return nil, "out_formats must be an array of 0 and 1"
end
if traceable ~= nil and type(traceable) ~= "boolean" then
return nil, "trace flag must be a boolean"
end
......@@ -32,7 +46,9 @@ local function pg_bind(...)
local ok, err = pcall(
function()
return box.func[".proc_pg_bind"]:call({ descriptor, params, traceable })
return box.func[".proc_pg_bind"]:call({
id, stmt_name, portal_name, params, out_formats, traceable
})
end
)
......@@ -43,13 +59,17 @@ local function pg_bind(...)
return true
end
local function pg_close(descriptor)
if not check_descriptor(descriptor) then
return nil, "descriptor must be a number"
local function pg_close_stmt(id, name)
if type(id) ~= "number" or (type(id) == "number" and id < 0) then
return nil, "id must be a non-negative number"
end
if type(name) ~= "string" then
return nil, "name must be a string"
end
local ok, err = pcall(
function()
return box.func[".proc_pg_close"]:call({ descriptor })
return box.func[".proc_pg_close_stmt"]:call({ id, name })
end
)
......@@ -60,25 +80,69 @@ local function pg_close(descriptor)
return true
end
local function pg_describe(...)
local function pg_close_portal(id, name)
if type(id) ~= "number" or (type(id) == "number" and id < 0) then
return nil, "id must be a non-negative number"
end
if type(name) ~= "string" then
return nil, "name must be a string"
end
local ok, err = pcall(
function()
return box.func[".proc_pg_close_portal"]:call({ id, name })
end
)
if ok == false then
return nil, err
end
return true
end
local function pg_describe_stmt(...)
local n_args = select("#", ...)
if n_args == 0 or n_args > 2 then
return nil, "Usage: pg_describe(descriptor[, traceable])"
return nil, "Usage: pg_describe_stmt(id, name)"
end
local descriptor, traceable = ...
if not check_descriptor(descriptor) then
return nil, "descriptor must be a number"
local id, name = ...
if type(id) ~= "number" or (type(id) == "number" and id < 0) then
return nil, "id must be a non-negative number"
end
if traceable ~= nil and type(traceable) ~= "boolean" then
return nil, "trace flag must be a boolean"
if type(name) ~= "string" then
return nil, "name must be a string"
end
if traceable == nil then
traceable = false
local ok, res = pcall(
function()
return box.func[".proc_pg_describe_stmt"]:call({ id, name })
end
)
if ok == false then
return nil, res
end
return helper.format_result(res)
end
local function pg_describe_portal(...)
local n_args = select("#", ...)
if n_args == 0 or n_args > 3 then
return nil, "Usage: pg_describe_portal(id, name)"
end
local id, name = ...
if type(id) ~= "number" or (type(id) == "number" and id < 0) then
return nil, "id must be a non-negative number"
end
if type(name) ~= "string" then
return nil, "name must be a string"
end
local ok, res = pcall(
function()
return box.func[".proc_pg_describe"]:call({ descriptor, traceable })
return box.func[".proc_pg_describe_portal"]:call({ id, name })
end
)
......@@ -91,12 +155,18 @@ end
local function pg_execute(...)
local n_args = select("#", ...)
if n_args == 0 or n_args > 2 then
return nil, "Usage: pg_execute(descriptor[, traceable])"
if n_args == 0 or n_args > 4 then
return nil, "Usage: pg_execute(id, name, max_rows[, traceable])"
end
local descriptor, traceable = ...
if not check_descriptor(descriptor) then
return nil, "descriptor must be a number"
local id, name, max_rows, traceable = ...
if type(id) ~= "number" or (type(id) == "number" and id < 0) then
return nil, "id must be a non-negative number"
end
if type(name) ~= "string" then
return nil, "name must be a string"
end
if type(max_rows) ~= "number" then
return nil, "max_rows must be a number"
end
if traceable ~= nil and type(traceable) ~= "boolean" then
return nil, "trace flag must be a boolean"
......@@ -107,7 +177,7 @@ local function pg_execute(...)
local ok, res = pcall(
function()
return box.func[".proc_pg_execute"]:call({ descriptor, traceable })
return box.func[".proc_pg_execute"]:call({ id, name, max_rows, traceable })
end
)
......@@ -118,15 +188,38 @@ local function pg_execute(...)
return helper.format_result(res[1])
end
local function verify_param_oids(param_oids)
if type(param_oids) ~= "table" then
return false
end
for _, oid in ipairs(param_oids) do
if type(oid) ~= "number" or (type(oid) == "number" and oid < 0) then
return false
end
end
return true
end
local function pg_parse(...)
local n_args = select("#", ...)
if n_args == 0 or n_args > 2 then
return nil, "Usage: pg_parse(query[, traceable])"
if n_args == 0 or n_args > 5 then
return nil, "Usage: pg_parse(id, name, query, param_oids, [, traceable])"
end
local id, name, query, param_oids, traceable = ...
if type(id) ~= "number" or (type(id) == "number" and id < 0) then
return nil, "id must be a non-negative number"
end
if type(name) ~= "string" then
return nil, "name must be a string"
end
local query, traceable = ...
if type(query) ~= "string" then
return nil, "query pattern must be a string"
end
if not verify_param_oids(param_oids) then
return nil, "param_oids must be a list of non-negative integers"
end
if traceable ~= nil and type(traceable) ~= "boolean" then
return nil, "trace flag must be a boolean"
end
......@@ -134,9 +227,50 @@ local function pg_parse(...)
traceable = false
end
local ok, err = pcall(
function()
return box.func[".proc_pg_parse"]:call({
id, name, query, param_oids, traceable
})
end
)
if ok == false then
return nil, err
end
return true
end
local function pg_close_client_stmts(id)
if type(id) ~= "number" or (type(id) == "number" and id < 0) then
return nil, "id must be a non-negative number"
end
return pcall(
function()
return box.func[".proc_pg_close_client_stmts"]:call({id})
end
)
end
local function pg_close_client_portals(id)
if type(id) ~= "number" or (type(id) == "number" and id < 0) then
return nil, "id must be a non-negative number"
end
return pcall(
function()
return box.func[".proc_pg_close_client_portals"]:call({id})
end
)
end
local function pg_statements(id)
if type(id) ~= "number" or (type(id) == "number" and id < 0) then
return nil, "id must be a non-negative number"
end
local ok, res = pcall(
function()
return box.func[".proc_pg_parse"]:call({ query, traceable })
return box.func[".proc_pg_statements"]:call({id})
end
)
......@@ -144,13 +278,16 @@ local function pg_parse(...)
return nil, res
end
return tonumber(res)
return helper.format_result(res)
end
local function pg_portals()
local function pg_portals(id)
if type(id) ~= "number" or (type(id) == "number" and id < 0) then
return nil, "id must be a non-negative number"
end
local ok, res = pcall(
function()
return box.func[".proc_pg_portals"]:call({})
return box.func[".proc_pg_portals"]:call({id})
end
)
......@@ -163,9 +300,14 @@ end
return {
pg_bind = pg_bind,
pg_close = pg_close,
pg_describe = pg_describe,
pg_close_stmt = pg_close_stmt,
pg_close_portal = pg_close_portal,
pg_describe_stmt = pg_describe_stmt,
pg_describe_portal = pg_describe_portal,
pg_execute = pg_execute,
pg_parse = pg_parse,
pg_portals = pg_portals,
pg_close_client_portals = pg_close_client_portals,
pg_close_client_stmts = pg_close_client_stmts,
pg_statements = pg_statements,
pg_portals = pg_portals
}
This diff is collapsed.
......@@ -23,7 +23,7 @@ use crate::schema::{Distribution, PrivilegeType, SchemaObjectType};
use crate::schema::{IndexDef, TableDef};
use crate::schema::{PrivilegeDef, RoleDef, RoutineDef, UserDef};
use crate::schema::{ADMIN_ID, PUBLIC_ID, UNIVERSE_ID};
use crate::sql::pgproto::DEFAULT_MAX_PG_PORTALS;
use crate::sql::pgproto::{DEFAULT_MAX_PG_PORTALS, DEFAULT_MAX_PG_STATEMENTS};
use crate::tier::Tier;
use crate::tlog;
use crate::traft;
......@@ -1130,6 +1130,9 @@ impl From<ClusterwideTable> for SpaceId {
/// to an unresponsive instance.
MaxHeartbeatPeriod = "max_heartbeat_period",
/// PG statement storage size.
MaxPgStatements = "max_pg_statements",
/// PG portal storage size.
MaxPgPortals = "max_pg_portals",
......@@ -1210,6 +1213,7 @@ impl PropertyName {
| Self::PasswordMinLength
| Self::MaxLoginAttempts
| Self::MaxPgPortals
| Self::MaxPgStatements
| Self::SnapshotChunkMaxSize => {
// Check it's an unsigned integer.
_ = new.field::<u64>(1).map_err(map_err)?;
......@@ -1242,6 +1246,7 @@ impl PropertyName {
Self::PasswordMinLength
| Self::MaxLoginAttempts
| Self::SnapshotChunkMaxSize
| Self::MaxPgStatements
| Self::MaxPgPortals => {
let v = tuple.field::<usize>(1)?.ok_or_else(bad_value)?;
Some(format!("{v}"))
......@@ -1466,6 +1471,14 @@ impl Properties {
Ok(res)
}
#[inline]
pub fn max_pg_statements(&self) -> tarantool::Result<usize> {
let res = self
.get(PropertyName::MaxPgStatements)?
.unwrap_or(DEFAULT_MAX_PG_STATEMENTS);
Ok(res)
}
#[inline]
pub fn max_pg_portals(&self) -> tarantool::Result<usize> {
let res = self
......
......@@ -1505,32 +1505,90 @@ class Cluster:
)
class PgStorage:
def __init__(self, instance: Instance):
self.instance: Instance = instance
self.client_ids: list[int] = []
def statements(self, id: int):
return self.instance.call("pico.pg_statements", id)
def portals(self, id: int):
return self.instance.call("pico.pg_portals", id)
def bind(self, id, *params):
return self.instance.call("pico.pg_bind", id, *params, False)
def close_stmt(self, id: int, name: str):
return self.instance.call("pico.pg_close_stmt", id, name)
def close_portal(self, id: int, name: str):
return self.instance.call("pico.pg_close_portal", id, name)
def describe_stmt(self, id: int, name: str) -> dict:
return self.instance.call("pico.pg_describe_stmt", id, name)
def describe_portal(self, id: int, name: str) -> dict:
return self.instance.call("pico.pg_describe_portal", id, name)
def execute(self, id: int, name: str, max_rows: int) -> dict:
return self.instance.call("pico.pg_execute", id, name, max_rows, False)
def flush(self):
for id in self.client_ids:
for name in self.statements(id)["available"]:
self.close_stmt(id, name)
for name in self.portals(id)["available"]:
self.close_portal(id, name)
def parse(
self, id: int, name: str, sql: str, param_oids: list[int] | None = None
) -> int:
param_oids = param_oids if param_oids is not None else []
return self.instance.call("pico.pg_parse", id, name, sql, param_oids, False)
def new_client(self, id: int):
self.client_ids.append(id)
return PgClient(self, id)
@dataclass
class PortalStorage:
instance: Instance
class PgClient:
storage: PgStorage
id: int
@property
def instance(self) -> Instance:
return self.storage.instance
@property
def statements(self):
return self.storage.statements(self.id)
@property
def descriptors(self):
return self.instance.call("pico.pg_portals")
def portals(self):
return self.storage.portals(self.id)
def bind(self, *params):
return self.instance.call("pico.pg_bind", *params, False)
return self.storage.bind(self.id, *params)
def close(self, descriptor: int):
return self.instance.call("pico.pg_close", descriptor)
def close_stmt(self, name: str):
return self.storage.close_stmt(self.id, name)
def describe(self, descriptor: int) -> dict:
return self.instance.call("pico.pg_describe", descriptor, False)
def close_portal(self, name: str):
return self.storage.close_portal(self.id, name)
def execute(self, descriptor: int) -> dict:
return self.instance.call("pico.pg_execute", descriptor, False)
def describe_stmt(self, name: str) -> dict:
return self.storage.describe_stmt(self.id, name)
def flush(self):
for descriptor in self.descriptors["available"]:
self.close(descriptor)
def describe_portal(self, name: str) -> dict:
return self.storage.describe_portal(self.id, name)
def execute(self, name: str, max_rows: int = -1) -> dict:
return self.storage.execute(self.id, name, max_rows)
def parse(self, sql: str) -> int:
return self.instance.call("pico.pg_parse", sql, False)
def parse(self, name: str, sql: str, param_oids: list[int] | None = None) -> int:
return self.storage.parse(self.id, name, sql, param_oids)
@pytest.fixture(scope="session")
......@@ -1602,11 +1660,19 @@ def instance(cluster: Cluster, pytestconfig) -> Generator[Instance, None, None]:
@pytest.fixture
def pg_portals(instance: Instance) -> Generator[PortalStorage, None, None]:
"""Returns a PG portal storage on a single instance."""
portals = PortalStorage(instance)
yield portals
portals.flush()
def pg_storage(instance: Instance) -> Generator[PgStorage, None, None]:
"""Returns a PG storage on a single instance."""
storage = PgStorage(instance)
yield storage
storage.flush()
@pytest.fixture
def pg_client(instance: Instance) -> Generator[PgClient, None, None]:
"""Returns a PG client on a single instance."""
storage = PgStorage(instance)
yield storage.new_client(0)
storage.flush()
def retrying(fn, timeout=3):
......
......@@ -343,6 +343,7 @@ def test_raft_log(instance: Instance):
| 0 | 1 |Insert({_pico_property}, ["password_enforce_specialchars",false])|
| 0 | 1 |Insert({_pico_property}, ["auto_offline_timeout",5.0])|
| 0 | 1 |Insert({_pico_property}, ["max_heartbeat_period",5.0])|
| 0 | 1 |Insert({_pico_property}, ["max_pg_statements",50])|
| 0 | 1 |Insert({_pico_property}, ["max_pg_portals",50])|
| 0 | 1 |Insert({_pico_property}, ["snapshot_chunk_max_size",16777216])|
| 0 | 1 |Insert({_pico_property}, ["snapshot_read_view_close_timeout",86400.0])|
......
from conftest import Instance, PortalStorage, ReturnError
from conftest import PgStorage, PgClient, ReturnError
import pytest
def test_extended_ddl(pg_portals: PortalStorage):
assert len(pg_portals.descriptors["available"]) == 0
def test_extended_ddl(pg_client: PgClient):
assert len(pg_client.statements["available"]) == 0
assert len(pg_client.portals["available"]) == 0
ddl = """
create table "t" ("key" int not null, "value" string not null, primary key ("key"))
using vinyl
distributed by ("key")
option (timeout = 3)
"""
id = pg_portals.parse(ddl)
assert len(pg_portals.descriptors["available"]) == 1
pg_client.parse("", ddl)
assert len(pg_client.statements["available"]) == 1
assert len(pg_client.portals["available"]) == 0
desc = pg_portals.describe(id)
desc = pg_client.describe_stmt("")
assert desc["param_oids"] == []
assert desc["query_type"] == 1
assert desc["command_tag"] == 2
assert desc["metadata"] == []
data = pg_portals.execute(id)
pg_client.bind("", "portal", [], [])
assert len(pg_client.portals["available"]) == 1
data = pg_client.execute("portal")
assert data["row_count"] == 1
assert len(pg_portals.descriptors["available"]) == 1
assert len(pg_client.statements["available"]) == 1
pg_portals.close(id)
assert len(pg_portals.descriptors["available"]) == 0
pg_client.close_stmt("")
assert len(pg_client.statements["available"]) == 0
def test_extended_dml(pg_portals: PortalStorage):
instance = pg_portals.instance
def test_extended_dml(pg_client: PgClient):
instance = pg_client.instance
instance.sql(
"""
create table "t" ("key" int not null, "value" string not null, primary key ("key"))
......@@ -37,28 +42,35 @@ def test_extended_dml(pg_portals: PortalStorage):
"""
)
assert len(pg_portals.descriptors["available"]) == 0
assert len(pg_client.statements["available"]) == 0
assert len(pg_client.portals["available"]) == 0
dml = """ insert into "t" values (?, ?) """
id = pg_portals.parse(dml)
assert len(pg_portals.descriptors["available"]) == 1
pg_client.parse("", dml)
assert len(pg_client.statements["available"]) == 1
assert len(pg_client.portals["available"]) == 0
pg_portals.bind(id, [1, "a"])
pg_client.bind("", "", [1, "a"], [])
assert len(pg_client.portals["available"]) == 1
desc = pg_portals.describe(id)
desc = pg_client.describe_stmt("")
# params were not specified, so they are treated as text
assert desc["param_oids"] == [25, 25]
assert desc["query_type"] == 2
assert desc["command_tag"] == 9
assert desc["metadata"] == []
data = pg_portals.execute(id)
data = pg_client.execute("")
pg_client.close_portal("")
assert data["row_count"] == 1
assert len(pg_portals.descriptors["available"]) == 1
assert len(pg_client.statements["available"]) == 1
assert len(pg_client.portals["available"]) == 0
pg_portals.close(id)
assert len(pg_portals.descriptors["available"]) == 0
pg_client.close_stmt("")
assert len(pg_client.statements["available"]) == 0
def test_extended_dql(pg_portals: PortalStorage):
instance = pg_portals.instance
def test_extended_dql(pg_client: PgClient):
instance = pg_client.instance
instance.sql(
"""
create table "t" ("key" int not null, "value" string not null, primary key ("key"))
......@@ -69,14 +81,20 @@ def test_extended_dql(pg_portals: PortalStorage):
)
# First query
assert len(pg_portals.descriptors["available"]) == 0
assert len(pg_client.statements["available"]) == 0
assert len(pg_client.portals["available"]) == 0
dql = """ select * from "t" """
id1 = pg_portals.parse(dql)
assert len(pg_portals.descriptors["available"]) == 1
id1 = "1"
pg_client.parse(id1, dql)
assert len(pg_client.statements["available"]) == 1
assert len(pg_client.portals["available"]) == 0
pg_portals.bind(id1, [])
pg_client.bind(id1, id1, [], [])
assert len(pg_client.statements["available"]) == 1
assert len(pg_client.portals["available"]) == 1
desc = pg_portals.describe(id1)
desc = pg_client.describe_stmt(id1)
assert desc["param_oids"] == []
assert desc["query_type"] == 3
assert desc["command_tag"] == 12
assert desc["metadata"] == [
......@@ -84,49 +102,183 @@ def test_extended_dql(pg_portals: PortalStorage):
{"name": '"value"', "type": "string"},
]
data = pg_portals.execute(id1)
data = pg_client.execute(id1)
pg_client.close_portal(id1)
assert len(pg_client.statements["available"]) == 1
assert len(pg_client.portals["available"]) == 0
assert data["rows"] == []
# Second query
dql = """ select "value" as "value_alias" from "t" """
id2 = pg_portals.parse(dql)
assert len(pg_portals.descriptors["available"]) == 2
id2 = "2"
pg_client.parse(id2, dql)
assert len(pg_client.statements["available"]) == 2
assert len(pg_client.portals["available"]) == 0
pg_portals.bind(id2, [])
pg_client.bind(id2, id2, [], [])
assert len(pg_client.statements["available"]) == 2
assert len(pg_client.portals["available"]) == 1
desc = pg_portals.describe(id2)
desc = pg_client.describe_stmt(id2)
assert desc["param_oids"] == []
assert desc["query_type"] == 3
assert desc["command_tag"] == 12
assert desc["metadata"] == [{"name": '"value_alias"', "type": "string"}]
# Third query
dql = """ select * from (values (1)) """
id3 = pg_portals.parse(dql)
assert len(pg_portals.descriptors["available"]) == 3
id3 = "3"
pg_client.parse(id3, dql)
assert len(pg_client.statements["available"]) == 3
assert len(pg_client.portals["available"]) == 1
pg_portals.bind(id3, [])
pg_client.bind(id3, id3, [], [])
assert len(pg_client.statements["available"]) == 3
assert len(pg_client.portals["available"]) == 2
desc = pg_portals.describe(id3)
desc = pg_client.describe_stmt(id3)
assert desc["param_oids"] == []
assert desc["query_type"] == 3
assert desc["command_tag"] == 12
assert desc["metadata"] == [{"name": '"COLUMN_1"', "type": "unsigned"}]
# Flush the cache
pg_portals.flush()
assert len(pg_portals.descriptors["available"]) == 0
def test_extended_errors(pg_portals: PortalStorage):
def test_extended_errors(pg_client: PgClient):
sql = """ invalid syntax """
with pytest.raises(ReturnError, match="rule parsing error"):
pg_portals.parse(sql)
pg_client.parse("", sql)
sql = """ select * from "t" """
with pytest.raises(ReturnError, match="space t not found"):
pg_portals.parse(sql)
pg_client.parse("", sql)
def test_updates_with_unnamed(pg_client: PgClient):
sql = """
create table "t" ("val" int not null, primary key ("val"))
using vinyl
distributed by ("val")
option (timeout = 3)
"""
pg_client.parse("", sql)
pg_client.bind("", "", [], [])
pg_client.execute("")
pg_client.close_portal("")
assert len(pg_client.statements["available"]) == 1
assert len(pg_client.portals["available"]) == 0
sql = """
insert into "t" values (1), (2), (3)
"""
# update statement
pg_client.parse("", sql)
pg_client.bind("", "", [], [])
pg_client.execute("")
pg_client.close_portal("")
assert len(pg_client.statements["available"]) == 1
assert len(pg_client.portals["available"]) == 0
sql = """
select * from "t" where "val" = 1
"""
# update statement
pg_client.parse("", sql)
pg_client.bind("", "", [], [])
row_with_one = pg_client.execute("")["rows"][0]
pg_client.close_portal("")
assert row_with_one == [1]
assert len(pg_client.statements["available"]) == 1
assert len(pg_client.portals["available"]) == 0
# bind portal
pg_client.bind("", "", [], [])
assert len(pg_client.portals["available"]) == 1
sql = """
select * from "t" where "val" = 2
"""
pg_client.parse("", sql)
# update portal
pg_client.bind("", "", [], [])
row_with_two = pg_client.execute("")["rows"][0]
pg_client.close_portal("")
assert row_with_two == [2]
assert len(pg_client.statements["available"]) == 1
assert len(pg_client.portals["available"]) == 0
def test_updates_with_named(pg_client: PgClient):
name = "named"
sql = """
create table "t" ("val" int not null, primary key ("val"))
using vinyl
distributed by ("val")
option (timeout = 3)
"""
pg_client.parse(name, sql)
with pytest.raises(ReturnError, match=f"Duplicated name '{name}'"):
pg_client.parse(name, sql)
pg_client.bind(name, name, [], [])
with pytest.raises(ReturnError, match=f"Duplicated name '{name}'"):
pg_client.bind(name, name, [], [])
def test_close_nonexistent(pg_client: PgClient):
try:
assert len(pg_client.statements["available"]) == 0
assert len(pg_client.portals["available"]) == 0
pg_client.close_portal("")
pg_client.close_stmt("")
pg_client.close_portal("nonexistent")
pg_client.close_stmt("nonexistent")
except Exception:
pytest.fail("close causes errors")
def test_statement_close(pg_client: PgClient):
sql = """
create table "t" ("val" int not null, primary key ("val"))
using vinyl
distributed by ("val")
option (timeout = 3)
"""
pg_client.parse("", sql)
assert len(pg_client.statements["available"]) == 1
assert len(pg_client.portals["available"]) == 0
pg_client.bind("", "1", [], [])
pg_client.bind("", "2", [], [])
pg_client.bind("", "3", [], [])
assert len(pg_client.statements["available"]) == 1
assert len(pg_client.portals["available"]) == 3
# close_stmt also closes statement's portals
pg_client.close_stmt("")
assert pg_client.statements == {"available": [], "total": 0}
assert pg_client.portals == {"available": [], "total": 0}
pg_client.parse("", sql)
assert len(pg_client.statements["available"]) == 1
assert len(pg_client.portals["available"]) == 0
pg_client.bind("", "1", [], [])
pg_client.bind("", "2", [], [])
pg_client.bind("", "3", [], [])
assert len(pg_client.portals["available"]) == 3
# update for unnamed == close_stmt("") + parse("", sql)
pg_client.parse("", sql)
assert pg_client.statements == {"available": [""], "total": 1}
assert pg_client.portals == {"available": [], "total": 0}
def test_visibility(pg_storage: PgStorage):
instance = pg_storage.instance
client1 = pg_storage.new_client(1)
client2 = pg_storage.new_client(2)
def test_portal_visibility(instance: Instance):
ddl = """
create table "t" ("key" int not null, "value" string not null, primary key ("key"))
using vinyl
......@@ -136,22 +288,145 @@ def test_portal_visibility(instance: Instance):
instance.sql(ddl)
sql = """ select * from "t" """
id = instance.eval(
f"""
box.session.su("admin")
local res = pico.pg_parse([[{sql}]])
box.session.su("guest")
return res
client1.parse("", sql)
assert client1.statements == {"available": [""], "total": 1}
assert client2.statements == {"available": [], "total": 1}
assert client1.portals == {"available": [], "total": 0}
assert client2.portals == {"available": [], "total": 0}
client1.bind("", "", [], [])
assert client1.statements == {"available": [""], "total": 1}
assert client2.statements == {"available": [], "total": 1}
assert client1.portals == {"available": [""], "total": 1}
assert client2.portals == {"available": [], "total": 1}
client2.close_portal("")
client2.close_stmt("")
assert client1.statements == {"available": [""], "total": 1}
assert client2.statements == {"available": [], "total": 1}
assert client1.portals == {"available": [""], "total": 1}
assert client2.portals == {"available": [], "total": 1}
client1.close_portal("")
assert client1.statements == {"available": [""], "total": 1}
assert client2.statements == {"available": [], "total": 1}
assert client1.portals == {"available": [], "total": 0}
assert client2.portals == {"available": [], "total": 0}
client1.close_stmt("")
assert client1.statements == {"available": [], "total": 0}
assert client2.statements == {"available": [], "total": 0}
assert client1.portals == {"available": [], "total": 0}
assert client2.portals == {"available": [], "total": 0}
def test_param_oids(pg_client: PgClient):
instance = pg_client.instance
instance.sql(
"""
create table "t" ("key" int not null, "value" string not null, primary key ("key"))
using vinyl
distributed by ("key")
option (timeout = 3)
"""
)
admin_portals = instance.eval(
""" return box.session.su("admin", pico.pg_portals) """
)
assert admin_portals == {"available": [id], "total": 1}
guest_portals = instance.eval(
""" return box.session.su("guest", pico.pg_portals) """
sql = """ insert into "t" values (?, ?) """
pg_client.parse("", sql, [1, 2])
pg_client.bind("", "", [1, "a"], [])
desc = pg_client.describe_stmt("")
assert desc["param_oids"] == [1, 2]
assert desc["query_type"] == 2
assert desc["command_tag"] == 9
assert desc["metadata"] == []
sql = """ insert into "t" values (?, ?) """
pg_client.parse("", sql, [42, 42])
pg_client.bind("", "", [1, "a"], [])
desc = pg_client.describe_stmt("")
assert desc["param_oids"] == [42, 42]
assert desc["query_type"] == 2
assert desc["command_tag"] == 9
assert desc["metadata"] == []
sql = """ insert into "t" values (1, ?) """
pg_client.parse("", sql, [1])
pg_client.bind("", "", [1, "a"], [])
desc = pg_client.describe_stmt("")
assert desc["param_oids"] == [1]
assert desc["query_type"] == 2
assert desc["command_tag"] == 9
assert desc["metadata"] == []
sql = """ insert into "t" values (?, ?) """
pg_client.parse("", sql)
pg_client.bind("", "", [1, "a"], [])
desc = pg_client.describe_stmt("")
assert desc["param_oids"] == [25, 25]
assert desc["query_type"] == 2
assert desc["command_tag"] == 9
assert desc["metadata"] == []
def test_interactive_portals(pg_client: PgClient):
instance = pg_client.instance
instance.sql(
"""
create table "t" ("key" int not null, "value" string not null, primary key ("key"))
using vinyl
distributed by ("key")
option (timeout = 3)
"""
)
assert guest_portals == {"available": [], "total": 1}
with pytest.raises(ReturnError, match="No such descriptor"):
instance.eval(f""" return box.session.su("guest", pico.pg_close, {id}) """)
assert instance.eval(f""" return box.session.su("admin", pico.pg_close, {id}) """)
sql = """ select * from "t" """
pg_client.parse("", sql)
pg_client.bind("", "", [], [])
data = pg_client.execute("", -1)
assert data["rows"] == []
assert data["is_finished"] is True
sql = """ insert into "t" values (1, 'kek'), (2, 'lol') """
instance.sql(sql)
sql = """ select * from "t" """
pg_client.parse("", sql)
pg_client.bind("", "", [], [])
# -1 - fetch all
data = pg_client.execute("", -1)
assert len(data["rows"]) == 2
assert [1, "kek"] in data["rows"]
assert [2, "lol"] in data["rows"]
assert data["is_finished"] is True
sql = """ select * from "t" """
pg_client.parse("", sql)
pg_client.bind("", "", [], [])
data = pg_client.execute("", 1)
assert len(data["rows"]) == 1
assert [1, "kek"] in data["rows"] or [2, "lol"] in data["rows"]
assert data["is_finished"] is False
data = pg_client.execute("", 1)
assert len(data["rows"]) == 1
assert [1, "kek"] in data["rows"] or [2, "lol"] in data["rows"]
assert data["is_finished"] is True
with pytest.raises(ReturnError, match="Can't execute portal in state Finished"):
data = pg_client.execute("", 1)
sql = """ explain select * from "t" """
pg_client.parse("", sql)
pg_client.bind("", "", [], [])
data = pg_client.execute("", 1)
assert len(data["rows"]) == 1
assert [
"""projection ("t"."key"::integer -> "key", "t"."value"::string -> "value")"""
] == data["rows"]
assert data["is_finished"] is False
data = pg_client.execute("", -1)
assert len(data["rows"]) == 4
assert """ scan "t\"""" in data["rows"]
assert """execution options:""" in data["rows"]
assert """sql_vdbe_max_steps = 45000""" in data["rows"]
assert """vtable_max_rows = 5000""" in data["rows"]
assert data["is_finished"] is True
......@@ -260,7 +260,7 @@ def test_read_from_system_tables(cluster: Cluster):
{"name": "key", "type": "string"},
{"name": "value", "type": "any"},
]
assert len(data["rows"]) == 15
assert len(data["rows"]) == 16
data = i1.sql(
"""
......