From f8ac1dbe959d01ee25e5447c5eab86a061ee16f4 Mon Sep 17 00:00:00 2001
From: Sergey V <sv@picodata.io>
Date: Mon, 30 May 2022 20:00:40 +0300
Subject: [PATCH] feat: --cluster-id parameter

* Make `--cluster-id` CLI mandatory.
* Handle cluster_id mismatch in raft_join.
  When an instance attempts to join the cluster and the instances's
  `--instance-id` parameter mismatches the cluster_id of the cluster
  an error is raised inside the raft_join handler.
---
 src/args.rs              |  3 +-
 src/main.rs              |  5 +++
 src/traft/mod.rs         |  1 +
 src/traft/node.rs        | 15 +++++++++
 src/traft/storage.rs     | 13 ++++++++
 src/traft/topology.rs    |  1 +
 test/int/conftest.py     | 22 ++++++++++---
 test/int/test_joining.py | 67 ++++++++++++++++++++++++++++++++--------
 8 files changed, 109 insertions(+), 18 deletions(-)

diff --git a/src/args.rs b/src/args.rs
index 248062a2af..6fce8296ed 100644
--- a/src/args.rs
+++ b/src/args.rs
@@ -24,7 +24,7 @@ pub enum Picodata {
 pub struct Run {
     #[clap(long, value_name = "name", env = "PICODATA_CLUSTER_ID")]
     /// Name of the cluster
-    pub cluster_id: Option<String>,
+    pub cluster_id: String,
 
     #[clap(
         long,
@@ -268,6 +268,7 @@ mod tests {
     fn test_parse() {
         let _env_dump = EnvDump::new();
 
+        std::env::set_var("PICODATA_CLUSTER_ID", "cluster1");
         std::env::set_var("PICODATA_INSTANCE_ID", "instance-id-from-env");
         std::env::set_var("PICODATA_PEER", "peer-from-env");
         {
diff --git a/src/main.rs b/src/main.rs
index 4615c1e78d..61cda8e237 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -411,6 +411,7 @@ fn start_boot(args: &args::Run) {
 
     let mut topology = traft::Topology::from_peers(vec![]);
     let req = traft::JoinRequest {
+        cluster_id: args.cluster_id.clone(),
         instance_id: args.instance_id.clone(),
         replicaset_id: args.replicaset_id.clone(),
         advertise_address: args.advertise_address(),
@@ -469,6 +470,7 @@ fn start_boot(args: &args::Run) {
         traft::Storage::persist_commit(1).unwrap();
         traft::Storage::persist_term(1).unwrap();
         traft::Storage::persist_id(raft_id).unwrap();
+        traft::Storage::persist_cluster_id(&args.cluster_id).unwrap();
         Ok(())
     })
     .unwrap();
@@ -480,6 +482,7 @@ fn start_join(args: &args::Run, leader_address: String) {
     tlog!(Info, ">>>>> start_join({leader_address})");
 
     let req = traft::JoinRequest {
+        cluster_id: args.cluster_id.clone(),
         instance_id: args.instance_id.clone(),
         replicaset_id: args.replicaset_id.clone(),
         voter: false,
@@ -517,6 +520,7 @@ fn start_join(args: &args::Run, leader_address: String) {
             traft::Storage::persist_peer(&peer).unwrap();
         }
         traft::Storage::persist_id(raft_id).unwrap();
+        traft::Storage::persist_cluster_id(&args.cluster_id).unwrap();
         Ok(())
     })
     .unwrap();
@@ -589,6 +593,7 @@ fn postjoin(args: &args::Run) {
 
         tlog!(Warning, "initiating self-promotion of {me:?}");
         let req = traft::JoinRequest {
+            cluster_id: args.cluster_id.clone(),
             instance_id: me.instance_id.clone(),
             replicaset_id: None, // TODO
             voter: true,
diff --git a/src/traft/mod.rs b/src/traft/mod.rs
index dd108472da..2ca6144806 100644
--- a/src/traft/mod.rs
+++ b/src/traft/mod.rs
@@ -285,6 +285,7 @@ pub trait ContextCoercion: Serialize + DeserializeOwned {
 /// Request to join the cluster.
 #[derive(Clone, Debug, Serialize, Deserialize)]
 pub struct JoinRequest {
+    pub cluster_id: String,
     pub instance_id: String,
     pub replicaset_id: Option<String>,
     pub advertise_address: String,
diff --git a/src/traft/node.rs b/src/traft/node.rs
index 0465b51716..b33906cc16 100644
--- a/src/traft/node.rs
+++ b/src/traft/node.rs
@@ -48,6 +48,12 @@ pub enum Error {
     Timeout,
     #[error("{0}")]
     Raft(#[from] RaftError),
+    /// cluster_id of the joining peer mismatches the cluster_id of the cluster
+    #[error("cannot join the instance to the cluster: cluster_id mismatch: cluster_id of the instance = {instance_cluster_id:?}, cluster_id of the cluster = {cluster_cluster_id:?}")]
+    ClusterIdMismatch {
+        instance_cluster_id: String,
+        cluster_cluster_id: String,
+    },
 }
 
 #[derive(Clone, Debug, tlua::Push, tlua::PushInto)]
@@ -708,6 +714,15 @@ fn raft_interact(pbs: Vec<traft::MessagePb>) -> Result<(), Box<dyn StdError>> {
 fn raft_join(req: JoinRequest) -> Result<JoinResponse, Box<dyn StdError>> {
     let node = global()?;
 
+    let cluster_id = Storage::cluster_id()?.ok_or("cluster_id is not set yet")?;
+
+    if req.cluster_id != cluster_id {
+        return Err(Box::new(Error::ClusterIdMismatch {
+            instance_cluster_id: req.cluster_id,
+            cluster_cluster_id: cluster_id,
+        }));
+    }
+
     let instance_id = req.instance_id.clone();
     node.join_one(req)?;
 
diff --git a/src/traft/storage.rs b/src/traft/storage.rs
index 16d84c995f..e5fd49c24c 100644
--- a/src/traft/storage.rs
+++ b/src/traft/storage.rs
@@ -200,6 +200,10 @@ impl Storage {
         Storage::raft_state("id")
     }
 
+    pub fn cluster_id() -> Result<Option<String>, StorageError> {
+        Storage::raft_state("cluster_id")
+    }
+
     /// Node generation i.e. the number of restarts.
     pub fn gen() -> Result<Option<u64>, StorageError> {
         Storage::raft_state("gen")
@@ -252,6 +256,15 @@ impl Storage {
         Ok(())
     }
 
+    pub fn persist_cluster_id(id: &str) -> Result<(), StorageError> {
+        Storage::space(RAFT_STATE)?
+            // We use `insert` instead of `replace` here
+            // because `cluster_id` should never be changed.
+            .insert(&("cluster_id", id))
+            .map_err(box_err!())?;
+        Ok(())
+    }
+
     pub fn persist_peer(peer: &traft::Peer) -> Result<(), StorageError> {
         Storage::space(RAFT_GROUP)?
             .replace(peer)
diff --git a/src/traft/topology.rs b/src/traft/topology.rs
index 71997947bd..49f71bafdc 100644
--- a/src/traft/topology.rs
+++ b/src/traft/topology.rs
@@ -171,6 +171,7 @@ mod tests {
             $voter:literal
         ) => {
             &JoinRequest {
+                cluster_id: "cluster1".into(),
                 instance_id: $instance_id.into(),
                 replicaset_id: $replicaset_id.map(|v: &str| v.into()),
                 advertise_address: $advertise_address.into(),
diff --git a/test/int/conftest.py b/test/int/conftest.py
index f78ee4af6b..8cfc19c27d 100644
--- a/test/int/conftest.py
+++ b/test/int/conftest.py
@@ -9,7 +9,8 @@ import signal
 import subprocess
 
 from shutil import rmtree
-from typing import Generator
+from typing import Generator, Iterator
+from itertools import count
 from pathlib import Path
 from contextlib import contextmanager, suppress
 from dataclasses import dataclass, field
@@ -128,7 +129,7 @@ OUT_LOCK = threading.Lock()
 @dataclass
 class Instance:
     binary_path: str
-
+    cluster_id: str
     instance_id: str
     data_dir: str
     peers: list[str]
@@ -154,6 +155,7 @@ class Instance:
         # fmt: off
         return [
             self.binary_path, "run",
+            "--cluster-id", self.cluster_id,
             "--instance-id", self.instance_id,
             "--data-dir", self.data_dir,
             "--listen", self.listen,
@@ -324,7 +326,7 @@ class Instance:
 @dataclass
 class Cluster:
     binary_path: str
-
+    id: str
     data_dir: str
     base_host: str
     base_port: int
@@ -357,6 +359,7 @@ class Cluster:
 
         instance = Instance(
             binary_path=self.binary_path,
+            cluster_id=self.id,
             instance_id=f"i{i}",
             data_dir=f"{self.data_dir}/i{i}",
             host=self.base_host,
@@ -401,8 +404,18 @@ def binary_path(compile) -> str:
     return os.path.realpath(Path(__file__) / "../../../target/debug/picodata")
 
 
+@pytest.fixture(scope="session")
+def cluster_ids(xdist_worker_number) -> Iterator[str]:
+    return (f"cluster-{xdist_worker_number}-{i}" for i in count())
+
+
 @pytest.fixture
-def cluster(binary_path, tmpdir, xdist_worker_number) -> Generator[Cluster, None, None]:
+def cluster(
+    binary_path,
+    tmpdir,
+    xdist_worker_number,
+    cluster_ids,
+) -> Generator[Cluster, None, None]:
     n = xdist_worker_number
     assert isinstance(n, int)
     assert n >= 0
@@ -414,6 +427,7 @@ def cluster(binary_path, tmpdir, xdist_worker_number) -> Generator[Cluster, None
 
     cluster = Cluster(
         binary_path=binary_path,
+        id=next(cluster_ids),
         data_dir=tmpdir,
         base_host="127.0.0.1",
         base_port=base_port,
diff --git a/test/int/test_joining.py b/test/int/test_joining.py
index 7beee863d1..56a9f8d982 100644
--- a/test/int/test_joining.py
+++ b/test/int/test_joining.py
@@ -1,5 +1,7 @@
+from functools import partial
 import os
 import errno
+import re
 import signal
 import pytest
 
@@ -25,21 +27,23 @@ def cluster3(cluster: Cluster):
     return cluster
 
 
-def raft_join(peer: Instance, id: str, timeout: float):
-    instance_id = f"{id}"
+def raft_join(
+    peer: Instance, cluster_id: str, instance_id: str, timeout_seconds: float | int
+):
     replicaset_id = None
     # Workaround slow address resolving. Intentionally use
     # invalid address format to eliminate blocking DNS requests.
     # See https://git.picodata.io/picodata/picodata/tarantool-module/-/issues/81
-    address = f"nowhere/{id}"
+    address = f"nowhere/{instance_id}"
     is_voter = False
     return peer.call(
         ".raft_join",
+        cluster_id,
         instance_id,
         replicaset_id,
         address,
         is_voter,
-        timeout=timeout,
+        timeout=timeout_seconds,
     )
 
 
@@ -54,14 +58,18 @@ def test_concurrency(cluster2: Cluster):
 
     # First request blocks the `join_loop` until i2 is resumed.
     with pytest.raises(OSError) as e0:
-        raft_join(i1, "fake-0", timeout=0.1)
+        raft_join(
+            peer=i1, cluster_id=cluster2.id, instance_id="fake-0", timeout_seconds=0.1
+        )
     assert e0.value.errno == errno.ECONNRESET
 
     # Subsequent requests get batched
     executor = ThreadPoolExecutor()
-    f1 = executor.submit(raft_join, i1, "fake-1", timeout=5)
-    f2 = executor.submit(raft_join, i1, "fake-2", timeout=5)
-    f3 = executor.submit(raft_join, i1, "fake-3", timeout=0.1)
+    submit_join = partial(executor.submit, raft_join, peer=i1, cluster_id=cluster2.id)
+
+    f1 = submit_join(instance_id="fake-1", timeout_seconds=5)
+    f2 = submit_join(instance_id="fake-2", timeout_seconds=5)
+    f3 = submit_join(instance_id="fake-3", timeout_seconds=0.1)
 
     # Make sure all requests reach the server before resuming i2.
     with pytest.raises(OSError) as e1:
@@ -72,8 +80,8 @@ def test_concurrency(cluster2: Cluster):
     os.killpg(i2.process.pid, signal.SIGCONT)
     eprint(f"{i2} signalled with SIGCONT")
 
-    peer1 = f1.result()[0]["peer"]
-    peer2 = f2.result()[0]["peer"]
+    peer1 = f1.result()[0]["peer"]  # type: ignore
+    peer2 = f2.result()[0]["peer"]  # type: ignore
     assert peer1["instance_id"] == "fake-1"
     assert peer2["instance_id"] == "fake-2"
     # Make sure the batching works as expected
@@ -85,7 +93,9 @@ def test_request_follower(cluster2: Cluster):
     i2.assert_raft_status("Follower")
 
     with pytest.raises(TarantoolError) as e:
-        raft_join(i2, "fake-0", timeout=1)
+        raft_join(
+            peer=i2, cluster_id=cluster2.id, instance_id="fake-0", timeout_seconds=1
+        )
     assert e.value.args == ("ER_PROC_C", "not a leader")
 
 
@@ -95,6 +105,7 @@ def test_uuids(cluster2: Cluster):
 
     peer_1 = i1.call(
         ".raft_join",
+        cluster2.id,
         i1.instance_id,
         None,  # replicaset_id
         i1.listen,  # address
@@ -106,6 +117,7 @@ def test_uuids(cluster2: Cluster):
 
     peer_2 = i1.call(
         ".raft_join",
+        cluster2.id,
         i2.instance_id,
         None,  # replicaset_id
         i2.listen,  # address
@@ -115,9 +127,18 @@ def test_uuids(cluster2: Cluster):
     assert peer_2["instance_uuid"] == i2.eval("return box.info.uuid")
     assert peer_2["replicaset_uuid"] == i2.eval("return box.info.cluster.uuid")
 
+    def join():
+        return raft_join(
+            peer=i1,
+            cluster_id=cluster2.id,
+            instance_id="fake",
+            timeout_seconds=1,
+        )
+
     # Two consequent requests must obtain same raft_id and instance_id
-    fake_peer_1 = raft_join(i1, "fake", timeout=1)[0]["peer"]
-    fake_peer_2 = raft_join(i1, "fake", timeout=1)[0]["peer"]
+    fake_peer_1 = join()[0]["peer"]
+    fake_peer_2 = join()[0]["peer"]
+
     assert fake_peer_1["instance_id"] == "fake"
     assert fake_peer_2["instance_id"] == "fake"
     assert fake_peer_1["raft_id"] == fake_peer_2["raft_id"]
@@ -207,3 +228,23 @@ def test_replication(cluster2: Cluster):
                 assert cfg_replication[0] == [i1.listen, i2.listen]
         else:
             assert cfg_replication[0] == [i1.listen, i2.listen]
+
+
+def test_cluster_id_mismatch(instance: Instance):
+    wrong_cluster_id = "wrong-cluster-id"
+
+    assert wrong_cluster_id != instance.cluster_id
+
+    expected_error_re = re.escape(
+        "cannot join the instance to the cluster: cluster_id mismatch:"
+        ' cluster_id of the instance = "wrong-cluster-id",'
+        f' cluster_id of the cluster = "{instance.cluster_id}"'
+    )
+
+    with pytest.raises(TarantoolError, match=expected_error_re):
+        raft_join(
+            peer=instance,
+            cluster_id=wrong_cluster_id,
+            instance_id="whatever",
+            timeout_seconds=1,
+        )
-- 
GitLab