Merge pull request #166 from dongri/add-models-api

Add models api
This commit is contained in:
Dongri Jin
2025-05-17 09:59:09 +09:00
committed by GitHub
8 changed files with 78 additions and 16 deletions

View File

@ -23,7 +23,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("Content: {:?}", result.choices[0].message.content);
// print response headers
for (key, value) in client.headers.unwrap().iter() {
for (key, value) in client.response_headers.unwrap().iter() {
println!("{}: {:?}", key, value);
}

23
examples/model.rs Normal file
View File

@ -0,0 +1,23 @@
use openai_api_rs::v1::api::OpenAIClient;
use std::env;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let api_key = env::var("OPENAI_API_KEY").unwrap().to_string();
let mut client = OpenAIClient::builder().with_api_key(api_key).build()?;
let result = client.list_models().await?;
let models = result.data;
for model in models {
println!("Model id: {:?}", model.id);
}
let result = client.retrieve_model("gpt-4.1".to_string()).await?;
println!("Model id: {:?}", result.id);
println!("Model object: {:?}", result.object);
Ok(())
}
// OPENAI_API_KEY=xxxx cargo run --package openai-api-rs --example model

View File

@ -24,7 +24,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let result = client.chat_completion(req).await?;
println!("Content: {:?}", result.choices[0].message.content);
println!("Response Headers: {:?}", client.headers);
println!("Response Headers: {:?}", client.response_headers);
Ok(())
}

View File

@ -1,6 +1,6 @@
use crate::v1::assistant::{
AssistantFileObject, AssistantFileRequest, AssistantObject, AssistantRequest, DeletionStatus,
ListAssistant, ListAssistantFile,
AssistantFileObject, AssistantFileRequest, AssistantObject, AssistantRequest, ListAssistant,
ListAssistantFile,
};
use crate::v1::audio::{
AudioSpeechRequest, AudioSpeechResponse, AudioTranscriptionRequest, AudioTranscriptionResponse,
@ -30,6 +30,7 @@ use crate::v1::message::{
CreateMessageRequest, ListMessage, ListMessageFile, MessageFileObject, MessageObject,
ModifyMessageRequest,
};
use crate::v1::model::{ModelResponse, ModelsResponse};
use crate::v1::moderation::{CreateModerationRequest, CreateModerationResponse};
use crate::v1::run::{
CreateRunRequest, CreateThreadAndRunRequest, ListRun, ListRunStep, ModifyRunRequest, RunObject,
@ -70,7 +71,8 @@ pub struct OpenAIClient {
organization: Option<String>,
proxy: Option<String>,
timeout: Option<u64>,
pub headers: Option<HeaderMap>,
headers: Option<HeaderMap>,
pub response_headers: Option<HeaderMap>,
}
impl OpenAIClientBuilder {
@ -124,6 +126,7 @@ impl OpenAIClientBuilder {
proxy: self.proxy,
timeout: self.timeout,
headers: self.headers,
response_headers: None,
})
}
}
@ -237,7 +240,7 @@ impl OpenAIClient {
let text = response.text().await.unwrap_or_else(|_| "".to_string());
match serde_json::from_str::<T>(&text) {
Ok(parsed) => {
self.headers = Some(headers);
self.response_headers = Some(headers);
Ok(parsed)
}
Err(e) => Err(APIError::CustomError {
@ -507,7 +510,7 @@ impl OpenAIClient {
pub async fn delete_assistant(
&mut self,
assistant_id: String,
) -> Result<DeletionStatus, APIError> {
) -> Result<common::DeletionStatus, APIError> {
self.delete(&format!("assistants/{}", assistant_id)).await
}
@ -544,7 +547,7 @@ impl OpenAIClient {
&mut self,
assistant_id: String,
file_id: String,
) -> Result<DeletionStatus, APIError> {
) -> Result<common::DeletionStatus, APIError> {
self.delete(&format!("assistants/{}/files/{}", assistant_id, file_id))
.await
}
@ -586,7 +589,10 @@ impl OpenAIClient {
self.post(&format!("threads/{}", thread_id), &req).await
}
pub async fn delete_thread(&mut self, thread_id: String) -> Result<DeletionStatus, APIError> {
pub async fn delete_thread(
&mut self,
thread_id: String,
) -> Result<common::DeletionStatus, APIError> {
self.delete(&format!("threads/{}", thread_id)).await
}
@ -781,6 +787,22 @@ impl OpenAIClient {
let url = Self::query_params(limit, None, after, None, "batches".to_string());
self.get(&url).await
}
pub async fn list_models(&mut self) -> Result<ModelsResponse, APIError> {
self.get("models").await
}
pub async fn retrieve_model(&mut self, model_id: String) -> Result<ModelResponse, APIError> {
self.get(&format!("models/{}", model_id)).await
}
pub async fn delete_model(
&mut self,
model_id: String,
) -> Result<common::DeletionStatus, APIError> {
self.delete(&format!("models/{}", model_id)).await
}
fn build_url_with_preserved_query(&self, path: &str) -> Result<String, url::ParseError> {
let (base, query_opt) = match self.api_endpoint.split_once('?') {
Some((b, q)) => (b.trim_end_matches('/'), Some(q)),
@ -797,6 +819,7 @@ impl OpenAIClient {
}
Ok(url.to_string())
}
fn query_params(
limit: Option<i64>,
order: Option<String>,

View File

@ -95,13 +95,6 @@ pub struct VectorStores {
pub metadata: Option<HashMap<String, String>>,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct DeletionStatus {
pub id: String,
pub object: String,
pub deleted: bool,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct ListAssistant {
pub object: String,

View File

@ -7,6 +7,13 @@ pub struct Usage {
pub total_tokens: i32,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct DeletionStatus {
pub id: String,
pub object: String,
pub deleted: bool,
}
#[macro_export]
macro_rules! impl_builder_methods {
($builder:ident, $($field:ident: $field_type:ty),*) => {

View File

@ -11,6 +11,7 @@ pub mod embedding;
pub mod file;
pub mod fine_tuning;
pub mod image;
pub mod model;
pub mod moderation;
// beta

15
src/v1/model.rs Normal file
View File

@ -0,0 +1,15 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Deserialize, Serialize)]
pub struct ModelsResponse {
pub object: String,
pub data: Vec<ModelResponse>,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct ModelResponse {
pub id: String,
pub object: String,
pub created: i64,
pub owned_by: String,
}