diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h index 9667088d5..8f669e27c 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h @@ -99,6 +99,14 @@ class OfflineSpeakerDiarizationPyannoteImpl segmentations.clear(); + if (labels.size() == 1) { + if (callback) { + callback(1, 1, callback_arg); + } + + return HandleOneChunkSpecialCase(labels[0], n); + } + // labels[i] is a 0-1 matrix of shape (num_frames, num_speakers) // speaker count per frame @@ -201,7 +209,7 @@ class OfflineSpeakerDiarizationPyannoteImpl } int32_t num_chunks = (n - window_size) / window_shift + 1; - bool has_last_chunk = (n - window_size) % window_shift > 0; + bool has_last_chunk = ((n - window_size) % window_shift) > 0; ans.reserve(num_chunks + has_last_chunk); @@ -524,9 +532,9 @@ class OfflineSpeakerDiarizationPyannoteImpl count(seq, Eigen::all).array() += labels[i].array(); } - bool has_last_chunk = (num_samples - window_size) % window_shift > 0; + bool has_last_chunk = ((num_samples - window_size) % window_shift) > 0; - if (has_last_chunk) { + if (!has_last_chunk) { return count; } @@ -622,6 +630,27 @@ class OfflineSpeakerDiarizationPyannoteImpl return ans; } + OfflineSpeakerDiarizationResult HandleOneChunkSpecialCase( + const Matrix2DInt32 &final_labels, int32_t num_samples) const { + const auto &meta_data = segmentation_model_.GetModelMetaData(); + int32_t window_size = meta_data.window_size; + int32_t window_shift = meta_data.window_shift; + int32_t receptive_field_shift = meta_data.receptive_field_shift; + + bool has_last_chunk = (num_samples - window_size) % window_shift > 0; + if (!has_last_chunk) { + return ComputeResult(final_labels); + } + + int32_t num_frames = final_labels.rows(); + + int32_t new_num_frames = num_samples / receptive_field_shift; + + num_frames = (new_num_frames <= num_frames) ? new_num_frames : num_frames; + + return ComputeResult(final_labels(Eigen::seq(0, num_frames), Eigen::all)); + } + void MergeSegments( std::vector *segments) const { float min_duration_off = config_.min_duration_off;