@@ -44,6 +44,7 @@ def __init__(self):
4444 def forward (self ):
4545 return infinicore .add (self .a , self .b )
4646
47+
4748infinicore_model_infer = InfiniCoreNet ()
4849# ============================================================
4950# 2. 加载权重
@@ -75,6 +76,91 @@ def forward(self):
7576
7677
7778# ============================================================
78- # 5. to测试,buffer测试
79+ # 5. to测试 - 测试模型在不同设备间的转换
7980# ============================================================
80- # 等待添加
81+ print ("\n " + "=" * 60 )
82+ print ("5. to测试 - 设备转换测试" )
83+ print ("=" * 60 )
84+
85+
86+ def print_model_state (model , title = "状态" ):
87+ """打印模型的参数状态"""
88+ print (f"\n { title } :" )
89+ print ("-" * 40 )
90+ print ("Parameters:" )
91+ for name , param in model .named_parameters ():
92+ print (
93+ f" { name } : shape={ param .shape } , dtype={ param .dtype } , device={ param .device } "
94+ )
95+
96+
97+ def verify_device_conversion (model , target_device , use_type_check = False ):
98+ """验证模型参数的设备转换"""
99+ print ("转换后的Parameters:" )
100+ for name , param in model .named_parameters ():
101+ print (
102+ f" { name } : shape={ param .shape } , dtype={ param .dtype } , device={ param .device } "
103+ )
104+ if use_type_check :
105+ # 当使用字符串参数时,只检查设备类型
106+ expected_type = (
107+ target_device if isinstance (target_device , str ) else target_device .type
108+ )
109+ assert param .device .type == expected_type , (
110+ f"参数 { name } 的设备转换失败: 期望类型 { expected_type } , 实际 { param .device .type } "
111+ )
112+ else :
113+ # 使用device对象时,进行完整比较
114+ assert param .device == target_device , (
115+ f"参数 { name } 的设备转换失败: 期望 { target_device } , 实际 { param .device } "
116+ )
117+
118+
119+ # 5.1 打印初始状态
120+ print_model_state (infinicore_model_infer , "5.1 初始状态" )
121+
122+ # 定义设备转换测试用例列表
123+ device_conversion_cases = [
124+ {
125+ "name" : "5.2 转换到CUDA设备" ,
126+ "description" : "使用 infinicore.device('cuda', 0)" ,
127+ "target" : infinicore .device ("cuda" , 0 ),
128+ "use_type_check" : False ,
129+ "success_msg" : "✓ CUDA设备转换验证通过" ,
130+ },
131+ {
132+ "name" : "5.3 转换到CPU设备" ,
133+ "description" : "使用 infinicore.device('cpu', 0)" ,
134+ "target" : infinicore .device ("cpu" , 0 ),
135+ "use_type_check" : False ,
136+ "success_msg" : "✓ CPU设备转换验证通过" ,
137+ },
138+ {
139+ "name" : "5.4 转换到CUDA设备" ,
140+ "description" : "使用字符串 'cuda'" ,
141+ "target" : "cuda" ,
142+ "use_type_check" : True ,
143+ "success_msg" : "✓ 字符串参数设备转换验证通过" ,
144+ },
145+ ]
146+
147+ # 循环测试每个设备转换用例
148+ for case in device_conversion_cases :
149+ print (f"\n { case ['name' ]} ({ case ['description' ]} ):" )
150+ print ("-" * 40 )
151+ infinicore_model_infer .to (case ["target" ])
152+ verify_device_conversion (
153+ infinicore_model_infer , case ["target" ], use_type_check = case ["use_type_check" ]
154+ )
155+ print (case ["success_msg" ])
156+
157+ # 5.5 验证to方法返回self(链式调用支持)
158+ print ("\n 5.5 测试to方法的返回值(链式调用):" )
159+ print ("-" * 40 )
160+ result = infinicore_model_infer .to (infinicore .device ("cpu" , 0 ))
161+ assert result is infinicore_model_infer , "to方法应该返回self以支持链式调用"
162+ print ("✓ to方法返回值验证通过" )
163+
164+ print ("\n " + "=" * 60 )
165+ print ("所有to测试通过!" )
166+ print ("=" * 60 + "\n " )
0 commit comments