2020**Author**: `Federico Peccia <https://fPecc.github.io/>`_
2121"""
2222import re
23+ import logging
2324from tvm .script import tir as T
24- from tvm .target .datatype import lower_call_pure_extern , register , register_op
25+ from tvm .target .codegen import llvm_get_vector_width
2526from .. import TensorIntrin
2627
28+ logger = logging .getLogger (__name__ )
29+
2730#####################################################
2831# LLVM RISC-V Intrinsic usage:
2932# https://llvm.org/docs//RISCV/RISCVVectorExtension.html
@@ -327,7 +330,7 @@ def rvv_multivmul(J: int, vlmax: int, input_dtype: str, output_dtype: str, lmul:
327330 @T .prim_func
328331 def rvv_multivmul_desc (
329332 A : T .Buffer ((int (vlmax ),), input_dtype , align = 4 , offset_factor = 1 ),
330- B : T .Buffer ((J , int (vlmax )), kernel_dtype , align = 4 , offset_factor = 1 ),
333+ B : T .Buffer ((J , int (vlmax )), input_dtype , align = 4 , offset_factor = 1 ),
331334 C : T .Buffer ((J ,), output_dtype , align = 4 , offset_factor = 1 ),
332335 ) -> None :
333336 with T .block ("root" ):
@@ -345,7 +348,7 @@ def rvv_multivmul_desc(
345348 def rvv_multivmul_llvm_impl (
346349 A : T .Buffer ((int (vlmax ),), input_dtype , align = 4 , offset_factor = 1 ),
347350 B : T .Buffer (
348- (J , int (vlmax )), kernel_dtype , align = 4 , offset_factor = 1 , strides = [T .int32 (), T .int32 ()]
351+ (J , int (vlmax )), input_dtype , align = 4 , offset_factor = 1 , strides = [T .int32 (), T .int32 ()]
349352 ),
350353 C : T .Buffer ((J ,), output_dtype , align = 4 , offset_factor = 1 ),
351354 ) -> None :
@@ -530,7 +533,7 @@ def rvv_vmul(J: int, vlmax: int, input_dtype: str, output_dtype: str, lmul: int)
530533 @T .prim_func
531534 def rvv_vmul_desc (
532535 A : T .Buffer ((int (vlmax ),), input_dtype , align = 4 , offset_factor = 1 ),
533- B : T .Buffer ((int (vlmax ),), kernel_dtype , align = 4 , offset_factor = 1 ),
536+ B : T .Buffer ((int (vlmax ),), input_dtype , align = 4 , offset_factor = 1 ),
534537 C : T .Buffer ((1 ,), output_dtype , align = 4 , offset_factor = 1 ),
535538 ) -> None :
536539 with T .block ("root" ):
@@ -544,7 +547,7 @@ def rvv_vmul_desc(
544547 @T .prim_func
545548 def rvv_vmul_llvm_impl (
546549 A : T .Buffer ((int (vlmax ),), input_dtype , align = 4 , offset_factor = 1 ),
547- B : T .Buffer ((int (vlmax ),), kernel_dtype , align = 4 , offset_factor = 1 ),
550+ B : T .Buffer ((int (vlmax ),), input_dtype , align = 4 , offset_factor = 1 ),
548551 C : T .Buffer ((1 ,), output_dtype , align = 4 , offset_factor = 1 ),
549552 ) -> None :
550553
@@ -690,7 +693,7 @@ def register_intrinsic_combinations(
690693
691694 desc , impl = generator (J , current_vlmax , input_dtype , output_dtype , lmul )
692695
693- print (f"Registering intrin { name } ..." )
696+ logger . debug (f"Registering intrin { name } ..." )
694697
695698 TensorIntrin .register (name , desc , impl , override = True )
696699
@@ -701,33 +704,15 @@ def register_riscv_tensor_intrinsics(target):
701704 target_kind = target .kind .name
702705 assert target_kind in ["llvm" ]
703706
704- #####################################################
705- # Register custom RVV types for C code generation
706- #####################################################
707- dtype_counter = 0
708- for bits in [8 , 16 , 32 , 64 ]:
709- for dtype in ["int" , "uint" , "float" ]:
710- for m in [1 , 2 , 4 , 8 ]:
711- custom_rvv_type = f"v{ dtype } { bits } m{ m } _t"
712- register (custom_rvv_type , 150 + dtype_counter )
713- register_op (
714- lower_call_pure_extern ,
715- "Call" ,
716- "c" ,
717- custom_rvv_type ,
718- intrinsic_name = "tir.call_pure_extern" ,
719- )
720- dtype_counter += 1
721-
722- vlen = get_vlen_from_mattrs (target .mattr )
707+ vlen = llvm_get_vector_width (target )
723708
724709 for vmul_type , func , outer_loops in zip (
725710 ["vmacc" , "multivmul" , "vmul" ],
726711 [rvv_vmacc , rvv_multivmul , rvv_vmul ],
727712 [[1 ], [get_vlmax (vlen , lmul = 1 , max_sew = 32 )], [1 ]],
728713 ):
729714
730- for idtype , odtype in zip (["int16" , "float32" ], ["int32" , "float32" ]):
715+ for idtype , odtype in zip (["int16" , "float16" , " float32" ], ["int32" , "float16 " , "float32" ]):
731716
732717 if idtype == "float32" and vmul_type == "multivmul" :
733718 continue
0 commit comments