From e75beef3640dcd4ac97d2e7ce1d933d7b76b0de7 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=D0=9A=D0=B8=D1=80=D0=B8=D0=BB=D0=BB=20=D0=91=D0=B5=D0=B7?=
 =?UTF-8?q?=D1=83=D0=B3=D0=BB=D1=8B=D0=B9?= <k.bezuglyi@picodata.io>
Date: Tue, 10 Sep 2024 11:07:52 +0300
Subject: [PATCH] test: add tests for IPROTO_EXECUTE redirection

---
 test/conftest.py        |  8 ++++
 test/int/test_iproto.py | 97 +++++++++++++++++++++++++++++++++++++++++
 2 files changed, 105 insertions(+)
 create mode 100644 test/int/test_iproto.py

diff --git a/test/conftest.py b/test/conftest.py
index 25adf0912d..35582221f3 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -309,6 +309,10 @@ def normalize_net_box_result(func):
         if result is None:
             return
 
+        # This is special case for non-SELECT or non-VALUES IPROTO_EXECUTE requests
+        if hasattr(result, "affected_row_count") and result.data is None:
+            return {"row_count": result.affected_row_count}
+
         match result.data:
             case []:
                 return None
@@ -529,6 +533,10 @@ class Connection(tarantool.Connection):  # type: ignore
     def eval(self, expr, *args, on_push=None, on_push_ctx=None):
         return super().eval(expr, *args, on_push=on_push, on_push_ctx=on_push_ctx)
 
+    @normalize_net_box_result
+    def execute(self, query, params=None):
+        return super().execute(query, params)
+
     def sql(self, sql: str, *params, options=None, sudo=False) -> dict[str, list]:
         """Run SQL query and return result"""
         if sudo:
diff --git a/test/int/test_iproto.py b/test/int/test_iproto.py
new file mode 100644
index 0000000000..a2434afbf5
--- /dev/null
+++ b/test/int/test_iproto.py
@@ -0,0 +1,97 @@
+import pytest
+
+from conftest import Instance, Connection, MalformedAPI, TarantoolError, ErrorCode
+
+
+TABLE_NAME = "warehouse"
+
+
+def create_connection(instance: Instance):
+    user_name = "kelthuzad"
+    user_pass = "g$$dP4ss"
+
+    dcl = instance.sql(
+        f"""
+        CREATE USER {user_name} WITH PASSWORD '{user_pass}' USING chap-sha1
+    """
+    )
+    assert dcl["row_count"] == 1
+    acl = instance.sql(f"GRANT READ ON TABLE {TABLE_NAME} TO {user_name}")
+    assert acl["row_count"] == 1
+    acl = instance.sql(f"GRANT WRITE ON TABLE {TABLE_NAME} TO {user_name}")
+    assert acl["row_count"] == 1
+
+    conn = Connection(
+        instance.host,
+        instance.port,
+        user=user_name,
+        password=user_pass,
+        connect_now=True,
+        reconnect_max_attempts=0,
+    )
+    assert conn
+
+    return conn
+
+
+def test_iproto_execute(instance: Instance):
+    # https://docs.picodata.io/picodata/stable/reference/legend/#create_test_tables
+    ddl = instance.sql(
+        f"""
+        CREATE TABLE {TABLE_NAME} (
+            id INTEGER NOT NULL,
+            item TEXT NOT NULL,
+            type TEXT NOT NULL,
+            PRIMARY KEY (id))
+        USING memtx DISTRIBUTED BY (id)
+        OPTION (TIMEOUT = 3.0)
+    """
+    )
+    assert ddl["row_count"] == 1
+
+    # https://docs.picodata.io/picodata/stable/reference/legend/#populate_test_tables
+    data = instance.sql(
+        f"""
+        INSERT INTO {TABLE_NAME} VALUES
+            (1, 'bricks', 'heavy'),
+            (2, 'panels', 'light')
+    """
+    )
+    assert data["row_count"] == 2
+
+    conn = create_connection(instance)
+
+    with pytest.raises(TarantoolError) as data:
+        conn.execute(f"SELECT * FROM {TABLE_NAME}")
+    assert data.value.args[:2] == (
+        "ER_ACCESS_DENIED",
+        f"Read access to space '{TABLE_NAME}' is denied for user 'guest'",
+    )
+
+    acl = instance.sql(
+        f"""
+        GRANT READ ON TABLE {TABLE_NAME} TO guest
+    """
+    )
+    assert acl["row_count"] == 1
+
+    with pytest.raises(MalformedAPI) as dql:
+        conn.execute(f"SELECT * FROM {TABLE_NAME}")
+    assert dql.value.args == ([1, "bricks", "heavy"], [2, "panels", "light"])
+
+    with pytest.raises(TarantoolError) as dql:  # type: ignore
+        conn.execute(f"SELECT * FRUM {TABLE_NAME}")
+    assert dql.value.args[:2] == (
+        ErrorCode.Other,
+        f"sbroad: rule parsing error:  --> 1:8\n  |\n1 | SELECT * FRUM {TABLE_NAME}\n  |        ^---\n  |\n  = expected Identifier or Distinct",  # noqa: E501
+    )
+
+    acl = instance.sql(
+        f"""
+        GRANT WRITE ON TABLE {TABLE_NAME} TO guest
+    """
+    )
+    assert acl["row_count"] == 1
+
+    dml = conn.execute(f"DELETE FROM {TABLE_NAME} WHERE id = 1")
+    assert dml["row_count"] == 1
-- 
GitLab