update fine tuning api endpoint

This commit is contained in:
d-roak
2024-01-15 18:23:14 +00:00
parent 399c3cdf5c
commit d104c47c05
4 changed files with 174 additions and 298 deletions

View File

@ -16,11 +16,9 @@ use crate::v1::file::{
FileRetrieveContentResponse, FileRetrieveRequest, FileRetrieveResponse, FileUploadRequest,
FileUploadResponse,
};
use crate::v1::fine_tune::{
CancelFineTuneRequest, CancelFineTuneResponse, CreateFineTuneRequest, CreateFineTuneResponse,
DeleteFineTuneModelRequest, DeleteFineTuneModelResponse, ListFineTuneEventsRequest,
ListFineTuneEventsResponse, ListFineTuneResponse, RetrieveFineTuneRequest,
RetrieveFineTuneResponse,
use crate::v1::fine_tuning::{
CreateFineTuningJobRequest, ListFineTuningJobEventsRequest, RetrieveFineTuningJobRequest, CancelFineTuningJobRequest,
FineTuningPagination, FineTuningJobObject, FineTuningJobEvent,
};
use crate::v1::image::{
ImageEditRequest, ImageEditResponse, ImageGenerationRequest, ImageGenerationResponse,
@ -336,69 +334,57 @@ impl Client {
Ok(AudioSpeechResponse { result: true })
}
pub fn create_fine_tune(
pub fn create_fine_tuning_job(
&self,
req: CreateFineTuneRequest,
) -> Result<CreateFineTuneResponse, APIError> {
let res = self.post("/fine-tunes", &req)?;
let r = res.json::<CreateFineTuneResponse>();
req: CreateFineTuningJobRequest,
) -> Result<FineTuningJobObject, APIError> {
let res = self.post("/fine_tuning/jobs", &req)?;
let r = res.json::<FineTuningJobObject>();
match r {
Ok(r) => Ok(r),
Err(e) => Err(self.new_error(e)),
}
}
pub fn list_fine_tune(&self) -> Result<ListFineTuneResponse, APIError> {
let res = self.get("/fine-tunes")?;
let r = res.json::<ListFineTuneResponse>();
pub fn list_fine_tuning_jobs(&self) -> Result<FineTuningPagination<FineTuningJobObject>, APIError> {
let res = self.get("/fine_tuning/jobs")?;
let r = res.json::<FineTuningPagination<FineTuningJobObject>>();
match r {
Ok(r) => Ok(r),
Err(e) => Err(self.new_error(e)),
}
}
pub fn retrieve_fine_tune(
pub fn list_fine_tuning_job_events(
&self,
req: RetrieveFineTuneRequest,
) -> Result<RetrieveFineTuneResponse, APIError> {
let res = self.get(&format!("/fine_tunes/{}", req.fine_tune_id))?;
let r = res.json::<RetrieveFineTuneResponse>();
req: ListFineTuningJobEventsRequest,
) -> Result<FineTuningPagination<FineTuningJobEvent>, APIError> {
let res = self.get(&format!("/fine_tuning/jobs/{}/events", req.fine_tuning_job_id))?;
let r = res.json::<FineTuningPagination<FineTuningJobEvent>>();
match r {
Ok(r) => Ok(r),
Err(e) => Err(self.new_error(e)),
}
}
pub fn cancel_fine_tune(
pub fn retrieve_fine_tuning_job(
&self,
req: CancelFineTuneRequest,
) -> Result<CancelFineTuneResponse, APIError> {
let res = self.post(&format!("/fine_tunes/{}/cancel", req.fine_tune_id), &req)?;
let r = res.json::<CancelFineTuneResponse>();
req: RetrieveFineTuningJobRequest,
) -> Result<FineTuningJobObject, APIError> {
let res = self.get(&format!("/fine_tuning/jobs/{}", req.fine_tuning_job_id))?;
let r = res.json::<FineTuningJobObject>();
match r {
Ok(r) => Ok(r),
Err(e) => Err(self.new_error(e)),
}
}
pub fn list_fine_tune_events(
pub fn cancel_fine_tuning_job(
&self,
req: ListFineTuneEventsRequest,
) -> Result<ListFineTuneEventsResponse, APIError> {
let res = self.get(&format!("/fine-tunes/{}/events", req.fine_tune_id))?;
let r = res.json::<ListFineTuneEventsResponse>();
match r {
Ok(r) => Ok(r),
Err(e) => Err(self.new_error(e)),
}
}
pub fn delete_fine_tune(
&self,
req: DeleteFineTuneModelRequest,
) -> Result<DeleteFineTuneModelResponse, APIError> {
let res = self.delete(&format!("/models/{}", req.model_id))?;
let r = res.json::<DeleteFineTuneModelResponse>();
req: CancelFineTuningJobRequest,
) -> Result<FineTuningJobObject, APIError> {
let res = self.post(&format!("/fine_tuning/jobs/{}/cancel", req.fine_tuning_job_id), &req)?;
let r = res.json::<FineTuningJobObject>();
match r {
Ok(r) => Ok(r),
Err(e) => Err(self.new_error(e)),

View File

@ -1,257 +0,0 @@
use serde::{Deserialize, Serialize};
use crate::impl_builder_methods;
#[derive(Debug, Serialize, Clone)]
pub struct CreateFineTuneRequest {
pub training_file: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub validation_file: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n_epochs: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub batch_size: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub learning_rate_multiplier: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_loss_weight: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub compute_classification_metrics: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub classification_n_classes: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub classification_positive_class: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub classification_betas: Option<Vec<f32>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub suffix: Option<String>,
}
impl CreateFineTuneRequest {
pub fn new(training_file: String) -> Self {
Self {
training_file,
validation_file: None,
model: None,
n_epochs: None,
batch_size: None,
learning_rate_multiplier: None,
prompt_loss_weight: None,
compute_classification_metrics: None,
classification_n_classes: None,
classification_positive_class: None,
classification_betas: None,
suffix: None,
}
}
}
impl_builder_methods!(
CreateFineTuneRequest,
validation_file: String,
model: String,
n_epochs: i32,
batch_size: i32,
learning_rate_multiplier: f32,
prompt_loss_weight: f32,
compute_classification_metrics: bool,
classification_n_classes: i32,
classification_positive_class: String,
classification_betas: Vec<f32>,
suffix: String
);
#[derive(Debug, Deserialize)]
pub struct CreateFineTuneResponse {
pub id: String,
pub object: String,
pub model: String,
pub created_at: i64,
pub events: Vec<FineTuneEvent>,
pub fine_tuned_model: Option<FineTunedModel>,
pub hyperparams: HyperParams,
pub organization_id: String,
pub result_files: Vec<ResultFile>,
pub status: String,
pub validation_files: Vec<ValidationFile>,
pub training_files: Vec<TrainingFile>,
pub updated_at: i64,
}
#[derive(Debug, Deserialize)]
pub struct FineTuneEvent {
pub object: String,
pub created_at: i64,
pub level: String,
pub message: String,
}
#[derive(Debug, Deserialize)]
pub struct FineTunedModel {
pub id: String,
pub object: String,
pub model_details: ModelDetails,
}
#[derive(Debug, Deserialize)]
pub struct ModelDetails {
pub architecture: String,
pub created_at: i64,
pub id: String,
pub object: String,
pub prompt: String,
pub samples_seen: i64,
}
#[derive(Debug, Deserialize)]
pub struct HyperParams {
pub batch_size: i32,
pub learning_rate_multiplier: f32,
pub n_epochs: i32,
pub prompt_loss_weight: f32,
}
#[derive(Debug, Deserialize)]
pub struct ResultFile {
pub id: String,
pub object: String,
pub bytes: i64,
pub created_at: i64,
pub filename: String,
pub purpose: String,
}
#[derive(Debug, Deserialize)]
pub struct ValidationFile {
pub id: String,
pub object: String,
pub bytes: i64,
pub created_at: i64,
pub filename: String,
pub purpose: String,
}
#[derive(Debug, Deserialize)]
pub struct TrainingFile {
pub id: String,
pub object: String,
pub bytes: i64,
pub created_at: i64,
pub filename: String,
pub purpose: String,
}
#[derive(Debug, Deserialize)]
pub struct ListFineTuneResponse {
pub object: String,
pub data: Vec<FineTuneData>,
}
#[derive(Debug, Deserialize)]
pub struct FineTuneData {
pub id: String,
pub object: String,
pub model: String,
pub created_at: u64,
pub fine_tuned_model: Option<String>,
pub hyperparams: HyperParams,
pub organization_id: String,
pub result_files: Vec<ResultFile>,
pub status: String,
pub validation_files: Vec<ValidationFile>,
pub training_files: Vec<TrainingFile>,
pub updated_at: u64,
}
#[derive(Debug, Deserialize)]
pub struct RetrieveFineTuneRequest {
pub fine_tune_id: String,
}
impl RetrieveFineTuneRequest {
pub fn new(fine_tune_id: String) -> Self {
Self { fine_tune_id }
}
}
#[derive(Debug, Deserialize)]
pub struct RetrieveFineTuneResponse {
pub id: String,
pub object: String,
pub model: String,
pub created_at: i64,
pub events: Vec<FineTuneEvent>,
pub fine_tuned_model: Option<FineTunedModel>,
pub hyperparams: HyperParams,
pub organization_id: String,
pub result_files: Vec<ResultFile>,
pub status: String,
pub validation_files: Vec<ValidationFile>,
pub training_files: Vec<TrainingFile>,
pub updated_at: i64,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct CancelFineTuneRequest {
pub fine_tune_id: String,
}
impl CancelFineTuneRequest {
pub fn new(fine_tune_id: String) -> Self {
Self { fine_tune_id }
}
}
#[derive(Debug, Deserialize)]
pub struct CancelFineTuneResponse {
pub id: String,
pub object: String,
pub model: String,
pub created_at: i64,
pub events: Vec<FineTuneEvent>,
pub fine_tuned_model: Option<String>,
pub hyperparams: HyperParams,
pub organization_id: String,
pub result_files: Vec<ResultFile>,
pub status: String,
pub validation_files: Vec<ValidationFile>,
pub training_files: Vec<TrainingFile>,
pub updated_at: i64,
}
#[derive(Debug, Deserialize)]
pub struct ListFineTuneEventsRequest {
pub fine_tune_id: String,
}
impl ListFineTuneEventsRequest {
pub fn new(fine_tune_id: String) -> Self {
Self { fine_tune_id }
}
}
#[derive(Debug, Deserialize)]
pub struct ListFineTuneEventsResponse {
pub object: String,
pub data: Vec<FineTuneEvent>,
}
#[derive(Debug, Deserialize)]
pub struct DeleteFineTuneModelRequest {
pub model_id: String,
}
impl DeleteFineTuneModelRequest {
pub fn new(model_id: String) -> Self {
Self { model_id }
}
}
#[derive(Debug, Deserialize)]
pub struct DeleteFineTuneModelResponse {
pub id: String,
pub object: String,
pub deleted: bool,
}

147
src/v1/fine_tuning.rs Normal file
View File

@ -0,0 +1,147 @@
use serde::{Deserialize, Serialize};
use crate::impl_builder_methods;
#[derive(Debug, Serialize, Clone)]
pub struct CreateFineTuningJobRequest {
pub model: String,
pub training_file: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub hyperparameters: Option<HyperParameters>,
#[serde(skip_serializing_if = "Option::is_none")]
pub suffix: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub validation_file: Option<String>,
}
impl CreateFineTuningJobRequest {
pub fn new(model: String, training_file: String) -> Self {
Self {
model,
training_file,
hyperparameters: None,
suffix: None,
validation_file: None,
}
}
}
impl_builder_methods!(
CreateFineTuningJobRequest,
hyperparameters: HyperParameters,
suffix: String,
validation_file: String
);
#[derive(Debug, Serialize)]
pub struct ListFineTuningJobsRequest {
// TODO pass as query params
#[serde(skip_serializing_if = "Option::is_none")]
pub after: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub limit: Option<i64>,
}
impl ListFineTuningJobsRequest {
pub fn new(fine_tune_id: String) -> Self {
Self {
after: None,
limit: None,
}
}
}
#[derive(Debug, Serialize)]
pub struct ListFineTuningJobEventsRequest {
pub fine_tuning_job_id: String,
// TODO pass as query params
#[serde(skip_serializing_if = "Option::is_none")]
pub after: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub limit: Option<i64>,
}
impl ListFineTuningJobEventsRequest {
pub fn new(fine_tuning_job_id: String) -> Self {
Self {
fine_tuning_job_id,
after: None,
limit: None,
}
}
}
#[derive(Debug, Serialize)]
pub struct RetrieveFineTuningJobRequest {
pub fine_tuning_job_id: String,
}
impl RetrieveFineTuningJobRequest {
pub fn new(fine_tuning_job_id: String) -> Self {
Self { fine_tuning_job_id }
}
}
#[derive(Debug, Serialize)]
pub struct CancelFineTuningJobRequest {
pub fine_tuning_job_id: String,
}
impl CancelFineTuningJobRequest {
pub fn new(fine_tuning_job_id: String) -> Self {
Self { fine_tuning_job_id }
}
}
#[derive(Debug, Deserialize)]
pub struct FineTuningPagination<T> {
pub object: String,
pub data: Vec<T>,
pub has_more: bool,
}
#[derive(Debug, Deserialize)]
pub struct FineTuningJobObject {
pub id: String,
pub created_at: i64,
pub error: Option<FineTuningJobError>,
pub fine_tuned_model: Option<String>,
pub finished_at: Option<String>,
pub hyperparameters: HyperParameters,
pub model: String,
pub object: String,
pub organization_id: String,
pub result_files: Vec<String>,
pub status: String,
pub trained_tokens: Option<i64>,
pub training_file: String,
pub validation_file: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct FineTuningJobError {
pub code: String,
pub message: String,
pub param: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct FineTuningJobEvent {
pub id: String,
pub created_at: i64,
pub level: String,
pub message: String,
pub object: String,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct HyperParameters {
#[serde(skip_serializing_if = "Option::is_none")]
pub batch_size: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub learning_rate_multiplier: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n_epochs: Option<String>,
}

View File

@ -7,7 +7,7 @@ pub mod completion;
pub mod edit;
pub mod embedding;
pub mod file;
pub mod fine_tune;
pub mod fine_tuning;
pub mod image;
pub mod moderation;