From 211c145c4c5e8d910c55111664a61bfa9222e502 Mon Sep 17 00:00:00 2001 From: mii443 Date: Tue, 17 Sep 2024 19:19:40 +0900 Subject: [PATCH] fix resampling --- .vscode/launch.json | 16 +++++++++++++++ src/audio/resampling.rs | 19 ++++++++++++++++-- src/commands/run.rs | 11 +++++------ src/device/virtual_device.rs | 38 ++++++++++++++++++++++-------------- 4 files changed, 61 insertions(+), 23 deletions(-) create mode 100644 .vscode/launch.json diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..10efcb2 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,16 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "type": "lldb", + "request": "launch", + "name": "Debug", + "program": "${workspaceFolder}/", + "args": [], + "cwd": "${workspaceFolder}" + } + ] +} \ No newline at end of file diff --git a/src/audio/resampling.rs b/src/audio/resampling.rs index cb825a8..b26ff93 100644 --- a/src/audio/resampling.rs +++ b/src/audio/resampling.rs @@ -1,4 +1,4 @@ -use rubato::{Resampler, SincFixedIn, SincInterpolationParameters, SincInterpolationType, WindowFunction}; +use rubato::{FastFixedIn, FastFixedOut, PolynomialDegree, Resampler, SincFixedIn, SincInterpolationParameters, SincInterpolationType, WindowFunction}; const MAX_CHUNK_SIZE: usize = 1024; @@ -64,4 +64,19 @@ pub fn resampling( } Ok(output) -} \ No newline at end of file +} + +pub fn get_resampler(from: f64, to: f64, channels: usize, chunk_size: usize) -> FastFixedIn +where + T: rubato::Sample +{ + let resample_ratio = to / from; + + FastFixedIn::::new( + resample_ratio, + 2.0, + PolynomialDegree::Cubic, + chunk_size, + channels + ).unwrap() +} diff --git a/src/commands/run.rs b/src/commands/run.rs index f32ad83..b3addd8 100644 --- a/src/commands/run.rs +++ b/src/commands/run.rs @@ -1,13 +1,12 @@ use std::{ - cmp::min, - io::Read, - sync::{Arc, Mutex}, + cmp::min, io::Read, mem::MaybeUninit, sync::{Arc, Mutex} }; use cpal::{ traits::{DeviceTrait, HostTrait, StreamTrait}, Device, Stream, StreamConfig, SupportedStreamConfig, }; +use rubato::FastFixedIn; use crate::{args::Run, device::virtual_device::VirtualDevice}; @@ -351,7 +350,7 @@ fn input_callback(data: &[T], channels: usize, virtual_device: Arc = data.iter().map(|d| d.to_f32().unwrap()).collect(); + let data: Vec = data.iter().map(|d| d.to_f64().unwrap()).collect(); let audio_data = reshape_audio_data(&data, channels); virtual_device .lock() @@ -377,7 +376,7 @@ fn output_callback( data.len() / channels, ); let mut count = 0; - while audio_data.is_none() && count < 1 { + while audio_data.is_none() && count < 1000 { audio_data = virtual_device.take_output( index, min(channels as u8, vd_channels), @@ -401,7 +400,7 @@ fn output_callback( let audio_data: Vec = audio_data .iter() - .map(|d| T::from_f32(*d).unwrap()) + .map(|d| T::from_f64(*d).unwrap()) .collect(); data.clone_from_slice(&audio_data); diff --git a/src/device/virtual_device.rs b/src/device/virtual_device.rs index 99079a0..3e22441 100644 --- a/src/device/virtual_device.rs +++ b/src/device/virtual_device.rs @@ -1,6 +1,8 @@ -use std::collections::HashMap; +use std::{collections::HashMap, mem::MaybeUninit, sync::{Arc, Mutex}}; -use crate::audio::resampling::resampling; +use rubato::{FastFixedIn, Resampler}; + +use crate::audio::resampling::{get_resampler, resampling}; pub struct VirtualDevice { pub name: String, @@ -8,7 +10,7 @@ pub struct VirtualDevice { pub sample_rate: u32, output_index: HashMap>, - output_buffer: HashMap>>, + output_buffer: HashMap>, Arc>>>)>, } impl VirtualDevice { @@ -31,7 +33,7 @@ impl VirtualDevice { if let std::collections::hash_map::Entry::Vacant(e) = self.output_index.entry(sample_rate) { e.insert(Vec::new()); self.output_buffer - .insert(sample_rate, vec![vec![]; self.channels as usize]); + .insert(sample_rate, (vec![vec![]; self.channels as usize], Arc::new(Mutex::new(None)))); } let min_index = self.get_min_index(sample_rate); self.output_index @@ -47,17 +49,17 @@ impl VirtualDevice { channels: u8, sample_rate: u32, take_size: usize, - ) -> Option>> { + ) -> Option>> { let mut buffer = vec![Vec::with_capacity(take_size); channels as usize]; let start = self.output_index[&sample_rate][index]; let end = start + take_size; for channel in 0..channels { - if end >= self.output_buffer[&sample_rate][channel as usize].len() { + if end >= self.output_buffer[&sample_rate].0[channel as usize].len() { println!( "End of buffer: {}, {}[{}]", end, - self.output_buffer[&sample_rate][channel as usize].len(), + self.output_buffer[&sample_rate].0[channel as usize].len(), channel as usize ); return None; @@ -67,7 +69,7 @@ impl VirtualDevice { for i in start..end { for channel in 0..channels { buffer[channel as usize] - .push(self.output_buffer[&sample_rate][channel as usize][i]); + .push(self.output_buffer[&sample_rate].0[channel as usize][i]); } } self.output_index.get_mut(&sample_rate).unwrap()[index] = end; @@ -75,8 +77,8 @@ impl VirtualDevice { let min = self.get_min_index(sample_rate); if min != 0 { for i in 0..self.channels as usize { - let len = self.output_buffer[&sample_rate][i].len(); - self.output_buffer.get_mut(&sample_rate).unwrap()[i] + let len = self.output_buffer[&sample_rate].0[i].len(); + self.output_buffer.get_mut(&sample_rate).unwrap().0[i] .drain(0..(if len < min { len } else { min })); } for i in 0..self.output_index.len() { @@ -87,17 +89,23 @@ impl VirtualDevice { Some(buffer) } - pub fn write_input_multiple_channels(&mut self, input_buffer: &[Vec]) { + pub fn write_input_multiple_channels(&mut self, input_buffer: &[Vec]) { for (sample_rate, buffer) in self.output_buffer.iter_mut() { + let mut guard = buffer.1.lock().unwrap(); + let resampler = { + if guard.is_none() { + *guard = Some(get_resampler::(self.sample_rate as f64, *sample_rate as f64, self.channels as usize, input_buffer[0].len())); + } + guard.as_mut().unwrap() + }; + let buffer_resample = if self.sample_rate == *sample_rate { input_buffer.to_vec() } else { - resampling(self.sample_rate, *sample_rate, input_buffer.to_vec()).unwrap() + resampler.process(input_buffer, None).unwrap() }; - println!("Resampling: {} -> {}", input_buffer[0].len(), buffer_resample[0].len()); - - (0..self.channels as usize).for_each(|i| buffer[i].extend(buffer_resample[i].iter())); + (0..self.channels as usize).for_each(|i| buffer.0[i].extend(buffer_resample[i].iter())); } } }