|
10 | 10 | from torch.nn.parameter import Parameter
|
11 | 11 | from deepspeed.accelerator import get_accelerator
|
12 | 12 | from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list
|
| 13 | +from deepspeed.runtime.zero.utils import is_zero_param |
13 | 14 | from abc import ABC, abstractmethod
|
14 | 15 | from typing import Iterable, Any, Optional, List, Tuple
|
15 | 16 | from .fusedqkv_utils import shard_value_with_share_qk, shard_chunk_mlp, prepare_tp_fused_qkvw
|
@@ -262,12 +263,13 @@ def __deepcopy__(self, memo):
|
262 | 263 | return new_obj
|
263 | 264 |
|
264 | 265 | def extra_repr(self):
|
| 266 | + out_features, in_features = None, None |
265 | 267 | if self.weight is not None:
|
266 |
| - out_features, in_features = self.weight.shape[-2:] if self.weight is not None else (None, None) |
267 |
| - dtype = self.weight.dtype if self.weight is not None else None |
268 |
| - extra_repr_str = "in_features={}, out_features={}, bias={}, dtype={}".format( |
269 |
| - in_features, out_features, self.bias is not None, dtype) |
270 |
| - return extra_repr_str |
| 268 | + out_features, in_features = self.weight.ds_shape[-2:] if is_zero_param( |
| 269 | + self.weight) else self.weight.shape[-2:] |
| 270 | + dtype = self.weight.dtype if self.weight is not None else None |
| 271 | + return "in_features={}, out_features={}, bias={}, dtype={}".format(in_features, out_features, self.bias |
| 272 | + is not None, dtype) |
271 | 273 |
|
272 | 274 | def move(self, tensor):
|
273 | 275 | # TODO: consider the timing of deletion
|
|
0 commit comments