From 13310ca4cd5ddf96f1e9f209ebe6294bd43601cd Mon Sep 17 00:00:00 2001
From: Kaitmazian Maksim <m.kaitmazian@picodata.io>
Date: Thu, 9 Nov 2023 19:36:25 +0300
Subject: [PATCH] feat: handle SslRequest

---
 pgproto/Cargo.lock            | 49 ++++++++++++++++++++++++++---------
 pgproto/Cargo.toml            |  2 +-
 pgproto/src/client/startup.rs | 40 +++++++++++++++++++++-------
 pgproto/src/messages.rs       |  5 ++++
 pgproto/src/stream.rs         |  6 +++++
 pgproto/test/ssl_test.py      | 41 +++++++++++++++++++++++++++++
 6 files changed, 121 insertions(+), 22 deletions(-)
 create mode 100644 pgproto/test/ssl_test.py

diff --git a/pgproto/Cargo.lock b/pgproto/Cargo.lock
index 193214500d..bbbeb61ab5 100644
--- a/pgproto/Cargo.lock
+++ b/pgproto/Cargo.lock
@@ -702,9 +702,8 @@ dependencies = [
 
 [[package]]
 name = "pgwire"
-version = "0.16.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "593c5af58c6394873b84c6fabf31f97e49ab29a56809e7fd240c1bcc4e5d272f"
+version = "0.16.1"
+source = "git+https://github.com/sunng87/pgwire?rev=6e653e851703bb743a96ba7a8e0f1fde6a5abeb0#6e653e851703bb743a96ba7a8e0f1fde6a5abeb0"
 dependencies = [
  "async-trait",
  "base64 0.21.2",
@@ -718,7 +717,7 @@ dependencies = [
  "md5",
  "postgres-types",
  "rand",
- "ring",
+ "ring 0.17.3",
  "stringprep",
  "thiserror",
  "time 0.3.17",
@@ -881,12 +880,26 @@ dependencies = [
  "cc",
  "libc",
  "once_cell",
- "spin",
- "untrusted",
+ "spin 0.5.2",
+ "untrusted 0.7.1",
  "web-sys",
  "winapi",
 ]
 
+[[package]]
+name = "ring"
+version = "0.17.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9babe80d5c16becf6594aa32ad2be8fe08498e7ae60b77de8df700e67f191d7e"
+dependencies = [
+ "cc",
+ "getrandom",
+ "libc",
+ "spin 0.9.8",
+ "untrusted 0.9.0",
+ "windows-sys",
+]
+
 [[package]]
 name = "rmp"
 version = "0.8.12"
@@ -934,7 +947,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "1d1feddffcfcc0b33f5c6ce9a29e341e4cd59c3f78e7ee45f4a40c038b1d6cbb"
 dependencies = [
  "log",
- "ring",
+ "ring 0.16.20",
  "rustls-webpki",
  "sct",
 ]
@@ -945,8 +958,8 @@ version = "0.101.3"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "261e9e0888cba427c3316e6322805653c9425240b6fd96cee7cb671ab70ab8d0"
 dependencies = [
- "ring",
- "untrusted",
+ "ring 0.16.20",
+ "untrusted 0.7.1",
 ]
 
 [[package]]
@@ -961,8 +974,8 @@ version = "0.7.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "d53dcdb7c9f8158937a7981b48accfd39a43af418591a5d008c7b22b5e1b7ca4"
 dependencies = [
- "ring",
- "untrusted",
+ "ring 0.16.20",
+ "untrusted 0.7.1",
 ]
 
 [[package]]
@@ -1077,6 +1090,12 @@ version = "0.5.2"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d"
 
+[[package]]
+name = "spin"
+version = "0.9.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
+
 [[package]]
 name = "spki"
 version = "0.7.2"
@@ -1379,6 +1398,12 @@ version = "0.7.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a"
 
+[[package]]
+name = "untrusted"
+version = "0.9.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
+
 [[package]]
 name = "uuid"
 version = "0.8.2"
@@ -1582,7 +1607,7 @@ dependencies = [
  "der",
  "hex",
  "pem",
- "ring",
+ "ring 0.16.20",
  "signature",
  "spki",
  "thiserror",
diff --git a/pgproto/Cargo.toml b/pgproto/Cargo.toml
index 4a4bc3e20b..61621e4417 100644
--- a/pgproto/Cargo.toml
+++ b/pgproto/Cargo.toml
@@ -12,7 +12,7 @@ bytes = "1.4.0"
 ctor = "0.2.4"
 log = "0.4.20"
 once_cell = "1.18.0"
-pgwire = "0.16.0"
+pgwire = { git = "https://github.com/sunng87/pgwire", rev = "6e653e851703bb743a96ba7a8e0f1fde6a5abeb0" }
 rand = "0.8.5"
 rmp = "0.8.12"
 rmp-serde = "1.0.0"
diff --git a/pgproto/src/client/startup.rs b/pgproto/src/client/startup.rs
index bf432e59fd..038807a0e7 100644
--- a/pgproto/src/client/startup.rs
+++ b/pgproto/src/client/startup.rs
@@ -1,4 +1,7 @@
+use pgwire::messages::startup::Startup;
+
 use crate::error::{PgError, PgResult};
+use crate::messages;
 use crate::stream::{FeMessage, PgStream};
 use std::{collections::BTreeMap, io};
 
@@ -7,15 +10,7 @@ pub struct ClientParams {
     pub rest: BTreeMap<String, String>,
 }
 
-/// Read startup message, verify parameters and return them.
-pub fn handshake(stream: &mut PgStream<impl io::Read>) -> PgResult<ClientParams> {
-    let message = stream.read_message()?;
-    let FeMessage::Startup(mut startup) = message else {
-        return Err(PgError::ProtocolViolation(format!(
-            "expected Startup, got {message:?}"
-        )));
-    };
-
+fn parse_startup(mut startup: Startup) -> PgResult<ClientParams> {
     let mut parameters = std::mem::take(startup.parameters_mut());
     log::debug!("client parameters: {parameters:?}");
 
@@ -32,3 +27,30 @@ pub fn handshake(stream: &mut PgStream<impl io::Read>) -> PgResult<ClientParams>
         rest: parameters,
     })
 }
+/// Respond to SslRequest if you receive it, read startup message, verify parameters and return them.
+pub fn handshake(stream: &mut PgStream<impl io::Read + io::Write>) -> PgResult<ClientParams> {
+    let mut expect_startup = false;
+    loop {
+        let message = stream.read_message()?;
+        // At the beginning we can get SslRequest or Startup.
+        match message {
+            FeMessage::Startup(startup) => return parse_startup(startup),
+            FeMessage::SslRequest(_) => {
+                if expect_startup {
+                    return Err(PgError::ProtocolViolation(format!(
+                        "expected Startup, got {message:?}"
+                    )));
+                } else {
+                    stream.write_message(messages::ssl_refuse())?;
+                    // After SslRequest, only Startup is expected.
+                    expect_startup = true;
+                }
+            }
+            _ => {
+                return Err(PgError::ProtocolViolation(format!(
+                    "expected Startup or SslRequest, got {message:?}"
+                )))
+            }
+        }
+    }
+}
diff --git a/pgproto/src/messages.rs b/pgproto/src/messages.rs
index 7a2c071e39..86e50d8bd2 100644
--- a/pgproto/src/messages.rs
+++ b/pgproto/src/messages.rs
@@ -1,6 +1,7 @@
 use crate::stream::BeMessage;
 use pgwire::error::ErrorInfo;
 use pgwire::messages::data::{DataRow, RowDescription};
+use pgwire::messages::response::SslResponse;
 use pgwire::messages::{response, startup::*};
 
 /// MD5AuthRequest requests md5 password from the frontend.
@@ -42,3 +43,7 @@ pub fn row_description(row_description: RowDescription) -> BeMessage {
 pub fn data_row(data_row: DataRow) -> BeMessage {
     BeMessage::DataRow(data_row)
 }
+
+pub fn ssl_refuse() -> BeMessage {
+    BeMessage::SslResponse(SslResponse::Refuse)
+}
diff --git a/pgproto/src/stream.rs b/pgproto/src/stream.rs
index 94fd066511..74a37e7e55 100644
--- a/pgproto/src/stream.rs
+++ b/pgproto/src/stream.rs
@@ -1,5 +1,6 @@
 use crate::error::PgResult;
 use bytes::{BufMut, BytesMut};
+use pgwire::messages::startup::SslRequest;
 use std::io::{self, ErrorKind::UnexpectedEof};
 
 // Public re-exports.
@@ -56,6 +57,11 @@ impl<S: io::Read> PgStream<S> {
             return FeMessage::decode(&mut self.ibuf).map_err(|e| e.into());
         }
 
+        // Try to decode SslRequest first, as it fits the Startup format with an invalid version.
+        if let Some(ssl_request) = SslRequest::decode(&mut self.ibuf)? {
+            return Ok(Some(FeMessage::SslRequest(ssl_request)));
+        }
+
         // This is done once at connection startup.
         let startup = Startup::decode(&mut self.ibuf)?.map(|x| {
             log::debug!("received StartupPacket from client");
diff --git a/pgproto/test/ssl_test.py b/pgproto/test/ssl_test.py
new file mode 100644
index 0000000000..eeb963ea9f
--- /dev/null
+++ b/pgproto/test/ssl_test.py
@@ -0,0 +1,41 @@
+import pytest
+import pg8000.dbapi as pg  # type: ignore
+from conftest import Postgres
+import os
+
+
+def test_ssl_request_handling(postgres: Postgres):
+    host = "127.0.0.1"
+    port = 5432
+
+    postgres.start(host, port)
+    i1 = postgres.instance
+
+    user = "user"
+    password = "password"
+    i1.eval("box.cfg{auth_type='md5'}")
+    i1.call("pico.create_user", user, password, dict(timeout=3))
+
+    # disable: only try a non-SSL connection
+    os.environ["PGSSLMODE"] = "disable"
+    conn = pg.Connection(user, password=password, host=host, port=port)
+    conn.close()
+
+    # prefer: first try an SSL connection; if that fails,
+    #         try a non-SSL connection.
+    # As ssl is not supported, server will respond to SslRequest with
+    # SslRefuse and client will try a non-SSL connection.
+    os.environ["PGSSLMODE"] = "prefer"
+    conn = pg.Connection(user, password=password, host=host, port=port)
+    conn.close()
+
+    # require: only try an SSL connection.
+    # As ssl is not supported, server will respond to SslRequest with
+    # SslRefuse and client won't try to connect again.
+    # Client we will see: server does not support SSL, but SSL was required,
+    # but client doesn't have to inform the server.
+    os.environ["PGSSLMODE"] = "require"
+    with pytest.raises(
+        pg.DatabaseError, match=f"authentication failed for user '{user}'"
+    ):
+        pg.Connection(user, password="wrong password", host=host, port=port)
-- 
GitLab