Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 59 additions & 11 deletions src/ffi/dtype.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,17 +206,46 @@ inline DLDataType StringViewToDLDataType_(std::string_view str) {
dtype.bits = 32;
dtype.lanes = 1;
const char* scan;
const char* str_end = str.data() + str.length();

// Helper lambda to parse decimal digits from a bounded string_view
// Returns the parsed value and updates ptr to point past the last digit
auto parse_digits = [](const char*& ptr, const char* end) -> uint32_t {
uint64_t value = 0;
const char* start_ptr = ptr;
while (ptr < end && *ptr >= '0' && *ptr <= '9') {
value = value * 10 + (*ptr - '0');
ptr++;
}
if (value > UINT32_MAX) {
TVM_FFI_THROW(ValueError) << "Integer value in dtype string '"
<< std::string_view(start_ptr, ptr - start_ptr)
<< "' is out of range for uint32_t";
}
return static_cast<uint32_t>(value);
};

auto parse_float = [&](const std::string_view& str, int offset, int code, int bits) {
dtype.code = static_cast<uint8_t>(code);
dtype.bits = static_cast<uint8_t>(bits);
scan = str.data() + offset;
char* endpt = nullptr;
if (*scan == 'x') {
dtype.lanes = static_cast<uint16_t>(strtoul(scan + 1, &endpt, 10));
const char* endpt = scan;
if (scan < str_end && *scan == 'x') {
endpt = scan + 1;
const char* digits_start = endpt;
uint32_t lanes_val = parse_digits(endpt, str_end);
if (endpt == digits_start || lanes_val == 0) {
TVM_FFI_THROW(ValueError) << "Invalid lanes specification in dtype '" << str
<< "'. Lanes must be a positive integer.";
}
if (lanes_val > UINT16_MAX) {
TVM_FFI_THROW(ValueError) << "Lanes value " << lanes_val
<< " is out of range for uint16_t in dtype '" << str << "'";
}
dtype.lanes = static_cast<uint16_t>(lanes_val);
scan = endpt;
}
if (scan != str.data() + str.length()) {
if (scan != str_end) {
TVM_FFI_THROW(ValueError) << "unknown dtype `" << str << '`';
}
return dtype;
Expand Down Expand Up @@ -293,19 +322,38 @@ inline DLDataType StringViewToDLDataType_(std::string_view str) {
scan = str.data();
TVM_FFI_THROW(ValueError) << "unknown dtype `" << str << '`';
}
char* xdelim; // emulate sscanf("%ux%u", bits, lanes)
uint8_t bits = static_cast<uint8_t>(strtoul(scan, &xdelim, 10));
// Parse bits manually to handle non-null-terminated string_view
const char* xdelim = scan;
uint32_t bits_val = parse_digits(xdelim, str_end);
if (bits_val > UINT8_MAX) {
TVM_FFI_THROW(ValueError) << "Bits value " << bits_val << " is out of range for uint8_t in dtype '"
<< str << "'";
}
uint8_t bits = static_cast<uint8_t>(bits_val);
if (bits != 0) dtype.bits = bits;
int scalable_multiplier = 1;
if (strncmp(xdelim, "xvscale", 7) == 0) {
// Check bounds before dereferencing xdelim
if (xdelim < str_end && strncmp(xdelim, "xvscale", 7) == 0) {
scalable_multiplier = -1;
xdelim += 7;
}
char* endpt = xdelim;
if (*xdelim == 'x') {
dtype.lanes = static_cast<uint16_t>(scalable_multiplier * strtoul(xdelim + 1, &endpt, 10));
const char* endpt = xdelim;
// Check bounds before dereferencing xdelim
if (xdelim < str_end && *xdelim == 'x') {
endpt = xdelim + 1;
const char* digits_start = endpt;
uint32_t lanes_val = parse_digits(endpt, str_end);
if (endpt == digits_start || lanes_val == 0) {
TVM_FFI_THROW(ValueError) << "Invalid lanes specification in dtype '" << str
<< "'. Lanes must be a positive integer.";
}
if (lanes_val > UINT16_MAX) {
TVM_FFI_THROW(ValueError) << "Lanes value " << lanes_val
<< " is out of range for uint16_t in dtype '" << str << "'";
}
dtype.lanes = static_cast<uint16_t>(scalable_multiplier * lanes_val);
}
if (endpt != str.data() + str.length()) {
if (endpt != str_end) {
TVM_FFI_THROW(ValueError) << "unknown dtype `" << str << '`';
}
return dtype;
Expand Down
72 changes: 72 additions & 0 deletions tests/cpp/test_dtype.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,76 @@ TEST(DataType, AnyConversionWithString) {
EXPECT_EQ(opt_v1.value().bits, 16);
EXPECT_EQ(opt_v1.value().lanes, 2);
}

TEST(DType, NonNullTerminatedStringView) {
// Simulate memory scenario similar to Electron where memory after string
// contains garbage data (digits from previous strings)
//
// We test by calling TVMFFIDataTypeFromString directly with TVMFFIByteArray
// to bypass String's automatic null-termination

// Helper lambda to test with raw byte array (no null terminator)
auto test_dtype_from_bytes = [](const char* data, size_t size) -> DLDataType {
TVMFFIByteArray byte_array{data, size};
DLDataType dtype;
int ret = TVMFFIDataTypeFromString(&byte_array, &dtype);
EXPECT_EQ(ret, 0) << "TVMFFIDataTypeFromString failed";
return dtype;
};

// Test 1: "float16" followed by digit garbage
char buffer1[] = "float16999888777";
DLDataType dtype1 = test_dtype_from_bytes(buffer1, 7); // Only "float16"
EXPECT_EQ(dtype1.code, kDLFloat);
EXPECT_EQ(dtype1.bits, 16); // Should be 16, not 16999888777!
EXPECT_EQ(dtype1.lanes, 1);

// Test 2: "int32" followed by "x4" from previous leftover
char buffer2[] = "int32x4extradata";
DLDataType dtype2 = test_dtype_from_bytes(buffer2, 5); // Only "int32"
EXPECT_EQ(dtype2.code, kDLInt);
EXPECT_EQ(dtype2.bits, 32); // Should be 32, not parse the 'x4'
EXPECT_EQ(dtype2.lanes, 1); // Should be 1, not 4

// Test 3: "uint8" followed by more digits
char buffer3[] = "uint8192";
DLDataType dtype3 = test_dtype_from_bytes(buffer3, 5); // Only "uint8"
EXPECT_EQ(dtype3.code, kDLUInt);
EXPECT_EQ(dtype3.bits, 8); // Should be 8, not 8192
EXPECT_EQ(dtype3.lanes, 1);

// Test 4: "bfloat16" followed by "x2" garbage
char buffer4[] = "bfloat16x2garbage";
DLDataType dtype4 = test_dtype_from_bytes(buffer4, 8); // Only "bfloat16"
EXPECT_EQ(dtype4.code, kDLBfloat);
EXPECT_EQ(dtype4.bits, 16);
EXPECT_EQ(dtype4.lanes, 1); // Should be 1, not 2

// Test 5: "bfloat16x2" - lanes within bounds (should work)
DLDataType dtype5 = test_dtype_from_bytes(buffer4, 10); // "bfloat16x2"
EXPECT_EQ(dtype5.code, kDLBfloat);
EXPECT_EQ(dtype5.bits, 16);
EXPECT_EQ(dtype5.lanes, 2); // Should correctly parse x2

// Test 6: Truly non-null-terminated - overwrite null byte
char buffer6[] = "float64AAAAA";
buffer6[7] = 'X'; // Ensure no null terminator at position 7
DLDataType dtype6 = test_dtype_from_bytes(buffer6, 7); // "float64"
EXPECT_EQ(dtype6.code, kDLFloat);
EXPECT_EQ(dtype6.bits, 64);
EXPECT_EQ(dtype6.lanes, 1);

// Test 7: "int8" followed by "x16" pattern
char buffer7[] = "int8x16leftovers";
DLDataType dtype7 = test_dtype_from_bytes(buffer7, 4); // Only "int8"
EXPECT_EQ(dtype7.code, kDLInt);
EXPECT_EQ(dtype7.bits, 8);
EXPECT_EQ(dtype7.lanes, 1); // Should be 1, not 16

// Test 8: With actual x specification that should parse
DLDataType dtype8 = test_dtype_from_bytes(buffer7, 7); // "int8x16"
EXPECT_EQ(dtype8.code, kDLInt);
EXPECT_EQ(dtype8.bits, 8);
EXPECT_EQ(dtype8.lanes, 16); // Should correctly parse x16
}
} // namespace
Loading