Skip to content

Commit 056cae0

Browse files
WIP
1 parent 997395d commit 056cae0

File tree

4 files changed

+293
-13
lines changed

4 files changed

+293
-13
lines changed

docs/source/training_tutorials/sft_lora_finetune_llm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def training_function(script_args, training_args):
3838
lora_alpha=16,
3939
lora_dropout=0.05,
4040
# target_modules=["q_proj", "gate_proj", "v_proj", "o_proj", "k_proj", "up_proj", "down_proj"],
41-
target_modules=["q_proj", "v_proj"],
41+
target_modules=["q_proj", "k_proj", "v_proj"],
4242
bias="none",
4343
task_type="CAUSAL_LM",
4444
)

optimum/neuron/distributed/utils.py

+91-12
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from ..utils import DynamicPatch, Patcher
3535
from ..utils.import_utils import is_neuronx_distributed_available
3636
from ..utils.misc import download_checkpoints_in_cache, is_precompilation
37-
from ..utils.peft_utils import NeuronPeftModel
37+
from ..utils.peft_utils import LoraGQAQKVParallelLinear, NeuronPeftModel
3838
from ..utils.require_utils import requires_neuronx_distributed, requires_peft, requires_safetensors, requires_torch_xla
3939

4040

@@ -67,6 +67,11 @@ def __init__(self, *args, **kwargs):
6767

6868

6969
def get_base_model_and_peft_prefix(model: torch.nn.Module) -> Tuple[torch.nn.Module, str]:
70+
"""
71+
Retrieves the base model and the associated PEFT prefix.
72+
73+
It also attaches a callable to get the original PEFT model from the base model if needed.
74+
"""
7075
if is_peft_available() and isinstance(model, NeuronPeftModel):
7176
from peft.tuners.tuners_utils import BaseTunerLayer
7277

@@ -81,9 +86,18 @@ def get_base_model_and_peft_prefix(model: torch.nn.Module) -> Tuple[torch.nn.Mod
8186
for mod in model.modules():
8287
if isinstance(mod, BaseTunerLayer):
8388
mod._peft_config = model.peft_config
89+
90+
# We need to provide a way for the base model to get access to the PEFT model instance to be able to call
91+
# methods that might be needed during parallelization, such as injecting new tuner layers.
92+
# We attach a function instead of the attribute itself to avoid an infinite loop when looping over the model's
93+
# modules.
94+
def peft_model():
95+
return model
96+
orig_model._peft_model = peft_model
8497
else:
8598
peft_prefix = ""
8699
orig_model = model
100+
87101
return orig_model, peft_prefix
88102

89103

@@ -212,15 +226,28 @@ def __init__(
212226
self.num_attention_heads = num_attention_heads
213227
self.num_key_value_heads = num_key_value_heads
214228

229+
# Creating these aliases for LoRA.
230+
self.in_features = self.input_size
231+
self.out_features = self.output_sizes
232+
215233
def get_parameter_names_mapping(
216234
self, named_modules: Dict[str, torch.nn.Module], reversed: bool = False
217235
) -> Dict[str, str]:
218236
module_to_name = {v: k for k, v in named_modules.items()}
219237
fully_qualified_name = module_to_name[self]
220238
parent_module_name, _ = fully_qualified_name.rsplit(".", maxsplit=1)
239+
240+
# There are 2 cases:
241+
# 1. The parent module is an "actual" module from the original model
242+
# 2. The parent module is a Lora layer wrapping the QGAQKVColumnParallelLinear
243+
parent_module = named_modules[parent_module_name]
244+
if isinstance(parent_module, LoraGQAQKVParallelLinear):
245+
parent_module_name, _ = parent_module_name.rsplit(".", maxsplit=1)
246+
221247
mapping = {}
222248
for qkv_proj_name, proj_name in self._qkv_proj_name_to_proj_name.items():
223249
proj_qualified_name = f"{parent_module_name}.{proj_name}"
250+
print(parent_module_name, proj_name)
224251
proj_module = named_modules[proj_qualified_name]
225252

226253
original_qualified_name = f"{parent_module_name}.{proj_name}"
@@ -235,6 +262,8 @@ def get_parameter_names_mapping(
235262
mapping[f"{original_qualified_name}.bias"] = f"{fully_qualified_name}.bias_{qkv_proj_name}"
236263
if reversed:
237264
mapping = {v: k for k, v in mapping.items()}
265+
print(mapping)
266+
# assert 3==2
238267
return mapping
239268

240269

@@ -275,6 +304,9 @@ def get_output_projection_qualified_names_after_qga_qkv_replacement(model: torch
275304
for name, mod in named_modules.items():
276305
if isinstance(mod, OptimumGQAQKVColumnParallelLinear):
277306
parent_name = name.rsplit(".", maxsplit=1)[0]
307+
parent_module = named_modules[parent_name]
308+
if isinstance(parent_module, LoraGQAQKVParallelLinear):
309+
parent_name, _ = parent_name.rsplit(".", maxsplit=1)
278310
output_projection_name = f"{parent_name}.{mod.output_proj_name}"
279311
qualified_names.add(f"{output_projection_name}.weight")
280312
if model.get_submodule(output_projection_name).bias is not None:
@@ -1361,19 +1393,64 @@ def inplace_linears_to_gqa_qkv_column_parallel_linear(
13611393
key_is_peft_tuner = False
13621394
value_is_peft_tuner = False
13631395
if is_peft_available():
1364-
from peft.tuners.tuners_utils import BaseTunerLayer
1396+
from peft.tuners.tuners_utils import BaseTunerLayer, BaseTuner
13651397

13661398
query_is_peft_tuner = isinstance(query_linear, BaseTunerLayer)
13671399
key_is_peft_tuner = isinstance(key_linear, BaseTunerLayer)
13681400
value_is_peft_tuner = isinstance(value_linear, BaseTunerLayer)
13691401

1370-
def get_parent_and_base_layer_in_tuner_layer(tuner_layer):
1371-
parent = tuner_layer
1372-
base_layer = tuner_layer
1373-
while hasattr(base_layer, "base_layer"):
1374-
parent = base_layer
1375-
base_layer = base_layer.base_layer
1376-
return parent, base_layer
1402+
# TODO: add test that all are peft tuners or none of them are.
1403+
if query_is_peft_tuner and key_is_peft_tuner and value_is_peft_tuner:
1404+
if isinstance(model, BaseTuner):
1405+
peft_model = model
1406+
elif hasattr(model, "_peft_model"):
1407+
peft_model = model._peft_model()
1408+
else:
1409+
raise RuntimeError(
1410+
"`model` must be a `PeftModel` or have the `_peft_model` method to be able to retrieve the "
1411+
"associated `PeftModel`."
1412+
)
1413+
1414+
new_module = None
1415+
for adapter_name in peft_model.base_model.active_adapters:
1416+
lora_config = peft_model.peft_config[adapter_name]
1417+
r = lora_config.r
1418+
lora_alpha = lora_config.lora_alpha
1419+
lora_dropout = lora_config.lora_dropout
1420+
if new_module is None:
1421+
# TODO: add other keyword arguments
1422+
new_module = LoraGQAQKVParallelLinear(
1423+
gqa_qkv_column_parallel_linear,
1424+
adapter_name,
1425+
r=r,
1426+
lora_alpha=lora_alpha,
1427+
lora_dropout=lora_dropout,
1428+
)
1429+
setattr(attention_layer, gqa_qkv_proj_name, new_module)
1430+
else:
1431+
new_module.update_layer(
1432+
adapter_name,
1433+
r,
1434+
lora_alpha,
1435+
lora_dropout
1436+
)
1437+
1438+
# peft_model._create_and_replace(
1439+
# peft_model.peft_config[adapter_name],
1440+
# adapter_name,
1441+
# gqa_qkv_column_parallel_linear,
1442+
# f"{attention_layer_qualified_name}.{gqa_qkv_proj_name}",
1443+
# attention_layer,
1444+
# attention_layer_qualified_name,
1445+
# )
1446+
1447+
# def get_parent_and_base_layer_in_tuner_layer(tuner_layer):
1448+
# parent = tuner_layer
1449+
# base_layer = tuner_layer
1450+
# while hasattr(base_layer, "base_layer"):
1451+
# parent = base_layer
1452+
# base_layer = base_layer.base_layer
1453+
# return parent, base_layer
13771454

13781455
fake_q_proj = FakeProj(
13791456
f"{attention_layer_qualified_name}.{queries_name}",
@@ -1383,7 +1460,7 @@ def get_parent_and_base_layer_in_tuner_layer(tuner_layer):
13831460
attention_layer_qualified_name,
13841461
gqa_qkv_proj_name,
13851462
)
1386-
if query_is_peft_tuner:
1463+
if False: # query_is_peft_tuner:
13871464
parent, _ = get_parent_and_base_layer_in_tuner_layer(query_linear)
13881465
setattr(parent, "base_layer", fake_q_proj)
13891466
_parallelize_active_adapters(
@@ -1405,7 +1482,7 @@ def get_parent_and_base_layer_in_tuner_layer(tuner_layer):
14051482
attention_layer_qualified_name,
14061483
gqa_qkv_proj_name,
14071484
)
1408-
if key_is_peft_tuner:
1485+
if False: # key_is_peft_tuner:
14091486
parent, _ = get_parent_and_base_layer_in_tuner_layer(key_linear)
14101487
setattr(parent, "base_layer", fake_k_proj)
14111488
_parallelize_active_adapters(
@@ -1427,7 +1504,7 @@ def get_parent_and_base_layer_in_tuner_layer(tuner_layer):
14271504
attention_layer_qualified_name,
14281505
gqa_qkv_proj_name,
14291506
)
1430-
if value_is_peft_tuner:
1507+
if False: # value_is_peft_tuner:
14311508
parent, _ = get_parent_and_base_layer_in_tuner_layer(value_linear)
14321509
setattr(parent, "base_layer", fake_v_proj)
14331510
_parallelize_active_adapters(
@@ -1441,6 +1518,8 @@ def get_parent_and_base_layer_in_tuner_layer(tuner_layer):
14411518
else:
14421519
setattr(attention_layer, values_name, fake_v_proj)
14431520

1521+
print(model)
1522+
14441523

14451524
@requires_neuronx_distributed
14461525
def delete_tensor_model_parallel_attributes(tensor: torch.Tensor):

optimum/neuron/trainers.py

+1
Original file line numberDiff line numberDiff line change
@@ -1654,6 +1654,7 @@ def make_inputs_require_grad(module, input, output):
16541654
model = get_peft_model(model, peft_config, autocast_adapter_dtype=False)
16551655
else:
16561656
model = get_peft_model(model, peft_config)
1657+
16571658
if (
16581659
args is not None
16591660
and args.bf16

0 commit comments

Comments
 (0)