Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-48747][SQL] Add code point iterator to UTF8String #47123

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.function.Function;
import java.util.Iterator;
import java.util.Map;
import java.util.regex.Pattern;

Expand Down Expand Up @@ -405,6 +406,105 @@ public boolean isValid() {
return true;
}

/**
* Code point iteration over a UTF8String can be done using one of two modes:
* 1. CODE_POINT_ITERATOR_ASSUME_VALID: The caller ensures that the UTF8String is valid and does
* not contain any invalid UTF-8 byte sequences. In this case, the code point iterator will
* return the code points in the current string one by one, as integers. If an invalid code
* point is found within the string during iteration, an exception will be thrown. This mode
* is more dangerous, but faster - since no scan is needed prior to beginning iteration.
* 2. CODE_POINT_ITERATOR_MAKE_VALID: The caller does not ensure that the UTF8String is valid,
* but instead expects the code point iterator to first check whether the current UTF8String
* is valid, then perform the invalid byte sequence replacement using `makeValid`, and finally
* begin the code point iteration over the resulting valid UTF8String. However, the original
* UTF8String stays unchanged. This mode is safer, but slower - due to initial validation.
* The default mode is CODE_POINT_ITERATOR_ASSUME_VALID.
*/
public enum CodePointIteratorType {
CODE_POINT_ITERATOR_ASSUME_VALID,
CODE_POINT_ITERATOR_MAKE_VALID
}

/**
* Returns a code point iterator for this UTF8String.
*/
public Iterator<Integer> codePointIterator() {
return codePointIterator(CodePointIteratorType.CODE_POINT_ITERATOR_ASSUME_VALID);
}

public Iterator<Integer> codePointIterator(CodePointIteratorType iteratorMode) {
if (iteratorMode == CodePointIteratorType.CODE_POINT_ITERATOR_MAKE_VALID && !isValid()) {
return makeValid().codePointIterator();
}
return new CodePointIterator();
}

/**
* Code point iterator implementation for the UTF8String class. The iterator will return code
* points in the current string one by one, as integers. However, the code point iterator is only
* guaranteed to work if the current UTF8String does not contain any invalid UTF-8 byte sequences.
* If the current string contains any invalid UTF-8 byte sequences, exceptions will be thrown.
*/
private class CodePointIterator implements Iterator<Integer> {
// Byte index used to iterate over the current UTF8String.
private int byteIndex = 0;

@Override
public boolean hasNext() {
return byteIndex < numBytes;
}

@Override
public Integer next() {
if (!hasNext()) {
throw new IndexOutOfBoundsException();
}
int codePoint = codePointFrom(byteIndex);
byteIndex += numBytesForFirstByte(getByte(byteIndex));
return codePoint;
}
}

/**
* Reverse version of the code point iterator for this UTF8String, returns code points in the
* current string one by one, as integers, in reverse order. The logic is similar to the above.
*/

public Iterator<Integer> reverseCodePointIterator() {
return reverseCodePointIterator(CodePointIteratorType.CODE_POINT_ITERATOR_ASSUME_VALID);
}

public Iterator<Integer> reverseCodePointIterator(CodePointIteratorType iteratorMode) {
if (iteratorMode == CodePointIteratorType.CODE_POINT_ITERATOR_MAKE_VALID && !isValid()) {
return makeValid().reverseCodePointIterator();
}
return new ReverseCodePointIterator();
}

private class ReverseCodePointIterator implements Iterator<Integer> {
private int byteIndex = numBytes - 1;

@Override
public boolean hasNext() {
return byteIndex >= 0;
}

@Override
public Integer next() {
if (!hasNext()) {
throw new IndexOutOfBoundsException();
}
while (byteIndex > 0 && isContinuationByte(getByte(byteIndex))) {
--byteIndex;
}
return codePointFrom(byteIndex--);
}

private boolean isContinuationByte(byte b) {
return (b & 0xC0) == 0x80;
}
}

/**
* Returns a substring of this.
* @param start the position of first code point
Expand Down Expand Up @@ -483,6 +583,46 @@ public byte getByte(int i) {
return Platform.getByte(base, offset + i);
}

/**
* Returns the code point at position `i`.
uros-db marked this conversation as resolved.
Show resolved Hide resolved
*/
public int getChar(int i) {
if (i < 0 || i >= numChars()) {
throw new IndexOutOfBoundsException();
}
int charCount = 0, byteCount = 0;
while (charCount < i) {
byteCount += numBytesForFirstByte(getByte(byteCount));
charCount += 1;
}
return codePointFrom(byteCount);
}

/**
* Returns the code point starting from the byte at position `index`.
*/
public int codePointFrom(int index) {
if (index < 0 || index >= numBytes) {
throw new IndexOutOfBoundsException();
}
byte b = getByte(index);
int numBytes = numBytesForFirstByte(b);
return switch (numBytes) {
case 1 ->
b & 0x7F;
case 2 ->
((b & 0x1F) << 6) | (getByte(index + 1) & 0x3F);
case 3 ->
((b & 0x0F) << 12) | ((getByte(index + 1) & 0x3F) << 6) |
(getByte(index + 2) & 0x3F);
case 4 ->
((b & 0x07) << 18) | ((getByte(index + 1) & 0x3F) << 12) |
((getByte(index + 2) & 0x3F) << 6) | (getByte(index + 3) & 0x3F);
default ->
throw new IllegalArgumentException("Invalid UTF-8 sequence");
};
}

public boolean matchAt(final UTF8String s, int pos) {
if (s.numBytes + pos > numBytes || pos < 0) {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1110,6 +1110,112 @@ public void isValid() {
testIsValid("0x9C 0x76 0x17", "0xEF 0xBF 0xBD 0x76 0x17");
}

@Test
public void utf8StringCodePoints() {
String s = "aéह 日å!";
UTF8String s0 = fromString(s);
for (int i = 0; i < s.length(); ++i) {
assertEquals(s.codePointAt(i), s0.getChar(i));
}

UTF8String s1 = fromBytes(new byte[] {0x41, (byte) 0xC3, (byte) 0xB1, (byte) 0xE2,
(byte) 0x82, (byte) 0xAC, (byte) 0xF0, (byte) 0x90, (byte) 0x8D, (byte) 0x88});
// numBytesForFirstByte
assertEquals(1, UTF8String.numBytesForFirstByte(s1.getByte(0)));
assertEquals(2, UTF8String.numBytesForFirstByte(s1.getByte(1)));
assertEquals(3, UTF8String.numBytesForFirstByte(s1.getByte(3)));
assertEquals(4, UTF8String.numBytesForFirstByte(s1.getByte(6)));
// getByte
assertEquals((byte) 0x41, s1.getByte(0));
assertEquals((byte) 0xC3, s1.getByte(1));
assertEquals((byte) 0xE2, s1.getByte(3));
assertEquals((byte) 0xF0, s1.getByte(6));
// codePointFrom
assertEquals(0x41, s1.codePointFrom(0));
assertEquals(0xF1, s1.codePointFrom(1));
assertEquals(0x20AC, s1.codePointFrom(3));
assertEquals(0x10348, s1.codePointFrom(6));
assertThrows(IndexOutOfBoundsException.class, () -> s1.codePointFrom(-1));
assertThrows(IndexOutOfBoundsException.class, () -> s1.codePointFrom(99));
// getChar
assertEquals(0x41, s1.getChar(0));
assertEquals(0xF1, s1.getChar(1));
assertEquals(0x20AC, s1.getChar(2));
assertEquals(0x10348, s1.getChar(3));
assertThrows(IndexOutOfBoundsException.class, () -> s1.getChar(-1));
assertThrows(IndexOutOfBoundsException.class, () -> s1.getChar(99));

UTF8String s2 = fromString("Añ€𐍈");
// numBytesForFirstByte
assertEquals(1, UTF8String.numBytesForFirstByte(s2.getByte(0)));
assertEquals(2, UTF8String.numBytesForFirstByte(s2.getByte(1)));
assertEquals(3, UTF8String.numBytesForFirstByte(s2.getByte(3)));
assertEquals(4, UTF8String.numBytesForFirstByte(s2.getByte(6)));
// getByte
assertEquals((byte) 0x41, s2.getByte(0));
assertEquals((byte) 0xC3, s2.getByte(1));
assertEquals((byte) 0xE2, s2.getByte(3));
assertEquals((byte) 0xF0, s2.getByte(6));
// codePointFrom
assertEquals(0x41, s2.codePointFrom(0));
assertEquals(0xF1, s2.codePointFrom(1));
assertEquals(0x20AC, s2.codePointFrom(3));
assertEquals(0x10348, s2.codePointFrom(6));
assertThrows(IndexOutOfBoundsException.class, () -> s2.codePointFrom(-1));
assertThrows(IndexOutOfBoundsException.class, () -> s2.codePointFrom(99));
// getChar
assertEquals(0x41, s2.getChar(0));
assertEquals(0xF1, s2.getChar(1));
assertEquals(0x20AC, s2.getChar(2));
assertEquals(0x10348, s2.getChar(3));
assertThrows(IndexOutOfBoundsException.class, () -> s2.getChar(-1));
assertThrows(IndexOutOfBoundsException.class, () -> s2.getChar(99));

UTF8String s3 = EMPTY_UTF8;
// codePointFrom
assertThrows(IndexOutOfBoundsException.class, () -> s3.codePointFrom(0));
assertThrows(IndexOutOfBoundsException.class, () -> s3.codePointFrom(-1));
assertThrows(IndexOutOfBoundsException.class, () -> s3.codePointFrom(99));
// getChar
assertThrows(IndexOutOfBoundsException.class, () -> s3.getChar(0));
assertThrows(IndexOutOfBoundsException.class, () -> s3.getChar(-1));
assertThrows(IndexOutOfBoundsException.class, () -> s3.getChar(99));
}

private void testCodePointIterator(String str) {
UTF8String s = fromString(str);
Iterator<Integer> it = s.codePointIterator();
for (int i = 0; i < str.length(); ++i) {
assertTrue(it.hasNext());
assertEquals(str.charAt(i), (int) it.next());
}
assertFalse(it.hasNext());
}
@Test
public void codePointIterator() {
testCodePointIterator("");
testCodePointIterator("abc");
testCodePointIterator("a!2&^R");
testCodePointIterator("aéह 日å!");
}

private void testReverseCodePointIterator(String str) {
UTF8String s = fromString(str);
Iterator<Integer> it = s.reverseCodePointIterator();
for (int i = str.length() - 1; i >= 0 ; --i) {
assertTrue(it.hasNext());
assertEquals(str.charAt(i), (int) it.next());
}
assertFalse(it.hasNext());
}
@Test
public void reverseCodePointIterator() {
testReverseCodePointIterator("");
testReverseCodePointIterator("abc");
testReverseCodePointIterator("a!2&^R");
testReverseCodePointIterator("aéह 日å!");
}

@Test
public void toBinaryString() {
assertEquals(ZERO_UTF8, UTF8String.toBinaryString(0));
Expand Down