Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-48747][SQL] Add code point iterator to UTF8String #47123

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.function.Function;
import java.util.Iterator;
import java.util.Map;
import java.util.regex.Pattern;

Expand Down Expand Up @@ -405,6 +406,105 @@ public boolean isValid() {
return true;
}

/**
* Code point iteration over a UTF8String can be done using one of two modes:
* 1. CODE_POINT_ITERATOR_ASSUME_VALID: The caller ensures that the UTF8String is valid and does
* not contain any invalid UTF-8 byte sequences. In this case, the code point iterator will
* return the code points in the current string one by one, as integers. If an invalid code
* point is found within the string during iteration, an exception will be thrown. This mode
* is more dangerous, but faster - since no scan is needed prior to beginning iteration.
* 2. CODE_POINT_ITERATOR_MAKE_VALID: The caller does not ensure that the UTF8String is valid,
* but instead expects the code point iterator to first check whether the current UTF8String
* is valid, then perform the invalid byte sequence replacement using `makeValid`, and finally
* begin the code point iteration over the resulting valid UTF8String. However, the original
* UTF8String stays unchanged. This mode is safer, but slower - due to initial validation.
* The default mode is CODE_POINT_ITERATOR_ASSUME_VALID.
*/
public enum CodePointIteratorType {
CODE_POINT_ITERATOR_ASSUME_VALID, // USE ONLY WITH VALID STRINGS
CODE_POINT_ITERATOR_MAKE_VALID
}

/**
* Returns a code point iterator for this UTF8String.
*/
public Iterator<Integer> codePointIterator() {
return codePointIterator(CodePointIteratorType.CODE_POINT_ITERATOR_ASSUME_VALID);
}

public Iterator<Integer> codePointIterator(CodePointIteratorType iteratorMode) {
if (iteratorMode == CodePointIteratorType.CODE_POINT_ITERATOR_MAKE_VALID && !isValid()) {
return makeValid().codePointIterator();
}
return new CodePointIterator();
}

/**
* Code point iterator implementation for the UTF8String class. The iterator will return code
* points in the current string one by one, as integers. However, the code point iterator is only
* guaranteed to work if the current UTF8String does not contain any invalid UTF-8 byte sequences.
* If the current string contains any invalid UTF-8 byte sequences, exceptions will be thrown.
*/
private class CodePointIterator implements Iterator<Integer> {
// Byte index used to iterate over the current UTF8String.
private int byteIndex = 0;

@Override
public boolean hasNext() {
return byteIndex < numBytes;
}

@Override
public Integer next() {
if (!hasNext()) {
throw new IndexOutOfBoundsException();
}
int codePoint = codePointFrom(byteIndex);
byteIndex += numBytesForFirstByte(getByte(byteIndex));
return codePoint;
}
}

/**
* Reverse version of the code point iterator for this UTF8String, returns code points in the
* current string one by one, as integers, in reverse order. The logic is similar to the above.
*/

public Iterator<Integer> reverseCodePointIterator() {
return reverseCodePointIterator(CodePointIteratorType.CODE_POINT_ITERATOR_ASSUME_VALID);
}

public Iterator<Integer> reverseCodePointIterator(CodePointIteratorType iteratorMode) {
if (iteratorMode == CodePointIteratorType.CODE_POINT_ITERATOR_MAKE_VALID && !isValid()) {
return makeValid().reverseCodePointIterator();
}
return new ReverseCodePointIterator();
}

private class ReverseCodePointIterator implements Iterator<Integer> {
private int byteIndex = numBytes - 1;

@Override
public boolean hasNext() {
return byteIndex >= 0;
}

@Override
public Integer next() {
if (!hasNext()) {
throw new IndexOutOfBoundsException();
}
while (byteIndex > 0 && isContinuationByte(getByte(byteIndex))) {
--byteIndex;
}
return codePointFrom(byteIndex--);
}

private boolean isContinuationByte(byte b) {
return (b & 0xC0) == 0x80;
}
}

/**
* Returns a substring of this.
* @param start the position of first code point
Expand Down Expand Up @@ -477,10 +577,53 @@ public boolean contains(final UTF8String substring) {
}

/**
* Returns the byte at position `i`.
* Returns the byte at (byte) position `byteIndex`. If byte index is invalid, returns 0.
*/
public byte getByte(int byteIndex) {
return Platform.getByte(base, offset + byteIndex);
}

/**
* Returns the code point at (char) position `charIndex`. If char index is invalid, throws
* exception. Note that this method is not efficient as it needs to traverse the UTF-8 string.
* If `byteIndex` of the first byte in the code point is known, use `codePointFrom` instead.
*/
public int getChar(int charIndex) {
if (charIndex < 0 || charIndex >= numChars()) {
throw new IndexOutOfBoundsException();
}
int charCount = 0, byteCount = 0;
while (charCount < charIndex) {
byteCount += numBytesForFirstByte(getByte(byteCount));
charCount += 1;
}
return codePointFrom(byteCount);
}

/**
* Returns the code point starting from the byte at position `byteIndex`.
* If byte index is invalid, throws exception.
*/
public byte getByte(int i) {
return Platform.getByte(base, offset + i);
public int codePointFrom(int byteIndex) {
if (byteIndex < 0 || byteIndex >= numBytes) {
throw new IndexOutOfBoundsException();
}
byte b = getByte(byteIndex);
int numBytes = numBytesForFirstByte(b);
return switch (numBytes) {
case 1 ->
b & 0x7F;
case 2 ->
((b & 0x1F) << 6) | (getByte(byteIndex + 1) & 0x3F);
case 3 ->
((b & 0x0F) << 12) | ((getByte(byteIndex + 1) & 0x3F) << 6) |
(getByte(byteIndex + 2) & 0x3F);
case 4 ->
((b & 0x07) << 18) | ((getByte(byteIndex + 1) & 0x3F) << 12) |
((getByte(byteIndex + 2) & 0x3F) << 6) | (getByte(byteIndex + 3) & 0x3F);
default ->
throw new IllegalStateException("Error in UTF-8 code point");
};
}

public boolean matchAt(final UTF8String s, int pos) {
Expand Down
Loading