50
50
ResumeMemoryOccupationReqInput ,
51
51
UpdateWeightsFromTensorReqInput ,
52
52
)
53
- from sglang .srt .model_executor .model_runner import LocalSerializedTensor
54
53
from sglang .srt .utils import (
55
54
MultiprocessingSerializer ,
56
55
assert_pkg_version ,
@@ -397,49 +396,21 @@ def generate(self, query: List[Dict], is_eval: bool) -> List[Dict]:
397
396
self .flush_cache ()
398
397
return outputs
399
398
400
- def update_weights_from_ipc_handles (self , reduce_data , load_format = None ):
401
- if load_format == "flattened_bucket" :
402
- gathered_data = None
403
- if self .is_engine ():
404
- gathered_data = [None ] * self ._tp_size
405
- dist .gather_object (
406
- obj = reduce_data ,
407
- object_gather_list = gathered_data ,
408
- dst = self .cpu_mesh ["tp" ].mesh .tolist ()[0 ],
409
- group = self .cpu_mesh ["tp" ].get_group (),
410
- )
411
- if self .is_engine ():
412
- self .llm .update_weights_from_tensor (
413
- named_tensors = gathered_data ,
414
- load_format = load_format ,
415
- )
416
- torch .cuda .synchronize ()
417
- return
418
-
419
- for index , (name , serialized_tensor ) in enumerate (reduce_data .items ()):
420
- if self .is_engine ():
421
- gathered_serialized_tensors = [None ] * self ._tp_size
422
- else :
423
- gathered_serialized_tensors = None
424
-
425
- dist .gather_object (
426
- obj = serialized_tensor ,
427
- object_gather_list = gathered_serialized_tensors ,
428
- dst = self .cpu_mesh ["tp" ].mesh .tolist ()[0 ],
429
- group = self .cpu_mesh ["tp" ].get_group (),
399
+ def update_weights_from_ipc_handles (self , reduce_data ):
400
+ gathered_data = None
401
+ if self .is_engine ():
402
+ gathered_data = [None ] * self ._tp_size
403
+ dist .gather_object (
404
+ obj = reduce_data ,
405
+ object_gather_list = gathered_data ,
406
+ dst = self .cpu_mesh ["tp" ].mesh .tolist ()[0 ],
407
+ group = self .cpu_mesh ["tp" ].get_group (),
408
+ )
409
+ if self .is_engine ():
410
+ self .llm .update_weights_from_tensor (
411
+ named_tensors = gathered_data ,
412
+ load_format = "flattened_bucket" ,
430
413
)
431
-
432
- if self .is_engine ():
433
- self .llm .update_weights_from_tensor (
434
- named_tensors = [
435
- (
436
- name ,
437
- LocalSerializedTensor (values = gathered_serialized_tensors ),
438
- )
439
- ],
440
- # load_format=load_format,
441
- flush_cache = index == len (reduce_data ) - 1 ,
442
- )
443
414
torch .cuda .synchronize ()
444
415
445
416
def flush_cache (self ):
@@ -567,6 +538,8 @@ def parameter_sync(self):
567
538
568
539
@torch .no_grad ()
569
540
def update_weights_from_buckets (self , buckets : List [Optional ['BucketInfo' ]]):
541
+ """Used for Mcore2SGLang Parameter Sync
542
+ """
570
543
from sglang .srt .patch_torch import monkey_patch_torch_reductions
571
544
monkey_patch_torch_reductions ()
572
545
param_id_to_update = set ()
@@ -584,10 +557,12 @@ def update_weights_from_buckets(self, buckets: List[Optional['BucketInfo']]):
584
557
for shard_idx , (offset , sharded_tensor_info ) in enumerate (bucket .recv_layout ):
585
558
param_id_to_bucket [sharded_tensor_info .param_id ].append ((bucket_idx , shard_idx ))
586
559
560
+ # 1-dim concated flattened tensor
587
561
buffer = None
588
562
buffer_offset = 0
589
563
buffer_size = 4 * 1024 ** 3
590
- metadatas = []
564
+ # metadata: name, shape, dtype, start_idx, end_idx, numel for every tensor item in buffer
565
+ metadatas : List [FlattenedTensorMetadata ] = []
591
566
for param_id in param_id_to_update :
592
567
param_name = self .param_id_to_local_name [param_id ]
593
568
shard_info = self .param_id_to_metadata [param_id ]
@@ -600,7 +575,7 @@ def update_weights_from_buckets(self, buckets: List[Optional['BucketInfo']]):
600
575
serialized_bucket = MultiprocessingSerializer .serialize (
601
576
bucket_dict , output_str = True
602
577
)
603
- self .update_weights_from_ipc_handles (serialized_bucket , load_format = "flattened_bucket" )
578
+ self .update_weights_from_ipc_handles (serialized_bucket )
604
579
buffer = torch .empty (buffer_size , dtype = shard_info .dtype , device = 'cuda' )
605
580
buffer_offset = 0
606
581
metadatas = []
@@ -630,7 +605,7 @@ def update_weights_from_buckets(self, buckets: List[Optional['BucketInfo']]):
630
605
serialized_bucket = MultiprocessingSerializer .serialize (
631
606
bucket_dict , output_str = True
632
607
)
633
- self .update_weights_from_ipc_handles (serialized_bucket , load_format = "flattened_bucket" )
608
+ self .update_weights_from_ipc_handles (serialized_bucket )
634
609
635
610
del buffer , weight , shard , bucket_dict
636
611
torch .cuda .synchronize ()
@@ -727,49 +702,22 @@ async def generate_per_request(self, query: Dict, is_eval: bool) -> Dict:
727
702
)
728
703
return outputs
729
704
730
- async def update_weights_from_ipc_handles (self , reduce_data , load_format = None ):
731
- if load_format == "flattened_bucket" :
732
- gathered_data = None
733
- if self .is_engine ():
734
- gathered_data = [None ] * self ._tp_size
735
- dist .gather_object (
736
- obj = reduce_data ,
737
- object_gather_list = gathered_data ,
738
- dst = self .cpu_mesh ["tp" ].mesh .tolist ()[0 ],
739
- group = self .cpu_mesh ["tp" ].get_group (),
740
- )
741
- if self .is_engine ():
742
- await self .llm .update_weights_from_tensor (
743
- named_tensors = gathered_data ,
744
- load_format = load_format ,
745
- )
746
- torch .cuda .synchronize ()
747
- return
705
+ async def update_weights_from_ipc_handles (self , reduce_data ):
748
706
749
- for index , (name , serialized_tensor ) in enumerate (reduce_data .items ()):
750
- if self .is_engine ():
751
- gathered_serialized_tensors = [None ] * self ._tp_size
752
- else :
753
- gathered_serialized_tensors = None
754
-
755
- dist .gather_object (
756
- obj = serialized_tensor ,
757
- object_gather_list = gathered_serialized_tensors ,
758
- dst = self .cpu_mesh ["tp" ].mesh .tolist ()[0 ],
759
- group = self .cpu_mesh ["tp" ].get_group (),
707
+ gathered_data = None
708
+ if self .is_engine ():
709
+ gathered_data = [None ] * self ._tp_size
710
+ dist .gather_object (
711
+ obj = reduce_data ,
712
+ object_gather_list = gathered_data ,
713
+ dst = self .cpu_mesh ["tp" ].mesh .tolist ()[0 ],
714
+ group = self .cpu_mesh ["tp" ].get_group (),
715
+ )
716
+ if self .is_engine ():
717
+ await self .llm .update_weights_from_tensor (
718
+ named_tensors = gathered_data ,
719
+ load_format = "flattened_bucket" ,
760
720
)
761
-
762
- if self .is_engine ():
763
- await self .llm .update_weights_from_tensor (
764
- named_tensors = [
765
- (
766
- name ,
767
- LocalSerializedTensor (values = gathered_serialized_tensors ),
768
- )
769
- ],
770
- # load_format=load_format,
771
- flush_cache = index == len (reduce_data ) - 1 ,
772
- )
773
721
torch .cuda .synchronize ()
774
722
775
723
@torch .no_grad ()
@@ -807,7 +755,7 @@ async def update_weights_from_buckets(self, buckets: List[Optional['BucketInfo']
807
755
serialized_bucket = MultiprocessingSerializer .serialize (
808
756
bucket_dict , output_str = True
809
757
)
810
- await self .update_weights_from_ipc_handles (serialized_bucket , load_format = "flattened_bucket" )
758
+ await self .update_weights_from_ipc_handles (serialized_bucket )
811
759
buffer = torch .empty (buffer_size , dtype = shard_info .dtype , device = 'cuda' )
812
760
buffer_offset = 0
813
761
metadatas = []
@@ -837,7 +785,7 @@ async def update_weights_from_buckets(self, buckets: List[Optional['BucketInfo']
837
785
serialized_bucket = MultiprocessingSerializer .serialize (
838
786
bucket_dict , output_str = True
839
787
)
840
- await self .update_weights_from_ipc_handles (serialized_bucket , load_format = "flattened_bucket" )
788
+ await self .update_weights_from_ipc_handles (serialized_bucket )
841
789
842
790
del buffer , weight , shard , bucket_dict
843
791
torch .cuda .synchronize ()
0 commit comments