diff --git a/test/box/sql.result b/test/box/sql.result index 947d4b4cd3a44d6f2b97cb02fd6e4ab547e5494f..27ace669657c18c0d7d46ad0d53c7814651fdb6c 100644 --- a/test/box/sql.result +++ b/test/box/sql.result @@ -27,3 +27,29 @@ delete from t0 where k0 = 1 Delete OK, 1 row affected select * from t0 where k0 = 1 No match +update t0 set k1 = "I am a new tuple" where k0=1 +Insert OK, 0 row affected +select * from t0 where k0 = 1 +No match +insert into t0 values (1, "I am a new tuple") +Insert OK, 1 row affected +select * from t0 where k0 = 1 +Found 1 tuple: +[1, 'I am a new tuple'] +update t0 set k1 = "I am the newest tuple" where k0=1 +Insert OK, 1 row affected +select * from t0 where k0 = 1 +Found 1 tuple: +[1, 'I am the newest tuple'] +update t0 set k1 = "Huh", k2 = "Oh-ho-ho" where k0=1 +An error occurred: ERR_CODE_ILLEGAL_PARAMS, 'Illegal parameters' +select * from t0 where k0 = 1 +Found 1 tuple: +[1, 'I am the newest tuple'] +insert into t0 values (1, "I am a new tuple", "stub") +Insert OK, 1 row affected +update t0 set k1 = "Huh", k2 = "Oh-ho-ho" where k0=1 +Insert OK, 1 row affected +select * from t0 where k0 = 1 +Found 1 tuple: +[1, 'Huh', 'Oh-ho-ho'] diff --git a/test/box/sql.test b/test/box/sql.test index 3353e832ecb582e7d09ba973550e209ddd848499..c0a040bd691283e76a2beadd17ab56172b21fe2a 100644 --- a/test/box/sql.test +++ b/test/box/sql.test @@ -16,5 +16,17 @@ server.restart() exec sql 'select * from t0 where k0 = 1' exec sql 'delete from t0 where k0 = 1' exec sql 'select * from t0 where k0 = 1' +exec sql 'update t0 set k1 = "I am a new tuple" where k0=1' +exec sql 'select * from t0 where k0 = 1' +exec sql 'insert into t0 values (1, "I am a new tuple")' +exec sql 'select * from t0 where k0 = 1' +exec sql 'update t0 set k1 = "I am the newest tuple" where k0=1' +exec sql 'select * from t0 where k0 = 1' +# this is illegal, can't change tuple dimension with update +exec sql 'update t0 set k1 = "Huh", k2 = "Oh-ho-ho" where k0=1' +exec sql 'select * from t0 where k0 = 1' +exec sql 'insert into t0 values (1, "I am a new tuple", "stub")' +exec sql 'update t0 set k1 = "Huh", k2 = "Oh-ho-ho" where k0=1' +exec sql 'select * from t0 where k0 = 1' # vim: syntax=python diff --git a/test/lib/sql.g b/test/lib/sql.g index a57fdd50012a68f803417314d496252f4c653bc8..594e9942fe2c5091d3ab2151b58beda797100475 100644 --- a/test/lib/sql.g +++ b/test/lib/sql.g @@ -1,4 +1,7 @@ import sql_ast +import re + +object_no_re = re.compile("[a-z_]*", re.I) %% @@ -26,18 +29,18 @@ parser sql: select {{ stmt = select }} | ping {{ stmt = ping }}) END {{ return stmt }} - rule insert: INSERT [INTO] ID VALUES value_list - {{ return sql_ast.StatementInsert(ID, value_list) }} - rule update: UPDATE ID SET update_list opt_where - {{ return sql_ast.StatementUpdate(ID, update_list, opt_where) }} - rule delete: DELETE FROM ID opt_where - {{ return sql_ast.StatementDelete(ID, opt_where) }} - rule select: SELECT '\*' FROM ID opt_where - {{ return sql_ast.StatementSelect(ID, opt_where) }} + rule insert: INSERT [INTO] ident VALUES value_list + {{ return sql_ast.StatementInsert(ident, value_list) }} + rule update: UPDATE ident SET update_list opt_where + {{ return sql_ast.StatementUpdate(ident, update_list, opt_where) }} + rule delete: DELETE FROM ident opt_where + {{ return sql_ast.StatementDelete(ident, opt_where) }} + rule select: SELECT '\*' FROM ident opt_where + {{ return sql_ast.StatementSelect(ident, opt_where) }} rule ping: PING {{ return sql_ast.StatementPing() }} - rule predicate: ID '=' constant - {{ return (ID, constant) }} + rule predicate: ident '=' constant + {{ return (ident, constant) }} rule opt_where: {{ return None }} | WHERE predicate {{ return predicate }} @@ -49,6 +52,7 @@ parser sql: {{ return update_list }} rule expr: constant {{ return constant }} rule constant: NUM {{ return int(NUM) }} | STR {{ return STR[1:-1] }} + rule ident: ID {{ return int(object_no_re.sub("", ID)) }} %% # SQL is case-insensitive, but in yapps it's not possible to diff --git a/test/lib/sql.py b/test/lib/sql.py index 80f87134b025006f3008dc2084f7e40a300c56b1..93593316b3b447c98a30a9496a46b06b5ce9e1ef 100644 --- a/test/lib/sql.py +++ b/test/lib/sql.py @@ -1,4 +1,7 @@ import sql_ast +import re + +object_no_re = re.compile("[a-z_]*", re.I) # Begin -- grammar generated by Yapps @@ -60,36 +63,36 @@ class sql(runtime.Parser): INSERT = self._scan('INSERT', context=_context) if self._peek('INTO', 'ID', context=_context) == 'INTO': INTO = self._scan('INTO', context=_context) - ID = self._scan('ID', context=_context) + ident = self.ident(_context) VALUES = self._scan('VALUES', context=_context) value_list = self.value_list(_context) - return sql_ast.StatementInsert(ID, value_list) + return sql_ast.StatementInsert(ident, value_list) def update(self, _parent=None): _context = self.Context(_parent, self._scanner, 'update', []) UPDATE = self._scan('UPDATE', context=_context) - ID = self._scan('ID', context=_context) + ident = self.ident(_context) SET = self._scan('SET', context=_context) update_list = self.update_list(_context) opt_where = self.opt_where(_context) - return sql_ast.StatementUpdate(ID, update_list, opt_where) + return sql_ast.StatementUpdate(ident, update_list, opt_where) def delete(self, _parent=None): _context = self.Context(_parent, self._scanner, 'delete', []) DELETE = self._scan('DELETE', context=_context) FROM = self._scan('FROM', context=_context) - ID = self._scan('ID', context=_context) + ident = self.ident(_context) opt_where = self.opt_where(_context) - return sql_ast.StatementDelete(ID, opt_where) + return sql_ast.StatementDelete(ident, opt_where) def select(self, _parent=None): _context = self.Context(_parent, self._scanner, 'select', []) SELECT = self._scan('SELECT', context=_context) self._scan("'\\*'", context=_context) FROM = self._scan('FROM', context=_context) - ID = self._scan('ID', context=_context) + ident = self.ident(_context) opt_where = self.opt_where(_context) - return sql_ast.StatementSelect(ID, opt_where) + return sql_ast.StatementSelect(ident, opt_where) def ping(self, _parent=None): _context = self.Context(_parent, self._scanner, 'ping', []) @@ -98,10 +101,10 @@ class sql(runtime.Parser): def predicate(self, _parent=None): _context = self.Context(_parent, self._scanner, 'predicate', []) - ID = self._scan('ID', context=_context) + ident = self.ident(_context) self._scan("'='", context=_context) constant = self.constant(_context) - return (ID, constant) + return (ident, constant) def opt_where(self, _parent=None): _context = self.Context(_parent, self._scanner, 'opt_where', []) @@ -154,6 +157,11 @@ class sql(runtime.Parser): STR = self._scan('STR', context=_context) return STR[1:-1] + def ident(self, _parent=None): + _context = self.Context(_parent, self._scanner, 'ident', []) + ID = self._scan('ID', context=_context) + return int(object_no_re.sub("", ID)) + def parse(rule, text): P = sql(sqlScanner(text)) diff --git a/test/lib/sql_ast.py b/test/lib/sql_ast.py index da08a3537f647dccca38a35423f2fc3b9cfb5683..fa0e04e50bb2ea777a6676a57d1eb6ea4a7bbea3 100644 --- a/test/lib/sql_ast.py +++ b/test/lib/sql_ast.py @@ -4,12 +4,17 @@ import ctypes # IPROTO header is always 3 4-byte ints: # command code, length, request id -IPROTO_HEADER_LEN = 12 -INSERT_REQUEST_FIXED_LEN = 8 -DELETE_REQUEST_FIXED_LEN = 4 -SELECT_REQUEST_FIXED_LEN = 20 +INT_FIELD_LEN = 4 +INT_BER_MAX_LEN = 5 +IPROTO_HEADER_LEN = 3*INT_FIELD_LEN +INSERT_REQUEST_FIXED_LEN = 2*INT_FIELD_LEN +UPDATE_REQUEST_FIXED_LEN = 2*INT_FIELD_LEN +DELETE_REQUEST_FIXED_LEN = INT_FIELD_LEN +SELECT_REQUEST_FIXED_LEN = 5*INT_FIELD_LEN PACKET_BUF_LEN = 2048 +UPDATE_SET_FIELD_OPCODE = 0 + # command code in IPROTO header INSERT_REQUEST_TYPE = 13 @@ -52,7 +57,6 @@ def format_error(return_code): return "An error occurred: {0}, \'{1}'".format(ER[return_code][0], ER[return_code][1]) -object_no_re = re.compile("[a-z_]*", re.I) def save_varint32(value): """Implement Perl pack's 'w' option, aka base 128 encoding.""" @@ -93,33 +97,48 @@ def opt_resize_buf(buf, newsize): return buf +def pack_field(value, buf, offset): + if type(value) is int: + buf = opt_resize_buf(buf, offset + INT_FIELD_LEN) + struct.pack_into("<cL", buf, offset, chr(INT_FIELD_LEN), value) + offset += INT_FIELD_LEN + 1 + elif type(value) is str: + opt_resize_buf(buf, offset + INT_BER_MAX_LEN + len(value)) + value_len_ber = save_varint32(len(value)) + struct.pack_into("{0}s{1}s".format(len(value_len_ber), len(value)), + buf, offset, value_len_ber, value) + offset += len(value_len_ber) + len(value) + else: + raise RuntimeError("Unsupported value type in value list") + return (buf, offset) + + def pack_tuple(value_list, buf, offset): """Represents <tuple> rule in tarantool protocol description. Pack tuple into a binary representation. buf and offset are in-out parameters, offset is advanced to the amount of bytes that it took to pack the tuple""" # length of int field: 1 byte - field len (is always 4), 4 bytes - data - INT_FIELD_LEN = 4 # max length of compressed integer - INT_BER_MAX_LEN = 5 cardinality = len(value_list) struct.pack_into("<L", buf, offset, cardinality) offset += INT_FIELD_LEN for value in value_list: - if type(value) is int: - buf = opt_resize_buf(buf, offset + INT_FIELD_LEN) - struct.pack_into("<cL", buf, offset, chr(INT_FIELD_LEN), value) - offset += INT_FIELD_LEN + 1 - elif type(value) is str: - opt_resize_buf(buf, offset + INT_BER_MAX_LEN + len(value)) - value_len_ber = save_varint32(len(value)) - struct.pack_into("{0}s{1}s".format(len(value_len_ber), len(value)), - buf, offset, value_len_ber, value) - offset += len(value_len_ber) + len(value) - else: - raise RuntimeError("Unsupported value type in value list") + (buf, offset) = pack_field(value, buf, offset) return buf, offset +def pack_operation_list(update_list, buf, offset): + buf = opt_resize_buf(buf, offset + INT_FIELD_LEN) + struct.pack_into("<L", buf, offset, len(update_list)) + offset += INT_FIELD_LEN + for update in update_list: + opt_resize_buf(buf, offset + INT_FIELD_LEN + 1) + struct.pack_into("<Lc", buf, offset, + update[0], + chr(UPDATE_SET_FIELD_OPCODE)) + offset += INT_FIELD_LEN + 1 + (buf, offset) = pack_field(update[1], buf, offset) + return (buf, offset) def unpack_tuple(response, offset): (size,cardinality) = struct.unpack("<LL", response[offset:offset + 8]) @@ -133,6 +152,7 @@ def unpack_tuple(response, offset): (data,) = struct.unpack("<L", data) res.append(data) return str(res), offset + class StatementPing: reqeust_type = PING_REQUEST_TYPE @@ -146,7 +166,7 @@ class StatementInsert(StatementPing): reqeust_type = INSERT_REQUEST_TYPE def __init__(self, table_name, value_list): - self.namespace_no = int(object_no_re.sub("", table_name)) + self.namespace_no = table_name self.flags = 0 self.value_list = value_list @@ -168,16 +188,34 @@ class StatementUpdate(StatementPing): reqeust_type = UPDATE_REQUEST_TYPE def __init__(self, table_name, update_list, where): - self.namespace_no = int(object_no_re.sub("", table_name)) + self.namespace_no = table_name + self.flags = 0 + key_no = where[0] + if key_no != 0: + raise RuntimeError("UPDATE can only be made by the primary key (#0)") + self.value_list = where[1:] self.update_list = update_list - self.where = where + + def pack(self): + buf = ctypes.create_string_buffer(PACKET_BUF_LEN) + struct.pack_into("<LL", buf, 0, self.namespace_no, self.flags) + (buf, offset) = pack_tuple(self.value_list, buf, UPDATE_REQUEST_FIXED_LEN) + (buf, offset) = pack_operation_list(self.update_list, buf, offset) + return buf[:offset] + + def unpack(self, response): + (return_code,) = struct.unpack("<L", response[:4]) + if return_code: + return format_error(return_code) + (result_code, row_count) = struct.unpack("<LL", response) + return "Insert OK, {0} row affected".format(row_count) class StatementDelete(StatementPing): reqeust_type = DELETE_REQUEST_TYPE def __init__(self, table_name, where): - self.namespace_no = int(object_no_re.sub("", table_name)) - key_no = int(object_no_re.sub("", where[0])) + self.namespace_no = table_name + key_no = where[0] if key_no != 0: raise RuntimeError("DELETE can only be made by the primary key (#0)") self.value_list = where[1:] @@ -199,10 +237,9 @@ class StatementSelect(StatementPing): reqeust_type = SELECT_REQUEST_TYPE def __init__(self, table_name, where): - self.namespace_no = int(object_no_re.sub("", table_name)) + self.namespace_no = table_name if where: - (index, key) = where - self.index_no = int(object_no_re.sub("", index)) + (self.index_no, key) = where self.key = [key] else: self.index_no = 0