Skip to content
Snippets Groups Projects
Commit ba99dd2e authored by Dmitry Ivanov's avatar Dmitry Ivanov
Browse files

chores(pgproto): simplify PgError

parent 0e246bd7
No related branches found
No related tags found
1 merge request!999chores(pgproto): simplify error handling
use self::{
client::PgClient,
error::PgResult,
tls::{TlsAcceptor, TlsConfig, TlsError},
};
use self::{client::PgClient, error::PgResult, tls::TlsAcceptor};
use crate::{address::Address, introspection::Introspection, tlog, traft::error::Error};
use std::path::{Path, PathBuf};
use stream::PgStream;
......@@ -94,11 +90,6 @@ fn do_handle_client(
Ok(())
}
fn new_tls_acceptor(data_dir: &Path) -> Result<TlsAcceptor, TlsError> {
let tls_config = TlsConfig::from_data_dir(data_dir)?;
TlsAcceptor::new(&tls_config)
}
/// Server execution context.
pub struct Context {
server: CoIOListener,
......@@ -117,7 +108,7 @@ impl Context {
let tls_acceptor = config
.ssl()
.then(|| new_tls_acceptor(data_dir))
.then(|| TlsAcceptor::new_from_dir(data_dir))
.transpose()
.map_err(Error::invalid_configuration)?;
......
use super::tls::TlsError;
use pgwire::error::{ErrorInfo, PgWireError};
use std::error;
use std::io;
use std::num::{ParseFloatError, ParseIntError};
use std::str::ParseBoolError;
use std::string::FromUtf8Error;
use tarantool::error::BoxError;
use tarantool::error::IntoBoxError;
use std::{error, io};
use tarantool::error::{BoxError, IntoBoxError};
use thiserror::Error;
pub type PgResult<T> = Result<T, PgError>;
......@@ -26,69 +20,63 @@ pub enum PgError {
#[error("authentication failed for user '{0}'")]
InvalidPassword(String),
#[error("IO error: {0}")]
IoError(#[from] io::Error),
// Server could not encode value into client's format.
#[error("encoding error: {0}")]
EncodingError(Box<dyn error::Error>),
#[error("pgwire error: {0}")]
PgWireError(#[from] PgWireError),
#[error("lua error: {0}")]
LuaError(#[from] tarantool::tlua::LuaError),
#[error("json error: {0}")]
JsonError(#[from] serde_json::Error),
// Server could not decode value recieved from client.
#[error("{0}")]
#[error("decoding error: {0}")]
DecodingError(#[from] DecodingError),
#[error("tls error: {0}")]
TlsError(#[from] TlsError),
#[error("sbroad error: {0}")]
SbroadError(#[from] sbroad::errors::SbroadError),
#[error("traft error: {0}")]
TraftError(Box<crate::traft::error::Error>),
// Common error for postges protocol helpers.
#[error("pgwire error: {0}")]
PgWireError(#[from] PgWireError),
#[error("tarantool error: {0}")]
TarantoolError(#[from] tarantool::error::Error),
// This is picodata's main app error which incapsulates
// everything else, including sbroad and tarantool errors.
#[error("picodata error: {0}")]
PicodataError(#[from] crate::traft::error::Error),
#[error("encoding error: {0}")]
RmpSerdeEncode(#[from] rmp_serde::encode::Error),
// Generic IO error (TLS/SSL errors also go here).
#[error("IO error: {0}")]
IoError(#[from] io::Error),
#[error("{0}")]
Other(Box<dyn error::Error>),
}
impl From<sbroad::errors::SbroadError> for PgError {
#[inline(always)]
fn from(e: sbroad::errors::SbroadError) -> Self {
crate::traft::error::Error::from(e).into()
}
}
impl From<tarantool::error::Error> for PgError {
#[inline(always)]
fn from(e: tarantool::error::Error) -> Self {
crate::traft::error::Error::from(e).into()
}
}
#[derive(Error, Debug)]
pub enum DecodingError {
#[error("failed to decode int: {0}")]
ParseIntError(#[from] ParseIntError),
ParseIntError(#[from] std::num::ParseIntError),
#[error("failed to decode float: {0}")]
ParseFloatError(#[from] ParseFloatError),
ParseFloatError(#[from] std::num::ParseFloatError),
#[error("from utf8 error: {0}")]
FromUtf8Error(#[from] FromUtf8Error),
FromUtf8Error(#[from] std::string::FromUtf8Error),
#[error("failed to decode bool: {0}")]
ParseBoolError(#[from] ParseBoolError),
ParseBoolError(#[from] std::str::ParseBoolError),
#[error("decoding error: {0}")]
Other(Box<dyn error::Error>),
}
impl From<crate::traft::error::Error> for PgError {
fn from(value: crate::traft::error::Error) -> Self {
PgError::TraftError(value.into())
}
}
/// Build error info from PgError.
impl PgError {
pub fn info(&self) -> ErrorInfo {
......
......@@ -150,26 +150,18 @@ impl<S: io::Read + io::Write> PgStream<S> {
impl<S: io::Read + io::Write> PgStream<S> {
pub fn into_secure(self, acceptor: &TlsAcceptor) -> PgResult<PgStream<S>> {
let Self {
socket,
ibuf,
obuf,
startup_processed,
} = self;
if let PgSocket::Plain(socket) = socket {
let secure_socket = acceptor.accept(socket)?;
Ok(PgStream {
socket: PgSocket::Secure(secure_socket),
ibuf,
obuf,
startup_processed,
})
} else {
// this should be checked during the handshake
Err(PgError::ProtocolViolation(
"BUG: stream is already secured".into(),
))
}
let PgSocket::Plain(socket) = self.socket else {
return Err(PgError::ProtocolViolation(
"BUG: cannot upgrade TLS stream".into(),
));
};
let secure_socket = acceptor.accept(socket).map_err(io::Error::other)?;
let stream = PgStream {
socket: PgSocket::Secure(secure_socket),
..self
};
Ok(stream)
}
}
......@@ -6,34 +6,26 @@ use std::{fs, io, path::Path, path::PathBuf, rc::Rc};
use thiserror::Error;
#[derive(Error, Debug)]
pub enum TlsError {
#[error("error stack: {0}")]
ErrorStack(#[from] ErrorStack),
pub enum TlsHandshakeError {
#[error("setup failure: {0}")]
SetupFailure(ErrorStack),
#[error("handshake failure")]
HandshakeFailure,
}
#[derive(Error, Debug)]
pub enum TlsConfigError {
#[error("error stack: {0}")]
ErrorStack(#[from] ErrorStack),
// A helper error that indicates that the error happened with a cert file.
#[error("cert file error '{0}': {1}")]
CertFile(PathBuf, std::io::Error),
// A helper error that indicates that the error happened with a key file.
#[error("key file error '{0}': {1}")]
KeyFile(PathBuf, std::io::Error),
}
impl<S> From<HandshakeError<S>> for TlsError {
fn from(value: HandshakeError<S>) -> Self {
match value {
HandshakeError::SetupFailure(stack) => TlsError::SetupFailure(stack),
_ => TlsError::HandshakeFailure,
}
}
}
#[derive(Debug)]
pub struct TlsConfig {
cert: PathBuf,
......@@ -41,14 +33,14 @@ pub struct TlsConfig {
}
impl TlsConfig {
pub fn from_data_dir(data_dir: &Path) -> Result<Self, TlsError> {
pub fn from_data_dir(data_dir: &Path) -> Result<Self, TlsConfigError> {
// We should use the absolute paths here, because SslContextBuilder::set_certificate_chain_file
// fails for relative paths with an unclear error, represented as an empty error stack.
let cert = data_dir.join("server.crt");
let cert = fs::canonicalize(&cert).map_err(|e| TlsError::CertFile(cert, e))?;
let cert = fs::canonicalize(&cert).map_err(|e| TlsConfigError::CertFile(cert, e))?;
let key = data_dir.join("server.key");
let key = fs::canonicalize(&key).map_err(|e| TlsError::KeyFile(key, e))?;
let key = fs::canonicalize(&key).map_err(|e| TlsConfigError::KeyFile(key, e))?;
// TODO: Make sure that the file permissions are set to 0640 or 0600.
// See https://www.postgresql.org/docs/current/ssl-tcp.html#SSL-SETUP for details.
......@@ -62,17 +54,25 @@ pub type TlsStream<S> = SslStream<S>;
pub struct TlsAcceptor(Rc<ssl::SslAcceptor>);
impl TlsAcceptor {
pub fn new(config: &TlsConfig) -> Result<Self, TlsError> {
pub fn new(config: &TlsConfig) -> Result<Self, TlsConfigError> {
let mut builder = ssl::SslAcceptor::mozilla_intermediate_v5(SslMethod::tls())?;
builder.set_certificate_chain_file(&config.cert)?;
builder.set_private_key_file(&config.key, SslFiletype::PEM)?;
Ok(Self(builder.build().into()))
}
pub fn accept<S>(&self, stream: S) -> Result<TlsStream<S>, TlsError>
pub fn new_from_dir(data_dir: &Path) -> Result<Self, TlsConfigError> {
let tls_config = TlsConfig::from_data_dir(data_dir)?;
Self::new(&tls_config)
}
pub fn accept<S>(&self, stream: S) -> Result<TlsStream<S>, TlsHandshakeError>
where
S: io::Read + io::Write,
{
self.0.accept(stream).map_err(Into::into)
self.0.accept(stream).map_err(|e| match e {
HandshakeError::SetupFailure(stack) => TlsHandshakeError::SetupFailure(stack),
_ => TlsHandshakeError::HandshakeFailure,
})
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment