diff --git a/src/main/java/com/github/packageurl/internal/StringUtil.java b/src/main/java/com/github/packageurl/internal/StringUtil.java index e8ed080..ee1a45a 100644 --- a/src/main/java/com/github/packageurl/internal/StringUtil.java +++ b/src/main/java/com/github/packageurl/internal/StringUtil.java @@ -21,10 +21,10 @@ */ package com.github.packageurl.internal; +import static java.lang.Byte.toUnsignedInt; + import com.github.packageurl.ValidationException; -import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; -import java.util.stream.IntStream; /** * String utility for validation and encoding. @@ -35,6 +35,24 @@ public final class StringUtil { private static final byte PERCENT_CHAR = '%'; + private static final boolean[] UNRESERVED_CHARS = new boolean[128]; + + static { + for (char c = '0'; c <= '9'; c++) { + UNRESERVED_CHARS[c] = true; + } + for (char c = 'A'; c <= 'Z'; c++) { + UNRESERVED_CHARS[c] = true; + } + for (char c = 'a'; c <= 'z'; c++) { + UNRESERVED_CHARS[c] = true; + } + UNRESERVED_CHARS['-'] = true; + UNRESERVED_CHARS['.'] = true; + UNRESERVED_CHARS['_'] = true; + UNRESERVED_CHARS['~'] = true; + } + private StringUtil() { throw new AssertionError("Cannot instantiate StringUtil"); } @@ -48,10 +66,6 @@ private StringUtil() { * @since 2.0.0 */ public static String toLowerCase(String s) { - if (s == null) { - return null; - } - int pos = indexOfFirstUpperCaseChar(s); if (pos == -1) { @@ -59,10 +73,9 @@ public static String toLowerCase(String s) { } char[] chars = s.toCharArray(); - int length = chars.length; - for (int i = pos; i < length; i++) { - chars[i] = (char) toLowerCase(chars[i]); + for (int length = chars.length; pos < length; pos++) { + chars[pos] = (char) toLowerCase(chars[pos]); } return new String(chars); @@ -77,26 +90,22 @@ public static String toLowerCase(String s) { * @since 2.0.0 */ public static String percentDecode(final String source) { - if (source == null || source.isEmpty()) { + if (source.indexOf(PERCENT_CHAR) == -1) { return source; } byte[] bytes = source.getBytes(StandardCharsets.UTF_8); - int i = indexOfFirstPercentChar(bytes); - - if (i == -1) { - return source; - } + int readPos = indexOfFirstPercentChar(bytes); + int writePos = readPos; int length = bytes.length; - int writePos = i; - while (i < length) { - byte b = bytes[i]; + while (readPos < length) { + byte b = bytes[readPos]; if (b == PERCENT_CHAR) { - bytes[writePos++] = percentDecode(bytes, i++); - i += 2; + bytes[writePos++] = percentDecode(bytes, readPos++); + readPos += 2; } else { - bytes[writePos++] = bytes[i++]; + bytes[writePos++] = bytes[readPos++]; } } @@ -112,34 +121,29 @@ public static String percentDecode(final String source) { * @since 2.0.0 */ public static String percentEncode(final String source) { - if (source == null || source.isEmpty()) { - return source; - } - byte[] bytes = source.getBytes(StandardCharsets.UTF_8); - int start = indexOfFirstNonAsciiChar(bytes); - if (start == -1) { + if (!shouldEncode(source)) { return source; } - int length = bytes.length; - ByteBuffer buffer = ByteBuffer.allocate(start + ((length - start) * 3)); - if (start != 0) { - buffer.put(bytes, 0, start); - } - for (int i = start; i < length; i++) { - byte b = bytes[i]; - if (shouldEncode(b)) { - byte b1 = (byte) Character.toUpperCase(Character.forDigit((b >> 4) & 0xF, 16)); - byte b2 = (byte) Character.toUpperCase(Character.forDigit(b & 0xF, 16)); - buffer.put(PERCENT_CHAR); - buffer.put(b1); - buffer.put(b2); + byte[] src = source.getBytes(StandardCharsets.UTF_8); + byte[] dest = new byte[3 * src.length]; + + int writePos = 0; + for (byte b : src) { + if (shouldEncode(toUnsignedInt(b))) { + dest[writePos++] = PERCENT_CHAR; + dest[writePos++] = toHexDigit(b >> 4); + dest[writePos++] = toHexDigit(b); } else { - buffer.put(b); + dest[writePos++] = b; } } - return new String(buffer.array(), 0, buffer.position(), StandardCharsets.UTF_8); + return new String(dest, 0, writePos, StandardCharsets.UTF_8); + } + + private static byte toHexDigit(int b) { + return (byte) Character.toUpperCase(Character.forDigit(b & 0xF, 16)); } /** @@ -178,14 +182,34 @@ public static boolean isValidCharForKey(int c) { return (isAlphaNumeric(c) || c == '.' || c == '_' || c == '-'); } + /** + * Returns {@code true} if the character is in the unreserved RFC 3986 set. + *

+ * Warning: Profiling shows that the performance of {@link #percentEncode} relies heavily on this method. + * Modify with care. + *

+ * @param c non-negative integer. + */ private static boolean isUnreserved(int c) { - return (isValidCharForKey(c) || c == '~'); + return c < 128 && UNRESERVED_CHARS[c]; } + /** + * @param c non-negative integer + */ private static boolean shouldEncode(int c) { return !isUnreserved(c); } + private static boolean shouldEncode(String s) { + for (int i = 0, length = s.length(); i < length; i++) { + if (shouldEncode(s.charAt(i))) { + return true; + } + } + return false; + } + private static boolean isAlpha(int c) { return (isLowerCase(c) || isUpperCase(c)); } @@ -195,7 +219,7 @@ private static boolean isAlphaNumeric(int c) { } private static boolean isUpperCase(int c) { - return (c >= 'A' && c <= 'Z'); + return 'A' <= c && c <= 'Z'; } private static boolean isLowerCase(int c) { @@ -207,34 +231,21 @@ private static int toLowerCase(int c) { } private static int indexOfFirstUpperCaseChar(String s) { - int length = s.length(); - - for (int i = 0; i < length; i++) { + for (int i = 0, length = s.length(); i < length; i++) { if (isUpperCase(s.charAt(i))) { return i; } } - return -1; } - private static int indexOfFirstNonAsciiChar(byte[] bytes) { - int length = bytes.length; - int start = -1; - for (int i = 0; i < length; i++) { - if (shouldEncode(bytes[i])) { - start = i; - break; + private static int indexOfFirstPercentChar(final byte[] bytes) { + for (int i = 0, length = bytes.length; i < length; i++) { + if (bytes[i] == PERCENT_CHAR) { + return i; } } - return start; - } - - private static int indexOfFirstPercentChar(final byte[] bytes) { - return IntStream.range(0, bytes.length) - .filter(i -> bytes[i] == PERCENT_CHAR) - .findFirst() - .orElse(-1); + return -1; } private static byte percentDecode(final byte[] bytes, final int start) { diff --git a/src/test/java/com/github/packageurl/internal/StringUtilBenchmark.java b/src/test/java/com/github/packageurl/internal/StringUtilBenchmark.java index d75537f..e05b534 100644 --- a/src/test/java/com/github/packageurl/internal/StringUtilBenchmark.java +++ b/src/test/java/com/github/packageurl/internal/StringUtilBenchmark.java @@ -31,6 +31,7 @@ import org.openjdk.jmh.annotations.OutputTimeUnit; import org.openjdk.jmh.annotations.Param; import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; import org.openjdk.jmh.annotations.State; import org.openjdk.jmh.infra.Blackhole; @@ -62,8 +63,14 @@ public class StringUtilBenchmark { @Param({"0", "0.1", "0.5"}) private double nonAsciiProb; - private final String[] decodedData = createDecodedData(); - private final String[] encodedData = encodeData(decodedData); + private String[] decodedData; + private String[] encodedData; + + @Setup + public void setup() { + decodedData = createDecodedData(); + encodedData = encodeData(decodedData); + } private String[] createDecodedData() { Random random = new Random(); @@ -87,7 +94,10 @@ private static String[] encodeData(String[] decodedData) { for (int i = 0; i < encodedData.length; i++) { encodedData[i] = StringUtil.percentEncode(decodedData[i]); if (!StringUtil.percentDecode(encodedData[i]).equals(decodedData[i])) { - throw new RuntimeException("Invalid implementation of `percentEncode` and `percentDecode`."); + throw new RuntimeException( + "Invalid implementation of `percentEncode` and `percentDecode`.\nOriginal data: " + + encodedData[i] + "\nEncoded and decoded data: " + + StringUtil.percentDecode(encodedData[i])); } } return encodedData;