Bug Fixes: Increased the number of available releases and fixed the Trocr bug. (#58)

This commit is contained in:
Jamjamjon
2025-01-22 14:55:12 +08:00
committed by GitHub
parent 475a680703
commit b3072e8da3
4 changed files with 96 additions and 93 deletions

View File

@ -17,9 +17,9 @@ fn main() -> anyhow::Result<()> {
// 3. Fetch tags and files // 3. Fetch tags and files
let hub = Hub::default().with_owner("jamjamjon").with_repo("usls"); let hub = Hub::default().with_owner("jamjamjon").with_repo("usls");
for tag in hub.tags().iter() { for (i, tag) in hub.tags().iter().enumerate() {
let files = hub.files(tag); let files = hub.files(tag);
println!("{} => {:?}", tag, files); // Should be empty println!("{} :: {} => {:?}", i, tag, files); // Should be empty
} }
Ok(()) Ok(())

View File

@ -1,7 +1,7 @@
use aksr::Builder; use aksr::Builder;
use anyhow::Result; use anyhow::Result;
use half::{bf16, f16}; use half::{bf16, f16};
use log::{error, info, warn}; use log::{debug, info, warn};
use ndarray::{Array, IxDyn}; use ndarray::{Array, IxDyn};
#[allow(unused_imports)] #[allow(unused_imports)]
use ort::{ use ort::{
@ -275,7 +275,7 @@ impl Engine {
{ {
match x.try_extract_tensor::<T>() { match x.try_extract_tensor::<T>() {
Err(err) => { Err(err) => {
error!("Failed to extract from ort outputs: {:?}", err); debug!("Failed to extract from ort outputs: {:?}. A default value has been generated.", err);
Array::zeros(0).into_dyn() Array::zeros(0).into_dyn()
} }
Ok(x) => x.view().mapv(map_fn).into_owned(), Ok(x) => x.view().mapv(map_fn).into_owned(),

View File

@ -460,8 +460,10 @@ impl Hub {
let cache = to.path()?.join(Self::cache_file(owner, repo)); let cache = to.path()?.join(Self::cache_file(owner, repo));
let is_file_expired = Self::is_file_expired(&cache, ttl)?; let is_file_expired = Self::is_file_expired(&cache, ttl)?;
let body = if is_file_expired { let body = if is_file_expired {
let gh_api_release = let gh_api_release = format!(
format!("https://api.github.com/repos/{}/{}/releases", owner, repo); "https://api.github.com/repos/{}/{}/releases?per_page=100",
owner, repo
);
Self::fetch_and_cache_releases(&gh_api_release, &cache)? Self::fetch_and_cache_releases(&gh_api_release, &cache)?
} else { } else {
std::fs::read_to_string(&cache)? std::fs::read_to_string(&cache)?

View File

@ -1,12 +1,13 @@
use aksr::Builder; use aksr::Builder;
use anyhow::Result; use anyhow::Result;
use image::DynamicImage; use image::DynamicImage;
use ndarray::{s, Axis};
use rayon::prelude::*; use rayon::prelude::*;
use crate::{ use crate::{
elapsed, elapsed,
models::{BaseModelTextual, BaseModelVisual}, models::{BaseModelTextual, BaseModelVisual},
Options, Scale, Ts, Xs, Ys, X, Y, LogitsSampler, Options, Scale, Ts, Xs, Ys, X, Y,
}; };
#[derive(Debug, Copy, Clone)] #[derive(Debug, Copy, Clone)]
@ -96,85 +97,6 @@ impl TrOCR {
Ok(ys) Ok(ys)
} }
// fn generate(&mut self, encoder_hidden_states: &X) -> Result<Vec<Vec<u32>>> {
// // input_ids
// let input_ids = X::from(vec![self.decoder_start_token_id as f32])
// .insert_axis(0)?
// .repeat(0, self.encoder.batch())?;
// // decoder
// let mut decoder_outputs = self.decoder.inference(Xs::from(vec![
// input_ids.clone(),
// encoder_hidden_states.clone(),
// ]))?;
// // encoder kvs
// let encoder_kvs: Vec<_> = (3..4 * self.n_kvs)
// .step_by(4)
// .flat_map(|i| [i, i + 1])
// .map(|i| decoder_outputs[i].clone())
// .collect();
// // token ids
// let mut token_ids: Vec<Vec<u32>> = vec![vec![]; self.encoder.batch()];
// let mut finished = vec![false; self.encoder.batch()];
// let mut last_tokens: Vec<f32> = vec![0.; self.encoder.batch()];
// let mut logits_sampler = LogitsSampler::new();
// // generate
// for _ in 0..self.max_length {
// let logits = &decoder_outputs[0];
// let decoder_kvs: Vec<_> = (1..(4 * self.n_kvs) - 2)
// .step_by(4)
// .flat_map(|i| [i, i + 1])
// .map(|i| decoder_outputs[i].clone())
// .collect();
// // decode each token for each batch
// for (i, logit) in logits.axis_iter(Axis(0)).enumerate() {
// if !finished[i] {
// let token_id = logits_sampler.decode(
// &logit
// .slice(s![-1, ..])
// .into_owned()
// .into_raw_vec_and_offset()
// .0,
// )?;
// if token_id == self.eos_token_id {
// finished[i] = true;
// } else {
// token_ids[i].push(token_id);
// }
// // update
// last_tokens[i] = token_id as f32;
// }
// }
// // all finished?
// if finished.iter().all(|&x| x) {
// break;
// }
// // build inputs
// let input_ids = X::from(last_tokens.clone()).insert_axis(1)?;
// let mut xs = vec![input_ids, encoder_hidden_states.clone()];
// for i in 0..self.n_kvs {
// xs.push(decoder_kvs[i * 2].clone());
// xs.push(decoder_kvs[i * 2 + 1].clone());
// xs.push(encoder_kvs[i * 2].clone());
// xs.push(encoder_kvs[i * 2 + 1].clone());
// }
// xs.push(X::ones(&[1])); // use_cache
// // generate
// decoder_outputs = self.decoder_merged.inference(xs.into())?;
// }
// Ok(token_ids)
// }
fn generate(&mut self, encoder_hidden_states: &X) -> Result<Vec<Vec<u32>>> { fn generate(&mut self, encoder_hidden_states: &X) -> Result<Vec<Vec<u32>>> {
// input_ids // input_ids
let input_ids = X::from(vec![self.decoder_start_token_id as f32]) let input_ids = X::from(vec![self.decoder_start_token_id as f32])
@ -196,6 +118,9 @@ impl TrOCR {
// token ids // token ids
let mut token_ids: Vec<Vec<u32>> = vec![vec![]; self.encoder.batch()]; let mut token_ids: Vec<Vec<u32>> = vec![vec![]; self.encoder.batch()];
let mut finished = vec![false; self.encoder.batch()];
let mut last_tokens: Vec<f32> = vec![0.; self.encoder.batch()];
let logits_sampler = LogitsSampler::new();
// generate // generate
for _ in 0..self.max_length { for _ in 0..self.max_length {
@ -207,18 +132,34 @@ impl TrOCR {
.collect(); .collect();
// decode each token for each batch // decode each token for each batch
let (finished, last_tokens) = self.decoder_merged.processor().par_generate( for (i, logit) in logits.axis_iter(Axis(0)).enumerate() {
logits, if !finished[i] {
&mut token_ids, let token_id = logits_sampler.decode(
self.eos_token_id, &logit
)?; .slice(s![-1, ..])
.into_owned()
.into_raw_vec_and_offset()
.0,
)?;
if finished { if token_id == self.eos_token_id {
finished[i] = true;
} else {
token_ids[i].push(token_id);
}
// update
last_tokens[i] = token_id as f32;
}
}
// all finished?
if finished.iter().all(|&x| x) {
break; break;
} }
// build inputs // build inputs
let input_ids = X::from(last_tokens).insert_axis(1)?; let input_ids = X::from(last_tokens.clone()).insert_axis(1)?;
let mut xs = vec![input_ids, encoder_hidden_states.clone()]; let mut xs = vec![input_ids, encoder_hidden_states.clone()];
for i in 0..self.n_kvs { for i in 0..self.n_kvs {
xs.push(decoder_kvs[i * 2].clone()); xs.push(decoder_kvs[i * 2].clone());
@ -235,6 +176,66 @@ impl TrOCR {
Ok(token_ids) Ok(token_ids)
} }
// fn generate(&mut self, encoder_hidden_states: &X) -> Result<Vec<Vec<u32>>> {
// // input_ids
// let input_ids = X::from(vec![self.decoder_start_token_id as f32])
// .insert_axis(0)?
// .repeat(0, self.encoder.batch())?;
// // decoder
// let mut decoder_outputs = self.decoder.inference(Xs::from(vec![
// input_ids.clone(),
// encoder_hidden_states.clone(),
// ]))?;
// // encoder kvs
// let encoder_kvs: Vec<_> = (3..4 * self.n_kvs)
// .step_by(4)
// .flat_map(|i| [i, i + 1])
// .map(|i| decoder_outputs[i].clone())
// .collect();
// // token ids
// let mut token_ids: Vec<Vec<u32>> = vec![vec![]; self.encoder.batch()];
// // generate
// for _ in 0..self.max_length {
// let logits = &decoder_outputs[0];
// let decoder_kvs: Vec<_> = (1..(4 * self.n_kvs) - 2)
// .step_by(4)
// .flat_map(|i| [i, i + 1])
// .map(|i| decoder_outputs[i].clone())
// .collect();
// // decode each token for each batch
// let (finished, last_tokens) = self.decoder_merged.processor().par_generate(
// logits,
// &mut token_ids,
// self.eos_token_id,
// )?;
// if finished {
// break;
// }
// // build inputs
// let input_ids = X::from(last_tokens).insert_axis(1)?;
// let mut xs = vec![input_ids, encoder_hidden_states.clone()];
// for i in 0..self.n_kvs {
// xs.push(decoder_kvs[i * 2].clone());
// xs.push(decoder_kvs[i * 2 + 1].clone());
// xs.push(encoder_kvs[i * 2].clone());
// xs.push(encoder_kvs[i * 2 + 1].clone());
// }
// xs.push(X::ones(&[1])); // use_cache
// // generate
// decoder_outputs = self.decoder_merged.inference(xs.into())?;
// }
// Ok(token_ids)
// }
pub fn decode(&self, token_ids: Vec<Vec<u32>>) -> Result<Ys> { pub fn decode(&self, token_ids: Vec<Vec<u32>>) -> Result<Ys> {
// decode // decode
let texts = self let texts = self