diff --git a/README.md b/README.md index 87a03a9..53d1727 100644 --- a/README.md +++ b/README.md @@ -70,7 +70,7 @@ Check out the [full API documentation](https://platform.openai.com/docs/api-refe - [x] [Edits](https://platform.openai.com/docs/api-reference/edits) - [x] [Images](https://platform.openai.com/docs/api-reference/images) - [x] [Embeddings](https://platform.openai.com/docs/api-reference/embeddings) -- [ ] [Audio](https://platform.openai.com/docs/api-reference/audio) +- [x] [Audio](https://platform.openai.com/docs/api-reference/audio) - [x] [Files](https://platform.openai.com/docs/api-reference/files) -- [ ] [Fine-tunes](https://platform.openai.com/docs/api-reference/fine-tunes) -- [ ] [Moderations](https://platform.openai.com/docs/api-reference/moderations) +- [x] [Fine-tunes](https://platform.openai.com/docs/api-reference/fine-tunes) +- [x] [Moderations](https://platform.openai.com/docs/api-reference/moderations) diff --git a/src/v1/api.rs b/src/v1/api.rs index 7463756..93bf9d9 100644 --- a/src/v1/api.rs +++ b/src/v1/api.rs @@ -1,3 +1,7 @@ +use crate::v1::audio::{ + AudioTranscriptionRequest, AudioTranscriptionResponse, AudioTranslationRequest, + AudioTranslationResponse, +}; use crate::v1::chat_completion::{ChatCompletionRequest, ChatCompletionResponse}; use crate::v1::completion::{CompletionRequest, CompletionResponse}; use crate::v1::edit::{EditRequest, EditResponse}; @@ -8,10 +12,18 @@ use crate::v1::file::{ FileRetrieveContentResponse, FileRetrieveRequest, FileRetrieveResponse, FileUploadRequest, FileUploadResponse, }; +use crate::v1::fine_tune::{ + CancelFineTuneRequest, CancelFineTuneResponse, CreateFineTuneRequest, CreateFineTuneResponse, + DeleteFineTuneModelRequest, DeleteFineTuneModelResponse, ListFineTuneEventsRequest, + ListFineTuneEventsResponse, ListFineTuneResponse, RetrieveFineTuneRequest, + RetrieveFineTuneResponse, +}; use crate::v1::image::{ ImageEditRequest, ImageEditResponse, ImageGenerationRequest, ImageGenerationResponse, ImageVariationRequest, ImageVariationResponse, }; +use crate::v1::moderation::{CreateModerationRequest, CreateModerationResponse}; + use reqwest::Response; const APU_URL_V1: &str = "https://api.openai.com/v1"; @@ -232,6 +244,117 @@ impl Client { } } + pub async fn audio_transcription( + &self, + req: AudioTranscriptionRequest, + ) -> Result { + let res = self.post("/audio/transcriptions", &req).await?; + let r = res.json::().await; + match r { + Ok(r) => Ok(r), + Err(e) => Err(self.new_error(e)), + } + } + + pub async fn audio_translation( + &self, + req: AudioTranslationRequest, + ) -> Result { + let res = self.post("/audio/translations", &req).await?; + let r = res.json::().await; + match r { + Ok(r) => Ok(r), + Err(e) => Err(self.new_error(e)), + } + } + + pub async fn create_fine_tune( + &self, + req: CreateFineTuneRequest, + ) -> Result { + let res = self.post("/fine-tunes", &req).await?; + let r = res.json::().await; + match r { + Ok(r) => Ok(r), + Err(e) => Err(self.new_error(e)), + } + } + + pub async fn list_fine_tune(&self) -> Result { + let res = self.get("/fine-tunes").await?; + let r = res.json::().await; + match r { + Ok(r) => Ok(r), + Err(e) => Err(self.new_error(e)), + } + } + + pub async fn retrieve_fine_tune( + &self, + req: RetrieveFineTuneRequest, + ) -> Result { + let res = self + .get(&format!("/fine_tunes/{}", req.fine_tune_id)) + .await?; + let r = res.json::().await; + match r { + Ok(r) => Ok(r), + Err(e) => Err(self.new_error(e)), + } + } + + pub async fn cancel_fine_tune( + &self, + req: CancelFineTuneRequest, + ) -> Result { + let res = self + .post(&format!("/fine_tunes/{}/cancel", req.fine_tune_id), &req) + .await?; + let r = res.json::().await; + match r { + Ok(r) => Ok(r), + Err(e) => Err(self.new_error(e)), + } + } + + pub async fn list_fine_tune_events( + &self, + req: ListFineTuneEventsRequest, + ) -> Result { + let res = self + .get(&format!("/fine-tunes/{}/events", req.fine_tune_id)) + .await?; + let r = res.json::().await; + match r { + Ok(r) => Ok(r), + Err(e) => Err(self.new_error(e)), + } + } + + pub async fn delete_fine_tune( + &self, + req: DeleteFineTuneModelRequest, + ) -> Result { + let res = self.delete(&format!("/models/{}", req.model_id)).await?; + let r = res.json::().await; + match r { + Ok(r) => Ok(r), + Err(e) => Err(self.new_error(e)), + } + } + + pub async fn create_moderation( + &self, + req: CreateModerationRequest, + ) -> Result { + let res = self.post("/moderations", &req).await?; + let r = res.json::().await; + match r { + Ok(r) => Ok(r), + Err(e) => Err(self.new_error(e)), + } + } + fn new_error(&self, err: reqwest::Error) -> APIError { APIError { message: err.to_string(), diff --git a/src/v1/audio.rs b/src/v1/audio.rs new file mode 100644 index 0000000..1948dd3 --- /dev/null +++ b/src/v1/audio.rs @@ -0,0 +1,39 @@ +use serde::{Deserialize, Serialize}; + +pub const WHISPER_1: &str = "whisper-1"; + +#[derive(Debug, Serialize)] +pub struct AudioTranscriptionRequest { + pub model: String, + pub file: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub language: Option, +} + +#[derive(Debug, Deserialize)] +pub struct AudioTranscriptionResponse { + pub text: String, +} + +#[derive(Debug, Serialize)] +pub struct AudioTranslationRequest { + pub model: String, + pub file: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, +} + +#[derive(Debug, Deserialize)] +pub struct AudioTranslationResponse { + pub text: String, +} diff --git a/src/v1/fine_tune.rs b/src/v1/fine_tune.rs new file mode 100644 index 0000000..25d9c7e --- /dev/null +++ b/src/v1/fine_tune.rs @@ -0,0 +1,197 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize)] +pub struct CreateFineTuneRequest { + pub training_file: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub validation_file: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub n_epochs: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub batch_size: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub learning_rate_multiplier: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_loss_weight: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub compute_classification_metrics: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub classification_n_classes: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub classification_positive_class: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub classification_betas: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub suffix: Option, +} + +#[derive(Debug, Deserialize)] +pub struct CreateFineTuneResponse { + pub id: String, + pub object: String, + pub model: String, + pub created_at: i64, + pub events: Vec, + pub fine_tuned_model: Option, + pub hyperparams: HyperParams, + pub organization_id: String, + pub result_files: Vec, + pub status: String, + pub validation_files: Vec, + pub training_files: Vec, + pub updated_at: i64, +} + +#[derive(Debug, Deserialize)] +pub struct FineTuneEvent { + pub object: String, + pub created_at: i64, + pub level: String, + pub message: String, +} + +#[derive(Debug, Deserialize)] +pub struct FineTunedModel { + pub id: String, + pub object: String, + pub model_details: ModelDetails, +} + +#[derive(Debug, Deserialize)] +pub struct ModelDetails { + pub architecture: String, + pub created_at: i64, + pub id: String, + pub object: String, + pub prompt: String, + pub samples_seen: i64, +} + +#[derive(Debug, Deserialize)] +pub struct HyperParams { + pub batch_size: i32, + pub learning_rate_multiplier: f32, + pub n_epochs: i32, + pub prompt_loss_weight: f32, +} + +#[derive(Debug, Deserialize)] +pub struct ResultFile { + pub id: String, + pub object: String, + pub bytes: i64, + pub created_at: i64, + pub filename: String, + pub purpose: String, +} + +#[derive(Debug, Deserialize)] +pub struct ValidationFile { + pub id: String, + pub object: String, + pub bytes: i64, + pub created_at: i64, + pub filename: String, + pub purpose: String, +} + +#[derive(Debug, Deserialize)] +pub struct TrainingFile { + pub id: String, + pub object: String, + pub bytes: i64, + pub created_at: i64, + pub filename: String, + pub purpose: String, +} + +#[derive(Debug, Deserialize)] +pub struct ListFineTuneResponse { + pub object: String, + pub data: Vec, +} + +#[derive(Debug, Deserialize)] +pub struct FineTuneData { + pub id: String, + pub object: String, + pub model: String, + pub created_at: u64, + pub fine_tuned_model: Option, + pub hyperparams: HyperParams, + pub organization_id: String, + pub result_files: Vec, + pub status: String, + pub validation_files: Vec, + pub training_files: Vec, + pub updated_at: u64, +} + +#[derive(Debug, Deserialize)] +pub struct RetrieveFineTuneRequest { + pub fine_tune_id: String, +} + +#[derive(Debug, Deserialize)] +pub struct RetrieveFineTuneResponse { + pub id: String, + pub object: String, + pub model: String, + pub created_at: i64, + pub events: Vec, + pub fine_tuned_model: Option, + pub hyperparams: HyperParams, + pub organization_id: String, + pub result_files: Vec, + pub status: String, + pub validation_files: Vec, + pub training_files: Vec, + pub updated_at: i64, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct CancelFineTuneRequest { + pub fine_tune_id: String, +} + +#[derive(Debug, Deserialize)] +pub struct CancelFineTuneResponse { + pub id: String, + pub object: String, + pub model: String, + pub created_at: i64, + pub events: Vec, + pub fine_tuned_model: Option, + pub hyperparams: HyperParams, + pub organization_id: String, + pub result_files: Vec, + pub status: String, + pub validation_files: Vec, + pub training_files: Vec, + pub updated_at: i64, +} + +#[derive(Debug, Deserialize)] +pub struct ListFineTuneEventsRequest { + pub fine_tune_id: String, +} + +#[derive(Debug, Deserialize)] +pub struct ListFineTuneEventsResponse { + pub object: String, + pub data: Vec, +} + +#[derive(Debug, Deserialize)] +pub struct DeleteFineTuneModelRequest { + pub model_id: String, +} + +#[derive(Debug, Deserialize)] +pub struct DeleteFineTuneModelResponse { + pub id: String, + pub object: String, + pub deleted: bool, +} diff --git a/src/v1/mod.rs b/src/v1/mod.rs index c8ca2ad..718339f 100644 --- a/src/v1/mod.rs +++ b/src/v1/mod.rs @@ -1,11 +1,14 @@ pub mod common; pub mod error; +pub mod audio; pub mod chat_completion; pub mod completion; pub mod edit; pub mod embedding; pub mod file; +pub mod fine_tune; pub mod image; +pub mod moderation; pub mod api; diff --git a/src/v1/moderation.rs b/src/v1/moderation.rs new file mode 100644 index 0000000..58bc861 --- /dev/null +++ b/src/v1/moderation.rs @@ -0,0 +1,52 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize)] +pub struct CreateModerationRequest { + pub input: String, +} + +#[derive(Debug, Deserialize)] +pub struct CreateModerationResponse { + pub id: String, + pub model: String, + pub results: Vec, +} + +#[derive(Debug, Deserialize)] +pub struct ModerationResult { + pub categories: ModerationCategories, + pub category_scores: ModerationCategoryScores, + pub flagged: bool, +} + +#[derive(Debug, Deserialize)] +pub struct ModerationCategories { + #[serde(rename = "hate")] + pub is_hate: bool, + #[serde(rename = "hate/threatening")] + pub is_hate_threatening: bool, + #[serde(rename = "self-harm")] + pub is_self_harm: bool, + pub sexual: bool, + #[serde(rename = "sexual/minors")] + pub is_sexual_minors: bool, + pub violence: bool, + #[serde(rename = "violence/graphic")] + pub is_violence_graphic: bool, +} + +#[derive(Debug, Deserialize)] +pub struct ModerationCategoryScores { + #[serde(rename = "hate")] + pub hate_score: f64, + #[serde(rename = "hate/threatening")] + pub hate_threatening_score: f64, + #[serde(rename = "self-harm")] + pub self_harm_score: f64, + pub sexual: f64, + #[serde(rename = "sexual/minors")] + pub sexual_minors_score: f64, + pub violence: f64, + #[serde(rename = "violence/graphic")] + pub violence_graphic_score: f64, +}