Skip to content

Commit 1d439b0

Browse files
committed
fix
1 parent f4b7526 commit 1d439b0

File tree

1 file changed

+4
-10
lines changed

1 file changed

+4
-10
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@
3131
import warnings
3232
from collections import OrderedDict
3333
from collections.abc import Mapping
34+
from copy import deepcopy
3435
from pathlib import Path
3536
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
3637

37-
import deepcopy
3838
import numpy as np
3939
import paddle
4040
import paddle.amp.auto_cast as autocast
@@ -878,7 +878,7 @@ def _load_flex_checkpoint(self, resume_from_checkpoint):
878878
for k, v in optimizer_sharded_state_dict.items():
879879
v.local_tensor._clear_to_zero_allocation()
880880

881-
if isinstance(self.optimizer.inner_opt, DygraphShardingOptimizerV2):
881+
if isinstance(self.optimizer._inner_opt, DygraphShardingOptimizerV2):
882882
color_to_comm_buffer_list = self.optimizer._color_to_comm_buffer_list
883883
for color, _comm_buffer_list in color_to_comm_buffer_list.items():
884884
for comm_buffer in _comm_buffer_list:
@@ -954,16 +954,10 @@ def _load_flex_checkpoint(self, resume_from_checkpoint):
954954

955955
for k, v in optimizer_sharded_state_dict.items():
956956
source_tensor = optimizer_sharded_state_dict_pin[k]
957-
target_tensor = paddle.zeros_like(v.local_tensor)
958-
if source_tensor.place != target_tensor.place:
959-
source_tensor = source_tensor.to(target_tensor.place)
960-
paddle.assign(source_tensor, target_tensor)
961-
target_tensor_pin = target_tensor.cpu()
962-
del target_tensor
963-
target_tensor_pin._share_buffer_to(v.local_tensor)
957+
source_tensor._share_buffer_to(v.local_tensor)
964958
del source_tensor
965959

966-
if isinstance(self.optimizer.inner_opt, DygraphShardingOptimizerV2):
960+
if isinstance(self.optimizer._inner_opt, DygraphShardingOptimizerV2):
967961
color_to_comm_buffer_list = self.optimizer._color_to_comm_buffer_list
968962
for color, _comm_buffer_list in color_to_comm_buffer_list.items():
969963
for comm_buffer in _comm_buffer_list:

0 commit comments

Comments
 (0)