Skip to content
Snippets Groups Projects

feat: implement mTLS for Pgproto

Merged Roman Kuzmin requested to merge kuzmin/pgproto-mtls into master
8 files
+ 190
17
Compare changes
  • Side-by-side
  • Inline
Files
8
+ 37
1
use openssl::ssl::SslVerifyMode;
use openssl::x509::store::X509StoreBuilder;
use openssl::x509::X509;
use openssl::{
error::ErrorStack,
ssl::{self, HandshakeError, SslFiletype, SslMethod, SslStream},
};
use std::{fs, io, path::Path, path::PathBuf, rc::Rc};
use thiserror::Error;
@@ -24,12 +28,17 @@ pub enum TlsConfigError {
#[error("key file error '{0}': {1}")]
KeyFile(PathBuf, std::io::Error),
#[error("ca file error '{0}': {1}")]
CaFile(PathBuf, std::io::Error),
}
#[derive(Debug)]
pub struct TlsConfig {
cert: PathBuf,
key: PathBuf,
// Optional CA certificate for peer certificates verification (mTLS).
ca_cert: Option<PathBuf>,
}
impl TlsConfig {
@@ -42,9 +51,16 @@ impl TlsConfig {
let key = data_dir.join("server.key");
let key = fs::canonicalize(&key).map_err(|e| TlsConfigError::KeyFile(key, e))?;
let ca_cert = data_dir.join("ca.crt");
let ca_cert = match fs::canonicalize(&ca_cert) {
Ok(path) => Some(path),
Err(e) if e.kind() == io::ErrorKind::NotFound => None,
Err(e) => Err(TlsConfigError::CaFile(ca_cert, e))?,
};
// TODO: Make sure that the file permissions are set to 0640 or 0600.
// See https://www.postgresql.org/docs/current/ssl-tcp.html#SSL-SETUP for details.
Ok(Self { key, cert })
Ok(Self { key, cert, ca_cert })
}
}
@@ -58,6 +74,18 @@ impl TlsAcceptor {
let mut builder = ssl::SslAcceptor::mozilla_intermediate_v5(SslMethod::tls())?;
builder.set_certificate_chain_file(&config.cert)?;
builder.set_private_key_file(&config.key, SslFiletype::PEM)?;
if let Some(path) = &config.ca_cert {
let pem = fs::read(path).map_err(|e| TlsConfigError::CaFile(path.clone(), e))?;
let mut store_builder = X509StoreBuilder::new()?;
store_builder.add_cert(X509::from_pem(&pem)?)?;
builder.set_verify_cert_store(store_builder.build())?;
let mut verify_mode = SslVerifyMode::PEER;
verify_mode.insert(SslVerifyMode::FAIL_IF_NO_PEER_CERT);
builder.set_verify(verify_mode);
}
Ok(Self(builder.build().into()))
}
@@ -75,4 +103,12 @@ impl TlsAcceptor {
_ => TlsHandshakeError::HandshakeFailure,
})
}
pub fn kind(&self) -> &'static str {
if self.0.context().verify_mode().contains(SslVerifyMode::PEER) {
"mTLS"
} else {
"TLS"
}
}
}
Loading