Skip to content

Commit

Permalink
Handle audio files less than 10s for speaker diarization.
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Oct 11, 2024
1 parent 1d061df commit b200bd9
Showing 1 changed file with 32 additions and 3 deletions.
35 changes: 32 additions & 3 deletions sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,14 @@ class OfflineSpeakerDiarizationPyannoteImpl

segmentations.clear();

if (labels.size() == 1) {
if (callback) {
callback(1, 1, callback_arg);
}

return HandleOneChunkSpecialCase(labels[0], n);
}

// labels[i] is a 0-1 matrix of shape (num_frames, num_speakers)

// speaker count per frame
Expand Down Expand Up @@ -201,7 +209,7 @@ class OfflineSpeakerDiarizationPyannoteImpl
}

int32_t num_chunks = (n - window_size) / window_shift + 1;
bool has_last_chunk = (n - window_size) % window_shift > 0;
bool has_last_chunk = ((n - window_size) % window_shift) > 0;

ans.reserve(num_chunks + has_last_chunk);

Expand Down Expand Up @@ -524,9 +532,9 @@ class OfflineSpeakerDiarizationPyannoteImpl
count(seq, Eigen::all).array() += labels[i].array();
}

bool has_last_chunk = (num_samples - window_size) % window_shift > 0;
bool has_last_chunk = ((num_samples - window_size) % window_shift) > 0;

if (has_last_chunk) {
if (!has_last_chunk) {
return count;
}

Expand Down Expand Up @@ -622,6 +630,27 @@ class OfflineSpeakerDiarizationPyannoteImpl
return ans;
}

OfflineSpeakerDiarizationResult HandleOneChunkSpecialCase(
const Matrix2DInt32 &final_labels, int32_t num_samples) const {
const auto &meta_data = segmentation_model_.GetModelMetaData();
int32_t window_size = meta_data.window_size;
int32_t window_shift = meta_data.window_shift;
int32_t receptive_field_shift = meta_data.receptive_field_shift;

bool has_last_chunk = (num_samples - window_size) % window_shift > 0;
if (!has_last_chunk) {
return ComputeResult(final_labels);
}

int32_t num_frames = final_labels.rows();

int32_t new_num_frames = num_samples / receptive_field_shift;

num_frames = (new_num_frames <= num_frames) ? new_num_frames : num_frames;

return ComputeResult(final_labels(Eigen::seq(0, num_frames), Eigen::all));
}

void MergeSegments(
std::vector<OfflineSpeakerDiarizationSegment> *segments) const {
float min_duration_off = config_.min_duration_off;
Expand Down

0 comments on commit b200bd9

Please sign in to comment.