Files
usls/src/models/trocr/impl.rs
2025-06-05 16:29:29 +08:00

202 lines
6.1 KiB
Rust

use aksr::Builder;
use anyhow::Result;
use ndarray::{s, Axis};
use rayon::prelude::*;
use std::str::FromStr;
use crate::{elapsed, Config, Engine, Image, LogitsSampler, Processor, Scale, Ts, Xs, X, Y};
#[derive(Debug, Copy, Clone)]
pub enum TrOCRKind {
Printed,
HandWritten,
}
impl FromStr for TrOCRKind {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"printed" => Ok(Self::Printed),
"handwritten" | "hand-written" => Ok(Self::HandWritten),
x => anyhow::bail!("Unsupported TrOCRKind: {}", x),
}
}
}
#[derive(Debug, Builder)]
pub struct TrOCR {
encoder: Engine,
decoder: Engine,
decoder_merged: Engine,
max_length: u32,
eos_token_id: u32,
decoder_start_token_id: u32,
ts: Ts,
n_kvs: usize,
processor: Processor,
batch: usize,
height: usize,
width: usize,
}
impl TrOCR {
pub fn new(config: Config) -> Result<Self> {
let encoder = Engine::try_from_config(&config.visual)?;
let decoder = Engine::try_from_config(&config.textual_decoder)?;
let decoder_merged = Engine::try_from_config(&config.textual_decoder_merged)?;
let (batch, height, width) = (
encoder.batch().opt(),
encoder.try_height().unwrap_or(&384.into()).opt(),
encoder.try_width().unwrap_or(&384.into()).opt(),
);
let ts = Ts::merge(&[encoder.ts(), decoder.ts(), decoder_merged.ts()]);
// "bos_token": "<s>", "eos_token": "</s>", "sep_token": "</s>",
// "model_max_length": 1000000000000000019884624838656,
// let bos_token = "<s>";
// let eos_token = "</s>";
// let sep_token = "</s>";
// let bos_token_id = 0;
// let pad_token_id = 1;
let max_length = 1024; // TODO
let eos_token_id = 2;
let decoder_start_token_id = 2;
let n_kvs = match &config.scale {
Some(Scale::S) => 6,
Some(Scale::B) => 12,
_ => unimplemented!(),
};
let processor = Processor::try_from_config(&config.processor)?
.with_image_width(width as _)
.with_image_height(height as _);
Ok(Self {
encoder,
decoder,
decoder_merged,
max_length,
ts,
eos_token_id,
decoder_start_token_id,
n_kvs,
batch,
width,
height,
processor,
})
}
pub fn encode(&mut self, xs: &[Image]) -> Result<X> {
let ys = self.processor.process_images(xs)?;
self.batch = xs.len(); // update
let ys = self.encoder.run(ys.into())?;
Ok(ys[0].to_owned())
}
pub fn forward(&mut self, xs: &[Image]) -> Result<Vec<Y>> {
let encoder_hidden_states = elapsed!("encode", self.ts, { self.encode(xs)? });
let generated = elapsed!("generate", self.ts, {
self.generate(&encoder_hidden_states)?
});
let ys = elapsed!("decode", self.ts, { self.decode(generated)? });
Ok(ys)
}
fn generate(&mut self, encoder_hidden_states: &X) -> Result<Vec<Vec<u32>>> {
// input_ids
let input_ids = X::from(vec![self.decoder_start_token_id as f32])
.insert_axis(0)?
.repeat(0, self.batch)?;
// decoder
let mut decoder_outputs = self.decoder.run(Xs::from(vec![
input_ids.clone(),
encoder_hidden_states.clone(),
]))?;
// encoder kvs
let encoder_kvs: Vec<_> = (3..4 * self.n_kvs)
.step_by(4)
.flat_map(|i| [i, i + 1])
.map(|i| decoder_outputs[i].clone())
.collect();
// token ids
let mut token_ids: Vec<Vec<u32>> = vec![vec![]; self.batch];
let mut finished = vec![false; self.batch];
let mut last_tokens: Vec<f32> = vec![0.; self.batch];
let logits_sampler = LogitsSampler::new();
// generate
for _ in 0..self.max_length {
let logits = &decoder_outputs[0];
let decoder_kvs: Vec<_> = (1..(4 * self.n_kvs) - 2)
.step_by(4)
.flat_map(|i| [i, i + 1])
.map(|i| decoder_outputs[i].clone())
.collect();
// decode each token for each batch
for (i, logit) in logits.axis_iter(Axis(0)).enumerate() {
if !finished[i] {
let token_id = logits_sampler.decode(
&logit
.slice(s![-1, ..])
.into_owned()
.into_raw_vec_and_offset()
.0,
)?;
if token_id == self.eos_token_id {
finished[i] = true;
} else {
token_ids[i].push(token_id);
}
// update
last_tokens[i] = token_id as f32;
}
}
// all finished?
if finished.iter().all(|&x| x) {
break;
}
// build inputs
let input_ids = X::from(last_tokens.clone()).insert_axis(1)?;
let mut xs = vec![input_ids, encoder_hidden_states.clone()];
for i in 0..self.n_kvs {
xs.push(decoder_kvs[i * 2].clone());
xs.push(decoder_kvs[i * 2 + 1].clone());
xs.push(encoder_kvs[i * 2].clone());
xs.push(encoder_kvs[i * 2 + 1].clone());
}
xs.push(X::ones(&[1])); // use_cache
// generate
decoder_outputs = self.decoder_merged.run(xs.into())?;
}
Ok(token_ids)
}
pub fn decode(&self, token_ids: Vec<Vec<u32>>) -> Result<Vec<Y>> {
// decode
let texts = self.processor.decode_tokens_batch(&token_ids, false)?;
// to texts
let texts = texts
.into_par_iter()
.map(|x| Y::default().with_texts(&[x.into()]))
.collect::<Vec<_>>();
Ok(texts)
}
pub fn summary(&self) {
self.ts.summary();
}
}