deserialize

This commit is contained in:
hansel
2024-07-17 14:44:45 +08:00
parent 4a11c78393
commit 787ad45d4b

View File

@ -1,11 +1,11 @@
use serde::ser::SerializeMap; use serde::ser::SerializeMap;
use serde::{Deserialize, Serialize, Serializer}; use serde::{Deserialize, Serialize, Serializer, Deserializer};
use serde_json::Value; use serde_json::Value;
use std::collections::HashMap; use std::collections::HashMap;
use serde::de::{self, MapAccess, SeqAccess, Visitor};
use crate::impl_builder_methods; use crate::impl_builder_methods;
use crate::v1::common; use crate::v1::common;
use std::fmt;
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
pub enum ToolChoiceType { pub enum ToolChoiceType {
None, None,
@ -104,7 +104,7 @@ pub enum MessageRole {
tool, tool,
} }
#[derive(Debug, Deserialize, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub enum Content { pub enum Content {
Text(String), Text(String),
ImageUrl(Vec<ImageUrl>), ImageUrl(Vec<ImageUrl>),
@ -128,6 +128,61 @@ impl serde::Serialize for Content {
} }
} }
impl<'de> Deserialize<'de> for Content {
fn deserialize<D>(deserializer: D) -> Result<Content, D::Error>
where
D: Deserializer<'de>,
{
struct ContentVisitor;
impl<'de> Visitor<'de> for ContentVisitor {
type Value = Content;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a valid content type")
}
fn visit_str<E>(self, value: &str) -> Result<Content, E>
where
E: de::Error,
{
Ok(Content::Text(value.to_string()))
}
fn visit_seq<A>(self, seq: A) -> Result<Content, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
let image_urls: Vec<ImageUrl> = Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))?;
Ok(Content::ImageUrl(image_urls))
}
fn visit_map<M>(self, map: M) -> Result<Content, M::Error>
where
M: serde::de::MapAccess<'de>,
{
let image_urls: Vec<ImageUrl> = Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?;
Ok(Content::ImageUrl(image_urls))
}
fn visit_none<E>(self) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(Content::Text(String::new()))
}
fn visit_unit<E>(self) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(Content::Text(String::new()))
}
}
deserializer.deserialize_any(ContentVisitor)
}
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] #[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[allow(non_camel_case_types)] #[allow(non_camel_case_types)]
pub enum ContentType { pub enum ContentType {