From a22e541260bac9ba4ba4f9797ce7805e9e7f990a Mon Sep 17 00:00:00 2001
From: Georgy Moshkin <gmoshkin@picodata.io>
Date: Mon, 4 Jul 2022 13:25:49 +0300
Subject: [PATCH] feat: failure domains are always capitalized

---
 docs/topology.md |  8 ++++-
 src/args.rs      | 38 +++++++++++----------
 src/main.rs      | 11 ++-----
 src/traft/mod.rs | 36 ++++++++++++++++++--
 src/util.rs      | 86 ++++++++++++++++++++++++++++++++++++++++++++++++
 5 files changed, 150 insertions(+), 29 deletions(-)

diff --git a/docs/topology.md b/docs/topology.md
index 9d9e78fcb5..32357ca4d9 100644
--- a/docs/topology.md
+++ b/docs/topology.md
@@ -106,7 +106,7 @@ picodata run --init-replication-factor 2 --failure-domain region=us,zone=us-west
 
 Добавление инстанса в репликасет происходит по следующим правилам:
 
-- Если в каком-либо репликасете количество инстансов меньше необходимого фактора репликации, то новый инстанс добавляется в него при условии, что их параметры `--failure-domain` отличаются.
+- Если в каком-либо репликасете количество инстансов меньше необходимого фактора репликации, то новый инстанс добавляется в него при условии, что их параметры `--failure-domain` отличаются (регистр символов не учитывается).
 - Если подходящих репликасетов нет, то Picodata создает новый репликасет.
 
 Параметр `--failure-domain` играет роль только в момент добавления инстанса в кластер. **Принадлежность инстанса репликасету впоследствии не меняется**.
@@ -118,6 +118,12 @@ picodata run --init-replication-factor 2 --failure-domain region=us,zone=us-west
 
 Добавляемый инстанс должен обладать тем же набором параметров, которые уже есть в кластере. Например, инстанс `dc=msk` не сможет присоединиться к кластеру с `--failure-domain region=eu/us` и вернет ошибку.
 
+Как говорилось выше, сравнение доменов отказоустойчивости производится не
+учитывая регистр символов, таким образом два инстанса с аргументами
+`--failure-domain region=us` и `--failure-domain REGION=US` относятся к одному
+региону и следовательно не будут добавлены в один репликасет (за исключением
+случаев описанных ниже).
+
 ## Кейс: два датацентра по две реплики
 
 Picodata старается не объединять в один репликасет инстансы, у которых совпадает хотя бы один домен. Но иногда это все же необходимо. Чтобы ограничить Picodata в бесконечном создании репликасетов, можно воспользоваться флагом `--max-replicaset-count` (по умолчанию `inf`).
diff --git a/src/args.rs b/src/args.rs
index dac60d4276..96d4d9c709 100644
--- a/src/args.rs
+++ b/src/args.rs
@@ -1,13 +1,15 @@
 use clap::Parser;
 use std::{
     borrow::Cow,
-    collections::HashMap,
     ffi::{CStr, CString},
 };
 use tarantool::log::SayLevel;
 use tarantool::tlua::{self, c_str};
 use thiserror::Error;
 
+use crate::traft::FailureDomains;
+use crate::util::Uppercase;
+
 #[derive(Debug, Parser)]
 #[clap(name = "picodata", version = env!("CARGO_PKG_VERSION"))]
 pub enum Picodata {
@@ -99,7 +101,7 @@ pub struct Run {
         value_name = "key=value",
         require_value_delimiter = true,
         use_value_delimiter = true,
-        parse(try_from_str = try_parse_kv),
+        parse(try_from_str = try_parse_kv_uppercase),
         env = "PICODATA_FAILURE_DOMAIN"
     )]
     /// Comma-separated list describing physical location of the server.
@@ -109,7 +111,7 @@ pub struct Run {
     /// same value. Instead, new replicasets will be created.
     /// Replicasets will be populated with instances from different
     /// failure domains until the desired replication factor is reached.
-    pub failure_domains: Vec<(String, String)>,
+    pub failure_domains: Vec<(Uppercase, Uppercase)>,
 
     #[clap(long, value_name = "name", env = "PICODATA_REPLICASET_ID")]
     /// Name of the replicaset
@@ -179,12 +181,12 @@ impl Run {
         }
     }
 
-    pub fn failure_domains(&self) -> HashMap<&str, &str> {
-        let mut ret = HashMap::new();
-        for (k, v) in &self.failure_domains {
-            ret.insert(k.as_ref(), v.as_ref());
-        }
-        ret
+    pub fn failure_domains(&self) -> FailureDomains {
+        FailureDomains::from(
+            self.failure_domains
+                .iter()
+                .map(|(k, v)| (k.clone(), v.clone())),
+        )
     }
 }
 
@@ -263,11 +265,13 @@ fn try_parse_address(text: &str) -> Result<String, ParseAddressError> {
     Ok(format!("{host}:{port}"))
 }
 
-fn try_parse_kv(s: &str) -> Result<(String, String), String> {
-    let pos = s
-        .find('=')
+/// Parses a '=' sepparated string of key and value and converts both to
+/// uppercase.
+fn try_parse_kv_uppercase(s: &str) -> Result<(Uppercase, Uppercase), String> {
+    let (key, value) = s
+        .split_once('=')
         .ok_or_else(|| format!("invalid KEY=value: no `=` found in `{}`", s))?;
-    Ok((s[..pos].into(), s[pos + 1..].into()))
+    Ok((key.into(), value.into()))
 }
 
 #[cfg(test)]
@@ -333,7 +337,7 @@ mod tests {
             assert_eq!(parsed.listen, "localhost:3301"); // default
             assert_eq!(parsed.advertise_address(), "localhost:3301"); // default
             assert_eq!(parsed.log_level(), SayLevel::Info); // default
-            assert_eq!(parsed.failure_domains(), HashMap::new()); // default
+            assert_eq!(parsed.failure_domains(), FailureDomains::default()); // default
 
             let parsed = parse![Run, "--instance-id", "instance-id-from-args"];
             assert_eq!(
@@ -412,13 +416,13 @@ mod tests {
             let parsed = parse![Run,];
             assert_eq!(
                 parsed.failure_domains(),
-                HashMap::from([("k1", "env1"), ("k2", "env2")])
+                FailureDomains::from([("K1", "ENV1"), ("K2", "ENV2")])
             );
 
             let parsed = parse![Run, "--failure-domain", "k1=arg1,k1=arg1-again"];
             assert_eq!(
                 parsed.failure_domains(),
-                HashMap::from([("k1", "arg1-again")])
+                FailureDomains::from([("K1", "ARG1-AGAIN")])
             );
 
             let parsed = parse![
@@ -430,7 +434,7 @@ mod tests {
             ];
             assert_eq!(
                 parsed.failure_domains(),
-                HashMap::from([("k2", "arg2"), ("k3", "arg3"), ("k4", "arg4")])
+                FailureDomains::from([("K2", "ARG2"), ("K3", "ARG3"), ("K4", "ARG4")])
             );
         }
     }
diff --git a/src/main.rs b/src/main.rs
index acea68b8be..bd052fee06 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -420,10 +420,7 @@ fn start_boot(args: &args::Run) {
         args.instance_id(),
         args.replicaset_id.clone(),
         args.advertise_address(),
-        args.failure_domains()
-            .into_iter()
-            .map(|(k, v)| (k.into(), v.into()))
-            .collect(),
+        args.failure_domains(),
     );
     let raft_id = peer.raft_id;
     let instance_id = peer.instance_id.clone();
@@ -505,11 +502,7 @@ fn start_join(args: &args::Run, leader_address: String) {
         instance_id: args.instance_id(),
         replicaset_id: args.replicaset_id.clone(),
         advertise_address: args.advertise_address(),
-        failure_domains: args
-            .failure_domains()
-            .into_iter()
-            .map(|(k, v)| (k.into(), v.into()))
-            .collect(),
+        failure_domains: args.failure_domains(),
     };
 
     let fn_name = stringify_cfunc!(traft::node::raft_join);
diff --git a/src/traft/mod.rs b/src/traft/mod.rs
index 4e98940b98..3aa5001ed6 100644
--- a/src/traft/mod.rs
+++ b/src/traft/mod.rs
@@ -8,6 +8,7 @@ pub mod node;
 mod storage;
 pub mod topology;
 
+use crate::util::Uppercase;
 use ::raft::prelude as raft;
 use ::tarantool::tuple::AsTuple;
 use serde::de::DeserializeOwned;
@@ -28,8 +29,6 @@ pub type RaftId = u64;
 pub type InstanceId = String;
 pub type ReplicasetId = String;
 
-pub type FailureDomains = HashMap<String, String>;
-
 //////////////////////////////////////////////////////////////////////////////////////////
 /// Timestamps for raft entries.
 ///
@@ -518,3 +517,36 @@ pub fn replicaset_uuid(replicaset_id: &str) -> String {
     let uuid = Uuid::new_v3(&NAMESPACE_REPLICASET_UUID, replicaset_id.as_bytes());
     uuid.hyphenated().to_string()
 }
+
+////////////////////////////////////////////////////////////////////////////////
+/// Failure domains of a given instance.
+#[derive(Default, Debug, PartialEq, Eq, Clone, serde::Deserialize, serde::Serialize)]
+pub struct FailureDomains {
+    #[serde(flatten)]
+    data: HashMap<Uppercase, Uppercase>,
+}
+
+impl<I, K, V> From<I> for FailureDomains
+where
+    I: IntoIterator<Item = (K, V)>,
+    Uppercase: From<K>,
+    Uppercase: From<V>,
+{
+    fn from(data: I) -> Self {
+        Self {
+            data: data
+                .into_iter()
+                .map(|(k, v)| (Uppercase::from(k), Uppercase::from(v)))
+                .collect(),
+        }
+    }
+}
+
+impl<'a> IntoIterator for &'a FailureDomains {
+    type IntoIter = <&'a HashMap<Uppercase, Uppercase> as IntoIterator>::IntoIter;
+    type Item = <&'a HashMap<Uppercase, Uppercase> as IntoIterator>::Item;
+
+    fn into_iter(self) -> Self::IntoIter {
+        self.data.iter()
+    }
+}
diff --git a/src/util.rs b/src/util.rs
index b0aa9313ac..f8ba7dcbfd 100644
--- a/src/util.rs
+++ b/src/util.rs
@@ -23,3 +23,89 @@ macro_rules! warn_or_panic {
         }
     };
 }
+
+////////////////////////////////////////////////////////////////////////////////
+/// A wrapper around `String` that garantees the string is uppercase by
+/// converting it to uppercase (if needed) on construction.
+#[derive(Default, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, serde::Serialize)]
+pub struct Uppercase(String);
+
+impl<'de> serde::Deserialize<'de> for Uppercase {
+    fn deserialize<D>(de: D) -> Result<Self, D::Error>
+    where
+        D: serde::Deserializer<'de>,
+    {
+        Ok(String::deserialize(de)?.into())
+    }
+}
+
+impl<L: ::tarantool::tlua::AsLua> ::tarantool::tlua::Push<L> for Uppercase {
+    type Err = ::tarantool::tlua::Void;
+
+    fn push_to_lua(&self, lua: L) -> Result<tarantool::tlua::PushGuard<L>, (Self::Err, L)> {
+        self.0.push_to_lua(lua)
+    }
+}
+
+impl<L: ::tarantool::tlua::AsLua> ::tarantool::tlua::PushOne<L> for Uppercase {}
+
+impl<L: ::tarantool::tlua::AsLua> ::tarantool::tlua::PushInto<L> for Uppercase {
+    type Err = ::tarantool::tlua::Void;
+
+    fn push_into_lua(self, lua: L) -> Result<tarantool::tlua::PushGuard<L>, (Self::Err, L)> {
+        self.0.push_into_lua(lua)
+    }
+}
+
+impl<L: ::tarantool::tlua::AsLua> ::tarantool::tlua::PushOneInto<L> for Uppercase {}
+
+impl<L: ::tarantool::tlua::AsLua> ::tarantool::tlua::LuaRead<L> for Uppercase {
+    fn lua_read_at_position(lua: L, index: std::num::NonZeroI32) -> Result<Self, L> {
+        Ok(String::lua_read_at_position(lua, index)?.into())
+    }
+}
+
+impl From<String> for Uppercase {
+    fn from(s: String) -> Self {
+        if s.chars().all(char::is_uppercase) {
+            Self(s)
+        } else {
+            Self(s.to_uppercase())
+        }
+    }
+}
+
+impl From<&str> for Uppercase {
+    fn from(s: &str) -> Self {
+        Self(s.to_uppercase())
+    }
+}
+
+impl From<Uppercase> for String {
+    fn from(u: Uppercase) -> Self {
+        u.0
+    }
+}
+
+impl std::ops::Deref for Uppercase {
+    type Target = String;
+
+    fn deref(&self) -> &String {
+        &self.0
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn uppercase() {
+        assert_eq!(&*Uppercase::from(""), "");
+        assert_eq!(&*Uppercase::from("hello"), "HELLO");
+        assert_eq!(&*Uppercase::from("HELLO"), "HELLO");
+        assert_eq!(&*Uppercase::from("123-?!"), "123-?!");
+        assert_eq!(&*Uppercase::from(String::from("hello")), "HELLO");
+        assert_eq!(&*Uppercase::from(String::from("HELLO")), "HELLO");
+    }
+}
-- 
GitLab