Merge pull request #97 from hanselke/deserialize_content

Adds deserializer for Content
This commit is contained in:
Dongri Jin
2024-07-17 17:41:58 +09:00
committed by GitHub

View File

@ -1,11 +1,11 @@
use serde::ser::SerializeMap;
use serde::{Deserialize, Serialize, Serializer};
use serde_json::Value;
use std::collections::HashMap;
use crate::impl_builder_methods; use crate::impl_builder_methods;
use crate::v1::common; use crate::v1::common;
use serde::de::{self, MapAccess, SeqAccess, Visitor};
use serde::ser::SerializeMap;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde_json::Value;
use std::collections::HashMap;
use std::fmt;
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
pub enum ToolChoiceType { pub enum ToolChoiceType {
None, None,
@ -104,7 +104,7 @@ pub enum MessageRole {
tool, tool,
} }
#[derive(Debug, Deserialize, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub enum Content { pub enum Content {
Text(String), Text(String),
ImageUrl(Vec<ImageUrl>), ImageUrl(Vec<ImageUrl>),
@ -128,6 +128,63 @@ impl serde::Serialize for Content {
} }
} }
impl<'de> Deserialize<'de> for Content {
fn deserialize<D>(deserializer: D) -> Result<Content, D::Error>
where
D: Deserializer<'de>,
{
struct ContentVisitor;
impl<'de> Visitor<'de> for ContentVisitor {
type Value = Content;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a valid content type")
}
fn visit_str<E>(self, value: &str) -> Result<Content, E>
where
E: de::Error,
{
Ok(Content::Text(value.to_string()))
}
fn visit_seq<A>(self, seq: A) -> Result<Content, A::Error>
where
A: SeqAccess<'de>,
{
let image_urls: Vec<ImageUrl> =
Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))?;
Ok(Content::ImageUrl(image_urls))
}
fn visit_map<M>(self, map: M) -> Result<Content, M::Error>
where
M: MapAccess<'de>,
{
let image_urls: Vec<ImageUrl> =
Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?;
Ok(Content::ImageUrl(image_urls))
}
fn visit_none<E>(self) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(Content::Text(String::new()))
}
fn visit_unit<E>(self) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(Content::Text(String::new()))
}
}
deserializer.deserialize_any(ContentVisitor)
}
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] #[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[allow(non_camel_case_types)] #[allow(non_camel_case_types)]
pub enum ContentType { pub enum ContentType {