mirror of
https://github.com/mii443/usls.git
synced 2025-08-22 15:45:41 +00:00
Bug Fixes: Increased the number of available releases and fixed the Trocr bug. (#58)
This commit is contained in:
@ -17,9 +17,9 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
// 3. Fetch tags and files
|
||||
let hub = Hub::default().with_owner("jamjamjon").with_repo("usls");
|
||||
for tag in hub.tags().iter() {
|
||||
for (i, tag) in hub.tags().iter().enumerate() {
|
||||
let files = hub.files(tag);
|
||||
println!("{} => {:?}", tag, files); // Should be empty
|
||||
println!("{} :: {} => {:?}", i, tag, files); // Should be empty
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
@ -1,7 +1,7 @@
|
||||
use aksr::Builder;
|
||||
use anyhow::Result;
|
||||
use half::{bf16, f16};
|
||||
use log::{error, info, warn};
|
||||
use log::{debug, info, warn};
|
||||
use ndarray::{Array, IxDyn};
|
||||
#[allow(unused_imports)]
|
||||
use ort::{
|
||||
@ -275,7 +275,7 @@ impl Engine {
|
||||
{
|
||||
match x.try_extract_tensor::<T>() {
|
||||
Err(err) => {
|
||||
error!("Failed to extract from ort outputs: {:?}", err);
|
||||
debug!("Failed to extract from ort outputs: {:?}. A default value has been generated.", err);
|
||||
Array::zeros(0).into_dyn()
|
||||
}
|
||||
Ok(x) => x.view().mapv(map_fn).into_owned(),
|
||||
|
@ -460,8 +460,10 @@ impl Hub {
|
||||
let cache = to.path()?.join(Self::cache_file(owner, repo));
|
||||
let is_file_expired = Self::is_file_expired(&cache, ttl)?;
|
||||
let body = if is_file_expired {
|
||||
let gh_api_release =
|
||||
format!("https://api.github.com/repos/{}/{}/releases", owner, repo);
|
||||
let gh_api_release = format!(
|
||||
"https://api.github.com/repos/{}/{}/releases?per_page=100",
|
||||
owner, repo
|
||||
);
|
||||
Self::fetch_and_cache_releases(&gh_api_release, &cache)?
|
||||
} else {
|
||||
std::fs::read_to_string(&cache)?
|
||||
|
@ -1,12 +1,13 @@
|
||||
use aksr::Builder;
|
||||
use anyhow::Result;
|
||||
use image::DynamicImage;
|
||||
use ndarray::{s, Axis};
|
||||
use rayon::prelude::*;
|
||||
|
||||
use crate::{
|
||||
elapsed,
|
||||
models::{BaseModelTextual, BaseModelVisual},
|
||||
Options, Scale, Ts, Xs, Ys, X, Y,
|
||||
LogitsSampler, Options, Scale, Ts, Xs, Ys, X, Y,
|
||||
};
|
||||
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
@ -96,85 +97,6 @@ impl TrOCR {
|
||||
Ok(ys)
|
||||
}
|
||||
|
||||
// fn generate(&mut self, encoder_hidden_states: &X) -> Result<Vec<Vec<u32>>> {
|
||||
// // input_ids
|
||||
// let input_ids = X::from(vec![self.decoder_start_token_id as f32])
|
||||
// .insert_axis(0)?
|
||||
// .repeat(0, self.encoder.batch())?;
|
||||
|
||||
// // decoder
|
||||
// let mut decoder_outputs = self.decoder.inference(Xs::from(vec![
|
||||
// input_ids.clone(),
|
||||
// encoder_hidden_states.clone(),
|
||||
// ]))?;
|
||||
|
||||
// // encoder kvs
|
||||
// let encoder_kvs: Vec<_> = (3..4 * self.n_kvs)
|
||||
// .step_by(4)
|
||||
// .flat_map(|i| [i, i + 1])
|
||||
// .map(|i| decoder_outputs[i].clone())
|
||||
// .collect();
|
||||
|
||||
// // token ids
|
||||
// let mut token_ids: Vec<Vec<u32>> = vec![vec![]; self.encoder.batch()];
|
||||
// let mut finished = vec![false; self.encoder.batch()];
|
||||
// let mut last_tokens: Vec<f32> = vec![0.; self.encoder.batch()];
|
||||
// let mut logits_sampler = LogitsSampler::new();
|
||||
|
||||
// // generate
|
||||
// for _ in 0..self.max_length {
|
||||
// let logits = &decoder_outputs[0];
|
||||
// let decoder_kvs: Vec<_> = (1..(4 * self.n_kvs) - 2)
|
||||
// .step_by(4)
|
||||
// .flat_map(|i| [i, i + 1])
|
||||
// .map(|i| decoder_outputs[i].clone())
|
||||
// .collect();
|
||||
|
||||
// // decode each token for each batch
|
||||
// for (i, logit) in logits.axis_iter(Axis(0)).enumerate() {
|
||||
// if !finished[i] {
|
||||
// let token_id = logits_sampler.decode(
|
||||
// &logit
|
||||
// .slice(s![-1, ..])
|
||||
// .into_owned()
|
||||
// .into_raw_vec_and_offset()
|
||||
// .0,
|
||||
// )?;
|
||||
|
||||
// if token_id == self.eos_token_id {
|
||||
// finished[i] = true;
|
||||
// } else {
|
||||
// token_ids[i].push(token_id);
|
||||
// }
|
||||
|
||||
// // update
|
||||
// last_tokens[i] = token_id as f32;
|
||||
// }
|
||||
// }
|
||||
|
||||
// // all finished?
|
||||
// if finished.iter().all(|&x| x) {
|
||||
// break;
|
||||
// }
|
||||
|
||||
// // build inputs
|
||||
// let input_ids = X::from(last_tokens.clone()).insert_axis(1)?;
|
||||
// let mut xs = vec![input_ids, encoder_hidden_states.clone()];
|
||||
// for i in 0..self.n_kvs {
|
||||
// xs.push(decoder_kvs[i * 2].clone());
|
||||
// xs.push(decoder_kvs[i * 2 + 1].clone());
|
||||
// xs.push(encoder_kvs[i * 2].clone());
|
||||
// xs.push(encoder_kvs[i * 2 + 1].clone());
|
||||
// }
|
||||
// xs.push(X::ones(&[1])); // use_cache
|
||||
|
||||
// // generate
|
||||
// decoder_outputs = self.decoder_merged.inference(xs.into())?;
|
||||
// }
|
||||
|
||||
// Ok(token_ids)
|
||||
// }
|
||||
|
||||
fn generate(&mut self, encoder_hidden_states: &X) -> Result<Vec<Vec<u32>>> {
|
||||
// input_ids
|
||||
let input_ids = X::from(vec![self.decoder_start_token_id as f32])
|
||||
@ -196,6 +118,9 @@ impl TrOCR {
|
||||
|
||||
// token ids
|
||||
let mut token_ids: Vec<Vec<u32>> = vec![vec![]; self.encoder.batch()];
|
||||
let mut finished = vec![false; self.encoder.batch()];
|
||||
let mut last_tokens: Vec<f32> = vec![0.; self.encoder.batch()];
|
||||
let logits_sampler = LogitsSampler::new();
|
||||
|
||||
// generate
|
||||
for _ in 0..self.max_length {
|
||||
@ -207,18 +132,34 @@ impl TrOCR {
|
||||
.collect();
|
||||
|
||||
// decode each token for each batch
|
||||
let (finished, last_tokens) = self.decoder_merged.processor().par_generate(
|
||||
logits,
|
||||
&mut token_ids,
|
||||
self.eos_token_id,
|
||||
for (i, logit) in logits.axis_iter(Axis(0)).enumerate() {
|
||||
if !finished[i] {
|
||||
let token_id = logits_sampler.decode(
|
||||
&logit
|
||||
.slice(s![-1, ..])
|
||||
.into_owned()
|
||||
.into_raw_vec_and_offset()
|
||||
.0,
|
||||
)?;
|
||||
|
||||
if finished {
|
||||
if token_id == self.eos_token_id {
|
||||
finished[i] = true;
|
||||
} else {
|
||||
token_ids[i].push(token_id);
|
||||
}
|
||||
|
||||
// update
|
||||
last_tokens[i] = token_id as f32;
|
||||
}
|
||||
}
|
||||
|
||||
// all finished?
|
||||
if finished.iter().all(|&x| x) {
|
||||
break;
|
||||
}
|
||||
|
||||
// build inputs
|
||||
let input_ids = X::from(last_tokens).insert_axis(1)?;
|
||||
let input_ids = X::from(last_tokens.clone()).insert_axis(1)?;
|
||||
let mut xs = vec![input_ids, encoder_hidden_states.clone()];
|
||||
for i in 0..self.n_kvs {
|
||||
xs.push(decoder_kvs[i * 2].clone());
|
||||
@ -235,6 +176,66 @@ impl TrOCR {
|
||||
Ok(token_ids)
|
||||
}
|
||||
|
||||
// fn generate(&mut self, encoder_hidden_states: &X) -> Result<Vec<Vec<u32>>> {
|
||||
// // input_ids
|
||||
// let input_ids = X::from(vec![self.decoder_start_token_id as f32])
|
||||
// .insert_axis(0)?
|
||||
// .repeat(0, self.encoder.batch())?;
|
||||
|
||||
// // decoder
|
||||
// let mut decoder_outputs = self.decoder.inference(Xs::from(vec![
|
||||
// input_ids.clone(),
|
||||
// encoder_hidden_states.clone(),
|
||||
// ]))?;
|
||||
|
||||
// // encoder kvs
|
||||
// let encoder_kvs: Vec<_> = (3..4 * self.n_kvs)
|
||||
// .step_by(4)
|
||||
// .flat_map(|i| [i, i + 1])
|
||||
// .map(|i| decoder_outputs[i].clone())
|
||||
// .collect();
|
||||
|
||||
// // token ids
|
||||
// let mut token_ids: Vec<Vec<u32>> = vec![vec![]; self.encoder.batch()];
|
||||
|
||||
// // generate
|
||||
// for _ in 0..self.max_length {
|
||||
// let logits = &decoder_outputs[0];
|
||||
// let decoder_kvs: Vec<_> = (1..(4 * self.n_kvs) - 2)
|
||||
// .step_by(4)
|
||||
// .flat_map(|i| [i, i + 1])
|
||||
// .map(|i| decoder_outputs[i].clone())
|
||||
// .collect();
|
||||
|
||||
// // decode each token for each batch
|
||||
// let (finished, last_tokens) = self.decoder_merged.processor().par_generate(
|
||||
// logits,
|
||||
// &mut token_ids,
|
||||
// self.eos_token_id,
|
||||
// )?;
|
||||
|
||||
// if finished {
|
||||
// break;
|
||||
// }
|
||||
|
||||
// // build inputs
|
||||
// let input_ids = X::from(last_tokens).insert_axis(1)?;
|
||||
// let mut xs = vec![input_ids, encoder_hidden_states.clone()];
|
||||
// for i in 0..self.n_kvs {
|
||||
// xs.push(decoder_kvs[i * 2].clone());
|
||||
// xs.push(decoder_kvs[i * 2 + 1].clone());
|
||||
// xs.push(encoder_kvs[i * 2].clone());
|
||||
// xs.push(encoder_kvs[i * 2 + 1].clone());
|
||||
// }
|
||||
// xs.push(X::ones(&[1])); // use_cache
|
||||
|
||||
// // generate
|
||||
// decoder_outputs = self.decoder_merged.inference(xs.into())?;
|
||||
// }
|
||||
|
||||
// Ok(token_ids)
|
||||
// }
|
||||
|
||||
pub fn decode(&self, token_ids: Vec<Vec<u32>>) -> Result<Ys> {
|
||||
// decode
|
||||
let texts = self
|
||||
|
Reference in New Issue
Block a user