diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py b/lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py index 8511ee9c9f..9bb328269a 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py @@ -49,12 +49,6 @@ def __init__(self, model: torch.nn.Module, model_config: ModelConfig, cache_conf backend='atbgraph') else: self.model = torch.compile(self.model, fullgraph=True, dynamic=True, backend='atbgraph') - if SocVersion.is_Ascend310P() and hasattr(self.model, 'get_logits'): - # Compile get_logits for Ascend310P to use ATB linear since we would convert weight to nz format - self.model.get_logits = torch.compile(self.model.get_logits, - fullgraph=True, - dynamic=True, - backend='atbgraph') def check_enable_graph(self): """Check enable graph.""" diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py index 72bad9f30c..7e95a16d35 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py @@ -13,7 +13,6 @@ from lmdeploy.utils import get_logger from ..op_backend import DlinferOpsBackend -from .utils import nd_to_nz_spec logger = get_logger('lmdeploy') @@ -26,10 +25,9 @@ class SocVersion: @lru_cache(maxsize=1) def device_name(cls) -> str: try: - import torch_npu - return torch_npu.npu.get_device_name() + return torch.npu.get_device_name() except ImportError: - logger.warning('Failed to import torch_npu. Please make sure torch_npu is installed correctly. ') + logger.warning('Failed to import torch_npu. Please make sure torch_npu is installed correctly.') except Exception as e: logger.warning(f'Error during Ascend get device name: {str(e)}. ' 'Please check your Ascend environment configuration.') @@ -108,12 +106,6 @@ def get_k_block_shape( ) -> Tuple[int, ...]: if SocVersion.is_Ascend910(): return (block_size, num_heads, head_size) - elif SocVersion.is_Ascend310P(): - return ( - (num_heads * head_size + 15) // 16, - block_size, - 16, - ) else: raise ValueError(f'dlinfer does not support {SocVersion.device_name()} device currently.') @@ -126,160 +118,115 @@ def get_v_block_shape( ) -> Tuple[int, ...]: if SocVersion.is_Ascend910(): return (block_size, num_heads, head_size) - elif SocVersion.is_Ascend310P(): - return ( - (num_heads * head_size + 15) // 16, - block_size, - 16, - ) else: raise ValueError(f'dlinfer does not support {SocVersion.device_name()} device currently.') - @classmethod - @lru_cache(maxsize=1) - def enable_aclgraph(cls) -> bool: - if os.getenv('ASCEND_GRAPH_MODE', 'aclgraph') == 'aclgraph' and not SocVersion.is_Ascend310P(): - return True - elif os.getenv('ASCEND_GRAPH_MODE', 'aclgraph') == 'atbgraph' or SocVersion.is_Ascend310P(): - return False - else: - raise ValueError(f"unsupported ASCEND_GRAPH_MODE: {os.getenv('ASCEND_GRAPH_MODE')}") - @classmethod def update_step_context(cls, step_context): """Update step context.""" + block_num, block_size, *_ = step_context.kv_caches[0][0].shape + is_unpaged_prefill = False + if not step_context.is_decoding: + is_unpaged_prefill = all((step_context.q_seqlens == step_context.kv_seqlens).tolist()) + if step_context.block_offsets.dtype != torch.int32: + step_context.block_offsets = step_context.block_offsets.to(torch.int32) + if not (step_context.is_decoding or is_unpaged_prefill): + step_context.block_offsets = step_context.block_offsets.repeat_interleave(step_context.q_seqlens, 0) + if step_context.kv_seqlens.dtype != torch.int32: + step_context.kv_seqlens = step_context.kv_seqlens.to(torch.int32) + if step_context.q_seqlens.dtype != torch.int32: + step_context.q_seqlens = step_context.q_seqlens.to(torch.int32) + def get_total_slots(): if cls.total_slots is None: cls.total_slots = torch.arange(block_num * block_size, - dtype=torch.long, + dtype=torch.int32, device=step_context.block_offsets.device) cls.total_slots = cls.total_slots.view(block_num, block_size) return cls.total_slots - kv_start_indices, attention_mask = [], [] - if SocVersion.is_Ascend910(): - block_num, block_size, *_ = step_context.kv_caches[0][0].shape - elif SocVersion.is_Ascend310P(): - block_num, _, block_size, _ = step_context.kv_caches[0][0].shape + def get_cpu_seqlens(is_decoding, is_unpaged_prefill): + if is_decoding: + q_seqlens_cpu, kv_seqlens_cpu = None, step_context.kv_seqlens.cpu() + elif is_unpaged_prefill: + q_seqlens_cpu = kv_seqlens_cpu = step_context.q_seqlens.cpu() + else: + q_seqlens_cpu = step_context.q_seqlens.cpu() + kv_seqlens_cpu = step_context.kv_seqlens.cpu() + return q_seqlens_cpu, kv_seqlens_cpu - is_unpaged_prefill = False - if not step_context.is_decoding: - is_unpaged_prefill = \ - all((step_context.q_seqlens == - step_context.kv_seqlens).tolist()) - q_seqlens_list = step_context.q_seqlens.tolist() - kv_seqlens_list = step_context.kv_seqlens.tolist() - max_q_seq_len = max(q_seqlens_list) - max_kv_seq_len = max(kv_seqlens_list) - - if step_context.is_decoding: - # collect kv_start_indices without using a for-loop, - # (fill kv-cache for just ONE token during the decoding phase) - idx = (step_context.kv_seqlens - 1) % block_size - block_num = (step_context.kv_seqlens - 1) // block_size - last_block = step_context.block_offsets.gather(1, block_num.view(-1, 1)).view(-1) - kv_start_indices = last_block * block_size + idx - else: - for i in range(step_context.q_start_loc.size(0)): - q_seq_len = q_seqlens_list[i] - kv_seq_len = kv_seqlens_list[i] - - # collect kv start indices during the prefill phase. - history_length = kv_seq_len - q_seq_len - total_slots = get_total_slots() - slot_tables = total_slots[step_context.block_offsets[i]].view(-1) - slots = slot_tables[history_length:kv_seq_len] - kv_start_indices.append(slots) - - # collect attention mask of paged_prefill attention stage. - if not is_unpaged_prefill: - single_attention_mask = torch.logical_not( - torch.tril( + def get_list_seqlens(is_decoding, is_unpaged_prefill, q_seqlens_cpu=None, kv_seqlens_cpu=None): + if is_decoding: + q_seqlens_list, kv_seqlens_list = None, None + elif is_unpaged_prefill: + q_seqlens_list = kv_seqlens_list = q_seqlens_cpu.tolist() + else: + q_seqlens_list, kv_seqlens_list = q_seqlens_cpu.tolist(), kv_seqlens_cpu.tolist() + return q_seqlens_list, kv_seqlens_list + + def get_max_seqlens(is_decoding, is_unpaged_prefill, q_seqlens_list=None, kv_seqlens_list=None): + if is_decoding: + max_q_seq_len, max_kv_seq_len = 1, None + elif is_unpaged_prefill: + max_q_seq_len = max_kv_seq_len = max(q_seqlens_list) + else: + max_q_seq_len = max(q_seqlens_list) + max_kv_seq_len = max(kv_seqlens_list) + return max_q_seq_len, max_kv_seq_len + + def get_kv_start_indices_and_attention_mask(is_decoding, is_unpaged_prefill, q_seqlens_list, kv_seqlens_list, + max_q_seq_len, max_kv_seq_len): + kv_start_indices, attention_mask = [], [] + if is_decoding: + idx = (step_context.kv_seqlens - 1) % block_size + block_num = (step_context.kv_seqlens - 1) // block_size + last_block = step_context.block_offsets.gather(1, block_num.view(-1, 1)).view(-1) + kv_start_indices = last_block * block_size + idx + else: + for i in range(step_context.q_start_loc.size(0)): + q_seq_len = q_seqlens_list[i] + kv_seq_len = kv_seqlens_list[i] + + history_length = kv_seq_len - q_seq_len + total_slots = get_total_slots() + slot_tables = total_slots[step_context.block_offsets[i]].view(-1) + slots = slot_tables[history_length:kv_seq_len] + kv_start_indices.append(slots) + + if not is_unpaged_prefill: + single_attention_mask = torch.triu( torch.ones(q_seq_len, step_context.block_offsets.shape[1] * block_size, dtype=torch.bool, device=step_context.block_offsets.device), - diagonal=kv_seq_len - q_seq_len, - )) - attention_mask.append(single_attention_mask) - - kv_start_indices = torch.cat(kv_start_indices) - - if step_context.is_decoding: - # prepare some params of paged_decode attention stage. - q_start_loc_cpu, q_seqlens_cpu = None, None - elif is_unpaged_prefill: - # prepare some params of unpaged_prefill attention stage. - q_start_loc_cpu, kv_seqlens_cpu = None, None - q_seqlens_cpu = step_context.q_seqlens.cpu().to(torch.int32) - if SocVersion.is_Ascend910(): - single_attention_mask = torch.logical_not( - torch.tril( - torch.ones(max_q_seq_len, max_kv_seq_len, dtype=torch.bool).cuda(), - diagonal=max_kv_seq_len - max_q_seq_len, - )) - attention_mask.append(single_attention_mask) - elif SocVersion.is_Ascend310P(): - if not cls.enable_graph: - for i in range(q_seqlens_cpu.size(0)): - single_attention_mask = torch.zeros(q_seqlens_cpu[i], - q_seqlens_cpu[i]).fill_(-float('inf')).cuda() - single_attention_mask = torch.triu(single_attention_mask, diagonal=1) + diagonal=kv_seq_len - q_seq_len + 1, + ) attention_mask.append(single_attention_mask) - else: - # Transdata needs dtype to be float16 or int8 - single_attention_mask = torch.triu( - torch.ones(max_q_seq_len, max_kv_seq_len, dtype=torch.float16).fill_(-float('inf')).cuda(), - diagonal=max_kv_seq_len - max_q_seq_len + 1, - ) - # Convert to NZ format - attention_mask.append(nd_to_nz_spec(single_attention_mask)) - else: - raise ValueError(f"dlinfer doesn't support {SocVersion.device_name()} device currently.") - else: - # prepare some params of paged_prefill attention stage. - q_start_loc_cpu, q_seqlens_cpu = None, None - attention_mask = [torch.cat([mask for mask in attention_mask])] - - if cls.enable_graph: - kv_start_indices = kv_start_indices.view(-1).to(torch.int32) - import torch._dynamo as dynamo - if not is_unpaged_prefill: - step_context.block_offsets = step_context.block_offsets.to(torch.int32) - if not step_context.is_decoding: - step_context.block_offsets = step_context.block_offsets\ - .repeat_interleave(step_context.q_seqlens, 0) - dynamo.mark_dynamic(step_context.block_offsets, [0, 1]) - kv_seqlens = step_context.kv_seqlens.cpu().to(torch.int32) - if not step_context.is_decoding: + if is_unpaged_prefill: - if SocVersion.is_Ascend910(): - attention_mask = [mask.half() for mask in attention_mask] + attention_mask.append( + torch.triu(torch.ones(max_q_seq_len, + max_kv_seq_len, + dtype=step_context.kv_caches[0][0].dtype, + device=step_context.block_offsets.device), + diagonal=max_kv_seq_len - max_q_seq_len + 1)) else: - if SocVersion.is_Ascend910(): - attention_mask = [ - torch.cat([mask.half() * cls.half_negative_inf for mask in attention_mask]).unsqueeze(1) - ] - elif SocVersion.is_Ascend310P(): - # Convert mask to NZ format. - attention_mask = [ - nd_to_nz_spec(torch.cat([mask.half() * cls.half_negative_inf for mask in attention_mask])) - ] - else: - raise ValueError(f"dlinfer doesn't support {SocVersion.device_name()} device currently.") - kv_seqlens = kv_seqlens.repeat_interleave(step_context.q_seqlens, 0) - else: - if step_context.is_decoding: - kv_seqlens_cpu = step_context.kv_seqlens.cpu().to(torch.int32) - elif is_unpaged_prefill: - pass - else: - kv_seqlens_cpu = step_context.kv_seqlens.repeat_interleave(step_context.q_seqlens, 0).cpu() - block_offsets_int32 = step_context.block_offsets.to(torch.int32) - step_context.block_offsets = block_offsets_int32\ - .repeat_interleave(step_context.q_seqlens, 0) - kv_seqlens = kv_seqlens_cpu + attention_mask = [torch.cat(attention_mask)] + + kv_start_indices = torch.cat(kv_start_indices) + + return kv_start_indices, attention_mask + + q_seqlens_cpu, kv_seqlens_cpu = get_cpu_seqlens(step_context.is_decoding, is_unpaged_prefill) + q_seqlens_list, kv_seqlens_list = get_list_seqlens(step_context.is_decoding, is_unpaged_prefill, q_seqlens_cpu, + kv_seqlens_cpu) + max_q_seq_len, max_kv_seq_len = get_max_seqlens(step_context.is_decoding, is_unpaged_prefill, q_seqlens_list, + kv_seqlens_list) + kv_start_indices, attention_mask = get_kv_start_indices_and_attention_mask(step_context.is_decoding, + is_unpaged_prefill, q_seqlens_list, + kv_seqlens_list, max_q_seq_len, + max_kv_seq_len) if not cls.enable_graph and step_context.kv_quant_policy == 8: record_file = os.getenv('ASCEND_QUANT_RECORD_FILE') @@ -298,9 +245,9 @@ def get_total_slots(): attn_metadata = attn_meta_cls( step_context.is_decoding, step_context.block_offsets, - q_start_loc=q_start_loc_cpu, + q_start_loc=None, q_seqlens=q_seqlens_cpu, - kv_seqlens=kv_seqlens, + kv_seqlens=kv_seqlens_cpu, kv_start_indices=kv_start_indices, block_size=block_size, attention_mask=attention_mask, @@ -318,23 +265,14 @@ def get_total_slots(): def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig, cache_config: CacheConfig, backend_config: BackendConfig, device: torch.device): """Build graph runner.""" - if AscendOpsBackend.enable_aclgraph(): - from lmdeploy.pytorch.backends.cuda.graph_runner import CUDAGraphRunner - return CUDAGraphRunner(model, model_config, cache_config, backend_config, device) - else: - from .graph_runner import AscendGraphRunner - ascend_graph_runner = AscendGraphRunner(model, model_config, cache_config, backend_config, device) - AscendOpsBackend.enable_graph = ascend_graph_runner.enable_graph - return ascend_graph_runner + from lmdeploy.pytorch.backends.cuda.graph_runner import CUDAGraphRunner + return CUDAGraphRunner(model, model_config, cache_config, backend_config, device) @staticmethod def init(): """Initialize Ascend backend.""" try: from torch_npu.contrib import transfer_to_npu # noqa: F401 - if SocVersion.is_Ascend310P(): - # NOTE: Ascend310P has a bug with InternVL vision embedding using interpolate. - torch.npu.set_compile_mode(jit_compile=False) except ImportError: logger.warning('Failed to import torch_npu. Please make sure torch_npu is installed correctly. ' 'Ascend initialization skipped.') diff --git a/lmdeploy/pytorch/backends/dlinfer/linear.py b/lmdeploy/pytorch/backends/dlinfer/linear.py index ec682bba8b..fbe717f5c2 100644 --- a/lmdeploy/pytorch/backends/dlinfer/linear.py +++ b/lmdeploy/pytorch/backends/dlinfer/linear.py @@ -17,14 +17,6 @@ def update_weights(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = No """Update weights.""" if os.getenv('DLINFER_LINEAR_USE_NN_LAYOUT', '0') == '1': weight = weight.data.t().contiguous() - if weight.device.type == 'npu': - from .ascend import SocVersion - if SocVersion.is_Ascend310P() and not os.getenv('DLINFER_DISABLE_LINEAR_NZ_FORMAT', '0') == '1': - # Ascend 310P device need weight to be NZ format, so Transdata it initially. - # Transdata Linear weight by default, if Error occurs, please set - # DLINFER_DISABLE_LINEAR_NZ_FORMAT=1 to disable transdata. - from .ascend.utils import nd_to_nz_spec - weight = nd_to_nz_spec(weight) return weight, bias def forward(self,