diff --git a/src/pgproto.rs b/src/pgproto.rs index da9f8629a88f2f9a3a47a95a5ffcf3b51d48f111..9698a6089903e6adcf7216c086ceb949e7525d49 100644 --- a/src/pgproto.rs +++ b/src/pgproto.rs @@ -1,7 +1,12 @@ -use crate::address::Address; -use crate::tlog; -use tarantool::coio::CoIOListener; -use thiserror::Error; +use self::{ + client::PgClient, + error::PgResult, + tls::{TlsAcceptor, TlsConfig, TlsError}, +}; +use crate::{address::Address, introspection::Introspection, tlog, traft::error::Error}; +use std::path::{Path, PathBuf}; +use stream::PgStream; +use tarantool::coio::{CoIOListener, CoIOStream}; mod client; mod entrypoints; @@ -12,57 +17,7 @@ mod storage; mod stream; mod tls; -use self::client::PgClient; -use self::error::PgResult; -use self::tls::{TlsAcceptor, TlsConfig}; -use crate::introspection::Introspection; -use std::cell::Cell; -use std::path::{Path, PathBuf}; -use stream::PgStream; -use tarantool::{coio::CoIOStream, fiber::JoinHandle}; - -pub use error::PgError; - -fn server_start(context: Context) { - let mut handles = vec![]; - while let Ok(raw) = context.server.accept() { - let stream = PgStream::new(raw); - handles.push(handle_client(stream, context.tls_acceptor.clone())); - } - - // TODO: this feels forced; find a better way. - for handle in handles { - handle.join(); - } -} - -fn handle_client( - client: PgStream<CoIOStream>, - tls_acceptor: Option<TlsAcceptor>, -) -> JoinHandle<'static, ()> { - tlog!(Info, "spawning a new fiber for postgres client connection"); - tarantool::fiber::start(move || { - let res = do_handle_client(client, tls_acceptor); - if let Err(e) = res { - tlog!(Error, "postgres client connection error: {e}"); - } - }) -} - -fn do_handle_client( - stream: PgStream<CoIOStream>, - tls_acceptor: Option<TlsAcceptor>, -) -> PgResult<()> { - let mut client = PgClient::accept(stream, tls_acceptor)?; - client.send_parameter("server_version", "15.0")?; - client.send_parameter("server_encoding", "UTF8")?; - client.send_parameter("client_encoding", "UTF8")?; - client.send_parameter("date_style", "ISO YMD")?; - client.send_parameter("integer_datetimes", "on")?; - client.process_messages_loop()?; - Ok(()) -} - +/// Main postgres server configuration. #[derive(PartialEq, Default, Debug, Clone, serde::Deserialize, serde::Serialize, Introspection)] #[serde(deny_unknown_fields)] pub struct Config { @@ -89,34 +44,79 @@ impl Config { } } -#[derive(Error, Debug)] -pub enum ConfigError { - #[error("bad port: {0}")] - BadPort(String), +fn server_start(context: Context) { + while let Ok(raw) = context.server.accept() { + let stream = PgStream::new(raw); + if let Err(e) = handle_client(stream, context.tls_acceptor.clone()) { + tlog!(Error, "failed to handle client {e}"); + } + } +} + +fn handle_client( + client: PgStream<CoIOStream>, + tls_acceptor: Option<TlsAcceptor>, +) -> tarantool::Result<()> { + tlog!(Info, "spawning a new fiber for postgres client connection"); + + tarantool::fiber::Builder::new() + .name("pgproto::client") + .func(move || { + let res = do_handle_client(client, tls_acceptor); + if let Err(e) = res { + tlog!(Error, "postgres client connection error: {e}"); + } + }) + .start_non_joinable()?; + + Ok(()) } +fn do_handle_client( + stream: PgStream<CoIOStream>, + tls_acceptor: Option<TlsAcceptor>, +) -> PgResult<()> { + let mut client = PgClient::accept(stream, tls_acceptor)?; + + // Send important parameters to the client. + client + .send_parameter("server_version", "15.0")? + .send_parameter("server_encoding", "UTF8")? + .send_parameter("client_encoding", "UTF8")? + .send_parameter("date_style", "ISO YMD")? + .send_parameter("integer_datetimes", "on")?; + + client.process_messages_loop()?; + + Ok(()) +} + +fn new_tls_acceptor(data_dir: &Path) -> Result<TlsAcceptor, TlsError> { + let tls_config = TlsConfig::from_data_dir(data_dir)?; + TlsAcceptor::new(&tls_config) +} + +/// Server execution context. pub struct Context { server: CoIOListener, tls_acceptor: Option<TlsAcceptor>, } impl Context { - pub fn new(config: &Config, data_dir: &Path) -> PgResult<Self> { + pub fn new(config: &Config, data_dir: &Path) -> Result<Self, Error> { assert!(config.enabled(), "must be checked before the call"); let listen = config.listen(); let host = listen.host.as_str(); - let port = listen - .port - .parse::<u16>() - .map_err(|_| ConfigError::BadPort(listen.port.clone()))?; - - let tls_acceptor = if config.ssl() { - let tls_config = TlsConfig::from_data_dir(data_dir)?; - Some(TlsAcceptor::new(&tls_config)?) - } else { - None - }; + let port = listen.port.parse::<u16>().map_err(|_| { + Error::invalid_configuration(format!("bad postgres port {}", listen.port)) + })?; + + let tls_acceptor = config + .ssl() + .then(|| new_tls_acceptor(data_dir)) + .transpose() + .map_err(Error::invalid_configuration)?; let addr = (host, port); tlog!(Info, "starting postgres server at {:?}...", addr); @@ -130,22 +130,13 @@ impl Context { } /// Start a postgres server fiber. -/// -/// WARNING: It must be called only once, otherwise a panic will happen. -pub fn start(config: &Config, data_dir: PathBuf) -> PgResult<()> { +pub fn start(config: &Config, data_dir: PathBuf) -> Result<(), Error> { let context = Context::new(config, &data_dir)?; - let handler = tarantool::fiber::start(move || server_start(context)); - - // There's currently no way of detaching a fiber without leaking memory, - // so we have to store it's join handle somewhere. - // - // From JoinHandle's doc: - // NOTE: if `JoinHandle` is dropped before [`JoinHandle::join`] is called on it - // a panic will happen. - thread_local! { - static FIBER_JOIN_HANDLE: Cell<Option<JoinHandle<'static, ()>>> = Cell::new(None); - } - FIBER_JOIN_HANDLE.replace(Some(handler)); + + tarantool::fiber::Builder::new() + .name("pgproto") + .func(move || server_start(context)) + .start_non_joinable()?; Ok(()) } diff --git a/src/pgproto/entrypoints.rs b/src/pgproto/entrypoints.rs index b9462a8a9c52a148b638b91a95ebb4a8785e7c03..763feffbe79cecca9f40567ffbb6bcabb44505d9 100644 --- a/src/pgproto/entrypoints.rs +++ b/src/pgproto/entrypoints.rs @@ -339,7 +339,7 @@ impl Entrypoints { let json: String = self .simple_query .call_with_args((client_id, sql)) - .map_err(|e| PgError::TarantoolError(e.into()))?; + .map_err(|e| PgError::LuaError(e.into()))?; simple_execute_result_from_json(&json) } @@ -353,7 +353,7 @@ impl Entrypoints { ) -> PgResult<()> { self.parse .call_with_args((client_id, name, sql, param_oids)) - .map_err(|e| PgError::TarantoolError(e.into())) + .map_err(|e| PgError::LuaError(e.into())) } /// Handler for a Bind message. See self.bind for the details. @@ -367,7 +367,7 @@ impl Entrypoints { ) -> PgResult<()> { self.bind .call_with_args((id, statement, portal, params, result_format)) - .map_err(|e| PgError::TarantoolError(e.into())) + .map_err(|e| PgError::LuaError(e.into())) } /// Handler for an Execute message. See self.execute for the details. @@ -375,7 +375,7 @@ impl Entrypoints { let json: String = self .execute .call_with_args((id, portal)) - .map_err(|e| PgError::TarantoolError(e.into()))?; + .map_err(|e| PgError::LuaError(e.into()))?; execute_result_from_json(&json) } @@ -384,7 +384,7 @@ impl Entrypoints { let json: String = self .describe_portal .call_with_args((client_id, portal)) - .map_err(|e| PgError::TarantoolError(e.into()))?; + .map_err(|e| PgError::LuaError(e.into()))?; let describe = serde_json::from_str(&json)?; Ok(describe) } @@ -398,7 +398,7 @@ impl Entrypoints { let json: String = self .describe_statement .call_with_args((client_id, statement)) - .map_err(|e| PgError::TarantoolError(e.into()))?; + .map_err(|e| PgError::LuaError(e.into()))?; let describe = serde_json::from_str(&json)?; Ok(describe) } @@ -407,28 +407,28 @@ impl Entrypoints { pub fn close_portal(&self, id: ClientId, portal: &str) -> PgResult<()> { self.close_portal .call_with_args((id, portal)) - .map_err(|e| PgError::TarantoolError(e.into())) + .map_err(|e| PgError::LuaError(e.into())) } /// Handler for a Close message. See self.close_statement for the details. pub fn close_statement(&self, client_id: ClientId, statement: &str) -> PgResult<()> { self.close_statement .call_with_args((client_id, statement)) - .map_err(|e| PgError::TarantoolError(e.into())) + .map_err(|e| PgError::LuaError(e.into())) } /// Close all the client statements and portals. See self.close_client_statements for the details. pub fn close_client_statements(&self, client_id: ClientId) -> PgResult<()> { self.close_client_statements .call_with_args(client_id) - .map_err(|e| PgError::TarantoolError(e.into())) + .map_err(|e| PgError::LuaError(e.into())) } /// Close client statements with its portals. pub fn close_client_portals(&self, client_id: ClientId) -> PgResult<()> { self.close_client_portals .call_with_args(client_id) - .map_err(|e| PgError::TarantoolError(e.into())) + .map_err(|e| PgError::LuaError(e.into())) } } diff --git a/src/pgproto/error.rs b/src/pgproto/error.rs index 553d270a89442bfa2b3fd6d94ad6b8207cd51fb8..8d78bcd4d298a3e4bb3031f61c9b2ece87ce06b2 100644 --- a/src/pgproto/error.rs +++ b/src/pgproto/error.rs @@ -1,5 +1,5 @@ +use super::tls::TlsError; use pgwire::error::{ErrorInfo, PgWireError}; -use std::env; use std::error; use std::io; use std::num::{ParseFloatError, ParseIntError}; @@ -7,9 +7,6 @@ use std::str::ParseBoolError; use std::string::FromUtf8Error; use thiserror::Error; -use super::tls::TlsError; -use super::ConfigError; - pub type PgResult<T> = Result<T, PgError>; /// See <https://www.postgresql.org/docs/current/errcodes-appendix.html>. @@ -34,7 +31,7 @@ pub enum PgError { PgWireError(#[from] PgWireError), #[error("lua error: {0}")] - TarantoolError(#[from] tarantool::tlua::LuaError), + LuaError(#[from] tarantool::tlua::LuaError), #[error("json error: {0}")] JsonError(#[from] serde_json::Error), @@ -44,12 +41,6 @@ pub enum PgError { #[error("tls error: {0}")] TlsError(#[from] TlsError), - - #[error("env error: {0}")] - EnvError(#[from] env::VarError), - - #[error("config error: {0}")] - ConfigError(#[from] ConfigError), } #[derive(Error, Debug)] diff --git a/src/pgproto/server.rs b/src/pgproto/server.rs index c1177e2c28953ac738763d1b12196f1e5ad19a12..dac9d682edc06469e8ca5c5c61c5cd42093ad1e9 100644 --- a/src/pgproto/server.rs +++ b/src/pgproto/server.rs @@ -2,7 +2,7 @@ use crate::tlog; use std::io; use tarantool::coio::CoIOListener; -pub fn new_listener(addr: (&str, u16)) -> io::Result<CoIOListener> { +pub fn new_listener(addr: (&str, u16)) -> tarantool::Result<CoIOListener> { let mut socket = None; let mut f = |_| { let wrapped = std::net::TcpListener::bind(addr); @@ -12,9 +12,10 @@ pub fn new_listener(addr: (&str, u16)) -> io::Result<CoIOListener> { }; if tarantool::coio::coio_call(&mut f, ()) != 0 { - return Err(io::Error::last_os_error()); + return Err(io::Error::last_os_error().into()); } let socket = socket.expect("uninitialized socket")?; - tarantool::coio::CoIOListener::try_from(socket) + let listener = tarantool::coio::CoIOListener::try_from(socket)?; + Ok(listener) } diff --git a/src/pgproto/tls.rs b/src/pgproto/tls.rs index cfc9837044068ad6aeac59925aff6b968008d15b..e80e7c70f7bcaa2fea6feec4bbfa81f067cb27e2 100644 --- a/src/pgproto/tls.rs +++ b/src/pgproto/tls.rs @@ -1,4 +1,3 @@ -use super::error::PgResult; use openssl::{ error::ErrorStack, ssl::{self, HandshakeError, SslFiletype, SslMethod, SslStream}, @@ -42,7 +41,7 @@ pub struct TlsConfig { } impl TlsConfig { - pub fn from_data_dir(data_dir: &Path) -> PgResult<Self> { + pub fn from_data_dir(data_dir: &Path) -> Result<Self, TlsError> { // We should use the absolute paths here, because SslContextBuilder::set_certificate_chain_file // fails for relative paths with an unclear error, represented as an empty error stack. let cert = data_dir.join("server.crt"); diff --git a/src/traft/error.rs b/src/traft/error.rs index a8895f49d957917729ee089f0c31be05331a441b..a1662cf40ffb96f8ee71425ce39e183d6a57c450 100644 --- a/src/traft/error.rs +++ b/src/traft/error.rs @@ -1,7 +1,6 @@ use std::fmt::{Debug, Display}; use crate::instance::InstanceId; -use crate::pgproto; use crate::plugin::PluginError; use crate::traft::{RaftId, RaftTerm}; use tarantool::error::{BoxError, IntoBoxError}; @@ -86,9 +85,6 @@ pub enum Error { #[error(transparent)] Plugin(#[from] PluginError), - #[error(transparent)] - PgProto(#[from] pgproto::PgError), - #[error("{0}")] Other(Box<dyn std::error::Error>), }