From 901366f89008b61ab869d9fe77d5261a663a327b Mon Sep 17 00:00:00 2001
From: Eugine Blikh <bigbes@gmail.com>
Date: Fri, 13 Sep 2013 18:57:10 +0400
Subject: [PATCH] Fix number of bugs and output

---
 test/lib/box_connection.py       | 44 ++++++++++++++++------
 test/lib/sql_ast.py              | 63 +++++++++++++++++---------------
 test/lib/tarantool_connection.py |  2 -
 test/lib/tarantool_server.py     | 15 ++++++++
 4 files changed, 81 insertions(+), 43 deletions(-)

diff --git a/test/lib/box_connection.py b/test/lib/box_connection.py
index be93bbc60c..4829290cb7 100644
--- a/test/lib/box_connection.py
+++ b/test/lib/box_connection.py
@@ -23,8 +23,12 @@ __author__ = "Konstantin Osipov <kostja.osipov@gmail.com>"
 import os
 import sql
 import sys
+import errno
+import ctypes
 import socket
 import struct
+import warnings
+
 from tarantool_connection import TarantoolConnection
 
 try:
@@ -32,8 +36,8 @@ try:
     tnt_py = os.path.join(tnt_py, 'tarantool-python/src')
     sys.path.append(tnt_py)
     from tarantool import Connection as tnt_connection
+    from tarantool import Schema
 except ImportError:
-    raise
     sys.stderr.write("\n\nNo tarantool-python library found\n")
     sys.exit(1)
 
@@ -43,25 +47,41 @@ class BoxConnection(TarantoolConnection):
         self.py_con = tnt_connection(host, port, connect_now=False)
         self.py_con.error = False
         self.sort = False
+ 
+    def connect(self):
+        self.py_con.connect()
+    
+    def disconnect(self):
+        self.py_con.close()
+    
+    def reconnect(self):
+        self.disconnect()
+        self.connect() 
+
+    def set_schema(self, schemadict):
+        self.py_con.schema = Schema(schemadict)
 
-    def recvall(self, length):
-        res = ""
-        while len(res) < length:
-            buf = self.socket.recv(length - len(res))
-            if not buf:
-                raise RuntimeError("Got EOF from socket, the server has "
-                                   "probably crashed")
-            res = res + buf
-        return res
+    def check_connection(self):
+        rc = self.py_con._recv(self.py_con._socket.fileno(), '', 0, socket.MSG_DONTWAIT)
+        if ctypes.get_errno() == errno.EAGAIN:
+            ctypes.set_errno(0)
+            return True
+        return False
+
+    def execute(self, command, silent=True):
+        return self.execute_no_reconnect(command, silent)
 
     def execute_no_reconnect(self, command, silent=True):
         statement = sql.parse("sql", command)
         if statement == None:
             return "You have an error in your SQL syntax\n"
         statement.sort = self.sort
-
+        
+        response = None
         request = statement.pack(self.py_con)
-        response = self.py_con._send_request(request, False)
+        with warnings.catch_warnings():
+            warnings.simplefilter("ignore")
+            response = self.py_con._send_request(request, False)
 
         if not silent:
             print command
diff --git a/test/lib/sql_ast.py b/test/lib/sql_ast.py
index d7e054e570..88884e31b1 100644
--- a/test/lib/sql_ast.py
+++ b/test/lib/sql_ast.py
@@ -17,7 +17,6 @@ try:
             RequestDelete,
     )
 except ImportError:
-    raise
     sys.stderr.write("\n\nNo tarantool-python library found\n")
     sys.exit(1)
 
@@ -83,7 +82,13 @@ ER = {
 }
 
 def format_error(response):
-    return "An error occurred: {0}".format(ER[response.return_code >> 8])
+    return "---\n- error: '{1}'\n...".format(ER[response.return_code],
+            response.return_message)
+
+def format_yamllike(response):
+    table = ("\n"+"\n".join(["- "+str(list(k)) for k in response])) \
+            if len(response) else ""
+    return "---{0}\n...".format(table)
 
 class Statement(object):
     def __init__(self):
@@ -100,53 +105,57 @@ class StatementPing(Statement):
     def unpack(self, response):
         if response._return_code:
             return format_error(response)
-        return "ok\n---"
+        return "---\n- ok\n..."
 
 class StatementInsert(Statement):
     def __init__(self, table_name, value_list):
         self.space_no = table_name
-        self.flags = 0x03 # ADD
+        self.flags = 0x03 # ADD + RET
         self.value_list = value_list
 
     def pack(self, connection):
-        return RequestInsert(connection, self.space_no, self.value_list, self.flags)
+        return RequestInsert(connection, self.space_no, self.value_list,
+                self.flags)
 
     def unpack(self, response):
-        if response._return_code:
+        if response.return_code:
             return format_error(response)
-        return "Insert OK, {0} row affected".format(len(response))
+        return format_yamllike(response)
 
 class StatementReplace(Statement):
     def __init__(self, table_name, value_list):
         self.space_no = table_name
-        self.flags = 0x05 # REPLACE
+        self.flags = 0x05 # REPLACE + RET
         self.value_list = value_list
 
     def pack(self, connection):
-        return RequestInsert(connection, self.space_no, self.value_list, self.flags)
+        return RequestInsert(connection, self.space_no, self.value_list,
+                self.flags)
 
     def unpack(self, response):
-        if response._return_code:
+        if response.return_code:
             return format_error(response)
-        return "Insert OK, {0} row affected".format(len(response))
+        return format_yamllike(response)
 
 class StatementUpdate(Statement):
     def __init__(self, table_name, update_list, where):
         self.space_no = table_name
         self.flags = 0
-        key_no = where[0]
-        if key_no != 0:
-            raise RuntimeError("UPDATE can only be made by the primary key (#0)")
+        self.key_no = where[0]
+        if self.key_no != 0:
+            raise RuntimeError("UPDATE can only be made by the"
+                    " primary key (#0)")
         self.value_list = where[1]
         self.update_list = [(pair[0], '=', pair[1]) for pair in update_list]
 
     def pack(self, connection):
-        return RequestUpdate(connection, self.space_no, self.value_list, True)
+        return RequestUpdate(connection, self.space_no, self.value_list,
+                self.update_list, True)
 
     def unpack(self, response):
-        if response._return_code:
+        if response.return_code:
             return format_error(response)
-        return "Update OK, {0} row affected".format(len(response))
+        return format_yamllike(response)
 
 class StatementDelete(Statement):
     def __init__(self, table_name, where):
@@ -154,16 +163,17 @@ class StatementDelete(Statement):
         self.flags = 0
         key_no = where[0]
         if key_no != 0:
-            raise RuntimeError("DELETE can only be made by the primary key (#0)")
+            raise RuntimeError("DELETE can only be made by the "
+                    "primary key (#0)")
         self.value_list = where[1]
 
     def pack(self, connection):
         return RequestDelete(connection, self.space_no, self.value_list, True)
 
     def unpack(self, response):
-        if response._return_code:
+        if response.return_code:
             return format_error(response)
-        return "Delete OK, {0} row affected".format(len(response))
+        return format_yamllike(response)
 
 class StatementSelect(Statement):
     def __init__(self, table_name, where, limit):
@@ -179,7 +189,8 @@ class StatementSelect(Statement):
                 if self.index_no == None:
                     self.index_no = index_no
                 elif self.index_no != index_no:
-                    raise RuntimeError("All key values in a disjunction must refer to the same index")
+                    raise RuntimeError("All key values in a disjunction must "
+                            "refer to the same index")
         self.offset = 0
         self.limit = limit
 
@@ -188,17 +199,11 @@ class StatementSelect(Statement):
                 self.key_list , self.offset, self.limit)
 
     def unpack(self, response):
-        if response._return_code:
+        if response.return_code:
             return format_error(response)
         if self.sort:
             response = sorted(response[0:])
-        if not len(response):
-            return "No match"
-        elif len(response) == 1:
-            return "Found 1 tuple:\n" + str(response[0])
-        else:
-            return "Found {0} tuples:\n".format(len(response)) + \
-                    "\n".join([str(tup) for tup in response])
+        return format_yamllike(response)
 
 class StatementCall(StatementSelect):
     def __init__(self, proc_name, value_list):
diff --git a/test/lib/tarantool_connection.py b/test/lib/tarantool_connection.py
index e05ee6baba..b8b9246f52 100644
--- a/test/lib/tarantool_connection.py
+++ b/test/lib/tarantool_connection.py
@@ -23,7 +23,6 @@ __author__ = "Konstantin Osipov <kostja.osipov@gmail.com>"
 
 import socket
 import sys
-import cStringIO
 import errno
 
 class TarantoolConnection(object):
@@ -31,7 +30,6 @@ class TarantoolConnection(object):
         self.host = host
         self.port = port
         self.is_connected = False
-        self.stream = cStringIO.StringIO()
         self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
         self.socket.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
 
diff --git a/test/lib/tarantool_server.py b/test/lib/tarantool_server.py
index 1761ee7320..39f5f4e0a9 100644
--- a/test/lib/tarantool_server.py
+++ b/test/lib/tarantool_server.py
@@ -25,6 +25,16 @@ try:
 except ImportError:
     import StringIO
 
+try:
+    tnt_py = os.path.dirname(os.path.abspath(__file__))
+    tnt_py = os.path.join(tnt_py, 'tarantool-python/src')
+    sys.path.append(tnt_py)
+    import tarantool
+    from tarantool import Connection as tnt_connection
+except ImportError:
+    sys.stderr.write("\n\nNo tarantool-python library found\n")
+    sys.exit(1)
+
 def check_port(port):
     """Check if the port we're connecting to is available"""
     try:
@@ -240,6 +250,11 @@ class LuaTest(FuncTest):
 
 class PythonTest(FuncTest):
     def execute(self, server):
+        Schema = tarantool.Schema
+        tntNUM = tarantool.NUM
+        tntSTR = tarantool.STR
+        tntNUM64 = tarantool.NUM64
+        tntRAW = tarantool.RAW
         execfile(self.name, dict(locals(), **server.__dict__))
 
 class TarantoolConfigFile:
-- 
GitLab