From 13310ca4cd5ddf96f1e9f209ebe6294bd43601cd Mon Sep 17 00:00:00 2001 From: Kaitmazian Maksim <m.kaitmazian@picodata.io> Date: Thu, 9 Nov 2023 19:36:25 +0300 Subject: [PATCH] feat: handle SslRequest --- pgproto/Cargo.lock | 49 ++++++++++++++++++++++++++--------- pgproto/Cargo.toml | 2 +- pgproto/src/client/startup.rs | 40 +++++++++++++++++++++------- pgproto/src/messages.rs | 5 ++++ pgproto/src/stream.rs | 6 +++++ pgproto/test/ssl_test.py | 41 +++++++++++++++++++++++++++++ 6 files changed, 121 insertions(+), 22 deletions(-) create mode 100644 pgproto/test/ssl_test.py diff --git a/pgproto/Cargo.lock b/pgproto/Cargo.lock index 193214500d..bbbeb61ab5 100644 --- a/pgproto/Cargo.lock +++ b/pgproto/Cargo.lock @@ -702,9 +702,8 @@ dependencies = [ [[package]] name = "pgwire" -version = "0.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "593c5af58c6394873b84c6fabf31f97e49ab29a56809e7fd240c1bcc4e5d272f" +version = "0.16.1" +source = "git+https://github.com/sunng87/pgwire?rev=6e653e851703bb743a96ba7a8e0f1fde6a5abeb0#6e653e851703bb743a96ba7a8e0f1fde6a5abeb0" dependencies = [ "async-trait", "base64 0.21.2", @@ -718,7 +717,7 @@ dependencies = [ "md5", "postgres-types", "rand", - "ring", + "ring 0.17.3", "stringprep", "thiserror", "time 0.3.17", @@ -881,12 +880,26 @@ dependencies = [ "cc", "libc", "once_cell", - "spin", - "untrusted", + "spin 0.5.2", + "untrusted 0.7.1", "web-sys", "winapi", ] +[[package]] +name = "ring" +version = "0.17.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9babe80d5c16becf6594aa32ad2be8fe08498e7ae60b77de8df700e67f191d7e" +dependencies = [ + "cc", + "getrandom", + "libc", + "spin 0.9.8", + "untrusted 0.9.0", + "windows-sys", +] + [[package]] name = "rmp" version = "0.8.12" @@ -934,7 +947,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d1feddffcfcc0b33f5c6ce9a29e341e4cd59c3f78e7ee45f4a40c038b1d6cbb" dependencies = [ "log", - "ring", + "ring 0.16.20", "rustls-webpki", "sct", ] @@ -945,8 +958,8 @@ version = "0.101.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "261e9e0888cba427c3316e6322805653c9425240b6fd96cee7cb671ab70ab8d0" dependencies = [ - "ring", - "untrusted", + "ring 0.16.20", + "untrusted 0.7.1", ] [[package]] @@ -961,8 +974,8 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d53dcdb7c9f8158937a7981b48accfd39a43af418591a5d008c7b22b5e1b7ca4" dependencies = [ - "ring", - "untrusted", + "ring 0.16.20", + "untrusted 0.7.1", ] [[package]] @@ -1077,6 +1090,12 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" + [[package]] name = "spki" version = "0.7.2" @@ -1379,6 +1398,12 @@ version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + [[package]] name = "uuid" version = "0.8.2" @@ -1582,7 +1607,7 @@ dependencies = [ "der", "hex", "pem", - "ring", + "ring 0.16.20", "signature", "spki", "thiserror", diff --git a/pgproto/Cargo.toml b/pgproto/Cargo.toml index 4a4bc3e20b..61621e4417 100644 --- a/pgproto/Cargo.toml +++ b/pgproto/Cargo.toml @@ -12,7 +12,7 @@ bytes = "1.4.0" ctor = "0.2.4" log = "0.4.20" once_cell = "1.18.0" -pgwire = "0.16.0" +pgwire = { git = "https://github.com/sunng87/pgwire", rev = "6e653e851703bb743a96ba7a8e0f1fde6a5abeb0" } rand = "0.8.5" rmp = "0.8.12" rmp-serde = "1.0.0" diff --git a/pgproto/src/client/startup.rs b/pgproto/src/client/startup.rs index bf432e59fd..038807a0e7 100644 --- a/pgproto/src/client/startup.rs +++ b/pgproto/src/client/startup.rs @@ -1,4 +1,7 @@ +use pgwire::messages::startup::Startup; + use crate::error::{PgError, PgResult}; +use crate::messages; use crate::stream::{FeMessage, PgStream}; use std::{collections::BTreeMap, io}; @@ -7,15 +10,7 @@ pub struct ClientParams { pub rest: BTreeMap<String, String>, } -/// Read startup message, verify parameters and return them. -pub fn handshake(stream: &mut PgStream<impl io::Read>) -> PgResult<ClientParams> { - let message = stream.read_message()?; - let FeMessage::Startup(mut startup) = message else { - return Err(PgError::ProtocolViolation(format!( - "expected Startup, got {message:?}" - ))); - }; - +fn parse_startup(mut startup: Startup) -> PgResult<ClientParams> { let mut parameters = std::mem::take(startup.parameters_mut()); log::debug!("client parameters: {parameters:?}"); @@ -32,3 +27,30 @@ pub fn handshake(stream: &mut PgStream<impl io::Read>) -> PgResult<ClientParams> rest: parameters, }) } +/// Respond to SslRequest if you receive it, read startup message, verify parameters and return them. +pub fn handshake(stream: &mut PgStream<impl io::Read + io::Write>) -> PgResult<ClientParams> { + let mut expect_startup = false; + loop { + let message = stream.read_message()?; + // At the beginning we can get SslRequest or Startup. + match message { + FeMessage::Startup(startup) => return parse_startup(startup), + FeMessage::SslRequest(_) => { + if expect_startup { + return Err(PgError::ProtocolViolation(format!( + "expected Startup, got {message:?}" + ))); + } else { + stream.write_message(messages::ssl_refuse())?; + // After SslRequest, only Startup is expected. + expect_startup = true; + } + } + _ => { + return Err(PgError::ProtocolViolation(format!( + "expected Startup or SslRequest, got {message:?}" + ))) + } + } + } +} diff --git a/pgproto/src/messages.rs b/pgproto/src/messages.rs index 7a2c071e39..86e50d8bd2 100644 --- a/pgproto/src/messages.rs +++ b/pgproto/src/messages.rs @@ -1,6 +1,7 @@ use crate::stream::BeMessage; use pgwire::error::ErrorInfo; use pgwire::messages::data::{DataRow, RowDescription}; +use pgwire::messages::response::SslResponse; use pgwire::messages::{response, startup::*}; /// MD5AuthRequest requests md5 password from the frontend. @@ -42,3 +43,7 @@ pub fn row_description(row_description: RowDescription) -> BeMessage { pub fn data_row(data_row: DataRow) -> BeMessage { BeMessage::DataRow(data_row) } + +pub fn ssl_refuse() -> BeMessage { + BeMessage::SslResponse(SslResponse::Refuse) +} diff --git a/pgproto/src/stream.rs b/pgproto/src/stream.rs index 94fd066511..74a37e7e55 100644 --- a/pgproto/src/stream.rs +++ b/pgproto/src/stream.rs @@ -1,5 +1,6 @@ use crate::error::PgResult; use bytes::{BufMut, BytesMut}; +use pgwire::messages::startup::SslRequest; use std::io::{self, ErrorKind::UnexpectedEof}; // Public re-exports. @@ -56,6 +57,11 @@ impl<S: io::Read> PgStream<S> { return FeMessage::decode(&mut self.ibuf).map_err(|e| e.into()); } + // Try to decode SslRequest first, as it fits the Startup format with an invalid version. + if let Some(ssl_request) = SslRequest::decode(&mut self.ibuf)? { + return Ok(Some(FeMessage::SslRequest(ssl_request))); + } + // This is done once at connection startup. let startup = Startup::decode(&mut self.ibuf)?.map(|x| { log::debug!("received StartupPacket from client"); diff --git a/pgproto/test/ssl_test.py b/pgproto/test/ssl_test.py new file mode 100644 index 0000000000..eeb963ea9f --- /dev/null +++ b/pgproto/test/ssl_test.py @@ -0,0 +1,41 @@ +import pytest +import pg8000.dbapi as pg # type: ignore +from conftest import Postgres +import os + + +def test_ssl_request_handling(postgres: Postgres): + host = "127.0.0.1" + port = 5432 + + postgres.start(host, port) + i1 = postgres.instance + + user = "user" + password = "password" + i1.eval("box.cfg{auth_type='md5'}") + i1.call("pico.create_user", user, password, dict(timeout=3)) + + # disable: only try a non-SSL connection + os.environ["PGSSLMODE"] = "disable" + conn = pg.Connection(user, password=password, host=host, port=port) + conn.close() + + # prefer: first try an SSL connection; if that fails, + # try a non-SSL connection. + # As ssl is not supported, server will respond to SslRequest with + # SslRefuse and client will try a non-SSL connection. + os.environ["PGSSLMODE"] = "prefer" + conn = pg.Connection(user, password=password, host=host, port=port) + conn.close() + + # require: only try an SSL connection. + # As ssl is not supported, server will respond to SslRequest with + # SslRefuse and client won't try to connect again. + # Client we will see: server does not support SSL, but SSL was required, + # but client doesn't have to inform the server. + os.environ["PGSSLMODE"] = "require" + with pytest.raises( + pg.DatabaseError, match=f"authentication failed for user '{user}'" + ): + pg.Connection(user, password="wrong password", host=host, port=port) -- GitLab