mirror of
https://github.com/mii443/usls.git
synced 2025-08-22 15:45:41 +00:00
489 lines
16 KiB
Rust
489 lines
16 KiB
Rust
//! Options for everthing
|
|
use aksr::Builder;
|
|
use anyhow::Result;
|
|
use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
|
|
|
|
use crate::{
|
|
models::{SamKind, YOLOPredsFormat},
|
|
try_fetch_file_stem, DType, Device, Engine, Hub, Iiix, Kind, LogitsSampler, MinOptMax,
|
|
Processor, ResizeMode, Scale, Task, Version,
|
|
};
|
|
|
|
/// Options for building models and inference
|
|
#[derive(Builder, Debug, Clone)]
|
|
pub struct Options {
|
|
// Model configs
|
|
pub model_file: String,
|
|
pub model_name: &'static str,
|
|
pub model_device: Device,
|
|
pub model_dtype: DType,
|
|
pub model_version: Option<Version>,
|
|
pub model_task: Option<Task>,
|
|
pub model_scale: Option<Scale>,
|
|
pub model_kind: Option<Kind>,
|
|
pub model_iiixs: Vec<Iiix>,
|
|
pub model_spec: String,
|
|
pub model_num_dry_run: usize,
|
|
pub trt_fp16: bool,
|
|
pub profile: bool,
|
|
|
|
// models
|
|
pub model_encoder_file: Option<String>,
|
|
pub model_decoder_file: Option<String>,
|
|
pub visual_encoder_file: Option<String>,
|
|
pub visual_decoder_file: Option<String>,
|
|
pub textual_encoder_file: Option<String>,
|
|
pub textual_decoder_file: Option<String>,
|
|
|
|
// Processor configs
|
|
#[args(except(setter))]
|
|
pub image_width: u32,
|
|
#[args(except(setter))]
|
|
pub image_height: u32,
|
|
pub resize_mode: ResizeMode,
|
|
pub resize_filter: &'static str,
|
|
pub padding_value: u8,
|
|
pub letterbox_center: bool,
|
|
pub normalize: bool,
|
|
pub image_std: Vec<f32>,
|
|
pub image_mean: Vec<f32>,
|
|
pub nchw: bool,
|
|
pub unsigned: bool,
|
|
|
|
// Names
|
|
pub class_names: Option<Vec<String>>,
|
|
pub class_names_2: Option<Vec<String>>,
|
|
pub class_names_3: Option<Vec<String>>,
|
|
pub keypoint_names: Option<Vec<String>>,
|
|
pub keypoint_names_2: Option<Vec<String>>,
|
|
pub keypoint_names_3: Option<Vec<String>>,
|
|
pub text_names: Option<Vec<String>>,
|
|
pub text_names_2: Option<Vec<String>>,
|
|
pub text_names_3: Option<Vec<String>>,
|
|
pub category_names: Option<Vec<String>>,
|
|
pub category_names_2: Option<Vec<String>>,
|
|
pub category_names_3: Option<Vec<String>>,
|
|
|
|
// Confs
|
|
pub class_confs: Vec<f32>,
|
|
pub class_confs_2: Vec<f32>,
|
|
pub class_confs_3: Vec<f32>,
|
|
pub keypoint_confs: Vec<f32>,
|
|
pub keypoint_confs_2: Vec<f32>,
|
|
pub keypoint_confs_3: Vec<f32>,
|
|
pub text_confs: Vec<f32>,
|
|
pub text_confs_2: Vec<f32>,
|
|
pub text_confs_3: Vec<f32>,
|
|
|
|
// Files
|
|
pub file: Option<String>,
|
|
pub file_2: Option<String>,
|
|
pub file_3: Option<String>,
|
|
|
|
// For classification
|
|
pub apply_softmax: Option<bool>,
|
|
pub topk: Option<usize>,
|
|
pub topk_2: Option<usize>,
|
|
pub topk_3: Option<usize>,
|
|
|
|
// For detection
|
|
#[args(aka = "nc")]
|
|
pub num_classes: Option<usize>,
|
|
#[args(aka = "nk")]
|
|
pub num_keypoints: Option<usize>,
|
|
#[args(aka = "nm")]
|
|
pub num_masks: Option<usize>,
|
|
pub iou: Option<f32>,
|
|
pub iou_2: Option<f32>,
|
|
pub iou_3: Option<f32>,
|
|
pub apply_nms: Option<bool>,
|
|
pub find_contours: bool,
|
|
pub yolo_preds_format: Option<YOLOPredsFormat>,
|
|
pub classes_excluded: Vec<usize>,
|
|
pub classes_retained: Vec<usize>,
|
|
pub min_width: Option<f32>,
|
|
pub min_height: Option<f32>,
|
|
|
|
// Language models related
|
|
pub model_max_length: Option<u64>,
|
|
pub tokenizer_file: Option<String>,
|
|
pub config_file: Option<String>,
|
|
pub special_tokens_map_file: Option<String>,
|
|
pub tokenizer_config_file: Option<String>,
|
|
pub generation_config_file: Option<String>,
|
|
pub vocab_file: Option<String>, // vocab.json file
|
|
pub vocab_txt: Option<String>, // vacab.txt file, not kv pairs
|
|
pub temperature: f32,
|
|
pub topp: f32,
|
|
|
|
// For DB
|
|
pub unclip_ratio: Option<f32>,
|
|
pub binary_thresh: Option<f32>,
|
|
|
|
// For SAM
|
|
pub sam_kind: Option<SamKind>, // TODO: remove
|
|
pub low_res_mask: Option<bool>, // TODO: remove
|
|
|
|
// Others
|
|
pub ort_graph_opt_level: Option<u8>,
|
|
}
|
|
|
|
impl Default for Options {
|
|
fn default() -> Self {
|
|
Self {
|
|
model_file: Default::default(),
|
|
model_name: Default::default(),
|
|
model_version: Default::default(),
|
|
model_task: Default::default(),
|
|
model_scale: Default::default(),
|
|
model_kind: Default::default(),
|
|
model_device: Device::Cpu(0),
|
|
model_dtype: DType::Auto,
|
|
model_spec: Default::default(),
|
|
model_iiixs: Default::default(),
|
|
model_num_dry_run: 3,
|
|
trt_fp16: true,
|
|
profile: false,
|
|
normalize: true,
|
|
image_mean: vec![],
|
|
image_std: vec![],
|
|
image_height: 640,
|
|
image_width: 640,
|
|
padding_value: 114,
|
|
resize_mode: ResizeMode::FitExact,
|
|
resize_filter: "Bilinear",
|
|
letterbox_center: false,
|
|
nchw: true,
|
|
unsigned: false,
|
|
class_names: None,
|
|
class_names_2: None,
|
|
class_names_3: None,
|
|
category_names: None,
|
|
category_names_2: None,
|
|
category_names_3: None,
|
|
keypoint_names: None,
|
|
keypoint_names_2: None,
|
|
keypoint_names_3: None,
|
|
text_names: None,
|
|
text_names_2: None,
|
|
text_names_3: None,
|
|
file: None,
|
|
file_2: None,
|
|
file_3: None,
|
|
class_confs: vec![0.3f32],
|
|
class_confs_2: vec![0.3f32],
|
|
class_confs_3: vec![0.3f32],
|
|
keypoint_confs: vec![0.3f32],
|
|
keypoint_confs_2: vec![0.5f32],
|
|
keypoint_confs_3: vec![0.5f32],
|
|
text_confs: vec![0.4f32],
|
|
text_confs_2: vec![0.4f32],
|
|
text_confs_3: vec![0.4f32],
|
|
apply_softmax: Some(false),
|
|
num_classes: None,
|
|
num_keypoints: None,
|
|
num_masks: None,
|
|
iou: None,
|
|
iou_2: None,
|
|
iou_3: None,
|
|
find_contours: false,
|
|
yolo_preds_format: None,
|
|
classes_excluded: vec![],
|
|
classes_retained: vec![],
|
|
apply_nms: None,
|
|
model_max_length: None,
|
|
tokenizer_file: None,
|
|
config_file: None,
|
|
special_tokens_map_file: None,
|
|
tokenizer_config_file: None,
|
|
generation_config_file: None,
|
|
vocab_file: None,
|
|
vocab_txt: None,
|
|
min_width: None,
|
|
min_height: None,
|
|
unclip_ratio: Some(1.5),
|
|
binary_thresh: Some(0.2),
|
|
sam_kind: None,
|
|
low_res_mask: None,
|
|
temperature: 1.,
|
|
topp: 0.,
|
|
topk: None,
|
|
topk_2: None,
|
|
topk_3: None,
|
|
ort_graph_opt_level: None,
|
|
model_encoder_file: None,
|
|
model_decoder_file: None,
|
|
visual_encoder_file: None,
|
|
visual_decoder_file: None,
|
|
textual_encoder_file: None,
|
|
textual_decoder_file: None,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Options {
|
|
pub fn new() -> Self {
|
|
Default::default()
|
|
}
|
|
|
|
pub fn to_engine(&self) -> Result<Engine> {
|
|
Engine {
|
|
file: self.model_file.clone(),
|
|
spec: self.model_spec.clone(),
|
|
device: self.model_device,
|
|
trt_fp16: self.trt_fp16,
|
|
iiixs: self.model_iiixs.clone(),
|
|
num_dry_run: self.model_num_dry_run,
|
|
graph_opt_level: self.ort_graph_opt_level,
|
|
..Default::default()
|
|
}
|
|
.build()
|
|
}
|
|
|
|
pub fn to_processor(&self) -> Result<Processor> {
|
|
let logits_sampler = LogitsSampler::new()
|
|
.with_temperature(self.temperature)
|
|
.with_topp(self.topp);
|
|
|
|
// try to build tokenizer
|
|
let tokenizer = match self.model_kind {
|
|
Some(Kind::Language) | Some(Kind::VisionLanguage) => Some(self.try_build_tokenizer()?),
|
|
_ => None,
|
|
};
|
|
|
|
// try to build vocab from `vocab.txt`
|
|
let vocab: Vec<String> = match &self.vocab_txt {
|
|
Some(x) => {
|
|
let file = if !std::path::PathBuf::from(&x).exists() {
|
|
Hub::default().try_fetch(&format!("{}/{}", self.model_name, x))?
|
|
} else {
|
|
x.to_string()
|
|
};
|
|
std::fs::read_to_string(file)?
|
|
.lines()
|
|
.map(|line| line.to_string())
|
|
.collect()
|
|
}
|
|
None => vec![],
|
|
};
|
|
|
|
Ok(Processor {
|
|
image_width: self.image_width,
|
|
image_height: self.image_height,
|
|
resize_mode: self.resize_mode.clone(),
|
|
resize_filter: self.resize_filter,
|
|
padding_value: self.padding_value,
|
|
do_normalize: self.normalize,
|
|
image_mean: self.image_mean.clone(),
|
|
image_std: self.image_std.clone(),
|
|
nchw: self.nchw,
|
|
unsigned: self.unsigned,
|
|
tokenizer,
|
|
vocab,
|
|
logits_sampler: Some(logits_sampler),
|
|
..Default::default()
|
|
})
|
|
}
|
|
|
|
pub fn commit(mut self) -> Result<Self> {
|
|
// Identify the local model or fetch the remote model
|
|
|
|
if std::path::PathBuf::from(&self.model_file).exists() {
|
|
// Local
|
|
self.model_spec = format!(
|
|
"{}/{}",
|
|
self.model_name,
|
|
try_fetch_file_stem(&self.model_file)?
|
|
);
|
|
} else {
|
|
// Remote
|
|
if self.model_file.is_empty() && self.model_name.is_empty() {
|
|
anyhow::bail!("Neither `model_name` nor `model_file` were specified. Faild to fetch model from remote.")
|
|
}
|
|
|
|
// Load
|
|
match Hub::is_valid_github_release_url(&self.model_file) {
|
|
Some((owner, repo, tag, _file_name)) => {
|
|
let stem = try_fetch_file_stem(&self.model_file)?;
|
|
self.model_spec =
|
|
format!("{}/{}-{}-{}-{}", self.model_name, owner, repo, tag, stem);
|
|
self.model_file = Hub::default().try_fetch(&self.model_file)?;
|
|
}
|
|
None => {
|
|
// special yolo case
|
|
if self.model_file.is_empty() && self.model_name == "yolo" {
|
|
// [version]-[scale]-[task]
|
|
let mut y = String::new();
|
|
if let Some(x) = self.model_version() {
|
|
y.push_str(&x.to_string());
|
|
}
|
|
if let Some(x) = self.model_scale() {
|
|
y.push_str(&format!("-{}", x));
|
|
}
|
|
if let Some(x) = self.model_task() {
|
|
y.push_str(&format!("-{}", x.yolo_str()));
|
|
}
|
|
y.push_str(".onnx");
|
|
self.model_file = y;
|
|
}
|
|
|
|
// append dtype to model file
|
|
match self.model_dtype {
|
|
d @ (DType::Auto | DType::Fp32) => {
|
|
if self.model_file.is_empty() {
|
|
self.model_file = format!("{}.onnx", d);
|
|
}
|
|
}
|
|
dtype => {
|
|
if self.model_file.is_empty() {
|
|
self.model_file = format!("{}.onnx", dtype);
|
|
} else {
|
|
let pos = self.model_file.len() - 5; // .onnx
|
|
let suffix = self.model_file.split_off(pos);
|
|
self.model_file =
|
|
format!("{}-{}{}", self.model_file, dtype, suffix);
|
|
}
|
|
}
|
|
}
|
|
|
|
let stem = try_fetch_file_stem(&self.model_file)?;
|
|
self.model_spec = format!("{}/{}", self.model_name, stem);
|
|
self.model_file = Hub::default()
|
|
.try_fetch(&format!("{}/{}", self.model_name, self.model_file))?;
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(self)
|
|
}
|
|
|
|
pub fn with_batch_size(mut self, x: usize) -> Self {
|
|
self.model_iiixs.push(Iiix::from((0, 0, x.into())));
|
|
self
|
|
}
|
|
|
|
pub fn with_image_height(mut self, x: u32) -> Self {
|
|
self.image_height = x;
|
|
self.model_iiixs.push(Iiix::from((0, 2, x.into())));
|
|
self
|
|
}
|
|
|
|
pub fn with_image_width(mut self, x: u32) -> Self {
|
|
self.image_width = x;
|
|
self.model_iiixs.push(Iiix::from((0, 3, x.into())));
|
|
self
|
|
}
|
|
|
|
pub fn with_model_ixx(mut self, i: usize, ii: usize, x: MinOptMax) -> Self {
|
|
self.model_iiixs.push(Iiix::from((i, ii, x)));
|
|
self
|
|
}
|
|
|
|
pub fn exclude_classes(mut self, xs: &[usize]) -> Self {
|
|
self.classes_retained.clear();
|
|
self.classes_excluded.extend_from_slice(xs);
|
|
self
|
|
}
|
|
|
|
pub fn retain_classes(mut self, xs: &[usize]) -> Self {
|
|
self.classes_excluded.clear();
|
|
self.classes_retained.extend_from_slice(xs);
|
|
self
|
|
}
|
|
|
|
pub fn try_build_tokenizer(&self) -> Result<Tokenizer> {
|
|
let mut hub = Hub::default();
|
|
// config file
|
|
// TODO: save configs?
|
|
let pad_id = match hub.try_fetch(
|
|
self.tokenizer_config_file
|
|
.as_ref()
|
|
.unwrap_or(&format!("{}/config.json", self.model_name)),
|
|
) {
|
|
Ok(x) => {
|
|
let config: serde_json::Value = serde_json::from_str(&std::fs::read_to_string(x)?)?;
|
|
config["pad_token_id"].as_u64().unwrap_or(0) as u32
|
|
}
|
|
Err(_err) => 0u32,
|
|
};
|
|
|
|
// tokenizer_config file
|
|
let mut max_length = None;
|
|
let mut pad_token = String::from("[PAD]");
|
|
match hub.try_fetch(
|
|
self.tokenizer_config_file
|
|
.as_ref()
|
|
.unwrap_or(&format!("{}/tokenizer_config.json", self.model_name)),
|
|
) {
|
|
Err(_) => {}
|
|
Ok(x) => {
|
|
let tokenizer_config: serde_json::Value =
|
|
serde_json::from_str(&std::fs::read_to_string(x)?)?;
|
|
max_length = tokenizer_config["model_max_length"].as_u64();
|
|
pad_token = tokenizer_config["pad_token"]
|
|
.as_str()
|
|
.unwrap_or("[PAD]")
|
|
.to_string();
|
|
}
|
|
}
|
|
|
|
// tokenizer file
|
|
let mut tokenizer: tokenizers::Tokenizer = tokenizers::Tokenizer::from_file(
|
|
hub.try_fetch(
|
|
self.tokenizer_file
|
|
.as_ref()
|
|
.unwrap_or(&format!("{}/tokenizer.json", self.model_name)),
|
|
)?,
|
|
)
|
|
.map_err(|err| anyhow::anyhow!("Faild to build tokenizer: {err}"))?;
|
|
|
|
// TODO: padding
|
|
// if `max_length` specified: use `Fixed` strategy
|
|
// else: use `BatchLongest` strategy
|
|
// TODO: if sequence_length is dynamic, `BatchLongest` is fine
|
|
let tokenizer = match self.model_max_length {
|
|
Some(n) => {
|
|
let n = match max_length {
|
|
None => n,
|
|
Some(x) => x.min(n),
|
|
};
|
|
tokenizer
|
|
.with_padding(Some(PaddingParams {
|
|
strategy: PaddingStrategy::Fixed(n as _),
|
|
pad_token,
|
|
pad_id,
|
|
..Default::default()
|
|
}))
|
|
.clone()
|
|
}
|
|
None => match max_length {
|
|
Some(n) => tokenizer
|
|
.with_padding(Some(PaddingParams {
|
|
strategy: PaddingStrategy::BatchLongest,
|
|
pad_token,
|
|
pad_id,
|
|
..Default::default()
|
|
}))
|
|
.with_truncation(Some(TruncationParams {
|
|
max_length: n as _,
|
|
..Default::default()
|
|
}))
|
|
.map_err(|err| anyhow::anyhow!("Failed to truncate: {}", err))?
|
|
.clone(),
|
|
None => tokenizer
|
|
.with_padding(Some(PaddingParams {
|
|
strategy: PaddingStrategy::BatchLongest,
|
|
pad_token,
|
|
pad_id,
|
|
..Default::default()
|
|
}))
|
|
.clone(),
|
|
},
|
|
};
|
|
|
|
// TODO: generation_config.json & special_tokens_map file
|
|
|
|
Ok(tokenizer.into())
|
|
}
|
|
}
|