diff --git a/csrc/models/llama/llama_attention.cpp b/csrc/models/llama/llama_attention.cpp index fe76479b..9cd4828f 100644 --- a/csrc/models/llama/llama_attention.cpp +++ b/csrc/models/llama/llama_attention.cpp @@ -202,8 +202,9 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta } infinicore::Tensor attn_output; - if (false) { - // experimental nineoothed flash attention + if (q_reshaped->device().getType() == infinicore::Device::Type::NVIDIA + || q_reshaped->device().getType() == infinicore::Device::Type::ILUVATAR + || q_reshaped->device().getType() == infinicore::Device::Type::CAMBRICON) { attn_output = infinicore::op::flash_attention(q_reshaped, k_total, v_total, total_sequence_lengths.value(), scaling_, true); attn_output = attn_output->permute({0, 2, 1, 3}) ->contiguous() diff --git a/python/infinilm/infer_engine.py b/python/infinilm/infer_engine.py index 8b614980..8c9e242f 100644 --- a/python/infinilm/infer_engine.py +++ b/python/infinilm/infer_engine.py @@ -73,7 +73,7 @@ def forward( past_kv_lengths._underlying if past_kv_lengths is not None else None ) total_kv_lengths = ( - total_kv_lengths._underlying if past_kv_lengths is not None else None + total_kv_lengths._underlying if total_kv_lengths is not None else None ) input_offsets = input_offsets._underlying if input_offsets is not None else None block_tables = block_tables._underlying if block_tables is not None else None