diff --git a/src/luamod.rs b/src/luamod.rs index 9c37830e1122be8759ed95491721007c6bbf0630..843970c8549c2647128f7fa02690f92f9f129863 100644 --- a/src/luamod.rs +++ b/src/luamod.rs @@ -633,6 +633,8 @@ pub(crate) fn setup(args: &args::Run) { }) }, ); + // FIXME: space param inconsistency + // op requires `space` to be a String, but ranges in predicate require space to be a `SpaceId` luamod_set( &l, "cas", diff --git a/src/rpc/cas.rs b/src/rpc/cas.rs index 83a0d9388e1e1e79b4358a3299b0445558906593..5b47a77e72410e6dd2faaf34ed9cfa2e62e4866f 100644 --- a/src/rpc/cas.rs +++ b/src/rpc/cas.rs @@ -274,30 +274,6 @@ impl Predicate { requested: self.index, conflict_index: entry_index, }; - let check_bounds = |key_def: &KeyDef, key: &Tuple, range: &Range| { - let min_satisfied = match &range.key_min { - Bound::Included(bound) => key_def.compare_with_key(key, bound).is_ge(), - Bound::Excluded(bound) => key_def.compare_with_key(key, bound).is_gt(), - Bound::Unbounded => true, - }; - // short-circuit - if !min_satisfied { - return Ok(()); - } - let max_satisfied = match &range.key_max { - Bound::Included(bound) => key_def.compare_with_key(key, bound).is_le(), - Bound::Excluded(bound) => key_def.compare_with_key(key, bound).is_lt(), - Bound::Unbounded => true, - }; - // min_satisfied && max_satisfied - if max_satisfied { - // If entry found that modified a tuple in bounds - cancel CaS. - Err(error()) - } else { - Ok(()) - } - }; - let ddl_keys: Lazy<Vec<Tuple>> = Lazy::new(|| { use crate::storage::PropertyName::*; @@ -328,23 +304,31 @@ impl Predicate { Op::Dml(Dml::Update { key, .. } | Dml::Delete { key, .. }) => { let key = Tuple::new(key)?; let key_def = storage.key_def_for_key(space, 0)?; - check_bounds(&key_def, &key, range)?; + if range.contains(&key_def, &key) { + return Err(error()); + } } Op::Dml(Dml::Insert { tuple, .. } | Dml::Replace { tuple, .. }) => { let tuple = Tuple::new(tuple)?; let key_def = storage.key_def(space, 0)?; - check_bounds(&key_def, &tuple, range)?; + if range.contains(&key_def, &tuple) { + return Err(error()); + } } Op::DdlPrepare { .. } | Op::DdlCommit | Op::DdlAbort => { let key_def = storage.key_def_for_key(space, 0)?; for key in ddl_keys.iter() { - check_bounds(&key_def, key, range)?; + if range.contains(&key_def, key) { + return Err(error()); + } } } Op::PersistInstance(op) => { let key = Tuple::new(&(&op.0.instance_id,))?; let key_def = storage.key_def_for_key(space, 0)?; - check_bounds(&key_def, &key, range)?; + if range.contains(&key_def, &key) { + return Err(error()); + } } Op::Nop => (), }; @@ -376,62 +360,115 @@ impl Range { pub fn new(space: SpaceId) -> Self { Self { space, - key_min: Bound::Unbounded, - key_max: Bound::Unbounded, + key_min: Bound::unbounded(), + key_max: Bound::unbounded(), } } /// Add a "greater than" restriction. #[inline(always)] pub fn gt(mut self, key: impl ToTupleBuffer) -> Self { - let tuple = key.to_tuple_buffer().expect("cannot fail"); - self.key_min = Bound::Excluded(tuple); + self.key_min = Bound::excluded(&key); self } /// Add a "greater or equal" restriction. #[inline(always)] pub fn ge(mut self, key: impl ToTupleBuffer) -> Self { - let tuple = key.to_tuple_buffer().expect("cannot fail"); - self.key_min = Bound::Included(tuple); + self.key_min = Bound::included(&key); self } /// Add a "less than" restriction. #[inline(always)] pub fn lt(mut self, key: impl ToTupleBuffer) -> Self { - let tuple = key.to_tuple_buffer().expect("cannot fail"); - self.key_max = Bound::Excluded(tuple); + self.key_max = Bound::excluded(&key); self } /// Add a "less or equal" restriction. #[inline(always)] pub fn le(mut self, key: impl ToTupleBuffer) -> Self { - let tuple = key.to_tuple_buffer().expect("cannot fail"); - self.key_max = Bound::Included(tuple); + self.key_max = Bound::included(&key); self } /// Add a "equal" restriction. #[inline(always)] pub fn eq(mut self, key: impl ToTupleBuffer) -> Self { - let tuple = key.to_tuple_buffer().expect("cannot fail"); - self.key_min = Bound::Included(tuple.clone()); - self.key_max = Bound::Included(tuple); + self.key_min = Bound::included(&key); + self.key_max = Bound::included(&key); self } + + pub fn contains(&self, key_def: &KeyDef, tuple: &Tuple) -> bool { + let min_satisfied = match self.key_min.kind { + BoundKind::Included => key_def + .compare_with_key(tuple, self.key_min.key.as_ref().unwrap()) + .is_ge(), + BoundKind::Excluded => key_def + .compare_with_key(tuple, self.key_min.key.as_ref().unwrap()) + .is_gt(), + BoundKind::Unbounded => true, + }; + // short-circuit + if !min_satisfied { + return false; + } + let max_satisfied = match self.key_max.kind { + BoundKind::Included => key_def + .compare_with_key(tuple, self.key_max.key.as_ref().unwrap()) + .is_le(), + BoundKind::Excluded => key_def + .compare_with_key(tuple, self.key_max.key.as_ref().unwrap()) + .is_lt(), + BoundKind::Unbounded => true, + }; + // min_satisfied && max_satisfied + max_satisfied + } } /// A bound for keys. #[derive(Clone, Debug, ::serde::Serialize, ::serde::Deserialize, tlua::LuaRead)] -#[serde(rename_all = "snake_case", tag = "kind", content = "value")] -pub enum Bound { - #[serde(with = "serde_bytes")] - Included(TupleBuffer), +pub struct Bound { + kind: BoundKind, #[serde(with = "serde_bytes")] - Excluded(TupleBuffer), - Unbounded, + key: Option<TupleBuffer>, +} + +impl Bound { + pub fn included(key: &impl ToTupleBuffer) -> Self { + Self { + kind: BoundKind::Included, + key: Some(key.to_tuple_buffer().expect("cannot fail")), + } + } + + pub fn excluded(key: &impl ToTupleBuffer) -> Self { + Self { + kind: BoundKind::Excluded, + key: Some(key.to_tuple_buffer().expect("cannot fail")), + } + } + + pub fn unbounded() -> Self { + Self { + kind: BoundKind::Unbounded, + key: None, + } + } +} + +::tarantool::define_str_enum! { + /// A bound for keys. + #[derive(Default)] + pub enum BoundKind { + Included = "included", + Excluded = "excluded", + #[default] + Unbounded = "unbounded", + } } /// Get space that the operation touches. diff --git a/test/conftest.py b/test/conftest.py index 179a642db4d13c43a3c20908c16da002e99c4a78..fd8af98f1766fcc1e8fe167728e51c839b77dd02 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -195,21 +195,21 @@ class RaftStatus: class CasRange: - key_min = dict(kind="unbounded", value=None) - key_max = dict(kind="unbounded", value=None) + key_min = dict(kind="unbounded", key=None) + key_max = dict(kind="unbounded", key=None) repr_min = "unbounded" repr_max = "unbounded" @property def key_min_packed(self) -> dict: key = self.key_min.copy() - key["value"] = msgpack.packb([key["value"]]) + key["key"] = msgpack.packb([key["key"][0]]) return key @property def key_max_packed(self) -> dict: key = self.key_max.copy() - key["value"] = msgpack.packb([key["value"]]) + key["key"] = msgpack.packb([key["key"][0]]) return key def __repr__(self): @@ -233,21 +233,21 @@ class CasRange: Example: `CasRange(ge=1) # [1, +infinity)` """ if gt is not None: - self.key_min = dict(kind="excluded", value=gt) + self.key_min = dict(kind="excluded", key=(gt,)) self.repr_min = f'gt="{gt}"' if ge is not None: - self.key_min = dict(kind="included", value=ge) + self.key_min = dict(kind="included", key=(ge,)) self.repr_min = f'ge="{ge}"' if lt is not None: - self.key_max = dict(kind="excluded", value=lt) + self.key_max = dict(kind="excluded", key=(lt,)) self.repr_max = f'lt="{lt}"' if le is not None: - self.key_max = dict(kind="included", value=le) + self.key_max = dict(kind="included", key=(le,)) self.repr_max = f'le="{le}"' if eq is not None: - self.key_min = dict(kind="included", value=eq) - self.key_max = dict(kind="included", value=eq) + self.key_min = dict(kind="included", key=(eq,)) + self.key_max = dict(kind="included", key=(eq,)) self.repr_min = f'ge="{eq}"' self.repr_max = f'le="{eq}"' @@ -1077,7 +1077,7 @@ class Cluster: def cas( self, dml_kind: Literal["insert", "replace", "delete"], - space: str, + space: str | int, tuple: Tuple | List, index: int | None = None, term: int | None = None, @@ -1096,12 +1096,14 @@ class Cluster: if instance is None: instance = self.instances[0] + space_id = instance.space_id(space) + predicate_ranges = [] if ranges is not None: for range in ranges: predicate_ranges.append( dict( - space=space, + space=space_id, key_min=range.key_min, key_max=range.key_max, ) diff --git a/test/int/test_cas.py b/test/int/test_cas.py index 732f5cf0882ef5efcba5fd408ff052f3afca5777..c2c6891f51ab6736e37a6619a9cf924803da31b2 100644 --- a/test/int/test_cas.py +++ b/test/int/test_cas.py @@ -1,5 +1,5 @@ import pytest -from conftest import Instance, TarantoolError, CasRange +from conftest import Cluster, Instance, TarantoolError, ReturnError, CasRange _3_SEC = 3 @@ -145,3 +145,42 @@ def test_cas_predicate(instance: Instance): instance.raft_wait_index(ret, _3_SEC) assert instance.raft_read_index(_3_SEC) == ret assert property("flower") == "tulip" + + +# Previous tests use stored procedure `.proc_cas`, this one uses `pico.cas` lua api instead +def test_cas_lua_api(cluster: Cluster): + def property(k: str): + return cluster.instances[0].eval( + """ + local tuple = box.space._pico_property:get(...) + return tuple and tuple.value + """, + k, + ) + + cluster.deploy(instance_count=3) + read_index = cluster.instances[0].raft_read_index(_3_SEC) + + # Successful insert + ret = cluster.cas("insert", "_pico_property", ["fruit", "apple"], read_index) + assert ret == read_index + 1 + cluster.raft_wait_index(ret, _3_SEC) + assert cluster.instances[0].raft_read_index(_3_SEC) == ret + assert property("fruit") == "apple" + + # CaS rejected + with pytest.raises(ReturnError) as e5: + cluster.cas( + "insert", + "_pico_property", + ["fruit", "orange"], + index=read_index, + ranges=[CasRange(eq="fruit")], + ) + assert e5.value.args == ( + "Network error: service responded with error: " + + "compare-and-swap request failed: " + + f"comparison failed for index {read_index} " + + f"as it conflicts with {read_index+1}", + ) + pass