From 5a667cb79c9982c0a9a07c71762d0cc8a8e9a0a7 Mon Sep 17 00:00:00 2001
From: Kaitmazian Maksim <m.kaitmazian@picodata.io>
Date: Wed, 20 Nov 2024 15:54:38 +0300
Subject: [PATCH] feat(pgproto): support vdbe_max_steps and vtable_max_rows
 options

With these changes, a user can specify new default values for vdbe_max_steps
vtable_max_rows options in connection string.

For example, the following connection string sets both options to 42:
postgres://postgres:Passw0rd@localhost:5432?options=vtable_max_rows%3D42,vdbe_max_steps%3D42
---
 CHANGELOG.md                    |   8 +++
 src/pgproto/backend.rs          |  63 ++++++++++++++++---
 src/pgproto/backend/describe.rs |  12 ++--
 src/pgproto/backend/pgproc.rs   |   2 +-
 src/pgproto/client.rs           |   4 +-
 src/pgproto/client/startup.rs   |  89 ++++++++++++++++++++++-----
 src/pgproto/error.rs            |   6 ++
 test/pgproto/options_test.py    | 104 ++++++++++++++++++++++++++++++++
 8 files changed, 261 insertions(+), 27 deletions(-)
 create mode 100644 test/pgproto/options_test.py

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 0e9087aff5..0c2c31d48e 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -139,6 +139,14 @@ to 2 and 3.
 - EXPLAIN estimates query buckets
 - SQL supports `COALESCE` function
 
+### Pgproro
+- "vdbe_max_steps" and "vtable_max_rows" options are supported in connection
+  string. These options allow to override the defalt values of the
+  corresponding execution options used in sql (VDBE_MAX_STEPS and VTABLE_MAX_ROWS).
+
+  For example, the following connection string sets both options to 42:
+  postgres://postgres:Passw0rd@localhost:5432?options=vtable_max_rows%3D42,vdbe_max_steps%3D42
+
 ### Fixes
 
 - Fixed bucket rebalancing for sharded tables
diff --git a/src/pgproto/backend.rs b/src/pgproto/backend.rs
index 96cd5f079f..671201d78c 100644
--- a/src/pgproto/backend.rs
+++ b/src/pgproto/backend.rs
@@ -1,10 +1,10 @@
 use self::{
-    describe::{PortalDescribe, StatementDescribe},
+    describe::{PortalDescribe, QueryType, StatementDescribe},
     result::ExecuteResult,
     storage::{with_portals_mut, Portal, Statement, PG_PORTALS, PG_STATEMENTS},
 };
 use super::{
-    client::ClientId,
+    client::{ClientId, ClientParams},
     error::{PgError, PgResult},
     value::PgValue,
 };
@@ -16,8 +16,8 @@ use crate::{
 use crate::{tlog, traft::error::Error};
 use bytes::Bytes;
 use postgres_types::Oid;
-use sbroad::errors::SbroadError;
-use sbroad::ir::value::Value as SbroadValue;
+use sbroad::ir::{value::Value as SbroadValue, OptionKind};
+use sbroad::{errors::SbroadError, ir::OptionSpec};
 use sbroad::{
     executor::{
         engine::{QueryCache, Router, TableVersionMap},
@@ -86,12 +86,43 @@ fn prepare_encoding_format(formats: &[RawFormat], n: usize) -> PgResult<Vec<Fiel
     }
 }
 
+/// Set values from default_options for unspecified options in query_options.
+fn apply_default_options(
+    query_options: &[OptionSpec],
+    default_options: &[OptionSpec],
+) -> Vec<OptionSpec> {
+    // First, set query options, as they have higher priority.
+    let (mut max_steps, mut max_rows) = (None, None);
+    for opt in query_options {
+        match opt.kind {
+            OptionKind::VdbeMaxSteps => max_steps = Some(opt),
+            OptionKind::VTableMaxRows => max_rows = Some(opt),
+        }
+    }
+
+    // Then, apply defaults for unspecified options.
+    for opt in default_options {
+        match opt.kind {
+            OptionKind::VdbeMaxSteps if max_steps.is_none() => max_steps = Some(opt),
+            OptionKind::VTableMaxRows if max_rows.is_none() => max_rows = Some(opt),
+            _ => {}
+        }
+    }
+
+    // Keep only Some variants.
+    [max_steps, max_rows]
+        .into_iter()
+        .filter_map(|x| x.cloned())
+        .collect()
+}
+
 pub fn bind(
     client_id: ClientId,
     stmt_name: String,
     portal_name: String,
     params: Vec<Value>,
     result_format: Vec<FieldFormat>,
+    default_options: Vec<OptionSpec>,
 ) -> PgResult<()> {
     let key = (client_id, stmt_name.into());
     let Some(statement) = PG_STATEMENTS.with(|storage| storage.borrow().get(&key)) else {
@@ -99,7 +130,13 @@ pub fn bind(
             format!("Couldn't find statement \'{}\'.", key.1).into(),
         ));
     };
+
     let mut plan = statement.plan().clone();
+    let is_dql = matches!(statement.describe().query_type(), QueryType::Dql);
+    if is_dql && !default_options.is_empty() {
+        plan.raw_options = apply_default_options(&plan.raw_options, &default_options);
+    }
+
     if !plan.is_empty() && !plan.is_ddl()? && !plan.is_acl()? && !plan.is_plugin()? {
         plan.bind_params(params)?;
         plan.apply_options()?;
@@ -193,10 +230,12 @@ pub struct Backend {
     /// A unique identificator of a postgres client. It is used as a part of a key in the portal
     /// storage, allowing to store in a single storage portals from many clients.
     client_id: ClientId,
+
+    params: ClientParams,
 }
 
 impl Backend {
-    pub fn new() -> Self {
+    pub fn new(params: ClientParams) -> Self {
         /// Generate a unique client id.
         fn unique_id() -> ClientId {
             static ID_COUNTER: AtomicU32 = AtomicU32::new(0);
@@ -205,6 +244,7 @@ impl Backend {
 
         Self {
             client_id: unique_id(),
+            params,
         }
     }
 
@@ -273,6 +313,7 @@ impl Backend {
                 "".into(),
                 params,
                 vec![FieldFormat::Text; ncolumns],
+                vec![],
             )?;
             self.execute(None, -1)
         };
@@ -326,8 +367,16 @@ impl Backend {
         let params_format = prepare_encoding_format(params_format, params.len())?;
         let result_format = prepare_encoding_format(result_format, describe.ncolumns())?;
         let params = decode_parameter_values(params, &describe.param_oids, &params_format)?;
-
-        bind(self.client_id, statement, portal, params, result_format)
+        let default_options = self.params.execution_options();
+
+        bind(
+            self.client_id,
+            statement,
+            portal,
+            params,
+            result_format,
+            default_options,
+        )
     }
 
     /// Handler for an Execute message.
diff --git a/src/pgproto/backend/describe.rs b/src/pgproto/backend/describe.rs
index 4d8465f68d..522ce5e2be 100644
--- a/src/pgproto/backend/describe.rs
+++ b/src/pgproto/backend/describe.rs
@@ -27,7 +27,7 @@ use tarantool::{
     tuple::FunctionCtx,
 };
 
-#[derive(Debug, Clone, Default, Deserialize_repr, Serialize_repr)]
+#[derive(Debug, Clone, Copy, Default, Deserialize_repr, Serialize_repr)]
 #[repr(u8)]
 pub enum QueryType {
     Acl = 0,
@@ -394,8 +394,8 @@ impl Describe {
 }
 
 impl Describe {
-    pub fn query_type(&self) -> &QueryType {
-        &self.query_type
+    pub fn query_type(&self) -> QueryType {
+        self.query_type
     }
 
     pub fn command_tag(&self) -> CommandTag {
@@ -439,6 +439,10 @@ impl StatementDescribe {
             param_oids,
         }
     }
+
+    pub fn query_type(&self) -> QueryType {
+        self.describe.query_type()
+    }
 }
 
 impl StatementDescribe {
@@ -507,7 +511,7 @@ impl PortalDescribe {
             .collect()
     }
 
-    pub fn query_type(&self) -> &QueryType {
+    pub fn query_type(&self) -> QueryType {
         self.describe.query_type()
     }
 
diff --git a/src/pgproto/backend/pgproc.rs b/src/pgproto/backend/pgproc.rs
index 80a91c7f48..57d01b8928 100644
--- a/src/pgproto/backend/pgproc.rs
+++ b/src/pgproto/backend/pgproc.rs
@@ -62,7 +62,7 @@ pub fn proc_pg_bind(args: BindArgs) -> PgResult<()> {
         encoding_format: output_format,
     } = args;
 
-    backend::bind(id, stmt_name, portal_name, params, output_format)
+    backend::bind(id, stmt_name, portal_name, params, output_format, vec![])
 }
 
 #[proc]
diff --git a/src/pgproto/client.rs b/src/pgproto/client.rs
index aa627c6867..360204850d 100644
--- a/src/pgproto/client.rs
+++ b/src/pgproto/client.rs
@@ -15,6 +15,8 @@ mod extended_query;
 mod simple_query;
 mod startup;
 
+pub use startup::ClientParams;
+
 pub type ClientId = u32;
 
 /// Postgres client representation.
@@ -42,7 +44,7 @@ impl<S: io::Read + io::Write> PgClient<S> {
         tlog!(Info, "client authenticated");
 
         Ok(PgClient {
-            backend: Backend::new(),
+            backend: Backend::new(params),
             loop_state: MessageLoopState::ReadyForQuery,
             stream,
         })
diff --git a/src/pgproto/client/startup.rs b/src/pgproto/client/startup.rs
index ab58828990..8666017afa 100644
--- a/src/pgproto/client/startup.rs
+++ b/src/pgproto/client/startup.rs
@@ -4,30 +4,91 @@ use crate::pgproto::stream::{FeMessage, PgStream};
 use crate::pgproto::tls::TlsAcceptor;
 use crate::tlog;
 use pgwire::messages::startup::Startup;
+use sbroad::ir::value::Value as SbroadValue;
+use sbroad::ir::{OptionKind, OptionParamValue, OptionSpec};
 use std::collections::BTreeMap;
 use std::io::{Read, Write};
 
+#[derive(Clone, Debug)]
 pub struct ClientParams {
     pub username: String,
+    pub vtable_max_rows: Option<u64>,
+    pub vdbe_max_steps: Option<u64>,
     pub _rest: BTreeMap<String, String>,
+    // NB: add more params as needed.
+    // Keep in mind that a client is required to send only "user".
 }
 
-fn parse_startup(startup: Startup) -> PgResult<ClientParams> {
-    let mut parameters = startup.parameters;
-    tlog!(Debug, "client parameters: {parameters:?}");
+impl ClientParams {
+    fn new(mut parameters: BTreeMap<String, String>) -> PgResult<Self> {
+        let Some(username) = parameters.remove("user") else {
+            return Err(PgError::ProtocolViolation(
+                "parameter 'user' is missing".into(),
+            ));
+        };
 
-    let Some(username) = parameters.remove("user") else {
-        return Err(PgError::ProtocolViolation(
-            "parameter 'user' is missing".into(),
-        ));
-    };
+        let (mut vtable_max_rows, mut vdbe_max_steps) = (None, None);
+        if let Some(options) = parameters.get("options") {
+            for pair in options.split(',') {
+                let mut pair = pair.split('=');
+                let name = pair.next().ok_or(PgError::other("option with no name"))?;
+                let val = pair.next().ok_or(PgError::other("option with no value"))?;
+                match name {
+                    "vtable_max_rows" => {
+                        vtable_max_rows = Some(val.parse().map_err(PgError::other)?)
+                    }
+                    "vdbe_max_steps" => vdbe_max_steps = Some(val.parse().map_err(PgError::other)?),
+                    _ => {
+                        // We prefer using warnings instead of errors for these reasons:
+                        // 1) This is similar to how we handle unknown PostgreSQL parameters:
+                        //    we just ignore them without causing errors.
+                        // 2) Some clients might send unknown parameters, so throwing errors will
+                        //    make it impossible to work with such clients. However, we're not
+                        //    sure if any clients do this.
+                        tlog!(Warning, "unknown option: '{name}'");
+                    }
+                }
+            }
+        }
 
-    // NB: add more params as needed.
-    // Keep in mind that a client is required to send only "user".
-    Ok(ClientParams {
-        username,
-        _rest: parameters,
-    })
+        Ok(Self {
+            username,
+            vtable_max_rows,
+            vdbe_max_steps,
+            _rest: parameters,
+        })
+    }
+
+    pub fn execution_options(&self) -> Vec<OptionSpec> {
+        let mut opts = vec![];
+
+        if let Some(vdbe_max_steps) = self.vdbe_max_steps {
+            let vdbe_max_steps = OptionParamValue::Value {
+                val: SbroadValue::Unsigned(vdbe_max_steps),
+            };
+            opts.push(OptionSpec {
+                kind: OptionKind::VdbeMaxSteps,
+                val: vdbe_max_steps,
+            })
+        }
+
+        if let Some(vtable_max_rows) = self.vtable_max_rows {
+            let vtable_max_rows = OptionParamValue::Value {
+                val: SbroadValue::Unsigned(vtable_max_rows),
+            };
+            opts.push(OptionSpec {
+                kind: OptionKind::VTableMaxRows,
+                val: vtable_max_rows,
+            })
+        }
+
+        opts
+    }
+}
+
+fn parse_startup(startup: Startup) -> PgResult<ClientParams> {
+    tlog!(Debug, "client parameters: {:?}", &startup.parameters);
+    ClientParams::new(startup.parameters)
 }
 
 fn handle_ssl_request<S: Read + Write>(
diff --git a/src/pgproto/error.rs b/src/pgproto/error.rs
index fff953b809..1a6e6ee6c6 100644
--- a/src/pgproto/error.rs
+++ b/src/pgproto/error.rs
@@ -106,6 +106,12 @@ pub enum PgError {
     Other(Box<dyn std::error::Error>),
 }
 
+impl PgError {
+    pub fn other<E: Into<Box<dyn std::error::Error>>>(e: E) -> Self {
+        Self::Other(e.into())
+    }
+}
+
 impl From<sbroad::errors::SbroadError> for PgError {
     #[inline(always)]
     fn from(e: sbroad::errors::SbroadError) -> Self {
diff --git a/test/pgproto/options_test.py b/test/pgproto/options_test.py
new file mode 100644
index 0000000000..245aed3da5
--- /dev/null
+++ b/test/pgproto/options_test.py
@@ -0,0 +1,104 @@
+import psycopg
+import pytest
+from conftest import Postgres
+
+
+def test_vdbe_max_steps_and_vtable_max_rows_options(postgres: Postgres):
+    user = "postgres"
+    password = "Passw0rd"
+    host = postgres.host
+    port = postgres.port
+
+    postgres.instance.sql(f"CREATE USER \"{user}\" WITH PASSWORD '{password}'")
+
+    # Note that "vtable_max_rows%3D1" is an escaped version of "vtable_max_rows=1".
+
+    # Set the default for "vtable_max_rows" to 1.
+    conn = psycopg.connect(
+        f"postgres://{user}:{password}@{host}:{port}?options=vtable_max_rows%3D1",
+        autocommit=True,
+    )
+    with pytest.raises(
+        psycopg.InternalError,
+        match=r"Exceeded maximum number of rows \(1\) in virtual table: 2",
+    ):
+        conn.execute("SELECT * FROM (VALUES (1), (2))")
+
+    # Check if it still fails with "vdbe_max_steps" provided.
+    with pytest.raises(
+        psycopg.InternalError,
+        match=r"Exceeded maximum number of rows \(1\) in virtual table: 2",
+    ):
+        conn.execute("SELECT * FROM (VALUES (1), (2)) OPTION (VDBE_MAX_STEPS = 1000)")
+
+    # Specify "vtable_max_rows" in a query so the default is not used.
+    conn.execute("SELECT * FROM (VALUES (1), (2)) OPTION (VTABLE_MAX_ROWS = 2)")
+
+    # Set the default for "vdbe_max_steps" to 1.
+    conn = psycopg.connect(
+        f"postgres://{user}:{password}@{host}:{port}?options=vdbe_max_steps%3D1",
+        autocommit=True,
+    )
+    with pytest.raises(
+        psycopg.InternalError,
+        match="Reached a limit on max executed vdbe opcodes. Limit: 1",
+    ):
+        conn.execute("SELECT * FROM (VALUES (1), (2))")
+
+    # Check if it still fails with "vtable_max_rows" provided.
+    with pytest.raises(
+        psycopg.InternalError,
+        match="Reached a limit on max executed vdbe opcodes. Limit: 1",
+    ):
+        conn.execute("SELECT * FROM (VALUES (1), (2)) OPTION (VTABLE_MAX_ROWS = 1000)")
+
+    # Specify "vdbe_max_steps" in a query so the default is not used.
+    conn.execute("SELECT * FROM (VALUES (1), (2)) OPTION (VDBE_MAX_STEPS = 1000)")
+
+    # Set both options and reach "vdbe_max_steps" limit.
+    conn = psycopg.connect(
+        f"postgres://{user}:{password}@{host}:{port}?"
+        "options=vtable_max_rows%3D1,vdbe_max_steps%3D1",
+        autocommit=True,
+    )
+    with pytest.raises(
+        psycopg.InternalError,
+        match=r"Reached a limit on max executed vdbe opcodes. Limit: 1",
+    ):
+        conn.execute("SELECT * FROM (VALUES (1), (2))")
+
+    # Set both options and reach "vtable_max_rows" limit.
+    conn = psycopg.connect(
+        f"postgres://{user}:{password}@{host}:{port}?"
+        "options=vtable_max_rows%3D1,vdbe_max_steps%3D1000",
+        autocommit=True,
+    )
+    with pytest.raises(
+        psycopg.InternalError,
+        match=r"Exceeded maximum number of rows \(1\) in virtual table: 2",
+    ):
+        conn.execute("SELECT * FROM (VALUES (1), (2))")
+
+
+def test_repeating_options(postgres: Postgres):
+    user = "postgres"
+    password = "Passw0rd"
+    host = postgres.host
+    port = postgres.port
+
+    postgres.instance.sql(f"CREATE USER \"{user}\" WITH PASSWORD '{password}'")
+
+    # Note that "vtable_max_rows%3D1" is an escaped version of "vtable_max_rows=1".
+
+    # Check if the last option value is applied (3 -> 2 -> 1).
+    conn = psycopg.connect(
+        f"postgres://{user}:{password}@{host}:{port}?"
+        "options=vtable_max_rows%3D3,vtable_max_rows%3D2,"
+        "vtable_max_rows%3D1",
+        autocommit=True,
+    )
+    with pytest.raises(
+        psycopg.InternalError,
+        match=r"Exceeded maximum number of rows \(1\) in virtual table: 2",
+    ):
+        conn.execute("SELECT * FROM (VALUES (1), (2))")
-- 
GitLab