. whitepace - add empty lines on all files

This commit is contained in:
Jeremy Chone
2024-09-17 16:52:21 -07:00
parent 9fe3018730
commit 896b9d4f72
62 changed files with 438 additions and 432 deletions

View File

@ -32,4 +32,4 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
}
Ok(())
}
}

View File

@ -51,4 +51,4 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
}
Ok(())
}
}

View File

@ -49,4 +49,4 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
}
Ok(())
}
}

View File

@ -41,4 +41,4 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
print_chat_stream(chat_res, None).await?;
Ok(())
}
}

View File

@ -24,4 +24,4 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
}
Ok(())
}
}

View File

@ -56,4 +56,4 @@ pub struct WebRequestData {
pub payload: Value,
}
// endregion: --- WebRequestData
// endregion: --- WebRequestData

View File

@ -189,4 +189,4 @@ struct AnthropicRequestParts {
// TODO: need to add tools
}
// endregion: --- Support
// endregion: --- Support

View File

@ -10,4 +10,4 @@ mod streamer;
pub use adapter_impl::*;
pub use streamer::*;
// endregion: --- Modules
// endregion: --- Modules

View File

@ -9,170 +9,170 @@ use std::task::{Context, Poll};
use value_ext::JsonValueExt;
pub struct AnthropicStreamer {
inner: EventSource,
options: StreamerOptions,
inner: EventSource,
options: StreamerOptions,
// -- Set by the poll_next
/// Flag to prevent polling the EventSource after a MessageStop event
done: bool,
captured_data: StreamerCapturedData,
// -- Set by the poll_next
/// Flag to prevent polling the EventSource after a MessageStop event
done: bool,
captured_data: StreamerCapturedData,
}
impl AnthropicStreamer {
pub fn new(inner: EventSource, model_iden: ModelIden, options_set: ChatOptionsSet<'_, '_>) -> Self {
Self {
inner,
done: false,
options: StreamerOptions::new(model_iden, options_set),
captured_data: Default::default(),
}
}
pub fn new(inner: EventSource, model_iden: ModelIden, options_set: ChatOptionsSet<'_, '_>) -> Self {
Self {
inner,
done: false,
options: StreamerOptions::new(model_iden, options_set),
captured_data: Default::default(),
}
}
}
impl futures::Stream for AnthropicStreamer {
type Item = Result<InterStreamEvent>;
type Item = Result<InterStreamEvent>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.done {
return Poll::Ready(None);
}
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.done {
return Poll::Ready(None);
}
while let Poll::Ready(event) = Pin::new(&mut self.inner).poll_next(cx) {
// NOTE: At this point, we capture more events than needed for genai::StreamItem, but it serves as documentation.
match event {
Some(Ok(Event::Open)) => return Poll::Ready(Some(Ok(InterStreamEvent::Start))),
Some(Ok(Event::Message(message))) => {
let message_type = message.event.as_str();
while let Poll::Ready(event) = Pin::new(&mut self.inner).poll_next(cx) {
// NOTE: At this point, we capture more events than needed for genai::StreamItem, but it serves as documentation.
match event {
Some(Ok(Event::Open)) => return Poll::Ready(Some(Ok(InterStreamEvent::Start))),
Some(Ok(Event::Message(message))) => {
let message_type = message.event.as_str();
match message_type {
"message_start" => {
self.capture_usage(message_type, &message.data)?;
continue;
}
"message_delta" => {
self.capture_usage(message_type, &message.data)?;
continue;
}
"content_block_start" => {
continue;
}
"content_block_delta" => {
let mut data: Value =
serde_json::from_str(&message.data).map_err(|serde_error| Error::StreamParse {
model_iden: self.options.model_iden.clone(),
serde_error,
})?;
let content: String = data.x_take("/delta/text")?;
match message_type {
"message_start" => {
self.capture_usage(message_type, &message.data)?;
continue;
}
"message_delta" => {
self.capture_usage(message_type, &message.data)?;
continue;
}
"content_block_start" => {
continue;
}
"content_block_delta" => {
let mut data: Value =
serde_json::from_str(&message.data).map_err(|serde_error| Error::StreamParse {
model_iden: self.options.model_iden.clone(),
serde_error,
})?;
let content: String = data.x_take("/delta/text")?;
// Add to the captured_content if chat options say so
if self.options.capture_content {
match self.captured_data.content {
Some(ref mut c) => c.push_str(&content),
None => self.captured_data.content = Some(content.clone()),
}
}
// Add to the captured_content if chat options say so
if self.options.capture_content {
match self.captured_data.content {
Some(ref mut c) => c.push_str(&content),
None => self.captured_data.content = Some(content.clone()),
}
}
return Poll::Ready(Some(Ok(InterStreamEvent::Chunk(content))));
}
"content_block_stop" => {
continue;
}
// -- END MESSAGE
"message_stop" => {
// Make sure we do not poll the EventSource anymore on the next poll.
// NOTE: This way, the last MessageStop event is still sent,
// but then, on the next poll, it will be stopped.
self.done = true;
return Poll::Ready(Some(Ok(InterStreamEvent::Chunk(content))));
}
"content_block_stop" => {
continue;
}
// -- END MESSAGE
"message_stop" => {
// Make sure we do not poll the EventSource anymore on the next poll.
// NOTE: This way, the last MessageStop event is still sent,
// but then, on the next poll, it will be stopped.
self.done = true;
// Capture the usage
let captured_usage = if self.options.capture_usage {
self.captured_data.usage.take().map(|mut usage| {
// Compute the total if any of input/output are not null
if usage.input_tokens.is_some() || usage.output_tokens.is_some() {
usage.total_tokens =
Some(usage.input_tokens.unwrap_or(0) + usage.output_tokens.unwrap_or(0));
}
usage
})
} else {
None
};
// Capture the usage
let captured_usage = if self.options.capture_usage {
self.captured_data.usage.take().map(|mut usage| {
// Compute the total if any of input/output are not null
if usage.input_tokens.is_some() || usage.output_tokens.is_some() {
usage.total_tokens =
Some(usage.input_tokens.unwrap_or(0) + usage.output_tokens.unwrap_or(0));
}
usage
})
} else {
None
};
let inter_stream_end = InterStreamEnd {
captured_usage,
captured_content: self.captured_data.content.take(),
};
let inter_stream_end = InterStreamEnd {
captured_usage,
captured_content: self.captured_data.content.take(),
};
// TODO: Need to capture the data as needed
return Poll::Ready(Some(Ok(InterStreamEvent::End(inter_stream_end))));
}
// TODO: Need to capture the data as needed
return Poll::Ready(Some(Ok(InterStreamEvent::End(inter_stream_end))));
}
"ping" => continue, // Loop to the next event
other => println!("UNKNOWN MESSAGE TYPE: {other}"),
}
}
Some(Err(err)) => {
println!("Error: {}", err);
return Poll::Ready(Some(Err(Error::ReqwestEventSource(err))));
}
None => return Poll::Ready(None),
}
}
Poll::Pending
}
"ping" => continue, // Loop to the next event
other => println!("UNKNOWN MESSAGE TYPE: {other}"),
}
}
Some(Err(err)) => {
println!("Error: {}", err);
return Poll::Ready(Some(Err(Error::ReqwestEventSource(err))));
}
None => return Poll::Ready(None),
}
}
Poll::Pending
}
}
// Support
impl AnthropicStreamer {
fn capture_usage(&mut self, message_type: &str, message_data: &str) -> Result<()> {
if self.options.capture_usage {
let data = self.parse_message_data(message_data)?;
// TODO: Might want to exit early if usage is not found
fn capture_usage(&mut self, message_type: &str, message_data: &str) -> Result<()> {
if self.options.capture_usage {
let data = self.parse_message_data(message_data)?;
// TODO: Might want to exit early if usage is not found
let (input_path, output_path) = if message_type == "message_start" {
("/message/usage/input_tokens", "/message/usage/output_tokens")
} else if message_type == "message_delta" {
("/usage/input_tokens", "/usage/output_tokens")
} else {
// TODO: Use tracing
println!(
"TRACING DEBUG - Anthropic message type not supported for input/output tokens: {message_type}"
);
return Ok(()); // For now permissive
};
let (input_path, output_path) = if message_type == "message_start" {
("/message/usage/input_tokens", "/message/usage/output_tokens")
} else if message_type == "message_delta" {
("/usage/input_tokens", "/usage/output_tokens")
} else {
// TODO: Use tracing
println!(
"TRACING DEBUG - Anthropic message type not supported for input/output tokens: {message_type}"
);
return Ok(()); // For now permissive
};
// -- Capture/Add the eventual input_tokens
// NOTE: Permissive on this one, if error, treat as nonexistent (for now)
if let Ok(input_tokens) = data.x_get::<i32>(input_path) {
let val = self
.captured_data
.usage
.get_or_insert(MetaUsage::default())
.input_tokens
.get_or_insert(0);
*val += input_tokens;
}
// -- Capture/Add the eventual input_tokens
// NOTE: Permissive on this one, if error, treat as nonexistent (for now)
if let Ok(input_tokens) = data.x_get::<i32>(input_path) {
let val = self
.captured_data
.usage
.get_or_insert(MetaUsage::default())
.input_tokens
.get_or_insert(0);
*val += input_tokens;
}
if let Ok(output_tokens) = data.x_get::<i32>(output_path) {
let val = self
.captured_data
.usage
.get_or_insert(MetaUsage::default())
.output_tokens
.get_or_insert(0);
*val += output_tokens;
}
}
if let Ok(output_tokens) = data.x_get::<i32>(output_path) {
let val = self
.captured_data
.usage
.get_or_insert(MetaUsage::default())
.output_tokens
.get_or_insert(0);
*val += output_tokens;
}
}
Ok(())
}
Ok(())
}
/// Simple wrapper for now, with the corresponding map_err.
/// Might have more logic later.
fn parse_message_data(&self, payload: &str) -> Result<Value> {
serde_json::from_str(payload).map_err(|serde_error| Error::StreamParse {
model_iden: self.options.model_iden.clone(),
serde_error,
})
}
}
/// Simple wrapper for now, with the corresponding map_err.
/// Might have more logic later.
fn parse_message_data(&self, payload: &str) -> Result<Value> {
serde_json::from_str(payload).map_err(|serde_error| Error::StreamParse {
model_iden: self.options.model_iden.clone(),
serde_error,
})
}
}

View File

@ -10,4 +10,4 @@ mod streamer;
pub use adapter_impl::*;
pub use streamer::*;
// endregion: --- Modules
// endregion: --- Modules

View File

@ -136,4 +136,4 @@ impl futures::Stream for CohereStreamer {
}
Poll::Pending
}
}
}

View File

@ -232,4 +232,4 @@ struct GeminiChatRequestParts {
contents: Vec<Value>,
}
// endregion: --- Support
// endregion: --- Support

View File

@ -10,4 +10,4 @@ mod streamer;
pub use adapter_impl::*;
pub use streamer::*;
// endregion: --- Modules
// endregion: --- Modules

View File

@ -125,4 +125,4 @@ impl futures::Stream for GeminiStreamer {
}
Poll::Pending
}
}
}

View File

@ -60,4 +60,4 @@ impl Adapter for GroqAdapter {
) -> Result<ChatStreamResponse> {
OpenAIAdapter::to_chat_stream(model_iden, reqwest_builder, options_set)
}
}
}

View File

@ -8,4 +8,4 @@ mod adapter_impl;
pub use adapter_impl::*;
// endregion: --- Modules
// endregion: --- Modules

View File

@ -77,4 +77,4 @@ impl Adapter for OllamaAdapter {
) -> Result<ChatStreamResponse> {
OpenAIAdapter::to_chat_stream(model_iden, reqwest_builder, options_set)
}
}
}

View File

@ -6,6 +6,6 @@
mod adapter_impl;
pub use adapter_impl::*;
pub use adapter_impl::*;
// endregion: --- Modules
// endregion: --- Modules

View File

@ -257,4 +257,4 @@ struct OpenAIRequestParts {
messages: Vec<Value>,
}
// endregion: --- Support
// endregion: --- Support

View File

@ -10,4 +10,4 @@ mod streamer;
pub use adapter_impl::*;
pub use streamer::*;
// endregion: --- Modules
// endregion: --- Modules

View File

@ -134,4 +134,4 @@ impl futures::Stream for OpenAIStreamer {
}
Poll::Pending
}
}
}

View File

@ -33,4 +33,4 @@ pub struct StreamerCapturedData {
pub content: Option<String>,
}
// endregion: --- Streamer Captured Data
// endregion: --- Streamer Captured Data

View File

@ -22,4 +22,4 @@ pub enum InterStreamEvent {
Start,
Chunk(String),
End(InterStreamEnd),
}
}

View File

@ -27,4 +27,4 @@ pub use adapter_kind::*;
// -- Crate modules
pub(crate) mod inter_stream;
// endregion: --- Modules
// endregion: --- Modules

View File

@ -45,4 +45,4 @@ pub fn get_api_key(model_iden: ModelIden, client_config: &ClientConfig) -> Resul
.to_string();
Ok(key)
}
}

View File

@ -166,4 +166,4 @@ impl ChatOptionsSet<'_, '_> {
}
}
// endregion: --- ChatOptionsSet
// endregion: --- ChatOptionsSet

View File

@ -159,4 +159,4 @@ pub struct ToolExtra {
tool_id: String,
}
// endregion: --- ChatMessage
// endregion: --- ChatMessage

View File

@ -65,4 +65,4 @@ pub struct MetaUsage {
pub total_tokens: Option<i32>,
}
// endregion: --- MetaUsage
// endregion: --- MetaUsage

View File

@ -51,4 +51,4 @@ impl JsonSpec {
self.description = Some(description.into());
self
}
}
}

View File

@ -96,4 +96,4 @@ impl From<InterStreamEnd> for StreamEnd {
}
}
// endregion: --- ChatStreamEvent
// endregion: --- ChatStreamEvent

View File

@ -85,4 +85,4 @@ where
// Remote(Url),
// Base64(String)
// }
// ```
// ```

View File

@ -22,4 +22,4 @@ pub use tool::*;
pub mod printer;
// endregion: --- Modules
// endregion: --- Modules

View File

@ -134,4 +134,4 @@ impl core::fmt::Display for Error {
impl std::error::Error for Error {}
// endregion: --- Error Boilerplate
// endregion: --- Error Boilerplate

View File

@ -11,4 +11,4 @@ pub struct Tool {
fn_name: String,
fn_description: String,
params: Value,
}
}

View File

@ -81,4 +81,4 @@ impl ClientBuilder {
};
Client { inner: Arc::new(inner) }
}
}
}

View File

@ -109,4 +109,4 @@ impl Client {
Ok(res)
}
}
}

View File

@ -53,4 +53,4 @@ pub(super) struct ClientInner {
pub(super) config: ClientConfig,
}
// endregion: --- ClientInner
// endregion: --- ClientInner

View File

@ -46,4 +46,4 @@ impl ClientConfig {
pub fn chat_options(&self) -> Option<&ChatOptions> {
self.chat_options.as_ref()
}
}
}

View File

@ -9,4 +9,4 @@ pub use builder::*;
pub use client_types::*;
pub use config::*;
// endregion: --- Modules
// endregion: --- Modules

View File

@ -6,4 +6,4 @@ mod model_name;
pub use model_iden::*;
pub use model_name::*;
// endregion: --- Modules
// endregion: --- Modules

View File

@ -35,4 +35,4 @@ where
model_name: model_name.into(),
}
}
}
}

View File

@ -50,4 +50,4 @@ impl Deref for ModelName {
fn deref(&self) -> &Self::Target {
&self.0
}
}
}

View File

@ -97,4 +97,4 @@ impl core::fmt::Display for Error {
impl std::error::Error for Error {}
// endregion: --- Error Boilerplate
// endregion: --- Error Boilerplate

View File

@ -18,4 +18,4 @@ pub mod chat;
pub mod resolver;
pub mod webc;
// endregion: --- Modules
// endregion: --- Modules

View File

@ -176,4 +176,4 @@ impl std::fmt::Debug for AuthData {
}
}
// endregion: --- AuthData Std Impls
// endregion: --- AuthData Std Impls

View File

@ -6,28 +6,28 @@ pub type Result<T> = core::result::Result<T, Error>;
/// Resolver error type.
#[derive(Debug, From)]
pub enum Error {
/// The API key environment variable was not found.
ApiKeyEnvNotFound {
/// The name of the environment variable.
env_name: String,
},
/// The API key environment variable was not found.
ApiKeyEnvNotFound {
/// The name of the environment variable.
env_name: String,
},
/// The `AuthData` is not a single value.
ResolverAuthDataNotSingleValue,
/// The `AuthData` is not a single value.
ResolverAuthDataNotSingleValue,
/// Custom error message.
#[from]
Custom(String),
/// Custom error message.
#[from]
Custom(String),
}
// region: --- Error Boilerplate
impl core::fmt::Display for Error {
fn fmt(&self, fmt: &mut core::fmt::Formatter) -> core::result::Result<(), core::fmt::Error> {
write!(fmt, "{self:?}")
}
fn fmt(&self, fmt: &mut core::fmt::Formatter) -> core::result::Result<(), core::fmt::Error> {
write!(fmt, "{self:?}")
}
}
impl std::error::Error for Error {}
// endregion: --- Error Boilerplate
// endregion: --- Error Boilerplate

View File

@ -13,4 +13,4 @@ pub use auth_resolver::*;
pub use error::{Error, Result};
pub use model_mapper::*;
// endregion: --- Modules
// endregion: --- Modules

View File

@ -95,4 +95,4 @@ where
}
}
// endregion: --- IntoModelMapperFn
// endregion: --- IntoModelMapperFn

View File

@ -38,4 +38,4 @@ impl core::fmt::Display for Error {
impl std::error::Error for Error {}
// endregion: --- Error Boilerplate
// endregion: --- Error Boilerplate

View File

@ -14,4 +14,4 @@ pub(crate) use web_stream::*;
// only public for external use
pub use error::Error;
// endregion: --- Modules
// endregion: --- Modules

View File

@ -111,4 +111,4 @@ impl WebResponse {
}
}
// endregion: --- WebResponse
// endregion: --- WebResponse

View File

@ -15,150 +15,150 @@ use std::task::{Context, Poll};
/// - It is the responsibility of the user of this stream to wrap it into a semantically correct stream of events depending on the domain.
#[allow(clippy::type_complexity)]
pub struct WebStream {
stream_mode: StreamMode,
reqwest_builder: Option<RequestBuilder>,
response_future: Option<Pin<Box<dyn Future<Output = Result<Response, Box<dyn Error>>> + Send>>>,
bytes_stream: Option<Pin<Box<dyn Stream<Item = Result<Bytes, Box<dyn Error>>> + Send>>>,
// If a poll was a partial message, then we kept the previous part
partial_message: Option<String>,
// If a poll retrieved multiple messages, we keep them to be sent in the next poll
remaining_messages: Option<VecDeque<String>>,
stream_mode: StreamMode,
reqwest_builder: Option<RequestBuilder>,
response_future: Option<Pin<Box<dyn Future<Output = Result<Response, Box<dyn Error>>> + Send>>>,
bytes_stream: Option<Pin<Box<dyn Stream<Item = Result<Bytes, Box<dyn Error>>> + Send>>>,
// If a poll was a partial message, then we kept the previous part
partial_message: Option<String>,
// If a poll retrieved multiple messages, we keep them to be sent in the next poll
remaining_messages: Option<VecDeque<String>>,
}
pub enum StreamMode {
// This is used for Cohere with a single `\n`
Delimiter(&'static str),
// This is for Gemini (standard JSON array, pretty formatted)
PrettyJsonArray,
// This is used for Cohere with a single `\n`
Delimiter(&'static str),
// This is for Gemini (standard JSON array, pretty formatted)
PrettyJsonArray,
}
impl WebStream {
pub fn new_with_delimiter(reqwest_builder: RequestBuilder, message_delimiter: &'static str) -> Self {
Self {
stream_mode: StreamMode::Delimiter(message_delimiter),
reqwest_builder: Some(reqwest_builder),
response_future: None,
bytes_stream: None,
partial_message: None,
remaining_messages: None,
}
}
pub fn new_with_delimiter(reqwest_builder: RequestBuilder, message_delimiter: &'static str) -> Self {
Self {
stream_mode: StreamMode::Delimiter(message_delimiter),
reqwest_builder: Some(reqwest_builder),
response_future: None,
bytes_stream: None,
partial_message: None,
remaining_messages: None,
}
}
pub fn new_with_pretty_json_array(reqwest_builder: RequestBuilder) -> Self {
Self {
stream_mode: StreamMode::PrettyJsonArray,
reqwest_builder: Some(reqwest_builder),
response_future: None,
bytes_stream: None,
partial_message: None,
remaining_messages: None,
}
}
pub fn new_with_pretty_json_array(reqwest_builder: RequestBuilder) -> Self {
Self {
stream_mode: StreamMode::PrettyJsonArray,
reqwest_builder: Some(reqwest_builder),
response_future: None,
bytes_stream: None,
partial_message: None,
remaining_messages: None,
}
}
}
impl Stream for WebStream {
type Item = Result<String, Box<dyn Error>>;
type Item = Result<String, Box<dyn Error>>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
// -- First, we check if we have any remaining messages to send.
if let Some(ref mut remaining_messages) = this.remaining_messages {
if let Some(msg) = remaining_messages.pop_front() {
return Poll::Ready(Some(Ok(msg)));
}
}
// -- First, we check if we have any remaining messages to send.
if let Some(ref mut remaining_messages) = this.remaining_messages {
if let Some(msg) = remaining_messages.pop_front() {
return Poll::Ready(Some(Ok(msg)));
}
}
// -- Then execute the web poll and processing loop
loop {
if let Some(ref mut fut) = this.response_future {
match Pin::new(fut).poll(cx) {
Poll::Ready(Ok(response)) => {
let bytes_stream = response.bytes_stream().map_err(|e| Box::new(e) as Box<dyn Error>);
this.bytes_stream = Some(Box::pin(bytes_stream));
this.response_future = None;
}
Poll::Ready(Err(e)) => {
this.response_future = None;
return Poll::Ready(Some(Err(e)));
}
Poll::Pending => return Poll::Pending,
}
}
// -- Then execute the web poll and processing loop
loop {
if let Some(ref mut fut) = this.response_future {
match Pin::new(fut).poll(cx) {
Poll::Ready(Ok(response)) => {
let bytes_stream = response.bytes_stream().map_err(|e| Box::new(e) as Box<dyn Error>);
this.bytes_stream = Some(Box::pin(bytes_stream));
this.response_future = None;
}
Poll::Ready(Err(e)) => {
this.response_future = None;
return Poll::Ready(Some(Err(e)));
}
Poll::Pending => return Poll::Pending,
}
}
if let Some(ref mut stream) = this.bytes_stream {
match stream.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(bytes))) => {
let buff_string = match String::from_utf8(bytes.to_vec()) {
Ok(s) => s,
Err(e) => return Poll::Ready(Some(Err(Box::new(e) as Box<dyn Error>))),
};
if let Some(ref mut stream) = this.bytes_stream {
match stream.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(bytes))) => {
let buff_string = match String::from_utf8(bytes.to_vec()) {
Ok(s) => s,
Err(e) => return Poll::Ready(Some(Err(Box::new(e) as Box<dyn Error>))),
};
// -- Iterate through the parts
let buff_response = match this.stream_mode {
StreamMode::Delimiter(delimiter) => {
process_buff_string_delimited(buff_string, &mut this.partial_message, delimiter)
}
StreamMode::PrettyJsonArray => {
new_with_pretty_json_array(buff_string, &mut this.partial_message)
}
};
// -- Iterate through the parts
let buff_response = match this.stream_mode {
StreamMode::Delimiter(delimiter) => {
process_buff_string_delimited(buff_string, &mut this.partial_message, delimiter)
}
StreamMode::PrettyJsonArray => {
new_with_pretty_json_array(buff_string, &mut this.partial_message)
}
};
let BuffResponse {
mut first_message,
next_messages,
candidate_message,
} = buff_response?;
let BuffResponse {
mut first_message,
next_messages,
candidate_message,
} = buff_response?;
// -- Add next_messages as remaining messages if present
if let Some(next_messages) = next_messages {
this.remaining_messages.get_or_insert(VecDeque::new()).extend(next_messages);
}
// -- Add next_messages as remaining messages if present
if let Some(next_messages) = next_messages {
this.remaining_messages.get_or_insert(VecDeque::new()).extend(next_messages);
}
// -- If we still have a candidate, it's the partial for the next one
if let Some(candidate_message) = candidate_message {
// For now, we will just log this
if this.partial_message.is_some() {
println!("GENAI - WARNING - partial_message is not none");
}
this.partial_message = Some(candidate_message);
}
// -- If we still have a candidate, it's the partial for the next one
if let Some(candidate_message) = candidate_message {
// For now, we will just log this
if this.partial_message.is_some() {
println!("GENAI - WARNING - partial_message is not none");
}
this.partial_message = Some(candidate_message);
}
// -- If we have a first message, we have to send it.
if let Some(first_message) = first_message.take() {
return Poll::Ready(Some(Ok(first_message)));
} else {
continue;
}
}
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
Poll::Ready(None) => {
if let Some(partial) = this.partial_message.take() {
if !partial.is_empty() {
return Poll::Ready(Some(Ok(partial)));
}
}
this.bytes_stream = None;
}
Poll::Pending => return Poll::Pending,
}
}
// -- If we have a first message, we have to send it.
if let Some(first_message) = first_message.take() {
return Poll::Ready(Some(Ok(first_message)));
} else {
continue;
}
}
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
Poll::Ready(None) => {
if let Some(partial) = this.partial_message.take() {
if !partial.is_empty() {
return Poll::Ready(Some(Ok(partial)));
}
}
this.bytes_stream = None;
}
Poll::Pending => return Poll::Pending,
}
}
if let Some(reqwest_builder) = this.reqwest_builder.take() {
let fut = async move { reqwest_builder.send().await.map_err(|e| Box::new(e) as Box<dyn Error>) };
this.response_future = Some(Box::pin(fut));
continue;
}
if let Some(reqwest_builder) = this.reqwest_builder.take() {
let fut = async move { reqwest_builder.send().await.map_err(|e| Box::new(e) as Box<dyn Error>) };
this.response_future = Some(Box::pin(fut));
continue;
}
return Poll::Ready(None);
}
}
return Poll::Ready(None);
}
}
}
struct BuffResponse {
first_message: Option<String>,
next_messages: Option<Vec<String>>,
candidate_message: Option<String>,
first_message: Option<String>,
next_messages: Option<Vec<String>>,
candidate_message: Option<String>,
}
/// Process a string buffer for the pretty_json_array (for Gemini)
@ -172,99 +172,99 @@ struct BuffResponse {
/// for each array item (which seems to be the case with Gemini).
/// This probably needs to be made more robust later.
fn new_with_pretty_json_array(
buff_string: String,
_partial_message: &mut Option<String>,
buff_string: String,
_partial_message: &mut Option<String>,
) -> Result<BuffResponse, crate::Error> {
let buff_str = buff_string.trim();
let buff_str = buff_string.trim();
let mut messages: Vec<String> = Vec::new();
let mut messages: Vec<String> = Vec::new();
// -- Capture the array start/end and each eventual sub-object (assuming only one sub-object)
let (array_start, rest_str) = match buff_str.strip_prefix('[') {
Some(rest) => (Some("["), rest.trim()),
None => (None, buff_str),
};
// -- Capture the array start/end and each eventual sub-object (assuming only one sub-object)
let (array_start, rest_str) = match buff_str.strip_prefix('[') {
Some(rest) => (Some("["), rest.trim()),
None => (None, buff_str),
};
// Remove the eventual ',' prefix and suffix.
let rest_str = rest_str.strip_prefix(',').unwrap_or(rest_str);
let rest_str = rest_str.strip_suffix(',').unwrap_or(rest_str);
// Remove the eventual ',' prefix and suffix.
let rest_str = rest_str.strip_prefix(',').unwrap_or(rest_str);
let rest_str = rest_str.strip_suffix(',').unwrap_or(rest_str);
let (rest_str, array_end) = match rest_str.strip_suffix(']') {
Some(rest) => (rest.trim(), Some("]")),
None => (rest_str, None),
};
let (rest_str, array_end) = match rest_str.strip_suffix(']') {
Some(rest) => (rest.trim(), Some("]")),
None => (rest_str, None),
};
// -- Prep the BuffResponse
if let Some(array_start) = array_start {
messages.push(array_start.to_string());
}
if !rest_str.is_empty() {
messages.push(rest_str.to_string());
}
// We ignore the comma
if let Some(array_end) = array_end {
messages.push(array_end.to_string());
}
// -- Prep the BuffResponse
if let Some(array_start) = array_start {
messages.push(array_start.to_string());
}
if !rest_str.is_empty() {
messages.push(rest_str.to_string());
}
// We ignore the comma
if let Some(array_end) = array_end {
messages.push(array_end.to_string());
}
// -- Return the buf response
let first_message = if !messages.is_empty() {
Some(messages[0].to_string())
} else {
None
};
// -- Return the buf response
let first_message = if !messages.is_empty() {
Some(messages[0].to_string())
} else {
None
};
let next_messages = if messages.len() > 1 {
Some(messages[1..].to_vec())
} else {
None
};
let next_messages = if messages.len() > 1 {
Some(messages[1..].to_vec())
} else {
None
};
Ok(BuffResponse {
first_message,
next_messages,
candidate_message: None,
})
Ok(BuffResponse {
first_message,
next_messages,
candidate_message: None,
})
}
/// Process a string buffer for the delimited mode (e.g., Cohere)
fn process_buff_string_delimited(
buff_string: String,
partial_message: &mut Option<String>,
delimiter: &str,
buff_string: String,
partial_message: &mut Option<String>,
delimiter: &str,
) -> Result<BuffResponse, crate::Error> {
let mut first_message: Option<String> = None;
let mut candidate_message: Option<String> = None;
let mut next_messages: Option<Vec<String>> = None;
let mut first_message: Option<String> = None;
let mut candidate_message: Option<String> = None;
let mut next_messages: Option<Vec<String>> = None;
let parts = buff_string.split(delimiter);
let parts = buff_string.split(delimiter);
for part in parts {
// If we already have a candidate, the candidate becomes the message
if let Some(candidate_message) = candidate_message.take() {
// If candidate is empty, we skip
if !candidate_message.is_empty() {
let message = candidate_message.to_string();
if first_message.is_none() {
first_message = Some(message);
} else {
next_messages.get_or_insert_with(Vec::new).push(message);
}
} else {
continue;
}
} else {
// And then, this part becomes the candidate
if let Some(partial) = partial_message.take() {
candidate_message = Some(format!("{partial}{part}"));
} else {
candidate_message = Some(part.to_string());
}
}
}
for part in parts {
// If we already have a candidate, the candidate becomes the message
if let Some(candidate_message) = candidate_message.take() {
// If candidate is empty, we skip
if !candidate_message.is_empty() {
let message = candidate_message.to_string();
if first_message.is_none() {
first_message = Some(message);
} else {
next_messages.get_or_insert_with(Vec::new).push(message);
}
} else {
continue;
}
} else {
// And then, this part becomes the candidate
if let Some(partial) = partial_message.take() {
candidate_message = Some(format!("{partial}{part}"));
} else {
candidate_message = Some(part.to_string());
}
}
}
Ok(BuffResponse {
first_message,
next_messages,
candidate_message,
})
}
Ok(BuffResponse {
first_message,
next_messages,
candidate_message,
})
}

View File

@ -183,7 +183,10 @@ pub async fn common_test_chat_stream_simple_ok(model: &str) -> Result<()> {
let stream_end = extract_stream_end(chat_res.stream).await?;
// -- Check no meta_usage and captured_content
assert!(stream_end.captured_usage.is_none(), "StreamEnd should not have any meta_usage");
assert!(
stream_end.captured_usage.is_none(),
"StreamEnd should not have any meta_usage"
);
assert!(
stream_end.captured_content.is_none(),
"StreamEnd should not have any captured_content"
@ -207,7 +210,10 @@ pub async fn common_test_chat_stream_capture_content_ok(model: &str) -> Result<(
// -- Check meta_usage
// Should be None as not captured
assert!(stream_end.captured_usage.is_none(), "StreamEnd should not have any meta_usage");
assert!(
stream_end.captured_usage.is_none(),
"StreamEnd should not have any meta_usage"
);
// -- Check captured_content
let captured_content = get_option_value!(stream_end.captured_content);
@ -277,4 +283,4 @@ pub async fn common_test_resolver_auth_ok(model: &str, auth_data: AuthData) -> R
Ok(())
}
// endregion: --- With Resolvers
// endregion: --- With Resolvers

View File

@ -36,4 +36,4 @@ pub async fn extract_stream_end(mut chat_stream: ChatStream) -> Result<StreamEnd
}
stream_end.ok_or("Should have a StreamEnd event".into())
}
}

View File

@ -15,4 +15,4 @@ pub mod common_tests;
pub type Result<T> = core::result::Result<T, Box<dyn std::error::Error>>;
// endregion: --- Modules
// endregion: --- Modules

View File

@ -6,4 +6,4 @@ pub fn seed_chat_req_simple() -> ChatRequest {
ChatMessage::system("Answer in one sentence"),
ChatMessage::user("Why is the sky blue?"),
])
}
}

View File

@ -52,4 +52,4 @@ async fn test_resolver_auth_ok() -> Result<()> {
common_tests::common_test_resolver_auth_ok(MODEL, AuthData::from_env("ANTHROPIC_API_KEY")).await
}
// endregion: --- Resolver Tests
// endregion: --- Resolver Tests

View File

@ -47,4 +47,4 @@ async fn test_resolver_auth_ok() -> Result<()> {
common_tests::common_test_resolver_auth_ok(MODEL, AuthData::from_env("COHERE_API_KEY")).await
}
// endregion: --- Resolver Tests
// endregion: --- Resolver Tests

View File

@ -11,17 +11,17 @@ const MODEL: &str = "gemini-1.5-flash-latest";
#[tokio::test]
async fn test_chat_simple_ok() -> Result<()> {
common_tests::common_test_chat_simple_ok(MODEL).await
common_tests::common_test_chat_simple_ok(MODEL).await
}
#[tokio::test]
async fn test_chat_json_structured_ok() -> Result<()> {
common_tests::common_test_chat_json_structured_ok(MODEL, true).await
common_tests::common_test_chat_json_structured_ok(MODEL, true).await
}
#[tokio::test]
async fn test_chat_temperature_ok() -> Result<()> {
common_tests::common_test_chat_temperature_ok(MODEL).await
common_tests::common_test_chat_temperature_ok(MODEL).await
}
// endregion: --- Chat
@ -30,17 +30,17 @@ async fn test_chat_temperature_ok() -> Result<()> {
#[tokio::test]
async fn test_chat_stream_simple_ok() -> Result<()> {
common_tests::common_test_chat_stream_simple_ok(MODEL).await
common_tests::common_test_chat_stream_simple_ok(MODEL).await
}
#[tokio::test]
async fn test_chat_stream_capture_content_ok() -> Result<()> {
common_tests::common_test_chat_stream_capture_content_ok(MODEL).await
common_tests::common_test_chat_stream_capture_content_ok(MODEL).await
}
#[tokio::test]
async fn test_chat_stream_capture_all_ok() -> Result<()> {
common_tests::common_test_chat_stream_capture_all_ok(MODEL).await
common_tests::common_test_chat_stream_capture_all_ok(MODEL).await
}
// endregion: --- Chat Stream Tests
@ -49,7 +49,7 @@ async fn test_chat_stream_capture_all_ok() -> Result<()> {
#[tokio::test]
async fn test_resolver_auth_ok() -> Result<()> {
common_tests::common_test_resolver_auth_ok(MODEL, AuthData::from_env("GEMINI_API_KEY")).await
common_tests::common_test_resolver_auth_ok(MODEL, AuthData::from_env("GEMINI_API_KEY")).await
}
// endregion: --- Resolver Tests
// endregion: --- Resolver Tests

View File

@ -54,4 +54,4 @@ async fn test_resolver_auth_ok() -> Result<()> {
common_tests::common_test_resolver_auth_ok(MODEL, AuthData::from_env("GROQ_API_KEY")).await
}
// endregion: --- Resolver Tests
// endregion: --- Resolver Tests

View File

@ -54,4 +54,4 @@ async fn test_resolver_auth_ok() -> Result<()> {
common_tests::common_test_resolver_auth_ok(MODEL, AuthData::from_single("ollama")).await
}
// endregion: --- Resolver Tests
// endregion: --- Resolver Tests

View File

@ -11,22 +11,22 @@ const MODEL: &str = "gpt-4o-mini";
#[tokio::test]
async fn test_chat_simple_ok() -> Result<()> {
common_tests::common_test_chat_simple_ok(MODEL).await
common_tests::common_test_chat_simple_ok(MODEL).await
}
#[tokio::test]
async fn test_chat_json_mode_ok() -> Result<()> {
common_tests::common_test_chat_json_mode_ok(MODEL, true).await
common_tests::common_test_chat_json_mode_ok(MODEL, true).await
}
#[tokio::test]
async fn test_chat_json_structured_ok() -> Result<()> {
common_tests::common_test_chat_json_structured_ok(MODEL, true).await
common_tests::common_test_chat_json_structured_ok(MODEL, true).await
}
#[tokio::test]
async fn test_chat_temperature_ok() -> Result<()> {
common_tests::common_test_chat_temperature_ok(MODEL).await
common_tests::common_test_chat_temperature_ok(MODEL).await
}
// endregion: --- Chat
@ -35,17 +35,17 @@ async fn test_chat_temperature_ok() -> Result<()> {
#[tokio::test]
async fn test_chat_stream_simple_ok() -> Result<()> {
common_tests::common_test_chat_stream_simple_ok(MODEL).await
common_tests::common_test_chat_stream_simple_ok(MODEL).await
}
#[tokio::test]
async fn test_chat_stream_capture_content_ok() -> Result<()> {
common_tests::common_test_chat_stream_capture_content_ok(MODEL).await
common_tests::common_test_chat_stream_capture_content_ok(MODEL).await
}
#[tokio::test]
async fn test_chat_stream_capture_all_ok() -> Result<()> {
common_tests::common_test_chat_stream_capture_all_ok(MODEL).await
common_tests::common_test_chat_stream_capture_all_ok(MODEL).await
}
// endregion: --- Chat Stream Tests
@ -54,7 +54,7 @@ async fn test_chat_stream_capture_all_ok() -> Result<()> {
#[tokio::test]
async fn test_resolver_auth_ok() -> Result<()> {
common_tests::common_test_resolver_auth_ok(MODEL, AuthData::from_env("OPENAI_API_KEY")).await
common_tests::common_test_resolver_auth_ok(MODEL, AuthData::from_env("OPENAI_API_KEY")).await
}
// endregion: --- Resolver Tests
// endregion: --- Resolver Tests