diff --git a/sbroad b/sbroad index 4cb04bfc0b957aee2d1529db276e3ead758ffe62..73576c4bd9dd047016fbcf287e83452aad8b3416 160000 --- a/sbroad +++ b/sbroad @@ -1 +1 @@ -Subproject commit 4cb04bfc0b957aee2d1529db276e3ead758ffe62 +Subproject commit 73576c4bd9dd047016fbcf287e83452aad8b3416 diff --git a/src/sql.rs b/src/sql.rs index 4781c09f5dd1d152cea4cd7962ca89a2f52069eb..a271239772b91796e035ba7933e67c9e0617e924 100644 --- a/src/sql.rs +++ b/src/sql.rs @@ -75,7 +75,7 @@ fn check_table_privileges(plan: &IrPlan) -> traft::Result<()> { false }; let mut plan_traversal = PostOrderWithFilter::with_capacity( - |node| plan.subtree_iter(node), + |node| plan.subtree_iter(node, false), REL_CAPACITY, Box::new(filter), ); diff --git a/src/sql/pgproto.rs b/src/sql/pgproto.rs index 6a8b88b6ec6c5a791c9cb566ef27e7dbd9be1468..32d88e81c761e249ab2e4cfad4799d5504bd8e67 100644 --- a/src/sql/pgproto.rs +++ b/src/sql/pgproto.rs @@ -308,6 +308,7 @@ impl TryFrom<&Node> for CommandTag { | Relational::Projection { .. } | Relational::ScanRelation { .. } | Relational::ScanSubQuery { .. } + | Relational::Intersect { .. } | Relational::Selection { .. } | Relational::GroupBy { .. } | Relational::Having { .. } diff --git a/src/sql/router.rs b/src/sql/router.rs index 4fb164eabe06d298a79a629886b602bf298f63e1..1c447e0de0d814dda3b73618d526cd2aab83cad4 100644 --- a/src/sql/router.rs +++ b/src/sql/router.rs @@ -551,30 +551,34 @@ impl Metadata for RouterMetadata { columns.push(column); } + let normalized_name = normalize_name_from_sql(table_name); + let pk_cols = space_pk_columns(&name, &columns)?; + let pk_cols_str: &[&str] = &pk_cols.iter().map(String::as_str).collect::<Vec<_>>(); + // Try to find the sharding columns of the space in "_pico_table". // If nothing found then the space is local and we can't query it with // distributed SQL. let is_system_table = ClusterwideTable::values() .iter() .any(|sys_name| *sys_name == name.as_str()); - let shard_key_cols: Vec<String> = if is_system_table { - vec![] - } else { - Self::get_shard_cols(&name, &meta)? - }; - let sharding_key_arg: &[&str] = &shard_key_cols + + if is_system_table { + return Table::new_system(&normalized_name, columns, pk_cols_str); + } + let sharded_columns = Self::get_shard_cols(&name, &meta)?; + if sharded_columns.is_empty() { + return Table::new_global(&normalized_name, columns, pk_cols_str); + } + let sharding_columns_str: &[&str] = &sharded_columns .iter() .map(String::as_str) .collect::<Vec<_>>(); - let pk_cols = space_pk_columns(&name, &columns)?; - let pk_arg = &pk_cols.iter().map(String::as_str).collect::<Vec<_>>(); - Table::new( - &normalize_name_from_sql(table_name), + Table::new_sharded( + &normalized_name, columns, - sharding_key_arg, - pk_arg, + sharding_columns_str, + pk_cols_str, engine.into(), - is_system_table, ) } diff --git a/test/int/test_sql.py b/test/int/test_sql.py index 26dd30464e5c20e3973fba30825902c66df03613..38fe0db1bdd9891a47522cc4944e63c49e094c56 100644 --- a/test/int/test_sql.py +++ b/test/int/test_sql.py @@ -692,6 +692,149 @@ def test_union_all_on_global_tbls(cluster: Cluster): Retriable(rps=5, timeout=6).call(check_complex_segment_child) +def test_except_on_global_tbls(cluster: Cluster): + cluster.deploy(instance_count=1) + i1 = cluster.instances[0] + + ddl = i1.sql( + """ + create table g (a int not null, b int not null, primary key (a)) + using memtx + distributed globally + option (timeout = 3) + """ + ) + assert ddl["row_count"] == 1 + for i in range(1, 6): + index = i1.cas("insert", "G", [i, i]) + i1.raft_wait_index(index, 3) + + ddl = i1.sql( + """ + create table s (c int not null, d int not null, primary key (c)) + using memtx + distributed by (c) + option (timeout = 3) + """ + ) + assert ddl["row_count"] == 1 + + data = i1.sql("""insert into s values (3, 2), (4, 3), (5, 4), (6, 5), (7, 6);""") + assert data["row_count"] == 5 + + def check_global_vs_global(): + data = i1.sql( + """ + select a from g + except + select a - 1 from g + """, + timeout=2, + ) + assert data["rows"] == [[5]] + + Retriable(rps=5, timeout=6).call(check_global_vs_global) + + def check_global_vs_segment(): + data = i1.sql( + """ + select a from g + except + select c from s + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == [[1], [2]] + + Retriable(rps=5, timeout=6).call(check_global_vs_segment) + + def check_global_vs_any(): + data = i1.sql( + """ + select b from g + except + select d from s + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == [[1]] + + Retriable(rps=5, timeout=6).call(check_global_vs_any) + + def check_global_vs_single(): + data = i1.sql( + """ + select b from g + where b = 1 or b = 2 + except + select sum(d) from s + where d = 3 + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == [[1], [2]] + + Retriable(rps=5, timeout=6).call(check_global_vs_single) + + def check_single_vs_global(): + data = i1.sql( + """ + select sum(d) from s + where d = 3 or d = 2 + except + select b from g + where b = 1 or b = 2 + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == [[5]] + + Retriable(rps=5, timeout=6).call(check_single_vs_global) + + def check_segment_vs_global(): + data = i1.sql( + """ + select c from s + except + select a from g + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == [[6], [7]] + + Retriable(rps=5, timeout=6).call(check_segment_vs_global) + + def check_any_vs_global(): + data = i1.sql( + """ + select d from s + except + select b from g + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == [[6]] + + Retriable(rps=5, timeout=6).call(check_any_vs_global) + + def check_multiple_excepts(): + data = i1.sql( + """ + select a + 5 from g + where a = 1 or a = 2 + except select * from ( + select d from s + except + select b from g + ) + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == [[7]] + + Retriable(rps=5, timeout=6).call(check_multiple_excepts) + + def test_hash(cluster: Cluster): cluster.deploy(instance_count=1) i1 = cluster.instances[0]