Skip to content

Commit 918d22e

Browse files
committed
Added diff proposal to reuse target.llvm_get_vector_width
1 parent b558b12 commit 918d22e

File tree

4 files changed

+7
-64
lines changed

4 files changed

+7
-64
lines changed

src/meta_schedule/space_generator/space_generator.cc

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@
1818
*/
1919
#include <tvm/ffi/reflection/registry.h>
2020

21-
#include "../../runtime/regex.h"
2221
#include "../../target/parsers/aprofile.h"
23-
#include "../../target/parsers/cpu.h"
2422
#include "../utils.h"
2523

2624
namespace tvm {
@@ -41,13 +39,14 @@ String GetRuleKindFromTarget(const Target& target) {
4139
return "avx512";
4240
}
4341
}
42+
bool have_rvv = target_has_feature_fn_ptr("v", target).cast<bool>();
43+
if (have_rvv) {
44+
return "rvv";
45+
}
4446

4547
TargetJSON target_json = target::parsers::aprofile::ParseTarget(target->Export());
4648
TargetFeatures afeatures = Downcast<TargetFeatures>(target_json.at("features"));
4749

48-
if (Downcast<Bool>(afeatures.at("has_rvv"))) {
49-
return "rvv";
50-
}
5150
if (Downcast<Bool>(afeatures.at("has_dotprod"))) {
5251
return "dotprod";
5352
}
@@ -88,28 +87,10 @@ String GetRuleKindFromTarget(const Target& target) {
8887
throw;
8988
}
9089

91-
std::string GetRISCVMarchFromTarget(const Target& target) {
92-
if (target->kind->name == "c") {
93-
if (Optional<String> opt_march = target->GetAttr<String>("march")) {
94-
return opt_march.value();
95-
}
96-
}
97-
return "";
98-
}
99-
100-
int GetRISCVVLENFromCTarget(const Target& target) {
101-
auto march = GetRISCVMarchFromTarget(target);
102-
int vlen = 0;
103-
if (march.find("zvl") != std::string::npos) {
104-
vlen = tvm::target::parsers::cpu::extractVLENFromString(march);
105-
}
106-
return vlen;
107-
}
108-
10990
int GetRISCVVLENFromLLVMTarget(const Target& target) {
110-
TargetJSON target_json = target::parsers::aprofile::ParseTarget(target->Export());
111-
TargetFeatures afeatures = Downcast<TargetFeatures>(target_json.at("features"));
112-
int vlen = Downcast<Integer>(afeatures.at("rvv_vlen"))->value;
91+
static auto llvm_get_vector_width_fn =
92+
tvm::ffi::Function::GetGlobalRequired("target.llvm_get_vector_width");
93+
const int vlen = llvm_get_vector_width_fn(target).cast<int>();
11394
return vlen;
11495
}
11596

src/target/parsers/aprofile.cc

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,8 @@
2727
#include <memory>
2828
#include <string>
2929

30-
#include "../../runtime/regex.h"
3130
#include "../../support/utils.h"
3231
#include "../llvm/llvm_instance.h"
33-
#include "cpu.h"
3432

3533
namespace tvm {
3634
namespace target {
@@ -82,16 +80,6 @@ bool CheckContains(Array<String> array, String predicate) {
8280
return std::any_of(array.begin(), array.end(), [&](String var) { return var == predicate; });
8381
}
8482

85-
int FindRISCVVLEN(Map<String, String> features) {
86-
int vlen = 128;
87-
for (auto const& feature : features) {
88-
std::string feature_str = Downcast<String>(feature.first);
89-
if (feature_str.find("zvl") != std::string::npos) {
90-
vlen = tvm::target::parsers::cpu::extractVLENFromString(feature_str);
91-
}
92-
}
93-
return vlen;
94-
}
9583

9684
static TargetFeatures GetFeatures(TargetJSON target) {
9785
#ifdef TVM_LLVM_VERSION
@@ -122,8 +110,6 @@ static TargetFeatures GetFeatures(TargetJSON target) {
122110
return {{"is_aarch64", Bool(IsAArch64(mtriple))},
123111
{"has_asimd", Bool(has_feature("neon"))},
124112
{"has_sve", Bool(has_feature("sve"))},
125-
{"has_rvv", Bool(has_feature("v"))},
126-
{"rvv_vlen", Integer(FindRISCVVLEN(features))},
127113
{"has_dotprod", Bool(has_feature("dotprod"))},
128114
{"has_matmul_i8", Bool(has_feature("i8mm"))},
129115
{"has_fp16_simd", Bool(has_feature("fullfp16"))},

src/target/parsers/cpu.cc

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -60,28 +60,6 @@ TargetJSON ParseTarget(TargetJSON target) {
6060
return target;
6161
}
6262

63-
int extractVLENFromString(const std::string& input) {
64-
for (size_t i = 0; i + 4 <= input.size(); ++i) {
65-
// Look for the starting sequence "zvl"
66-
if (input[i] == 'z' && input[i + 1] == 'v' && input[i + 2] == 'l') {
67-
size_t j = i + 3;
68-
std::string number;
69-
70-
// Collect digits
71-
while (j < input.size() && std::isdigit(input[j])) {
72-
number += input[j];
73-
++j;
74-
}
75-
76-
// Check if followed by 'b' after digits
77-
if (!number.empty() && j < input.size() && input[j] == 'b') {
78-
return std::stoi(number); // Convert the number to int
79-
}
80-
}
81-
}
82-
83-
throw std::runtime_error("No valid pattern found");
84-
}
8563

8664
} // namespace cpu
8765
} // namespace parsers

src/target/parsers/cpu.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,13 @@
2727

2828
#include <tvm/target/target.h>
2929

30-
#include <string>
3130

3231
namespace tvm {
3332
namespace target {
3433
namespace parsers {
3534
namespace cpu {
3635

3736
TargetJSON ParseTarget(TargetJSON target);
38-
int extractVLENFromString(const std::string& input);
3937

4038
} // namespace cpu
4139
} // namespace parsers

0 commit comments

Comments
 (0)