diff --git a/src/ffi/dtype.cc b/src/ffi/dtype.cc index 74c9eeb4..14cfa5dc 100644 --- a/src/ffi/dtype.cc +++ b/src/ffi/dtype.cc @@ -206,17 +206,61 @@ inline DLDataType StringViewToDLDataType_(std::string_view str) { dtype.bits = 32; dtype.lanes = 1; const char* scan; + const char* str_end = str.data() + str.length(); + + // Helper lambda to parse decimal digits from a bounded string_view + // Returns the parsed value and updates *ptr to point past the last digit + auto parse_digits = [](const char** ptr, const char* end) -> uint32_t { + uint64_t value = 0; + const char* start_ptr = *ptr; + while (*ptr < end && **ptr >= '0' && **ptr <= '9') { + value = value * 10 + (**ptr - '0'); + (*ptr)++; + } + if (value > UINT32_MAX) { + TVM_FFI_THROW(ValueError) << "Integer value in dtype string '" + << std::string_view(start_ptr, *ptr - start_ptr) + << "' is out of range for uint32_t"; + } + return static_cast(value); + }; + + // Helper lambda to parse lanes specification (e.g., "x16" or "xvscalex4") + // Returns the parsed lanes value and updates *ptr to point past the lanes specification + // Supports scalable vectors with the "xvscale" prefix (represented as negative lanes) + auto parse_lanes = [&](const char** ptr, const char* end, const std::string_view& dtype_str, + bool allow_scalable = false) -> uint16_t { + int multiplier = 1; + // Check for "xvscale" prefix for scalable vectors + if (allow_scalable && (end - *ptr >= 7) && strncmp(*ptr, "xvscale", 7) == 0) { + multiplier = -1; + *ptr += 7; + } + if (*ptr >= end || **ptr != 'x') { + return 1; // No lanes specification, default to 1 + } + (*ptr)++; // Skip 'x' + const char* digits_start = *ptr; + uint32_t lanes_val = parse_digits(ptr, end); + if (*ptr == digits_start || lanes_val == 0) { + TVM_FFI_THROW(ValueError) << "Invalid lanes specification in dtype '" << dtype_str + << "'. Lanes must be a positive integer."; + } + if (lanes_val > UINT16_MAX) { + TVM_FFI_THROW(ValueError) << "Lanes value " << lanes_val + << " is out of range for uint16_t in dtype '" << dtype_str << "'"; + } + return static_cast(multiplier * lanes_val); + }; auto parse_float = [&](const std::string_view& str, int offset, int code, int bits) { dtype.code = static_cast(code); dtype.bits = static_cast(bits); scan = str.data() + offset; - char* endpt = nullptr; - if (*scan == 'x') { - dtype.lanes = static_cast(strtoul(scan + 1, &endpt, 10)); - scan = endpt; - } - if (scan != str.data() + str.length()) { + const char* endpt = scan; + dtype.lanes = parse_lanes(&endpt, str_end, str); + scan = endpt; + if (scan != str_end) { TVM_FFI_THROW(ValueError) << "unknown dtype `" << str << '`'; } return dtype; @@ -293,19 +337,18 @@ inline DLDataType StringViewToDLDataType_(std::string_view str) { scan = str.data(); TVM_FFI_THROW(ValueError) << "unknown dtype `" << str << '`'; } - char* xdelim; // emulate sscanf("%ux%u", bits, lanes) - uint8_t bits = static_cast(strtoul(scan, &xdelim, 10)); - if (bits != 0) dtype.bits = bits; - int scalable_multiplier = 1; - if (strncmp(xdelim, "xvscale", 7) == 0) { - scalable_multiplier = -1; - xdelim += 7; - } - char* endpt = xdelim; - if (*xdelim == 'x') { - dtype.lanes = static_cast(scalable_multiplier * strtoul(xdelim + 1, &endpt, 10)); + // Parse bits manually to handle non-null-terminated string_view + const char* xdelim = scan; + uint32_t bits_val = parse_digits(&xdelim, str_end); + if (bits_val > UINT8_MAX) { + TVM_FFI_THROW(ValueError) << "Bits value " << bits_val + << " is out of range for uint8_t in dtype '" << str << "'"; } - if (endpt != str.data() + str.length()) { + uint8_t bits = static_cast(bits_val); + if (bits != 0) dtype.bits = bits; + const char* endpt = xdelim; + dtype.lanes = parse_lanes(&endpt, str_end, str, /*allow_scalable=*/true); + if (endpt != str_end) { TVM_FFI_THROW(ValueError) << "unknown dtype `" << str << '`'; } return dtype; diff --git a/tests/cpp/test_dtype.cc b/tests/cpp/test_dtype.cc index 79fc9d7c..67b74d8e 100644 --- a/tests/cpp/test_dtype.cc +++ b/tests/cpp/test_dtype.cc @@ -127,4 +127,90 @@ TEST(DataType, AnyConversionWithString) { EXPECT_EQ(opt_v1.value().bits, 16); EXPECT_EQ(opt_v1.value().lanes, 2); } + +TEST(DType, NonNullTerminatedStringView) { + // Simulate memory scenario similar to Electron where memory after string + // contains garbage data (digits from previous strings) + // + // We test by calling TVMFFIDataTypeFromString directly with TVMFFIByteArray + // to bypass String's automatic null-termination + + // Helper lambda to test with raw byte array (no null terminator) + auto test_dtype_from_bytes = [](const char* data, size_t size) -> DLDataType { + TVMFFIByteArray byte_array{data, size}; + DLDataType dtype; + int ret = TVMFFIDataTypeFromString(&byte_array, &dtype); + EXPECT_EQ(ret, 0) << "TVMFFIDataTypeFromString failed"; + return dtype; + }; + + // Test 1: "float16" followed by digit garbage + char buffer1[] = "float16999888777"; + DLDataType dtype1 = test_dtype_from_bytes(buffer1, 7); // Only "float16" + EXPECT_EQ(dtype1.code, kDLFloat); + EXPECT_EQ(dtype1.bits, 16); // Should be 16, not 16999888777! + EXPECT_EQ(dtype1.lanes, 1); + + // Test 2: "int32" followed by "x4" from previous leftover + char buffer2[] = "int32x4extradata"; + DLDataType dtype2 = test_dtype_from_bytes(buffer2, 5); // Only "int32" + EXPECT_EQ(dtype2.code, kDLInt); + EXPECT_EQ(dtype2.bits, 32); // Should be 32, not parse the 'x4' + EXPECT_EQ(dtype2.lanes, 1); // Should be 1, not 4 + + // Test 3: "uint8" followed by more digits + char buffer3[] = "uint8192"; + DLDataType dtype3 = test_dtype_from_bytes(buffer3, 5); // Only "uint8" + EXPECT_EQ(dtype3.code, kDLUInt); + EXPECT_EQ(dtype3.bits, 8); // Should be 8, not 8192 + EXPECT_EQ(dtype3.lanes, 1); + + // Test 4: "bfloat16" followed by "x2" garbage + char buffer4[] = "bfloat16x2garbage"; + DLDataType dtype4 = test_dtype_from_bytes(buffer4, 8); // Only "bfloat16" + EXPECT_EQ(dtype4.code, kDLBfloat); + EXPECT_EQ(dtype4.bits, 16); + EXPECT_EQ(dtype4.lanes, 1); // Should be 1, not 2 + + // Test 5: "bfloat16x2" - lanes within bounds (should work) + DLDataType dtype5 = test_dtype_from_bytes(buffer4, 10); // "bfloat16x2" + EXPECT_EQ(dtype5.code, kDLBfloat); + EXPECT_EQ(dtype5.bits, 16); + EXPECT_EQ(dtype5.lanes, 2); // Should correctly parse x2 + + // Test 6: Truly non-null-terminated - overwrite null byte + char buffer6[] = "float64AAAAA"; + buffer6[7] = 'X'; // Ensure no null terminator at position 7 + DLDataType dtype6 = test_dtype_from_bytes(buffer6, 7); // "float64" + EXPECT_EQ(dtype6.code, kDLFloat); + EXPECT_EQ(dtype6.bits, 64); + EXPECT_EQ(dtype6.lanes, 1); + + // Test 7: "int8" followed by "x16" pattern + char buffer7[] = "int8x16leftovers"; + DLDataType dtype7 = test_dtype_from_bytes(buffer7, 4); // Only "int8" + EXPECT_EQ(dtype7.code, kDLInt); + EXPECT_EQ(dtype7.bits, 8); + EXPECT_EQ(dtype7.lanes, 1); // Should be 1, not 16 + + // Test 8: With actual x specification that should parse + DLDataType dtype8 = test_dtype_from_bytes(buffer7, 7); // "int8x16" + EXPECT_EQ(dtype8.code, kDLInt); + EXPECT_EQ(dtype8.bits, 8); + EXPECT_EQ(dtype8.lanes, 16); // Should correctly parse x16 + + // Test 9: Scalable vector - "int32xvscalex4" + char buffer9[] = "int32xvscalex4extra"; + DLDataType dtype9 = test_dtype_from_bytes(buffer9, 14); // "int32xvscalex4" + EXPECT_EQ(dtype9.code, kDLInt); + EXPECT_EQ(dtype9.bits, 32); + EXPECT_EQ(dtype9.lanes, static_cast(-4)); // Scalable: -4 + + // Test 10: Scalable vector with garbage after + char buffer10[] = "float16xvscalex8999"; + DLDataType dtype10 = test_dtype_from_bytes(buffer10, 16); // "float16xvscalex8" + EXPECT_EQ(dtype10.code, kDLFloat); + EXPECT_EQ(dtype10.bits, 16); + EXPECT_EQ(dtype10.lanes, static_cast(-8)); // Should be -8, not parse "999" +} } // namespace