From c62505e3d0530e07f723cb4f123f284d0b9dee19 Mon Sep 17 00:00:00 2001
From: Georgy Moshkin <gmoshkin@picodata.io>
Date: Mon, 4 Jul 2022 19:42:06 +0300
Subject: [PATCH] feat: implement failure domains

---
 src/traft/mod.rs      | 22 ++++++++++++++++
 src/traft/topology.rs | 61 ++++++++++++++++++++++++++++++++++++-------
 src/util.rs           | 12 +++++++++
 3 files changed, 85 insertions(+), 10 deletions(-)

diff --git a/src/traft/mod.rs b/src/traft/mod.rs
index 0df2c296d7..17e216882a 100644
--- a/src/traft/mod.rs
+++ b/src/traft/mod.rs
@@ -532,6 +532,28 @@ pub struct FailureDomains {
     data: HashMap<Uppercase, Uppercase>,
 }
 
+impl FailureDomains {
+    pub fn contains_name(&self, name: &Uppercase) -> bool {
+        self.data.contains_key(name)
+    }
+
+    pub fn names(&self) -> std::collections::hash_map::Keys<Uppercase, Uppercase> {
+        self.data.keys()
+    }
+
+    pub fn intersects(&self, other: &Self) -> bool {
+        for (name, value) in &self.data {
+            match other.data.get(name) {
+                Some(other_value) if value == other_value => {
+                    return true;
+                }
+                _ => {}
+            }
+        }
+        false
+    }
+}
+
 impl<I, K, V> From<I> for FailureDomains
 where
     I: IntoIterator<Item = (K, V)>,
diff --git a/src/traft/topology.rs b/src/traft/topology.rs
index a5a7089b4e..23d29ff7be 100644
--- a/src/traft/topology.rs
+++ b/src/traft/topology.rs
@@ -6,6 +6,7 @@ use crate::traft::FailureDomains;
 use crate::traft::Health;
 use crate::traft::Peer;
 use crate::traft::{InstanceId, RaftId, ReplicasetId};
+use crate::util::Uppercase;
 
 use raft::INVALID_INDEX;
 
@@ -13,6 +14,7 @@ pub struct Topology {
     replication_factor: u8,
     max_raft_id: RaftId,
 
+    failure_domain_names: HashSet<Uppercase>,
     instance_map: HashMap<InstanceId, Peer>,
     replicaset_map: BTreeMap<ReplicasetId, HashSet<InstanceId>>,
 }
@@ -22,6 +24,7 @@ impl Topology {
         let mut ret = Self {
             replication_factor: 2,
             max_raft_id: 0,
+            failure_domain_names: Default::default(),
             instance_map: Default::default(),
             replicaset_map: Default::default(),
         };
@@ -51,6 +54,8 @@ impl Topology {
                 .remove(&old_peer.instance_id);
         }
 
+        self.failure_domain_names
+            .extend(peer.failure_domains.names().cloned());
         self.instance_map.insert(instance_id.clone(), peer);
         self.replicaset_map
             .entry(replicaset_id)
@@ -64,10 +69,14 @@ impl Topology {
     }
 
     fn choose_replicaset_id(&self, failure_domains: &FailureDomains) -> String {
-        // TODO: implement logic
-        let _ = failure_domains;
-        for (replicaset_id, peers) in self.replicaset_map.iter() {
+        'next_replicaset: for (replicaset_id, peers) in self.replicaset_map.iter() {
             if peers.len() < self.replication_factor as usize {
+                for peer_id in peers {
+                    let peer = self.instance_map.get(peer_id).unwrap();
+                    if peer.failure_domains.intersects(failure_domains) {
+                        continue 'next_replicaset;
+                    }
+                }
                 return replicaset_id.clone();
             }
         }
@@ -82,6 +91,22 @@ impl Topology {
         }
     }
 
+    pub fn check_required_failure_domains(&self, fd: &FailureDomains) -> Result<(), String> {
+        let mut res = Vec::new();
+        for domain_name in &self.failure_domain_names {
+            if !fd.contains_name(domain_name) {
+                res.push(domain_name.to_string());
+            }
+        }
+
+        if res.is_empty() {
+            return Ok(());
+        }
+
+        res.sort();
+        Err(format!("missing failure domain names: {}", res.join(", ")))
+    }
+
     pub fn join(
         &mut self,
         instance_id: Option<String>,
@@ -98,6 +123,8 @@ impl Topology {
             }
         }
 
+        self.check_required_failure_domains(&failure_domains)?;
+
         // Anyway, `join` always produces a new raft_id.
         let raft_id = self.max_raft_id + 1;
         let instance_id: String = instance_id.unwrap_or_else(|| self.choose_instance_id(raft_id));
@@ -244,8 +271,9 @@ mod tests {
     }
 
     macro_rules! faildoms {
-        ($($k:tt : $v:tt),* $(,)?) => {
-            FailureDomains::from([$((stringify!($k), stringify!($v))),*])
+        ($(,)?) => { FailureDomains::default() };
+        ($($k:tt : $v:tt),+ $(,)?) => {
+            FailureDomains::from([$((stringify!($k), stringify!($v))),+])
         }
     }
 
@@ -400,7 +428,6 @@ mod tests {
     }
 
     #[test]
-    #[should_panic]
     fn failure_domains() {
         let mut t = Topology::from_peers(peers![]).with_replication_factor(3);
 
@@ -441,9 +468,9 @@ mod tests {
 
         assert_eq!(
             join!(t, None, None, "-", faildoms! {os: Arch})
-                .unwrap()
-                .replicaset_id,
-            "r2",
+                .unwrap_err()
+                .to_string(),
+            "missing failure domain names: PLANET",
         );
 
         assert_eq!(
@@ -454,10 +481,24 @@ mod tests {
         );
 
         assert_eq!(
-            join!(t, None, None, "-", faildoms! {os: Mac})
+            join!(t, None, None, "-", faildoms! {planet: Venus, os: Mac})
+                .unwrap()
+                .replicaset_id,
+            "r2",
+        );
+
+        assert_eq!(
+            join!(t, None, None, "-", faildoms! {planet: Mars, os: Mac})
                 .unwrap()
                 .replicaset_id,
             "r3",
         );
+
+        assert_eq!(
+            join!(t, None, None, "-", faildoms! {})
+                .unwrap_err()
+                .to_string(),
+            "missing failure domain names: OS, PLANET",
+        );
     }
 }
diff --git a/src/util.rs b/src/util.rs
index f8ba7dcbfd..d5ea903a89 100644
--- a/src/util.rs
+++ b/src/util.rs
@@ -95,6 +95,18 @@ impl std::ops::Deref for Uppercase {
     }
 }
 
+impl std::fmt::Display for Uppercase {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        std::fmt::Display::fmt(&self.0, f)
+    }
+}
+
+impl std::borrow::Borrow<str> for Uppercase {
+    fn borrow(&self) -> &str {
+        &*self.0
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
-- 
GitLab