|
31 | 31 | import warnings
|
32 | 32 | from collections import OrderedDict
|
33 | 33 | from collections.abc import Mapping
|
| 34 | +from copy import deepcopy |
34 | 35 | from pathlib import Path
|
35 | 36 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
36 | 37 |
|
37 |
| -import deepcopy |
38 | 38 | import numpy as np
|
39 | 39 | import paddle
|
40 | 40 | import paddle.amp.auto_cast as autocast
|
@@ -878,7 +878,7 @@ def _load_flex_checkpoint(self, resume_from_checkpoint):
|
878 | 878 | for k, v in optimizer_sharded_state_dict.items():
|
879 | 879 | v.local_tensor._clear_to_zero_allocation()
|
880 | 880 |
|
881 |
| - if isinstance(self.optimizer.inner_opt, DygraphShardingOptimizerV2): |
| 881 | + if isinstance(self.optimizer._inner_opt, DygraphShardingOptimizerV2): |
882 | 882 | color_to_comm_buffer_list = self.optimizer._color_to_comm_buffer_list
|
883 | 883 | for color, _comm_buffer_list in color_to_comm_buffer_list.items():
|
884 | 884 | for comm_buffer in _comm_buffer_list:
|
@@ -954,16 +954,10 @@ def _load_flex_checkpoint(self, resume_from_checkpoint):
|
954 | 954 |
|
955 | 955 | for k, v in optimizer_sharded_state_dict.items():
|
956 | 956 | source_tensor = optimizer_sharded_state_dict_pin[k]
|
957 |
| - target_tensor = paddle.zeros_like(v.local_tensor) |
958 |
| - if source_tensor.place != target_tensor.place: |
959 |
| - source_tensor = source_tensor.to(target_tensor.place) |
960 |
| - paddle.assign(source_tensor, target_tensor) |
961 |
| - target_tensor_pin = target_tensor.cpu() |
962 |
| - del target_tensor |
963 |
| - target_tensor_pin._share_buffer_to(v.local_tensor) |
| 957 | + source_tensor._share_buffer_to(v.local_tensor) |
964 | 958 | del source_tensor
|
965 | 959 |
|
966 |
| - if isinstance(self.optimizer.inner_opt, DygraphShardingOptimizerV2): |
| 960 | + if isinstance(self.optimizer._inner_opt, DygraphShardingOptimizerV2): |
967 | 961 | color_to_comm_buffer_list = self.optimizer._color_to_comm_buffer_list
|
968 | 962 | for color, _comm_buffer_list in color_to_comm_buffer_list.items():
|
969 | 963 | for comm_buffer in _comm_buffer_list:
|
|
0 commit comments