Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 61 additions & 18 deletions src/ffi/dtype.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,17 +206,61 @@ inline DLDataType StringViewToDLDataType_(std::string_view str) {
dtype.bits = 32;
dtype.lanes = 1;
const char* scan;
const char* str_end = str.data() + str.length();

// Helper lambda to parse decimal digits from a bounded string_view
// Returns the parsed value and updates *ptr to point past the last digit
auto parse_digits = [](const char** ptr, const char* end) -> uint32_t {
uint64_t value = 0;
const char* start_ptr = *ptr;
while (*ptr < end && **ptr >= '0' && **ptr <= '9') {
value = value * 10 + (**ptr - '0');
(*ptr)++;
}
Comment on lines +216 to +219
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of parse_digits has a subtle bug where value can overflow uint64_t if the input string contains a very large number (more than 19 digits). If value overflows, it will wrap around (which is defined behavior for unsigned integers), and the subsequent check value > UINT32_MAX might fail to detect the overflow, leading to incorrect parsing.

To make this function more robust, you should check for potential overflow before performing the multiplication and addition. Since the final value should fit in a uint32_t, you can check against UINT32_MAX within the loop.

    while (*ptr < end && **ptr >= '0' && **ptr <= '9') {
      uint8_t digit = **ptr - '0';
      if (value > UINT32_MAX / 10 || (value == UINT32_MAX / 10 && digit > UINT32_MAX % 10)) {
        // Number is too large for uint32_t, set to overflow and consume rest of digits.
        value = (uint64_t)UINT32_MAX + 1;
        while (*ptr < end && **ptr >= '0' && **ptr <= '9') {
          (*ptr)++;
        }
        break;
      }
      value = value * 10 + digit;
      (*ptr)++;
    }

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gemini is getting too much

if (value > UINT32_MAX) {
TVM_FFI_THROW(ValueError) << "Integer value in dtype string '"
<< std::string_view(start_ptr, *ptr - start_ptr)
<< "' is out of range for uint32_t";
}
return static_cast<uint32_t>(value);
};

// Helper lambda to parse lanes specification (e.g., "x16" or "xvscalex4")
// Returns the parsed lanes value and updates *ptr to point past the lanes specification
// Supports scalable vectors with the "xvscale" prefix (represented as negative lanes)
auto parse_lanes = [&](const char** ptr, const char* end, const std::string_view& dtype_str,
bool allow_scalable = false) -> uint16_t {
int multiplier = 1;
// Check for "xvscale" prefix for scalable vectors
if (allow_scalable && (end - *ptr >= 7) && strncmp(*ptr, "xvscale", 7) == 0) {
multiplier = -1;
*ptr += 7;
}
if (*ptr >= end || **ptr != 'x') {
return 1; // No lanes specification, default to 1
}
(*ptr)++; // Skip 'x'
const char* digits_start = *ptr;
uint32_t lanes_val = parse_digits(ptr, end);
if (*ptr == digits_start || lanes_val == 0) {
TVM_FFI_THROW(ValueError) << "Invalid lanes specification in dtype '" << dtype_str
<< "'. Lanes must be a positive integer.";
}
if (lanes_val > UINT16_MAX) {
TVM_FFI_THROW(ValueError) << "Lanes value " << lanes_val
<< " is out of range for uint16_t in dtype '" << dtype_str << "'";
}
return static_cast<uint16_t>(multiplier * lanes_val);
};

auto parse_float = [&](const std::string_view& str, int offset, int code, int bits) {
dtype.code = static_cast<uint8_t>(code);
dtype.bits = static_cast<uint8_t>(bits);
scan = str.data() + offset;
char* endpt = nullptr;
if (*scan == 'x') {
dtype.lanes = static_cast<uint16_t>(strtoul(scan + 1, &endpt, 10));
scan = endpt;
}
if (scan != str.data() + str.length()) {
const char* endpt = scan;
dtype.lanes = parse_lanes(&endpt, str_end, str);
scan = endpt;
if (scan != str_end) {
TVM_FFI_THROW(ValueError) << "unknown dtype `" << str << '`';
}
return dtype;
Expand Down Expand Up @@ -293,19 +337,18 @@ inline DLDataType StringViewToDLDataType_(std::string_view str) {
scan = str.data();
TVM_FFI_THROW(ValueError) << "unknown dtype `" << str << '`';
}
char* xdelim; // emulate sscanf("%ux%u", bits, lanes)
uint8_t bits = static_cast<uint8_t>(strtoul(scan, &xdelim, 10));
if (bits != 0) dtype.bits = bits;
int scalable_multiplier = 1;
if (strncmp(xdelim, "xvscale", 7) == 0) {
scalable_multiplier = -1;
xdelim += 7;
}
char* endpt = xdelim;
if (*xdelim == 'x') {
dtype.lanes = static_cast<uint16_t>(scalable_multiplier * strtoul(xdelim + 1, &endpt, 10));
// Parse bits manually to handle non-null-terminated string_view
const char* xdelim = scan;
uint32_t bits_val = parse_digits(&xdelim, str_end);
if (bits_val > UINT8_MAX) {
TVM_FFI_THROW(ValueError) << "Bits value " << bits_val
<< " is out of range for uint8_t in dtype '" << str << "'";
}
if (endpt != str.data() + str.length()) {
uint8_t bits = static_cast<uint8_t>(bits_val);
if (bits != 0) dtype.bits = bits;
const char* endpt = xdelim;
dtype.lanes = parse_lanes(&endpt, str_end, str, /*allow_scalable=*/true);
if (endpt != str_end) {
TVM_FFI_THROW(ValueError) << "unknown dtype `" << str << '`';
}
return dtype;
Expand Down
86 changes: 86 additions & 0 deletions tests/cpp/test_dtype.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,90 @@ TEST(DataType, AnyConversionWithString) {
EXPECT_EQ(opt_v1.value().bits, 16);
EXPECT_EQ(opt_v1.value().lanes, 2);
}

TEST(DType, NonNullTerminatedStringView) {
// Simulate memory scenario similar to Electron where memory after string
// contains garbage data (digits from previous strings)
//
// We test by calling TVMFFIDataTypeFromString directly with TVMFFIByteArray
// to bypass String's automatic null-termination

// Helper lambda to test with raw byte array (no null terminator)
auto test_dtype_from_bytes = [](const char* data, size_t size) -> DLDataType {
TVMFFIByteArray byte_array{data, size};
DLDataType dtype;
int ret = TVMFFIDataTypeFromString(&byte_array, &dtype);
EXPECT_EQ(ret, 0) << "TVMFFIDataTypeFromString failed";
return dtype;
};

// Test 1: "float16" followed by digit garbage
char buffer1[] = "float16999888777";
DLDataType dtype1 = test_dtype_from_bytes(buffer1, 7); // Only "float16"
EXPECT_EQ(dtype1.code, kDLFloat);
EXPECT_EQ(dtype1.bits, 16); // Should be 16, not 16999888777!
EXPECT_EQ(dtype1.lanes, 1);

// Test 2: "int32" followed by "x4" from previous leftover
char buffer2[] = "int32x4extradata";
DLDataType dtype2 = test_dtype_from_bytes(buffer2, 5); // Only "int32"
EXPECT_EQ(dtype2.code, kDLInt);
EXPECT_EQ(dtype2.bits, 32); // Should be 32, not parse the 'x4'
EXPECT_EQ(dtype2.lanes, 1); // Should be 1, not 4

// Test 3: "uint8" followed by more digits
char buffer3[] = "uint8192";
DLDataType dtype3 = test_dtype_from_bytes(buffer3, 5); // Only "uint8"
EXPECT_EQ(dtype3.code, kDLUInt);
EXPECT_EQ(dtype3.bits, 8); // Should be 8, not 8192
EXPECT_EQ(dtype3.lanes, 1);

// Test 4: "bfloat16" followed by "x2" garbage
char buffer4[] = "bfloat16x2garbage";
DLDataType dtype4 = test_dtype_from_bytes(buffer4, 8); // Only "bfloat16"
EXPECT_EQ(dtype4.code, kDLBfloat);
EXPECT_EQ(dtype4.bits, 16);
EXPECT_EQ(dtype4.lanes, 1); // Should be 1, not 2

// Test 5: "bfloat16x2" - lanes within bounds (should work)
DLDataType dtype5 = test_dtype_from_bytes(buffer4, 10); // "bfloat16x2"
EXPECT_EQ(dtype5.code, kDLBfloat);
EXPECT_EQ(dtype5.bits, 16);
EXPECT_EQ(dtype5.lanes, 2); // Should correctly parse x2

// Test 6: Truly non-null-terminated - overwrite null byte
char buffer6[] = "float64AAAAA";
buffer6[7] = 'X'; // Ensure no null terminator at position 7
DLDataType dtype6 = test_dtype_from_bytes(buffer6, 7); // "float64"
EXPECT_EQ(dtype6.code, kDLFloat);
EXPECT_EQ(dtype6.bits, 64);
EXPECT_EQ(dtype6.lanes, 1);

// Test 7: "int8" followed by "x16" pattern
char buffer7[] = "int8x16leftovers";
DLDataType dtype7 = test_dtype_from_bytes(buffer7, 4); // Only "int8"
EXPECT_EQ(dtype7.code, kDLInt);
EXPECT_EQ(dtype7.bits, 8);
EXPECT_EQ(dtype7.lanes, 1); // Should be 1, not 16

// Test 8: With actual x specification that should parse
DLDataType dtype8 = test_dtype_from_bytes(buffer7, 7); // "int8x16"
EXPECT_EQ(dtype8.code, kDLInt);
EXPECT_EQ(dtype8.bits, 8);
EXPECT_EQ(dtype8.lanes, 16); // Should correctly parse x16

// Test 9: Scalable vector - "int32xvscalex4"
char buffer9[] = "int32xvscalex4extra";
DLDataType dtype9 = test_dtype_from_bytes(buffer9, 14); // "int32xvscalex4"
EXPECT_EQ(dtype9.code, kDLInt);
EXPECT_EQ(dtype9.bits, 32);
EXPECT_EQ(dtype9.lanes, static_cast<uint16_t>(-4)); // Scalable: -4

// Test 10: Scalable vector with garbage after
char buffer10[] = "float16xvscalex8999";
DLDataType dtype10 = test_dtype_from_bytes(buffer10, 16); // "float16xvscalex8"
EXPECT_EQ(dtype10.code, kDLFloat);
EXPECT_EQ(dtype10.bits, 16);
EXPECT_EQ(dtype10.lanes, static_cast<uint16_t>(-8)); // Should be -8, not parse "999"
}
} // namespace