diff --git a/sbroad-core/src/ir/relation.rs b/sbroad-core/src/ir/relation.rs index ecc5f561a144666c88409bef8e0cc55c7c458fa7..95073ce20aa04171fbdd5dfe38a5a1800e9835b4 100644 --- a/sbroad-core/src/ir/relation.rs +++ b/sbroad-core/src/ir/relation.rs @@ -67,6 +67,22 @@ impl Type { v => Err(SbroadError::NotImplemented(Entity::Type, v.to_string())), } } + + /// The type of the column is scalar. + /// Only scalar types can be used as a distribution key. + pub fn is_scalar(&self) -> bool { + matches!( + self, + Type::Boolean + | Type::Decimal + | Type::Double + | Type::Integer + | Type::Number + | Type::Scalar + | Type::String + | Type::Unsigned + ) + } } /// A role of the column in the relation. @@ -341,7 +357,24 @@ impl Table { let positions = keys .iter() .map(|name| match pos_map.get(*name) { - Some(pos) => Ok(*pos), + Some(pos) => { + // Check that the column type is scalar. + // Compound types are not supported as sharding keys. + let column = &columns.get(*pos).ok_or_else(|| { + SbroadError::FailedTo( + Action::Create, + Some(Entity::Column), + format!("column {name} not found at position {pos}"), + ) + })?; + if !column.r#type.is_scalar() { + return Err(SbroadError::Invalid( + Entity::Column, + Some(format!("column {name} at position {pos} is not scalar",)), + )); + } + Ok(*pos) + } None => Err(SbroadError::Invalid(Entity::ShardingKey, None)), }) .collect::<Result<Vec<usize>, _>>()?; diff --git a/sbroad-core/src/ir/relation/tests.rs b/sbroad-core/src/ir/relation/tests.rs index 7628305beef5e54e703e117106f7f8067dedba33..eb46947d377bd3e454dc389b2c52314e77209e40 100644 --- a/sbroad-core/src/ir/relation/tests.rs +++ b/sbroad-core/src/ir/relation/tests.rs @@ -126,6 +126,26 @@ fn table_seg_wrong_key() { ); } +#[test] +fn table_seg_compound_type_in_key() { + assert_eq!( + Table::new_seg( + "t", + vec![ + Column::new("bucket_id", Type::Unsigned, ColumnRole::Sharding), + Column::new("a", Type::Array, ColumnRole::User), + ], + &["a"], + SpaceEngine::Memtx, + ) + .unwrap_err(), + SbroadError::Invalid( + Entity::Column, + Some("column a at position 1 is not scalar".into()), + ) + ); +} + #[test] fn table_seg_serialized() { let t = Table::new_seg(