Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 5 additions & 11 deletions moshi/moshi/offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,34 +90,31 @@ def wrap_with_system_tags(text: str) -> str:
return f"<system> {cleaned} <system>"


def warmup(mimi: MimiModel, other_mimi: MimiModel, lm_gen: LMGen, device: str, frame_size: int):
def warmup(mimi: MimiModel, lm_gen: LMGen, device: str, frame_size: int):
"""Run a short warmup loop to initialize CUDA graphs and streaming state.

Replicates the same warmup behavior as server.py: zeros → encode → LMGen.step → decode.
"""
for _ in range(4):
chunk = torch.zeros(1, 1, frame_size, dtype=torch.float32, device=device)
codes = mimi.encode(chunk)
_ = other_mimi.encode(chunk)
for c in range(codes.shape[-1]):
tokens = lm_gen.step(codes[:, :, c : c + 1])
if tokens is None:
continue
# Decode agent audio channels to ensure decode graphs/states are primed
_ = mimi.decode(tokens[:, 1:9])
_ = other_mimi.decode(tokens[:, 1:9])
if torch.cuda.is_available():
torch.cuda.synchronize()


def decode_tokens_to_pcm(mimi: MimiModel, other_mimi: MimiModel, lm_gen: LMGen, tokens: torch.Tensor) -> np.ndarray:
def decode_tokens_to_pcm(mimi: MimiModel, tokens: torch.Tensor) -> np.ndarray:
"""Decode a single step of model tokens to PCM using Mimi.

tokens is shaped [B, dep_q+1, 1]; channels 1..dep_q are the agent audio codebooks.
Returns a 1D float32 numpy array (mono) for the current frame.
"""
pcm = mimi.decode(tokens[:, 1:9])
_ = other_mimi.decode(tokens[:, 1:9])
pcm = pcm.detach().cpu().numpy()[0, 0]
return pcm

Expand Down Expand Up @@ -191,7 +188,6 @@ def run_inference(
if mimi_weight is None:
mimi_weight = hf_hub_download(hf_repo, loaders.MIMI_NAME) # type: ignore
mimi = loaders.get_mimi(mimi_weight, device)
other_mimi = loaders.get_mimi(mimi_weight, device)
log("info", "mimi loaded")

# 2) Load tokenizer
Expand Down Expand Up @@ -224,12 +220,11 @@ def run_inference(
)
# Keep models in streaming mode similar to the server
mimi.streaming_forever(1)
other_mimi.streaming_forever(1)
lm_gen.streaming_forever(1)

# 5) Warmup
log("info", "warming up the model")
warmup(mimi, other_mimi, lm_gen, device, frame_size)
warmup(mimi, lm_gen, device, frame_size)

# 6) Prompt configuration (text + voice)
# System text tokens (k=0) and agent voice-prompt audio (k=1..dep_q) are forced
Expand All @@ -248,7 +243,6 @@ def run_inference(
# - Text prompt injection
# - Final audio silence
mimi.reset_streaming()
other_mimi.reset_streaming()
lm_gen.reset_streaming()
lm_gen.step_system_prompts(mimi)
# Reset mimi streaming after voice prompt encoding
Expand Down Expand Up @@ -280,7 +274,7 @@ def run_inference(
if tokens is None:
continue
# Decode current sampled agent frame to PCM
pcm = decode_tokens_to_pcm(mimi, other_mimi, lm_gen, tokens)
pcm = decode_tokens_to_pcm(mimi, tokens)
generated_frames.append(pcm)
# Decode text token
text_token = tokens[0, 0, 0].item()
Expand Down Expand Up @@ -428,4 +422,4 @@ def main():


if __name__ == "__main__":
main()
main()
12 changes: 1 addition & 11 deletions moshi/moshi/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,16 +89,14 @@ def wrap_with_system_tags(text: str) -> str:
@dataclass
class ServerState:
mimi: MimiModel
other_mimi: MimiModel
text_tokenizer: sentencepiece.SentencePieceProcessor
lm_gen: LMGen
lock: asyncio.Lock

def __init__(self, mimi: MimiModel, other_mimi: MimiModel, text_tokenizer: sentencepiece.SentencePieceProcessor,
def __init__(self, mimi: MimiModel, text_tokenizer: sentencepiece.SentencePieceProcessor,
lm: LMModel, device: str | torch.device, voice_prompt_dir: str | None = None,
save_voice_prompt_embeddings: bool = False):
self.mimi = mimi
self.other_mimi = other_mimi
self.text_tokenizer = text_tokenizer
self.device = device
self.voice_prompt_dir = voice_prompt_dir
Expand All @@ -113,20 +111,17 @@ def __init__(self, mimi: MimiModel, other_mimi: MimiModel, text_tokenizer: sente

self.lock = asyncio.Lock()
self.mimi.streaming_forever(1)
self.other_mimi.streaming_forever(1)
self.lm_gen.streaming_forever(1)

def warmup(self):
for _ in range(4):
chunk = torch.zeros(1, 1, self.frame_size, dtype=torch.float32, device=self.device)
codes = self.mimi.encode(chunk)
_ = self.other_mimi.encode(chunk)
for c in range(codes.shape[-1]):
tokens = self.lm_gen.step(codes[:, :, c: c + 1])
if tokens is None:
continue
_ = self.mimi.decode(tokens[:, 1:9])
_ = self.other_mimi.decode(tokens[:, 1:9])

if self.device.type == 'cuda':
torch.cuda.synchronize()
Expand Down Expand Up @@ -222,14 +217,12 @@ async def opus_loop():
chunk = torch.from_numpy(chunk)
chunk = chunk.to(device=self.device)[None, None]
codes = self.mimi.encode(chunk)
_ = self.other_mimi.encode(chunk)
for c in range(codes.shape[-1]):
tokens = self.lm_gen.step(codes[:, :, c: c + 1])
if tokens is None:
continue
assert tokens.shape[1] == self.lm_gen.lm_model.dep_q + 1
main_pcm = self.mimi.decode(tokens[:, 1:9])
_ = self.other_mimi.decode(tokens[:, 1:9])
main_pcm = main_pcm.cpu()
opus_writer.append_pcm(main_pcm[0, 0].numpy())
text_token = tokens[0, 0, 0].item()
Expand Down Expand Up @@ -263,7 +256,6 @@ async def send_loop():
opus_writer = sphn.OpusStreamWriter(self.mimi.sample_rate)
opus_reader = sphn.OpusStreamReader(self.mimi.sample_rate)
self.mimi.reset_streaming()
self.other_mimi.reset_streaming()
self.lm_gen.reset_streaming()
async def is_alive():
if close or ws.closed:
Expand Down Expand Up @@ -432,7 +424,6 @@ def main():
if args.mimi_weight is None:
args.mimi_weight = hf_hub_download(args.hf_repo, loaders.MIMI_NAME)
mimi = loaders.get_mimi(args.mimi_weight, args.device)
other_mimi = loaders.get_mimi(args.mimi_weight, args.device)
logger.info("mimi loaded")

if args.tokenizer is None:
Expand All @@ -447,7 +438,6 @@ def main():
logger.info("moshi loaded")
state = ServerState(
mimi=mimi,
other_mimi=other_mimi,
text_tokenizer=text_tokenizer,
lm=lm,
device=args.device,
Expand Down