diff --git a/src/misc/ops.rs b/src/misc/ops.rs index e30a929..8d416ce 100644 --- a/src/misc/ops.rs +++ b/src/misc/ops.rs @@ -7,7 +7,8 @@ use fast_image_resize::{ FilterType, ResizeAlg, ResizeOptions, Resizer, }; use image::{DynamicImage, GenericImageView}; -use ndarray::{concatenate, s, Array, Array3, Axis, IntoDimension, Ix2, IxDyn}; +use ndarray::{concatenate, s, Array, Array3, ArrayView1, Axis, IntoDimension, Ix2, IxDyn, Zip}; + use rayon::prelude::*; pub enum Ops<'a> { @@ -27,7 +28,7 @@ pub enum Ops<'a> { } impl Ops<'_> { - pub fn normalize(x: Array, min: f32, max: f32) -> Result> { + pub fn normalize(x: &mut Array, min: f32, max: f32) -> Result<()> { if min >= max { anyhow::bail!( "Invalid range in `normalize`: `min` ({}) must be less than `max` ({}).", @@ -35,7 +36,10 @@ impl Ops<'_> { max ); } - Ok((x - min) / (max - min)) + let range = max - min; + x.par_mapv_inplace(|x| (x - min) / range); + + Ok(()) } pub fn sigmoid(x: Array) -> Array { @@ -74,11 +78,11 @@ impl Ops<'_> { } pub fn standardize( - x: Array, - mean: &[f32], - std: &[f32], + x: &mut Array, + mean: ArrayView1, + std: ArrayView1, dim: usize, - ) -> Result> { + ) -> Result<()> { if mean.len() != std.len() { anyhow::bail!( "`standardize`: `mean` and `std` lengths are not equal. Mean length: {}, Std length: {}.", @@ -86,6 +90,7 @@ impl Ops<'_> { std.len() ); } + let shape = x.shape(); if dim >= shape.len() || shape[dim] != mean.len() { anyhow::bail!( @@ -95,11 +100,20 @@ impl Ops<'_> { mean.len() ); } - let mut shape = vec![1; shape.len()]; - shape[dim] = mean.len(); - let mean = Array::from_shape_vec(shape.clone(), mean.to_vec())?; - let std = Array::from_shape_vec(shape, std.to_vec())?; - Ok((x - mean) / std) + let mean_broadcast = mean.broadcast(shape).ok_or_else(|| { + anyhow::anyhow!("Failed to broadcast `mean` to the shape of the input array.") + })?; + let std_broadcast = std.broadcast(shape).ok_or_else(|| { + anyhow::anyhow!("Failed to broadcast `std` to the shape of the input array.") + })?; + Zip::from(x) + .and(mean_broadcast) + .and(std_broadcast) + .par_for_each(|x_val, &mean_val, &std_val| { + *x_val = (*x_val - mean_val) / std_val; + }); + + Ok(()) } pub fn permute(x: Array, shape: &[usize]) -> Result> { diff --git a/src/misc/processor.rs b/src/misc/processor.rs index 3ada0bb..2fbd424 100644 --- a/src/misc/processor.rs +++ b/src/misc/processor.rs @@ -6,6 +6,8 @@ use fast_image_resize::{ }; use image::{DynamicImage, GenericImageView}; use ndarray::{s, Array, Axis}; +use rayon::prelude::*; +use std::sync::Mutex; use tokenizers::{Encoding, Tokenizer}; use crate::{LogitsSampler, X}; @@ -67,10 +69,10 @@ impl Processor { } pub fn process_images(&mut self, xs: &[DynamicImage]) -> Result { - // reset - self.reset_image0_status(); - - let mut x = self.resize_batch(xs)?; + // self.reset_image0_status(); + let (mut x, image0s_size, scale_factors_hw) = self.par_resize(xs)?; + self.image0s_size = image0s_size; + self.scale_factors_hw = scale_factors_hw; if self.do_normalize { x = x.normalize(0., 255.)?; } @@ -85,6 +87,7 @@ impl Processor { if self.unsigned { x = x.unsigned(); } + Ok(x) } @@ -341,26 +344,144 @@ impl Processor { Ok(y.into()) } - pub fn resize_batch(&mut self, xs: &[DynamicImage]) -> Result { - // TODO: par resize - if xs.is_empty() { - anyhow::bail!("Found no input images.") + #[allow(clippy::type_complexity)] + pub fn resize2(&self, x: &DynamicImage) -> Result<(X, (u32, u32), Vec)> { + if self.image_width + self.image_height == 0 { + anyhow::bail!( + "Invalid target height: {} or width: {}.", + self.image_height, + self.image_width + ); } - let mut ys = Array::ones(( - xs.len(), - self.image_height as usize, - self.image_width as usize, - 3, - )) + let image0s_size: (u32, u32); // original image height and width + let scale_factors_hw: Vec; + + let buffer = match x.dimensions() { + (w, h) if (w, h) == (self.image_height, self.image_width) => { + image0s_size = (h, w); + scale_factors_hw = vec![1., 1.]; + x.to_rgb8().into_raw() + } + (w0, h0) => { + image0s_size = (h0, w0); + let (mut resizer, options) = Self::build_resizer_filter(self.resize_filter)?; + + if let ResizeMode::FitExact = self.resize_mode { + let mut dst = Image::new(self.image_width, self.image_height, PixelType::U8x3); + resizer.resize(x, &mut dst, &options)?; + scale_factors_hw = vec![ + (self.image_height as f32 / h0 as f32), + (self.image_width as f32 / w0 as f32), + ]; + + dst.into_vec() + } else { + let (w, h) = match self.resize_mode { + ResizeMode::Letterbox | ResizeMode::FitAdaptive => { + let r = (self.image_width as f32 / w0 as f32) + .min(self.image_height as f32 / h0 as f32); + scale_factors_hw = vec![r, r]; + + ( + (w0 as f32 * r).round() as u32, + (h0 as f32 * r).round() as u32, + ) + } + ResizeMode::FitHeight => { + let r = self.image_height as f32 / h0 as f32; + scale_factors_hw = vec![1.0, r]; + ((r * w0 as f32).round() as u32, self.image_height) + } + ResizeMode::FitWidth => { + // scale factor + let r = self.image_width as f32 / w0 as f32; + scale_factors_hw = vec![r, 1.0]; + (self.image_width, (r * h0 as f32).round() as u32) + } + + _ => unreachable!(), + }; + + let mut dst = Image::from_vec_u8( + self.image_width, + self.image_height, + vec![ + self.padding_value; + 3 * self.image_height as usize * self.image_width as usize + ], + PixelType::U8x3, + )?; + let (l, t) = if let ResizeMode::Letterbox = self.resize_mode { + if w == self.image_width { + (0, (self.image_height - h) / 2) + } else { + ((self.image_width - w) / 2, 0) + } + } else { + (0, 0) + }; + + let mut dst_cropped = CroppedImageMut::new(&mut dst, l, t, w, h)?; + resizer.resize(x, &mut dst_cropped, &options)?; + dst.into_vec() + } + } + }; + + let y = Array::from_shape_vec( + (self.image_height as usize, self.image_width as usize, 3), + buffer, + )? + .mapv(|x| x as f32) .into_dyn(); - xs.iter().enumerate().try_for_each(|(idx, x)| { - let y = self.resize(x)?; - ys.slice_mut(s![idx, .., .., ..]).assign(&y); - anyhow::Ok(()) - })?; + Ok((y.into(), image0s_size, scale_factors_hw)) + } - Ok(ys.into()) + #[allow(clippy::type_complexity)] + pub fn par_resize(&self, xs: &[DynamicImage]) -> Result<(X, Vec<(u32, u32)>, Vec>)> { + match xs.len() { + 0 => anyhow::bail!("Found no input images."), + 1 => { + let (y, image0_size, scale_factors) = self.resize2(&xs[0])?; + Ok((y.insert_axis(0)?, vec![image0_size], vec![scale_factors])) + } + _ => { + let ys = Mutex::new( + Array::zeros(( + xs.len(), + self.image_height as usize, + self.image_width as usize, + 3, + )) + .into_dyn(), + ); + + let results: Result)>> = xs + .par_iter() + .enumerate() + .map(|(idx, x)| { + let (y, image0_size, scale_factors) = self.resize2(x)?; + { + let mut ys_guard = ys + .lock() + .map_err(|e| anyhow::anyhow!("Mutex lock error: {e}"))?; + ys_guard.slice_mut(s![idx, .., .., ..]).assign(&y); + } + + Ok((image0_size, scale_factors)) + }) + .collect(); + + let (image0s_size, scale_factors_hw): (Vec<_>, Vec<_>) = + results?.into_iter().unzip(); + let ys_inner = ys + .into_inner() + .map_err(|e| anyhow::anyhow!("Mutex into_inner error: {e}"))?; + + Ok((ys_inner.into(), image0s_size, scale_factors_hw)) + } + } } } diff --git a/src/xy/x.rs b/src/xy/x.rs index 6b4b482..17c07ec 100644 --- a/src/xy/x.rs +++ b/src/xy/x.rs @@ -172,12 +172,13 @@ impl X { } pub fn normalize(mut self, min_: f32, max_: f32) -> Result { - self.0 = Ops::normalize(self.0, min_, max_)?; + Ops::normalize(&mut self.0, min_, max_)?; + Ok(self) } pub fn standardize(mut self, mean: &[f32], std: &[f32], dim: usize) -> Result { - self.0 = Ops::standardize(self.0, mean, std, dim)?; + Ops::standardize(&mut self.0, mean.into(), std.into(), dim)?; Ok(self) } @@ -187,7 +188,7 @@ impl X { } pub fn unsigned(mut self) -> Self { - self.0 = self.0.mapv(|x| if x < 0.0 { 0.0 } else { x }); + self.0.par_mapv_inplace(|x| if x < 0.0 { 0.0 } else { x }); self }