From f172850a019c738cb6522428f67f685957f245c1 Mon Sep 17 00:00:00 2001
From: Kaitmazian Maksim <m.kaitmazian@picodata.io>
Date: Fri, 19 Jan 2024 17:21:47 +0300
Subject: [PATCH] refactor: refactor PgValue

---
 pgproto/src/entrypoints.rs   | 15 +++----
 pgproto/src/storage/value.rs | 79 ++++++++++++++++++++++++++++--------
 2 files changed, 67 insertions(+), 27 deletions(-)

diff --git a/pgproto/src/entrypoints.rs b/pgproto/src/entrypoints.rs
index ea85a0c2b3..7ad5a1b1d6 100644
--- a/pgproto/src/entrypoints.rs
+++ b/pgproto/src/entrypoints.rs
@@ -31,12 +31,10 @@ fn parse_dql(res: Value) -> PgResult<Vec<Row>> {
     }
 
     let res: DqlResult = serde_json::from_value(res)?;
-    let rows = res
-        .rows
+    res.rows
         .into_iter()
-        .map(|row| row.into_iter().map(PgValue::from).collect())
-        .collect();
-    Ok(rows)
+        .map(|row| row.into_iter().map(PgValue::try_from).collect())
+        .collect()
 }
 
 fn parse_dml(res: Value) -> PgResult<usize> {
@@ -51,11 +49,10 @@ fn parse_dml(res: Value) -> PgResult<usize> {
 
 fn parse_explain(res: Value) -> PgResult<Vec<Row>> {
     let res: Vec<Value> = serde_json::from_value(res)?;
-    Ok(res
-        .into_iter()
+    res.into_iter()
         // every row must be a vector
-        .map(|val| vec![PgValue::from(val)])
-        .collect())
+        .map(|val| Ok(vec![PgValue::try_from(val)?]))
+        .collect()
 }
 
 fn execute_result_from_json(json: &str) -> PgResult<ExecuteResult> {
diff --git a/pgproto/src/storage/value.rs b/pgproto/src/storage/value.rs
index a3a2cd803f..0275db9d82 100644
--- a/pgproto/src/storage/value.rs
+++ b/pgproto/src/storage/value.rs
@@ -3,6 +3,7 @@ use pgwire::api::Type;
 use pgwire::types::ToSqlText;
 use postgres_types::IsNull;
 use serde_json::Value;
+use serde_repr::Deserialize_repr;
 use std::str;
 
 use crate::error::{PgError, PgResult};
@@ -20,37 +21,79 @@ pub fn type_from_name(name: &str) -> PgResult<Type> {
     }
 }
 
+#[derive(Debug, Clone, Copy, Deserialize_repr)]
+#[repr(i16)]
+pub enum Format {
+    Text = 0,
+    Binary = 1,
+}
+
+impl TryFrom<i16> for Format {
+    type Error = PgError;
+    fn try_from(value: i16) -> Result<Self, Self::Error> {
+        match value {
+            0 => Ok(Format::Text),
+            1 => Ok(Format::Binary),
+            _ => Err(PgError::FeatureNotSupported(format!(
+                "encoding type {value}"
+            ))),
+        }
+    }
+}
+
 #[derive(Debug)]
-pub struct PgValue(Value);
+pub enum PgValue {
+    Integer(i64),
+    Float(f64),
+    Boolean(bool),
+    Text(String),
+    Null,
+}
+
+impl TryFrom<Value> for PgValue {
+    type Error = PgError;
 
-impl From<Value> for PgValue {
-    fn from(value: Value) -> Self {
-        PgValue(value)
+    fn try_from(value: Value) -> Result<Self, Self::Error> {
+        let ret = match value {
+            Value::Number(number) => {
+                if number.is_f64() {
+                    PgValue::Float(number.as_f64().unwrap())
+                } else if number.is_i64() {
+                    PgValue::Integer(number.as_i64().unwrap())
+                } else {
+                    Err(PgError::FeatureNotSupported(format!(
+                        "unsupported type {number}"
+                    )))?
+                }
+            }
+            Value::String(string) => PgValue::Text(string),
+            Value::Bool(bool) => PgValue::Boolean(bool),
+            Value::Null => PgValue::Null,
+            _ => Err(PgError::FeatureNotSupported(format!(
+                "unsupported type {value}"
+            )))?,
+        };
+        Ok(ret)
     }
 }
 
 impl PgValue {
     pub fn encode(&self, buf: &mut BytesMut) -> PgResult<Option<Bytes>> {
-        // TODO: add ToSqlText::to_sql_text_checked for type checking.
-        // Value::Bool(bool).to_sql_text(&Type::FLOAT8) doesn't result in an error.
-        let do_encode = |buf: &mut BytesMut| match &self.0 {
-            Value::Bool(val) => {
+        let do_encode = |buf: &mut BytesMut| match &self {
+            PgValue::Boolean(val) => {
                 buf.put_u8(if *val { b't' } else { b'f' });
                 Ok(IsNull::No)
             }
-            Value::String(string) => string.to_sql_text(&Type::TEXT, buf),
-            Value::Number(number) => {
-                if number.is_f64() {
-                    number.as_f64().to_sql_text(&Type::FLOAT8, buf)?;
-                } else {
-                    number.as_i64().to_sql_text(&Type::INT8, buf)?;
-                }
+            PgValue::Integer(number) => {
+                number.to_sql_text(&Type::INT8, buf)?;
                 Ok(IsNull::No)
             }
-            _ => {
-                let value = &self.0;
-                Err(format!("can't encode value {value:?}"))?
+            PgValue::Float(number) => {
+                number.to_sql_text(&Type::FLOAT8, buf)?;
+                Ok(IsNull::No)
             }
+            PgValue::Text(string) => string.to_sql_text(&Type::TEXT, buf),
+            PgValue::Null => Ok(IsNull::Yes),
         };
 
         let len = buf.len();
-- 
GitLab