@@ -304,121 +304,120 @@ Array<ScheduleRule> ScheduleRule::DefaultHexagon() {
304304 };
305305}
306306
307- int GetVLMAX (int vlen, int lmul, int max_sew) {
308- return (lmul * vlen) / max_sew;
309- }
307+ int GetVLMAX (int vlen, int lmul, int max_sew) { return (lmul * vlen) / max_sew; }
310308
311309Array<ScheduleRule> ScheduleRule::DefaultRISCV (int vlen) {
312- Array<ScheduleRule> rules;
313-
314- rules.push_back (ScheduleRule::ApplyCustomRule ());
315-
316- rules.push_back (ScheduleRule::InlineConstantScalars ());
317-
318- rules.push_back (ScheduleRule::AutoInline (
319- /* into_producer=*/ false ,
320- /* into_consumer=*/ true ,
321- /* inline_const_tensor=*/ true ,
322- /* disallow_if_then_else=*/ true ,
323- /* require_injective=*/ true ,
324- /* require_ordered=*/ true ,
325- /* disallow_op=*/ Array<String>{" tir.exp" }));
326-
327- rules.push_back (ScheduleRule::AddRFactor (
328- /* max_jobs_per_core=*/ 16 ,
329- /* max_innermost_factor=*/ Integer (64 )));
330-
331- int vlmax = 0 ;
332- int RISCV_MIN_VL = 4 ;
333- std::vector<std::string> vmul_types = {" multivmul" , " vmul" , " vmacc" };
334- String intrin_name = " " ;
335- int j = 1 ;
336-
337- for (const std::string& vmul_type : vmul_types) {
338- if (vmul_type == " multivmul" ) j = GetVLMAX (vlen, 1 , 32 );
339- else
340- j = 1 ;
341-
342- // Registering for int16
343- vlmax = GetVLMAX (vlen, 8 , 32 );
344- while (vlmax >= RISCV_MIN_VL) {
345- intrin_name = " rvv_int16_" + vmul_type + " _" + \
346- std::to_string (j) + " _" + std::to_string (vlmax) + " _m8" ;
347- rules.push_back (ScheduleRule::MultiLevelTilingWithIntrin (
348- /* intrin_name=*/ intrin_name,
349- /* structure=*/ " SSRSRS" ,
350- /* tile_binds=*/ std::nullopt ,
351- /* max_innermost_factor=*/ Integer (vlmax),
352- /* vector_load_lens=*/ std::nullopt ,
353- /* reuse_read=*/ std::nullopt ,
354- /* reuse_write=*/
355- Map<String, ffi::Any>{ {" req" , String (" may" )},
356- {" levels" , Array<Integer>{1 , 2 }},
357- {" scope" , String (" global" )}}));
358- vlmax /= 2 ;
359- }
360-
361- // Registering for float16
362- vlmax = GetVLMAX (vlen, 8 , 16 );
363- if (vmul_type == " multivmul" ) j = GetVLMAX (vlen, 1 , 32 );
364- else
365- j = 1 ;
366-
367- while (vlmax >= RISCV_MIN_VL) {
368- intrin_name = " rvv_float16_" + vmul_type + " _" + \
369- std::to_string (j) + " _" + std::to_string (vlmax) + " _m8" ;
370- rules.push_back (ScheduleRule::MultiLevelTilingWithIntrin (
371- /* intrin_name=*/ intrin_name,
372- /* structure=*/ " SSRSRS" ,
373- /* tile_binds=*/ std::nullopt ,
374- /* max_innermost_factor=*/ Integer (vlmax),
375- /* vector_load_lens=*/ std::nullopt ,
376- /* reuse_read=*/ std::nullopt ,
377- /* reuse_write=*/
378- Map<String, ffi::Any>{ {" req" , String (" may" )},
379- {" levels" , Array<Integer>{1 , 2 }},
380- {" scope" , String (" global" )}}));
381- vlmax /= 2 ;
382- }
383-
384- vlmax = GetVLMAX (vlen, 8 , 32 );
385- while (vlmax >= RISCV_MIN_VL) {
386- intrin_name = " rvv_float32_" + vmul_type + " _" + \
387- std::to_string (j) + " _" + std::to_string (vlmax) + " _m8" ;
388- rules.push_back (ScheduleRule::MultiLevelTilingWithIntrin (
389- /* intrin_name=*/ intrin_name,
390- /* structure=*/ " SSRSRS" ,
391- /* tile_binds=*/ std::nullopt ,
392- /* max_innermost_factor=*/ Integer (vlmax),
393- /* vector_load_lens=*/ std::nullopt ,
394- /* reuse_read=*/ std::nullopt ,
395- /* reuse_write=*/
396- Map<String, ffi::Any>{ {" req" , String (" may" )},
397- {" levels" , Array<Integer>{1 , 2 }},
398- {" scope" , String (" global" )}}));
399- vlmax /= 2 ;
400- }
310+ Array<ScheduleRule> rules;
311+
312+ rules.push_back (ScheduleRule::ApplyCustomRule ());
313+
314+ rules.push_back (ScheduleRule::InlineConstantScalars ());
315+
316+ rules.push_back (ScheduleRule::AutoInline (
317+ /* into_producer=*/ false ,
318+ /* into_consumer=*/ true ,
319+ /* inline_const_tensor=*/ true ,
320+ /* disallow_if_then_else=*/ true ,
321+ /* require_injective=*/ true ,
322+ /* require_ordered=*/ true ,
323+ /* disallow_op=*/ Array<String>{" tir.exp" }));
324+
325+ rules.push_back (ScheduleRule::AddRFactor (
326+ /* max_jobs_per_core=*/ 16 ,
327+ /* max_innermost_factor=*/ Integer (64 )));
328+
329+ int vlmax = 0 ;
330+ int RISCV_MIN_VL = 4 ;
331+ std::vector<std::string> vmul_types = {" multivmul" , " vmul" , " vmacc" };
332+ String intrin_name = " " ;
333+ int j = 1 ;
334+
335+ for (const std::string& vmul_type : vmul_types) {
336+ if (vmul_type == " multivmul" )
337+ j = GetVLMAX (vlen, 1 , 32 );
338+ else
339+ j = 1 ;
340+
341+ // Registering for int16
342+ vlmax = GetVLMAX (vlen, 8 , 32 );
343+ while (vlmax >= RISCV_MIN_VL) {
344+ intrin_name =
345+ " rvv_int16_" + vmul_type + " _" + std::to_string (j) + " _" + std::to_string (vlmax) + " _m8" ;
346+ rules.push_back (ScheduleRule::MultiLevelTilingWithIntrin (
347+ /* intrin_name=*/ intrin_name,
348+ /* structure=*/ " SSRSRS" ,
349+ /* tile_binds=*/ std::nullopt ,
350+ /* max_innermost_factor=*/ Integer (vlmax),
351+ /* vector_load_lens=*/ std::nullopt ,
352+ /* reuse_read=*/ std::nullopt ,
353+ /* reuse_write=*/
354+ Map<String, ffi::Any>{{" req" , String (" may" )},
355+ {" levels" , Array<Integer>{1 , 2 }},
356+ {" scope" , String (" global" )}}));
357+ vlmax /= 2 ;
401358 }
402- rules.push_back (ScheduleRule::MultiLevelTiling (
403- /* structure=*/ " SSRSRS" ,
404- /* tile_binds=*/ std::nullopt ,
405- /* max_innermost_factor=*/ Integer (64 ),
406- /* vector_load_lens=*/ std::nullopt ,
407- /* reuse_read=*/ std::nullopt ,
408- /* reuse_write=*/
409- Map<String, ffi::Any>{ {" req" , String (" may" )},
410- {" levels" , Array<Integer>{1 , 2 }},
411- {" scope" , String (" global" )}}));
412-
413- rules.push_back (ScheduleRule::ParallelizeVectorizeUnroll (
414- /* max_jobs_per_core=*/ 16 ,
415- /* max_vectorize_extent=*/ 64 ,
416- /* unroll_max_steps=*/ Array<Integer>{0 , 16 , 64 , 512 },
417- /* unroll_explicit=*/ true ));
418359
419- rules.push_back (ScheduleRule::RandomComputeLocation ());
360+ // Registering for float16
361+ vlmax = GetVLMAX (vlen, 8 , 16 );
362+ if (vmul_type == " multivmul" )
363+ j = GetVLMAX (vlen, 1 , 32 );
364+ else
365+ j = 1 ;
366+
367+ while (vlmax >= RISCV_MIN_VL) {
368+ intrin_name = " rvv_float16_" + vmul_type + " _" + std::to_string (j) + " _" +
369+ std::to_string (vlmax) + " _m8" ;
370+ rules.push_back (ScheduleRule::MultiLevelTilingWithIntrin (
371+ /* intrin_name=*/ intrin_name,
372+ /* structure=*/ " SSRSRS" ,
373+ /* tile_binds=*/ std::nullopt ,
374+ /* max_innermost_factor=*/ Integer (vlmax),
375+ /* vector_load_lens=*/ std::nullopt ,
376+ /* reuse_read=*/ std::nullopt ,
377+ /* reuse_write=*/
378+ Map<String, ffi::Any>{{" req" , String (" may" )},
379+ {" levels" , Array<Integer>{1 , 2 }},
380+ {" scope" , String (" global" )}}));
381+ vlmax /= 2 ;
382+ }
420383
421- return rules;
384+ vlmax = GetVLMAX (vlen, 8 , 32 );
385+ while (vlmax >= RISCV_MIN_VL) {
386+ intrin_name = " rvv_float32_" + vmul_type + " _" + std::to_string (j) + " _" +
387+ std::to_string (vlmax) + " _m8" ;
388+ rules.push_back (ScheduleRule::MultiLevelTilingWithIntrin (
389+ /* intrin_name=*/ intrin_name,
390+ /* structure=*/ " SSRSRS" ,
391+ /* tile_binds=*/ std::nullopt ,
392+ /* max_innermost_factor=*/ Integer (vlmax),
393+ /* vector_load_lens=*/ std::nullopt ,
394+ /* reuse_read=*/ std::nullopt ,
395+ /* reuse_write=*/
396+ Map<String, ffi::Any>{{" req" , String (" may" )},
397+ {" levels" , Array<Integer>{1 , 2 }},
398+ {" scope" , String (" global" )}}));
399+ vlmax /= 2 ;
400+ }
401+ }
402+ rules.push_back (ScheduleRule::MultiLevelTiling (
403+ /* structure=*/ " SSRSRS" ,
404+ /* tile_binds=*/ std::nullopt ,
405+ /* max_innermost_factor=*/ Integer (64 ),
406+ /* vector_load_lens=*/ std::nullopt ,
407+ /* reuse_read=*/ std::nullopt ,
408+ /* reuse_write=*/
409+ Map<String, ffi::Any>{
410+ {" req" , String (" may" )}, {" levels" , Array<Integer>{1 , 2 }}, {" scope" , String (" global" )}}));
411+
412+ rules.push_back (ScheduleRule::ParallelizeVectorizeUnroll (
413+ /* max_jobs_per_core=*/ 16 ,
414+ /* max_vectorize_extent=*/ 64 ,
415+ /* unroll_max_steps=*/ Array<Integer>{0 , 16 , 64 , 512 },
416+ /* unroll_explicit=*/ true ));
417+
418+ rules.push_back (ScheduleRule::RandomComputeLocation ());
419+
420+ return rules;
422421}
423422
424423Array<ScheduleRule> GetARMNeonSpecificRules () {
0 commit comments