From 6c037819bce9057c8c75c595bc369d81d4e7847f Mon Sep 17 00:00:00 2001
From: Kaitmazian Maksim <m.kaitmazian@picodata.io>
Date: Sun, 14 Jul 2024 19:52:12 +0300
Subject: [PATCH] pgproto: refactor portal states

---
 src/pgproto/backend/storage.rs | 31 ++++++++++++++++++-------------
 test/int/test_pgproto.py       |  2 +-
 2 files changed, 19 insertions(+), 14 deletions(-)

diff --git a/src/pgproto/backend/storage.rs b/src/pgproto/backend/storage.rs
index a60f4ba814..fe86afcb33 100644
--- a/src/pgproto/backend/storage.rs
+++ b/src/pgproto/backend/storage.rs
@@ -387,9 +387,14 @@ pub struct Portal {
 #[derive(Debug, Default)]
 enum PortalState {
     #[default]
+    /// Portal has just been created.
     NotStarted,
-    Running(IntoIter<Vec<PgValue>>),
-    Finished(Option<ExecuteResult>),
+    /// Portal has been executed and contains rows to be sent in batches.
+    StreamingRows(IntoIter<Vec<PgValue>>),
+    /// Portal has been executed and contains a result ready to be sent.
+    ResultReady(ExecuteResult),
+    /// Portal has been executed, and a result has been sent.
+    Done,
 }
 
 /// Get rows from dql-like(dql or explain) query execution result.
@@ -451,19 +456,19 @@ impl Portal {
         loop {
             match &mut self.state {
                 PortalState::NotStarted => self.start()?,
-                PortalState::Finished(Some(_)) => {
-                    let state = std::mem::replace(&mut self.state, PortalState::Finished(None));
+                PortalState::ResultReady(_) => {
+                    let state = std::mem::replace(&mut self.state, PortalState::Done);
                     match state {
-                        PortalState::Finished(Some(result)) => return Ok(result),
+                        PortalState::ResultReady(result) => return Ok(result),
                         _ => unreachable!(),
                     }
                 }
-                PortalState::Running(ref mut stored_rows) => {
+                PortalState::StreamingRows(ref mut stored_rows) => {
                     let taken: Vec<_> = stored_rows.take(max_rows).collect();
                     let row_count = taken.len();
                     let rows = Rows::new(taken, self.describe.row_info());
                     if stored_rows.len() == 0 {
-                        self.state = PortalState::Finished(None);
+                        self.state = PortalState::Done;
                         return Ok(ExecuteResult::FinishedDql {
                             rows,
                             tag: self.describe.command_tag(),
@@ -496,14 +501,14 @@ impl Portal {
         );
         let tuple = dispatch(query)?;
         self.state = match self.describe().query_type() {
-            QueryType::Dml => {
-                let row_count = get_row_count_from_tuple(&tuple)?;
+            QueryType::Acl | QueryType::Ddl => {
                 let tag = self.describe().command_tag();
-                PortalState::Finished(Some(ExecuteResult::Dml { row_count, tag }))
+                PortalState::ResultReady(ExecuteResult::AclOrDdl { tag })
             }
-            QueryType::Acl | QueryType::Ddl => {
+            QueryType::Dml => {
+                let row_count = get_row_count_from_tuple(&tuple)?;
                 let tag = self.describe().command_tag();
-                PortalState::Finished(Some(ExecuteResult::AclOrDdl { tag }))
+                PortalState::ResultReady(ExecuteResult::Dml { row_count, tag })
             }
             QueryType::Dql | QueryType::Explain => {
                 let mp_rows = get_rows_from_tuple(&tuple)?;
@@ -512,7 +517,7 @@ impl Portal {
                     .into_iter()
                     .map(|row| mp_row_into_pg_row(row, metadata))
                     .collect::<PgResult<Vec<Vec<_>>>>()?;
-                PortalState::Running(pg_rows.into_iter())
+                PortalState::StreamingRows(pg_rows.into_iter())
             }
         };
         Ok(())
diff --git a/test/int/test_pgproto.py b/test/int/test_pgproto.py
index ef36e11e42..eaa7c5ced5 100644
--- a/test/int/test_pgproto.py
+++ b/test/int/test_pgproto.py
@@ -413,7 +413,7 @@ def test_interactive_portals(pg_client: PgClient):
     assert [1, "kek"] in data["rows"] or [2, "lol"] in data["rows"]
     assert data["is_finished"] is True
 
-    with pytest.raises(ReturnError, match="Can't execute portal in state Finished"):
+    with pytest.raises(ReturnError, match="Can't execute portal in state Done"):
         data = pg_client.execute("", 1)
 
     sql = """ explain select * from "t" """
-- 
GitLab