From 6c037819bce9057c8c75c595bc369d81d4e7847f Mon Sep 17 00:00:00 2001 From: Kaitmazian Maksim <m.kaitmazian@picodata.io> Date: Sun, 14 Jul 2024 19:52:12 +0300 Subject: [PATCH] pgproto: refactor portal states --- src/pgproto/backend/storage.rs | 31 ++++++++++++++++++------------- test/int/test_pgproto.py | 2 +- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/src/pgproto/backend/storage.rs b/src/pgproto/backend/storage.rs index a60f4ba814..fe86afcb33 100644 --- a/src/pgproto/backend/storage.rs +++ b/src/pgproto/backend/storage.rs @@ -387,9 +387,14 @@ pub struct Portal { #[derive(Debug, Default)] enum PortalState { #[default] + /// Portal has just been created. NotStarted, - Running(IntoIter<Vec<PgValue>>), - Finished(Option<ExecuteResult>), + /// Portal has been executed and contains rows to be sent in batches. + StreamingRows(IntoIter<Vec<PgValue>>), + /// Portal has been executed and contains a result ready to be sent. + ResultReady(ExecuteResult), + /// Portal has been executed, and a result has been sent. + Done, } /// Get rows from dql-like(dql or explain) query execution result. @@ -451,19 +456,19 @@ impl Portal { loop { match &mut self.state { PortalState::NotStarted => self.start()?, - PortalState::Finished(Some(_)) => { - let state = std::mem::replace(&mut self.state, PortalState::Finished(None)); + PortalState::ResultReady(_) => { + let state = std::mem::replace(&mut self.state, PortalState::Done); match state { - PortalState::Finished(Some(result)) => return Ok(result), + PortalState::ResultReady(result) => return Ok(result), _ => unreachable!(), } } - PortalState::Running(ref mut stored_rows) => { + PortalState::StreamingRows(ref mut stored_rows) => { let taken: Vec<_> = stored_rows.take(max_rows).collect(); let row_count = taken.len(); let rows = Rows::new(taken, self.describe.row_info()); if stored_rows.len() == 0 { - self.state = PortalState::Finished(None); + self.state = PortalState::Done; return Ok(ExecuteResult::FinishedDql { rows, tag: self.describe.command_tag(), @@ -496,14 +501,14 @@ impl Portal { ); let tuple = dispatch(query)?; self.state = match self.describe().query_type() { - QueryType::Dml => { - let row_count = get_row_count_from_tuple(&tuple)?; + QueryType::Acl | QueryType::Ddl => { let tag = self.describe().command_tag(); - PortalState::Finished(Some(ExecuteResult::Dml { row_count, tag })) + PortalState::ResultReady(ExecuteResult::AclOrDdl { tag }) } - QueryType::Acl | QueryType::Ddl => { + QueryType::Dml => { + let row_count = get_row_count_from_tuple(&tuple)?; let tag = self.describe().command_tag(); - PortalState::Finished(Some(ExecuteResult::AclOrDdl { tag })) + PortalState::ResultReady(ExecuteResult::Dml { row_count, tag }) } QueryType::Dql | QueryType::Explain => { let mp_rows = get_rows_from_tuple(&tuple)?; @@ -512,7 +517,7 @@ impl Portal { .into_iter() .map(|row| mp_row_into_pg_row(row, metadata)) .collect::<PgResult<Vec<Vec<_>>>>()?; - PortalState::Running(pg_rows.into_iter()) + PortalState::StreamingRows(pg_rows.into_iter()) } }; Ok(()) diff --git a/test/int/test_pgproto.py b/test/int/test_pgproto.py index ef36e11e42..eaa7c5ced5 100644 --- a/test/int/test_pgproto.py +++ b/test/int/test_pgproto.py @@ -413,7 +413,7 @@ def test_interactive_portals(pg_client: PgClient): assert [1, "kek"] in data["rows"] or [2, "lol"] in data["rows"] assert data["is_finished"] is True - with pytest.raises(ReturnError, match="Can't execute portal in state Finished"): + with pytest.raises(ReturnError, match="Can't execute portal in state Done"): data = pg_client.execute("", 1) sql = """ explain select * from "t" """ -- GitLab