From cd1802d9b0a89bf928831bfed0a38e87a82ed8a5 Mon Sep 17 00:00:00 2001
From: Erik Khamitov <e.khamitov@picodata.io>
Date: Tue, 24 Dec 2024 15:29:54 +0300
Subject: [PATCH] feat(sql): support VARCHAR without limit

---
 sbroad/sbroad-core/src/frontend/sql.rs        | 26 ++++++++------
 .../sbroad-core/src/frontend/sql/query.pest   |  2 +-
 sbroad/sbroad-core/src/ir/expression/cast.rs  |  5 ++-
 test/int/test_sql.py                          | 34 +++++++++++++++++++
 4 files changed, 55 insertions(+), 12 deletions(-)

diff --git a/sbroad/sbroad-core/src/frontend/sql.rs b/sbroad/sbroad-core/src/frontend/sql.rs
index f063cda6b8..d0ecb2b22e 100644
--- a/sbroad/sbroad-core/src/frontend/sql.rs
+++ b/sbroad/sbroad-core/src/frontend/sql.rs
@@ -2395,16 +2395,22 @@ fn cast_type_from_pair(type_pair: Pair<Rule>) -> Result<CastType, SbroadError> {
     }
 
     let mut type_pairs_inner = column_def_type.into_inner();
-    let varchar_length = type_pairs_inner
-        .next()
-        .expect("Length is missing under Varchar");
-    let len = varchar_length.as_str().parse::<usize>().map_err(|e| {
-        SbroadError::ParsingError(
-            Entity::Value,
-            format_smolstr!("Failed to parse varchar length: {e:?}."),
-        )
-    })?;
-    Ok(CastType::Varchar(len))
+    let type_cast = type_pairs_inner.next().map_or_else(
+        || Ok(CastType::Text),
+        |varchar_length| {
+            varchar_length
+                .as_str()
+                .parse::<usize>()
+                .map(CastType::Varchar)
+                .map_err(|e| {
+                    SbroadError::ParsingError(
+                        Entity::Value,
+                        format_smolstr!("Failed to parse varchar length: {e:?}."),
+                    )
+                })
+        },
+    )?;
+    Ok(type_cast)
 }
 
 /// Function responsible for parsing expressions using Pratt parser.
diff --git a/sbroad/sbroad-core/src/frontend/sql/query.pest b/sbroad/sbroad-core/src/frontend/sql/query.pest
index 5b6554b431..21648aaf6f 100644
--- a/sbroad/sbroad-core/src/frontend/sql/query.pest
+++ b/sbroad/sbroad-core/src/frontend/sql/query.pest
@@ -414,7 +414,7 @@ Expr = ${ ExprAtomValue ~ (ExprInfixOpo ~ ExprAtomValue)* }
                     TypeText = { ^"text" }
                     TypeUuid = { ^"uuid" }
                     TypeUnsigned = { ^"unsigned" }
-                    TypeVarchar = !{ ^"varchar" ~ "(" ~ Unsigned ~ ")" }
+                    TypeVarchar = { ^"varchar" ~ ("(" ~ WO ~ Unsigned ~ WO ~ ")")? }
             UnaryOperator = _{ Exists }
                 Exists = ${ (NotFlag ~ W)? ~ ^"exists" ~ W ~ SubQuery }
             Row = !{ "(" ~ Expr ~ ("," ~ Expr)* ~ ")" }
diff --git a/sbroad/sbroad-core/src/ir/expression/cast.rs b/sbroad/sbroad-core/src/ir/expression/cast.rs
index 92a5da90ed..e620012377 100644
--- a/sbroad/sbroad-core/src/ir/expression/cast.rs
+++ b/sbroad/sbroad-core/src/ir/expression/cast.rs
@@ -97,7 +97,10 @@ impl From<&Type> for SmolStr {
             Type::Text => "text".to_smolstr(),
             Type::Uuid => "uuid".to_smolstr(),
             Type::Unsigned => "unsigned".to_smolstr(),
-            Type::Varchar(length) => format_smolstr!("varchar({length})"),
+            Type::Varchar(length) => match length {
+                0 => "varchar".to_smolstr(),
+                _ => format_smolstr!("varchar({length})"),
+            },
         }
     }
 }
diff --git a/test/int/test_sql.py b/test/int/test_sql.py
index 332891baf5..be10ce89f9 100644
--- a/test/int/test_sql.py
+++ b/test/int/test_sql.py
@@ -5928,3 +5928,37 @@ Exceeded maximum number of rows (1) in virtual table: 2"""
         i1.sql(
             f"SELECT * FROM (VALUES (1), (2)) OPTION (VTABLE_MAX_ROWS = {new_vtable_max_rows})"
         )
+
+
+def test_varchar_cast(cluster: Cluster):
+    cluster.deploy(instance_count=1)
+    i1 = cluster.instances[0]
+
+    ddl = i1.sql(
+        """
+        create table t (a varchar not null, primary key (a))
+        using memtx
+        distributed by (a)
+        """
+    )
+    assert ddl["row_count"] == 1
+
+    dml = i1.sql("""insert into t values ('test_string')""")
+    assert dml["row_count"] == 1
+
+    data = i1.sql("""select cast(a as varchar(20)) from t""")
+    assert data == [["test_string"]]
+
+    # VARCHAR cast without length
+    data = i1.sql("""select cast(a as varchar) from t""")
+    assert data == [["test_string"]]
+
+    # VARCHAR cast with string literal
+    data = i1.sql("""select cast('direct_string' as varchar)""")
+    assert data == [["direct_string"]]
+
+    data = i1.sql(
+        """select cast(a as varchar) from t where a = 'test_string'""",
+        strip_metadata=False,
+    )
+    assert data["metadata"] == [{"name": "col_1", "type": "string"}]
-- 
GitLab