mingyang91 commited on
Commit
81301f1
β€’
1 Parent(s): 708bc33

split whisper into a standalone crate

Browse files
Cargo.lock CHANGED
@@ -1544,6 +1544,7 @@ version = "0.1.0"
1544
  dependencies = [
1545
  "anyhow",
1546
  "async-stream",
 
1547
  "aws-config",
1548
  "aws-sdk-polly",
1549
  "aws-sdk-transcribestreaming",
@@ -1552,6 +1553,7 @@ dependencies = [
1552
  "futures-util",
1553
  "fvad",
1554
  "hound",
 
1555
  "once_cell",
1556
  "poem",
1557
  "serde",
@@ -1562,6 +1564,7 @@ dependencies = [
1562
  "tracing",
1563
  "tracing-subscriber",
1564
  "tracing-test",
 
1565
  "whisper-rs",
1566
  "whisper-rs-sys",
1567
  ]
@@ -2156,9 +2159,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
2156
 
2157
  [[package]]
2158
  name = "tokio"
2159
- version = "1.33.0"
2160
  source = "registry+https://github.com/rust-lang/crates.io-index"
2161
- checksum = "4f38200e3ef7995e5ef13baec2f432a6da0aa9ac495b2c0e8f3b7eec2c92d653"
2162
  dependencies = [
2163
  "backtrace",
2164
  "bytes",
@@ -2174,9 +2177,9 @@ dependencies = [
2174
 
2175
  [[package]]
2176
  name = "tokio-macros"
2177
- version = "2.1.0"
2178
  source = "registry+https://github.com/rust-lang/crates.io-index"
2179
- checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e"
2180
  dependencies = [
2181
  "proc-macro2",
2182
  "quote",
@@ -2529,6 +2532,22 @@ dependencies = [
2529
  "rustix",
2530
  ]
2531
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2532
  [[package]]
2533
  name = "whisper-rs"
2534
  version = "0.9.0-rc.2"
 
1544
  dependencies = [
1545
  "anyhow",
1546
  "async-stream",
1547
+ "async-trait",
1548
  "aws-config",
1549
  "aws-sdk-polly",
1550
  "aws-sdk-transcribestreaming",
 
1553
  "futures-util",
1554
  "fvad",
1555
  "hound",
1556
+ "lazy_static",
1557
  "once_cell",
1558
  "poem",
1559
  "serde",
 
1564
  "tracing",
1565
  "tracing-subscriber",
1566
  "tracing-test",
1567
+ "whisper",
1568
  "whisper-rs",
1569
  "whisper-rs-sys",
1570
  ]
 
2159
 
2160
  [[package]]
2161
  name = "tokio"
2162
+ version = "1.34.0"
2163
  source = "registry+https://github.com/rust-lang/crates.io-index"
2164
+ checksum = "d0c014766411e834f7af5b8f4cf46257aab4036ca95e9d2c144a10f59ad6f5b9"
2165
  dependencies = [
2166
  "backtrace",
2167
  "bytes",
 
2177
 
2178
  [[package]]
2179
  name = "tokio-macros"
2180
+ version = "2.2.0"
2181
  source = "registry+https://github.com/rust-lang/crates.io-index"
2182
+ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b"
2183
  dependencies = [
2184
  "proc-macro2",
2185
  "quote",
 
2532
  "rustix",
2533
  ]
2534
 
2535
+ [[package]]
2536
+ name = "whisper"
2537
+ version = "0.1.0"
2538
+ dependencies = [
2539
+ "fvad",
2540
+ "hound",
2541
+ "lazy_static",
2542
+ "once_cell",
2543
+ "serde",
2544
+ "tokio",
2545
+ "tracing",
2546
+ "tracing-test",
2547
+ "whisper-rs",
2548
+ "whisper-rs-sys",
2549
+ ]
2550
+
2551
  [[package]]
2552
  name = "whisper-rs"
2553
  version = "0.9.0-rc.2"
Cargo.toml CHANGED
@@ -3,6 +3,9 @@ name = "polyhedron"
3
  version = "0.1.0"
4
  edition = "2021"
5
 
 
 
 
6
  [dependencies]
7
  anyhow = "1.0"
8
  async-stream = "0.3"
@@ -21,6 +24,12 @@ tokio-stream = "0.1"
21
  tracing = { version = "0.1", features = [] }
22
  tracing-subscriber = { version = "0.3", features = ["env-filter"] }
23
  fvad = "0.1"
 
 
 
 
 
 
24
 
25
  [dependencies.poem]
26
  version = "1.3"
 
3
  version = "0.1.0"
4
  edition = "2021"
5
 
6
+ [workspace]
7
+ members = ["whisper"]
8
+
9
  [dependencies]
10
  anyhow = "1.0"
11
  async-stream = "0.3"
 
24
  tracing = { version = "0.1", features = [] }
25
  tracing-subscriber = { version = "0.3", features = ["env-filter"] }
26
  fvad = "0.1"
27
+ whisper = { path="whisper", optional = true }
28
+ async-trait = "0.1.74"
29
+ lazy_static = { version = "1.4.0", features = [] }
30
+
31
+ [features]
32
+ whisper = ["dep:whisper"]
33
 
34
  [dependencies.poem]
35
  version = "1.3"
src/asr/aws.rs ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use async_trait::async_trait;
2
+ use tokio::sync::broadcast::Receiver;
3
+ use crate::asr::{ASR, Event};
4
+
5
+ struct AWS_ASR {
6
+ aws: aws_sdk_transcribestreaming::Client,
7
+ }
8
+ #[async_trait]
9
+ impl ASR for AWS_ASR {
10
+ async fn frame(&mut self, frame: &[i16]) -> anyhow::Result<()> {
11
+ todo!()
12
+ }
13
+
14
+ fn subscribe(&mut self) -> Receiver<Event> {
15
+ todo!()
16
+ }
17
+ }
src/asr/whisper.rs ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #[cfg(feature = "whisper")]
2
+ pub mod whisper_asr {
3
+ use async_trait::async_trait;
4
+ use tokio::{select, spawn};
5
+ use tokio::sync::broadcast::Receiver;
6
+ use tokio::sync::broadcast::error::RecvError;
7
+ use lazy_static::lazy_static;
8
+ use whisper::config::WhisperConfig;
9
+
10
+ extern crate whisper;
11
+
12
+ use whisper::handler::{Error, Output, WhisperHandler, Context};
13
+ use crate::asr::{ASR, Event};
14
+ use crate::config::SETTINGS;
15
+
16
+ lazy_static! {
17
+ pub static ref CONTEXT: Context = Context::new(&SETTINGS.whisper.model)
18
+ .expect("Failed to initialize whisper context");
19
+ }
20
+
21
+ pub struct Whisper_ASR {
22
+ whisper: WhisperHandler,
23
+ tx: tokio::sync::broadcast::Sender<Event>,
24
+ }
25
+
26
+ impl Whisper_ASR {
27
+ pub async fn from_config() -> Result<Whisper_ASR, Error> {
28
+ let whisper = CONTEXT.create_handler(&SETTINGS.whisper, "".to_string())?;
29
+ let mut output_rx = whisper.subscribe();
30
+ let (tx, _) = tokio::sync::broadcast::channel(64);
31
+ let shared_tx = tx.clone();
32
+ let fut = async move {
33
+ loop {
34
+ select! {
35
+ poll = output_rx.recv() => {
36
+ match poll {
37
+ Ok(outputs) => {
38
+ for output in outputs {
39
+ let res = match output {
40
+ Output::Stable(segment) => tx.send(Event {
41
+ transcript: segment.text,
42
+ is_final: true,
43
+ }),
44
+ Output::Unstable(segment) => tx.send(Event {
45
+ transcript: segment.text,
46
+ is_final: false,
47
+ }),
48
+ };
49
+ if let Err(e) = res {
50
+ tracing::warn!("Failed to send whisper event: {}", e);
51
+ break
52
+ }
53
+ }
54
+ },
55
+ Err(RecvError::Closed) => break,
56
+ Err(RecvError::Lagged(lagged)) => {
57
+ tracing::warn!("Whisper ASR output lagged: {}", lagged);
58
+ }
59
+ }
60
+ },
61
+ }
62
+ }
63
+ };
64
+ spawn(fut);
65
+ Ok(Self { whisper, tx: shared_tx })
66
+ }
67
+ }
68
+
69
+ #[async_trait]
70
+ impl ASR for Whisper_ASR {
71
+ async fn frame(&mut self, frame: &[i16]) -> anyhow::Result<()> {
72
+ Ok(self.whisper.send_i16(frame.to_vec()).await?)
73
+ }
74
+
75
+ fn subscribe(&mut self) -> Receiver<Event> {
76
+ self.tx.subscribe()
77
+ }
78
+ }
79
+ }
src/config.rs CHANGED
@@ -1,106 +1,15 @@
1
- use std::{env, ffi::c_int, net::IpAddr};
2
 
3
  use config::{Config, Environment, File};
4
  use once_cell::sync::Lazy;
5
  use serde::Deserialize;
6
- use whisper_rs::{FullParams};
7
  use tracing::debug;
 
 
8
 
9
  pub(crate) static SETTINGS: Lazy<Settings> =
10
  Lazy::new(|| Settings::new().expect("Failed to initialize settings"));
11
 
12
- #[derive(Debug, Deserialize, Clone)]
13
- pub(crate) struct WhisperConfig {
14
- pub(crate) params: WhisperParams,
15
- pub(crate) step_ms: usize,
16
- pub(crate) length_ms: usize,
17
- pub(crate) keep_ms: usize,
18
- pub(crate) model: String,
19
- pub(crate) max_prompt_tokens: usize,
20
- pub(crate) context_confidence_threshold: f32,
21
- }
22
-
23
- #[allow(dead_code)]
24
- #[derive(Debug, Deserialize, Clone)]
25
- pub(crate) struct WhisperParams {
26
- pub(crate) n_threads: Option<usize>,
27
- pub(crate) max_tokens: Option<u32>,
28
- pub(crate) audio_ctx: Option<u32>,
29
- pub(crate) speed_up: Option<bool>,
30
- pub(crate) translate: Option<bool>,
31
- pub(crate) no_context: Option<bool>,
32
- pub(crate) print_special: Option<bool>,
33
- pub(crate) print_realtime: Option<bool>,
34
- pub(crate) print_progress: Option<bool>,
35
- pub(crate) token_timestamps: Option<bool>,
36
- pub(crate) no_timestamps: Option<bool>,
37
- pub(crate) temperature_inc: Option<f32>,
38
- pub(crate) entropy_threshold: Option<f32>,
39
- pub(crate) single_segment: Option<bool>,
40
- pub(crate) suppress_non_speech_tokens: Option<bool>,
41
- pub(crate) n_max_text_ctx: Option<usize>,
42
- // pub(crate) tinydiarize: bool,
43
- pub(crate) language: Option<String>,
44
- }
45
-
46
- impl WhisperParams {
47
- pub(crate) fn to_full_params<'a, 'b>(&'a self, tokens: &'b [c_int]) -> FullParams<'a, 'b> {
48
- let mut param = FullParams::new(Default::default());
49
- if let Some(print_progress) = self.print_progress.as_ref() {
50
- param.set_print_progress(*print_progress);
51
- }
52
- if let Some(print_special) = self.print_special.as_ref() {
53
- param.set_print_special(*print_special);
54
- }
55
- if let Some(print_realtime) = self.print_realtime.as_ref() {
56
- param.set_print_realtime(*print_realtime);
57
- }
58
- if let Some(single_segment) = self.single_segment.as_ref() {
59
- param.set_single_segment(*single_segment);
60
- }
61
- if let Some(no_timestamps) = self.no_timestamps.as_ref() {
62
- param.set_print_timestamps(!no_timestamps);
63
- }
64
- if let Some(token_timestamps) = self.token_timestamps.as_ref() {
65
- param.set_token_timestamps(*token_timestamps);
66
- }
67
- if let Some(translate) = self.translate.as_ref() {
68
- param.set_translate(*translate);
69
- }
70
- if let Some(max_tokens) = self.max_tokens.as_ref() {
71
- param.set_max_tokens(*max_tokens as i32);
72
- }
73
- param.set_language(self.language.as_deref());
74
- if let Some(n_threads) = self.n_threads.as_ref() {
75
- param.set_n_threads(*n_threads as i32);
76
- }
77
- if let Some(audio_ctx) = self.audio_ctx.as_ref() {
78
- param.set_audio_ctx(*audio_ctx as i32);
79
- }
80
- if let Some(speed_up) = self.speed_up.as_ref() {
81
- param.set_speed_up(*speed_up);
82
- }
83
- // param.set_tdrz_enable(self.tinydiarize);
84
- if let Some(temperature_inc) = self.temperature_inc.as_ref() {
85
- param.set_temperature_inc(*temperature_inc);
86
- }
87
- if let Some(suppress_non_speech_tokens) = self.suppress_non_speech_tokens.as_ref() {
88
- param.set_suppress_non_speech_tokens(*suppress_non_speech_tokens);
89
- }
90
- if let Some(no_context) = self.no_context.as_ref() {
91
- param.set_no_context(*no_context);
92
- }
93
- if let Some(entropy_threshold) = self.entropy_threshold.as_ref() {
94
- param.set_entropy_thold(*entropy_threshold);
95
- }
96
- if let Some(n_max_text_ctx) = self.n_max_text_ctx.as_ref() {
97
- param.set_n_max_text_ctx(*n_max_text_ctx as i32);
98
- }
99
-
100
- param.set_tokens(tokens);
101
- param
102
- }
103
- }
104
 
105
  #[derive(Debug, Deserialize)]
106
  pub(crate) struct Server {
@@ -110,7 +19,8 @@ pub(crate) struct Server {
110
 
111
  #[derive(Debug, Deserialize)]
112
  pub struct Settings {
113
- pub(crate) whisper: WhisperConfig,
 
114
  pub(crate) server: Server,
115
  }
116
 
 
1
+ use std::{env, net::IpAddr};
2
 
3
  use config::{Config, Environment, File};
4
  use once_cell::sync::Lazy;
5
  use serde::Deserialize;
 
6
  use tracing::debug;
7
+ #[cfg(feature = "whisper")]
8
+ use crate::whisper;
9
 
10
  pub(crate) static SETTINGS: Lazy<Settings> =
11
  Lazy::new(|| Settings::new().expect("Failed to initialize settings"));
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  #[derive(Debug, Deserialize)]
15
  pub(crate) struct Server {
 
19
 
20
  #[derive(Debug, Deserialize)]
21
  pub struct Settings {
22
+ #[cfg(feature = "whisper")]
23
+ pub(crate) whisper: whisper::config::WhisperConfig,
24
  pub(crate) server: Server,
25
  }
26
 
src/main.rs CHANGED
@@ -5,6 +5,8 @@
5
 
6
  #![allow(clippy::result_large_err)]
7
 
 
 
8
  use aws_sdk_transcribestreaming::meta::PKG_VERSION;
9
  use futures_util::{stream::StreamExt, SinkExt};
10
  use poem::{
@@ -22,12 +24,13 @@ use tokio::select;
22
  use tracing::debug;
23
  use tracing_subscriber::{fmt, prelude::*, EnvFilter};
24
 
25
- use crate::{config::*, lesson::*, whisper::*};
 
 
26
 
27
  mod config;
28
- mod group;
29
  mod lesson;
30
- mod whisper;
31
 
32
  #[derive(Clone)]
33
  struct Context {
@@ -104,20 +107,23 @@ async fn stream_speaker(
104
  ws.on_upgrade(|mut socket| async move {
105
  let _origin_tx = lesson.voice_channel();
106
  let mut transcribe_rx = lesson.transcript_channel();
107
- let mut whisper = WhisperHandler::new(SETTINGS.whisper.clone(), prompt)
 
108
  .expect("failed to create whisper");
 
109
  let mut whisper_transcribe_rx = whisper.subscribe();
110
  loop {
111
  select! {
112
- w = whisper_transcribe_rx.recv() => {
113
- let Ok(_txt) = w else {
114
- // TODO: handle msg
115
- continue
116
- };
117
- }
118
  msg = socket.next() => {
119
  match msg.as_ref() {
120
  Some(Ok(Message::Binary(bin))) => {
 
121
  let _ = whisper.send_bytes(bin.to_vec()).await; // whisper test
122
  // if let Err(e) = origin_tx.send(bin.to_vec()).await {
123
  // tracing::warn!("failed to send voice: {}", e);
 
5
 
6
  #![allow(clippy::result_large_err)]
7
 
8
+ #[cfg(feature = "whisper")]
9
+ extern crate whisper;
10
  use aws_sdk_transcribestreaming::meta::PKG_VERSION;
11
  use futures_util::{stream::StreamExt, SinkExt};
12
  use poem::{
 
24
  use tracing::debug;
25
  use tracing_subscriber::{fmt, prelude::*, EnvFilter};
26
 
27
+ use crate::{config::*, lesson::*};
28
+ #[cfg(feature = "whisper")]
29
+ use crate::whisper::*;
30
 
31
  mod config;
 
32
  mod lesson;
33
+ mod asr;
34
 
35
  #[derive(Clone)]
36
  struct Context {
 
107
  ws.on_upgrade(|mut socket| async move {
108
  let _origin_tx = lesson.voice_channel();
109
  let mut transcribe_rx = lesson.transcript_channel();
110
+ #[cfg(feature = "whisper")]
111
+ let mut whisper = asr::whisper::whisper_asr::CONTEXT.create_handler(&SETTINGS.whisper, prompt)
112
  .expect("failed to create whisper");
113
+ #[cfg(feature = "whisper")]
114
  let mut whisper_transcribe_rx = whisper.subscribe();
115
  loop {
116
  select! {
117
+ // w = whisper_transcribe_rx.recv() => {
118
+ // let Ok(_txt) = w else {
119
+ // // TODO: handle msg
120
+ // continue
121
+ // };
122
+ // }
123
  msg = socket.next() => {
124
  match msg.as_ref() {
125
  Some(Ok(Message::Binary(bin))) => {
126
+ #[cfg(feature = "whisper")]
127
  let _ = whisper.send_bytes(bin.to_vec()).await; // whisper test
128
  // if let Err(e) = origin_tx.send(bin.to_vec()).await {
129
  // tracing::warn!("failed to send voice: {}", e);
whisper/ggml-metal.metal ADDED
The diff for this file is too large to render. See raw diff
 
whisper/src/config.rs ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use serde::Deserialize;
2
+ use whisper_rs::{FullParams, WhisperToken};
3
+
4
+ #[derive(Debug, Deserialize, Clone)]
5
+ pub struct WhisperConfig {
6
+ pub(crate) params: WhisperParams,
7
+ pub(crate) step_ms: usize,
8
+ pub(crate) length_ms: usize,
9
+ pub(crate) keep_ms: usize,
10
+ pub model: String,
11
+ pub(crate) max_prompt_tokens: usize,
12
+ pub(crate) context_confidence_threshold: f32,
13
+ }
14
+
15
+ #[allow(dead_code)]
16
+ #[derive(Debug, Deserialize, Clone)]
17
+ pub(crate) struct WhisperParams {
18
+ pub(crate) n_threads: Option<usize>,
19
+ pub(crate) max_tokens: Option<u32>,
20
+ pub(crate) audio_ctx: Option<u32>,
21
+ pub(crate) speed_up: Option<bool>,
22
+ pub(crate) translate: Option<bool>,
23
+ pub(crate) no_context: Option<bool>,
24
+ pub(crate) print_special: Option<bool>,
25
+ pub(crate) print_realtime: Option<bool>,
26
+ pub(crate) print_progress: Option<bool>,
27
+ pub(crate) token_timestamps: Option<bool>,
28
+ pub(crate) no_timestamps: Option<bool>,
29
+ pub(crate) temperature_inc: Option<f32>,
30
+ pub(crate) entropy_threshold: Option<f32>,
31
+ pub(crate) single_segment: Option<bool>,
32
+ pub(crate) suppress_non_speech_tokens: Option<bool>,
33
+ pub(crate) n_max_text_ctx: Option<usize>,
34
+ // pub(crate) tinydiarize: bool,
35
+ pub(crate) language: Option<String>,
36
+ }
37
+
38
+ impl WhisperParams {
39
+ pub(crate) fn to_full_params<'a, 'b>(&'a self, tokens: &'b [WhisperToken]) -> FullParams<'a, 'b> {
40
+ let mut param = FullParams::new(Default::default());
41
+ if let Some(print_progress) = self.print_progress.as_ref() {
42
+ param.set_print_progress(*print_progress);
43
+ }
44
+ if let Some(print_special) = self.print_special.as_ref() {
45
+ param.set_print_special(*print_special);
46
+ }
47
+ if let Some(print_realtime) = self.print_realtime.as_ref() {
48
+ param.set_print_realtime(*print_realtime);
49
+ }
50
+ if let Some(single_segment) = self.single_segment.as_ref() {
51
+ param.set_single_segment(*single_segment);
52
+ }
53
+ if let Some(no_timestamps) = self.no_timestamps.as_ref() {
54
+ param.set_print_timestamps(!no_timestamps);
55
+ }
56
+ if let Some(token_timestamps) = self.token_timestamps.as_ref() {
57
+ param.set_token_timestamps(*token_timestamps);
58
+ }
59
+ if let Some(translate) = self.translate.as_ref() {
60
+ param.set_translate(*translate);
61
+ }
62
+ if let Some(max_tokens) = self.max_tokens.as_ref() {
63
+ param.set_max_tokens(*max_tokens as i32);
64
+ }
65
+ param.set_language(self.language.as_deref());
66
+ if let Some(n_threads) = self.n_threads.as_ref() {
67
+ param.set_n_threads(*n_threads as i32);
68
+ }
69
+ if let Some(audio_ctx) = self.audio_ctx.as_ref() {
70
+ param.set_audio_ctx(*audio_ctx as i32);
71
+ }
72
+ if let Some(speed_up) = self.speed_up.as_ref() {
73
+ param.set_speed_up(*speed_up);
74
+ }
75
+ // param.set_tdrz_enable(self.tinydiarize);
76
+ if let Some(temperature_inc) = self.temperature_inc.as_ref() {
77
+ param.set_temperature_inc(*temperature_inc);
78
+ }
79
+ if let Some(suppress_non_speech_tokens) = self.suppress_non_speech_tokens.as_ref() {
80
+ param.set_suppress_non_speech_tokens(*suppress_non_speech_tokens);
81
+ }
82
+ if let Some(no_context) = self.no_context.as_ref() {
83
+ param.set_no_context(*no_context);
84
+ }
85
+ if let Some(entropy_threshold) = self.entropy_threshold.as_ref() {
86
+ param.set_entropy_thold(*entropy_threshold);
87
+ }
88
+ if let Some(n_max_text_ctx) = self.n_max_text_ctx.as_ref() {
89
+ param.set_n_max_text_ctx(*n_max_text_ctx as i32);
90
+ }
91
+
92
+ param.set_tokens(tokens);
93
+ param
94
+ }
95
+ }
{src β†’ whisper/src}/group.rs RENAMED
File without changes
src/whisper.rs β†’ whisper/src/handler.rs RENAMED
@@ -6,35 +6,47 @@ use std::{
6
  };
7
  use fvad::SampleRate;
8
 
9
- use once_cell::sync::Lazy;
10
- use tokio::sync::{broadcast, mpsc, oneshot};
11
  use tokio::time::Instant;
12
- use tracing::{debug, trace, warn};
13
- use whisper_rs::{convert_integer_to_float_audio, WhisperContext, WhisperState, WhisperToken, WhisperTokenData};
14
 
15
- use crate::config::{Settings, SETTINGS};
16
  use crate::{config::WhisperConfig, group::GroupedWithin};
17
 
18
  const WHISPER_SAMPLE_RATE: usize = whisper_rs_sys::WHISPER_SAMPLE_RATE as usize;
19
 
20
- static WHISPER_CONTEXT: Lazy<WhisperContext> = Lazy::new(|| {
21
- let settings = Settings::new().expect("Failed to initialize settings.");
22
- if tracing::enabled!(tracing::Level::DEBUG) {
23
- let info = whisper_rs::print_system_info();
24
- debug!("system_info: n_threads = {} / {} | {}\n",
25
- settings.whisper.params.n_threads.unwrap_or(0),
26
- std::thread::available_parallelism().map(|c| c.get()).unwrap_or(0),
27
- info);
 
 
 
 
28
  }
29
- WhisperContext::new(&settings.whisper.model).expect("failed to create WhisperContext")
30
- });
 
31
 
 
 
 
 
 
 
 
 
 
32
 
33
  #[derive(Debug)]
34
- pub(crate) enum Error {
35
  WhisperError {
36
  description: String,
37
- error: whisper_rs::WhisperError,
38
  },
39
  }
40
 
@@ -97,22 +109,25 @@ pub struct WhisperHandler {
97
  }
98
 
99
  impl WhisperHandler {
100
- pub(crate) fn new(config: WhisperConfig, prompt: String) -> Result<Self, Error> {
 
 
 
 
 
 
 
 
101
  let vad_slice_size = WHISPER_SAMPLE_RATE / 100 * 3;
102
  let (stop_handle, mut stop_signal) = oneshot::channel();
103
  let (pcm_tx, pcm_rx) = mpsc::channel::<Vec<i16>>(128);
104
  let (transcription_tx, _) = broadcast::channel::<Vec<Output>>(128);
105
  let shared_transcription_tx = transcription_tx.clone();
106
- let state = WHISPER_CONTEXT
107
- .create_state()
108
- .map_err(|e| Error::whisper_error("failed to create WhisperState", e))?;
109
- let preset_prompt_tokens = WHISPER_CONTEXT
110
- .tokenize(prompt.as_str(), SETTINGS.whisper.max_prompt_tokens)
111
- .map_err(|e| Error::whisper_error("failed to tokenize prompt", e))?;
112
  tokio::task::spawn_blocking(move || {
113
  let mut vad = fvad::Fvad::new().expect("failed to create VAD")
114
  .set_sample_rate(SampleRate::Rate16kHz);
115
- let mut detector = Detector::new(state, &SETTINGS.whisper, preset_prompt_tokens);
116
  let mut grouped = GroupedWithin::new(
117
  detector.n_samples_step,
118
  Duration::from_millis(config.step_ms as u64),
@@ -368,7 +383,9 @@ mod test {
368
  use std::io::{stdout, Write};
369
  use hound;
370
  use tracing_test;
371
- use tracing::info;
 
 
372
 
373
  async fn print_output(output: Output) {
374
  match output {
@@ -386,18 +403,50 @@ mod test {
386
  }
387
  stdout().flush().unwrap();
388
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
  #[tokio::test]
390
  #[tracing_test::traced_test]
391
  async fn test_whisper_handler() {
392
- let mut whisper_handler = WhisperHandler::new(
393
- SETTINGS.whisper.clone(),
394
- "Harry Potter and the Philosopher's Stone".to_string(),
395
- ).expect("failed to create WhisperHandler");
396
 
397
  let wav = hound::WavReader::open("samples/ADHD_1A.wav")
398
  .expect("failed to open wav");
399
  let spec = wav.spec();
400
- println!("{:?}", spec);
401
  let samples = wav
402
  .into_samples::<i16>()
403
  .map(|s| s.unwrap())
@@ -423,9 +472,9 @@ mod test {
423
 
424
  match output {
425
  Output::Stable(stable) => {
426
- println!("{}", stable.text);
427
  },
428
- Output::Unstable(unstable) => {
429
 
430
  }
431
  }
 
6
  };
7
  use fvad::SampleRate;
8
 
9
+ use tokio::sync::{broadcast, mpsc, oneshot, OnceCell};
 
10
  use tokio::time::Instant;
11
+ use tracing::{warn};
12
+ use whisper_rs::{convert_integer_to_float_audio, WhisperContext, WhisperError, WhisperState, WhisperToken, WhisperTokenData};
13
 
 
14
  use crate::{config::WhisperConfig, group::GroupedWithin};
15
 
16
  const WHISPER_SAMPLE_RATE: usize = whisper_rs_sys::WHISPER_SAMPLE_RATE as usize;
17
 
18
+ pub struct Context {
19
+ context: WhisperContext,
20
+ }
21
+
22
+ impl <'a> Context {
23
+ pub fn new(model: &str) -> Result<Context, WhisperError> {
24
+ WhisperContext::new(model)
25
+ .map(|context| Self { context })
26
+ }
27
+
28
+ pub fn create_handler(&'static self, config: &'static WhisperConfig, prompt: String) -> Result<WhisperHandler, Error> {
29
+ WhisperHandler::new(&self.context, config, prompt)
30
  }
31
+ }
32
+
33
+ static WHISPER_CONTEXT: OnceCell<WhisperContext> = OnceCell::const_new();
34
 
35
+ async fn initialize_whisper_context(model: String) -> WhisperContext {
36
+ tokio::task::spawn_blocking(move || {
37
+ WhisperContext::new(&model).expect("failed to create WhisperContext")
38
+ }).await.expect("failed to spawn")
39
+ }
40
+
41
+ async fn get_whisper_context(model: String) -> &'static WhisperContext {
42
+ WHISPER_CONTEXT.get_or_init(|| initialize_whisper_context(model)).await
43
+ }
44
 
45
  #[derive(Debug)]
46
+ pub enum Error {
47
  WhisperError {
48
  description: String,
49
+ error: WhisperError,
50
  },
51
  }
52
 
 
109
  }
110
 
111
  impl WhisperHandler {
112
+
113
+ fn new(whisper_context: &'static WhisperContext, config: &'static WhisperConfig, prompt: String) -> Result<Self, Error> {
114
+ // let whisper_context = get_whisper_context(config.model.clone()).await;
115
+ let state = whisper_context
116
+ .create_state()
117
+ .map_err(|e| Error::whisper_error("failed to create WhisperState", e))?;
118
+ let preset_prompt_tokens = whisper_context
119
+ .tokenize(&prompt, config.max_prompt_tokens)
120
+ .map_err(|e| Error::whisper_error("failed to tokenize prompt", e))?;
121
  let vad_slice_size = WHISPER_SAMPLE_RATE / 100 * 3;
122
  let (stop_handle, mut stop_signal) = oneshot::channel();
123
  let (pcm_tx, pcm_rx) = mpsc::channel::<Vec<i16>>(128);
124
  let (transcription_tx, _) = broadcast::channel::<Vec<Output>>(128);
125
  let shared_transcription_tx = transcription_tx.clone();
126
+
 
 
 
 
 
127
  tokio::task::spawn_blocking(move || {
128
  let mut vad = fvad::Fvad::new().expect("failed to create VAD")
129
  .set_sample_rate(SampleRate::Rate16kHz);
130
+ let mut detector = Detector::new(state, &config, preset_prompt_tokens);
131
  let mut grouped = GroupedWithin::new(
132
  detector.n_samples_step,
133
  Duration::from_millis(config.step_ms as u64),
 
383
  use std::io::{stdout, Write};
384
  use hound;
385
  use tracing_test;
386
+ use tracing::{info, debug};
387
+ use crate::config::WhisperParams;
388
+ use lazy_static::lazy_static;
389
 
390
  async fn print_output(output: Output) {
391
  match output {
 
403
  }
404
  stdout().flush().unwrap();
405
  }
406
+
407
+ lazy_static! {
408
+ static ref CONFIG: WhisperConfig = WhisperConfig {
409
+ length_ms: 5000,
410
+ step_ms: 500,
411
+ keep_ms: 200,
412
+ model: "models/ggml-large-v3.bin".to_string(),
413
+ max_prompt_tokens: 32,
414
+ context_confidence_threshold: 0.5,
415
+ params: WhisperParams {
416
+ n_threads: None,
417
+ max_tokens: None,
418
+ audio_ctx: None,
419
+ speed_up: None,
420
+ translate: None,
421
+ no_context: None,
422
+ print_special: None,
423
+ print_realtime: None,
424
+ print_progress: None,
425
+ token_timestamps: None,
426
+ no_timestamps: None,
427
+ temperature_inc: None,
428
+ entropy_threshold: None,
429
+ single_segment: Some(true),
430
+ suppress_non_speech_tokens: None,
431
+ n_max_text_ctx: None,
432
+ language: Some("en".to_string()),
433
+ }
434
+ };
435
+
436
+ static ref CONTEXT: Context = Context::new(&CONFIG.model).expect("failed to create WhisperContext");
437
+ }
438
+
439
  #[tokio::test]
440
  #[tracing_test::traced_test]
441
  async fn test_whisper_handler() {
442
+ let mut whisper_handler = CONTEXT
443
+ .create_handler(&CONFIG, "Harry Potter and the Philosopher's Stone".to_string())
444
+ .expect("failed to create WhisperHandler");
 
445
 
446
  let wav = hound::WavReader::open("samples/ADHD_1A.wav")
447
  .expect("failed to open wav");
448
  let spec = wav.spec();
449
+ info!("{:?}", spec);
450
  let samples = wav
451
  .into_samples::<i16>()
452
  .map(|s| s.unwrap())
 
472
 
473
  match output {
474
  Output::Stable(stable) => {
475
+ debug!("{}", stable.text);
476
  },
477
+ Output::Unstable(_unstable) => {
478
 
479
  }
480
  }
whisper/src/main.rs ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+
2
+
3
+ #[tokio::main]
4
+ async fn main() {
5
+
6
+ }