Skip to content

Commit ec17a86

Browse files
author
pengcheng888
committed
issue/890 - 为python端的nn.module添加to函数
1 parent 3b5afff commit ec17a86

File tree

2 files changed

+117
-12
lines changed

2 files changed

+117
-12
lines changed

python/infinicore/nn/modules/module.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
import infinicore
3434

35+
from ...device import device as InfiniCoreDevice
3536
from ...tensor import Tensor
3637
from ..parameter import InfiniCoreParameter as Parameter
3738

@@ -481,15 +482,14 @@ def _load_from_state_dict(
481482
f"While copying the parameter named {key}, expected Tensor from checkpoint but received {type(input_param)}"
482483
)
483484

484-
if (
485-
(param.shape == input_param.shape)
486-
and (param.dtype == input_param.dtype)
487-
and (param.device == input_param.device)
485+
if (param.shape == input_param.shape) and (
486+
param.dtype == input_param.dtype
488487
):
489488
param.copy_(input_param)
490489
else:
491-
print(f"param '{name}' don't match input_param '{key}'")
492-
setattr(self, name, input_param)
490+
raise KeyError(
491+
f"param '{name}' don't match input_param '{key}' with shape or dtype"
492+
)
493493

494494
elif strict:
495495
missing_keys.append(key)
@@ -848,10 +848,29 @@ def eval(self: T) -> T:
848848
Returns:
849849
Module: self
850850
"""
851-
pass
851+
raise KeyError("not support")
852852

853853
def _apply(self, fn, recurse=True):
854-
raise KeyError("not support")
854+
if recurse:
855+
for module in self.children():
856+
module._apply(fn)
855857

856-
def to(self, *args, **kwargs):
857-
raise KeyError("not support")
858+
for key, param in self._parameters.items():
859+
if param is not None:
860+
setattr(self, key, fn(param))
861+
862+
for key, buf in self._buffers.items():
863+
if buf is not None:
864+
setattr(self, key, fn(buf))
865+
866+
return self
867+
868+
def to(self, device: str | InfiniCoreDevice):
869+
if device is None:
870+
raise ValueError("device cannot be None")
871+
device = InfiniCoreDevice(device)
872+
873+
def convert(t):
874+
return t.to(device)
875+
876+
return self._apply(convert)

test/infinicore/nn/module.py

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(self):
4444
def forward(self):
4545
return infinicore.add(self.a, self.b)
4646

47+
4748
infinicore_model_infer = InfiniCoreNet()
4849
# ============================================================
4950
# 2. 加载权重
@@ -75,6 +76,91 @@ def forward(self):
7576

7677

7778
# ============================================================
78-
# 5. to测试,buffer测试
79+
# 5. to测试 - 测试模型在不同设备间的转换
7980
# ============================================================
80-
# 等待添加
81+
print("\n" + "=" * 60)
82+
print("5. to测试 - 设备转换测试")
83+
print("=" * 60)
84+
85+
86+
def print_model_state(model, title="状态"):
87+
"""打印模型的参数状态"""
88+
print(f"\n{title}:")
89+
print("-" * 40)
90+
print("Parameters:")
91+
for name, param in model.named_parameters():
92+
print(
93+
f" {name}: shape={param.shape}, dtype={param.dtype}, device={param.device}"
94+
)
95+
96+
97+
def verify_device_conversion(model, target_device, use_type_check=False):
98+
"""验证模型参数的设备转换"""
99+
print("转换后的Parameters:")
100+
for name, param in model.named_parameters():
101+
print(
102+
f" {name}: shape={param.shape}, dtype={param.dtype}, device={param.device}"
103+
)
104+
if use_type_check:
105+
# 当使用字符串参数时,只检查设备类型
106+
expected_type = (
107+
target_device if isinstance(target_device, str) else target_device.type
108+
)
109+
assert param.device.type == expected_type, (
110+
f"参数 {name} 的设备转换失败: 期望类型 {expected_type}, 实际 {param.device.type}"
111+
)
112+
else:
113+
# 使用device对象时,进行完整比较
114+
assert param.device == target_device, (
115+
f"参数 {name} 的设备转换失败: 期望 {target_device}, 实际 {param.device}"
116+
)
117+
118+
119+
# 5.1 打印初始状态
120+
print_model_state(infinicore_model_infer, "5.1 初始状态")
121+
122+
# 定义设备转换测试用例列表
123+
device_conversion_cases = [
124+
{
125+
"name": "5.2 转换到CUDA设备",
126+
"description": "使用 infinicore.device('cuda', 0)",
127+
"target": infinicore.device("cuda", 0),
128+
"use_type_check": False,
129+
"success_msg": "✓ CUDA设备转换验证通过",
130+
},
131+
{
132+
"name": "5.3 转换到CPU设备",
133+
"description": "使用 infinicore.device('cpu', 0)",
134+
"target": infinicore.device("cpu", 0),
135+
"use_type_check": False,
136+
"success_msg": "✓ CPU设备转换验证通过",
137+
},
138+
{
139+
"name": "5.4 转换到CUDA设备",
140+
"description": "使用字符串 'cuda'",
141+
"target": "cuda",
142+
"use_type_check": True,
143+
"success_msg": "✓ 字符串参数设备转换验证通过",
144+
},
145+
]
146+
147+
# 循环测试每个设备转换用例
148+
for case in device_conversion_cases:
149+
print(f"\n{case['name']} ({case['description']}):")
150+
print("-" * 40)
151+
infinicore_model_infer.to(case["target"])
152+
verify_device_conversion(
153+
infinicore_model_infer, case["target"], use_type_check=case["use_type_check"]
154+
)
155+
print(case["success_msg"])
156+
157+
# 5.5 验证to方法返回self(链式调用支持)
158+
print("\n5.5 测试to方法的返回值(链式调用):")
159+
print("-" * 40)
160+
result = infinicore_model_infer.to(infinicore.device("cpu", 0))
161+
assert result is infinicore_model_infer, "to方法应该返回self以支持链式调用"
162+
print("✓ to方法返回值验证通过")
163+
164+
print("\n" + "=" * 60)
165+
print("所有to测试通过!")
166+
print("=" * 60 + "\n")

0 commit comments

Comments
 (0)