Skip to content
Snippets Groups Projects
Commit ed8f5c37 authored by Dmitry Ivanov's avatar Dmitry Ivanov
Browse files

chores: use appropriate fiber spawning primitives in pgproto

parent ec67b2b0
No related branches found
No related tags found
1 merge request!997chores: use appropriate fiber spawning primitives in pgproto
Pipeline #40170 passed
use crate::address::Address;
use crate::tlog;
use tarantool::coio::CoIOListener;
use thiserror::Error;
use self::{
client::PgClient,
error::PgResult,
tls::{TlsAcceptor, TlsConfig, TlsError},
};
use crate::{address::Address, introspection::Introspection, tlog, traft::error::Error};
use std::path::{Path, PathBuf};
use stream::PgStream;
use tarantool::coio::{CoIOListener, CoIOStream};
mod client;
mod entrypoints;
......@@ -12,57 +17,7 @@ mod storage;
mod stream;
mod tls;
use self::client::PgClient;
use self::error::PgResult;
use self::tls::{TlsAcceptor, TlsConfig};
use crate::introspection::Introspection;
use std::cell::Cell;
use std::path::{Path, PathBuf};
use stream::PgStream;
use tarantool::{coio::CoIOStream, fiber::JoinHandle};
pub use error::PgError;
fn server_start(context: Context) {
let mut handles = vec![];
while let Ok(raw) = context.server.accept() {
let stream = PgStream::new(raw);
handles.push(handle_client(stream, context.tls_acceptor.clone()));
}
// TODO: this feels forced; find a better way.
for handle in handles {
handle.join();
}
}
fn handle_client(
client: PgStream<CoIOStream>,
tls_acceptor: Option<TlsAcceptor>,
) -> JoinHandle<'static, ()> {
tlog!(Info, "spawning a new fiber for postgres client connection");
tarantool::fiber::start(move || {
let res = do_handle_client(client, tls_acceptor);
if let Err(e) = res {
tlog!(Error, "postgres client connection error: {e}");
}
})
}
fn do_handle_client(
stream: PgStream<CoIOStream>,
tls_acceptor: Option<TlsAcceptor>,
) -> PgResult<()> {
let mut client = PgClient::accept(stream, tls_acceptor)?;
client.send_parameter("server_version", "15.0")?;
client.send_parameter("server_encoding", "UTF8")?;
client.send_parameter("client_encoding", "UTF8")?;
client.send_parameter("date_style", "ISO YMD")?;
client.send_parameter("integer_datetimes", "on")?;
client.process_messages_loop()?;
Ok(())
}
/// Main postgres server configuration.
#[derive(PartialEq, Default, Debug, Clone, serde::Deserialize, serde::Serialize, Introspection)]
#[serde(deny_unknown_fields)]
pub struct Config {
......@@ -89,34 +44,79 @@ impl Config {
}
}
#[derive(Error, Debug)]
pub enum ConfigError {
#[error("bad port: {0}")]
BadPort(String),
fn server_start(context: Context) {
while let Ok(raw) = context.server.accept() {
let stream = PgStream::new(raw);
if let Err(e) = handle_client(stream, context.tls_acceptor.clone()) {
tlog!(Error, "failed to handle client {e}");
}
}
}
fn handle_client(
client: PgStream<CoIOStream>,
tls_acceptor: Option<TlsAcceptor>,
) -> tarantool::Result<()> {
tlog!(Info, "spawning a new fiber for postgres client connection");
tarantool::fiber::Builder::new()
.name("pgproto::client")
.func(move || {
let res = do_handle_client(client, tls_acceptor);
if let Err(e) = res {
tlog!(Error, "postgres client connection error: {e}");
}
})
.start_non_joinable()?;
Ok(())
}
fn do_handle_client(
stream: PgStream<CoIOStream>,
tls_acceptor: Option<TlsAcceptor>,
) -> PgResult<()> {
let mut client = PgClient::accept(stream, tls_acceptor)?;
// Send important parameters to the client.
client
.send_parameter("server_version", "15.0")?
.send_parameter("server_encoding", "UTF8")?
.send_parameter("client_encoding", "UTF8")?
.send_parameter("date_style", "ISO YMD")?
.send_parameter("integer_datetimes", "on")?;
client.process_messages_loop()?;
Ok(())
}
fn new_tls_acceptor(data_dir: &Path) -> Result<TlsAcceptor, TlsError> {
let tls_config = TlsConfig::from_data_dir(data_dir)?;
TlsAcceptor::new(&tls_config)
}
/// Server execution context.
pub struct Context {
server: CoIOListener,
tls_acceptor: Option<TlsAcceptor>,
}
impl Context {
pub fn new(config: &Config, data_dir: &Path) -> PgResult<Self> {
pub fn new(config: &Config, data_dir: &Path) -> Result<Self, Error> {
assert!(config.enabled(), "must be checked before the call");
let listen = config.listen();
let host = listen.host.as_str();
let port = listen
.port
.parse::<u16>()
.map_err(|_| ConfigError::BadPort(listen.port.clone()))?;
let tls_acceptor = if config.ssl() {
let tls_config = TlsConfig::from_data_dir(data_dir)?;
Some(TlsAcceptor::new(&tls_config)?)
} else {
None
};
let port = listen.port.parse::<u16>().map_err(|_| {
Error::invalid_configuration(format!("bad postgres port {}", listen.port))
})?;
let tls_acceptor = config
.ssl()
.then(|| new_tls_acceptor(data_dir))
.transpose()
.map_err(Error::invalid_configuration)?;
let addr = (host, port);
tlog!(Info, "starting postgres server at {:?}...", addr);
......@@ -130,22 +130,13 @@ impl Context {
}
/// Start a postgres server fiber.
///
/// WARNING: It must be called only once, otherwise a panic will happen.
pub fn start(config: &Config, data_dir: PathBuf) -> PgResult<()> {
pub fn start(config: &Config, data_dir: PathBuf) -> Result<(), Error> {
let context = Context::new(config, &data_dir)?;
let handler = tarantool::fiber::start(move || server_start(context));
// There's currently no way of detaching a fiber without leaking memory,
// so we have to store it's join handle somewhere.
//
// From JoinHandle's doc:
// NOTE: if `JoinHandle` is dropped before [`JoinHandle::join`] is called on it
// a panic will happen.
thread_local! {
static FIBER_JOIN_HANDLE: Cell<Option<JoinHandle<'static, ()>>> = Cell::new(None);
}
FIBER_JOIN_HANDLE.replace(Some(handler));
tarantool::fiber::Builder::new()
.name("pgproto")
.func(move || server_start(context))
.start_non_joinable()?;
Ok(())
}
......@@ -339,7 +339,7 @@ impl Entrypoints {
let json: String = self
.simple_query
.call_with_args((client_id, sql))
.map_err(|e| PgError::TarantoolError(e.into()))?;
.map_err(|e| PgError::LuaError(e.into()))?;
simple_execute_result_from_json(&json)
}
......@@ -353,7 +353,7 @@ impl Entrypoints {
) -> PgResult<()> {
self.parse
.call_with_args((client_id, name, sql, param_oids))
.map_err(|e| PgError::TarantoolError(e.into()))
.map_err(|e| PgError::LuaError(e.into()))
}
/// Handler for a Bind message. See self.bind for the details.
......@@ -367,7 +367,7 @@ impl Entrypoints {
) -> PgResult<()> {
self.bind
.call_with_args((id, statement, portal, params, result_format))
.map_err(|e| PgError::TarantoolError(e.into()))
.map_err(|e| PgError::LuaError(e.into()))
}
/// Handler for an Execute message. See self.execute for the details.
......@@ -375,7 +375,7 @@ impl Entrypoints {
let json: String = self
.execute
.call_with_args((id, portal))
.map_err(|e| PgError::TarantoolError(e.into()))?;
.map_err(|e| PgError::LuaError(e.into()))?;
execute_result_from_json(&json)
}
......@@ -384,7 +384,7 @@ impl Entrypoints {
let json: String = self
.describe_portal
.call_with_args((client_id, portal))
.map_err(|e| PgError::TarantoolError(e.into()))?;
.map_err(|e| PgError::LuaError(e.into()))?;
let describe = serde_json::from_str(&json)?;
Ok(describe)
}
......@@ -398,7 +398,7 @@ impl Entrypoints {
let json: String = self
.describe_statement
.call_with_args((client_id, statement))
.map_err(|e| PgError::TarantoolError(e.into()))?;
.map_err(|e| PgError::LuaError(e.into()))?;
let describe = serde_json::from_str(&json)?;
Ok(describe)
}
......@@ -407,28 +407,28 @@ impl Entrypoints {
pub fn close_portal(&self, id: ClientId, portal: &str) -> PgResult<()> {
self.close_portal
.call_with_args((id, portal))
.map_err(|e| PgError::TarantoolError(e.into()))
.map_err(|e| PgError::LuaError(e.into()))
}
/// Handler for a Close message. See self.close_statement for the details.
pub fn close_statement(&self, client_id: ClientId, statement: &str) -> PgResult<()> {
self.close_statement
.call_with_args((client_id, statement))
.map_err(|e| PgError::TarantoolError(e.into()))
.map_err(|e| PgError::LuaError(e.into()))
}
/// Close all the client statements and portals. See self.close_client_statements for the details.
pub fn close_client_statements(&self, client_id: ClientId) -> PgResult<()> {
self.close_client_statements
.call_with_args(client_id)
.map_err(|e| PgError::TarantoolError(e.into()))
.map_err(|e| PgError::LuaError(e.into()))
}
/// Close client statements with its portals.
pub fn close_client_portals(&self, client_id: ClientId) -> PgResult<()> {
self.close_client_portals
.call_with_args(client_id)
.map_err(|e| PgError::TarantoolError(e.into()))
.map_err(|e| PgError::LuaError(e.into()))
}
}
......
use super::tls::TlsError;
use pgwire::error::{ErrorInfo, PgWireError};
use std::env;
use std::error;
use std::io;
use std::num::{ParseFloatError, ParseIntError};
......@@ -7,9 +7,6 @@ use std::str::ParseBoolError;
use std::string::FromUtf8Error;
use thiserror::Error;
use super::tls::TlsError;
use super::ConfigError;
pub type PgResult<T> = Result<T, PgError>;
/// See <https://www.postgresql.org/docs/current/errcodes-appendix.html>.
......@@ -34,7 +31,7 @@ pub enum PgError {
PgWireError(#[from] PgWireError),
#[error("lua error: {0}")]
TarantoolError(#[from] tarantool::tlua::LuaError),
LuaError(#[from] tarantool::tlua::LuaError),
#[error("json error: {0}")]
JsonError(#[from] serde_json::Error),
......@@ -44,12 +41,6 @@ pub enum PgError {
#[error("tls error: {0}")]
TlsError(#[from] TlsError),
#[error("env error: {0}")]
EnvError(#[from] env::VarError),
#[error("config error: {0}")]
ConfigError(#[from] ConfigError),
}
#[derive(Error, Debug)]
......
......@@ -2,7 +2,7 @@ use crate::tlog;
use std::io;
use tarantool::coio::CoIOListener;
pub fn new_listener(addr: (&str, u16)) -> io::Result<CoIOListener> {
pub fn new_listener(addr: (&str, u16)) -> tarantool::Result<CoIOListener> {
let mut socket = None;
let mut f = |_| {
let wrapped = std::net::TcpListener::bind(addr);
......@@ -12,9 +12,10 @@ pub fn new_listener(addr: (&str, u16)) -> io::Result<CoIOListener> {
};
if tarantool::coio::coio_call(&mut f, ()) != 0 {
return Err(io::Error::last_os_error());
return Err(io::Error::last_os_error().into());
}
let socket = socket.expect("uninitialized socket")?;
tarantool::coio::CoIOListener::try_from(socket)
let listener = tarantool::coio::CoIOListener::try_from(socket)?;
Ok(listener)
}
use super::error::PgResult;
use openssl::{
error::ErrorStack,
ssl::{self, HandshakeError, SslFiletype, SslMethod, SslStream},
......@@ -42,7 +41,7 @@ pub struct TlsConfig {
}
impl TlsConfig {
pub fn from_data_dir(data_dir: &Path) -> PgResult<Self> {
pub fn from_data_dir(data_dir: &Path) -> Result<Self, TlsError> {
// We should use the absolute paths here, because SslContextBuilder::set_certificate_chain_file
// fails for relative paths with an unclear error, represented as an empty error stack.
let cert = data_dir.join("server.crt");
......
use std::fmt::{Debug, Display};
use crate::instance::InstanceId;
use crate::pgproto;
use crate::plugin::PluginError;
use crate::traft::{RaftId, RaftTerm};
use tarantool::error::{BoxError, IntoBoxError};
......@@ -86,9 +85,6 @@ pub enum Error {
#[error(transparent)]
Plugin(#[from] PluginError),
#[error(transparent)]
PgProto(#[from] pgproto::PgError),
#[error("{0}")]
Other(Box<dyn std::error::Error>),
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment