Skip to content
Snippets Groups Projects
pgproto.rs 3.63 KiB
use self::{client::PgClient, error::PgResult, tls::TlsAcceptor};
use crate::{address::Address, introspection::Introspection, tlog, traft::error::Error};
use std::path::{Path, PathBuf};
use stream::PgStream;
use tarantool::coio::{CoIOListener, CoIOStream};

mod backend;
mod client;
mod error;
mod messages;
mod server;
mod stream;
mod tls;
mod value;

pub const DEFAULT_MAX_PG_STATEMENTS: usize = 50;
pub const DEFAULT_MAX_PG_PORTALS: usize = 50;

/// Main postgres server configuration.
#[derive(PartialEq, Default, Debug, Clone, serde::Deserialize, serde::Serialize, Introspection)]
#[serde(deny_unknown_fields)]
pub struct Config {
    pub listen: Option<Address>,

    #[introspection(config_default = false)]
    pub ssl: Option<bool>,
}

impl Config {
    pub fn enabled(&self) -> bool {
        // Pgproto is enabled if listen was specified.
        self.listen.is_some()
    }

    pub fn listen(&self) -> Address {
        self.listen
            .clone()
            .expect("must be checked before the call")
    }

    pub fn ssl(&self) -> bool {
        self.ssl.expect("set by default")
    }
}

fn server_start(context: Context) {
    while let Ok(raw) = context.server.accept() {
        let stream = PgStream::new(raw);
        if let Err(e) = handle_client(stream, context.tls_acceptor.clone()) {
            tlog!(Error, "failed to handle client {e}");
        }
    }
}

fn handle_client(
    client: PgStream<CoIOStream>,
    tls_acceptor: Option<TlsAcceptor>,
) -> tarantool::Result<()> {
    tlog!(Info, "spawning a new fiber for postgres client connection");

    tarantool::fiber::Builder::new()
        .name("pgproto::client")
        .func(move || {
            let res = do_handle_client(client, tls_acceptor);
            if let Err(e) = res {
                tlog!(Error, "postgres client connection error: {e}");
            }
        })
        .start_non_joinable()?;
    Ok(())
}

fn do_handle_client(
    stream: PgStream<CoIOStream>,
    tls_acceptor: Option<TlsAcceptor>,
) -> PgResult<()> {
    let mut client = PgClient::accept(stream, tls_acceptor)?;

    // Send important parameters to the client.
    client
        .send_parameter("server_version", "15.0")?
        .send_parameter("server_encoding", "UTF8")?
        .send_parameter("client_encoding", "UTF8")?
        .send_parameter("date_style", "ISO YMD")?
        .send_parameter("integer_datetimes", "on")?;

    client.process_messages_loop()?;

    Ok(())
}

/// Server execution context.
pub struct Context {
    server: CoIOListener,
    tls_acceptor: Option<TlsAcceptor>,
}

impl Context {
    pub fn new(config: &Config, data_dir: &Path) -> Result<Self, Error> {
        assert!(config.enabled(), "must be checked before the call");

        let listen = config.listen();
        let host = listen.host.as_str();
        let port = listen.port.parse::<u16>().map_err(|_| {
            Error::invalid_configuration(format!("bad postgres port {}", listen.port))
        })?;

        let tls_acceptor = config
            .ssl()
            .then(|| TlsAcceptor::new_from_dir(data_dir))
            .transpose()
            .map_err(Error::invalid_configuration)?;

        let addr = (host, port);
        tlog!(Info, "starting postgres server at {:?}...", addr);
        let server = server::new_listener(addr)?;

        Ok(Self {
            server,
            tls_acceptor,
        })
    }
}

/// Start a postgres server fiber.
pub fn start(config: &Config, data_dir: PathBuf) -> Result<(), Error> {
    let context = Context::new(config, &data_dir)?;

    tarantool::fiber::Builder::new()
        .name("pgproto")
        .func(move || server_start(context))
        .start_non_joinable()?;

    Ok(())
}