mirror of
https://github.com/mii443/usls.git
synced 2025-08-22 15:45:41 +00:00
Options -> ModelConfig
This commit is contained in:
@ -1,7 +1,7 @@
|
||||
## Quick Start
|
||||
|
||||
```shell
|
||||
cargo run -r -F cuda --example florence2 -- --device cuda --scale base --dtype fp16
|
||||
cargo run -r -F cuda --example florence2 -- --device cuda --dtype fp16
|
||||
```
|
||||
|
||||
|
||||
|
@ -1,20 +1,16 @@
|
||||
use anyhow::Result;
|
||||
use usls::{models::Florence2, Annotator, DataLoader, Options, Scale, Style, Task};
|
||||
use usls::{models::Florence2, Annotator, DataLoader, ModelConfig, Style, Task};
|
||||
|
||||
#[derive(argh::FromArgs)]
|
||||
/// Example
|
||||
struct Args {
|
||||
/// dtype
|
||||
#[argh(option, default = "String::from(\"auto\")")]
|
||||
#[argh(option, default = "String::from(\"fp16\")")]
|
||||
dtype: String,
|
||||
|
||||
/// device
|
||||
#[argh(option, default = "String::from(\"cpu:0\")")]
|
||||
device: String,
|
||||
|
||||
/// scale
|
||||
#[argh(option, default = "String::from(\"base\")")]
|
||||
scale: String,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -29,51 +25,12 @@ fn main() -> Result<()> {
|
||||
let xs = DataLoader::try_read_n(&["images/green-car.jpg", "assets/bus.jpg"])?;
|
||||
|
||||
// build model
|
||||
let (
|
||||
options_vision_encoder,
|
||||
options_text_embed,
|
||||
options_encoder,
|
||||
options_decoder,
|
||||
options_decoder_merged,
|
||||
) = match args.scale.as_str().try_into()? {
|
||||
Scale::B => (
|
||||
Options::florence2_visual_encoder_base(),
|
||||
Options::florence2_textual_embed_base(),
|
||||
Options::florence2_texual_encoder_base(),
|
||||
Options::florence2_texual_decoder_base(),
|
||||
Options::florence2_texual_decoder_merged_base(),
|
||||
),
|
||||
Scale::L => todo!(),
|
||||
_ => anyhow::bail!("Unsupported Florence2 scale."),
|
||||
};
|
||||
|
||||
let mut model = Florence2::new(
|
||||
options_vision_encoder
|
||||
.with_model_dtype(args.dtype.as_str().try_into()?)
|
||||
.with_model_device(args.device.as_str().try_into()?)
|
||||
.with_batch_size(xs.len())
|
||||
.commit()?,
|
||||
options_text_embed
|
||||
.with_model_dtype(args.dtype.as_str().try_into()?)
|
||||
.with_model_device(args.device.as_str().try_into()?)
|
||||
.with_batch_size(xs.len())
|
||||
.commit()?,
|
||||
options_encoder
|
||||
.with_model_dtype(args.dtype.as_str().try_into()?)
|
||||
.with_model_device(args.device.as_str().try_into()?)
|
||||
.with_batch_size(xs.len())
|
||||
.commit()?,
|
||||
options_decoder
|
||||
.with_model_dtype(args.dtype.as_str().try_into()?)
|
||||
.with_model_device(args.device.as_str().try_into()?)
|
||||
.with_batch_size(xs.len())
|
||||
.commit()?,
|
||||
options_decoder_merged
|
||||
.with_model_dtype(args.dtype.as_str().try_into()?)
|
||||
.with_model_device(args.device.as_str().try_into()?)
|
||||
.with_batch_size(xs.len())
|
||||
.commit()?,
|
||||
)?;
|
||||
let config = ModelConfig::florence2_base()
|
||||
.with_dtype_all(args.dtype.as_str().try_into()?)
|
||||
.with_device_all(args.device.as_str().try_into()?)
|
||||
.with_batch_size_all(xs.len())
|
||||
.commit()?;
|
||||
let mut model = Florence2::new(config)?;
|
||||
|
||||
// tasks
|
||||
let tasks = [
|
||||
|
Reference in New Issue
Block a user