From 44a0fc615347373b09182dfe99a47417fc511e15 Mon Sep 17 00:00:00 2001 From: Kurdakov Alexander <kusancho12@gmail.com> Date: Wed, 27 Dec 2023 19:07:32 +0300 Subject: [PATCH] refactor: admin console --- src/cli/admin.rs | 199 ++++++++++++++++++++++++++++++++--- src/cli/connect.rs | 167 ++++++----------------------- src/cli/console.rs | 147 ++++++++++++++++++++++++++ src/cli/mod.rs | 1 + test/int/test_cli_connect.py | 30 ++++-- 5 files changed, 389 insertions(+), 155 deletions(-) create mode 100644 src/cli/console.rs diff --git a/src/cli/admin.rs b/src/cli/admin.rs index e634ce42bb..debc77be2f 100644 --- a/src/cli/admin.rs +++ b/src/cli/admin.rs @@ -1,20 +1,133 @@ -use crate::{ - tarantool_main, - util::{unwrap_or_terminate, validate_and_complete_unix_socket_path}, -}; +use std::io::{self, ErrorKind, Read, Write}; +use std::os::unix::net::UnixStream; +use std::str::from_utf8; +use std::time::Duration; + +use crate::tarantool_main; +use crate::util::unwrap_or_terminate; use super::args; +use super::console::{Console, ReplError}; + +/// Wrapper around unix socket with console-like interface +/// for communicating with tarantool console. +pub struct UnixClient { + socket: UnixStream, + buffer: Vec<u8>, +} + +#[derive(thiserror::Error, Debug)] +pub enum UnixClientError { + #[error("error during IO: {0}")] + Io(#[from] io::Error), + + #[error("malformed output: {0}")] + DeserializeMessageError(String), +} + +pub type Result<T> = std::result::Result<T, UnixClientError>; + +impl UnixClient { + const SERVER_DELIM: &str = "$EOF$\n"; + const CLIENT_DELIM: &[u8] = b"\n...\n"; + const WAIT_TIMEOUT: u64 = 10; + const INITIAL_BUFFER_SIZE: usize = 1024; + + fn from_stream(socket: UnixStream) -> Result<Self> { + socket.set_read_timeout(Some(Duration::from_secs(Self::WAIT_TIMEOUT)))?; + Ok(UnixClient { + socket, + buffer: vec![0; Self::INITIAL_BUFFER_SIZE], + }) + } + + /// Creates struct object using `path` for raw unix socket. + /// + /// Setup delimiter and ignore tarantool prompt. + fn new(path: &str) -> Result<Self> { + let socket = UnixStream::connect(path)?; + let mut client = Self::from_stream(socket)?; + + // set delimiter + let prelude: &str = "require(\"console\").delimiter(\"$EOF$\")\n"; + client.write_raw(prelude)?; + + // Ignore tarantool prompt. + // Prompt looks like: + // "Tarantool $version (Lua console) + // type 'help' for interactive help" + let prompt = client.read()?; + debug_assert!(prompt.contains("Tarantool")); + debug_assert!(prompt.contains("Lua console")); + + Ok(client) + } + + /// Writes message appended with delimiter to tarantool console + fn write(&mut self, line: &str) -> Result<()> { + self.write_raw(&(line.to_owned() + Self::SERVER_DELIM)) + } + + fn write_raw(&mut self, line: &str) -> Result<()> { + self.socket + .write_all(line.as_bytes()) + .map_err(UnixClientError::Io) + } + + /// Reads response from tarantool console. + /// Blocks until delimiter sequence or timeout is reached. + /// + /// # Errors + /// Returns error in the following cases: + /// 1. Read timeout + /// 2. Deserialization failure + fn read(&mut self) -> Result<String> { + let mut pos = 0; + loop { + let read = match self.socket.read(&mut self.buffer[pos..]) { + Ok(n) => n, + Err(ref e) if e.kind() == ErrorKind::Interrupted => { + continue; + } + Err(err) => return Err(err.into()), + }; + + pos += read; -fn connect_and_start_interacitve_console(args: args::Admin) -> Result<(), String> { - let endpoint = validate_and_complete_unix_socket_path(&args.socket_path)?; + // tarantool console appends delimiter to each response. + // Delimiter can be changed, but since we not do it manually, it's ok + if self.buffer[..pos].ends_with(Self::CLIENT_DELIM) { + break; + } - tarantool::lua_state() - .exec_with( - r#"local code, arg = ... - return load(code, '@src/connect.lua')(arg)"#, - (include_str!("connect.lua"), endpoint), - ) - .map_err(|e| e.to_string())?; + if pos == self.buffer.len() { + self.buffer.resize(pos * 2, 0); + } + } + + let deserialized = from_utf8(&self.buffer[..pos]) + .map_err(|err| UnixClientError::DeserializeMessageError(err.to_string()))? + .to_string(); + + return Ok(deserialized); + } +} + +fn admin_repl(args: args::Admin) -> core::result::Result<(), ReplError> { + let mut client = UnixClient::new(&args.socket_path).map_err(|err| { + ReplError::Other(format!( + "connection via unix socket by path '{}' is not established, reason: {}", + args.socket_path, err + )) + })?; + + let mut console = Console::new("picoadmin :) ")?; + + while let Some(line) = console.read()? { + client.write(&line)?; + let response = client.read()?; + console.write(&response); + } Ok(()) } @@ -25,10 +138,68 @@ pub fn main(args: args::Admin) -> ! { callback_data: args, callback_data_type: args::Admin, callback_body: { - unwrap_or_terminate(connect_and_start_interacitve_console(args)); + unwrap_or_terminate(admin_repl(args)); std::process::exit(0) } ); std::process::exit(rc); } + +#[cfg(test)] +mod tests { + use std::os::unix::net::UnixStream; + + use rmp::encode::RmpWrite; + + use super::{UnixClient, UnixClientError}; + + fn setup_client_server() -> (UnixClient, UnixStream) { + let (client, server) = UnixStream::pair().unwrap(); + let unix_client = UnixClient::from_stream(client).unwrap(); + (unix_client, server) + } + + #[test] + fn delimiter_timeout() { + let (mut client, mut server) = setup_client_server(); + server.write_bytes(b"output without delim").unwrap(); + let output = client.read(); + assert!(output.is_err()); + } + + #[test] + fn non_utf8_output() { + let (mut client, mut server) = setup_client_server(); + let non_utf = b"\x00\x9f\x92\x96\n...\n"; + server.write_bytes(non_utf).unwrap(); + let output = client.read(); + match output { + Err(UnixClientError::DeserializeMessageError(_)) => (), + _ => panic!(), + } + } + + #[test] + fn output_with_delimiter_is_accepted() { + let (mut client, mut server) = setup_client_server(); + server.write_bytes(b"output with delimiter\n...\n").unwrap(); + let output = client.read(); + assert!(output.is_ok()); + assert_eq!(output.unwrap(), "output with delimiter\n...\n"); + } + + #[test] + fn resize_logic() { + let (mut client, mut server) = setup_client_server(); + let initial_buf_size = client.buffer.len(); // 1024 + let delimiter = b"\n...\n"; + let mut big_output = vec![0u8; 1024]; + big_output.extend(delimiter); + server.write_bytes(&big_output.as_slice()).unwrap(); + let output = client.read(); + assert!(output.is_ok()); + assert!(client.buffer.len() > initial_buf_size); + assert!(client.buffer.len() == initial_buf_size * 2); + } +} diff --git a/src/cli/connect.rs b/src/cli/connect.rs index d2fcd79323..1255adc600 100644 --- a/src/cli/connect.rs +++ b/src/cli/connect.rs @@ -1,21 +1,17 @@ use std::fmt::Display; -use std::ops::ControlFlow; -use std::path::Path; use std::str::FromStr; -use std::{env, fs, io, process}; -use comfy_table::{ContentArrangement, Table}; - -use rustyline::error::ReadlineError; -use rustyline::DefaultEditor; -use tarantool::network::{client, AsClient, Client, Config}; +use tarantool::network::{AsClient, Client, Config}; use crate::tarantool_main; -use crate::util::unwrap_or_terminate; +use crate::util::{prompt_password, unwrap_or_terminate}; use super::args::{self, Address, DEFAULT_USERNAME}; +use super::console::{Console, ReplError}; +use comfy_table::{ContentArrangement, Table}; +use serde::{Deserialize, Serialize}; -pub(crate) fn get_password_from_file(path: &str) -> Result<String, String> { +fn get_password_from_file(path: &str) -> Result<String, String> { let content = std::fs::read_to_string(path).map_err(|e| { format!(r#"can't read password from password file by "{path}", reason: {e}"#) })?; @@ -33,22 +29,22 @@ pub(crate) fn get_password_from_file(path: &str) -> Result<String, String> { Ok(password.into()) } -#[derive(serde::Serialize, serde::Deserialize, Debug)] -struct ColDesc { +#[derive(Serialize, Deserialize, Debug)] +struct ColumnDesc { name: String, #[serde(rename = "type")] ty: String, } -impl Display for ColDesc { +impl Display for ColumnDesc { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_str(&self.name) } } -#[derive(serde::Serialize, serde::Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug)] struct RowSet { - metadata: Vec<ColDesc>, + metadata: Vec<ColumnDesc>, rows: Vec<Vec<rmpv::Value>>, } @@ -67,7 +63,7 @@ impl Display for RowSet { } } -#[derive(serde::Serialize, serde::Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug)] struct RowCount { row_count: usize, } @@ -78,11 +74,16 @@ impl Display for RowCount { } } -#[derive(serde::Serialize, serde::Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug)] #[serde(untagged)] enum ResultSet { RowSet(Vec<RowSet>), RowCount(Vec<RowCount>), + // Option<()> here is for ignoring tarantool redundant "null" + // Example: + // --- + // - null <----- + // - "sbroad: rule parsing error: ... Error(Option<()>, String), } @@ -103,142 +104,43 @@ impl Display for ResultSet { } } -#[derive(thiserror::Error, Debug)] -enum SqlReplError { - #[error("{0}")] - Client(#[from] client::Error), - - #[error("Failed to prompt for a password: {0}")] - Prompt(io::Error), - - #[error("{0}")] - Io(io::Error), -} - -const HISTORY_FILE_NAME: &str = ".picodata_history"; - -async fn sql_repl_main(args: args::Connect) { - unwrap_or_terminate(sql_repl(args).await); - std::process::exit(0) -} - -// Ideally we should have an enum for all commands. For now we have only two options, usual line -// and only one special command. To not overengineer things at this point just handle this as ifs. -// When the set of commands grows it makes total sense to transform this to clear parse/execute pipeline -// and separate enum variants for each command variant. -fn handle_special_sequence(line: &str) -> Result<ControlFlow<String>, SqlReplError> { - if line != "\\e" { - eprintln!("Unknown special sequence"); - return Ok(ControlFlow::Continue(())); - } +fn sql_repl(args: args::Connect) -> Result<(), ReplError> { + let address = Address::from_str(&args.address).map_err(ReplError::Other)?; - let editor = match env::var_os("EDITOR") { - Some(e) => e, - None => { - eprintln!("EDITOR environment variable is not set"); - return Ok(ControlFlow::Continue(())); - } - }; - - let temp = tempfile::Builder::new() - .suffix(".sql") - .tempfile() - .map_err(SqlReplError::Io)?; - let status = process::Command::new(&editor) - .arg(temp.path()) - .status() - .map_err(SqlReplError::Io)?; - - if !status.success() { - eprintln!("{:?} returned non zero exit status: {}", editor, status); - return Ok(ControlFlow::Continue(())); - } - - let line = fs::read_to_string(temp.path()).map_err(SqlReplError::Io)?; - Ok(ControlFlow::Break(line)) -} - -async fn sql_repl(args: args::Connect) -> Result<(), SqlReplError> { - let address = unwrap_or_terminate(Address::from_str(&args.address)); let user = address.user.as_ref().unwrap_or(&args.user).clone(); let password = if user == DEFAULT_USERNAME { String::new() } else if let Some(path) = args.password_file { - unwrap_or_terminate(get_password_from_file(&path)) + get_password_from_file(&path).map_err(ReplError::Other)? } else { let prompt = format!("Enter password for {user}: "); - match crate::util::prompt_password(&prompt) { - Ok(password) => password, - Err(e) => { - return Err(SqlReplError::Prompt(e)); - } - } + prompt_password(&prompt) + .map_err(|err| ReplError::Other(format!("Failed to prompt for a password: {err}")))? }; - let client = Client::connect_with_config( + let client = ::tarantool::fiber::block_on(Client::connect_with_config( &address.host, address.port.parse().unwrap(), Config { creds: Some((user, password)), }, - ) - .await?; + ))?; // Check if connection is valid. We need to do it because connect is lazy // and we want to check whether authentication have succeeded or not - client.call("box.schema.user.info", &()).await?; - - // It is deprecated because of unexpected behavior on windows. - // We're ok with that. - #[allow(deprecated)] - let history_file = env::home_dir() - .unwrap_or_default() - .join(Path::new(HISTORY_FILE_NAME)); + ::tarantool::fiber::block_on(client.call("box.schema.user.info", &()))?; - let mut rl = DefaultEditor::new().unwrap(); - rl.load_history(&history_file).ok(); + let mut console = Console::new("picosql :) ")?; - loop { - let readline = rl.readline("picosql :) "); - match readline { - Ok(line) => { - let line = { - if line.starts_with('\\') { - match handle_special_sequence(&line)? { - ControlFlow::Continue(_) => continue, - ControlFlow::Break(line) => line, - } - } else { - line - } - }; + while let Some(line) = console.read()? { + let response = ::tarantool::fiber::block_on(client.call("pico.sql", &(line,)))?; - if line.is_empty() { - continue; - } + let res: ResultSet = response.decode().map_err(|err| { + ReplError::Other(format!("error occured while processing output: {}", err)) + })?; - let response = client.call("pico.sql", &(&line,)).await?; - let res: ResultSet = response - .decode() - .expect("Response must have the shape of ResultSet structure"); - println!("{res}"); - - rl.add_history_entry(line.as_str()).ok(); - rl.save_history(&history_file).ok(); - } - Err(ReadlineError::Interrupted) => { - println!("CTRL-C"); - } - Err(ReadlineError::Eof) => { - println!("CTRL-D"); - break; - } - Err(err) => { - println!("Error: {:?}", err); - break; - } - } + console.write(&res.to_string()); } Ok(()) @@ -250,7 +152,8 @@ pub fn main(args: args::Connect) -> ! { callback_data: (args,), callback_data_type: (args::Connect,), callback_body: { - ::tarantool::fiber::block_on(sql_repl_main(args)) + unwrap_or_terminate(sql_repl(args)); + std::process::exit(0) } ); std::process::exit(rc); diff --git a/src/cli/console.rs b/src/cli/console.rs new file mode 100644 index 0000000000..da2f52cb05 --- /dev/null +++ b/src/cli/console.rs @@ -0,0 +1,147 @@ +use std::env; +use std::fs::read_to_string; +use std::io; +use std::ops::ControlFlow; +use std::path::Path; +use std::path::PathBuf; +use std::process; + +use rustyline::{error::ReadlineError, history::FileHistory, DefaultEditor, Editor}; +use tarantool::network::client; + +use super::admin::UnixClientError; + +#[derive(thiserror::Error, Debug)] +pub enum ReplError { + #[error("{0}")] + Client(#[from] client::Error), + + #[error("{0}")] + UnixClient(#[from] UnixClientError), + + #[error("{0}")] + Io(#[from] io::Error), + + #[error("{0}")] + EditorError(#[from] ReadlineError), + + #[error("{0}")] + Other(String), +} + +pub type Result<T> = std::result::Result<T, ReplError>; + +/// Input/output handler +pub struct Console { + editor: Editor<(), FileHistory>, + history_file_path: PathBuf, + prompt: String, +} + +impl Console { + const HISTORY_FILE_NAME: &str = ".picodata_history"; + + // Ideally we should have an enum for all commands. For now we have only two options, usual line + // and only one special command. To not overengineer things at this point just handle this as ifs. + // When the set of commands grows it makes total sense to transform this to clear parse/execute pipeline + // and separate enum variants for each command variant. + fn process_line(&self, line: String) -> Result<ControlFlow<String>> { + if line.is_empty() { + return Ok(ControlFlow::Continue(())); + } + + if !line.starts_with('\\') { + return Ok(ControlFlow::Break(line)); + } + + if line == "\\e" { + let editor = match env::var_os("EDITOR") { + Some(e) => e, + None => { + self.write("EDITOR environment variable is not set"); + return Ok(ControlFlow::Continue(())); + } + }; + + let temp = tempfile::Builder::new().suffix(".sql").tempfile()?; + let status = process::Command::new(&editor).arg(temp.path()).status()?; + + if !status.success() { + self.write(&format!( + "{:?} returned non zero exit status: {}", + editor, status + )); + return Ok(ControlFlow::Continue(())); + } + + let line = read_to_string(temp.path()).map_err(ReplError::Io)?; + + return Ok(ControlFlow::Break(line)); + } else if line == "\\lua" { + return Ok(ControlFlow::Break("\\set language lua".to_owned())); + } else if line == "\\sql" { + return Ok(ControlFlow::Break("\\set language sql".to_owned())); + } + + self.write("Unknown special sequence"); + Ok(ControlFlow::Continue(())) + } + + pub fn new(prompt: &str) -> Result<Self> { + let mut editor = DefaultEditor::new()?; + + // newline by ALT + ENTER + editor.bind_sequence( + rustyline::KeyEvent(rustyline::KeyCode::Enter, rustyline::Modifiers::ALT), + rustyline::EventHandler::Simple(rustyline::Cmd::Newline), + ); + + // It is deprecated because of unexpected behavior on windows. + // We're ok with that. + #[allow(deprecated)] + let history_file_path = env::home_dir() + .unwrap_or_default() + .join(Path::new(Self::HISTORY_FILE_NAME)); + + editor.load_history(&history_file_path)?; + + Ok(Console { + editor, + history_file_path, + prompt: prompt.to_string(), + }) + } + + /// Reads from stdin. Takes into account treating special symbols. + pub fn read(&mut self) -> Result<Option<String>> { + loop { + let readline = self.editor.readline(&self.prompt); + match readline { + Ok(line) => { + let line = match self.process_line(line)? { + ControlFlow::Continue(_) => continue, + ControlFlow::Break(line) => line, + }; + + self.editor.add_history_entry(line.as_str())?; + self.editor.save_history(&self.history_file_path)?; + + return Ok(Some(line)); + } + Err(ReadlineError::Interrupted) => { + self.write("CTRL+C"); + continue; + } + Err(ReadlineError::Eof) => { + self.write("Bye"); + return Ok(None); + } + Err(err) => return Err(err.into()), + } + } + } + + pub fn write(&self, line: &str) { + println!("{}", line) + } +} diff --git a/src/cli/mod.rs b/src/cli/mod.rs index 8b4eedd790..0ee7727275 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -1,6 +1,7 @@ pub mod admin; pub mod args; pub mod connect; +pub mod console; pub mod expel; pub mod init_cfg; pub mod run; diff --git a/test/int/test_cli_connect.py b/test/int/test_cli_connect.py index 5b3ab21dc3..9f7b338714 100644 --- a/test/int/test_cli_connect.py +++ b/test/int/test_cli_connect.py @@ -183,9 +183,10 @@ def test_admin_enoent(binary_path: str): ) cli.logfile = sys.stdout - cli.expect_exact("Connection is not established") + cli.expect_exact( + "connection via unix socket by path 'wrong/path/t.sock' is not established" + ) cli.expect_exact("No such file or directory") - cli.expect_exact("uri: unix/:./wrong/path/t.sock") cli.expect_exact(pexpect.EOF) @@ -199,9 +200,10 @@ def test_admin_econnrefused(binary_path: str): ) cli.logfile = sys.stdout - cli.expect_exact("Connection is not established") + cli.expect_exact( + "connection via unix socket by path '/dev/null' is not established" + ) cli.expect_exact("Connection refused") - cli.expect_exact("uri: unix/:/dev/null") cli.expect_exact(pexpect.EOF) @@ -215,7 +217,8 @@ def test_admin_invalid_path(binary_path: str): ) cli.logfile = sys.stdout - cli.expect_exact("invalid socket path: ./[][]") + cli.expect_exact("connection via unix socket by path './[][]' is not established") + cli.expect_exact("No such file or directory") cli.expect_exact(pexpect.EOF) @@ -229,7 +232,8 @@ def test_admin_empty_path(binary_path: str): ) cli.logfile = sys.stdout - cli.expect_exact("invalid socket path:") + cli.expect_exact("connection via unix socket by path '' is not established") + cli.expect_exact("Invalid argument") cli.expect_exact(pexpect.EOF) @@ -254,10 +258,17 @@ def test_connect_unix_ok_via_default_sock(cluster: Cluster): ) cli.logfile = sys.stdout - cli.expect_exact("connected to unix/:./admin.sock") - cli.expect_exact("unix/:./admin.sock>") + cli.expect_exact("picoadmin :) ") + + # Change language to SQL works + cli.sendline("\\sql") + cli.sendline("CREATE ROLE CHANGE_TO_SQL_WORKS") + cli.expect_exact("---\r\n") + cli.expect_exact("- row_count: 1\r\n") + cli.expect_exact("...\r\n") + cli.expect_exact("\r\n") - cli.sendline("\\set language lua") + cli.sendline("\\lua") cli.sendline("box.session.user()") cli.expect_exact("---\r\n") cli.expect_exact("- admin\r\n") @@ -266,6 +277,7 @@ def test_connect_unix_ok_via_default_sock(cluster: Cluster): eprint("^D") cli.sendcontrol("d") + cli.expect_exact("Bye") cli.expect_exact(pexpect.EOF) -- GitLab