From 66b12711ab358700f967fc8c627b10419c3835e7 Mon Sep 17 00:00:00 2001 From: Dongri Jin Date: Sat, 17 May 2025 09:53:33 +0900 Subject: [PATCH] Add models api --- examples/chat_completion.rs | 2 +- examples/model.rs | 23 +++++++++++++++++++++++ examples/openrouter.rs | 2 +- src/v1/api.rs | 37 ++++++++++++++++++++++++++++++------- src/v1/assistant.rs | 7 ------- src/v1/common.rs | 7 +++++++ src/v1/mod.rs | 1 + src/v1/model.rs | 15 +++++++++++++++ 8 files changed, 78 insertions(+), 16 deletions(-) create mode 100644 examples/model.rs create mode 100644 src/v1/model.rs diff --git a/examples/chat_completion.rs b/examples/chat_completion.rs index 64f564e..635add5 100644 --- a/examples/chat_completion.rs +++ b/examples/chat_completion.rs @@ -23,7 +23,7 @@ async fn main() -> Result<(), Box> { println!("Content: {:?}", result.choices[0].message.content); // print response headers - for (key, value) in client.headers.unwrap().iter() { + for (key, value) in client.response_headers.unwrap().iter() { println!("{}: {:?}", key, value); } diff --git a/examples/model.rs b/examples/model.rs new file mode 100644 index 0000000..18dc109 --- /dev/null +++ b/examples/model.rs @@ -0,0 +1,23 @@ +use openai_api_rs::v1::api::OpenAIClient; +use std::env; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let api_key = env::var("OPENAI_API_KEY").unwrap().to_string(); + let mut client = OpenAIClient::builder().with_api_key(api_key).build()?; + + let result = client.list_models().await?; + let models = result.data; + + for model in models { + println!("Model id: {:?}", model.id); + } + + let result = client.retrieve_model("gpt-4.1".to_string()).await?; + println!("Model id: {:?}", result.id); + println!("Model object: {:?}", result.object); + + Ok(()) +} + +// OPENAI_API_KEY=xxxx cargo run --package openai-api-rs --example model diff --git a/examples/openrouter.rs b/examples/openrouter.rs index dd04e0a..5295bf4 100644 --- a/examples/openrouter.rs +++ b/examples/openrouter.rs @@ -24,7 +24,7 @@ async fn main() -> Result<(), Box> { let result = client.chat_completion(req).await?; println!("Content: {:?}", result.choices[0].message.content); - println!("Response Headers: {:?}", client.headers); + println!("Response Headers: {:?}", client.response_headers); Ok(()) } diff --git a/src/v1/api.rs b/src/v1/api.rs index 458dc34..1c50695 100644 --- a/src/v1/api.rs +++ b/src/v1/api.rs @@ -1,6 +1,6 @@ use crate::v1::assistant::{ - AssistantFileObject, AssistantFileRequest, AssistantObject, AssistantRequest, DeletionStatus, - ListAssistant, ListAssistantFile, + AssistantFileObject, AssistantFileRequest, AssistantObject, AssistantRequest, ListAssistant, + ListAssistantFile, }; use crate::v1::audio::{ AudioSpeechRequest, AudioSpeechResponse, AudioTranscriptionRequest, AudioTranscriptionResponse, @@ -30,6 +30,7 @@ use crate::v1::message::{ CreateMessageRequest, ListMessage, ListMessageFile, MessageFileObject, MessageObject, ModifyMessageRequest, }; +use crate::v1::model::{ModelResponse, ModelsResponse}; use crate::v1::moderation::{CreateModerationRequest, CreateModerationResponse}; use crate::v1::run::{ CreateRunRequest, CreateThreadAndRunRequest, ListRun, ListRunStep, ModifyRunRequest, RunObject, @@ -70,7 +71,8 @@ pub struct OpenAIClient { organization: Option, proxy: Option, timeout: Option, - pub headers: Option, + headers: Option, + pub response_headers: Option, } impl OpenAIClientBuilder { @@ -124,6 +126,7 @@ impl OpenAIClientBuilder { proxy: self.proxy, timeout: self.timeout, headers: self.headers, + response_headers: None, }) } } @@ -237,7 +240,7 @@ impl OpenAIClient { let text = response.text().await.unwrap_or_else(|_| "".to_string()); match serde_json::from_str::(&text) { Ok(parsed) => { - self.headers = Some(headers); + self.response_headers = Some(headers); Ok(parsed) } Err(e) => Err(APIError::CustomError { @@ -507,7 +510,7 @@ impl OpenAIClient { pub async fn delete_assistant( &mut self, assistant_id: String, - ) -> Result { + ) -> Result { self.delete(&format!("assistants/{}", assistant_id)).await } @@ -544,7 +547,7 @@ impl OpenAIClient { &mut self, assistant_id: String, file_id: String, - ) -> Result { + ) -> Result { self.delete(&format!("assistants/{}/files/{}", assistant_id, file_id)) .await } @@ -586,7 +589,10 @@ impl OpenAIClient { self.post(&format!("threads/{}", thread_id), &req).await } - pub async fn delete_thread(&mut self, thread_id: String) -> Result { + pub async fn delete_thread( + &mut self, + thread_id: String, + ) -> Result { self.delete(&format!("threads/{}", thread_id)).await } @@ -781,6 +787,22 @@ impl OpenAIClient { let url = Self::query_params(limit, None, after, None, "batches".to_string()); self.get(&url).await } + + pub async fn list_models(&mut self) -> Result { + self.get("models").await + } + + pub async fn retrieve_model(&mut self, model_id: String) -> Result { + self.get(&format!("models/{}", model_id)).await + } + + pub async fn delete_model( + &mut self, + model_id: String, + ) -> Result { + self.delete(&format!("models/{}", model_id)).await + } + fn build_url_with_preserved_query(&self, path: &str) -> Result { let (base, query_opt) = match self.api_endpoint.split_once('?') { Some((b, q)) => (b.trim_end_matches('/'), Some(q)), @@ -797,6 +819,7 @@ impl OpenAIClient { } Ok(url.to_string()) } + fn query_params( limit: Option, order: Option, diff --git a/src/v1/assistant.rs b/src/v1/assistant.rs index 319073e..e0e45d1 100644 --- a/src/v1/assistant.rs +++ b/src/v1/assistant.rs @@ -95,13 +95,6 @@ pub struct VectorStores { pub metadata: Option>, } -#[derive(Debug, Deserialize, Serialize)] -pub struct DeletionStatus { - pub id: String, - pub object: String, - pub deleted: bool, -} - #[derive(Debug, Deserialize, Serialize)] pub struct ListAssistant { pub object: String, diff --git a/src/v1/common.rs b/src/v1/common.rs index 9540d08..ab16946 100644 --- a/src/v1/common.rs +++ b/src/v1/common.rs @@ -7,6 +7,13 @@ pub struct Usage { pub total_tokens: i32, } +#[derive(Debug, Deserialize, Serialize)] +pub struct DeletionStatus { + pub id: String, + pub object: String, + pub deleted: bool, +} + #[macro_export] macro_rules! impl_builder_methods { ($builder:ident, $($field:ident: $field_type:ty),*) => { diff --git a/src/v1/mod.rs b/src/v1/mod.rs index 5271472..d44ed31 100644 --- a/src/v1/mod.rs +++ b/src/v1/mod.rs @@ -11,6 +11,7 @@ pub mod embedding; pub mod file; pub mod fine_tuning; pub mod image; +pub mod model; pub mod moderation; // beta diff --git a/src/v1/model.rs b/src/v1/model.rs new file mode 100644 index 0000000..2b0a044 --- /dev/null +++ b/src/v1/model.rs @@ -0,0 +1,15 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Deserialize, Serialize)] +pub struct ModelsResponse { + pub object: String, + pub data: Vec, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct ModelResponse { + pub id: String, + pub object: String, + pub created: i64, + pub owned_by: String, +}