From e388a9427146497f0fadda3cf1e113efc206f1c3 Mon Sep 17 00:00:00 2001 From: yaofengchen Date: Mon, 1 Dec 2025 07:32:30 +0000 Subject: [PATCH 1/6] refactor ascend op_backend --- .../backends/dlinfer/ascend/op_backend.py | 258 +++++++----------- 1 file changed, 92 insertions(+), 166 deletions(-) diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py index 72bad9f30c..f2a346f4d8 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py @@ -21,26 +21,15 @@ class SocVersion: Ascend310P: str = 'Ascend310P' Ascend910: str = 'Ascend910' - - @classmethod - @lru_cache(maxsize=1) - def device_name(cls) -> str: - try: - import torch_npu - return torch_npu.npu.get_device_name() - except ImportError: - logger.warning('Failed to import torch_npu. Please make sure torch_npu is installed correctly. ') - except Exception as e: - logger.warning(f'Error during Ascend get device name: {str(e)}. ' - 'Please check your Ascend environment configuration.') + device_name: str = torch.npu.get_device_name() @classmethod def is_Ascend310P(cls) -> bool: - return cls.device_name().startswith(cls.Ascend310P) + return cls.device_name.startswith(cls.Ascend310P) @classmethod def is_Ascend910(cls) -> bool: - return cls.device_name().startswith(cls.Ascend910) + return cls.device_name.startswith(cls.Ascend910) class AscendKVQuantMeta: @@ -108,12 +97,6 @@ def get_k_block_shape( ) -> Tuple[int, ...]: if SocVersion.is_Ascend910(): return (block_size, num_heads, head_size) - elif SocVersion.is_Ascend310P(): - return ( - (num_heads * head_size + 15) // 16, - block_size, - 16, - ) else: raise ValueError(f'dlinfer does not support {SocVersion.device_name()} device currently.') @@ -126,160 +109,109 @@ def get_v_block_shape( ) -> Tuple[int, ...]: if SocVersion.is_Ascend910(): return (block_size, num_heads, head_size) - elif SocVersion.is_Ascend310P(): - return ( - (num_heads * head_size + 15) // 16, - block_size, - 16, - ) else: raise ValueError(f'dlinfer does not support {SocVersion.device_name()} device currently.') - @classmethod - @lru_cache(maxsize=1) - def enable_aclgraph(cls) -> bool: - if os.getenv('ASCEND_GRAPH_MODE', 'aclgraph') == 'aclgraph' and not SocVersion.is_Ascend310P(): - return True - elif os.getenv('ASCEND_GRAPH_MODE', 'aclgraph') == 'atbgraph' or SocVersion.is_Ascend310P(): - return False - else: - raise ValueError(f"unsupported ASCEND_GRAPH_MODE: {os.getenv('ASCEND_GRAPH_MODE')}") - @classmethod def update_step_context(cls, step_context): """Update step context.""" + kv_start_indices, attention_mask = [], [] + block_num, block_size, *_ = step_context.kv_caches[0][0].shape + is_unpaged_prefill = False + if not step_context.is_decoding: + is_unpaged_prefill = all((step_context.q_seqlens == step_context.kv_seqlens).tolist()) + if step_context.block_offsets.dtype != torch.int32: + step_context.block_offsets = step_context.block_offsets.to(torch.int32) + if step_context.kv_seqlens.dtype != torch.int32: + step_context.kv_seqlens = step_context.kv_seqlens.to(torch.int32) + if step_context.q_seqlens.dtype != torch.int32: + step_context.q_seqlens = step_context.q_seqlens.to(torch.int32) + if cls.enable_graph: + import torch._dynamo as dynamo + dynamo.mark_dynamic(step_context.block_offsets, [0, 1]) + def get_total_slots(): if cls.total_slots is None: cls.total_slots = torch.arange(block_num * block_size, - dtype=torch.long, + dtype=torch.int32, device=step_context.block_offsets.device) cls.total_slots = cls.total_slots.view(block_num, block_size) return cls.total_slots - kv_start_indices, attention_mask = [], [] - if SocVersion.is_Ascend910(): - block_num, block_size, *_ = step_context.kv_caches[0][0].shape - elif SocVersion.is_Ascend310P(): - block_num, _, block_size, _ = step_context.kv_caches[0][0].shape - - is_unpaged_prefill = False - if not step_context.is_decoding: - is_unpaged_prefill = \ - all((step_context.q_seqlens == - step_context.kv_seqlens).tolist()) - q_seqlens_list = step_context.q_seqlens.tolist() - kv_seqlens_list = step_context.kv_seqlens.tolist() - max_q_seq_len = max(q_seqlens_list) - max_kv_seq_len = max(kv_seqlens_list) - - if step_context.is_decoding: - # collect kv_start_indices without using a for-loop, - # (fill kv-cache for just ONE token during the decoding phase) - idx = (step_context.kv_seqlens - 1) % block_size - block_num = (step_context.kv_seqlens - 1) // block_size - last_block = step_context.block_offsets.gather(1, block_num.view(-1, 1)).view(-1) - kv_start_indices = last_block * block_size + idx - else: - for i in range(step_context.q_start_loc.size(0)): - q_seq_len = q_seqlens_list[i] - kv_seq_len = kv_seqlens_list[i] - - # collect kv start indices during the prefill phase. - history_length = kv_seq_len - q_seq_len - total_slots = get_total_slots() - slot_tables = total_slots[step_context.block_offsets[i]].view(-1) - slots = slot_tables[history_length:kv_seq_len] - kv_start_indices.append(slots) - - # collect attention mask of paged_prefill attention stage. - if not is_unpaged_prefill: - single_attention_mask = torch.logical_not( - torch.tril( - torch.ones(q_seq_len, - step_context.block_offsets.shape[1] * block_size, - dtype=torch.bool, - device=step_context.block_offsets.device), - diagonal=kv_seq_len - q_seq_len, - )) - attention_mask.append(single_attention_mask) - - kv_start_indices = torch.cat(kv_start_indices) - - if step_context.is_decoding: - # prepare some params of paged_decode attention stage. - q_start_loc_cpu, q_seqlens_cpu = None, None - elif is_unpaged_prefill: - # prepare some params of unpaged_prefill attention stage. - q_start_loc_cpu, kv_seqlens_cpu = None, None - q_seqlens_cpu = step_context.q_seqlens.cpu().to(torch.int32) - if SocVersion.is_Ascend910(): - single_attention_mask = torch.logical_not( - torch.tril( - torch.ones(max_q_seq_len, max_kv_seq_len, dtype=torch.bool).cuda(), - diagonal=max_kv_seq_len - max_q_seq_len, - )) - attention_mask.append(single_attention_mask) - elif SocVersion.is_Ascend310P(): - if not cls.enable_graph: - for i in range(q_seqlens_cpu.size(0)): - single_attention_mask = torch.zeros(q_seqlens_cpu[i], - q_seqlens_cpu[i]).fill_(-float('inf')).cuda() - single_attention_mask = torch.triu(single_attention_mask, diagonal=1) - attention_mask.append(single_attention_mask) - else: - # Transdata needs dtype to be float16 or int8 - single_attention_mask = torch.triu( - torch.ones(max_q_seq_len, max_kv_seq_len, dtype=torch.float16).fill_(-float('inf')).cuda(), - diagonal=max_kv_seq_len - max_q_seq_len + 1, - ) - # Convert to NZ format - attention_mask.append(nd_to_nz_spec(single_attention_mask)) + def get_cpu_seqlens(is_decoding, is_unpaged_prefill): + if is_decoding: + q_seqlens_cpu, kv_seqlens_cpu = None, step_context.kv_seqlens.cpu() + elif is_unpaged_prefill: + q_seqlens_cpu = kv_seqlens_cpu = step_context.q_seqlens.cpu() else: - raise ValueError(f"dlinfer doesn't support {SocVersion.device_name()} device currently.") - else: - # prepare some params of paged_prefill attention stage. - q_start_loc_cpu, q_seqlens_cpu = None, None - attention_mask = [torch.cat([mask for mask in attention_mask])] + q_seqlens_cpu = step_context.q_seqlens.cpu() + kv_seqlens_cpu = step_context.kv_seqlens.cpu() + return q_seqlens_cpu, kv_seqlens_cpu + + def get_list_seqlens(is_decoding, is_unpaged_prefill, q_seqlens_cpu=None, kv_seqlens_cpu=None): + if is_decoding: + q_seqlens_list, kv_seqlens_list = None, None + elif is_unpaged_prefill: + q_seqlens_list = kv_seqlens_list = q_seqlens_cpu.tolist() + else: + q_seqlens_list, kv_seqlens_list = q_seqlens_cpu.tolist(), kv_seqlens_cpu.tolist() + return q_seqlens_list, kv_seqlens_list - if cls.enable_graph: - kv_start_indices = kv_start_indices.view(-1).to(torch.int32) - import torch._dynamo as dynamo - if not is_unpaged_prefill: - step_context.block_offsets = step_context.block_offsets.to(torch.int32) - if not step_context.is_decoding: - step_context.block_offsets = step_context.block_offsets\ - .repeat_interleave(step_context.q_seqlens, 0) - dynamo.mark_dynamic(step_context.block_offsets, [0, 1]) - kv_seqlens = step_context.kv_seqlens.cpu().to(torch.int32) - if not step_context.is_decoding: - if is_unpaged_prefill: - if SocVersion.is_Ascend910(): - attention_mask = [mask.half() for mask in attention_mask] - else: - if SocVersion.is_Ascend910(): - attention_mask = [ - torch.cat([mask.half() * cls.half_negative_inf for mask in attention_mask]).unsqueeze(1) - ] - elif SocVersion.is_Ascend310P(): - # Convert mask to NZ format. - attention_mask = [ - nd_to_nz_spec(torch.cat([mask.half() * cls.half_negative_inf for mask in attention_mask])) - ] - else: - raise ValueError(f"dlinfer doesn't support {SocVersion.device_name()} device currently.") - kv_seqlens = kv_seqlens.repeat_interleave(step_context.q_seqlens, 0) - else: - if step_context.is_decoding: - kv_seqlens_cpu = step_context.kv_seqlens.cpu().to(torch.int32) + def get_max_seqlens(is_decoding, is_unpaged_prefill, q_seqlens_list=None, kv_seqlens_list=None): + if is_decoding: + max_q_seq_len, max_kv_seq_len = 1, None elif is_unpaged_prefill: - pass + max_q_seq_len = max_kv_seq_len = max(q_seqlens_list) + else: + max_q_seq_len = max(q_seqlens_list) + max_kv_seq_len = max(kv_seqlens_list) + return max_q_seq_len, max_kv_seq_len + + def get_kv_start_indices_and_attention_mask(is_decoding, is_unpaged_prefill, q_seqlens_list, kv_seqlens_list, max_q_seq_len, max_kv_seq_len): + kv_start_indices, attention_mask = [], [] + if is_decoding: + idx = (step_context.kv_seqlens - 1) % block_size + block_num = (step_context.kv_seqlens - 1) // block_size + last_block = step_context.block_offsets.gather(1, block_num.view(-1, 1)).view(-1) + kv_start_indices = last_block * block_size + idx else: - kv_seqlens_cpu = step_context.kv_seqlens.repeat_interleave(step_context.q_seqlens, 0).cpu() - block_offsets_int32 = step_context.block_offsets.to(torch.int32) - step_context.block_offsets = block_offsets_int32\ - .repeat_interleave(step_context.q_seqlens, 0) - kv_seqlens = kv_seqlens_cpu + for i in range(step_context.q_start_loc.size(0)): + q_seq_len = q_seqlens_list[i] + kv_seq_len = kv_seqlens_list[i] + + history_length = kv_seq_len - q_seq_len + total_slots = get_total_slots() + slot_tables = total_slots[step_context.block_offsets[i]].view(-1) + slots = slot_tables[history_length:kv_seq_len] + kv_start_indices.append(slots) + + if not is_unpaged_prefill: + single_attention_mask = torch.logical_not( + torch.tril( + torch.ones(q_seq_len, + step_context.block_offsets.shape[1] * block_size, + dtype=step_context.kv_caches[0][0].dtype, + device=step_context.block_offsets.device), + diagonal=kv_seq_len - q_seq_len, + )) + attention_mask.append(single_attention_mask) + + if is_unpaged_prefill: + attention_mask.append(torch.logical_not( + torch.tril( + torch.ones(max_q_seq_len, max_kv_seq_len, dtype=step_context.kv_caches[0][0].dtype, device=step_context.block_offsets.device), + diagonal=max_kv_seq_len - max_q_seq_len, + ))) + + kv_start_indices = torch.cat(kv_start_indices) + + return kv_start_indices, attention_mask + + q_seqlens_cpu, kv_seqlens_cpu = get_cpu_seqlens(step_context.is_decoding, is_unpaged_prefill) + q_seqlens_list, kv_seqlens_list = get_list_seqlens(step_context.is_decoding, is_unpaged_prefill, q_seqlens_cpu, kv_seqlens_cpu) + max_q_seq_len, max_kv_seq_len = get_max_seqlens(step_context.is_decoding, is_unpaged_prefill, q_seqlens_list, kv_seqlens_list) + kv_start_indices, attention_mask = get_kv_start_indices_and_attention_mask(step_context.is_decoding, is_unpaged_prefill, q_seqlens_list, kv_seqlens_list, max_q_seq_len, max_kv_seq_len) if not cls.enable_graph and step_context.kv_quant_policy == 8: record_file = os.getenv('ASCEND_QUANT_RECORD_FILE') @@ -298,9 +230,9 @@ def get_total_slots(): attn_metadata = attn_meta_cls( step_context.is_decoding, step_context.block_offsets, - q_start_loc=q_start_loc_cpu, + q_start_loc=None, q_seqlens=q_seqlens_cpu, - kv_seqlens=kv_seqlens, + kv_seqlens=kv_seqlens_cpu, kv_start_indices=kv_start_indices, block_size=block_size, attention_mask=attention_mask, @@ -318,14 +250,8 @@ def get_total_slots(): def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig, cache_config: CacheConfig, backend_config: BackendConfig, device: torch.device): """Build graph runner.""" - if AscendOpsBackend.enable_aclgraph(): - from lmdeploy.pytorch.backends.cuda.graph_runner import CUDAGraphRunner - return CUDAGraphRunner(model, model_config, cache_config, backend_config, device) - else: - from .graph_runner import AscendGraphRunner - ascend_graph_runner = AscendGraphRunner(model, model_config, cache_config, backend_config, device) - AscendOpsBackend.enable_graph = ascend_graph_runner.enable_graph - return ascend_graph_runner + from lmdeploy.pytorch.backends.cuda.graph_runner import CUDAGraphRunner + return CUDAGraphRunner(model, model_config, cache_config, backend_config, device) @staticmethod def init(): From 080dc120e83358a0cb6b3ec115e82583c1fee836 Mon Sep 17 00:00:00 2001 From: yaofengchen Date: Tue, 2 Dec 2025 07:56:16 +0000 Subject: [PATCH 2/6] refactor mask --- .../backends/dlinfer/ascend/op_backend.py | 42 +++++++++++-------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py index f2a346f4d8..ea221cb3e2 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py @@ -2,7 +2,6 @@ import itertools import os import re -from functools import lru_cache from pathlib import Path from typing import Dict, Tuple @@ -13,7 +12,6 @@ from lmdeploy.utils import get_logger from ..op_backend import DlinferOpsBackend -from .utils import nd_to_nz_spec logger = get_logger('lmdeploy') @@ -123,6 +121,8 @@ def update_step_context(cls, step_context): is_unpaged_prefill = all((step_context.q_seqlens == step_context.kv_seqlens).tolist()) if step_context.block_offsets.dtype != torch.int32: step_context.block_offsets = step_context.block_offsets.to(torch.int32) + if not (step_context.is_decoding or is_unpaged_prefill): + step_context.block_offsets = step_context.block_offsets.repeat_interleave(step_context.q_seqlens, 0) if step_context.kv_seqlens.dtype != torch.int32: step_context.kv_seqlens = step_context.kv_seqlens.to(torch.int32) if step_context.q_seqlens.dtype != torch.int32: @@ -148,7 +148,7 @@ def get_cpu_seqlens(is_decoding, is_unpaged_prefill): q_seqlens_cpu = step_context.q_seqlens.cpu() kv_seqlens_cpu = step_context.kv_seqlens.cpu() return q_seqlens_cpu, kv_seqlens_cpu - + def get_list_seqlens(is_decoding, is_unpaged_prefill, q_seqlens_cpu=None, kv_seqlens_cpu=None): if is_decoding: q_seqlens_list, kv_seqlens_list = None, None @@ -168,7 +168,8 @@ def get_max_seqlens(is_decoding, is_unpaged_prefill, q_seqlens_list=None, kv_seq max_kv_seq_len = max(kv_seqlens_list) return max_q_seq_len, max_kv_seq_len - def get_kv_start_indices_and_attention_mask(is_decoding, is_unpaged_prefill, q_seqlens_list, kv_seqlens_list, max_q_seq_len, max_kv_seq_len): + def get_kv_start_indices_and_attention_mask(is_decoding, is_unpaged_prefill, q_seqlens_list, kv_seqlens_list, + max_q_seq_len, max_kv_seq_len): kv_start_indices, attention_mask = [], [] if is_decoding: idx = (step_context.kv_seqlens - 1) % block_size @@ -187,31 +188,38 @@ def get_kv_start_indices_and_attention_mask(is_decoding, is_unpaged_prefill, q_s kv_start_indices.append(slots) if not is_unpaged_prefill: - single_attention_mask = torch.logical_not( - torch.tril( + single_attention_mask = torch.triu( torch.ones(q_seq_len, step_context.block_offsets.shape[1] * block_size, - dtype=step_context.kv_caches[0][0].dtype, + dtype=torch.bool, device=step_context.block_offsets.device), - diagonal=kv_seq_len - q_seq_len, - )) + diagonal=kv_seq_len - q_seq_len + 1, + ) attention_mask.append(single_attention_mask) if is_unpaged_prefill: - attention_mask.append(torch.logical_not( - torch.tril( - torch.ones(max_q_seq_len, max_kv_seq_len, dtype=step_context.kv_caches[0][0].dtype, device=step_context.block_offsets.device), - diagonal=max_kv_seq_len - max_q_seq_len, - ))) + attention_mask.append(torch.triu( + torch.ones(max_q_seq_len, + max_kv_seq_len, + dtype=step_context.kv_caches[0][0].dtype, + device=step_context.block_offsets.device), + diagonal=max_kv_seq_len - max_q_seq_len + 1)) + else: + attention_mask = [torch.cat(attention_mask)] kv_start_indices = torch.cat(kv_start_indices) return kv_start_indices, attention_mask q_seqlens_cpu, kv_seqlens_cpu = get_cpu_seqlens(step_context.is_decoding, is_unpaged_prefill) - q_seqlens_list, kv_seqlens_list = get_list_seqlens(step_context.is_decoding, is_unpaged_prefill, q_seqlens_cpu, kv_seqlens_cpu) - max_q_seq_len, max_kv_seq_len = get_max_seqlens(step_context.is_decoding, is_unpaged_prefill, q_seqlens_list, kv_seqlens_list) - kv_start_indices, attention_mask = get_kv_start_indices_and_attention_mask(step_context.is_decoding, is_unpaged_prefill, q_seqlens_list, kv_seqlens_list, max_q_seq_len, max_kv_seq_len) + q_seqlens_list, kv_seqlens_list = get_list_seqlens(step_context.is_decoding, is_unpaged_prefill, q_seqlens_cpu, + kv_seqlens_cpu) + max_q_seq_len, max_kv_seq_len = get_max_seqlens(step_context.is_decoding, is_unpaged_prefill, q_seqlens_list, + kv_seqlens_list) + kv_start_indices, attention_mask = get_kv_start_indices_and_attention_mask(step_context.is_decoding, + is_unpaged_prefill, q_seqlens_list, + kv_seqlens_list, max_q_seq_len, + max_kv_seq_len) if not cls.enable_graph and step_context.kv_quant_policy == 8: record_file = os.getenv('ASCEND_QUANT_RECORD_FILE') From 0840a8272d00bc35320567269f766880c845efe8 Mon Sep 17 00:00:00 2001 From: yaofengchen Date: Tue, 2 Dec 2025 08:09:00 +0000 Subject: [PATCH 3/6] format code --- .../backends/dlinfer/ascend/op_backend.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py index ea221cb3e2..3fa084b777 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py @@ -189,21 +189,21 @@ def get_kv_start_indices_and_attention_mask(is_decoding, is_unpaged_prefill, q_s if not is_unpaged_prefill: single_attention_mask = torch.triu( - torch.ones(q_seq_len, - step_context.block_offsets.shape[1] * block_size, - dtype=torch.bool, - device=step_context.block_offsets.device), - diagonal=kv_seq_len - q_seq_len + 1, - ) + torch.ones(q_seq_len, + step_context.block_offsets.shape[1] * block_size, + dtype=torch.bool, + device=step_context.block_offsets.device), + diagonal=kv_seq_len - q_seq_len + 1, + ) attention_mask.append(single_attention_mask) if is_unpaged_prefill: - attention_mask.append(torch.triu( - torch.ones(max_q_seq_len, - max_kv_seq_len, - dtype=step_context.kv_caches[0][0].dtype, - device=step_context.block_offsets.device), - diagonal=max_kv_seq_len - max_q_seq_len + 1)) + attention_mask.append( + torch.triu(torch.ones(max_q_seq_len, + max_kv_seq_len, + dtype=step_context.kv_caches[0][0].dtype, + device=step_context.block_offsets.device), + diagonal=max_kv_seq_len - max_q_seq_len + 1)) else: attention_mask = [torch.cat(attention_mask)] From 24cb83288b2b61c4c72f5c6ce15cc7e222978c5d Mon Sep 17 00:00:00 2001 From: yaofengchen Date: Tue, 2 Dec 2025 08:29:58 +0000 Subject: [PATCH 4/6] remove 310P judge --- lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py | 6 ------ lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py | 3 --- lmdeploy/pytorch/backends/dlinfer/linear.py | 8 -------- 3 files changed, 17 deletions(-) diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py b/lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py index 8511ee9c9f..9bb328269a 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py @@ -49,12 +49,6 @@ def __init__(self, model: torch.nn.Module, model_config: ModelConfig, cache_conf backend='atbgraph') else: self.model = torch.compile(self.model, fullgraph=True, dynamic=True, backend='atbgraph') - if SocVersion.is_Ascend310P() and hasattr(self.model, 'get_logits'): - # Compile get_logits for Ascend310P to use ATB linear since we would convert weight to nz format - self.model.get_logits = torch.compile(self.model.get_logits, - fullgraph=True, - dynamic=True, - backend='atbgraph') def check_enable_graph(self): """Check enable graph.""" diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py index 3fa084b777..3726f210c4 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py @@ -266,9 +266,6 @@ def init(): """Initialize Ascend backend.""" try: from torch_npu.contrib import transfer_to_npu # noqa: F401 - if SocVersion.is_Ascend310P(): - # NOTE: Ascend310P has a bug with InternVL vision embedding using interpolate. - torch.npu.set_compile_mode(jit_compile=False) except ImportError: logger.warning('Failed to import torch_npu. Please make sure torch_npu is installed correctly. ' 'Ascend initialization skipped.') diff --git a/lmdeploy/pytorch/backends/dlinfer/linear.py b/lmdeploy/pytorch/backends/dlinfer/linear.py index ec682bba8b..fbe717f5c2 100644 --- a/lmdeploy/pytorch/backends/dlinfer/linear.py +++ b/lmdeploy/pytorch/backends/dlinfer/linear.py @@ -17,14 +17,6 @@ def update_weights(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = No """Update weights.""" if os.getenv('DLINFER_LINEAR_USE_NN_LAYOUT', '0') == '1': weight = weight.data.t().contiguous() - if weight.device.type == 'npu': - from .ascend import SocVersion - if SocVersion.is_Ascend310P() and not os.getenv('DLINFER_DISABLE_LINEAR_NZ_FORMAT', '0') == '1': - # Ascend 310P device need weight to be NZ format, so Transdata it initially. - # Transdata Linear weight by default, if Error occurs, please set - # DLINFER_DISABLE_LINEAR_NZ_FORMAT=1 to disable transdata. - from .ascend.utils import nd_to_nz_spec - weight = nd_to_nz_spec(weight) return weight, bias def forward(self, From 9a9604a8b3e10fa555a72e34245e5883fc6808a7 Mon Sep 17 00:00:00 2001 From: yaofengchen Date: Thu, 4 Dec 2025 06:32:24 +0000 Subject: [PATCH 5/6] remove unused code --- lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py index 3726f210c4..ace236d908 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py @@ -127,9 +127,6 @@ def update_step_context(cls, step_context): step_context.kv_seqlens = step_context.kv_seqlens.to(torch.int32) if step_context.q_seqlens.dtype != torch.int32: step_context.q_seqlens = step_context.q_seqlens.to(torch.int32) - if cls.enable_graph: - import torch._dynamo as dynamo - dynamo.mark_dynamic(step_context.block_offsets, [0, 1]) def get_total_slots(): if cls.total_slots is None: From bd877febe926c3962cf6aa8f09b2e949839d0695 Mon Sep 17 00:00:00 2001 From: yaofengchen Date: Thu, 4 Dec 2025 09:15:30 +0000 Subject: [PATCH 6/6] update code --- .../backends/dlinfer/ascend/op_backend.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py index ace236d908..7e95a16d35 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py @@ -2,6 +2,7 @@ import itertools import os import re +from functools import lru_cache from pathlib import Path from typing import Dict, Tuple @@ -19,15 +20,25 @@ class SocVersion: Ascend310P: str = 'Ascend310P' Ascend910: str = 'Ascend910' - device_name: str = torch.npu.get_device_name() + + @classmethod + @lru_cache(maxsize=1) + def device_name(cls) -> str: + try: + return torch.npu.get_device_name() + except ImportError: + logger.warning('Failed to import torch_npu. Please make sure torch_npu is installed correctly.') + except Exception as e: + logger.warning(f'Error during Ascend get device name: {str(e)}. ' + 'Please check your Ascend environment configuration.') @classmethod def is_Ascend310P(cls) -> bool: - return cls.device_name.startswith(cls.Ascend310P) + return cls.device_name().startswith(cls.Ascend310P) @classmethod def is_Ascend910(cls) -> bool: - return cls.device_name.startswith(cls.Ascend910) + return cls.device_name().startswith(cls.Ascend910) class AscendKVQuantMeta: @@ -114,7 +125,6 @@ def get_v_block_shape( def update_step_context(cls, step_context): """Update step context.""" - kv_start_indices, attention_mask = [], [] block_num, block_size, *_ = step_context.kv_caches[0][0].shape is_unpaged_prefill = False if not step_context.is_decoding: