Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/ffi/extra/env_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,12 @@ class EnvContext {
void SetDLPackManagedTensorAllocator(DLPackManagedTensorAllocator allocator,
int write_to_global_context,
DLPackManagedTensorAllocator* opt_out_original_allocator) {
dlpack_allocator_ = allocator;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Removing this assignment from the beginning of the function is the correct first step to fix the bug where the new allocator was returned instead of the previous one.

However, a more subtle logic issue remains. After this change, the function will return the previous thread-local allocator. But GetDLPackManagedTensorAllocator returns the effective allocator (thread-local if available, otherwise global). For consistency, SetDLPackManagedTensorAllocator should also return the previously effective allocator.

To fix this, you should capture the result of GetDLPackManagedTensorAllocator() before modifying any state, and return that as the original allocator.

For example:

void SetDLPackManagedTensorAllocator(...) {
  if (opt_out_original_allocator != nullptr) {
    *opt_out_original_allocator = GetDLPackManagedTensorAllocator();
  }
  if (write_to_global_context != 0) {
    GlobalTensorAllocator() = allocator;
  }
  dlpack_allocator_ = allocator;
}

This would make the behavior consistent and robust. I've also suggested a test case in test_c_env_api.cc that would fail with the current implementation but pass with this suggested fix.

if (opt_out_original_allocator != nullptr) {
*opt_out_original_allocator = GetDLPackManagedTensorAllocator();
}
if (write_to_global_context != 0) {
GlobalTensorAllocator() = allocator;
}
if (opt_out_original_allocator != nullptr) {
*opt_out_original_allocator = dlpack_allocator_;
}
dlpack_allocator_ = allocator;
}

Expand Down
8 changes: 8 additions & 0 deletions tests/cpp/extra/test_c_env_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ int TestDLPackManagedTensorAllocatorError(DLTensor* prototype, DLManagedTensorVe
return -1;
}

TEST(CEnvAPI, TVMFFIEnvSetDLPackManagedTensorAllocator) {
auto old_allocator = TVMFFIEnvGetDLPackManagedTensorAllocator();
DLPackManagedTensorAllocator pre_allocator;
TVMFFIEnvSetDLPackManagedTensorAllocator(TestDLPackManagedTensorAllocator, 0, &pre_allocator);
EXPECT_EQ(old_allocator, pre_allocator);
TVMFFIEnvSetDLPackManagedTensorAllocator(old_allocator, 0, nullptr);
}
Comment on lines +62 to +68
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The added test case is good for the basic scenario, but it's not comprehensive enough. It passes because initially both the thread-local and global allocators are nullptr. The test doesn't cover the scenario where a global allocator is active and the thread-local one is not. In that case, TVMFFIEnvGetDLPackManagedTensorAllocator would return the global allocator, but TVMFFIEnvSetDLPackManagedTensorAllocator would incorrectly return nullptr as the "previous" allocator. This indicates a logic issue in SetDLPackManagedTensorAllocator.

I suggest adding a test case to cover this scenario, which would look something like this:

TEST(CEnvAPI, TVMFFIEnvSetDLPackManagedTensorAllocator_Global) {
  // Setup: set a global allocator, but no thread-local one.
  TVMFFIEnvSetDLPackManagedTensorAllocator(TestDLPackManagedTensorAllocatorError, 1, nullptr);
  TVMFFIEnvSetDLPackManagedTensorAllocator(nullptr, 0, nullptr); // Unset thread-local

  // The effective allocator should be the global one.
  auto old_allocator = TVMFFIEnvGetDLPackManagedTensorAllocator();
  EXPECT_EQ(old_allocator, TestDLPackManagedTensorAllocatorError);

  // Set a new thread-local allocator and get the previous one.
  DLPackManagedTensorAllocator pre_allocator;
  TVMFFIEnvSetDLPackManagedTensorAllocator(TestDLPackManagedTensorAllocator, 0, &pre_allocator);

  // The returned previous allocator should be the effective (global) one.
  EXPECT_EQ(old_allocator, pre_allocator);

  // Cleanup
  TVMFFIEnvSetDLPackManagedTensorAllocator(nullptr, 1, nullptr);
}

This test would fail with the current implementation and highlight the need for a fix in src/ffi/extra/env_context.cc.


TEST(CEnvAPI, TVMFFIEnvTensorAlloc) {
auto old_allocator = TVMFFIEnvGetDLPackManagedTensorAllocator();
TVMFFIEnvSetDLPackManagedTensorAllocator(TestDLPackManagedTensorAllocator, 0, nullptr);
Expand Down