Merge pull request #5 from dongri/add-audio-fine-tunes

Add audio and fine_tunes
This commit is contained in:
Dongri Jin
2023-04-16 14:01:27 +09:00
committed by GitHub
6 changed files with 417 additions and 3 deletions

View File

@ -70,7 +70,7 @@ Check out the [full API documentation](https://platform.openai.com/docs/api-refe
- [x] [Edits](https://platform.openai.com/docs/api-reference/edits)
- [x] [Images](https://platform.openai.com/docs/api-reference/images)
- [x] [Embeddings](https://platform.openai.com/docs/api-reference/embeddings)
- [ ] [Audio](https://platform.openai.com/docs/api-reference/audio)
- [x] [Audio](https://platform.openai.com/docs/api-reference/audio)
- [x] [Files](https://platform.openai.com/docs/api-reference/files)
- [ ] [Fine-tunes](https://platform.openai.com/docs/api-reference/fine-tunes)
- [ ] [Moderations](https://platform.openai.com/docs/api-reference/moderations)
- [x] [Fine-tunes](https://platform.openai.com/docs/api-reference/fine-tunes)
- [x] [Moderations](https://platform.openai.com/docs/api-reference/moderations)

View File

@ -1,3 +1,7 @@
use crate::v1::audio::{
AudioTranscriptionRequest, AudioTranscriptionResponse, AudioTranslationRequest,
AudioTranslationResponse,
};
use crate::v1::chat_completion::{ChatCompletionRequest, ChatCompletionResponse};
use crate::v1::completion::{CompletionRequest, CompletionResponse};
use crate::v1::edit::{EditRequest, EditResponse};
@ -8,10 +12,18 @@ use crate::v1::file::{
FileRetrieveContentResponse, FileRetrieveRequest, FileRetrieveResponse, FileUploadRequest,
FileUploadResponse,
};
use crate::v1::fine_tune::{
CancelFineTuneRequest, CancelFineTuneResponse, CreateFineTuneRequest, CreateFineTuneResponse,
DeleteFineTuneModelRequest, DeleteFineTuneModelResponse, ListFineTuneEventsRequest,
ListFineTuneEventsResponse, ListFineTuneResponse, RetrieveFineTuneRequest,
RetrieveFineTuneResponse,
};
use crate::v1::image::{
ImageEditRequest, ImageEditResponse, ImageGenerationRequest, ImageGenerationResponse,
ImageVariationRequest, ImageVariationResponse,
};
use crate::v1::moderation::{CreateModerationRequest, CreateModerationResponse};
use reqwest::Response;
const APU_URL_V1: &str = "https://api.openai.com/v1";
@ -232,6 +244,117 @@ impl Client {
}
}
pub async fn audio_transcription(
&self,
req: AudioTranscriptionRequest,
) -> Result<AudioTranscriptionResponse, APIError> {
let res = self.post("/audio/transcriptions", &req).await?;
let r = res.json::<AudioTranscriptionResponse>().await;
match r {
Ok(r) => Ok(r),
Err(e) => Err(self.new_error(e)),
}
}
pub async fn audio_translation(
&self,
req: AudioTranslationRequest,
) -> Result<AudioTranslationResponse, APIError> {
let res = self.post("/audio/translations", &req).await?;
let r = res.json::<AudioTranslationResponse>().await;
match r {
Ok(r) => Ok(r),
Err(e) => Err(self.new_error(e)),
}
}
pub async fn create_fine_tune(
&self,
req: CreateFineTuneRequest,
) -> Result<CreateFineTuneResponse, APIError> {
let res = self.post("/fine-tunes", &req).await?;
let r = res.json::<CreateFineTuneResponse>().await;
match r {
Ok(r) => Ok(r),
Err(e) => Err(self.new_error(e)),
}
}
pub async fn list_fine_tune(&self) -> Result<ListFineTuneResponse, APIError> {
let res = self.get("/fine-tunes").await?;
let r = res.json::<ListFineTuneResponse>().await;
match r {
Ok(r) => Ok(r),
Err(e) => Err(self.new_error(e)),
}
}
pub async fn retrieve_fine_tune(
&self,
req: RetrieveFineTuneRequest,
) -> Result<RetrieveFineTuneResponse, APIError> {
let res = self
.get(&format!("/fine_tunes/{}", req.fine_tune_id))
.await?;
let r = res.json::<RetrieveFineTuneResponse>().await;
match r {
Ok(r) => Ok(r),
Err(e) => Err(self.new_error(e)),
}
}
pub async fn cancel_fine_tune(
&self,
req: CancelFineTuneRequest,
) -> Result<CancelFineTuneResponse, APIError> {
let res = self
.post(&format!("/fine_tunes/{}/cancel", req.fine_tune_id), &req)
.await?;
let r = res.json::<CancelFineTuneResponse>().await;
match r {
Ok(r) => Ok(r),
Err(e) => Err(self.new_error(e)),
}
}
pub async fn list_fine_tune_events(
&self,
req: ListFineTuneEventsRequest,
) -> Result<ListFineTuneEventsResponse, APIError> {
let res = self
.get(&format!("/fine-tunes/{}/events", req.fine_tune_id))
.await?;
let r = res.json::<ListFineTuneEventsResponse>().await;
match r {
Ok(r) => Ok(r),
Err(e) => Err(self.new_error(e)),
}
}
pub async fn delete_fine_tune(
&self,
req: DeleteFineTuneModelRequest,
) -> Result<DeleteFineTuneModelResponse, APIError> {
let res = self.delete(&format!("/models/{}", req.model_id)).await?;
let r = res.json::<DeleteFineTuneModelResponse>().await;
match r {
Ok(r) => Ok(r),
Err(e) => Err(self.new_error(e)),
}
}
pub async fn create_moderation(
&self,
req: CreateModerationRequest,
) -> Result<CreateModerationResponse, APIError> {
let res = self.post("/moderations", &req).await?;
let r = res.json::<CreateModerationResponse>().await;
match r {
Ok(r) => Ok(r),
Err(e) => Err(self.new_error(e)),
}
}
fn new_error(&self, err: reqwest::Error) -> APIError {
APIError {
message: err.to_string(),

39
src/v1/audio.rs Normal file
View File

@ -0,0 +1,39 @@
use serde::{Deserialize, Serialize};
pub const WHISPER_1: &str = "whisper-1";
#[derive(Debug, Serialize)]
pub struct AudioTranscriptionRequest {
pub model: String,
pub file: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub language: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct AudioTranscriptionResponse {
pub text: String,
}
#[derive(Debug, Serialize)]
pub struct AudioTranslationRequest {
pub model: String,
pub file: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
}
#[derive(Debug, Deserialize)]
pub struct AudioTranslationResponse {
pub text: String,
}

197
src/v1/fine_tune.rs Normal file
View File

@ -0,0 +1,197 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize)]
pub struct CreateFineTuneRequest {
pub training_file: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub validation_file: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n_epochs: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub batch_size: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub learning_rate_multiplier: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_loss_weight: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub compute_classification_metrics: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub classification_n_classes: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub classification_positive_class: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub classification_betas: Option<Vec<f32>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub suffix: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct CreateFineTuneResponse {
pub id: String,
pub object: String,
pub model: String,
pub created_at: i64,
pub events: Vec<FineTuneEvent>,
pub fine_tuned_model: Option<FineTunedModel>,
pub hyperparams: HyperParams,
pub organization_id: String,
pub result_files: Vec<ResultFile>,
pub status: String,
pub validation_files: Vec<ValidationFile>,
pub training_files: Vec<TrainingFile>,
pub updated_at: i64,
}
#[derive(Debug, Deserialize)]
pub struct FineTuneEvent {
pub object: String,
pub created_at: i64,
pub level: String,
pub message: String,
}
#[derive(Debug, Deserialize)]
pub struct FineTunedModel {
pub id: String,
pub object: String,
pub model_details: ModelDetails,
}
#[derive(Debug, Deserialize)]
pub struct ModelDetails {
pub architecture: String,
pub created_at: i64,
pub id: String,
pub object: String,
pub prompt: String,
pub samples_seen: i64,
}
#[derive(Debug, Deserialize)]
pub struct HyperParams {
pub batch_size: i32,
pub learning_rate_multiplier: f32,
pub n_epochs: i32,
pub prompt_loss_weight: f32,
}
#[derive(Debug, Deserialize)]
pub struct ResultFile {
pub id: String,
pub object: String,
pub bytes: i64,
pub created_at: i64,
pub filename: String,
pub purpose: String,
}
#[derive(Debug, Deserialize)]
pub struct ValidationFile {
pub id: String,
pub object: String,
pub bytes: i64,
pub created_at: i64,
pub filename: String,
pub purpose: String,
}
#[derive(Debug, Deserialize)]
pub struct TrainingFile {
pub id: String,
pub object: String,
pub bytes: i64,
pub created_at: i64,
pub filename: String,
pub purpose: String,
}
#[derive(Debug, Deserialize)]
pub struct ListFineTuneResponse {
pub object: String,
pub data: Vec<FineTuneData>,
}
#[derive(Debug, Deserialize)]
pub struct FineTuneData {
pub id: String,
pub object: String,
pub model: String,
pub created_at: u64,
pub fine_tuned_model: Option<String>,
pub hyperparams: HyperParams,
pub organization_id: String,
pub result_files: Vec<ResultFile>,
pub status: String,
pub validation_files: Vec<ValidationFile>,
pub training_files: Vec<TrainingFile>,
pub updated_at: u64,
}
#[derive(Debug, Deserialize)]
pub struct RetrieveFineTuneRequest {
pub fine_tune_id: String,
}
#[derive(Debug, Deserialize)]
pub struct RetrieveFineTuneResponse {
pub id: String,
pub object: String,
pub model: String,
pub created_at: i64,
pub events: Vec<FineTuneEvent>,
pub fine_tuned_model: Option<FineTunedModel>,
pub hyperparams: HyperParams,
pub organization_id: String,
pub result_files: Vec<ResultFile>,
pub status: String,
pub validation_files: Vec<ValidationFile>,
pub training_files: Vec<TrainingFile>,
pub updated_at: i64,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct CancelFineTuneRequest {
pub fine_tune_id: String,
}
#[derive(Debug, Deserialize)]
pub struct CancelFineTuneResponse {
pub id: String,
pub object: String,
pub model: String,
pub created_at: i64,
pub events: Vec<FineTuneEvent>,
pub fine_tuned_model: Option<String>,
pub hyperparams: HyperParams,
pub organization_id: String,
pub result_files: Vec<ResultFile>,
pub status: String,
pub validation_files: Vec<ValidationFile>,
pub training_files: Vec<TrainingFile>,
pub updated_at: i64,
}
#[derive(Debug, Deserialize)]
pub struct ListFineTuneEventsRequest {
pub fine_tune_id: String,
}
#[derive(Debug, Deserialize)]
pub struct ListFineTuneEventsResponse {
pub object: String,
pub data: Vec<FineTuneEvent>,
}
#[derive(Debug, Deserialize)]
pub struct DeleteFineTuneModelRequest {
pub model_id: String,
}
#[derive(Debug, Deserialize)]
pub struct DeleteFineTuneModelResponse {
pub id: String,
pub object: String,
pub deleted: bool,
}

View File

@ -1,11 +1,14 @@
pub mod common;
pub mod error;
pub mod audio;
pub mod chat_completion;
pub mod completion;
pub mod edit;
pub mod embedding;
pub mod file;
pub mod fine_tune;
pub mod image;
pub mod moderation;
pub mod api;

52
src/v1/moderation.rs Normal file
View File

@ -0,0 +1,52 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize)]
pub struct CreateModerationRequest {
pub input: String,
}
#[derive(Debug, Deserialize)]
pub struct CreateModerationResponse {
pub id: String,
pub model: String,
pub results: Vec<ModerationResult>,
}
#[derive(Debug, Deserialize)]
pub struct ModerationResult {
pub categories: ModerationCategories,
pub category_scores: ModerationCategoryScores,
pub flagged: bool,
}
#[derive(Debug, Deserialize)]
pub struct ModerationCategories {
#[serde(rename = "hate")]
pub is_hate: bool,
#[serde(rename = "hate/threatening")]
pub is_hate_threatening: bool,
#[serde(rename = "self-harm")]
pub is_self_harm: bool,
pub sexual: bool,
#[serde(rename = "sexual/minors")]
pub is_sexual_minors: bool,
pub violence: bool,
#[serde(rename = "violence/graphic")]
pub is_violence_graphic: bool,
}
#[derive(Debug, Deserialize)]
pub struct ModerationCategoryScores {
#[serde(rename = "hate")]
pub hate_score: f64,
#[serde(rename = "hate/threatening")]
pub hate_threatening_score: f64,
#[serde(rename = "self-harm")]
pub self_harm_score: f64,
pub sexual: f64,
#[serde(rename = "sexual/minors")]
pub sexual_minors_score: f64,
pub violence: f64,
#[serde(rename = "violence/graphic")]
pub violence_graphic_score: f64,
}