Skip to content

Commit 71bbe91

Browse files
nan2088nan
andauthored
Fix TVMFFIEnvSetDLPackManagedTensorAllocator to correctly return the original allocator (#371)
Before this commit, TVMFFIEnvSetDLPackManagedTensorAllocator incorrectly set the previous allocator in: TVMFFIEnvSetDLPackManagedTensorAllocator(NewAllocator, 0, &pre_allocator); --------- Co-authored-by: nan <[email protected]>
1 parent 46ab644 commit 71bbe91

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

src/ffi/extra/env_context.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,12 @@ class EnvContext {
6565
void SetDLPackManagedTensorAllocator(DLPackManagedTensorAllocator allocator,
6666
int write_to_global_context,
6767
DLPackManagedTensorAllocator* opt_out_original_allocator) {
68-
dlpack_allocator_ = allocator;
68+
if (opt_out_original_allocator != nullptr) {
69+
*opt_out_original_allocator = GetDLPackManagedTensorAllocator();
70+
}
6971
if (write_to_global_context != 0) {
7072
GlobalTensorAllocator() = allocator;
7173
}
72-
if (opt_out_original_allocator != nullptr) {
73-
*opt_out_original_allocator = dlpack_allocator_;
74-
}
7574
dlpack_allocator_ = allocator;
7675
}
7776

tests/cpp/extra/test_c_env_api.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,14 @@ int TestDLPackManagedTensorAllocatorError(DLTensor* prototype, DLManagedTensorVe
5959
return -1;
6060
}
6161

62+
TEST(CEnvAPI, TVMFFIEnvSetDLPackManagedTensorAllocator) {
63+
auto old_allocator = TVMFFIEnvGetDLPackManagedTensorAllocator();
64+
DLPackManagedTensorAllocator pre_allocator;
65+
TVMFFIEnvSetDLPackManagedTensorAllocator(TestDLPackManagedTensorAllocator, 0, &pre_allocator);
66+
EXPECT_EQ(old_allocator, pre_allocator);
67+
TVMFFIEnvSetDLPackManagedTensorAllocator(old_allocator, 0, nullptr);
68+
}
69+
6270
TEST(CEnvAPI, TVMFFIEnvTensorAlloc) {
6371
auto old_allocator = TVMFFIEnvGetDLPackManagedTensorAllocator();
6472
TVMFFIEnvSetDLPackManagedTensorAllocator(TestDLPackManagedTensorAllocator, 0, nullptr);

0 commit comments

Comments
 (0)