diff --git a/Cargo.lock b/Cargo.lock index 5565ec42b0..3675406d8e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -18514,12 +18514,15 @@ version = "0.1.0" dependencies = [ "aspasia", "audio-utils", + "dasp", "futures-util", "host", "language", "owhisper-client", "owhisper-interface", + "pyannote-local", "ractor", + "rodio", "serde", "specta", "specta-typescript", diff --git a/apps/desktop/src/components/main/body/sessions/outer-header/overflow/diarize.tsx b/apps/desktop/src/components/main/body/sessions/outer-header/overflow/diarize.tsx new file mode 100644 index 0000000000..e9d77637d5 --- /dev/null +++ b/apps/desktop/src/components/main/body/sessions/outer-header/overflow/diarize.tsx @@ -0,0 +1,162 @@ +import { useMutation } from "@tanstack/react-query"; +import { Loader2Icon, UsersIcon } from "lucide-react"; + +import { + type DiarizationSegment, + commands as listener2Commands, +} from "@hypr/plugin-listener2"; +import { + DropdownMenuItem, + DropdownMenuSub, + DropdownMenuSubContent, + DropdownMenuSubTrigger, +} from "@hypr/ui/components/ui/dropdown-menu"; + +import * as main from "../../../../../../store/tinybase/store/main"; +import type { SpeakerHintWithId } from "../../../../../../store/transcript/types"; +import { + parseTranscriptHints, + parseTranscriptWords, + updateTranscriptHints, +} from "../../../../../../store/transcript/utils"; +import { id } from "../../../../../../utils"; + +export function Diarize({ sessionId }: { sessionId: string }) { + const store = main.UI.useStore(main.STORE_ID); + const transcriptIds = main.UI.useSliceRowIds( + main.INDEXES.transcriptBySession, + sessionId, + main.STORE_ID, + ); + const checkpoints = main.UI.useCheckpoints(main.STORE_ID); + + const { mutate, isPending } = useMutation({ + mutationFn: async (maxSpeakers: number) => { + const result = await listener2Commands.diarizeSession( + sessionId, + maxSpeakers, + ); + if (result.status === "error") { + throw new Error(result.error); + } + return result.data; + }, + onSuccess: (segments: DiarizationSegment[]) => { + if (!store || !transcriptIds || transcriptIds.length === 0) { + return; + } + + const firstStartedAt = store.getCell( + "transcripts", + transcriptIds[0], + "started_at", + ); + + for (const transcriptId of transcriptIds) { + const startedAt = store.getCell( + "transcripts", + transcriptId, + "started_at", + ); + const offset = + typeof startedAt === "number" && typeof firstStartedAt === "number" + ? startedAt - firstStartedAt + : 0; + + const words = parseTranscriptWords(store, transcriptId); + const existingHints = parseTranscriptHints(store, transcriptId); + const newHints: SpeakerHintWithId[] = []; + + for (const word of words) { + if (word.start_ms === undefined || word.end_ms === undefined) { + continue; + } + + const wordStartSec = (word.start_ms + offset) / 1000; + const wordEndSec = (word.end_ms + offset) / 1000; + const wordMidSec = (wordStartSec + wordEndSec) / 2; + + const matchedSegment = segments.find( + (seg) => seg.start <= wordMidSec && wordMidSec < seg.end, + ); + + if (matchedSegment) { + newHints.push({ + id: id(), + word_id: word.id, + type: "provider_speaker_index", + value: JSON.stringify({ + speaker_index: matchedSegment.speaker, + provider: "pyannote-local", + }), + }); + } + } + + const filteredHints = existingHints.filter((h) => { + if (h.type !== "provider_speaker_index") return true; + try { + const v = JSON.parse(h.value as string); + return v.provider !== "pyannote-local"; + } catch { + return true; + } + }); + + updateTranscriptHints(store, transcriptId, [ + ...filteredHints, + ...newHints, + ]); + } + + checkpoints?.addCheckpoint("diarize_speakers"); + }, + }); + + return ( + + + {isPending ? : } + {isPending ? "Diarizing..." : "Diarize Speakers"} + + + { + e.preventDefault(); + mutate(2); + }} + disabled={isPending} + > + 2 speakers + + { + e.preventDefault(); + mutate(3); + }} + disabled={isPending} + > + 3 speakers + + { + e.preventDefault(); + mutate(4); + }} + disabled={isPending} + > + 4 speakers + + { + e.preventDefault(); + mutate(6); + }} + disabled={isPending} + > + Auto (up to 6) + + + + ); +} diff --git a/apps/desktop/src/components/main/body/sessions/outer-header/overflow/index.tsx b/apps/desktop/src/components/main/body/sessions/outer-header/overflow/index.tsx index d8a99f32ca..4a42358921 100644 --- a/apps/desktop/src/components/main/body/sessions/outer-header/overflow/index.tsx +++ b/apps/desktop/src/components/main/body/sessions/outer-header/overflow/index.tsx @@ -14,6 +14,7 @@ import { import type { EditorView } from "../../../../../../store/zustand/tabs/schema"; import { useHasTranscript } from "../../shared"; import { DeleteNote, DeleteRecording } from "./delete"; +import { Diarize } from "./diarize"; import { ExportPDF } from "./export-pdf"; import { ExportTranscript } from "./export-transcript"; import { Listening } from "./listening"; @@ -55,6 +56,7 @@ export function OverflowButton({ {hasTranscript && } + {audioExists.data && hasTranscript && } diff --git a/crates/data/build.rs b/crates/data/build.rs index b379cc3e86..79cd2287ca 100644 --- a/crates/data/build.rs +++ b/crates/data/build.rs @@ -30,6 +30,11 @@ fn run(name: &str) { } fn main() { + println!("cargo:rerun-if-changed=src/english_3/raw.json"); + println!("cargo:rerun-if-changed=src/english_4/raw.json"); + println!("cargo:rerun-if-changed=src/english_5/raw.json"); + println!("cargo:rerun-if-changed=src/english_7/raw.json"); + run("english_3"); run("english_4"); run("english_5"); diff --git a/crates/pyannote-local/src/diarize.rs b/crates/pyannote-local/src/diarize.rs new file mode 100644 index 0000000000..099d7493f4 --- /dev/null +++ b/crates/pyannote-local/src/diarize.rs @@ -0,0 +1,252 @@ +use crate::embedding::EmbeddingExtractor; +use crate::identify::EmbeddingManager; +use crate::segmentation::Segmenter; + +#[derive(Debug, Clone, serde::Serialize, specta::Type)] +pub struct DiarizationSegment { + pub start: f64, + pub end: f64, + pub speaker: usize, +} + +pub struct DiarizeOptions { + pub max_speakers: usize, + pub threshold: f32, + pub min_segment_duration: f64, +} + +impl Default for DiarizeOptions { + fn default() -> Self { + Self { + max_speakers: 6, + threshold: 0.5, + min_segment_duration: 0.5, + } + } +} + +pub fn diarize( + samples: &[i16], + sample_rate: u32, + options: Option, +) -> Result, crate::Error> { + let options = options.unwrap_or_default(); + + let mut segmenter = Segmenter::new(sample_rate)?; + let segments = segmenter.process(samples, sample_rate)?; + + let mut extractor = EmbeddingExtractor::new(); + let mut manager = EmbeddingManager::new(options.max_speakers, options.threshold); + + let mut result = Vec::with_capacity(segments.len()); + + for segment in &segments { + if segment.end - segment.start < options.min_segment_duration { + continue; + } + let embedding = extractor.compute(segment.samples.iter().copied())?; + let speaker = manager.identify(&embedding); + result.push(DiarizationSegment { + start: segment.start, + end: segment.end, + speaker, + }); + } + + smooth_speakers(&mut result); + + Ok(result) +} + +fn smooth_speakers(segments: &mut [DiarizationSegment]) { + if segments.len() < 3 { + return; + } + for i in 1..segments.len() - 1 { + let prev = segments[i - 1].speaker; + let next = segments[i + 1].speaker; + if prev == next && segments[i].speaker != prev { + segments[i].speaker = prev; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn audio_from_bytes(bytes: &[u8]) -> Vec { + bytes + .chunks_exact(2) + .map(|chunk| i16::from_le_bytes([chunk[0], chunk[1]])) + .collect() + } + + #[test] + fn test_diarize_english_1() { + let audio = audio_from_bytes(hypr_data::english_1::AUDIO); + let segments = diarize(&audio, 16000, None).unwrap(); + + assert!(!segments.is_empty(), "should produce at least one segment"); + for seg in &segments { + assert!(seg.end > seg.start); + println!( + "{:.2} - {:.2} [Speaker {}]", + seg.start, seg.end, seg.speaker + ); + } + } + + #[test] + fn test_diarize_english_2() { + let audio = audio_from_bytes(hypr_data::english_2::AUDIO); + let segments = diarize(&audio, 16000, None).unwrap(); + + assert!(!segments.is_empty(), "should produce at least one segment"); + for seg in &segments { + assert!(seg.end > seg.start); + println!( + "{:.2} - {:.2} [Speaker {}]", + seg.start, seg.end, seg.speaker + ); + } + } + + #[test] + #[ignore] + fn test_diarize_real_session() { + use dasp::sample::Sample; + + let path = std::path::PathBuf::from(env!("HOME")) + .join("Library/Application Support/hyprnote/sessions") + .join("ee73358b-c65e-4b62-9506-df14404d937b/audio.wav"); + + let f32_samples: Vec = rodio::Decoder::try_from(std::fs::File::open(&path).unwrap()) + .unwrap() + .collect(); + + let i16_samples: Vec = f32_samples.iter().map(|s| s.to_sample()).collect(); + + println!( + "Audio: {:.1}s, {} samples", + i16_samples.len() as f64 / 16000.0, + i16_samples.len() + ); + + let segments = diarize(&i16_samples, 16000, None).unwrap(); + + println!("\n{} segments found:\n", segments.len()); + let mut speakers = std::collections::HashSet::new(); + for seg in &segments { + speakers.insert(seg.speaker); + let dur = seg.end - seg.start; + println!( + " {:>6.2}s - {:>6.2}s ({:>5.2}s) Speaker {}", + seg.start, seg.end, dur, seg.speaker + ); + } + println!("\nUnique speakers: {}", speakers.len()); + } + + #[test] + #[ignore] + fn test_diarize_real_session_stereo() { + use dasp::sample::Sample; + use rodio::Source; + + let path = std::path::PathBuf::from(env!("HOME")) + .join("Library/Application Support/hyprnote/sessions") + .join("72cf1b45-e63d-40d2-b931-8980ad88734b/audio.wav"); + + let decoder = rodio::Decoder::try_from(std::fs::File::open(&path).unwrap()).unwrap(); + let channels = decoder.channels() as usize; + let f32_samples: Vec = decoder.collect(); + + let mono: Vec = if channels > 1 { + f32_samples + .chunks_exact(channels) + .map(|frame| frame.iter().sum::() / channels as f32) + .collect() + } else { + f32_samples + }; + + let i16_samples: Vec = mono.iter().map(|s| s.to_sample()).collect(); + + println!( + "Audio: {:.1}s, {} samples (mixed from {} ch)", + i16_samples.len() as f64 / 16000.0, + i16_samples.len(), + channels + ); + + let segments = diarize(&i16_samples, 16000, None).unwrap(); + + println!("\n{} segments found:\n", segments.len()); + let mut speakers = std::collections::HashSet::new(); + for seg in &segments { + speakers.insert(seg.speaker); + let dur = seg.end - seg.start; + println!( + " {:>6.2}s - {:>6.2}s ({:>5.2}s) Speaker {}", + seg.start, seg.end, dur, seg.speaker + ); + } + println!("\nUnique speakers: {}", speakers.len()); + } + + #[test] + #[ignore] + fn test_diarize_real_session_stereo_2speakers() { + use dasp::sample::Sample; + use rodio::Source; + + let path = std::path::PathBuf::from(env!("HOME")) + .join("Library/Application Support/hyprnote/sessions") + .join("72cf1b45-e63d-40d2-b931-8980ad88734b/audio.wav"); + + let decoder = rodio::Decoder::try_from(std::fs::File::open(&path).unwrap()).unwrap(); + let channels = decoder.channels() as usize; + let f32_samples: Vec = decoder.collect(); + + let mono: Vec = if channels > 1 { + f32_samples + .chunks_exact(channels) + .map(|frame| frame.iter().sum::() / channels as f32) + .collect() + } else { + f32_samples + }; + + let i16_samples: Vec = mono.iter().map(|s| s.to_sample()).collect(); + + println!( + "Audio: {:.1}s, {} samples (mixed from {} ch, max_speakers=2)", + i16_samples.len() as f64 / 16000.0, + i16_samples.len(), + channels + ); + + let opts = DiarizeOptions { + max_speakers: 2, + ..Default::default() + }; + let segments = diarize(&i16_samples, 16000, Some(opts)).unwrap(); + + println!("\n{} segments found:\n", segments.len()); + let mut speaker_time: std::collections::HashMap = + std::collections::HashMap::new(); + for seg in &segments { + *speaker_time.entry(seg.speaker).or_default() += seg.end - seg.start; + let dur = seg.end - seg.start; + println!( + " {:>6.2}s - {:>6.2}s ({:>5.2}s) Speaker {}", + seg.start, seg.end, dur, seg.speaker + ); + } + println!(); + for (spk, time) in &speaker_time { + println!("Speaker {}: {:.1}s total", spk, time); + } + } +} diff --git a/crates/pyannote-local/src/embedding.rs b/crates/pyannote-local/src/embedding.rs index 17b095915e..0186e7ce32 100644 --- a/crates/pyannote-local/src/embedding.rs +++ b/crates/pyannote-local/src/embedding.rs @@ -41,11 +41,6 @@ impl EmbeddingExtractor { let embeddings = ort_out.iter().copied().collect::>(); Ok(embeddings) } - - pub fn cluster(&self, _n_clusters: usize, embeddings: &[f32]) -> Vec { - let assignments = vec![0; embeddings.len()]; - assignments - } } #[cfg(test)] diff --git a/crates/pyannote-local/src/identify.rs b/crates/pyannote-local/src/identify.rs new file mode 100644 index 0000000000..ada4daeb12 --- /dev/null +++ b/crates/pyannote-local/src/identify.rs @@ -0,0 +1,165 @@ +use std::collections::HashMap; + +use simsimd::SpatialSimilarity; + +pub struct EmbeddingManager { + max_speakers: usize, + speakers: HashMap>, + speaker_counts: HashMap, + next_speaker_id: usize, + threshold: f32, +} + +impl EmbeddingManager { + pub fn new(max_speakers: usize, threshold: f32) -> Self { + Self { + max_speakers, + speakers: HashMap::new(), + speaker_counts: HashMap::new(), + next_speaker_id: 0, + threshold, + } + } + + pub fn identify(&mut self, embedding: &[f32]) -> usize { + let (best_id, best_similarity) = self.find_best_match(embedding); + + if let Some(id) = best_id { + if best_similarity > self.threshold { + self.update_centroid(id, embedding); + return id; + } + } + + if self.speakers.len() < self.max_speakers { + let id = self.next_speaker_id; + self.next_speaker_id += 1; + self.speakers.insert(id, embedding.to_vec()); + self.speaker_counts.insert(id, 1); + return id; + } + + if let Some(id) = best_id { + self.update_centroid(id, embedding); + id + } else { + 0 + } + } + + fn update_centroid(&mut self, id: usize, embedding: &[f32]) { + let count = self.speaker_counts.entry(id).or_insert(1); + if let Some(centroid) = self.speakers.get_mut(&id) { + let n = *count as f32; + for (c, &e) in centroid.iter_mut().zip(embedding.iter()) { + *c = (*c * n + e) / (n + 1.0); + } + *count += 1; + } + } + + fn find_best_match(&self, embedding: &[f32]) -> (Option, f32) { + let mut best_id = None; + let mut best_similarity = f32::NEG_INFINITY; + + for (&id, known) in &self.speakers { + let distance = f32::cosine(embedding, known).unwrap_or(1.0); + let similarity = 1.0 - distance as f32; + if similarity > best_similarity { + best_similarity = similarity; + best_id = Some(id); + } + } + + (best_id, best_similarity) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_embedding(dim: usize, dominant_axis: usize) -> Vec { + let mut emb = vec![0.0f32; dim]; + emb[dominant_axis] = 1.0; + emb + } + + #[test] + fn test_identify_matches_similar() { + let mut manager = EmbeddingManager::new(6, 0.5); + + let emb1 = make_embedding(128, 0); + let mut emb2 = make_embedding(128, 0); + emb2[1] = 0.1; + + let id1 = manager.identify(&emb1); + let id2 = manager.identify(&emb2); + assert_eq!(id1, id2, "similar embeddings should match"); + } + + #[test] + fn test_identify_separates_different() { + let mut manager = EmbeddingManager::new(6, 0.5); + + let emb1 = make_embedding(128, 0); + let emb2 = make_embedding(128, 64); + + let id1 = manager.identify(&emb1); + let id2 = manager.identify(&emb2); + assert_ne!(id1, id2, "orthogonal embeddings should get different IDs"); + } + + #[test] + fn test_max_speakers_limit() { + let mut manager = EmbeddingManager::new(2, 0.5); + + let emb1 = make_embedding(128, 0); + let emb2 = make_embedding(128, 64); + let emb3 = make_embedding(128, 127); + + let id1 = manager.identify(&emb1); + let id2 = manager.identify(&emb2); + let id3 = manager.identify(&emb3); + + assert_ne!(id1, id2); + let unique: std::collections::HashSet<_> = [id1, id2, id3].into_iter().collect(); + assert!( + unique.len() <= 2, + "should not exceed max_speakers=2, got {unique:?}" + ); + } + + #[test] + #[ignore] + fn test_identify_same_speaker_real_audio() { + use crate::embedding::EmbeddingExtractor; + use dasp::sample::{FromSample, Sample}; + + fn get_audio>(path: &str) -> Vec { + let base = std::path::Path::new(env!("CARGO_MANIFEST_DIR")); + let p = base.join("src/data").join(path); + let f32_samples = rodio::Decoder::try_from(std::fs::File::open(p).unwrap()) + .unwrap() + .collect::>(); + f32_samples + .iter() + .map(|s| s.to_sample()) + .collect::>() + } + + let mut extractor = EmbeddingExtractor::new(); + let mut manager = EmbeddingManager::new(6, 0.5); + + let emb1 = extractor + .compute(get_audio::("male_welcome_1.mp3").into_iter()) + .unwrap(); + let emb2 = extractor + .compute(get_audio::("male_welcome_2.mp3").into_iter()) + .unwrap(); + + let id1 = manager.identify(&emb1); + let id2 = manager.identify(&emb2); + assert_eq!(id1, id2, "same speaker should get same ID"); + } +} diff --git a/crates/pyannote-local/src/lib.rs b/crates/pyannote-local/src/lib.rs index ef2ad579b1..c821184331 100644 --- a/crates/pyannote-local/src/lib.rs +++ b/crates/pyannote-local/src/lib.rs @@ -1,4 +1,6 @@ +pub mod diarize; pub mod embedding; +pub mod identify; pub mod segmentation; mod error; diff --git a/plugins/listener2/Cargo.toml b/plugins/listener2/Cargo.toml index 14dcab1981..fa061c31f8 100644 --- a/plugins/listener2/Cargo.toml +++ b/plugins/listener2/Cargo.toml @@ -19,6 +19,10 @@ tauri-plugin-settings = { workspace = true } hypr-audio-utils = { workspace = true } hypr-host = { workspace = true } hypr-language = { workspace = true } +hypr-pyannote-local = { workspace = true } + +dasp = { workspace = true } +rodio = { workspace = true } owhisper-client = { workspace = true, features = ["argmax"] } owhisper-interface = { workspace = true } diff --git a/plugins/listener2/js/bindings.gen.ts b/plugins/listener2/js/bindings.gen.ts index a7d3d873b1..a1e36ca7a7 100644 --- a/plugins/listener2/js/bindings.gen.ts +++ b/plugins/listener2/js/bindings.gen.ts @@ -30,6 +30,14 @@ async exportToVtt(sessionId: string, words: VttWord[]) : Promise> { + try { + return { status: "ok", data: await TAURI_INVOKE("plugin:listener2|diarize_session", { sessionId, maxSpeakers }) }; +} catch (e) { + if(e instanceof Error) throw e; + else return { status: "error", error: e as any }; +} +}, async isSupportedLanguagesBatch(provider: string, model: string | null, languages: string[]) : Promise> { try { return { status: "ok", data: await TAURI_INVOKE("plugin:listener2|is_supported_languages_batch", { provider, model, languages }) }; @@ -79,6 +87,7 @@ export type BatchProvider = "deepgram" | "soniox" | "assemblyai" | "am" export type BatchResponse = { metadata: JsonValue; results: BatchResults } export type BatchResults = { channels: BatchChannel[] } export type BatchWord = { word: string; start: number; end: number; confidence: number; speaker: number | null; punctuated_word: string | null } +export type DiarizationSegment = { start: number; end: number; speaker: number } export type JsonValue = null | boolean | number | string | JsonValue[] | Partial<{ [key in string]: JsonValue }> export type StreamAlternatives = { transcript: string; words: StreamWord[]; confidence: number; languages?: string[] } export type StreamChannel = { alternatives: StreamAlternatives[] } diff --git a/plugins/listener2/permissions/autogenerated/commands/diarize_session.toml b/plugins/listener2/permissions/autogenerated/commands/diarize_session.toml new file mode 100644 index 0000000000..5599dd120f --- /dev/null +++ b/plugins/listener2/permissions/autogenerated/commands/diarize_session.toml @@ -0,0 +1,13 @@ +# Automatically generated - DO NOT EDIT! + +"$schema" = "../../schemas/schema.json" + +[[permission]] +identifier = "allow-diarize-session" +description = "Enables the diarize_session command without any pre-configured scope." +commands.allow = ["diarize_session"] + +[[permission]] +identifier = "deny-diarize-session" +description = "Denies the diarize_session command without any pre-configured scope." +commands.deny = ["diarize_session"] diff --git a/plugins/listener2/permissions/autogenerated/reference.md b/plugins/listener2/permissions/autogenerated/reference.md index 2b209550ec..7c0d74324b 100644 --- a/plugins/listener2/permissions/autogenerated/reference.md +++ b/plugins/listener2/permissions/autogenerated/reference.md @@ -7,6 +7,7 @@ Default permissions for the plugin - `allow-run-batch` - `allow-parse-subtitle` - `allow-export-to-vtt` +- `allow-diarize-session` - `allow-is-supported-languages-batch` - `allow-suggest-providers-for-languages-batch` - `allow-list-documented-language-codes-batch` @@ -20,6 +21,32 @@ Default permissions for the plugin + + + +`listener2:allow-diarize-session` + + + + +Enables the diarize_session command without any pre-configured scope. + + + + + + + +`listener2:deny-diarize-session` + + + + +Denies the diarize_session command without any pre-configured scope. + + + + diff --git a/plugins/listener2/permissions/default.toml b/plugins/listener2/permissions/default.toml index 13233e78d2..86a487f3e5 100644 --- a/plugins/listener2/permissions/default.toml +++ b/plugins/listener2/permissions/default.toml @@ -4,6 +4,7 @@ permissions = [ "allow-run-batch", "allow-parse-subtitle", "allow-export-to-vtt", + "allow-diarize-session", "allow-is-supported-languages-batch", "allow-suggest-providers-for-languages-batch", "allow-list-documented-language-codes-batch", diff --git a/plugins/listener2/permissions/schemas/schema.json b/plugins/listener2/permissions/schemas/schema.json index 391f1e79a1..6c5c94dd29 100644 --- a/plugins/listener2/permissions/schemas/schema.json +++ b/plugins/listener2/permissions/schemas/schema.json @@ -294,6 +294,18 @@ "PermissionKind": { "type": "string", "oneOf": [ + { + "description": "Enables the diarize_session command without any pre-configured scope.", + "type": "string", + "const": "allow-diarize-session", + "markdownDescription": "Enables the diarize_session command without any pre-configured scope." + }, + { + "description": "Denies the diarize_session command without any pre-configured scope.", + "type": "string", + "const": "deny-diarize-session", + "markdownDescription": "Denies the diarize_session command without any pre-configured scope." + }, { "description": "Enables the export_to_vtt command without any pre-configured scope.", "type": "string", @@ -367,10 +379,10 @@ "markdownDescription": "Denies the suggest_providers_for_languages_batch command without any pre-configured scope." }, { - "description": "Default permissions for the plugin\n#### This default permission set includes:\n\n- `allow-run-batch`\n- `allow-parse-subtitle`\n- `allow-export-to-vtt`\n- `allow-is-supported-languages-batch`\n- `allow-suggest-providers-for-languages-batch`\n- `allow-list-documented-language-codes-batch`", + "description": "Default permissions for the plugin\n#### This default permission set includes:\n\n- `allow-run-batch`\n- `allow-parse-subtitle`\n- `allow-export-to-vtt`\n- `allow-diarize-session`\n- `allow-is-supported-languages-batch`\n- `allow-suggest-providers-for-languages-batch`\n- `allow-list-documented-language-codes-batch`", "type": "string", "const": "default", - "markdownDescription": "Default permissions for the plugin\n#### This default permission set includes:\n\n- `allow-run-batch`\n- `allow-parse-subtitle`\n- `allow-export-to-vtt`\n- `allow-is-supported-languages-batch`\n- `allow-suggest-providers-for-languages-batch`\n- `allow-list-documented-language-codes-batch`" + "markdownDescription": "Default permissions for the plugin\n#### This default permission set includes:\n\n- `allow-run-batch`\n- `allow-parse-subtitle`\n- `allow-export-to-vtt`\n- `allow-diarize-session`\n- `allow-is-supported-languages-batch`\n- `allow-suggest-providers-for-languages-batch`\n- `allow-list-documented-language-codes-batch`" } ] } diff --git a/plugins/listener2/src/commands.rs b/plugins/listener2/src/commands.rs index 3100a8c0db..98a969170b 100644 --- a/plugins/listener2/src/commands.rs +++ b/plugins/listener2/src/commands.rs @@ -2,6 +2,7 @@ use owhisper_client::AdapterKind; use std::str::FromStr; use crate::{BatchParams, Listener2PluginExt, Subtitle, VttWord}; +use hypr_pyannote_local::diarize::DiarizationSegment; #[tauri::command] #[specta::specta] @@ -99,6 +100,19 @@ pub async fn suggest_providers_for_languages_batch( Ok(supported) } +#[tauri::command] +#[specta::specta] +pub async fn diarize_session( + app: tauri::AppHandle, + session_id: String, + max_speakers: usize, +) -> Result, String> { + app.listener2() + .diarize_session(session_id, max_speakers) + .await + .map_err(|e| e.to_string()) +} + #[tauri::command] #[specta::specta] pub async fn list_documented_language_codes_batch( diff --git a/plugins/listener2/src/error.rs b/plugins/listener2/src/error.rs index 0ee00641a7..a45d85bf0f 100644 --- a/plugins/listener2/src/error.rs +++ b/plugins/listener2/src/error.rs @@ -12,6 +12,8 @@ pub enum Error { SpawnError(#[from] ractor::SpawnErr), #[error("batch start failed: {0}")] BatchStartFailed(String), + #[error("diarization failed: {0}")] + DiarizeFailed(String), } impl Serialize for Error { diff --git a/plugins/listener2/src/ext.rs b/plugins/listener2/src/ext.rs index 7d5cd05bc0..1e7c30546a 100644 --- a/plugins/listener2/src/ext.rs +++ b/plugins/listener2/src/ext.rs @@ -95,6 +95,74 @@ impl<'a, R: tauri::Runtime, M: tauri::Manager> Listener2<'a, R, M> { } } + pub async fn diarize_session( + &self, + session_id: String, + max_speakers: usize, + ) -> Result, crate::Error> { + use dasp::sample::Sample; + use rodio::Source; + use tauri_plugin_settings::SettingsPluginExt; + + let base = self + .manager + .settings() + .cached_vault_base() + .map_err(|e| crate::Error::DiarizeFailed(e.to_string()))?; + + let session_dir = base.join("sessions").join(&session_id); + + let audio_path = if session_dir.join("audio.wav").exists() { + session_dir.join("audio.wav") + } else if session_dir.join("audio.ogg").exists() { + session_dir.join("audio.ogg") + } else { + return Err(crate::Error::DiarizeFailed( + "no audio file found".to_string(), + )); + }; + + let segments = tokio::task::spawn_blocking(move || { + let decoder = hypr_audio_utils::source_from_path(&audio_path) + .map_err(|e| crate::Error::DiarizeFailed(e.to_string()))?; + + let channels = decoder.channels() as usize; + let sample_rate = decoder.sample_rate(); + let f32_samples: Vec = decoder.collect(); + + let mono: Vec = if channels > 1 { + f32_samples + .chunks_exact(channels) + .map(|frame| frame.iter().sum::() / channels as f32) + .collect() + } else { + f32_samples + }; + + let resampled = if sample_rate != 16000 { + let source = rodio::buffer::SamplesBuffer::new(1, sample_rate, mono); + hypr_audio_utils::resample_audio(source, 16000) + .map_err(|e| crate::Error::DiarizeFailed(e.to_string()))? + } else { + mono + }; + + let i16_samples: Vec = resampled.iter().map(|s| s.to_sample()).collect(); + + let opts = hypr_pyannote_local::diarize::DiarizeOptions { + max_speakers, + ..Default::default() + }; + + hypr_pyannote_local::diarize::diarize(&i16_samples, 16000, Some(opts)) + .map_err(|e| crate::Error::DiarizeFailed(e.to_string())) + }) + .await + .map_err(|e| crate::Error::DiarizeFailed(format!("join error: {e}")))?; + + segments + } + pub fn parse_subtitle(&self, path: String) -> Result { use aspasia::TimedSubtitleFile; let sub = TimedSubtitleFile::new(&path).unwrap(); diff --git a/plugins/listener2/src/lib.rs b/plugins/listener2/src/lib.rs index 6a21df2a0b..676f8763f0 100644 --- a/plugins/listener2/src/lib.rs +++ b/plugins/listener2/src/lib.rs @@ -29,6 +29,7 @@ fn make_specta_builder() -> tauri_specta::Builder { commands::run_batch::, commands::parse_subtitle::, commands::export_to_vtt::, + commands::diarize_session::, commands::is_supported_languages_batch::, commands::suggest_providers_for_languages_batch::, commands::list_documented_language_codes_batch::,