From b3072e8da35038294cfac522a11b631f22356e93 Mon Sep 17 00:00:00 2001 From: Jamjamjon <51357717+jamjamjon@users.noreply.github.com> Date: Wed, 22 Jan 2025 14:55:12 +0800 Subject: [PATCH] Bug Fixes: Increased the number of available releases and fixed the Trocr bug. (#58) --- examples/hub/main.rs | 4 +- src/misc/engine.rs | 4 +- src/misc/hub.rs | 6 +- src/models/trocr/impl.rs | 175 ++++++++++++++++++++------------------- 4 files changed, 96 insertions(+), 93 deletions(-) diff --git a/examples/hub/main.rs b/examples/hub/main.rs index 45cc7b2..33f62e3 100644 --- a/examples/hub/main.rs +++ b/examples/hub/main.rs @@ -17,9 +17,9 @@ fn main() -> anyhow::Result<()> { // 3. Fetch tags and files let hub = Hub::default().with_owner("jamjamjon").with_repo("usls"); - for tag in hub.tags().iter() { + for (i, tag) in hub.tags().iter().enumerate() { let files = hub.files(tag); - println!("{} => {:?}", tag, files); // Should be empty + println!("{} :: {} => {:?}", i, tag, files); // Should be empty } Ok(()) diff --git a/src/misc/engine.rs b/src/misc/engine.rs index 04999ea..1c6543b 100644 --- a/src/misc/engine.rs +++ b/src/misc/engine.rs @@ -1,7 +1,7 @@ use aksr::Builder; use anyhow::Result; use half::{bf16, f16}; -use log::{error, info, warn}; +use log::{debug, info, warn}; use ndarray::{Array, IxDyn}; #[allow(unused_imports)] use ort::{ @@ -275,7 +275,7 @@ impl Engine { { match x.try_extract_tensor::() { Err(err) => { - error!("Failed to extract from ort outputs: {:?}", err); + debug!("Failed to extract from ort outputs: {:?}. A default value has been generated.", err); Array::zeros(0).into_dyn() } Ok(x) => x.view().mapv(map_fn).into_owned(), diff --git a/src/misc/hub.rs b/src/misc/hub.rs index acb2a43..0adcca1 100644 --- a/src/misc/hub.rs +++ b/src/misc/hub.rs @@ -460,8 +460,10 @@ impl Hub { let cache = to.path()?.join(Self::cache_file(owner, repo)); let is_file_expired = Self::is_file_expired(&cache, ttl)?; let body = if is_file_expired { - let gh_api_release = - format!("https://api.github.com/repos/{}/{}/releases", owner, repo); + let gh_api_release = format!( + "https://api.github.com/repos/{}/{}/releases?per_page=100", + owner, repo + ); Self::fetch_and_cache_releases(&gh_api_release, &cache)? } else { std::fs::read_to_string(&cache)? diff --git a/src/models/trocr/impl.rs b/src/models/trocr/impl.rs index c6894fd..07dc2ab 100644 --- a/src/models/trocr/impl.rs +++ b/src/models/trocr/impl.rs @@ -1,12 +1,13 @@ use aksr::Builder; use anyhow::Result; use image::DynamicImage; +use ndarray::{s, Axis}; use rayon::prelude::*; use crate::{ elapsed, models::{BaseModelTextual, BaseModelVisual}, - Options, Scale, Ts, Xs, Ys, X, Y, + LogitsSampler, Options, Scale, Ts, Xs, Ys, X, Y, }; #[derive(Debug, Copy, Clone)] @@ -96,85 +97,6 @@ impl TrOCR { Ok(ys) } - // fn generate(&mut self, encoder_hidden_states: &X) -> Result>> { - // // input_ids - // let input_ids = X::from(vec![self.decoder_start_token_id as f32]) - // .insert_axis(0)? - // .repeat(0, self.encoder.batch())?; - - // // decoder - // let mut decoder_outputs = self.decoder.inference(Xs::from(vec![ - // input_ids.clone(), - // encoder_hidden_states.clone(), - // ]))?; - - // // encoder kvs - // let encoder_kvs: Vec<_> = (3..4 * self.n_kvs) - // .step_by(4) - // .flat_map(|i| [i, i + 1]) - // .map(|i| decoder_outputs[i].clone()) - // .collect(); - - // // token ids - // let mut token_ids: Vec> = vec![vec![]; self.encoder.batch()]; - // let mut finished = vec![false; self.encoder.batch()]; - // let mut last_tokens: Vec = vec![0.; self.encoder.batch()]; - // let mut logits_sampler = LogitsSampler::new(); - - // // generate - // for _ in 0..self.max_length { - // let logits = &decoder_outputs[0]; - // let decoder_kvs: Vec<_> = (1..(4 * self.n_kvs) - 2) - // .step_by(4) - // .flat_map(|i| [i, i + 1]) - // .map(|i| decoder_outputs[i].clone()) - // .collect(); - - // // decode each token for each batch - // for (i, logit) in logits.axis_iter(Axis(0)).enumerate() { - // if !finished[i] { - // let token_id = logits_sampler.decode( - // &logit - // .slice(s![-1, ..]) - // .into_owned() - // .into_raw_vec_and_offset() - // .0, - // )?; - - // if token_id == self.eos_token_id { - // finished[i] = true; - // } else { - // token_ids[i].push(token_id); - // } - - // // update - // last_tokens[i] = token_id as f32; - // } - // } - - // // all finished? - // if finished.iter().all(|&x| x) { - // break; - // } - - // // build inputs - // let input_ids = X::from(last_tokens.clone()).insert_axis(1)?; - // let mut xs = vec![input_ids, encoder_hidden_states.clone()]; - // for i in 0..self.n_kvs { - // xs.push(decoder_kvs[i * 2].clone()); - // xs.push(decoder_kvs[i * 2 + 1].clone()); - // xs.push(encoder_kvs[i * 2].clone()); - // xs.push(encoder_kvs[i * 2 + 1].clone()); - // } - // xs.push(X::ones(&[1])); // use_cache - - // // generate - // decoder_outputs = self.decoder_merged.inference(xs.into())?; - // } - - // Ok(token_ids) - // } - fn generate(&mut self, encoder_hidden_states: &X) -> Result>> { // input_ids let input_ids = X::from(vec![self.decoder_start_token_id as f32]) @@ -196,6 +118,9 @@ impl TrOCR { // token ids let mut token_ids: Vec> = vec![vec![]; self.encoder.batch()]; + let mut finished = vec![false; self.encoder.batch()]; + let mut last_tokens: Vec = vec![0.; self.encoder.batch()]; + let logits_sampler = LogitsSampler::new(); // generate for _ in 0..self.max_length { @@ -207,18 +132,34 @@ impl TrOCR { .collect(); // decode each token for each batch - let (finished, last_tokens) = self.decoder_merged.processor().par_generate( - logits, - &mut token_ids, - self.eos_token_id, - )?; + for (i, logit) in logits.axis_iter(Axis(0)).enumerate() { + if !finished[i] { + let token_id = logits_sampler.decode( + &logit + .slice(s![-1, ..]) + .into_owned() + .into_raw_vec_and_offset() + .0, + )?; - if finished { + if token_id == self.eos_token_id { + finished[i] = true; + } else { + token_ids[i].push(token_id); + } + + // update + last_tokens[i] = token_id as f32; + } + } + + // all finished? + if finished.iter().all(|&x| x) { break; } // build inputs - let input_ids = X::from(last_tokens).insert_axis(1)?; + let input_ids = X::from(last_tokens.clone()).insert_axis(1)?; let mut xs = vec![input_ids, encoder_hidden_states.clone()]; for i in 0..self.n_kvs { xs.push(decoder_kvs[i * 2].clone()); @@ -235,6 +176,66 @@ impl TrOCR { Ok(token_ids) } + // fn generate(&mut self, encoder_hidden_states: &X) -> Result>> { + // // input_ids + // let input_ids = X::from(vec![self.decoder_start_token_id as f32]) + // .insert_axis(0)? + // .repeat(0, self.encoder.batch())?; + + // // decoder + // let mut decoder_outputs = self.decoder.inference(Xs::from(vec![ + // input_ids.clone(), + // encoder_hidden_states.clone(), + // ]))?; + + // // encoder kvs + // let encoder_kvs: Vec<_> = (3..4 * self.n_kvs) + // .step_by(4) + // .flat_map(|i| [i, i + 1]) + // .map(|i| decoder_outputs[i].clone()) + // .collect(); + + // // token ids + // let mut token_ids: Vec> = vec![vec![]; self.encoder.batch()]; + + // // generate + // for _ in 0..self.max_length { + // let logits = &decoder_outputs[0]; + // let decoder_kvs: Vec<_> = (1..(4 * self.n_kvs) - 2) + // .step_by(4) + // .flat_map(|i| [i, i + 1]) + // .map(|i| decoder_outputs[i].clone()) + // .collect(); + + // // decode each token for each batch + // let (finished, last_tokens) = self.decoder_merged.processor().par_generate( + // logits, + // &mut token_ids, + // self.eos_token_id, + // )?; + + // if finished { + // break; + // } + + // // build inputs + // let input_ids = X::from(last_tokens).insert_axis(1)?; + // let mut xs = vec![input_ids, encoder_hidden_states.clone()]; + // for i in 0..self.n_kvs { + // xs.push(decoder_kvs[i * 2].clone()); + // xs.push(decoder_kvs[i * 2 + 1].clone()); + // xs.push(encoder_kvs[i * 2].clone()); + // xs.push(encoder_kvs[i * 2 + 1].clone()); + // } + // xs.push(X::ones(&[1])); // use_cache + + // // generate + // decoder_outputs = self.decoder_merged.inference(xs.into())?; + // } + + // Ok(token_ids) + // } + pub fn decode(&self, token_ids: Vec>) -> Result { // decode let texts = self