diff --git a/src/args.rs b/src/args.rs index 248062a2af6f685787a4bb408ac10fe06d2150be..6fce8296ed61cc72a0777d882bbba69c525fa61b 100644 --- a/src/args.rs +++ b/src/args.rs @@ -24,7 +24,7 @@ pub enum Picodata { pub struct Run { #[clap(long, value_name = "name", env = "PICODATA_CLUSTER_ID")] /// Name of the cluster - pub cluster_id: Option<String>, + pub cluster_id: String, #[clap( long, @@ -268,6 +268,7 @@ mod tests { fn test_parse() { let _env_dump = EnvDump::new(); + std::env::set_var("PICODATA_CLUSTER_ID", "cluster1"); std::env::set_var("PICODATA_INSTANCE_ID", "instance-id-from-env"); std::env::set_var("PICODATA_PEER", "peer-from-env"); { diff --git a/src/main.rs b/src/main.rs index 4615c1e78d3e259f526a5e0dcf24d3aa98e5f687..61cda8e237065ed4809f818efba89e351f61aa97 100644 --- a/src/main.rs +++ b/src/main.rs @@ -411,6 +411,7 @@ fn start_boot(args: &args::Run) { let mut topology = traft::Topology::from_peers(vec![]); let req = traft::JoinRequest { + cluster_id: args.cluster_id.clone(), instance_id: args.instance_id.clone(), replicaset_id: args.replicaset_id.clone(), advertise_address: args.advertise_address(), @@ -469,6 +470,7 @@ fn start_boot(args: &args::Run) { traft::Storage::persist_commit(1).unwrap(); traft::Storage::persist_term(1).unwrap(); traft::Storage::persist_id(raft_id).unwrap(); + traft::Storage::persist_cluster_id(&args.cluster_id).unwrap(); Ok(()) }) .unwrap(); @@ -480,6 +482,7 @@ fn start_join(args: &args::Run, leader_address: String) { tlog!(Info, ">>>>> start_join({leader_address})"); let req = traft::JoinRequest { + cluster_id: args.cluster_id.clone(), instance_id: args.instance_id.clone(), replicaset_id: args.replicaset_id.clone(), voter: false, @@ -517,6 +520,7 @@ fn start_join(args: &args::Run, leader_address: String) { traft::Storage::persist_peer(&peer).unwrap(); } traft::Storage::persist_id(raft_id).unwrap(); + traft::Storage::persist_cluster_id(&args.cluster_id).unwrap(); Ok(()) }) .unwrap(); @@ -589,6 +593,7 @@ fn postjoin(args: &args::Run) { tlog!(Warning, "initiating self-promotion of {me:?}"); let req = traft::JoinRequest { + cluster_id: args.cluster_id.clone(), instance_id: me.instance_id.clone(), replicaset_id: None, // TODO voter: true, diff --git a/src/traft/mod.rs b/src/traft/mod.rs index dd108472daf4603a164e97d5d36e7eed555d9675..2ca614480607f3b129068721aef1a82c1fa49a29 100644 --- a/src/traft/mod.rs +++ b/src/traft/mod.rs @@ -285,6 +285,7 @@ pub trait ContextCoercion: Serialize + DeserializeOwned { /// Request to join the cluster. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct JoinRequest { + pub cluster_id: String, pub instance_id: String, pub replicaset_id: Option<String>, pub advertise_address: String, diff --git a/src/traft/node.rs b/src/traft/node.rs index 0465b51716db0c81bb9a07c2a8af14092172909a..b33906cc16021f18f4965fb5a91fea87f6a08d7a 100644 --- a/src/traft/node.rs +++ b/src/traft/node.rs @@ -48,6 +48,12 @@ pub enum Error { Timeout, #[error("{0}")] Raft(#[from] RaftError), + /// cluster_id of the joining peer mismatches the cluster_id of the cluster + #[error("cannot join the instance to the cluster: cluster_id mismatch: cluster_id of the instance = {instance_cluster_id:?}, cluster_id of the cluster = {cluster_cluster_id:?}")] + ClusterIdMismatch { + instance_cluster_id: String, + cluster_cluster_id: String, + }, } #[derive(Clone, Debug, tlua::Push, tlua::PushInto)] @@ -708,6 +714,15 @@ fn raft_interact(pbs: Vec<traft::MessagePb>) -> Result<(), Box<dyn StdError>> { fn raft_join(req: JoinRequest) -> Result<JoinResponse, Box<dyn StdError>> { let node = global()?; + let cluster_id = Storage::cluster_id()?.ok_or("cluster_id is not set yet")?; + + if req.cluster_id != cluster_id { + return Err(Box::new(Error::ClusterIdMismatch { + instance_cluster_id: req.cluster_id, + cluster_cluster_id: cluster_id, + })); + } + let instance_id = req.instance_id.clone(); node.join_one(req)?; diff --git a/src/traft/storage.rs b/src/traft/storage.rs index 16d84c995f2f335a8e259995565d8fa01526abb2..e5fd49c24c09e849f79246bb9170d681d186cdc7 100644 --- a/src/traft/storage.rs +++ b/src/traft/storage.rs @@ -200,6 +200,10 @@ impl Storage { Storage::raft_state("id") } + pub fn cluster_id() -> Result<Option<String>, StorageError> { + Storage::raft_state("cluster_id") + } + /// Node generation i.e. the number of restarts. pub fn gen() -> Result<Option<u64>, StorageError> { Storage::raft_state("gen") @@ -252,6 +256,15 @@ impl Storage { Ok(()) } + pub fn persist_cluster_id(id: &str) -> Result<(), StorageError> { + Storage::space(RAFT_STATE)? + // We use `insert` instead of `replace` here + // because `cluster_id` should never be changed. + .insert(&("cluster_id", id)) + .map_err(box_err!())?; + Ok(()) + } + pub fn persist_peer(peer: &traft::Peer) -> Result<(), StorageError> { Storage::space(RAFT_GROUP)? .replace(peer) diff --git a/src/traft/topology.rs b/src/traft/topology.rs index 71997947bd76399ba63cbd58acf47997a96c11ff..49f71bafdc3a9e52a89e48b239a21a4c2b3386c8 100644 --- a/src/traft/topology.rs +++ b/src/traft/topology.rs @@ -171,6 +171,7 @@ mod tests { $voter:literal ) => { &JoinRequest { + cluster_id: "cluster1".into(), instance_id: $instance_id.into(), replicaset_id: $replicaset_id.map(|v: &str| v.into()), advertise_address: $advertise_address.into(), diff --git a/test/int/conftest.py b/test/int/conftest.py index f78ee4af6bb49c2ff28861ca858a9f46016598b3..8cfc19c27deb55a4f1bc37eaed432e31674889f8 100644 --- a/test/int/conftest.py +++ b/test/int/conftest.py @@ -9,7 +9,8 @@ import signal import subprocess from shutil import rmtree -from typing import Generator +from typing import Generator, Iterator +from itertools import count from pathlib import Path from contextlib import contextmanager, suppress from dataclasses import dataclass, field @@ -128,7 +129,7 @@ OUT_LOCK = threading.Lock() @dataclass class Instance: binary_path: str - + cluster_id: str instance_id: str data_dir: str peers: list[str] @@ -154,6 +155,7 @@ class Instance: # fmt: off return [ self.binary_path, "run", + "--cluster-id", self.cluster_id, "--instance-id", self.instance_id, "--data-dir", self.data_dir, "--listen", self.listen, @@ -324,7 +326,7 @@ class Instance: @dataclass class Cluster: binary_path: str - + id: str data_dir: str base_host: str base_port: int @@ -357,6 +359,7 @@ class Cluster: instance = Instance( binary_path=self.binary_path, + cluster_id=self.id, instance_id=f"i{i}", data_dir=f"{self.data_dir}/i{i}", host=self.base_host, @@ -401,8 +404,18 @@ def binary_path(compile) -> str: return os.path.realpath(Path(__file__) / "../../../target/debug/picodata") +@pytest.fixture(scope="session") +def cluster_ids(xdist_worker_number) -> Iterator[str]: + return (f"cluster-{xdist_worker_number}-{i}" for i in count()) + + @pytest.fixture -def cluster(binary_path, tmpdir, xdist_worker_number) -> Generator[Cluster, None, None]: +def cluster( + binary_path, + tmpdir, + xdist_worker_number, + cluster_ids, +) -> Generator[Cluster, None, None]: n = xdist_worker_number assert isinstance(n, int) assert n >= 0 @@ -414,6 +427,7 @@ def cluster(binary_path, tmpdir, xdist_worker_number) -> Generator[Cluster, None cluster = Cluster( binary_path=binary_path, + id=next(cluster_ids), data_dir=tmpdir, base_host="127.0.0.1", base_port=base_port, diff --git a/test/int/test_joining.py b/test/int/test_joining.py index 7beee863d17e17585fc1412414c0e4dc624a6291..56a9f8d982948bcec7c0fb19a49558fa76bb2b2b 100644 --- a/test/int/test_joining.py +++ b/test/int/test_joining.py @@ -1,5 +1,7 @@ +from functools import partial import os import errno +import re import signal import pytest @@ -25,21 +27,23 @@ def cluster3(cluster: Cluster): return cluster -def raft_join(peer: Instance, id: str, timeout: float): - instance_id = f"{id}" +def raft_join( + peer: Instance, cluster_id: str, instance_id: str, timeout_seconds: float | int +): replicaset_id = None # Workaround slow address resolving. Intentionally use # invalid address format to eliminate blocking DNS requests. # See https://git.picodata.io/picodata/picodata/tarantool-module/-/issues/81 - address = f"nowhere/{id}" + address = f"nowhere/{instance_id}" is_voter = False return peer.call( ".raft_join", + cluster_id, instance_id, replicaset_id, address, is_voter, - timeout=timeout, + timeout=timeout_seconds, ) @@ -54,14 +58,18 @@ def test_concurrency(cluster2: Cluster): # First request blocks the `join_loop` until i2 is resumed. with pytest.raises(OSError) as e0: - raft_join(i1, "fake-0", timeout=0.1) + raft_join( + peer=i1, cluster_id=cluster2.id, instance_id="fake-0", timeout_seconds=0.1 + ) assert e0.value.errno == errno.ECONNRESET # Subsequent requests get batched executor = ThreadPoolExecutor() - f1 = executor.submit(raft_join, i1, "fake-1", timeout=5) - f2 = executor.submit(raft_join, i1, "fake-2", timeout=5) - f3 = executor.submit(raft_join, i1, "fake-3", timeout=0.1) + submit_join = partial(executor.submit, raft_join, peer=i1, cluster_id=cluster2.id) + + f1 = submit_join(instance_id="fake-1", timeout_seconds=5) + f2 = submit_join(instance_id="fake-2", timeout_seconds=5) + f3 = submit_join(instance_id="fake-3", timeout_seconds=0.1) # Make sure all requests reach the server before resuming i2. with pytest.raises(OSError) as e1: @@ -72,8 +80,8 @@ def test_concurrency(cluster2: Cluster): os.killpg(i2.process.pid, signal.SIGCONT) eprint(f"{i2} signalled with SIGCONT") - peer1 = f1.result()[0]["peer"] - peer2 = f2.result()[0]["peer"] + peer1 = f1.result()[0]["peer"] # type: ignore + peer2 = f2.result()[0]["peer"] # type: ignore assert peer1["instance_id"] == "fake-1" assert peer2["instance_id"] == "fake-2" # Make sure the batching works as expected @@ -85,7 +93,9 @@ def test_request_follower(cluster2: Cluster): i2.assert_raft_status("Follower") with pytest.raises(TarantoolError) as e: - raft_join(i2, "fake-0", timeout=1) + raft_join( + peer=i2, cluster_id=cluster2.id, instance_id="fake-0", timeout_seconds=1 + ) assert e.value.args == ("ER_PROC_C", "not a leader") @@ -95,6 +105,7 @@ def test_uuids(cluster2: Cluster): peer_1 = i1.call( ".raft_join", + cluster2.id, i1.instance_id, None, # replicaset_id i1.listen, # address @@ -106,6 +117,7 @@ def test_uuids(cluster2: Cluster): peer_2 = i1.call( ".raft_join", + cluster2.id, i2.instance_id, None, # replicaset_id i2.listen, # address @@ -115,9 +127,18 @@ def test_uuids(cluster2: Cluster): assert peer_2["instance_uuid"] == i2.eval("return box.info.uuid") assert peer_2["replicaset_uuid"] == i2.eval("return box.info.cluster.uuid") + def join(): + return raft_join( + peer=i1, + cluster_id=cluster2.id, + instance_id="fake", + timeout_seconds=1, + ) + # Two consequent requests must obtain same raft_id and instance_id - fake_peer_1 = raft_join(i1, "fake", timeout=1)[0]["peer"] - fake_peer_2 = raft_join(i1, "fake", timeout=1)[0]["peer"] + fake_peer_1 = join()[0]["peer"] + fake_peer_2 = join()[0]["peer"] + assert fake_peer_1["instance_id"] == "fake" assert fake_peer_2["instance_id"] == "fake" assert fake_peer_1["raft_id"] == fake_peer_2["raft_id"] @@ -207,3 +228,23 @@ def test_replication(cluster2: Cluster): assert cfg_replication[0] == [i1.listen, i2.listen] else: assert cfg_replication[0] == [i1.listen, i2.listen] + + +def test_cluster_id_mismatch(instance: Instance): + wrong_cluster_id = "wrong-cluster-id" + + assert wrong_cluster_id != instance.cluster_id + + expected_error_re = re.escape( + "cannot join the instance to the cluster: cluster_id mismatch:" + ' cluster_id of the instance = "wrong-cluster-id",' + f' cluster_id of the cluster = "{instance.cluster_id}"' + ) + + with pytest.raises(TarantoolError, match=expected_error_re): + raft_join( + peer=instance, + cluster_id=wrong_cluster_id, + instance_id="whatever", + timeout_seconds=1, + )