diff --git a/aphrodite/common/config.py b/aphrodite/common/config.py index 4a2c3097b..e5331247f 100644 --- a/aphrodite/common/config.py +++ b/aphrodite/common/config.py @@ -1576,6 +1576,7 @@ class LoRAConfig: # This is a constant. lora_vocab_padding_size: ClassVar[int] = 256 long_lora_scaling_factors: Optional[Tuple[float]] = None + enable_lora_modules_to_save: bool = False def __post_init__(self): # Setting the maximum rank to 256 should be able to satisfy the vast diff --git a/aphrodite/engine/args_tools.py b/aphrodite/engine/args_tools.py index b35e2b403..7adbde9e7 100644 --- a/aphrodite/engine/args_tools.py +++ b/aphrodite/engine/args_tools.py @@ -141,6 +141,7 @@ class EngineArgs: lora_dtype: str = "auto" max_cpu_loras: Optional[int] = None long_lora_scaling_factors: Optional[Tuple[float]] = None + enable_lora_modules_to_save: bool = False fully_sharded_loras: bool = False qlora_adapter_name_or_path: Optional[str] = None enable_prompt_adapter: bool = False @@ -832,6 +833,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=None, help="Category: Adapter Options\n" "Name or path of the LoRA adapter to use.") + parser.add_argument( + "--enable-lora-modules-to-save", + action="store_true", + help="Category: Adapter Options\n" + "If True, fully trained lm_head and embed_tokens " + "in LoRA will be used instead of A*B-style adapters.") parser.add_argument('--enable-prompt-adapter', action='store_true', help='Category: Adapter Options\n' @@ -1058,6 +1065,7 @@ def create_engine_config(self, ) -> EngineConfig: lora_extra_vocab_size=self.lora_extra_vocab_size, long_lora_scaling_factors=self.long_lora_scaling_factors, lora_dtype=self.lora_dtype, + enable_lora_modules_to_save=self.enable_lora_modules_to_save, max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 else None) if self.enable_lora else None diff --git a/aphrodite/lora/layers.py b/aphrodite/lora/layers.py index 5ad557277..8193fc152 100644 --- a/aphrodite/lora/layers.py +++ b/aphrodite/lora/layers.py @@ -27,7 +27,7 @@ from aphrodite.modeling.layers.rotary_embedding import ( LinearScalingRotaryEmbedding, RotaryEmbedding) from aphrodite.modeling.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) + ParallelLMHead, VocabParallelEmbedding) if TYPE_CHECKING: pass @@ -64,6 +64,25 @@ def dec(*args, **kwargs): return dec +class TensorPropertiesMixin: + + @property + def dtype(self): + return self._dtype + + @dtype.setter + def dtype(self, value): + self._dtype = value + + @property + def device(self): + return self._device + + @device.setter + def device(self, value): + self._device = value + + @dataclass class LoRAMapping(AdapterMapping): is_prefill: bool = False @@ -124,11 +143,13 @@ def can_replace_layer( raise NotImplementedError -class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): +class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA, TensorPropertiesMixin): def __init__(self, base_layer: VocabParallelEmbedding) -> None: super().__init__() self.base_layer = base_layer + self.dtype = self.base_layer.weight.dtype + self.device = self.base_layer.weight.device self.embeddings_slice: Optional[Tuple[int, int]] self.embeddings_weights: Optional[torch.Tensor] @@ -155,25 +176,20 @@ def create_lora_weights( self.embeddings_slice = None self.embeddings_weights = None - self.embeddings_tensors = torch.zeros( - ( - max_loras, - lora_config.lora_extra_vocab_size, - self.base_layer.embedding_dim, - ), - dtype=self.base_layer.weight.dtype, - device=self.base_layer.weight.device, - ) - self.lora_a_stacked = torch.zeros( - ( - max_loras, - self.base_layer.org_vocab_size + - lora_config.lora_extra_vocab_size, - lora_config.max_lora_rank, - ), - dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, - ) + self.embeddings_tensors = torch.zeros(( + max_loras, + lora_config.lora_extra_vocab_size, + self.base_layer.embedding_dim, + ), + dtype=self.dtype, + device=self.device) + self.lora_a_stacked = torch.zeros(( + max_loras, + self.base_layer.org_vocab_size + lora_config.lora_extra_vocab_size, + lora_config.max_lora_rank, + ), + dtype=self.dtype, + device=self.device) self.lora_b_stacked = torch.zeros( ( max_loras, @@ -182,7 +198,7 @@ def create_lora_weights( lora_config.max_lora_rank, ), dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ) self.lora_a_stacked_2d = self.lora_a_stacked.view( self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1], @@ -260,6 +276,9 @@ def can_replace_layer( packed_modules_list: List, model_config: Optional[PretrainedConfig], ) -> bool: + # do not use A*B-style LoRA, try to use modules_to_save + if lora_config.enable_lora_modules_to_save: + return False return type(source_layer) is VocabParallelEmbedding @@ -1010,7 +1029,7 @@ def can_replace_layer( return type(source_layer) is RowParallelLinear -class LogitsProcessorWithLoRA(BaseLayerWithLoRA): +class LogitsProcessorWithLoRA(BaseLayerWithLoRA, TensorPropertiesMixin): """ LoRA wrapper for LogitsProcessor, with extra logic to handle the application of the LoRA adapter and added LoRA vocabulary. @@ -1306,3 +1325,134 @@ def can_replace_layer( def extra_repr(self) -> str: return self.base_layer.extra_repr() + + +class ModulesToSaveWrapper(BaseLayerWithLoRA, TensorPropertiesMixin): + """ + LoRA wrapper for lm_head layer, inspired by ModulesToSaveWrapper from peft + contains the copy of base_layer but with replaced weights + overrides getattr in a such way that + returns the attribute of this base_layer copy, + so clients can call ModuleToSave exactly as base_layer module + + Args: + base_layer: layer to replace by Wrapper: + VocabParallelEmbedding (for embed_tokens) + or ParallelLMHead (for lm_head) + """ + + implemented_layers = ['lm_head', 'embed_tokens'] + + def __init__( + self, base_layer: Union[VocabParallelEmbedding, + ParallelLMHead]) -> None: + super().__init__() + self.base_layer = base_layer + + self.device = _get_lora_device(self.base_layer) + + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + @property + def padded_vocab_size(self): + # number of embeddings with paddings and with max_lora_extra_vocab_size + return self.base_layer.num_embeddings_padded + + @property + def org_vocab_size(self): + return self.base_layer.org_vocab_size + + @property + def embedding_dim(self): + return self.base_layer.embedding_dim + + @property + def bias(self): + return self.base_layer.bias + + @property + def linear_method(self): + if self.punica_wrapper.no_lora: + return self.base_layer.linear_method + + return self + + @property + def weight(self): + return self.base_layer.weight + + def apply(self, lm_head: 'ModulesToSaveWrapper', + hidden_states: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + + assert isinstance(self.base_layer, ParallelLMHead) + + logits = self.punica_wrapper.bgmv_sample(hidden_states, + self._lora_tensors, + self.base_layer.weight) + + if bias is not None: + logits += bias + + return logits + + def embedding(self, embed_tokens: 'ModulesToSaveWrapper', + masked_input: torch.LongTensor): + assert isinstance(self.base_layer, VocabParallelEmbedding) + embeddings = self.punica_wrapper.bgmv_embedding( + masked_input, self._lora_tensors, self.base_layer.weight) + return embeddings + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + + self.dtype = lora_config.lora_dtype + + # lora_tensors - lm_head tensors in case of ParallelLMHead base + # or embed_tokens tensors in case of VocabParallelEmbedding + self._lora_tensors = torch.zeros( + (max_loras, self.padded_vocab_size, self.base_layer.embedding_dim), + dtype=self.base_layer.weight.dtype, + device=self.device, + ) + for index in range(max_loras): + self.reset_lora(index) + + def reset_lora(self, index: int): + weights = self.base_layer.weight + self._lora_tensors[index, :weights.shape[0], :weights.shape[1]].copy_( + weights, non_blocking=True) + + def set_lora( + self, + index: int, + lora_a: Optional[torch.Tensor], + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + assert lora_a is None + assert embeddings_tensor is None + + self.reset_lora(index) + self._lora_tensors[index, :lora_b.shape[0], :lora_b.shape[1]].copy_( + lora_b, non_blocking=True) + + def forward(self, *args, **kwargs): + return type(self.base_layer).forward(self, *args, **kwargs) + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + if not lora_config.enable_lora_modules_to_save: + return False + return type(source_layer) in (ParallelLMHead, VocabParallelEmbedding) diff --git a/aphrodite/lora/lora.py b/aphrodite/lora/lora.py index 1ba7082cc..73da35f7d 100644 --- a/aphrodite/lora/lora.py +++ b/aphrodite/lora/lora.py @@ -12,21 +12,28 @@ class LoRALayerWeights: def __init__( self, module_name: str, - rank: int, + rank: Optional[int], lora_alpha: int, - lora_a: torch.Tensor, + lora_a: Optional[torch.Tensor], lora_b: torch.Tensor, embeddings_tensor: Optional[torch.Tensor] = None, scaling: Optional[float] = None, ) -> None: + """ + rank == None means that we have full rank tensors (ModulesToSave) + in this case: + lora_a=None + lora_b=full rank tensor + """ self.module_name = module_name self.rank = rank self.lora_alpha = lora_alpha self.lora_a = lora_a self.lora_b = lora_b self.embeddings_tensor = embeddings_tensor + self.scaling: Optional[float] - if scaling is None: + if (scaling is None) and (self.rank is not None): self.scaling = self.lora_alpha / self.rank else: self.scaling = scaling @@ -41,7 +48,10 @@ def optimize(self) -> "LoRALayerWeights": @property def input_dim(self) -> int: - return self.lora_a.shape[0] + if self.lora_a is not None: + return self.lora_a.shape[0] + + return self.lora_b.shape[0] @property def output_dim(self) -> int: @@ -62,34 +72,53 @@ def create_dummy_lora_weights( module_name: str, input_dim: int, output_dim: int, - rank: int, + rank: Optional[int], dtype: torch.types.Device, device: torch.device, embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights": pin_memory = str(device) == "cpu" and is_pin_memory_available() - lora_a = torch.zeros([input_dim, rank], - dtype=dtype, - device=device, - pin_memory=pin_memory) - lora_b = torch.zeros([rank, output_dim], - dtype=dtype, - device=device, - pin_memory=pin_memory) - embeddings_tensor = torch.rand( - 10, - embeddings_tensor_dim, - dtype=dtype, - device=device, - pin_memory=pin_memory) if embeddings_tensor_dim else None + if rank is None: + lora_a = None + lora_b = torch.zeros([input_dim, output_dim], + dtype=dtype, + device=device, + pin_memory=pin_memory) + embeddings_tensor = None + scaling = 1 + else: + lora_a = torch.zeros([input_dim, rank], + dtype=dtype, + device=device, + pin_memory=pin_memory) + lora_b = torch.zeros([rank, output_dim], + dtype=dtype, + device=device, + pin_memory=pin_memory) + scaling = None + + embeddings_tensor = torch.rand( + 10, + embeddings_tensor_dim, + dtype=dtype, + device=device, + pin_memory=pin_memory) if embeddings_tensor_dim else None return cls( module_name, rank=rank, lora_alpha=1, lora_a=lora_a, lora_b=lora_b, + scaling=scaling, embeddings_tensor=embeddings_tensor, ) + def lora_a_pin_memory(self): + if self.lora_a is not None: + self.lora_a = self.lora_a.pin_memory() + + def lora_b_pin_memory(self): + self.lora_b = self.lora_b.pin_memory() + class PackedLoRALayerWeights(LoRALayerWeights): """LoRA used for packed layers (eg. qkv_proj).""" @@ -97,7 +126,7 @@ class PackedLoRALayerWeights(LoRALayerWeights): def __init__( self, module_name: str, - rank: int, + rank: Optional[int], lora_alphas: List[Optional[int]], lora_a: List[Optional[torch.Tensor]], lora_b: List[Optional[torch.Tensor]], @@ -113,7 +142,7 @@ def __init__( embeddings_tensor=None, ) self.lora_alphas = lora_alphas - if scaling is None: + if (scaling is None) and (self.rank is not None): self.scaling = [ # type: ignore lora_alpha / self.rank # type: ignore # noqa for lora_alpha in self.lora_alphas diff --git a/aphrodite/lora/models.py b/aphrodite/lora/models.py index 8d94986bc..d3d0f7e6b 100644 --- a/aphrodite/lora/models.py +++ b/aphrodite/lora/models.py @@ -21,7 +21,7 @@ from aphrodite.common.utils import is_pin_memory_available from aphrodite.lora.layers import (BaseLayerWithLoRA, LinearScalingRotaryEmbeddingWithLora, - LoRAMapping) + LoRAMapping, ModulesToSaveWrapper) from aphrodite.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from aphrodite.lora.punica import PunicaWrapper from aphrodite.lora.utils import (from_layer, from_layer_logits_processor, @@ -105,6 +105,7 @@ def from_lora_tensors( rank: int, lora_alpha: int, tensors: Dict[str, torch.Tensor], + enable_lora_modules_to_save: bool = False, device: str = "cuda", dtype: Optional[torch.dtype] = None, embeddings: Optional[Dict[str, torch.Tensor]] = None, @@ -117,7 +118,8 @@ def from_lora_tensors( pin_memory = str(device) == "cpu" and is_pin_memory_available() loras: Dict[str, LoRALayerWeights] = {} for tensor_name, tensor in tensors.items(): - module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name) + module_name, is_lora_a = parse_fine_tuned_lora_name( + tensor_name, enable_lora_modules_to_save) if module_name not in loras: lora_embeddings_tensor = None if embeddings: @@ -139,8 +141,12 @@ def from_lora_tensors( loras[module_name].lora_a = tensor.to(device=device, dtype=dtype).t() if pin_memory: - loras[module_name].lora_a = loras[ - module_name].lora_a.pin_memory() + loras[module_name].lora_a_pin_memory() + elif is_lora_a is None: # this is modules_to_save tensor + loras[module_name].lora_b = tensor.to(device=device, + dtype=dtype) + if pin_memory: + loras[module_name].lora_b_pin_memory() else: loras[module_name].lora_b = tensor.to(device=device, dtype=dtype).t() @@ -166,10 +172,12 @@ def from_local_checkpoint( cls, lora_dir: str, expected_lora_modules: List[str], + expected_modules_to_save: List[str], *, max_position_embeddings: Optional[int] = None, lora_model_id: Optional[int] = None, device: str = "cuda", + enable_lora_modules_to_save: bool = False, dtype: Optional[torch.dtype] = None, target_embedding_padding: Optional[int] = None, embedding_modules: Optional[Dict[str, str]] = None, @@ -213,9 +221,15 @@ def from_local_checkpoint( with safetensors.safe_open(lora_tensor_path, framework="pt") as f: # type: ignore for lora_module in f.keys(): # noqa - module_name, _ = parse_fine_tuned_lora_name(lora_module) + module_name, is_lora_a = parse_fine_tuned_lora_name( + lora_module, enable_lora_modules_to_save) part_name = module_name.split(".")[-1] - if part_name not in expected_lora_modules: + + is_expected_module_to_save = (is_lora_a is None) and ( + part_name in expected_modules_to_save) + + if (part_name not in expected_lora_modules + ) and not is_expected_module_to_save: unexpected_modules.append(module_name) if unexpected_modules: raise ValueError( @@ -278,6 +292,7 @@ def from_local_checkpoint( rank=rank, lora_alpha=lora_alpha, tensors=tensors, + enable_lora_modules_to_save=enable_lora_modules_to_save, device=device, dtype=dtype, embeddings=embeddings, @@ -452,7 +467,9 @@ def _create_lora_modules(self): self.scaling_factor_to_offset = \ new_module.scaling_factor_to_offset # (yard1): TODO make this more robust - if "lm_head" in module_name: + # replace lm_head by A*B lora if needed + if (("lm_head" in module_name) + and not self.lora_config.enable_lora_modules_to_save): logits_processor_module = self.model.get_submodule( "logits_processor") new_module = replace_submodule( @@ -498,12 +515,14 @@ def create_dummy_lora( hasattr(module.base_layer, "embedding_dim") else module.base_layer.weight.shape[1]) + rank_ = None if isinstance(module, + ModulesToSaveWrapper) else rank lora = LoRALayerWeights.create_dummy_lora_weights( module_name, input_dim, output_dim, - rank, - module.lora_a_stacked.dtype, + rank_, + module.dtype, "cpu", embeddings_tensor_dim=embeddings_tensor_dim) else: diff --git a/aphrodite/lora/ops/bgmv_embed.py b/aphrodite/lora/ops/bgmv_embed.py new file mode 100644 index 000000000..2a0c71a29 --- /dev/null +++ b/aphrodite/lora/ops/bgmv_embed.py @@ -0,0 +1,126 @@ +import torch +import triton +import triton.language as tl + +from .utils import get_lora_op_configs + + +@triton.jit +def _bgmv_embed_kernel( + tokens, # pointer to tokens array + embed_tokens_all, # pointer to embedded tokens - all + embed_tokens_base, # pointer to embedded tokens - base + token_indices, # pointer to token indices + embeddings, # pointer to output embeddings + num_tokens, # number of tokens + HIDDEN_DIM: tl.constexpr, # hidden dimension + VOCAB_SIZE: tl.constexpr, # vocabulary size + BLOCK_N: tl.constexpr # block size (number of tokens per block) +): + # Calculate the starting index for this block + start_idx = tl.program_id(0) * BLOCK_N + # Create an array of offsets for the tokens in this block + offs_n = start_idx + tl.arange(0, BLOCK_N) + # Create a mask to handle cases where we exceed num_tokens + mask = offs_n < num_tokens + + # Load lora_index and tokens for the current block (masked) + lora_index = tl.load(token_indices + offs_n, mask=mask, other=-1) + cur_tokens = tl.load(tokens + offs_n, mask=mask, other=0) + + # Compute offsets into the embedding matrices + hidden_range = tl.arange(0, HIDDEN_DIM) + offsets_embed = cur_tokens[:, None] * HIDDEN_DIM + hidden_range[ + None, :] # Shape: (BLOCK_N, HIDDEN_DIM) + + # Load embeddings from embed_tokens_base + embeddings_base = tl.load(embed_tokens_base + offsets_embed, + mask=mask[:, None], + other=0.0) + + # Initialize embeddings_block with embeddings_base + embeddings_block = embeddings_base + + # Create a mask for tokens that require loading from embed_tokens_all + mask_all = (lora_index != -1) & mask + + # For tokens with lora_index != -1, load from embed_tokens_all + + # Calculate base offsets for tokens with lora_index != -1 + # Use tl.where to avoid invalid memory accesses + base_offsets_all = tl.where(mask_all, lora_index * HIDDEN_DIM * VOCAB_SIZE, + 0) + # Calculate full offsets into embed_tokens_all + full_offsets_all = base_offsets_all[:, None] + offsets_embed + # Load embeddings from embed_tokens_all + embeddings_all = tl.load(embed_tokens_all + full_offsets_all, + mask=mask_all[:, None], + other=0.0) + # Overwrite embeddings_block where lora_index != -1 + embeddings_block = tl.where(mask_all[:, None], embeddings_all, + embeddings_block) + + # Calculate the offsets where embeddings should be stored + output_offsets = offs_n[:, None] * HIDDEN_DIM + hidden_range[None, :] + + # Store embeddings_block to the output embeddings array + tl.store(embeddings + output_offsets, embeddings_block, mask=mask[:, None]) + + +@torch.inference_mode() +def _bgmv_embed( + tokens: torch.Tensor, + embed_tokens_all: torch.Tensor, + embed_tokens_base: torch.Tensor, + token_indices: torch.Tensor, +) -> torch.Tensor: + """ + Args: + tokens - [num_tokens] - input tokens + embed_tokens_all - [num_loras, vocab_size, hidden_dim] + modules_to_save embeddings + embed_tokens_base - [vocab_size, hidden_dim] - base layer + embeddings will be applied to tokens with index=-1 + token_indices - [num_tokens] LoRA indices from 0 to num_loras, + -1 means no LoRA, embed_tokens_base will be used + returns: + embeddings: [num_tokens, hidden_dim] + """ + + assert embed_tokens_all.dtype == embed_tokens_base.dtype + assert tokens.dtype == torch.int64 + assert token_indices.dtype == torch.int64 + + assert embed_tokens_base.is_contiguous() + assert embed_tokens_all.is_contiguous() + + vocab_size, hidden_dim = embed_tokens_all.shape[-2:] + num_tokens = tokens.shape[0] + embeddings = torch.zeros((num_tokens, hidden_dim), + dtype=embed_tokens_all.dtype, + device=embed_tokens_all.device) + + grid = lambda meta: (triton.cdiv(num_tokens, meta['BLOCK_N']), ) + + config = get_lora_op_configs("embed", num_tokens, hidden_dim) + + _bgmv_embed_kernel[grid]( + tokens, + embed_tokens_all, + embed_tokens_base, + token_indices, + embeddings, + num_tokens, + HIDDEN_DIM=hidden_dim, + VOCAB_SIZE=vocab_size, + **config, + ) + return embeddings + + +try: + bgmv_embed = torch.library.custom_op("lora::bgmv_embed", + _bgmv_embed, + mutates_args=[]) +except AttributeError: + bgmv_embed = _bgmv_embed diff --git a/aphrodite/lora/ops/bgmv_sample.py b/aphrodite/lora/ops/bgmv_sample.py new file mode 100644 index 000000000..1ad0b732c --- /dev/null +++ b/aphrodite/lora/ops/bgmv_sample.py @@ -0,0 +1,90 @@ +import torch +import triton +import triton.language as tl + +from .utils import get_lora_op_configs + + +@triton.jit +def _bgmv_sample_kernel(hidden_state_ptr, lm_heads_all_ptr, lm_head_base_ptr, + logits_ptr, sampling_indices_tensor_ptr, + HIDDEN_DIM: tl.constexpr, VOCAB_SIZE: tl.constexpr, + BLOCK_N: tl.constexpr): + + cur_token = tl.program_id(axis=0) + + logits_start_idx = tl.program_id(axis=1) * BLOCK_N + + lora_index = tl.load(sampling_indices_tensor_ptr + cur_token) + + hidden_state = tl.load(hidden_state_ptr + HIDDEN_DIM * cur_token + + tl.arange(0, HIDDEN_DIM)) + hidden_state = hidden_state.expand_dims(0) + + offsets_embed = tl.arange(0, HIDDEN_DIM) + offsets_logits = logits_start_idx + tl.arange(0, BLOCK_N) + + offset_base_layer = offsets_embed[ + None, :] + offsets_logits[:, None] * HIDDEN_DIM + offset_lora = lora_index * (VOCAB_SIZE * HIDDEN_DIM) + offset_base_layer + + if lora_index == -1: + weights = tl.load(lm_head_base_ptr + offset_base_layer) + else: + weights = tl.load(lm_heads_all_ptr + offset_lora) + + logits = tl.sum(weights * hidden_state, axis=1) + + tl.store(logits_ptr + cur_token * VOCAB_SIZE + offsets_logits, logits) + + +@torch.inference_mode() +def _bgmv_sample( + hidden_state: torch.Tensor, + lm_heads_all: torch.Tensor, + lm_head_base: torch.Tensor, + sampling_indices_tensor: torch.Tensor, +) -> torch.Tensor: + """ + Args: + hidden_state - [num_tokens, hidden_dim] + lm_heads_all - [num_loras, vocab_size, hidden_dim] + sampling_indices_tensor - [num_tokens] - indexes from 0 to num_loras-1 + """ + assert hidden_state.dtype == lm_heads_all.dtype + + assert hidden_state.size(-1) == lm_heads_all.size(-1) + assert hidden_state.is_contiguous() + assert lm_heads_all.is_contiguous() + + vocab_size = lm_heads_all.shape[-2] + logits = torch.zeros((hidden_state.size(0), vocab_size), + dtype=hidden_state.dtype, + device=hidden_state.device) + + num_tokens = sampling_indices_tensor.shape[0] + hidden_dim = hidden_state.shape[-1] + + grid = lambda meta: (num_tokens, triton.cdiv(vocab_size, meta['BLOCK_N'])) + + config = get_lora_op_configs("sample", num_tokens, hidden_dim) + + _bgmv_sample_kernel[grid]( + hidden_state, + lm_heads_all, + lm_head_base, + logits, + sampling_indices_tensor, + HIDDEN_DIM=hidden_dim, + VOCAB_SIZE=vocab_size, + **config, + ) + return logits + + +try: + bgmv_sample = torch.library.custom_op("lora::bgmv_sample", + _bgmv_sample, + mutates_args=[]) +except AttributeError: + bgmv_sample = _bgmv_sample diff --git a/aphrodite/lora/ops/utils.py b/aphrodite/lora/ops/utils.py index 7c3e27313..03c24def2 100644 --- a/aphrodite/lora/ops/utils.py +++ b/aphrodite/lora/ops/utils.py @@ -21,6 +21,12 @@ def _check_divisibility(hidden_size: int): def _get_default_config(op_type: str, batch: int, hidden_size: int): + if op_type == "sample": + return {"BLOCK_N": 2} + + if op_type == "embed": + return {"BLOCK_N": 4} + if op_type == "expand": return { "BLOCK_N": 256, diff --git a/aphrodite/lora/punica.py b/aphrodite/lora/punica.py index 1264c214b..b8b071638 100644 --- a/aphrodite/lora/punica.py +++ b/aphrodite/lora/punica.py @@ -13,8 +13,10 @@ from aphrodite.triton_utils import HAS_TRITON if HAS_TRITON and not is_xpu(): + from aphrodite.lora.ops.bgmv_embed import bgmv_embed from aphrodite.lora.ops.bgmv_expand import bgmv_expand from aphrodite.lora.ops.bgmv_expand_slice import bgmv_expand_slice + from aphrodite.lora.ops.bgmv_sample import bgmv_sample from aphrodite.lora.ops.bgmv_shrink import bgmv_shrink from aphrodite.lora.ops.sgmv_expand import sgmv_expand from aphrodite.lora.ops.sgmv_expand_slice import sgmv_expand_slice @@ -603,3 +605,50 @@ def add_lora_logits(self, bgmv_shrink(x, wa_t_all, buffer, self.sampler_indices, scale) bgmv_expand(buffer, wb_t_all, y, self.sampler_indices, add_inputs=True) y = y.view_as(y_org) + + def bgmv_sample(self, hidden_states: torch.Tensor, + lm_heads_all: torch.Tensor, lm_head_base: torch.Tensor): + ''' + hidden_states - [num_tokens, hidden_dim] + lm_heads_all - [num_loras, vocab_size, hidden_dim] + the same as: + vocab_size=self.lm_head_tensors.shape[-2] + hidden_dim=hidden_states.size(0) + + logits = torch.zeros((hidden_dim, vocab_size), + dtype=torch.float32, + device=hidden_states.device) + + for i in range(len(hidden_states)): + if indices[i]==-1: + logits[i]=lm_head_base @ hidden_states[i] + else: + logits[i]=self.lm_head_tensors[indices[i]] @ hidden_states[i] + ''' + + indices = self.sampler_indices + + logits = bgmv_sample(hidden_states, lm_heads_all, lm_head_base, + indices) + return logits + + def bgmv_embedding(self, tokens: torch.LongTensor, + embed_tokens_all: torch.Tensor, + embed_tokens_base: torch.Tensor) -> torch.Tensor: + ''' + embed_tokens_all - [num_loras, vocab_size, hidden_dim] + modules_to_save embeddings + embed_tokens_base - [vocab_size, hidden_dim] - base layer + embeddings will be applied to tokens with index=-1 + tokens - [num_tokens] + returns: + embeddings: [num_tokens, hidden_dim] + + ''' + + embeddings = bgmv_embed(tokens, + embed_tokens_all, + embed_tokens_base, + token_indices=self.token_lora_indices.long()) + + return embeddings diff --git a/aphrodite/lora/utils.py b/aphrodite/lora/utils.py index 1063616d8..53770579e 100644 --- a/aphrodite/lora/utils.py +++ b/aphrodite/lora/utils.py @@ -23,6 +23,7 @@ LogitsProcessorWithLoRA, MergedColumnParallelLinearWithLoRA, MergedQKVParallelLinearWithLora, + ModulesToSaveWrapper, QKVParallelLinearWithLora, ReplicatedLinearWithLoRA, RowParallelLinearWithLoRA, @@ -46,6 +47,7 @@ MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA, LinearScalingRotaryEmbeddingWithLora, + ModulesToSaveWrapper, } @@ -89,7 +91,10 @@ def replace_submodule(model: nn.Module, module_name: str, return new_module -def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]: +def parse_fine_tuned_lora_name( + name: str, + enable_lora_modules_to_save: bool = False +) -> Tuple[str, Optional[bool]]: """Parse the name of lora weights. args: @@ -98,7 +103,8 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]: return: Tuple(module_name, is_lora_a): module_name: the name of the module, e.g. model.dense1, - is_lora_a whether the tensor is lora_a or lora_b. + is_lora_a whether the tensor is lora_a or lora_b. + None - if tensor is for ModulesToSaveWrapper """ parts = name.split(".") @@ -106,6 +112,16 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]: if parts[-1] == "weight": if parts[-2] == "lora_A" or parts[-2] == "lora_B": return ".".join(parts[2:-2]), parts[-2] == "lora_A" + if parts[-2] in ModulesToSaveWrapper.implemented_layers: + + if not enable_lora_modules_to_save: + error_msg = f"""enable_lora_modules_to_save is False, + but found tensor name {name} in LoRA checkpoint. + Set enable_lora_modules_to_save=True to process + lm_head and embed_tokens as fully trained tensors""" + raise ValueError(error_msg) + + return '.'.join(parts[2:-1]), None elif parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B": return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A" diff --git a/aphrodite/lora/worker_manager.py b/aphrodite/lora/worker_manager.py index cd8327964..968308882 100644 --- a/aphrodite/lora/worker_manager.py +++ b/aphrodite/lora/worker_manager.py @@ -88,13 +88,17 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: packed_modules_mapping[module]) else: expected_lora_modules.append(module) + expected_modules_to_save: List[str] = model.modules_to_save lora_path = get_adapter_absolute_path(lora_request.lora_path) lora = self._lora_model_cls.from_local_checkpoint( lora_path, expected_lora_modules, + expected_modules_to_save, max_position_embeddings=self.max_position_embeddings, lora_model_id=lora_request.lora_int_id, device="cpu", + enable_lora_modules_to_save=self._adapter_manager.lora_config. + enable_lora_modules_to_save, dtype=self.lora_config.lora_dtype, target_embedding_padding=self.vocab_size + self.lora_config.lora_extra_vocab_size, diff --git a/aphrodite/modeling/models/interfaces.py b/aphrodite/modeling/models/interfaces.py index c51c0e9ee..24881e417 100644 --- a/aphrodite/modeling/models/interfaces.py +++ b/aphrodite/modeling/models/interfaces.py @@ -72,6 +72,7 @@ class SupportsLoRA(Protocol): supported_lora_modules: ClassVar[List[str]] embedding_modules: ClassVar[Dict[str, str]] embedding_padding_modules: ClassVar[List[str]] + modules_to_save: ClassVar[List[str]] = ["lm_head", "embed_tokens"] # lora_config is None when LoRA is not enabled def __init__(self, *, lora_config: Optional[LoRAConfig] = None) -> None: