From f172850a019c738cb6522428f67f685957f245c1 Mon Sep 17 00:00:00 2001 From: Kaitmazian Maksim <m.kaitmazian@picodata.io> Date: Fri, 19 Jan 2024 17:21:47 +0300 Subject: [PATCH] refactor: refactor PgValue --- pgproto/src/entrypoints.rs | 15 +++---- pgproto/src/storage/value.rs | 79 ++++++++++++++++++++++++++++-------- 2 files changed, 67 insertions(+), 27 deletions(-) diff --git a/pgproto/src/entrypoints.rs b/pgproto/src/entrypoints.rs index ea85a0c2b3..7ad5a1b1d6 100644 --- a/pgproto/src/entrypoints.rs +++ b/pgproto/src/entrypoints.rs @@ -31,12 +31,10 @@ fn parse_dql(res: Value) -> PgResult<Vec<Row>> { } let res: DqlResult = serde_json::from_value(res)?; - let rows = res - .rows + res.rows .into_iter() - .map(|row| row.into_iter().map(PgValue::from).collect()) - .collect(); - Ok(rows) + .map(|row| row.into_iter().map(PgValue::try_from).collect()) + .collect() } fn parse_dml(res: Value) -> PgResult<usize> { @@ -51,11 +49,10 @@ fn parse_dml(res: Value) -> PgResult<usize> { fn parse_explain(res: Value) -> PgResult<Vec<Row>> { let res: Vec<Value> = serde_json::from_value(res)?; - Ok(res - .into_iter() + res.into_iter() // every row must be a vector - .map(|val| vec![PgValue::from(val)]) - .collect()) + .map(|val| Ok(vec![PgValue::try_from(val)?])) + .collect() } fn execute_result_from_json(json: &str) -> PgResult<ExecuteResult> { diff --git a/pgproto/src/storage/value.rs b/pgproto/src/storage/value.rs index a3a2cd803f..0275db9d82 100644 --- a/pgproto/src/storage/value.rs +++ b/pgproto/src/storage/value.rs @@ -3,6 +3,7 @@ use pgwire::api::Type; use pgwire::types::ToSqlText; use postgres_types::IsNull; use serde_json::Value; +use serde_repr::Deserialize_repr; use std::str; use crate::error::{PgError, PgResult}; @@ -20,37 +21,79 @@ pub fn type_from_name(name: &str) -> PgResult<Type> { } } +#[derive(Debug, Clone, Copy, Deserialize_repr)] +#[repr(i16)] +pub enum Format { + Text = 0, + Binary = 1, +} + +impl TryFrom<i16> for Format { + type Error = PgError; + fn try_from(value: i16) -> Result<Self, Self::Error> { + match value { + 0 => Ok(Format::Text), + 1 => Ok(Format::Binary), + _ => Err(PgError::FeatureNotSupported(format!( + "encoding type {value}" + ))), + } + } +} + #[derive(Debug)] -pub struct PgValue(Value); +pub enum PgValue { + Integer(i64), + Float(f64), + Boolean(bool), + Text(String), + Null, +} + +impl TryFrom<Value> for PgValue { + type Error = PgError; -impl From<Value> for PgValue { - fn from(value: Value) -> Self { - PgValue(value) + fn try_from(value: Value) -> Result<Self, Self::Error> { + let ret = match value { + Value::Number(number) => { + if number.is_f64() { + PgValue::Float(number.as_f64().unwrap()) + } else if number.is_i64() { + PgValue::Integer(number.as_i64().unwrap()) + } else { + Err(PgError::FeatureNotSupported(format!( + "unsupported type {number}" + )))? + } + } + Value::String(string) => PgValue::Text(string), + Value::Bool(bool) => PgValue::Boolean(bool), + Value::Null => PgValue::Null, + _ => Err(PgError::FeatureNotSupported(format!( + "unsupported type {value}" + )))?, + }; + Ok(ret) } } impl PgValue { pub fn encode(&self, buf: &mut BytesMut) -> PgResult<Option<Bytes>> { - // TODO: add ToSqlText::to_sql_text_checked for type checking. - // Value::Bool(bool).to_sql_text(&Type::FLOAT8) doesn't result in an error. - let do_encode = |buf: &mut BytesMut| match &self.0 { - Value::Bool(val) => { + let do_encode = |buf: &mut BytesMut| match &self { + PgValue::Boolean(val) => { buf.put_u8(if *val { b't' } else { b'f' }); Ok(IsNull::No) } - Value::String(string) => string.to_sql_text(&Type::TEXT, buf), - Value::Number(number) => { - if number.is_f64() { - number.as_f64().to_sql_text(&Type::FLOAT8, buf)?; - } else { - number.as_i64().to_sql_text(&Type::INT8, buf)?; - } + PgValue::Integer(number) => { + number.to_sql_text(&Type::INT8, buf)?; Ok(IsNull::No) } - _ => { - let value = &self.0; - Err(format!("can't encode value {value:?}"))? + PgValue::Float(number) => { + number.to_sql_text(&Type::FLOAT8, buf)?; + Ok(IsNull::No) } + PgValue::Text(string) => string.to_sql_text(&Type::TEXT, buf), + PgValue::Null => Ok(IsNull::Yes), }; let len = buf.len(); -- GitLab