34
34
from ..utils import DynamicPatch , Patcher
35
35
from ..utils .import_utils import is_neuronx_distributed_available
36
36
from ..utils .misc import download_checkpoints_in_cache , is_precompilation
37
- from ..utils .peft_utils import NeuronPeftModel
37
+ from ..utils .peft_utils import LoraGQAQKVParallelLinear , NeuronPeftModel
38
38
from ..utils .require_utils import requires_neuronx_distributed , requires_peft , requires_safetensors , requires_torch_xla
39
39
40
40
@@ -67,6 +67,11 @@ def __init__(self, *args, **kwargs):
67
67
68
68
69
69
def get_base_model_and_peft_prefix (model : torch .nn .Module ) -> Tuple [torch .nn .Module , str ]:
70
+ """
71
+ Retrieves the base model and the associated PEFT prefix.
72
+
73
+ It also attaches a callable to get the original PEFT model from the base model if needed.
74
+ """
70
75
if is_peft_available () and isinstance (model , NeuronPeftModel ):
71
76
from peft .tuners .tuners_utils import BaseTunerLayer
72
77
@@ -81,9 +86,18 @@ def get_base_model_and_peft_prefix(model: torch.nn.Module) -> Tuple[torch.nn.Mod
81
86
for mod in model .modules ():
82
87
if isinstance (mod , BaseTunerLayer ):
83
88
mod ._peft_config = model .peft_config
89
+
90
+ # We need to provide a way for the base model to get access to the PEFT model instance to be able to call
91
+ # methods that might be needed during parallelization, such as injecting new tuner layers.
92
+ # We attach a function instead of the attribute itself to avoid an infinite loop when looping over the model's
93
+ # modules.
94
+ def peft_model ():
95
+ return model
96
+ orig_model ._peft_model = peft_model
84
97
else :
85
98
peft_prefix = ""
86
99
orig_model = model
100
+
87
101
return orig_model , peft_prefix
88
102
89
103
@@ -212,15 +226,28 @@ def __init__(
212
226
self .num_attention_heads = num_attention_heads
213
227
self .num_key_value_heads = num_key_value_heads
214
228
229
+ # Creating these aliases for LoRA.
230
+ self .in_features = self .input_size
231
+ self .out_features = self .output_sizes
232
+
215
233
def get_parameter_names_mapping (
216
234
self , named_modules : Dict [str , torch .nn .Module ], reversed : bool = False
217
235
) -> Dict [str , str ]:
218
236
module_to_name = {v : k for k , v in named_modules .items ()}
219
237
fully_qualified_name = module_to_name [self ]
220
238
parent_module_name , _ = fully_qualified_name .rsplit ("." , maxsplit = 1 )
239
+
240
+ # There are 2 cases:
241
+ # 1. The parent module is an "actual" module from the original model
242
+ # 2. The parent module is a Lora layer wrapping the QGAQKVColumnParallelLinear
243
+ parent_module = named_modules [parent_module_name ]
244
+ if isinstance (parent_module , LoraGQAQKVParallelLinear ):
245
+ parent_module_name , _ = parent_module_name .rsplit ("." , maxsplit = 1 )
246
+
221
247
mapping = {}
222
248
for qkv_proj_name , proj_name in self ._qkv_proj_name_to_proj_name .items ():
223
249
proj_qualified_name = f"{ parent_module_name } .{ proj_name } "
250
+ print (parent_module_name , proj_name )
224
251
proj_module = named_modules [proj_qualified_name ]
225
252
226
253
original_qualified_name = f"{ parent_module_name } .{ proj_name } "
@@ -235,6 +262,8 @@ def get_parameter_names_mapping(
235
262
mapping [f"{ original_qualified_name } .bias" ] = f"{ fully_qualified_name } .bias_{ qkv_proj_name } "
236
263
if reversed :
237
264
mapping = {v : k for k , v in mapping .items ()}
265
+ print (mapping )
266
+ # assert 3==2
238
267
return mapping
239
268
240
269
@@ -275,6 +304,9 @@ def get_output_projection_qualified_names_after_qga_qkv_replacement(model: torch
275
304
for name , mod in named_modules .items ():
276
305
if isinstance (mod , OptimumGQAQKVColumnParallelLinear ):
277
306
parent_name = name .rsplit ("." , maxsplit = 1 )[0 ]
307
+ parent_module = named_modules [parent_name ]
308
+ if isinstance (parent_module , LoraGQAQKVParallelLinear ):
309
+ parent_name , _ = parent_name .rsplit ("." , maxsplit = 1 )
278
310
output_projection_name = f"{ parent_name } .{ mod .output_proj_name } "
279
311
qualified_names .add (f"{ output_projection_name } .weight" )
280
312
if model .get_submodule (output_projection_name ).bias is not None :
@@ -1361,19 +1393,64 @@ def inplace_linears_to_gqa_qkv_column_parallel_linear(
1361
1393
key_is_peft_tuner = False
1362
1394
value_is_peft_tuner = False
1363
1395
if is_peft_available ():
1364
- from peft .tuners .tuners_utils import BaseTunerLayer
1396
+ from peft .tuners .tuners_utils import BaseTunerLayer , BaseTuner
1365
1397
1366
1398
query_is_peft_tuner = isinstance (query_linear , BaseTunerLayer )
1367
1399
key_is_peft_tuner = isinstance (key_linear , BaseTunerLayer )
1368
1400
value_is_peft_tuner = isinstance (value_linear , BaseTunerLayer )
1369
1401
1370
- def get_parent_and_base_layer_in_tuner_layer (tuner_layer ):
1371
- parent = tuner_layer
1372
- base_layer = tuner_layer
1373
- while hasattr (base_layer , "base_layer" ):
1374
- parent = base_layer
1375
- base_layer = base_layer .base_layer
1376
- return parent , base_layer
1402
+ # TODO: add test that all are peft tuners or none of them are.
1403
+ if query_is_peft_tuner and key_is_peft_tuner and value_is_peft_tuner :
1404
+ if isinstance (model , BaseTuner ):
1405
+ peft_model = model
1406
+ elif hasattr (model , "_peft_model" ):
1407
+ peft_model = model ._peft_model ()
1408
+ else :
1409
+ raise RuntimeError (
1410
+ "`model` must be a `PeftModel` or have the `_peft_model` method to be able to retrieve the "
1411
+ "associated `PeftModel`."
1412
+ )
1413
+
1414
+ new_module = None
1415
+ for adapter_name in peft_model .base_model .active_adapters :
1416
+ lora_config = peft_model .peft_config [adapter_name ]
1417
+ r = lora_config .r
1418
+ lora_alpha = lora_config .lora_alpha
1419
+ lora_dropout = lora_config .lora_dropout
1420
+ if new_module is None :
1421
+ # TODO: add other keyword arguments
1422
+ new_module = LoraGQAQKVParallelLinear (
1423
+ gqa_qkv_column_parallel_linear ,
1424
+ adapter_name ,
1425
+ r = r ,
1426
+ lora_alpha = lora_alpha ,
1427
+ lora_dropout = lora_dropout ,
1428
+ )
1429
+ setattr (attention_layer , gqa_qkv_proj_name , new_module )
1430
+ else :
1431
+ new_module .update_layer (
1432
+ adapter_name ,
1433
+ r ,
1434
+ lora_alpha ,
1435
+ lora_dropout
1436
+ )
1437
+
1438
+ # peft_model._create_and_replace(
1439
+ # peft_model.peft_config[adapter_name],
1440
+ # adapter_name,
1441
+ # gqa_qkv_column_parallel_linear,
1442
+ # f"{attention_layer_qualified_name}.{gqa_qkv_proj_name}",
1443
+ # attention_layer,
1444
+ # attention_layer_qualified_name,
1445
+ # )
1446
+
1447
+ # def get_parent_and_base_layer_in_tuner_layer(tuner_layer):
1448
+ # parent = tuner_layer
1449
+ # base_layer = tuner_layer
1450
+ # while hasattr(base_layer, "base_layer"):
1451
+ # parent = base_layer
1452
+ # base_layer = base_layer.base_layer
1453
+ # return parent, base_layer
1377
1454
1378
1455
fake_q_proj = FakeProj (
1379
1456
f"{ attention_layer_qualified_name } .{ queries_name } " ,
@@ -1383,7 +1460,7 @@ def get_parent_and_base_layer_in_tuner_layer(tuner_layer):
1383
1460
attention_layer_qualified_name ,
1384
1461
gqa_qkv_proj_name ,
1385
1462
)
1386
- if query_is_peft_tuner :
1463
+ if False : # query_is_peft_tuner:
1387
1464
parent , _ = get_parent_and_base_layer_in_tuner_layer (query_linear )
1388
1465
setattr (parent , "base_layer" , fake_q_proj )
1389
1466
_parallelize_active_adapters (
@@ -1405,7 +1482,7 @@ def get_parent_and_base_layer_in_tuner_layer(tuner_layer):
1405
1482
attention_layer_qualified_name ,
1406
1483
gqa_qkv_proj_name ,
1407
1484
)
1408
- if key_is_peft_tuner :
1485
+ if False : # key_is_peft_tuner:
1409
1486
parent , _ = get_parent_and_base_layer_in_tuner_layer (key_linear )
1410
1487
setattr (parent , "base_layer" , fake_k_proj )
1411
1488
_parallelize_active_adapters (
@@ -1427,7 +1504,7 @@ def get_parent_and_base_layer_in_tuner_layer(tuner_layer):
1427
1504
attention_layer_qualified_name ,
1428
1505
gqa_qkv_proj_name ,
1429
1506
)
1430
- if value_is_peft_tuner :
1507
+ if False : # value_is_peft_tuner:
1431
1508
parent , _ = get_parent_and_base_layer_in_tuner_layer (value_linear )
1432
1509
setattr (parent , "base_layer" , fake_v_proj )
1433
1510
_parallelize_active_adapters (
@@ -1441,6 +1518,8 @@ def get_parent_and_base_layer_in_tuner_layer(tuner_layer):
1441
1518
else :
1442
1519
setattr (attention_layer , values_name , fake_v_proj )
1443
1520
1521
+ print (model )
1522
+
1444
1523
1445
1524
@requires_neuronx_distributed
1446
1525
def delete_tensor_model_parallel_attributes (tensor : torch .Tensor ):
0 commit comments