From 787ad45d4b99b85b8a4a6582ad71091704282f8c Mon Sep 17 00:00:00 2001 From: hansel Date: Wed, 17 Jul 2024 14:44:45 +0800 Subject: [PATCH 1/3] deserialize --- src/v1/chat_completion.rs | 63 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 59 insertions(+), 4 deletions(-) diff --git a/src/v1/chat_completion.rs b/src/v1/chat_completion.rs index 28284be..faa4295 100644 --- a/src/v1/chat_completion.rs +++ b/src/v1/chat_completion.rs @@ -1,11 +1,11 @@ use serde::ser::SerializeMap; -use serde::{Deserialize, Serialize, Serializer}; +use serde::{Deserialize, Serialize, Serializer, Deserializer}; use serde_json::Value; use std::collections::HashMap; - +use serde::de::{self, MapAccess, SeqAccess, Visitor}; use crate::impl_builder_methods; use crate::v1::common; - +use std::fmt; #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] pub enum ToolChoiceType { None, @@ -104,7 +104,7 @@ pub enum MessageRole { tool, } -#[derive(Debug, Deserialize, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum Content { Text(String), ImageUrl(Vec), @@ -128,6 +128,61 @@ impl serde::Serialize for Content { } } +impl<'de> Deserialize<'de> for Content { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct ContentVisitor; + + impl<'de> Visitor<'de> for ContentVisitor { + type Value = Content; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a valid content type") + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + Ok(Content::Text(value.to_string())) + } + + fn visit_seq(self, seq: A) -> Result + where + A: serde::de::SeqAccess<'de>, + { + let image_urls: Vec = Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))?; + Ok(Content::ImageUrl(image_urls)) + } + + fn visit_map(self, map: M) -> Result + where + M: serde::de::MapAccess<'de>, + { + let image_urls: Vec = Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?; + Ok(Content::ImageUrl(image_urls)) + } + + fn visit_none(self) -> Result + where + E: de::Error, + { + Ok(Content::Text(String::new())) + } + + fn visit_unit(self) -> Result + where + E: de::Error, + { + Ok(Content::Text(String::new())) + } + } + + deserializer.deserialize_any(ContentVisitor) + } +} #[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] #[allow(non_camel_case_types)] pub enum ContentType { From c6e231bdf49a3f6cc35874a4ceeceb2d4d93451c Mon Sep 17 00:00:00 2001 From: hansel Date: Wed, 17 Jul 2024 14:48:09 +0800 Subject: [PATCH 2/3] fmt --- src/v1/chat_completion.rs | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/v1/chat_completion.rs b/src/v1/chat_completion.rs index faa4295..983b935 100644 --- a/src/v1/chat_completion.rs +++ b/src/v1/chat_completion.rs @@ -1,10 +1,10 @@ -use serde::ser::SerializeMap; -use serde::{Deserialize, Serialize, Serializer, Deserializer}; -use serde_json::Value; -use std::collections::HashMap; -use serde::de::{self, MapAccess, SeqAccess, Visitor}; use crate::impl_builder_methods; use crate::v1::common; +use serde::de::{self, MapAccess, SeqAccess, Visitor}; +use serde::ser::SerializeMap; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use serde_json::Value; +use std::collections::HashMap; use std::fmt; #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] pub enum ToolChoiceType { @@ -153,7 +153,8 @@ impl<'de> Deserialize<'de> for Content { where A: serde::de::SeqAccess<'de>, { - let image_urls: Vec = Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))?; + let image_urls: Vec = + Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))?; Ok(Content::ImageUrl(image_urls)) } @@ -161,7 +162,8 @@ impl<'de> Deserialize<'de> for Content { where M: serde::de::MapAccess<'de>, { - let image_urls: Vec = Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?; + let image_urls: Vec = + Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?; Ok(Content::ImageUrl(image_urls)) } From a128d333b07194f4f384c384e6ceaa724f9478dd Mon Sep 17 00:00:00 2001 From: hansel Date: Wed, 17 Jul 2024 14:50:31 +0800 Subject: [PATCH 3/3] ref imports directly --- src/v1/chat_completion.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/v1/chat_completion.rs b/src/v1/chat_completion.rs index 983b935..b999b92 100644 --- a/src/v1/chat_completion.rs +++ b/src/v1/chat_completion.rs @@ -151,7 +151,7 @@ impl<'de> Deserialize<'de> for Content { fn visit_seq(self, seq: A) -> Result where - A: serde::de::SeqAccess<'de>, + A: SeqAccess<'de>, { let image_urls: Vec = Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))?; @@ -160,7 +160,7 @@ impl<'de> Deserialize<'de> for Content { fn visit_map(self, map: M) -> Result where - M: serde::de::MapAccess<'de>, + M: MapAccess<'de>, { let image_urls: Vec = Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?;