From a880388ef5636c36a30d013f1ba2d303e905f3af Mon Sep 17 00:00:00 2001
From: Egor Ivkov <e.ivkov@picodata.io>
Date: Fri, 10 Nov 2023 08:50:18 +0000
Subject: [PATCH] feat: set max login attempts

---
 CHANGELOG.md         |  3 +++
 src/lib.rs           | 62 ++++++++++++++++++++++++++++++++++++++++++++
 src/storage.rs       | 19 ++++++++++++++
 test/int/test_acl.py | 60 ++++++++++++++++++++++++++++++++++++++++++
 4 files changed, 144 insertions(+)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index f341f25764..ac80cad931 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -24,6 +24,9 @@ with the `YY.0M.MICRO` scheme.
   Use `picodata connect --unix` to connect. Unlike connecting to a `--listen` address,
   console communication occurs in plain text and always operates under the admin account.
 
+- Restrict the number of login attempts through `picodata connect`. The limit can be set
+  through `max_login_attempts` property. The default value is `5`.
+
 - _Clusterwide SQL_ now available via `\set language sql` in interactive console.
 
 - Interactive console is disabled by default. Enable it implicitly with `picodata run -i`.
diff --git a/src/lib.rs b/src/lib.rs
index 29b064f0c2..8d33b8fbe6 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -12,6 +12,8 @@ use ::tarantool::time::Instant;
 use ::tarantool::tlua;
 use ::tarantool::transaction::transaction;
 use rpc::{join, update_instance};
+use std::cell::OnceCell;
+use std::collections::HashMap;
 use std::convert::TryFrom;
 use std::io;
 use std::time::Duration;
@@ -281,6 +283,65 @@ fn redirect_interactive_sql() {
     .expect("overriding sql executor shouldn't fail")
 }
 
+/// Sets a check for user exceeding maximum number of login attempts through `picodata connect`.
+/// Also see [`PropertyName::MaxLoginAttempts`].
+fn set_login_attempts_check(storage: Clusterwide) {
+    use std::collections::hash_map::Entry;
+
+    // It's ok to loose this information during restart, so we keep it as a static.
+    static mut LOGIN_ATTEMPTS: OnceCell<HashMap<String, usize>> = OnceCell::new();
+    const ERROR: &str = "maximum number of login attempts exceeded";
+
+    let lua = ::tarantool::lua_state();
+    lua.exec_with(
+        "box.session.on_auth(...)",
+        tlua::function3(
+            move |user: String, status: bool, lua: tlua::LuaState| unsafe {
+                // SAFETY: Accessing `USER_ATTEMPTS` is safe as it is only done from a single thread
+                LOGIN_ATTEMPTS.get_or_init(HashMap::new);
+                let attempts = LOGIN_ATTEMPTS.get_mut().expect("is initialized");
+
+                match attempts.entry(user) {
+                    Entry::Occupied(mut count) => {
+                        if *count.get()
+                            >= storage
+                                .properties
+                                .max_login_attempts()
+                                .expect("accessing storage should not fail")
+                        {
+                            // Currently `picodata connect` displays a generic error for all purposes
+                            // of connection failure. So this log is left to help during development/debugging.
+                            // Obviously this should not end up in production, so it is enabled only for debug
+                            // builds.
+                            // TODO: Remove once we fix `picodata connect` error display.
+                            if cfg!(debug_assertions) {
+                                tlog!(Warning, "{} (user=\"{}\")", ERROR, count.key());
+                            }
+                            // Raises an error instead of returning it as a function result.
+                            // This is the behavior required by `on_auth` trigger to drop the connection.
+                            // All the drop implementations are called, no need to clean anything up.
+                            tlua::ffi::lua_pushlstring(lua, ERROR.as_ptr() as _, ERROR.len());
+                            tlua::ffi::lua_error(lua);
+                            unreachable!();
+                        } else if status {
+                            // reset count on successful login
+                            count.remove();
+                        } else {
+                            *count.get_mut() += 1
+                        }
+                    }
+                    Entry::Vacant(count) => {
+                        if !status {
+                            count.insert(1);
+                        }
+                    }
+                }
+            },
+        ),
+    )
+    .expect("setting on auth trigger should not fail")
+}
+
 #[allow(clippy::enum_variant_names)]
 #[derive(Debug, Serialize, Deserialize)]
 pub enum Entrypoint {
@@ -328,6 +389,7 @@ fn init_common(args: &args::Run, cfg: &tarantool::Cfg) -> (Clusterwide, RaftSpac
     traft::event::init();
 
     let storage = Clusterwide::try_get(true).expect("storage initialization should never fail");
+    set_login_attempts_check(storage.clone());
     let raft_storage =
         RaftSpaceAccess::new().expect("raft storage initialization should never fail");
     (storage.clone(), raft_storage)
diff --git a/src/storage.rs b/src/storage.rs
index 22590c027c..87c1b2c69b 100644
--- a/src/storage.rs
+++ b/src/storage.rs
@@ -1075,6 +1075,15 @@ impl From<ClusterwideSpace> for SpaceId {
 
         PasswordMinLength = "password_min_length",
 
+        /// Maximum number of login attempts through `picodata connect`.
+        /// Each failed login attempt increases a local per user counter of failed attempts.
+        /// When the counter reaches the value of this property any subsequent logins
+        /// of this user will be denied.
+        /// Local counter for a user is reset on successful login.
+        ///
+        /// Default value is [`DEFAULT_MAX_LOGIN_ATTEMPTS`].
+        MaxLoginAttempts = "max_login_attempts",
+
         /// Number of seconds to wait before automatically changing an
         /// unresponsive instance's grade to Offline.
         AutoOfflineTimeout = "auto_offline_timeout",
@@ -1114,6 +1123,7 @@ pub const DEFAULT_AUTO_OFFLINE_TIMEOUT: f64 = 5.0;
 pub const DEFAULT_MAX_HEARTBEAT_PERIOD: f64 = 5.0;
 pub const DEFAULT_SNAPSHOT_CHUNK_MAX_SIZE: usize = 16 * 1024 * 1024;
 pub const DEFAULT_SNAPSHOT_READ_VIEW_CLOSE_TIMEOUT: f64 = (24 * 3600) as _;
+pub const DEFAULT_MAX_LOGIN_ATTEMPTS: usize = 5;
 
 impl Properties {
     pub fn new() -> tarantool::Result<Self> {
@@ -1173,6 +1183,15 @@ impl Properties {
         Ok(res)
     }
 
+    /// See [`PropertyName::MaxLoginAttempts`]
+    #[inline]
+    pub fn max_login_attempts(&self) -> tarantool::Result<usize> {
+        let res = self
+            .get(PropertyName::MaxLoginAttempts)?
+            .unwrap_or(DEFAULT_MAX_LOGIN_ATTEMPTS);
+        Ok(res)
+    }
+
     #[inline]
     pub fn pending_schema_change(&self) -> tarantool::Result<Option<Ddl>> {
         self.get(PropertyName::PendingSchemaChange)
diff --git a/test/int/test_acl.py b/test/int/test_acl.py
index 58fd7564e8..175baf4d8a 100644
--- a/test/int/test_acl.py
+++ b/test/int/test_acl.py
@@ -1,9 +1,11 @@
 import pytest
 from conftest import Cluster, Instance, TarantoolError, ReturnError
 from tarantool.error import NetworkError  # type: ignore
+from tarantool.connection import Connection  # type: ignore
 
 VALID_PASSWORD = "long enough"
 PASSWORD_MIN_LENGTH_KEY = "password_min_length"
+MAX_LOGIN_ATTEMPTS = 5
 
 
 def expected_min_password_violation_error(min_length: int):
@@ -26,6 +28,64 @@ def set_min_password_len(cluster: Cluster, i1: Instance, min_password_len: int):
     assert check[1] == min_password_len
 
 
+def test_max_login_attempts(cluster: Cluster):
+    i1, *_ = cluster.deploy(instance_count=1)
+
+    def connect(
+        i: Instance, user: str | None = None, password: str | None = None
+    ) -> Connection:
+        return Connection(
+            i.host,
+            i.port,
+            user=user,
+            password=password,
+            connect_now=True,
+            reconnect_max_attempts=0,
+        )
+
+    c = connect(i1)
+    c.eval(
+        """
+        box.session.su(1)
+        box.schema.user.create('foo', {password='bar'})
+        box.schema.user.grant('foo', 'read,write,execute', 'universe')
+        """
+    )
+    c.close()
+
+    # First login is successful
+    c = connect(i1, user="foo", password="bar")
+    assert c
+
+    # Several failed login attempts but one less than maximum
+    for _ in range(MAX_LOGIN_ATTEMPTS - 1):
+        with pytest.raises(
+            NetworkError, match="User not found or supplied credentials are invalid"
+        ):
+            # incorrect password
+            connect(i1, user="foo", password="baz")
+
+    # Still possible to login. Resets the attempts counter
+    c = connect(
+        i1,
+        user="foo",
+        password="bar",
+    )
+    assert c
+
+    # Maximum failed login attempts
+    for _ in range(MAX_LOGIN_ATTEMPTS):
+        with pytest.raises(
+            NetworkError, match="User not found or supplied credentials are invalid"
+        ):
+            # incorrect password
+            connect(i1, user="foo", password="baz")
+
+    # Next login even with correct password fails as the limit is reached
+    with pytest.raises(NetworkError, match="maximum number of login attempts exceeded"):
+        connect(i1, user="foo", password="bar")
+
+
 def test_acl_lua_api(cluster: Cluster):
     i1, *_ = cluster.deploy(instance_count=1)
 
-- 
GitLab