Skip to content
Snippets Groups Projects
Commit 13310ca4 authored by Maksim Kaitmazian's avatar Maksim Kaitmazian Committed by Maksim Kaitmazian
Browse files

feat: handle SslRequest

parent 121144f4
No related branches found
No related tags found
1 merge request!920pgproto module
......@@ -702,9 +702,8 @@ dependencies = [
[[package]]
name = "pgwire"
version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "593c5af58c6394873b84c6fabf31f97e49ab29a56809e7fd240c1bcc4e5d272f"
version = "0.16.1"
source = "git+https://github.com/sunng87/pgwire?rev=6e653e851703bb743a96ba7a8e0f1fde6a5abeb0#6e653e851703bb743a96ba7a8e0f1fde6a5abeb0"
dependencies = [
"async-trait",
"base64 0.21.2",
......@@ -718,7 +717,7 @@ dependencies = [
"md5",
"postgres-types",
"rand",
"ring",
"ring 0.17.3",
"stringprep",
"thiserror",
"time 0.3.17",
......@@ -881,12 +880,26 @@ dependencies = [
"cc",
"libc",
"once_cell",
"spin",
"untrusted",
"spin 0.5.2",
"untrusted 0.7.1",
"web-sys",
"winapi",
]
[[package]]
name = "ring"
version = "0.17.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9babe80d5c16becf6594aa32ad2be8fe08498e7ae60b77de8df700e67f191d7e"
dependencies = [
"cc",
"getrandom",
"libc",
"spin 0.9.8",
"untrusted 0.9.0",
"windows-sys",
]
[[package]]
name = "rmp"
version = "0.8.12"
......@@ -934,7 +947,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d1feddffcfcc0b33f5c6ce9a29e341e4cd59c3f78e7ee45f4a40c038b1d6cbb"
dependencies = [
"log",
"ring",
"ring 0.16.20",
"rustls-webpki",
"sct",
]
......@@ -945,8 +958,8 @@ version = "0.101.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "261e9e0888cba427c3316e6322805653c9425240b6fd96cee7cb671ab70ab8d0"
dependencies = [
"ring",
"untrusted",
"ring 0.16.20",
"untrusted 0.7.1",
]
[[package]]
......@@ -961,8 +974,8 @@ version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d53dcdb7c9f8158937a7981b48accfd39a43af418591a5d008c7b22b5e1b7ca4"
dependencies = [
"ring",
"untrusted",
"ring 0.16.20",
"untrusted 0.7.1",
]
[[package]]
......@@ -1077,6 +1090,12 @@ version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d"
[[package]]
name = "spin"
version = "0.9.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
[[package]]
name = "spki"
version = "0.7.2"
......@@ -1379,6 +1398,12 @@ version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a"
[[package]]
name = "untrusted"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
[[package]]
name = "uuid"
version = "0.8.2"
......@@ -1582,7 +1607,7 @@ dependencies = [
"der",
"hex",
"pem",
"ring",
"ring 0.16.20",
"signature",
"spki",
"thiserror",
......
......@@ -12,7 +12,7 @@ bytes = "1.4.0"
ctor = "0.2.4"
log = "0.4.20"
once_cell = "1.18.0"
pgwire = "0.16.0"
pgwire = { git = "https://github.com/sunng87/pgwire", rev = "6e653e851703bb743a96ba7a8e0f1fde6a5abeb0" }
rand = "0.8.5"
rmp = "0.8.12"
rmp-serde = "1.0.0"
......
use pgwire::messages::startup::Startup;
use crate::error::{PgError, PgResult};
use crate::messages;
use crate::stream::{FeMessage, PgStream};
use std::{collections::BTreeMap, io};
......@@ -7,15 +10,7 @@ pub struct ClientParams {
pub rest: BTreeMap<String, String>,
}
/// Read startup message, verify parameters and return them.
pub fn handshake(stream: &mut PgStream<impl io::Read>) -> PgResult<ClientParams> {
let message = stream.read_message()?;
let FeMessage::Startup(mut startup) = message else {
return Err(PgError::ProtocolViolation(format!(
"expected Startup, got {message:?}"
)));
};
fn parse_startup(mut startup: Startup) -> PgResult<ClientParams> {
let mut parameters = std::mem::take(startup.parameters_mut());
log::debug!("client parameters: {parameters:?}");
......@@ -32,3 +27,30 @@ pub fn handshake(stream: &mut PgStream<impl io::Read>) -> PgResult<ClientParams>
rest: parameters,
})
}
/// Respond to SslRequest if you receive it, read startup message, verify parameters and return them.
pub fn handshake(stream: &mut PgStream<impl io::Read + io::Write>) -> PgResult<ClientParams> {
let mut expect_startup = false;
loop {
let message = stream.read_message()?;
// At the beginning we can get SslRequest or Startup.
match message {
FeMessage::Startup(startup) => return parse_startup(startup),
FeMessage::SslRequest(_) => {
if expect_startup {
return Err(PgError::ProtocolViolation(format!(
"expected Startup, got {message:?}"
)));
} else {
stream.write_message(messages::ssl_refuse())?;
// After SslRequest, only Startup is expected.
expect_startup = true;
}
}
_ => {
return Err(PgError::ProtocolViolation(format!(
"expected Startup or SslRequest, got {message:?}"
)))
}
}
}
}
use crate::stream::BeMessage;
use pgwire::error::ErrorInfo;
use pgwire::messages::data::{DataRow, RowDescription};
use pgwire::messages::response::SslResponse;
use pgwire::messages::{response, startup::*};
/// MD5AuthRequest requests md5 password from the frontend.
......@@ -42,3 +43,7 @@ pub fn row_description(row_description: RowDescription) -> BeMessage {
pub fn data_row(data_row: DataRow) -> BeMessage {
BeMessage::DataRow(data_row)
}
pub fn ssl_refuse() -> BeMessage {
BeMessage::SslResponse(SslResponse::Refuse)
}
use crate::error::PgResult;
use bytes::{BufMut, BytesMut};
use pgwire::messages::startup::SslRequest;
use std::io::{self, ErrorKind::UnexpectedEof};
// Public re-exports.
......@@ -56,6 +57,11 @@ impl<S: io::Read> PgStream<S> {
return FeMessage::decode(&mut self.ibuf).map_err(|e| e.into());
}
// Try to decode SslRequest first, as it fits the Startup format with an invalid version.
if let Some(ssl_request) = SslRequest::decode(&mut self.ibuf)? {
return Ok(Some(FeMessage::SslRequest(ssl_request)));
}
// This is done once at connection startup.
let startup = Startup::decode(&mut self.ibuf)?.map(|x| {
log::debug!("received StartupPacket from client");
......
import pytest
import pg8000.dbapi as pg # type: ignore
from conftest import Postgres
import os
def test_ssl_request_handling(postgres: Postgres):
host = "127.0.0.1"
port = 5432
postgres.start(host, port)
i1 = postgres.instance
user = "user"
password = "password"
i1.eval("box.cfg{auth_type='md5'}")
i1.call("pico.create_user", user, password, dict(timeout=3))
# disable: only try a non-SSL connection
os.environ["PGSSLMODE"] = "disable"
conn = pg.Connection(user, password=password, host=host, port=port)
conn.close()
# prefer: first try an SSL connection; if that fails,
# try a non-SSL connection.
# As ssl is not supported, server will respond to SslRequest with
# SslRefuse and client will try a non-SSL connection.
os.environ["PGSSLMODE"] = "prefer"
conn = pg.Connection(user, password=password, host=host, port=port)
conn.close()
# require: only try an SSL connection.
# As ssl is not supported, server will respond to SslRequest with
# SslRefuse and client won't try to connect again.
# Client we will see: server does not support SSL, but SSL was required,
# but client doesn't have to inform the server.
os.environ["PGSSLMODE"] = "require"
with pytest.raises(
pg.DatabaseError, match=f"authentication failed for user '{user}'"
):
pg.Connection(user, password="wrong password", host=host, port=port)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment