Skip to content

Commit 49d9a0c

Browse files
committed
Fixes/changes based on comments on PR
1 parent 918d22e commit 49d9a0c

File tree

4 files changed

+13
-34
lines changed

4 files changed

+13
-34
lines changed

python/tvm/meta_schedule/tune_context.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from tvm.runtime import Object
2929
from tvm.target import Target
3030
from tvm.tir import PrimFunc, Schedule
31+
from tvm.target.codegen import target_has_features
3132

3233
from . import _ffi_api
3334
from .logging import Logger, get_logger, get_logging_func
@@ -118,8 +119,7 @@ def __init__(
118119
if not isinstance(target, Target):
119120
target = Target(target)
120121
if "riscv_cpu" in target.keys:
121-
base_features = str(target.attrs["march"]).split("_")[0].replace("rv", "")
122-
if "v" in base_features:
122+
if target_has_features("v", target):
123123
# Because the RVV intrinsics depend on the target, we register them here
124124
# pylint: disable=import-outside-toplevel
125125
from tvm.tir.tensor_intrin.riscv_cpu import register_riscv_tensor_intrinsics

python/tvm/target/target.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -638,16 +638,12 @@ def riscv_cpu(model="sifive-u54", options=None):
638638
# cc: riscv64-unknown-linux-gnu-g++ -march=rv64gc -mabi=lp64d -mcpu=sifive-u74
639639
],
640640
"bpi-f3": [
641-
# "-model=sifive-u74",
642641
"-mtriple=riscv64-unknown-linux-gnu",
643642
"-mcpu=generic",
644-
# "-march=rv64gcv_zvl256b",
645-
# "-mcpu=generic-rv64",
646643
"-mfloat-abi=hard",
647644
"-num-cores=8",
648645
"-mabi=lp64d",
649646
"-mattr=+v,+zvl256b",
650-
# cc: riscv64-unknown-linux-gnu-g++ -march=rv64gc -mabi=lp64d -mcpu=generic -mattr=+v
651647
],
652648
}
653649
pre_defined_opt = trans_table.get(model, ["-model=%s" % model])

python/tvm/tir/tensor_intrin/riscv_cpu.py

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,13 @@
2020
**Author**: `Federico Peccia <https://fPecc.github.io/>`_
2121
"""
2222
import re
23+
import logging
2324
from tvm.script import tir as T
24-
from tvm.target.datatype import lower_call_pure_extern, register, register_op
25+
from tvm.target.codegen import llvm_get_vector_width
2526
from .. import TensorIntrin
2627

28+
logger = logging.getLogger(__name__)
29+
2730
#####################################################
2831
# LLVM RISC-V Intrinsic usage:
2932
# https://llvm.org/docs//RISCV/RISCVVectorExtension.html
@@ -327,7 +330,7 @@ def rvv_multivmul(J: int, vlmax: int, input_dtype: str, output_dtype: str, lmul:
327330
@T.prim_func
328331
def rvv_multivmul_desc(
329332
A: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1),
330-
B: T.Buffer((J, int(vlmax)), kernel_dtype, align=4, offset_factor=1),
333+
B: T.Buffer((J, int(vlmax)), input_dtype, align=4, offset_factor=1),
331334
C: T.Buffer((J,), output_dtype, align=4, offset_factor=1),
332335
) -> None:
333336
with T.block("root"):
@@ -345,7 +348,7 @@ def rvv_multivmul_desc(
345348
def rvv_multivmul_llvm_impl(
346349
A: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1),
347350
B: T.Buffer(
348-
(J, int(vlmax)), kernel_dtype, align=4, offset_factor=1, strides=[T.int32(), T.int32()]
351+
(J, int(vlmax)), input_dtype, align=4, offset_factor=1, strides=[T.int32(), T.int32()]
349352
),
350353
C: T.Buffer((J,), output_dtype, align=4, offset_factor=1),
351354
) -> None:
@@ -530,7 +533,7 @@ def rvv_vmul(J: int, vlmax: int, input_dtype: str, output_dtype: str, lmul: int)
530533
@T.prim_func
531534
def rvv_vmul_desc(
532535
A: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1),
533-
B: T.Buffer((int(vlmax),), kernel_dtype, align=4, offset_factor=1),
536+
B: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1),
534537
C: T.Buffer((1,), output_dtype, align=4, offset_factor=1),
535538
) -> None:
536539
with T.block("root"):
@@ -544,7 +547,7 @@ def rvv_vmul_desc(
544547
@T.prim_func
545548
def rvv_vmul_llvm_impl(
546549
A: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1),
547-
B: T.Buffer((int(vlmax),), kernel_dtype, align=4, offset_factor=1),
550+
B: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1),
548551
C: T.Buffer((1,), output_dtype, align=4, offset_factor=1),
549552
) -> None:
550553

@@ -690,7 +693,7 @@ def register_intrinsic_combinations(
690693

691694
desc, impl = generator(J, current_vlmax, input_dtype, output_dtype, lmul)
692695

693-
print(f"Registering intrin {name}...")
696+
logger.debug(f"Registering intrin {name}...")
694697

695698
TensorIntrin.register(name, desc, impl, override=True)
696699

@@ -701,33 +704,15 @@ def register_riscv_tensor_intrinsics(target):
701704
target_kind = target.kind.name
702705
assert target_kind in ["llvm"]
703706

704-
#####################################################
705-
# Register custom RVV types for C code generation
706-
#####################################################
707-
dtype_counter = 0
708-
for bits in [8, 16, 32, 64]:
709-
for dtype in ["int", "uint", "float"]:
710-
for m in [1, 2, 4, 8]:
711-
custom_rvv_type = f"v{dtype}{bits}m{m}_t"
712-
register(custom_rvv_type, 150 + dtype_counter)
713-
register_op(
714-
lower_call_pure_extern,
715-
"Call",
716-
"c",
717-
custom_rvv_type,
718-
intrinsic_name="tir.call_pure_extern",
719-
)
720-
dtype_counter += 1
721-
722-
vlen = get_vlen_from_mattrs(target.mattr)
707+
vlen = llvm_get_vector_width(target)
723708

724709
for vmul_type, func, outer_loops in zip(
725710
["vmacc", "multivmul", "vmul"],
726711
[rvv_vmacc, rvv_multivmul, rvv_vmul],
727712
[[1], [get_vlmax(vlen, lmul=1, max_sew=32)], [1]],
728713
):
729714

730-
for idtype, odtype in zip(["int16", "float32"], ["int32", "float32"]):
715+
for idtype, odtype in zip(["int16", "float16", "float32"], ["int32", "float16", "float32"]):
731716

732717
if idtype == "float32" and vmul_type == "multivmul":
733718
continue

src/target/source/codegen_c.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,6 @@ std::string CodeGenC::GetBufferRef(DataType t, const BufferNode* buffer, PrimExp
268268
<< " + " << index_str << " / " << div_factor << ")";
269269
} else if (t == buffer_element_dtype) {
270270
os << buffer_str << "[" << index_str << "]";
271-
} else if (t == buffer_element_dtype) {
272-
os << buffer_str << "[" << index_str << "]";
273271
} else {
274272
os << "*" << ptr_cast(t) << "(" << buffer_str << " + " << index_str << ")";
275273
}

0 commit comments

Comments
 (0)