Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sensevoice timestamp error when process multiple VAD segments: lack index in encoder_out_lens #2339

Open
njzheng opened this issue Dec 26, 2024 · 1 comment

Comments

@njzheng
Copy link

njzheng commented Dec 26, 2024

https://github.com/modelscope/FunASR/blob/a3a1c55c4c674768e6cfef1dc9c10d2e1c3335c7/funasr/models/sense_voice/model.py#L928C20-L928C48

修改一些timestamp 对齐英文的bug, 见<--位置

funasr/models/sense_voice/model.py

            if output_timestamp:
                from itertools import groupby
                timestamp = []
                # NOTE: cannot match with tokenint, __file s --> __files
                tokens = tokenizer.text2tokens(text)[4:]
                if len(tokens) != len(token_int) - 4:
                    # NOTE: change back may cause id different, hope not affect prob much 
                    token_int = sum(tokenizer.encode(tokens), token_int[:4]) <--
                assert len(tokens) == len(token_int) - 4

                logits_speech = self.ctc.softmax(encoder_out)[i, 4:encoder_out_lens[i].item(), :]

                pred = logits_speech.argmax(-1).cpu()
                logits_speech[pred==self.blank_id, self.blank_id] = 0
                align = ctc_forced_align(
                    logits_speech.unsqueeze(0).float(),
                    torch.tensor(token_int[4:]).unsqueeze(0).long().to(logits_speech.device),
                    (**encoder_out_lens[i]** - 4).long(), <--
                    torch.tensor(len(token_int)-4).unsqueeze(0).long().to(logits_speech.device),
                    ignore_id=self.ignore_id,
                )


                _start = 0
                token_id = 0
                ts_max = encoder_out_lens[i] - 4
                # for pred_token, pred_frame in align:
                for pred_token, pred_frame in **groupby(align[0, :encoder_out_lens[i]])**:<--
                    _end = _start + len(list(pred_frame))
                    if pred_token != 0:
                        ts_left = max((_start*60-30)/1000, 0)
                        ts_right = min((_end*60-30)/1000, (ts_max*60-30)/1000)
                        timestamp.append([tokens[token_id], ts_left, ts_right])
                        token_id += 1
                    _start = _end
                import ipdb; ipdb.set_trace()
                # timestamp = self.post(timestamp)
                **word_new, timestamp = self.timestamp_post(timestamp)**<--
                result_i = {"key": key[i], "text": text, "timestamp": timestamp}
                results.append(result_i)
            else:
                result_i = {"key": key[i], "text": text}
                results.append(result_i)
    #rewrite the post function
    def timestamp_post(self, timestamp):
        timestamp_new = []
        word_new = []
        word_cat = ''
        for i, t in enumerate(timestamp):
            word, start, end = t
            if word == "▁":
                continue
            if i == 0:
                # timestamp_new.append([word, start, end])
                timestamp_new.append([int(start * 1000), int(end * 1000)])
                word_cat += word
            elif word.startswith("▁"):
                word = word[1:]
                # timestamp_new.append([word, start, end])
                timestamp_new.append([int(start * 1000), int(end * 1000)])
                word_new.append(word_cat)
                word_cat = word
            elif len(timestamp) > 1 and len(word) == 1:
                timestamp_new[-1][1] = int(end * 1000)
                word_cat += word
            else:
                # timestamp_new[-1][0] += word
                timestamp_new[-1][1] = int(end * 1000)
                word_cat += word
        if len(word_cat) > 0:
            word_new.append(word_cat)

        assert len(word_new) == len(timestamp_new)
        return word_new, timestamp_new

And other issue, such as remove the word "The." from the word_new and corresponding timestamp due to the process in rich_transcription_postprocess()

@psk-github
Copy link

你好,本地测试了一下,中英混合的情况下,输出时间戳好像是只有英文的,你这个修改是针对纯英文的修改吗?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants