Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
ef95183
3rdparty tvm bump
LeiWang1999 Oct 22, 2025
3c175e4
bump tvm into v0.22.0
LeiWang1999 Oct 22, 2025
951f2de
lint fix
LeiWang1999 Oct 22, 2025
6d29c1e
rebase tvm
LeiWang1999 Oct 23, 2025
21e7a0a
Update submodule tvm to latest commit 3085bc4
LeiWang1999 Oct 23, 2025
3877610
Refactor: Update configuration retrieval in CopyNode and adjust test …
LeiWang1999 Oct 23, 2025
cd6daaf
lint fix
LeiWang1999 Oct 23, 2025
995315e
test fix
LeiWang1999 Oct 23, 2025
7ef5d01
add requirement
LeiWang1999 Oct 24, 2025
fdf4669
atomic_fix
LeiWang1999 Oct 24, 2025
9ecf41a
atomic_fix
LeiWang1999 Oct 24, 2025
68b6ada
phaseout py39
LeiWang1999 Oct 25, 2025
dc12ebc
optimize
LeiWang1999 Oct 25, 2025
751dbc7
optimize
LeiWang1999 Oct 25, 2025
fde6a50
Merge branch 'main' of https://github.com/tile-ai/tilelang into tvm_r…
LeiWang1999 Oct 25, 2025
38f7e49
lint fix
LeiWang1999 Oct 25, 2025
46dfea1
do not clean cache
LeiWang1999 Oct 26, 2025
33d6ad1
do not clean cache
LeiWang1999 Oct 26, 2025
c768919
Merge branch 'main' of https://github.com/tile-ai/tilelang into tvm_r…
LeiWang1999 Oct 27, 2025
f9a97c7
[Minor] Minor update for Python versions and dependencies
XuehaiPan Oct 27, 2025
c9d64fa
[Lint] fix lint for py39
XuehaiPan Oct 27, 2025
89df129
Merge remote-tracking branch 'upstream/main' into tvm_rebase
XuehaiPan Oct 27, 2025
4abf1d0
[Lint] fix lint for ROCm
XuehaiPan Oct 27, 2025
452e40f
[Build][CI] Sync CI changes from upstream/sdist
XuehaiPan Oct 27, 2025
b9306ab
[Lint] fix lint for ROCm
XuehaiPan Oct 27, 2025
40a6138
[Build][CI] Update `repair-wheel-command`
XuehaiPan Oct 27, 2025
565fa97
Merge remote-tracking branch 'upstream/main' into tvm_rebase
XuehaiPan Oct 27, 2025
305c86a
[Minor] update abi3audit result format
XuehaiPan Oct 27, 2025
4d56db2
[Lint] fix lint for ROCm
XuehaiPan Oct 27, 2025
daf6c55
[BugFix] fix build
XuehaiPan Oct 27, 2025
66ac445
[Lint] fix lint for ROCm
XuehaiPan Oct 27, 2025
fbc250e
[BugFix] set rpath for libtvm and libtvm_runtime
XuehaiPan Oct 27, 2025
6bbf6aa
[Deps] pin apache-tvm-ffi version
XuehaiPan Oct 27, 2025
4aa24ea
[Build] set Python 3.9 Limited API for Cython target
XuehaiPan Oct 27, 2025
1b21e95
[Build] set Python 3.9 Limited API for Cython target
XuehaiPan Oct 27, 2025
1de65aa
[Deps] Restore Python 3.8 support
XuehaiPan Oct 27, 2025
6e8ad0d
Merge remote-tracking branch 'upstream/main' into tvm_rebase
XuehaiPan Oct 27, 2025
f4baf32
Merge remote-tracking branch 'upstream/main' into tvm_rebase
XuehaiPan Oct 28, 2025
ba13761
[Build] use `apache-tvm-ffi`'s `libtvm_ffi`
XuehaiPan Oct 28, 2025
5be4057
[BugFix] use `;` as delimiter for RPATH on macOS
XuehaiPan Oct 28, 2025
5e2ed4f
[BugFix] use `--ignore-missing-dependencies` for `delocate-wheel`
XuehaiPan Oct 28, 2025
2d5faa8
[Build] support `sccache` if available
XuehaiPan Oct 28, 2025
9a48282
[Build] add CIBW import test
XuehaiPan Oct 28, 2025
cd9ab57
[Build][CI] enable ccache for CIBW on Linux
XuehaiPan Oct 28, 2025
1807298
[BugFix] set rpath for libtvm and libtvm_runtime
XuehaiPan Oct 28, 2025
aa4eb5d
Revert "[Build][CI] enable ccache for CIBW on Linux"
XuehaiPan Oct 28, 2025
024c1e3
[CI] fix perfbench bot
XuehaiPan Oct 28, 2025
f276a26
[BugFix] use Python 3.9 to build wheel
XuehaiPan Oct 28, 2025
8a23c08
[Minor] update perfbench bot envs
XuehaiPan Oct 28, 2025
8424e99
[BugFix] fix CIBW environment on Linux
XuehaiPan Oct 28, 2025
9ef647c
[CI] skip import test on CentOS 7
XuehaiPan Oct 28, 2025
4bf2524
Merge remote-tracking branch 'upstream/main' into tvm_rebase
XuehaiPan Oct 29, 2025
c9e9191
[CI] use Python urllib to download file instead of Wget
XuehaiPan Oct 29, 2025
74ec78c
Merge remote-tracking branch 'upstream/main' into tvm_rebase
XuehaiPan Oct 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from 5bf17a to 0f1eba
10 changes: 4 additions & 6 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,6 @@ file(GLOB TILE_LANG_SRCS
src/target/utils.cc
src/target/codegen_cpp.cc
src/target/rt_mod_cpp.cc
# webgpu doesn't have system dependency
src/target/codegen_webgpu.cc
# intrin_rule doesn't have system dependency
src/target/intrin_rule*.cc
)
Expand Down Expand Up @@ -192,7 +190,7 @@ install(TARGETS tvm tvm_runtime tilelang_module tilelang LIBRARY DESTINATION til

# Copy tvm cython ext for wheels
# TODO: not necessary for editable builds
if(TVM_BUILD_FROM_SOURCE)
add_dependencies(tilelang tvm_cython)
install(FILES "${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/tvm/python/tvm/ffi/core.abi3.so" DESTINATION tilelang/3rdparty/tvm/python/tvm/ffi/)
endif()
# if(TVM_BUILD_FROM_SOURCE)
# add_dependencies(tilelang tvm_cython)
# install(FILES "${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/tvm/python/tvm/ffi/core.abi3.so" DESTINATION tilelang/3rdparty/tvm/python/tvm/ffi/)
# endif()
11 changes: 10 additions & 1 deletion cmake/load_tvm.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,17 @@ endif()

set(TVM_INCLUDES
${TVM_SOURCE}/include
${TVM_SOURCE}/ffi/include
${TVM_SOURCE}/src
${TVM_SOURCE}/3rdparty/dlpack/include
${TVM_SOURCE}/3rdparty/dmlc-core/include
)

if(EXISTS ${TVM_SOURCE}/ffi/include)
list(APPEND TVM_INCLUDES ${TVM_SOURCE}/ffi/include)
elseif(EXISTS ${TVM_SOURCE}/3rdparty/tvm-ffi/include)
list(APPEND TVM_INCLUDES ${TVM_SOURCE}/3rdparty/tvm-ffi/include)
endif()

if(EXISTS ${TVM_SOURCE}/3rdparty/tvm-ffi/3rdparty/dlpack/include)
list(APPEND TVM_INCLUDES ${TVM_SOURCE}/3rdparty/tvm-ffi/3rdparty/dlpack/include)
endif()
55 changes: 33 additions & 22 deletions src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
#include "./transform/common/attr.h"
#include "op/builtin.h"
#include "tvm/ffi/any.h"
#include <tvm/ffi/object.h>

#include "support/ffi_aliases.h"
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/script/ir_builder/tir/ir.h>
Expand Down Expand Up @@ -37,7 +40,7 @@ static ForFrame MakeIterVarFrame(const std::string &name, const PrimExpr &dom) {
using namespace tvm::tir;
Var var = Var(name, dom->dtype);
// Create a frame that represents a loop over the given domain.
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
ObjectPtr<ForFrameNode> n = tvm::ffi::make_object<ForFrameNode>();
n->vars.push_back(var);
n->doms.push_back(Range(0, dom));
n->f_make_for_loop = [](const Array<Var> &vars, const Array<Range> &doms,
Expand All @@ -52,7 +55,7 @@ static ForFrame MakeIterVarFrame(const std::string &name, const PrimExpr &dom) {
ForFrame ParallelFor(const Array<PrimExpr> &extents,
const Map<String, ObjectRef> &annotations) {
using namespace tvm::tir;
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
ObjectPtr<ForFrameNode> n = tvm::ffi::make_object<ForFrameNode>();
n->vars.reserve(extents.size());
n->doms.reserve(extents.size());
for (const auto &extent : extents) {
Expand Down Expand Up @@ -82,7 +85,7 @@ ForFrame PipelinedFor(PrimExpr start, const PrimExpr &stop, int num_stages,
const Array<Array<PrimExpr>> &sync,
const Array<Array<PrimExpr>> &groups) {
using namespace tvm::tir;
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
ObjectPtr<ForFrameNode> n = tvm::ffi::make_object<ForFrameNode>();
DataType dtype = stop.dtype();
n->vars.push_back(Var("v", dtype));
n->doms.push_back(Range(std::move(start), stop));
Expand Down Expand Up @@ -113,7 +116,7 @@ ForFrame PersistentFor(const Array<PrimExpr> &domain, const PrimExpr &wave_size,
const PrimExpr &index, PrimExpr group_size) {
using namespace tvm::tir;
ICHECK(!domain.empty());
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
ObjectPtr<ForFrameNode> n = tvm::ffi::make_object<ForFrameNode>();
n->vars.reserve(domain.size());
n->doms.reserve(domain.size());
PrimExpr domain_size = domain[0];
Expand Down Expand Up @@ -193,8 +196,8 @@ class KernelLaunchFrameNode : public TIRFrameNode {
"frames", &KernelLaunchFrameNode::frames);
}

static constexpr const char *_type_key = "tl.KernelLaunchFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(KernelLaunchFrameNode, TIRFrameNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.KernelLaunchFrame",
KernelLaunchFrameNode, TIRFrameNode);

public:
TVM_DLL void EnterWithScope() final {
Expand All @@ -218,14 +221,20 @@ class KernelLaunchFrameNode : public TIRFrameNode {
*/
class KernelLaunchFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(KernelLaunchFrame, TIRFrame,
KernelLaunchFrameNode);
explicit KernelLaunchFrame(ObjectPtr<KernelLaunchFrameNode> data)
: TIRFrame(::tvm::ffi::UnsafeInit{}) {
ICHECK(data != nullptr);
data_ = std::move(data);
}
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(KernelLaunchFrame, TIRFrame,
KernelLaunchFrameNode);
};

KernelLaunchFrame KernelLaunch(const Array<PrimExpr> &grid_size,
const Optional<Array<PrimExpr>> &block_size_opt,
const Map<String, ffi::Any> &attrs) {
ObjectPtr<KernelLaunchFrameNode> n = make_object<KernelLaunchFrameNode>();
ObjectPtr<KernelLaunchFrameNode> n =
tvm::ffi::make_object<KernelLaunchFrameNode>();

// If the kernel is a CPU kernel, we don't need to launch any threads.
bool is_cpu_kernel_frame =
Expand Down Expand Up @@ -289,16 +298,14 @@ KernelLaunchFrame KernelLaunch(const Array<PrimExpr> &grid_size,
return KernelLaunchFrame(n);
}

TVM_REGISTER_NODE_TYPE(KernelLaunchFrameNode);

TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("tl.Parallel", ParallelFor)
.def("tl.Pipelined", PipelinedFor)
.def("tl.Persistent", PersistentFor)
.def("tl.KernelLaunch", KernelLaunch);
});
}

class WarpSpecializeFrameNode : public TIRFrameNode {
public:
Expand All @@ -310,8 +317,8 @@ class WarpSpecializeFrameNode : public TIRFrameNode {
"frames", &WarpSpecializeFrameNode::frames);
}

static constexpr const char *_type_key = "tl.WarpSpecializeFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(WarpSpecializeFrameNode, TIRFrameNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.WarpSpecializeFrame",
WarpSpecializeFrameNode, TIRFrameNode);

public:
TVM_DLL void EnterWithScope() final {
Expand All @@ -330,15 +337,20 @@ class WarpSpecializeFrameNode : public TIRFrameNode {

class WarpSpecializeFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(WarpSpecializeFrame,
TIRFrame,
WarpSpecializeFrameNode);
explicit WarpSpecializeFrame(ObjectPtr<WarpSpecializeFrameNode> data)
: TIRFrame(::tvm::ffi::UnsafeInit{}) {
ICHECK(data != nullptr);
data_ = std::move(data);
}
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(WarpSpecializeFrame, TIRFrame,
WarpSpecializeFrameNode);
};

WarpSpecializeFrame WarpSpecialize(const Array<IntImm> &warp_group_ids,
const PrimExpr &thread_idx,
int warp_group_size = 128) {
ObjectPtr<WarpSpecializeFrameNode> n = make_object<WarpSpecializeFrameNode>();
ObjectPtr<WarpSpecializeFrameNode> n =
tvm::ffi::make_object<WarpSpecializeFrameNode>();
PrimExpr condition;
std::vector<int> warp_groups;
warp_groups.reserve(warp_group_ids.size());
Expand Down Expand Up @@ -376,13 +388,12 @@ WarpSpecializeFrame WarpSpecialize(const Array<IntImm> &warp_group_ids,
return WarpSpecializeFrame(n);
}

TVM_REGISTER_NODE_TYPE(WarpSpecializeFrameNode);
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.WarpSpecialize", WarpSpecialize);
KernelLaunchFrameNode::RegisterReflection();
WarpSpecializeFrameNode::RegisterReflection();
});
}

} // namespace tl
} // namespace tvm
44 changes: 12 additions & 32 deletions src/layout/layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,12 @@ Layout::Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index) {
}
forward_index =
forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); });

auto n = make_object<LayoutNode>(input_size, forward_index);
auto n = tvm::ffi::make_object<LayoutNode>(input_size, forward_index);
data_ = std::move(n);
}

Layout::Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index) {
auto n = make_object<LayoutNode>(input_size, forward_index);
auto n = tvm::ffi::make_object<LayoutNode>(input_size, forward_index);
data_ = std::move(n);
}

Expand Down Expand Up @@ -130,7 +129,6 @@ Array<PrimExpr> LayoutNode::Forward(const Array<PrimExpr> &vars) const {

Array<PrimExpr> transformed = forward_index_.Map(
[&](const PrimExpr &e) { return Substitute(e, vmap); });

// Concatenate with the remaining elements from vars
Array<PrimExpr> result;
for (size_t i = 0; i < vars.size() - InputDim(); i++) {
Expand Down Expand Up @@ -212,7 +210,7 @@ Fragment FragmentNode::DeReplicate() const {
factor = arith::ZeroAwareGCD(*rep_size, *idx_size);
}
if (factor == 1)
return GetRef<Fragment>(this);
return tvm::ffi::GetRef<Fragment>(this);

Map<Var, PrimExpr> vmap;
vmap.Set(ReplicationPlaceholder(), ReplicationPlaceholder() * factor +
Expand All @@ -224,7 +222,7 @@ Fragment FragmentNode::DeReplicate() const {
}

Fragment FragmentNode::BindThreadRange(Range thread_range) const {
auto n = make_object<FragmentNode>(*this);
auto n = tvm::ffi::make_object<FragmentNode>(*this);
n->thread_range_ = thread_range;
return Fragment(n);
}
Expand Down Expand Up @@ -336,8 +334,8 @@ Fragment::Fragment(Array<IterVar> forward_var, Array<PrimExpr> forward_index,
forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); });
forward_thread = Substitute(forward_thread, vmap);

auto n = make_object<FragmentNode>(input_size, forward_index, forward_thread,
replicate_size);
auto n = tvm::ffi::make_object<FragmentNode>(input_size, forward_index,
forward_thread, replicate_size);
data_ = std::move(n);
}

Expand All @@ -348,8 +346,8 @@ Fragment::Fragment(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
forward_thread = Substitute(
forward_thread, {{replicate_var.value(), ReplicationPlaceholder()}});
}
auto n = make_object<FragmentNode>(input_size, forward_index, forward_thread,
replicate_size);
auto n = tvm::ffi::make_object<FragmentNode>(input_size, forward_index,
forward_thread, replicate_size);
data_ = std::move(n);
}

Expand Down Expand Up @@ -442,21 +440,6 @@ std::string FragmentNode::DebugOutput() const {
return ss.str();
}

bool LayoutNode::SEqualReduce(const LayoutNode *other,
SEqualReducer equal) const {
return equal(this->InputShape(), other->InputShape()) &&
equal(this->forward_index_, other->forward_index_);
}

bool FragmentNode::SEqualReduce(const FragmentNode *other,
SEqualReducer equal) const {
return equal(this->ReplicateExtent(), other->ReplicateExtent()) &&
equal(this->InputShape(), other->InputShape()) &&
equal(this->ThreadExtent(), other->ThreadExtent()) &&
equal(this->forward_index_, other->forward_index_) &&
equal(this->forward_thread_, other->forward_thread_);
}

bool LayoutNode::IsEqual(const LayoutNode *other, bool skip_index) const {
bool ret = StructuralEqual()(this->InputShape(), other->InputShape());
ret &= StructuralEqual()(this->OutputShape(), other->OutputShape());
Expand Down Expand Up @@ -495,10 +478,7 @@ void FragmentNode::RegisterReflection() {
.def_ro("replicate_size", &FragmentNode::replicate_size_);
}

TVM_REGISTER_NODE_TYPE(LayoutNode);
TVM_REGISTER_NODE_TYPE(FragmentNode);

TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def_packed("tl.Layout",
Expand Down Expand Up @@ -582,13 +562,13 @@ TVM_FFI_STATIC_INIT_BLOCK({
.def("tl.make_linear_layout", [](int stride, int continuous) {
return makeGemmLayoutLinear(stride, continuous);
});
});
}

TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
LayoutNode::RegisterReflection();
FragmentNode::RegisterReflection();
});
}

} // namespace tl
} // namespace tvm
20 changes: 11 additions & 9 deletions src/layout/layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@

#include <tvm/arith/analyzer.h>
#include <tvm/arith/iter_affine_map.h>
#include <tvm/ffi/object.h>
#include <utility>

#include "../support/ffi_aliases.h"

namespace tvm {
namespace tl {

Expand Down Expand Up @@ -44,11 +47,10 @@ class LayoutNode : public Object {

virtual bool IsEqual(const LayoutNode *other, bool skip_index = false) const;

static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr const char *_type_key = "tl.Layout";
bool SEqualReduce(const LayoutNode *other, SEqualReducer equal) const;
static void RegisterReflection();
TVM_DECLARE_BASE_OBJECT_INFO(LayoutNode, Object);
TVM_FFI_DECLARE_OBJECT_INFO("tl.Layout", LayoutNode, Object);
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;

protected:
virtual Map<Var, Range> getVarMap() const;
Expand All @@ -65,7 +67,7 @@ class Layout : public ObjectRef {
TVM_DLL Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index);
TVM_DLL Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index);

TVM_DEFINE_OBJECT_REF_METHODS(Layout, ObjectRef, LayoutNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Layout, ObjectRef, LayoutNode);
};

class FragmentNode : public LayoutNode {
Expand Down Expand Up @@ -109,9 +111,9 @@ class FragmentNode : public LayoutNode {

static void RegisterReflection();

bool SEqualReduce(const FragmentNode *other, SEqualReducer equal) const;
static constexpr const char *_type_key = "tl.Fragment";
TVM_DECLARE_FINAL_OBJECT_INFO(FragmentNode, LayoutNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Fragment", FragmentNode, LayoutNode);
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;

protected:
Map<Var, Range> getVarMap() const final;
Expand All @@ -132,7 +134,7 @@ class Fragment : public Layout {
PrimExpr forward_thread, PrimExpr replicate_size,
Optional<Var> replicate_var);

TVM_DEFINE_OBJECT_REF_METHODS(Fragment, Layout, FragmentNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Fragment, Layout, FragmentNode);
};

Var InputPlaceholder(size_t idx);
Expand Down
Loading
Loading