Skip to content
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,6 @@ def __init__(self, model: torch.nn.Module, model_config: ModelConfig, cache_conf
backend='atbgraph')
else:
self.model = torch.compile(self.model, fullgraph=True, dynamic=True, backend='atbgraph')
if SocVersion.is_Ascend310P() and hasattr(self.model, 'get_logits'):
# Compile get_logits for Ascend310P to use ATB linear since we would convert weight to nz format
self.model.get_logits = torch.compile(self.model.get_logits,
fullgraph=True,
dynamic=True,
backend='atbgraph')

def check_enable_graph(self):
"""Check enable graph."""
Expand Down
262 changes: 95 additions & 167 deletions lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import itertools
import os
import re
from functools import lru_cache
from pathlib import Path
from typing import Dict, Tuple

Expand All @@ -13,34 +12,22 @@
from lmdeploy.utils import get_logger

from ..op_backend import DlinferOpsBackend
from .utils import nd_to_nz_spec

logger = get_logger('lmdeploy')


class SocVersion:
Ascend310P: str = 'Ascend310P'
Ascend910: str = 'Ascend910'

@classmethod
@lru_cache(maxsize=1)
def device_name(cls) -> str:
try:
import torch_npu
return torch_npu.npu.get_device_name()
except ImportError:
logger.warning('Failed to import torch_npu. Please make sure torch_npu is installed correctly. ')
except Exception as e:
logger.warning(f'Error during Ascend get device name: {str(e)}. '
'Please check your Ascend environment configuration.')
device_name: str = torch.npu.get_device_name()

@classmethod
def is_Ascend310P(cls) -> bool:
return cls.device_name().startswith(cls.Ascend310P)
return cls.device_name.startswith(cls.Ascend310P)

@classmethod
def is_Ascend910(cls) -> bool:
return cls.device_name().startswith(cls.Ascend910)
return cls.device_name.startswith(cls.Ascend910)


class AscendKVQuantMeta:
Expand Down Expand Up @@ -108,12 +95,6 @@ def get_k_block_shape(
) -> Tuple[int, ...]:
if SocVersion.is_Ascend910():
return (block_size, num_heads, head_size)
elif SocVersion.is_Ascend310P():
return (
(num_heads * head_size + 15) // 16,
block_size,
16,
)
else:
raise ValueError(f'dlinfer does not support {SocVersion.device_name()} device currently.')

Expand All @@ -126,160 +107,116 @@ def get_v_block_shape(
) -> Tuple[int, ...]:
if SocVersion.is_Ascend910():
return (block_size, num_heads, head_size)
elif SocVersion.is_Ascend310P():
return (
(num_heads * head_size + 15) // 16,
block_size,
16,
)
else:
raise ValueError(f'dlinfer does not support {SocVersion.device_name()} device currently.')

@classmethod
@lru_cache(maxsize=1)
def enable_aclgraph(cls) -> bool:
if os.getenv('ASCEND_GRAPH_MODE', 'aclgraph') == 'aclgraph' and not SocVersion.is_Ascend310P():
return True
elif os.getenv('ASCEND_GRAPH_MODE', 'aclgraph') == 'atbgraph' or SocVersion.is_Ascend310P():
return False
else:
raise ValueError(f"unsupported ASCEND_GRAPH_MODE: {os.getenv('ASCEND_GRAPH_MODE')}")

@classmethod
def update_step_context(cls, step_context):
"""Update step context."""

kv_start_indices, attention_mask = [], []
block_num, block_size, *_ = step_context.kv_caches[0][0].shape
is_unpaged_prefill = False
if not step_context.is_decoding:
is_unpaged_prefill = all((step_context.q_seqlens == step_context.kv_seqlens).tolist())
if step_context.block_offsets.dtype != torch.int32:
step_context.block_offsets = step_context.block_offsets.to(torch.int32)
if not (step_context.is_decoding or is_unpaged_prefill):
step_context.block_offsets = step_context.block_offsets.repeat_interleave(step_context.q_seqlens, 0)
if step_context.kv_seqlens.dtype != torch.int32:
step_context.kv_seqlens = step_context.kv_seqlens.to(torch.int32)
if step_context.q_seqlens.dtype != torch.int32:
step_context.q_seqlens = step_context.q_seqlens.to(torch.int32)

def get_total_slots():
if cls.total_slots is None:
cls.total_slots = torch.arange(block_num * block_size,
dtype=torch.long,
dtype=torch.int32,
device=step_context.block_offsets.device)
cls.total_slots = cls.total_slots.view(block_num, block_size)
return cls.total_slots

kv_start_indices, attention_mask = [], []
if SocVersion.is_Ascend910():
block_num, block_size, *_ = step_context.kv_caches[0][0].shape
elif SocVersion.is_Ascend310P():
block_num, _, block_size, _ = step_context.kv_caches[0][0].shape
def get_cpu_seqlens(is_decoding, is_unpaged_prefill):
if is_decoding:
q_seqlens_cpu, kv_seqlens_cpu = None, step_context.kv_seqlens.cpu()
elif is_unpaged_prefill:
q_seqlens_cpu = kv_seqlens_cpu = step_context.q_seqlens.cpu()
else:
q_seqlens_cpu = step_context.q_seqlens.cpu()
kv_seqlens_cpu = step_context.kv_seqlens.cpu()
return q_seqlens_cpu, kv_seqlens_cpu

is_unpaged_prefill = False
if not step_context.is_decoding:
is_unpaged_prefill = \
all((step_context.q_seqlens ==
step_context.kv_seqlens).tolist())
q_seqlens_list = step_context.q_seqlens.tolist()
kv_seqlens_list = step_context.kv_seqlens.tolist()
max_q_seq_len = max(q_seqlens_list)
max_kv_seq_len = max(kv_seqlens_list)

if step_context.is_decoding:
# collect kv_start_indices without using a for-loop,
# (fill kv-cache for just ONE token during the decoding phase)
idx = (step_context.kv_seqlens - 1) % block_size
block_num = (step_context.kv_seqlens - 1) // block_size
last_block = step_context.block_offsets.gather(1, block_num.view(-1, 1)).view(-1)
kv_start_indices = last_block * block_size + idx
else:
for i in range(step_context.q_start_loc.size(0)):
q_seq_len = q_seqlens_list[i]
kv_seq_len = kv_seqlens_list[i]

# collect kv start indices during the prefill phase.
history_length = kv_seq_len - q_seq_len
total_slots = get_total_slots()
slot_tables = total_slots[step_context.block_offsets[i]].view(-1)
slots = slot_tables[history_length:kv_seq_len]
kv_start_indices.append(slots)

# collect attention mask of paged_prefill attention stage.
if not is_unpaged_prefill:
single_attention_mask = torch.logical_not(
torch.tril(
def get_list_seqlens(is_decoding, is_unpaged_prefill, q_seqlens_cpu=None, kv_seqlens_cpu=None):
if is_decoding:
q_seqlens_list, kv_seqlens_list = None, None
elif is_unpaged_prefill:
q_seqlens_list = kv_seqlens_list = q_seqlens_cpu.tolist()
else:
q_seqlens_list, kv_seqlens_list = q_seqlens_cpu.tolist(), kv_seqlens_cpu.tolist()
return q_seqlens_list, kv_seqlens_list

def get_max_seqlens(is_decoding, is_unpaged_prefill, q_seqlens_list=None, kv_seqlens_list=None):
if is_decoding:
max_q_seq_len, max_kv_seq_len = 1, None
elif is_unpaged_prefill:
max_q_seq_len = max_kv_seq_len = max(q_seqlens_list)
else:
max_q_seq_len = max(q_seqlens_list)
max_kv_seq_len = max(kv_seqlens_list)
return max_q_seq_len, max_kv_seq_len

def get_kv_start_indices_and_attention_mask(is_decoding, is_unpaged_prefill, q_seqlens_list, kv_seqlens_list,
max_q_seq_len, max_kv_seq_len):
kv_start_indices, attention_mask = [], []
if is_decoding:
idx = (step_context.kv_seqlens - 1) % block_size
block_num = (step_context.kv_seqlens - 1) // block_size
last_block = step_context.block_offsets.gather(1, block_num.view(-1, 1)).view(-1)
kv_start_indices = last_block * block_size + idx
else:
for i in range(step_context.q_start_loc.size(0)):
q_seq_len = q_seqlens_list[i]
kv_seq_len = kv_seqlens_list[i]

history_length = kv_seq_len - q_seq_len
total_slots = get_total_slots()
slot_tables = total_slots[step_context.block_offsets[i]].view(-1)
slots = slot_tables[history_length:kv_seq_len]
kv_start_indices.append(slots)

if not is_unpaged_prefill:
single_attention_mask = torch.triu(
torch.ones(q_seq_len,
step_context.block_offsets.shape[1] * block_size,
dtype=torch.bool,
device=step_context.block_offsets.device),
diagonal=kv_seq_len - q_seq_len,
))
attention_mask.append(single_attention_mask)

kv_start_indices = torch.cat(kv_start_indices)

if step_context.is_decoding:
# prepare some params of paged_decode attention stage.
q_start_loc_cpu, q_seqlens_cpu = None, None
elif is_unpaged_prefill:
# prepare some params of unpaged_prefill attention stage.
q_start_loc_cpu, kv_seqlens_cpu = None, None
q_seqlens_cpu = step_context.q_seqlens.cpu().to(torch.int32)
if SocVersion.is_Ascend910():
single_attention_mask = torch.logical_not(
torch.tril(
torch.ones(max_q_seq_len, max_kv_seq_len, dtype=torch.bool).cuda(),
diagonal=max_kv_seq_len - max_q_seq_len,
))
attention_mask.append(single_attention_mask)
elif SocVersion.is_Ascend310P():
if not cls.enable_graph:
for i in range(q_seqlens_cpu.size(0)):
single_attention_mask = torch.zeros(q_seqlens_cpu[i],
q_seqlens_cpu[i]).fill_(-float('inf')).cuda()
single_attention_mask = torch.triu(single_attention_mask, diagonal=1)
diagonal=kv_seq_len - q_seq_len + 1,
)
attention_mask.append(single_attention_mask)
else:
# Transdata needs dtype to be float16 or int8
single_attention_mask = torch.triu(
torch.ones(max_q_seq_len, max_kv_seq_len, dtype=torch.float16).fill_(-float('inf')).cuda(),
diagonal=max_kv_seq_len - max_q_seq_len + 1,
)
# Convert to NZ format
attention_mask.append(nd_to_nz_spec(single_attention_mask))
else:
raise ValueError(f"dlinfer doesn't support {SocVersion.device_name()} device currently.")
else:
# prepare some params of paged_prefill attention stage.
q_start_loc_cpu, q_seqlens_cpu = None, None
attention_mask = [torch.cat([mask for mask in attention_mask])]

if cls.enable_graph:
kv_start_indices = kv_start_indices.view(-1).to(torch.int32)
import torch._dynamo as dynamo
if not is_unpaged_prefill:
step_context.block_offsets = step_context.block_offsets.to(torch.int32)
if not step_context.is_decoding:
step_context.block_offsets = step_context.block_offsets\
.repeat_interleave(step_context.q_seqlens, 0)
dynamo.mark_dynamic(step_context.block_offsets, [0, 1])
kv_seqlens = step_context.kv_seqlens.cpu().to(torch.int32)
if not step_context.is_decoding:

if is_unpaged_prefill:
if SocVersion.is_Ascend910():
attention_mask = [mask.half() for mask in attention_mask]
attention_mask.append(
torch.triu(torch.ones(max_q_seq_len,
max_kv_seq_len,
dtype=step_context.kv_caches[0][0].dtype,
device=step_context.block_offsets.device),
diagonal=max_kv_seq_len - max_q_seq_len + 1))
else:
if SocVersion.is_Ascend910():
attention_mask = [
torch.cat([mask.half() * cls.half_negative_inf for mask in attention_mask]).unsqueeze(1)
]
elif SocVersion.is_Ascend310P():
# Convert mask to NZ format.
attention_mask = [
nd_to_nz_spec(torch.cat([mask.half() * cls.half_negative_inf for mask in attention_mask]))
]
else:
raise ValueError(f"dlinfer doesn't support {SocVersion.device_name()} device currently.")
kv_seqlens = kv_seqlens.repeat_interleave(step_context.q_seqlens, 0)
else:
if step_context.is_decoding:
kv_seqlens_cpu = step_context.kv_seqlens.cpu().to(torch.int32)
elif is_unpaged_prefill:
pass
else:
kv_seqlens_cpu = step_context.kv_seqlens.repeat_interleave(step_context.q_seqlens, 0).cpu()
block_offsets_int32 = step_context.block_offsets.to(torch.int32)
step_context.block_offsets = block_offsets_int32\
.repeat_interleave(step_context.q_seqlens, 0)
kv_seqlens = kv_seqlens_cpu
attention_mask = [torch.cat(attention_mask)]

kv_start_indices = torch.cat(kv_start_indices)

return kv_start_indices, attention_mask

q_seqlens_cpu, kv_seqlens_cpu = get_cpu_seqlens(step_context.is_decoding, is_unpaged_prefill)
q_seqlens_list, kv_seqlens_list = get_list_seqlens(step_context.is_decoding, is_unpaged_prefill, q_seqlens_cpu,
kv_seqlens_cpu)
max_q_seq_len, max_kv_seq_len = get_max_seqlens(step_context.is_decoding, is_unpaged_prefill, q_seqlens_list,
kv_seqlens_list)
kv_start_indices, attention_mask = get_kv_start_indices_and_attention_mask(step_context.is_decoding,
is_unpaged_prefill, q_seqlens_list,
kv_seqlens_list, max_q_seq_len,
max_kv_seq_len)

if not cls.enable_graph and step_context.kv_quant_policy == 8:
record_file = os.getenv('ASCEND_QUANT_RECORD_FILE')
Expand All @@ -298,9 +235,9 @@ def get_total_slots():
attn_metadata = attn_meta_cls(
step_context.is_decoding,
step_context.block_offsets,
q_start_loc=q_start_loc_cpu,
q_start_loc=None,
q_seqlens=q_seqlens_cpu,
kv_seqlens=kv_seqlens,
kv_seqlens=kv_seqlens_cpu,
kv_start_indices=kv_start_indices,
block_size=block_size,
attention_mask=attention_mask,
Expand All @@ -318,23 +255,14 @@ def get_total_slots():
def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig, cache_config: CacheConfig,
backend_config: BackendConfig, device: torch.device):
"""Build graph runner."""
if AscendOpsBackend.enable_aclgraph():
from lmdeploy.pytorch.backends.cuda.graph_runner import CUDAGraphRunner
return CUDAGraphRunner(model, model_config, cache_config, backend_config, device)
else:
from .graph_runner import AscendGraphRunner
ascend_graph_runner = AscendGraphRunner(model, model_config, cache_config, backend_config, device)
AscendOpsBackend.enable_graph = ascend_graph_runner.enable_graph
return ascend_graph_runner
from lmdeploy.pytorch.backends.cuda.graph_runner import CUDAGraphRunner
return CUDAGraphRunner(model, model_config, cache_config, backend_config, device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's make a new aclgraphrunner instead


@staticmethod
def init():
"""Initialize Ascend backend."""
try:
from torch_npu.contrib import transfer_to_npu # noqa: F401
if SocVersion.is_Ascend310P():
# NOTE: Ascend310P has a bug with InternVL vision embedding using interpolate.
torch.npu.set_compile_mode(jit_compile=False)
except ImportError:
logger.warning('Failed to import torch_npu. Please make sure torch_npu is installed correctly. '
'Ascend initialization skipped.')
Expand Down
8 changes: 0 additions & 8 deletions lmdeploy/pytorch/backends/dlinfer/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,6 @@ def update_weights(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = No
"""Update weights."""
if os.getenv('DLINFER_LINEAR_USE_NN_LAYOUT', '0') == '1':
weight = weight.data.t().contiguous()
if weight.device.type == 'npu':
from .ascend import SocVersion
if SocVersion.is_Ascend310P() and not os.getenv('DLINFER_DISABLE_LINEAR_NZ_FORMAT', '0') == '1':
# Ascend 310P device need weight to be NZ format, so Transdata it initially.
# Transdata Linear weight by default, if Error occurs, please set
# DLINFER_DISABLE_LINEAR_NZ_FORMAT=1 to disable transdata.
from .ascend.utils import nd_to_nz_spec
weight = nd_to_nz_spec(weight)
return weight, bias

def forward(self,
Expand Down