From 09a8a5f0fb8f8f45cd684473aacaff55eb427f92 Mon Sep 17 00:00:00 2001
From: Kaitmazian Maksim <m.kaitmazian@picodata.io>
Date: Fri, 29 Mar 2024 15:09:47 +0300
Subject: [PATCH] refactor(pgproto): move proc_pg_* exports in a separate file

---
 src/pgproto/backend.rs        | 172 ++++++++++------------------------
 src/pgproto/backend/pgproc.rs | 128 +++++++++++++++++++++++++
 2 files changed, 180 insertions(+), 120 deletions(-)
 create mode 100644 src/pgproto/backend/pgproc.rs

diff --git a/src/pgproto/backend.rs b/src/pgproto/backend.rs
index fd5126ef12..50965b06cd 100644
--- a/src/pgproto/backend.rs
+++ b/src/pgproto/backend.rs
@@ -1,8 +1,5 @@
 use self::describe::{PortalDescribe, StatementDescribe};
-use self::storage::{
-    with_portals_mut, Portal, Statement, UserPortalNames, UserStatementNames, PG_PORTALS,
-    PG_STATEMENTS,
-};
+use self::storage::{with_portals_mut, Portal, Statement, PG_PORTALS, PG_STATEMENTS};
 use super::client::ClientId;
 use super::error::{PgError, PgResult};
 use crate::pgproto::storage::value::Format;
@@ -11,7 +8,6 @@ use crate::sql::otm::TracerKind;
 use crate::sql::router::RouterRuntime;
 use crate::sql::with_tracer;
 use crate::traft::error::Error;
-use ::tarantool::proc;
 use opentelemetry::sdk::trace::Tracer;
 use opentelemetry::Context;
 use postgres_types::Oid;
@@ -19,62 +15,19 @@ use sbroad::executor::engine::helpers::normalize_name_for_space_api;
 use sbroad::executor::engine::{QueryCache, Router, TableVersionMap};
 use sbroad::executor::lru::Cache;
 use sbroad::frontend::Ast;
-use sbroad::ir::value::{LuaValue, Value};
+use sbroad::ir::value::Value;
 use sbroad::ir::Plan as IrPlan;
 use sbroad::otm::{query_id, query_span, OTM_CHAR_LIMIT};
 use sbroad::utils::MutexLike;
-use serde::Deserialize;
 use smol_str::ToSmolStr;
 use std::rc::Rc;
 use tarantool::session::with_su;
 use tarantool::tuple::Tuple;
 
-pub mod describe;
+mod pgproc;
 mod storage;
 
-struct BindArgs {
-    id: ClientId,
-    stmt_name: String,
-    portal_name: String,
-    params: Vec<Value>,
-    encoding_format: Vec<u8>,
-    traceable: bool,
-}
-
-impl<'de> Deserialize<'de> for BindArgs {
-    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
-    where
-        D: serde::Deserializer<'de>,
-    {
-        #[derive(Deserialize)]
-        struct EncodedBindArgs(
-            ClientId,
-            String,
-            String,
-            Option<Vec<LuaValue>>,
-            Vec<u8>,
-            Option<bool>,
-        );
-
-        let EncodedBindArgs(id, stmt_name, portal_name, params, encoding_format, traceable) =
-            EncodedBindArgs::deserialize(deserializer)?;
-
-        let params = params
-            .unwrap_or_default()
-            .into_iter()
-            .map(Value::from)
-            .collect::<Vec<Value>>();
-
-        Ok(Self {
-            id,
-            stmt_name,
-            portal_name,
-            params,
-            encoding_format,
-            traceable: traceable.unwrap_or(false),
-        })
-    }
-}
+pub mod describe;
 
 // helper function to get `TracerRef`
 fn get_tracer_param(traceable: bool) -> &'static Tracer {
@@ -82,17 +35,15 @@ fn get_tracer_param(traceable: bool) -> &'static Tracer {
     kind.get_tracer()
 }
 
-#[proc(packed_args)]
-pub fn proc_pg_bind(args: BindArgs) -> PgResult<()> {
-    let BindArgs {
-        id,
-        stmt_name,
-        portal_name,
-        params,
-        encoding_format: output_format,
-        traceable,
-    } = args;
-    let key = (id, stmt_name.into());
+pub fn bind(
+    client_id: ClientId,
+    stmt_name: String,
+    portal_name: String,
+    params: Vec<Value>,
+    output_format: Vec<u8>,
+    traceable: bool,
+) -> PgResult<()> {
+    let key = (client_id, stmt_name.into());
     let Some(statement) = PG_STATEMENTS.with(|storage| storage.borrow().get(&key)) else {
         return Err(PgError::Other(
             format!("Couldn't find statement \'{}\'.", key.1).into(),
@@ -120,65 +71,15 @@ pub fn proc_pg_bind(args: BindArgs) -> PgResult<()> {
         },
     )?;
 
-    PG_PORTALS.with(|storage| storage.borrow_mut().put((id, portal_name.into()), portal))?;
+    PG_PORTALS.with(|storage| {
+        storage
+            .borrow_mut()
+            .put((client_id, portal_name.into()), portal)
+    })?;
     Ok(())
 }
 
-#[proc]
-pub fn proc_pg_statements(id: ClientId) -> UserStatementNames {
-    UserStatementNames::new(id)
-}
-
-#[proc]
-pub fn proc_pg_portals(id: ClientId) -> UserPortalNames {
-    UserPortalNames::new(id)
-}
-
-#[proc]
-pub fn proc_pg_close_stmt(id: ClientId, name: String) {
-    // Close can't cause an error in PG.
-    PG_STATEMENTS.with(|storage| storage.borrow_mut().remove(&(id, name.into())));
-}
-
-#[proc]
-pub fn proc_pg_close_portal(id: ClientId, name: String) {
-    // Close can't cause an error in PG.
-    PG_PORTALS.with(|storage| storage.borrow_mut().remove(&(id, name.into())));
-}
-
-#[proc]
-pub fn proc_pg_close_client_stmts(id: ClientId) {
-    PG_STATEMENTS.with(|storage| storage.borrow_mut().remove_by_client_id(id))
-}
-
-#[proc]
-pub fn proc_pg_close_client_portals(id: ClientId) {
-    PG_PORTALS.with(|storage| storage.borrow_mut().remove_by_client_id(id))
-}
-
-#[proc]
-pub fn proc_pg_describe_stmt(id: ClientId, name: String) -> PgResult<StatementDescribe> {
-    let key = (id, name.into());
-    let Some(statement) = PG_STATEMENTS.with(|storage| storage.borrow().get(&key)) else {
-        return Err(PgError::Other(
-            format!("Couldn't find statement \'{}\'.", key.1).into(),
-        ));
-    };
-    Ok(statement.describe().clone())
-}
-
-#[proc]
-pub fn proc_pg_describe_portal(id: ClientId, name: String) -> PgResult<PortalDescribe> {
-    with_portals_mut((id, name.into()), |portal| Ok(portal.describe().clone()))
-}
-
-#[proc]
-pub fn proc_pg_execute(
-    id: ClientId,
-    name: String,
-    max_rows: i64,
-    traceable: bool,
-) -> PgResult<Tuple> {
+pub fn execute(id: ClientId, name: String, max_rows: i64, traceable: bool) -> PgResult<Tuple> {
     let max_rows = if max_rows <= 0 { i64::MAX } else { max_rows };
     let name = Rc::from(name);
 
@@ -199,8 +100,7 @@ pub fn proc_pg_execute(
     })
 }
 
-#[proc]
-pub fn proc_pg_parse(
+pub fn parse(
     cid: ClientId,
     name: String,
     query: String,
@@ -258,3 +158,35 @@ pub fn proc_pg_parse(
         },
     )
 }
+
+pub fn describe_stmt(id: ClientId, name: String) -> PgResult<StatementDescribe> {
+    let key = (id, name.into());
+    let Some(statement) = PG_STATEMENTS.with(|storage| storage.borrow().get(&key)) else {
+        return Err(PgError::Other(
+            format!("Couldn't find statement \'{}\'.", key.1).into(),
+        ));
+    };
+    Ok(statement.describe().clone())
+}
+
+pub fn describe_portal(id: ClientId, name: String) -> PgResult<PortalDescribe> {
+    with_portals_mut((id, name.into()), |portal| Ok(portal.describe().clone()))
+}
+
+pub fn close_stmt(id: ClientId, name: String) {
+    // Close can't cause an error in PG.
+    PG_STATEMENTS.with(|storage| storage.borrow_mut().remove(&(id, name.into())));
+}
+
+pub fn close_portal(id: ClientId, name: String) {
+    // Close can't cause an error in PG.
+    PG_PORTALS.with(|storage| storage.borrow_mut().remove(&(id, name.into())));
+}
+
+pub fn close_client_stmts(id: ClientId) {
+    PG_STATEMENTS.with(|storage| storage.borrow_mut().remove_by_client_id(id))
+}
+
+pub fn close_client_portals(id: ClientId) {
+    PG_PORTALS.with(|storage| storage.borrow_mut().remove_by_client_id(id))
+}
diff --git a/src/pgproto/backend/pgproc.rs b/src/pgproto/backend/pgproc.rs
new file mode 100644
index 0000000000..5f98cf925c
--- /dev/null
+++ b/src/pgproto/backend/pgproc.rs
@@ -0,0 +1,128 @@
+use super::describe::{PortalDescribe, StatementDescribe};
+use super::storage::{UserPortalNames, UserStatementNames};
+use crate::pgproto::backend;
+use crate::pgproto::{client::ClientId, error::PgResult};
+use ::tarantool::proc;
+use postgres_types::Oid;
+use sbroad::ir::value::{LuaValue, Value};
+use serde::Deserialize;
+use tarantool::tuple::Tuple;
+
+struct BindArgs {
+    id: ClientId,
+    stmt_name: String,
+    portal_name: String,
+    params: Vec<Value>,
+    encoding_format: Vec<u8>,
+    traceable: bool,
+}
+
+impl<'de> Deserialize<'de> for BindArgs {
+    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+    where
+        D: serde::Deserializer<'de>,
+    {
+        #[derive(Deserialize)]
+        struct EncodedBindArgs(
+            ClientId,
+            String,
+            String,
+            Option<Vec<LuaValue>>,
+            Vec<u8>,
+            Option<bool>,
+        );
+
+        let EncodedBindArgs(id, stmt_name, portal_name, params, encoding_format, traceable) =
+            EncodedBindArgs::deserialize(deserializer)?;
+
+        let params = params
+            .unwrap_or_default()
+            .into_iter()
+            .map(Value::from)
+            .collect::<Vec<Value>>();
+
+        Ok(Self {
+            id,
+            stmt_name,
+            portal_name,
+            params,
+            encoding_format,
+            traceable: traceable.unwrap_or(false),
+        })
+    }
+}
+
+#[proc(packed_args)]
+pub fn proc_pg_bind(args: BindArgs) -> PgResult<()> {
+    let BindArgs {
+        id,
+        stmt_name,
+        portal_name,
+        params,
+        encoding_format: output_format,
+        traceable,
+    } = args;
+
+    backend::bind(id, stmt_name, portal_name, params, output_format, traceable)
+}
+
+#[proc]
+pub fn proc_pg_describe_stmt(id: ClientId, name: String) -> PgResult<StatementDescribe> {
+    backend::describe_stmt(id, name)
+}
+
+#[proc]
+pub fn proc_pg_describe_portal(id: ClientId, name: String) -> PgResult<PortalDescribe> {
+    backend::describe_portal(id, name)
+}
+
+#[proc]
+pub fn proc_pg_execute(
+    id: ClientId,
+    name: String,
+    max_rows: i64,
+    traceable: bool,
+) -> PgResult<Tuple> {
+    backend::execute(id, name, max_rows, traceable)
+}
+
+#[proc]
+pub fn proc_pg_parse(
+    id: ClientId,
+    name: String,
+    query: String,
+    param_oids: Vec<Oid>,
+    traceable: bool,
+) -> PgResult<()> {
+    backend::parse(id, name, query, param_oids, traceable)
+}
+
+#[proc]
+pub fn proc_pg_close_stmt(id: ClientId, name: String) {
+    backend::close_stmt(id, name)
+}
+
+#[proc]
+pub fn proc_pg_close_portal(id: ClientId, name: String) {
+    backend::close_portal(id, name)
+}
+
+#[proc]
+pub fn proc_pg_close_client_stmts(id: ClientId) {
+    backend::close_client_stmts(id)
+}
+
+#[proc]
+pub fn proc_pg_close_client_portals(id: ClientId) {
+    backend::close_client_portals(id)
+}
+
+#[proc]
+pub fn proc_pg_statements(id: ClientId) -> UserStatementNames {
+    UserStatementNames::new(id)
+}
+
+#[proc]
+pub fn proc_pg_portals(id: ClientId) -> UserPortalNames {
+    UserPortalNames::new(id)
+}
-- 
GitLab