From edd00db172c5fb5dbd0f15494564e9d911a66520 Mon Sep 17 00:00:00 2001
From: Kaitmazian Maksim <m.kaitmazian@picodata.io>
Date: Fri, 11 Aug 2023 20:24:30 +0300
Subject: [PATCH] feat: support explain messages

---
 pgproto/src/postgres/attributes.c |  9 ++++
 pgproto/src/postgres/attributes.h | 12 +++++-
 pgproto/src/postgres/messages.c   | 13 ++++--
 pgproto/src/postgres/postgres.c   |  7 +++
 pgproto/test/simple_query_test.py | 71 ++++++++++++++++++++++++++++++-
 5 files changed, 107 insertions(+), 5 deletions(-)

diff --git a/pgproto/src/postgres/attributes.c b/pgproto/src/postgres/attributes.c
index 503c593343..5c3d1737d1 100644
--- a/pgproto/src/postgres/attributes.c
+++ b/pgproto/src/postgres/attributes.c
@@ -258,3 +258,12 @@ parse_metadata(const char **data,
 	}
 	return 0;
 }
+
+void
+row_description_explain(struct row_description *row_desc)
+{
+	row_desc->natts = 1;
+	row_desc->atts = box_region_alloc(sizeof(*row_desc->atts));
+	pg_attribute_text(row_desc->atts, "QUERY PLAN", strlen("QUERY PLAN"),
+			  TEXT_FORMAT, TYPEMOD_DEFAULT);
+}
diff --git a/pgproto/src/postgres/attributes.h b/pgproto/src/postgres/attributes.h
index a58c7828e5..4239944767 100644
--- a/pgproto/src/postgres/attributes.h
+++ b/pgproto/src/postgres/attributes.h
@@ -48,10 +48,20 @@ struct row_description {
 };
 
 /**
- * Get row description from the metadata.
+ * Get a row description from the metadata.
  * Format is not mentioned in metadata so the caller must choose it him self.
+ * After the call metadata is consumed and the data points to the begining of
+ * the rows array.
+ * Allocates on box region.
  */
 int
 parse_metadata(const char **data,
 	       struct row_description *row_desc,
 	       uint16_t format);
+
+/**
+ * Get a row description for an explain query.
+ * Allocates on box region.
+ */
+void
+row_description_explain(struct row_description *row_desc);
diff --git a/pgproto/src/postgres/messages.c b/pgproto/src/postgres/messages.c
index 0b78f3b2ed..eabbb5bc9f 100644
--- a/pgproto/src/postgres/messages.c
+++ b/pgproto/src/postgres/messages.c
@@ -70,9 +70,16 @@ send_data_row(struct pg_port *port, const char **data,
 	pg_begin_msg(port, 'D');
 	pg_write_uint16(port, row_desc->natts);
 	const struct pg_attribute *atts = row_desc->atts;
-	assert(mp_typeof(**data) == MP_ARRAY);
-	uint32_t row_size = mp_decode_array(data);
-	assert(row_size == row_desc->natts);
+	/**
+	 * All queries except explain return rows as arrays,
+	 * explain returns strings, so there is no need for decoding.
+	 */
+	if (mp_typeof(**data) == MP_ARRAY) {
+		uint32_t row_size = mp_decode_array(data);
+		assert(row_size == row_desc->natts);
+	} else {
+		assert(mp_typeof(**data) == MP_STR);
+	}
 	for (uint16_t i = 0; i < row_desc->natts; ++i)
 		atts[i].write(&atts[i], port, data);
 	pg_end_msg(port);
diff --git a/pgproto/src/postgres/postgres.c b/pgproto/src/postgres/postgres.c
index 4c1cbe9581..5fe171d64d 100644
--- a/pgproto/src/postgres/postgres.c
+++ b/pgproto/src/postgres/postgres.c
@@ -101,6 +101,13 @@ process_query_response(struct pg_port *port, const char **response)
 	assert(size == 1);
 	assert(mp_typeof(**data) == MP_ARRAY);
 	size = mp_decode_array(data);
+	if (mp_typeof(**data) == MP_ARRAY) {
+		/** Explain query */
+		struct row_description row_desc;
+		row_description_explain(&row_desc);
+		send_row_description_message(port, &row_desc);
+		return send_data_rows(port, data, &row_desc);
+	}
 	assert(mp_typeof(**data) == MP_MAP);
 	size = mp_decode_map(data);
 	uint32_t len;
diff --git a/pgproto/test/simple_query_test.py b/pgproto/test/simple_query_test.py
index 9f9d7e185b..5e8606afaa 100644
--- a/pgproto/test/simple_query_test.py
+++ b/pgproto/test/simple_query_test.py
@@ -2,7 +2,6 @@ import pytest
 from conftest import Cluster
 import pg8000.dbapi as pg
 import os
-import time
 
 def start_pg_server(instance, host, service):
     start_pg_server_lua_code = f"""
@@ -111,6 +110,74 @@ def test_simple_flow_session(cluster: Cluster):
 
     stop_pg_server(i1)
 
+def test_explain(cluster: Cluster):
+    cluster.deploy(instance_count=1)
+    i1 = cluster.instances[0]
+
+    host = '127.0.0.1'
+    service = '54321'
+    start_pg_server(i1, host, service)
+
+    user = 'admin'
+    password = 'password'
+    i1.eval("box.cfg{auth_type='md5'}")
+    i1.eval(f"box.schema.user.passwd('{user}', '{password}')")
+
+    os.environ['PGSSLMODE'] = 'disable'
+    conn = pg.Connection(user, password=password, host=host, port=int(service))
+    conn.autocommit = True
+    cur = conn.cursor()
+
+    cur.execute("""
+        create table "explain" (
+            "id" integer not null,
+            primary key ("id")
+        )
+        using memtx distributed by ("id")
+        option (timeout = 3);
+    """)
+
+    query = """
+        insert into "explain" values (0);
+    """
+    cur.execute("explain " + query)
+    plan = cur.fetchall()
+    assert 'insert "explain" on conflict: fail' in plan[0]
+    assert '    motion [policy: local segment([ref("COLUMN_1")])]' in plan[1]
+    assert '        values' in plan[2]
+    assert '            value row (data=ROW(0::unsigned))' in plan[3]
+    assert 'execution options:' in plan[4]
+
+    cur.execute(query)
+    cur.execute("explain " + query)
+    plan = cur.fetchall()
+    assert 'insert "explain" on conflict: fail' in plan[0]
+    assert '    motion [policy: local segment([ref("COLUMN_1")])]' in plan[1]
+    assert '        values' in plan[2]
+    assert '            value row (data=ROW(0::unsigned))' in plan[3]
+    assert 'execution options:' in plan[4]
+
+    query = """
+        select * from "explain";
+    """
+    cur.execute("explain " + query)
+    plan = cur.fetchall()
+    assert 'projection ("explain"."id"::integer -> "id")' in plan[0]
+    assert '    scan "explain"' in plan[1]
+    assert 'execution options:' in plan[2]
+
+    cur.execute(query)
+    cur.execute("explain " + query)
+    plan = cur.fetchall()
+    assert 'projection ("explain"."id"::integer -> "id")' in plan[0]
+    assert '    scan "explain"' in plan[1]
+    assert 'execution options:' in plan[2]
+
+    cur.execute('drop table "explain";')
+
+    stop_pg_server(i1)
+
+
 # Aggregates return value type is decimal, which is currently not supported,
 # so an error is expected
 def test_aggregate_error(cluster: Cluster):
@@ -152,3 +219,5 @@ def test_aggregate_error(cluster: Cluster):
         cur.execute("""
             SELECT SUM("id") FROM "tall";
         """)
+
+    stop_pg_server(i1)
-- 
GitLab