mirror of
https://github.com/mii443/usls.git
synced 2025-08-22 15:45:41 +00:00
Bug Fixes: Increased the number of available releases and fixed the Trocr bug. (#58)
This commit is contained in:
@ -17,9 +17,9 @@ fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
// 3. Fetch tags and files
|
// 3. Fetch tags and files
|
||||||
let hub = Hub::default().with_owner("jamjamjon").with_repo("usls");
|
let hub = Hub::default().with_owner("jamjamjon").with_repo("usls");
|
||||||
for tag in hub.tags().iter() {
|
for (i, tag) in hub.tags().iter().enumerate() {
|
||||||
let files = hub.files(tag);
|
let files = hub.files(tag);
|
||||||
println!("{} => {:?}", tag, files); // Should be empty
|
println!("{} :: {} => {:?}", i, tag, files); // Should be empty
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use aksr::Builder;
|
use aksr::Builder;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
use log::{error, info, warn};
|
use log::{debug, info, warn};
|
||||||
use ndarray::{Array, IxDyn};
|
use ndarray::{Array, IxDyn};
|
||||||
#[allow(unused_imports)]
|
#[allow(unused_imports)]
|
||||||
use ort::{
|
use ort::{
|
||||||
@ -275,7 +275,7 @@ impl Engine {
|
|||||||
{
|
{
|
||||||
match x.try_extract_tensor::<T>() {
|
match x.try_extract_tensor::<T>() {
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
error!("Failed to extract from ort outputs: {:?}", err);
|
debug!("Failed to extract from ort outputs: {:?}. A default value has been generated.", err);
|
||||||
Array::zeros(0).into_dyn()
|
Array::zeros(0).into_dyn()
|
||||||
}
|
}
|
||||||
Ok(x) => x.view().mapv(map_fn).into_owned(),
|
Ok(x) => x.view().mapv(map_fn).into_owned(),
|
||||||
|
@ -460,8 +460,10 @@ impl Hub {
|
|||||||
let cache = to.path()?.join(Self::cache_file(owner, repo));
|
let cache = to.path()?.join(Self::cache_file(owner, repo));
|
||||||
let is_file_expired = Self::is_file_expired(&cache, ttl)?;
|
let is_file_expired = Self::is_file_expired(&cache, ttl)?;
|
||||||
let body = if is_file_expired {
|
let body = if is_file_expired {
|
||||||
let gh_api_release =
|
let gh_api_release = format!(
|
||||||
format!("https://api.github.com/repos/{}/{}/releases", owner, repo);
|
"https://api.github.com/repos/{}/{}/releases?per_page=100",
|
||||||
|
owner, repo
|
||||||
|
);
|
||||||
Self::fetch_and_cache_releases(&gh_api_release, &cache)?
|
Self::fetch_and_cache_releases(&gh_api_release, &cache)?
|
||||||
} else {
|
} else {
|
||||||
std::fs::read_to_string(&cache)?
|
std::fs::read_to_string(&cache)?
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
use aksr::Builder;
|
use aksr::Builder;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use image::DynamicImage;
|
use image::DynamicImage;
|
||||||
|
use ndarray::{s, Axis};
|
||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
elapsed,
|
elapsed,
|
||||||
models::{BaseModelTextual, BaseModelVisual},
|
models::{BaseModelTextual, BaseModelVisual},
|
||||||
Options, Scale, Ts, Xs, Ys, X, Y,
|
LogitsSampler, Options, Scale, Ts, Xs, Ys, X, Y,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Debug, Copy, Clone)]
|
#[derive(Debug, Copy, Clone)]
|
||||||
@ -96,85 +97,6 @@ impl TrOCR {
|
|||||||
Ok(ys)
|
Ok(ys)
|
||||||
}
|
}
|
||||||
|
|
||||||
// fn generate(&mut self, encoder_hidden_states: &X) -> Result<Vec<Vec<u32>>> {
|
|
||||||
// // input_ids
|
|
||||||
// let input_ids = X::from(vec![self.decoder_start_token_id as f32])
|
|
||||||
// .insert_axis(0)?
|
|
||||||
// .repeat(0, self.encoder.batch())?;
|
|
||||||
|
|
||||||
// // decoder
|
|
||||||
// let mut decoder_outputs = self.decoder.inference(Xs::from(vec![
|
|
||||||
// input_ids.clone(),
|
|
||||||
// encoder_hidden_states.clone(),
|
|
||||||
// ]))?;
|
|
||||||
|
|
||||||
// // encoder kvs
|
|
||||||
// let encoder_kvs: Vec<_> = (3..4 * self.n_kvs)
|
|
||||||
// .step_by(4)
|
|
||||||
// .flat_map(|i| [i, i + 1])
|
|
||||||
// .map(|i| decoder_outputs[i].clone())
|
|
||||||
// .collect();
|
|
||||||
|
|
||||||
// // token ids
|
|
||||||
// let mut token_ids: Vec<Vec<u32>> = vec![vec![]; self.encoder.batch()];
|
|
||||||
// let mut finished = vec![false; self.encoder.batch()];
|
|
||||||
// let mut last_tokens: Vec<f32> = vec![0.; self.encoder.batch()];
|
|
||||||
// let mut logits_sampler = LogitsSampler::new();
|
|
||||||
|
|
||||||
// // generate
|
|
||||||
// for _ in 0..self.max_length {
|
|
||||||
// let logits = &decoder_outputs[0];
|
|
||||||
// let decoder_kvs: Vec<_> = (1..(4 * self.n_kvs) - 2)
|
|
||||||
// .step_by(4)
|
|
||||||
// .flat_map(|i| [i, i + 1])
|
|
||||||
// .map(|i| decoder_outputs[i].clone())
|
|
||||||
// .collect();
|
|
||||||
|
|
||||||
// // decode each token for each batch
|
|
||||||
// for (i, logit) in logits.axis_iter(Axis(0)).enumerate() {
|
|
||||||
// if !finished[i] {
|
|
||||||
// let token_id = logits_sampler.decode(
|
|
||||||
// &logit
|
|
||||||
// .slice(s![-1, ..])
|
|
||||||
// .into_owned()
|
|
||||||
// .into_raw_vec_and_offset()
|
|
||||||
// .0,
|
|
||||||
// )?;
|
|
||||||
|
|
||||||
// if token_id == self.eos_token_id {
|
|
||||||
// finished[i] = true;
|
|
||||||
// } else {
|
|
||||||
// token_ids[i].push(token_id);
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // update
|
|
||||||
// last_tokens[i] = token_id as f32;
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // all finished?
|
|
||||||
// if finished.iter().all(|&x| x) {
|
|
||||||
// break;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // build inputs
|
|
||||||
// let input_ids = X::from(last_tokens.clone()).insert_axis(1)?;
|
|
||||||
// let mut xs = vec![input_ids, encoder_hidden_states.clone()];
|
|
||||||
// for i in 0..self.n_kvs {
|
|
||||||
// xs.push(decoder_kvs[i * 2].clone());
|
|
||||||
// xs.push(decoder_kvs[i * 2 + 1].clone());
|
|
||||||
// xs.push(encoder_kvs[i * 2].clone());
|
|
||||||
// xs.push(encoder_kvs[i * 2 + 1].clone());
|
|
||||||
// }
|
|
||||||
// xs.push(X::ones(&[1])); // use_cache
|
|
||||||
|
|
||||||
// // generate
|
|
||||||
// decoder_outputs = self.decoder_merged.inference(xs.into())?;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// Ok(token_ids)
|
|
||||||
// }
|
|
||||||
|
|
||||||
fn generate(&mut self, encoder_hidden_states: &X) -> Result<Vec<Vec<u32>>> {
|
fn generate(&mut self, encoder_hidden_states: &X) -> Result<Vec<Vec<u32>>> {
|
||||||
// input_ids
|
// input_ids
|
||||||
let input_ids = X::from(vec![self.decoder_start_token_id as f32])
|
let input_ids = X::from(vec![self.decoder_start_token_id as f32])
|
||||||
@ -196,6 +118,9 @@ impl TrOCR {
|
|||||||
|
|
||||||
// token ids
|
// token ids
|
||||||
let mut token_ids: Vec<Vec<u32>> = vec![vec![]; self.encoder.batch()];
|
let mut token_ids: Vec<Vec<u32>> = vec![vec![]; self.encoder.batch()];
|
||||||
|
let mut finished = vec![false; self.encoder.batch()];
|
||||||
|
let mut last_tokens: Vec<f32> = vec![0.; self.encoder.batch()];
|
||||||
|
let logits_sampler = LogitsSampler::new();
|
||||||
|
|
||||||
// generate
|
// generate
|
||||||
for _ in 0..self.max_length {
|
for _ in 0..self.max_length {
|
||||||
@ -207,18 +132,34 @@ impl TrOCR {
|
|||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// decode each token for each batch
|
// decode each token for each batch
|
||||||
let (finished, last_tokens) = self.decoder_merged.processor().par_generate(
|
for (i, logit) in logits.axis_iter(Axis(0)).enumerate() {
|
||||||
logits,
|
if !finished[i] {
|
||||||
&mut token_ids,
|
let token_id = logits_sampler.decode(
|
||||||
self.eos_token_id,
|
&logit
|
||||||
)?;
|
.slice(s![-1, ..])
|
||||||
|
.into_owned()
|
||||||
|
.into_raw_vec_and_offset()
|
||||||
|
.0,
|
||||||
|
)?;
|
||||||
|
|
||||||
if finished {
|
if token_id == self.eos_token_id {
|
||||||
|
finished[i] = true;
|
||||||
|
} else {
|
||||||
|
token_ids[i].push(token_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
// update
|
||||||
|
last_tokens[i] = token_id as f32;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// all finished?
|
||||||
|
if finished.iter().all(|&x| x) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
// build inputs
|
// build inputs
|
||||||
let input_ids = X::from(last_tokens).insert_axis(1)?;
|
let input_ids = X::from(last_tokens.clone()).insert_axis(1)?;
|
||||||
let mut xs = vec![input_ids, encoder_hidden_states.clone()];
|
let mut xs = vec![input_ids, encoder_hidden_states.clone()];
|
||||||
for i in 0..self.n_kvs {
|
for i in 0..self.n_kvs {
|
||||||
xs.push(decoder_kvs[i * 2].clone());
|
xs.push(decoder_kvs[i * 2].clone());
|
||||||
@ -235,6 +176,66 @@ impl TrOCR {
|
|||||||
Ok(token_ids)
|
Ok(token_ids)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// fn generate(&mut self, encoder_hidden_states: &X) -> Result<Vec<Vec<u32>>> {
|
||||||
|
// // input_ids
|
||||||
|
// let input_ids = X::from(vec![self.decoder_start_token_id as f32])
|
||||||
|
// .insert_axis(0)?
|
||||||
|
// .repeat(0, self.encoder.batch())?;
|
||||||
|
|
||||||
|
// // decoder
|
||||||
|
// let mut decoder_outputs = self.decoder.inference(Xs::from(vec![
|
||||||
|
// input_ids.clone(),
|
||||||
|
// encoder_hidden_states.clone(),
|
||||||
|
// ]))?;
|
||||||
|
|
||||||
|
// // encoder kvs
|
||||||
|
// let encoder_kvs: Vec<_> = (3..4 * self.n_kvs)
|
||||||
|
// .step_by(4)
|
||||||
|
// .flat_map(|i| [i, i + 1])
|
||||||
|
// .map(|i| decoder_outputs[i].clone())
|
||||||
|
// .collect();
|
||||||
|
|
||||||
|
// // token ids
|
||||||
|
// let mut token_ids: Vec<Vec<u32>> = vec![vec![]; self.encoder.batch()];
|
||||||
|
|
||||||
|
// // generate
|
||||||
|
// for _ in 0..self.max_length {
|
||||||
|
// let logits = &decoder_outputs[0];
|
||||||
|
// let decoder_kvs: Vec<_> = (1..(4 * self.n_kvs) - 2)
|
||||||
|
// .step_by(4)
|
||||||
|
// .flat_map(|i| [i, i + 1])
|
||||||
|
// .map(|i| decoder_outputs[i].clone())
|
||||||
|
// .collect();
|
||||||
|
|
||||||
|
// // decode each token for each batch
|
||||||
|
// let (finished, last_tokens) = self.decoder_merged.processor().par_generate(
|
||||||
|
// logits,
|
||||||
|
// &mut token_ids,
|
||||||
|
// self.eos_token_id,
|
||||||
|
// )?;
|
||||||
|
|
||||||
|
// if finished {
|
||||||
|
// break;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // build inputs
|
||||||
|
// let input_ids = X::from(last_tokens).insert_axis(1)?;
|
||||||
|
// let mut xs = vec![input_ids, encoder_hidden_states.clone()];
|
||||||
|
// for i in 0..self.n_kvs {
|
||||||
|
// xs.push(decoder_kvs[i * 2].clone());
|
||||||
|
// xs.push(decoder_kvs[i * 2 + 1].clone());
|
||||||
|
// xs.push(encoder_kvs[i * 2].clone());
|
||||||
|
// xs.push(encoder_kvs[i * 2 + 1].clone());
|
||||||
|
// }
|
||||||
|
// xs.push(X::ones(&[1])); // use_cache
|
||||||
|
|
||||||
|
// // generate
|
||||||
|
// decoder_outputs = self.decoder_merged.inference(xs.into())?;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// Ok(token_ids)
|
||||||
|
// }
|
||||||
|
|
||||||
pub fn decode(&self, token_ids: Vec<Vec<u32>>) -> Result<Ys> {
|
pub fn decode(&self, token_ids: Vec<Vec<u32>>) -> Result<Ys> {
|
||||||
// decode
|
// decode
|
||||||
let texts = self
|
let texts = self
|
||||||
|
Reference in New Issue
Block a user