Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to reduce memory when decoding on CPU? #1672

Open
tz301 opened this issue Jun 29, 2024 · 8 comments
Open

How to reduce memory when decoding on CPU? #1672

tz301 opened this issue Jun 29, 2024 · 8 comments

Comments

@tz301
Copy link

tz301 commented Jun 29, 2024

With zipformer I can get good performance.

Currently, when I decode on CPU one by one (not using batch), the memory cost will go to 2.5G. The token size is 5000 and using greedy_search to decode. I try to reduce it to 4000, but the memory cost seem not decrease much.

Any idea to reduce it without obvious performance degrade?

Some model configurations below:
num-encoder-layers=2,2,2,3,2,2
downsampling-factor=1,2,4,8,4,2
feedforward-dim=256,384,512,768,512,384
num-heads=4,4,4,4,4,4
encoder-dim=192,256,256,384,256,256
query-head-dim=24
value-head-dim=8
pos-head-dim=4
pos-dim=24
encoder-unmasked-dim=192,192,256,256,256,192
cnn-module-kernel=31,31,15,15,15,31
decoder-dim=256
joiner-dim=256

@csukuangfj
Copy link
Collaborator

Could you tell us which script you are using?

Have you changed any code or just used the original code from us?

Also, please tell us whether you are using a streaming or a non-streaming model and what is the typical wave duration of your test file.

It would be great if you can post the complete decoding command.

@tz301
Copy link
Author

tz301 commented Jul 1, 2024

Could you tell us which script you are using?

Have you changed any code or just used the original code from us?

Also, please tell us whether you are using a streaming or a non-streaming model and what is the typical wave duration of your test file.

It would be great if you can post the complete decoding command.

Hi @csukuangfj,

I have export the model using torch.jit.export and write my one decode code, which is used in offline scenario. However, the core decoding code is from icefall.

Actually my wave is usually long than 1 minite, but the max duration of wave file for asr is 20s (which is force cut by energy vad).

I have attached code below.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import logging
import math
from argparse import ArgumentParser
from pathlib import Path

import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import greedy_search_batch
from icefall.lexicon import Lexicon
from pydub import AudioSegment
from torch.nn.utils.rnn import pad_sequence

_LOGGER = logging.getLogger(__name__)


class Decoder:

    def __init__(self, lang_dir):
        self._sp = spm.SentencePieceProcessor()
        self._sp.load(str(lang_dir / 'bpe.model'))

        self._lexicon = Lexicon(lang_dir)

        self.blank_id = self._lexicon.token_table['<blk>']
        self.vocab_size = max(self._lexicon.tokens) + 1
        _LOGGER.info('Decoder Init Succeed.')

    @property
    def lexicon(self):
        return self._lexicon

    def decode_tokens(self, tokens):
        token_table = self._lexicon.token_table
        return self._sp.decode([token_table[idx] for idx in tokens])

    def decode(self, model, encoder_out, encoder_out_lens):
        pred = list()
        hyp_tokens = greedy_search_batch(
            model=model,
            encoder_out=encoder_out,
            encoder_out_lens=encoder_out_lens,
        )
        for i in range(encoder_out.size(0)):
            pred.append(self.decode_tokens(hyp_tokens[i]))
        return pred


class Asr:

    def __init__(self, model_dir):
        torch.set_num_threads(1)
        torch.set_num_interop_threads(1)
        model_dir = Path(model_dir)
        self._device = torch.device('cpu')
        self._feat_extractor = self.__get_feat_extractor()
        self._model = self._init_model(model_dir)
        self._decoder = Decoder(model_dir / 'lang',)
        _LOGGER.info('Asr Model init succeed.')

    def __get_feat_extractor(self):
        opts = kaldifeat.FbankOptions()
        opts.device = self._device
        opts.frame_opts.dither = 0
        opts.frame_opts.snip_edges = False
        opts.frame_opts.samp_freq = 8000
        opts.mel_opts.num_bins = 80
        feat_extractor = kaldifeat.Fbank(opts)
        _LOGGER.info('Fbank feat extractor init succeed.')
        return feat_extractor

    def _init_model(self, model_dir):
        model = torch.jit.load(model_dir / 'model.pt')
        model.eval()
        model.to(self._device)
        _LOGGER.info('ZipFormer Load succeed.')
        return model

    @staticmethod
    def _read_wav_files(wav_files):
        waves = list()
        for wav_file in wav_files:
            wave, sample_rate = torchaudio.load(wav_file)
            waves.append(wave[0].contiguous())
        return waves

    def _get_feature(self, wav_files):
        waves = self._read_wav_files(wav_files)
        waves = [w.to(self._device) for w in waves]

        features = self._feat_extractor(waves)
        features = pad_sequence(
            features,
            batch_first=True,
            padding_value=math.log(1e-10)
        )

        feature_lengths = torch.tensor(
            [f.size(0) for f in features], device=self._device
        )
        return features, feature_lengths

    def _encode(self, features, feature_lengths):
        """编码.

        Args:
            features: 特征.
            feature_lengths: 特征长度.

        Returns:
            编码输出和编码输出长度.
        """
        encoder_out, encoder_out_lens = self._model.encoder(
            features=features,
            feature_lengths=feature_lengths
        )
        return encoder_out, encoder_out_lens

    def recognize(self, wav_files):
        features, feature_lengths = self._get_feature(wav_files)
        encoder_out, encoder_out_lens = self._encode(features, feature_lengths)
        texts = self._decoder.decode(self._model, encoder_out, encoder_out_lens)
        return texts


def _main():
    parser = ArgumentParser('recognize')
    parser.add_argument('model_dir', type=Path, help='model directory')
    parser.add_argument('wav_dir', type=Path, help='wav directory')
    args = parser.parse_args()

    asr = Asr(args.model_dir)
    for wav_file in args.wav_dir.iterdir():
        duration = AudioSegment.from_file(wav_file).duration_seconds
        text = asr.recognize([wav_file])[0]
        _LOGGER.info(f'[{wav_file}]: [{duration:.3f}s] [{text}]')


if __name__ == '__main__':
    LOGGER_FORMAT = ('%(asctime)s.%(msecs)03d - %(name)s:%(lineno)s '
                     '- %(funcName)s() - %(levelname)s - %(message)s')
    logging.basicConfig(format=LOGGER_FORMAT, level=logging.INFO)
    _main()

@csukuangfj
Copy link
Collaborator

csukuangfj commented Jul 1, 2024

I see.

Please use

@torch.no_grad()
def _main():

as what we are doing in decoding.

@tz301
Copy link
Author

tz301 commented Jul 2, 2024

Hi @csukuangfj,

Yeah, add @torch.no_grad() seems work, the memory decrease from 2.5G to 2.1G.

I'm not sure is it normal to use ~2G memory, or any other idea to decrease it?

@csukuangfj
Copy link
Collaborator

Could you post your updated code?

@tz301
Copy link
Author

tz301 commented Jul 2, 2024

@csukuangfj

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import logging
import math
from argparse import ArgumentParser
from pathlib import Path

import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import greedy_search_batch
from icefall.lexicon import Lexicon
from pydub import AudioSegment
from torch.nn.utils.rnn import pad_sequence

_LOGGER = logging.getLogger(__name__)


class Decoder:

    def __init__(self, lang_dir):
        self._sp = spm.SentencePieceProcessor()
        self._sp.load(str(lang_dir / 'bpe.model'))

        self._lexicon = Lexicon(lang_dir)

        self.blank_id = self._lexicon.token_table['<blk>']
        self.vocab_size = max(self._lexicon.tokens) + 1
        _LOGGER.info('Decoder Init Succeed.')

    @property
    def lexicon(self):
        return self._lexicon

    def decode_tokens(self, tokens):
        token_table = self._lexicon.token_table
        return self._sp.decode([token_table[idx] for idx in tokens])

    def decode(self, model, encoder_out, encoder_out_lens):
        pred = list()
        hyp_tokens = greedy_search_batch(
            model=model,
            encoder_out=encoder_out,
            encoder_out_lens=encoder_out_lens,
        )
        for i in range(encoder_out.size(0)):
            pred.append(self.decode_tokens(hyp_tokens[i]))
        return pred


class Asr:

    def __init__(self, model_dir):
        torch.set_num_threads(1)
        torch.set_num_interop_threads(1)
        model_dir = Path(model_dir)
        self._device = torch.device('cpu')
        self._feat_extractor = self.__get_feat_extractor()
        self._model = self._init_model(model_dir)
        self._decoder = Decoder(model_dir / 'lang',)
        _LOGGER.info('Asr Model init succeed.')

    def __get_feat_extractor(self):
        opts = kaldifeat.FbankOptions()
        opts.device = self._device
        opts.frame_opts.dither = 0
        opts.frame_opts.snip_edges = False
        opts.frame_opts.samp_freq = 8000
        opts.mel_opts.num_bins = 80
        feat_extractor = kaldifeat.Fbank(opts)
        _LOGGER.info('Fbank feat extractor init succeed.')
        return feat_extractor

    def _init_model(self, model_dir):
        model = torch.jit.load(model_dir / 'model.pt')
        model.eval()
        model.to(self._device)
        _LOGGER.info('ZipFormer Load succeed.')
        return model

    @staticmethod
    def _read_wav_files(wav_files):
        waves = list()
        for wav_file in wav_files:
            wave, sample_rate = torchaudio.load(wav_file)
            waves.append(wave[0].contiguous())
        return waves

    def _get_feature(self, wav_files):
        waves = self._read_wav_files(wav_files)
        waves = [w.to(self._device) for w in waves]

        features = self._feat_extractor(waves)
        features = pad_sequence(
            features,
            batch_first=True,
            padding_value=math.log(1e-10)
        )

        feature_lengths = torch.tensor(
            [f.size(0) for f in features], device=self._device
        )
        return features, feature_lengths

    def _encode(self, features, feature_lengths):
        """编码.

        Args:
            features: 特征.
            feature_lengths: 特征长度.

        Returns:
            编码输出和编码输出长度.
        """
        encoder_out, encoder_out_lens = self._model.encoder(
            features=features,
            feature_lengths=feature_lengths
        )
        return encoder_out, encoder_out_lens

    def recognize(self, wav_files):
        features, feature_lengths = self._get_feature(wav_files)
        encoder_out, encoder_out_lens = self._encode(features, feature_lengths)
        texts = self._decoder.decode(self._model, encoder_out, encoder_out_lens)
        return texts


@torch.no_grad()
def _main():
    parser = ArgumentParser('recognize')
    parser.add_argument('model_dir', type=Path, help='model directory')
    parser.add_argument('wav_dir', type=Path, help='wav directory')
    args = parser.parse_args()

    asr = Asr(args.model_dir)
    for wav_file in args.wav_dir.iterdir():
        duration = AudioSegment.from_file(wav_file).duration_seconds
        text = asr.recognize([wav_file])[0]
        _LOGGER.info(f'[{wav_file}]: [{duration:.3f}s] [{text}]')


if __name__ == '__main__':
    LOGGER_FORMAT = ('%(asctime)s.%(msecs)03d - %(name)s:%(lineno)s '
                     '- %(funcName)s() - %(levelname)s - %(message)s')
    logging.basicConfig(format=LOGGER_FORMAT, level=logging.INFO)
    _main()

@csukuangfj
Copy link
Collaborator

Does the memory grow linearly from 0 to 2.1GB and then keep at 2.1 GB?


_LOGGER.info(f'[{wav_file}]: [{duration:.3f}s] [{text}]')

Could you give the output of the above log? What is the max value of duration?

@tz301
Copy link
Author

tz301 commented Jul 3, 2024

Does the memory grow linearly from 0 to 2.1GB and then keep at 2.1 GB?

_LOGGER.info(f'[{wav_file}]: [{duration:.3f}s] [{text}]')

Could you give the output of the above log? What is the max value of duration?

The max duration is 20s in my wav files.

The memory first grow to around 1.5G for the first few wavs, then grow slowly to 2.1G and keep at 2.1G.

The first few wavs (around 5 wavs) is extremely slow, may cost 1~2 minute to finish decode. I'm not sure if it's normal that the warm up for this asr model need this time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants