Skip to content

Commit

Permalink
Make Java consistently reject unmatched end-group tag.
Browse files Browse the repository at this point in the history
This brings it into conformance with our spec and other languages.  Some parse paths already did this check, and all of them prohibit *nested* unmatched end-group tags.

PiperOrigin-RevId: 705225060
  • Loading branch information
mkruskal-google authored and copybara-github committed Dec 11, 2024
1 parent b9cb184 commit a4d4bfe
Show file tree
Hide file tree
Showing 14 changed files with 1,426 additions and 1,018 deletions.
1 change: 1 addition & 0 deletions MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ maven.install(
"com.google.j2objc:j2objc-annotations:2.8",
"com.google.guava:guava:32.0.1-jre",
"com.google.guava:guava-testlib:32.0.1-jre",
"com.google.testparameterinjector:test-parameter-injector:1.18",
"com.google.truth:truth:1.1.2",
"junit:junit:4.13.2",
"org.mockito:mockito-core:4.3.1",
Expand Down
4 changes: 0 additions & 4 deletions conformance/failure_list_java.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,3 @@ Required.*.JsonInput.RepeatedFieldWrongElementTypeExpectingStringsGotBool
Required.*.JsonInput.RepeatedFieldWrongElementTypeExpectingStringsGotInt # Should have failed to parse, but didn't.
Required.*.JsonInput.StringFieldNotAString # Should have failed to parse, but didn't.
Required.*.ProtobufInput.UnknownOrdering.ProtobufOutput # Unknown field mismatch
Required.*.ProtobufInput.UnmatchedEndGroup # Should have failed to parse, but didn't.
Required.*.ProtobufInput.UnmatchedEndGroupUnknown # Should have failed to parse, but didn't.
Required.*.ProtobufInput.UnmatchedEndGroupWithData # Should have failed to parse, but didn't.
Required.*.ProtobufInput.UnmatchedEndGroupWrongType # Should have failed to parse, but didn't.
4 changes: 0 additions & 4 deletions conformance/failure_list_jruby.txt
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,3 @@ Required.Editions_Proto2.ProtobufInput.UnknownOrdering.ProtobufOutput
Required.Editions_Proto3.ProtobufInput.UnknownOrdering.ProtobufOutput
Required.Proto2.ProtobufInput.UnknownOrdering.ProtobufOutput
Required.Proto3.ProtobufInput.UnknownOrdering.ProtobufOutput
Required.*.ProtobufInput.UnmatchedEndGroup # Should have failed to parse, but didn't.
Required.*.ProtobufInput.UnmatchedEndGroupUnknown # Should have failed to parse, but didn't.
Required.*.ProtobufInput.UnmatchedEndGroupWithData # Should have failed to parse, but didn't.
Required.*.ProtobufInput.UnmatchedEndGroupWrongType # Should have failed to parse, but didn't.
1 change: 1 addition & 0 deletions java/core/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,7 @@ junit_tests(
":lite_test_protos_java_proto",
":test_util",
"@protobuf_maven//:com_google_guava_guava",
"@protobuf_maven//:com_google_testparameterinjector_test_parameter_injector",
"@protobuf_maven//:com_google_truth_truth",
"@protobuf_maven//:junit_junit",
"@protobuf_maven//:org_mockito_mockito_core",
Expand Down
97 changes: 59 additions & 38 deletions java/core/src/main/java/com/google/protobuf/CodedInputStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ public abstract class CodedInputStream {
private static volatile int defaultRecursionLimit = 100;

/** Visible for subclasses. See setRecursionLimit() */
int recursionDepth;
int messageDepth;

int groupDepth;

int recursionLimit = defaultRecursionLimit;

Expand Down Expand Up @@ -173,11 +175,22 @@ static CodedInputStream newInstance(ByteBuffer buf, boolean bufferIsImmutable) {
}

public void checkRecursionLimit() throws InvalidProtocolBufferException {
if (recursionDepth >= recursionLimit) {
if (messageDepth + groupDepth >= recursionLimit) {
throw InvalidProtocolBufferException.recursionLimitExceeded();
}
}

/**
* Verifies that the last tag was 0 if we aren't inside a group.
*
* @throws InvalidProtocolBufferException The last tag was not 0 and we aren't inside a group.
*/
public void checkValidEndTag() throws InvalidProtocolBufferException {
if (groupDepth == 0) {
checkLastTagWas(0);
}
}

/** Disable construction/inheritance outside of this class. */
private CodedInputStream() {}

Expand Down Expand Up @@ -231,9 +244,9 @@ public void skipMessage() throws IOException {
return;
}
checkRecursionLimit();
++recursionDepth;
++groupDepth;
boolean fieldSkipped = skipField(tag);
--recursionDepth;
--groupDepth;
if (!fieldSkipped) {
return;
}
Expand All @@ -251,9 +264,9 @@ public void skipMessage(CodedOutputStream output) throws IOException {
return;
}
checkRecursionLimit();
++recursionDepth;
++groupDepth;
boolean fieldSkipped = skipField(tag, output);
--recursionDepth;
--groupDepth;
if (!fieldSkipped) {
return;
}
Expand Down Expand Up @@ -668,6 +681,7 @@ public boolean skipField(final int tag) throws IOException {
WireFormat.makeTag(WireFormat.getTagFieldNumber(tag), WireFormat.WIRETYPE_END_GROUP));
return true;
case WireFormat.WIRETYPE_END_GROUP:
checkValidEndTag();
return false;
case WireFormat.WIRETYPE_FIXED32:
skipRawBytes(FIXED32_SIZE);
Expand Down Expand Up @@ -714,6 +728,7 @@ public boolean skipField(final int tag, final CodedOutputStream output) throws I
}
case WireFormat.WIRETYPE_END_GROUP:
{
checkValidEndTag();
return false;
}
case WireFormat.WIRETYPE_FIXED32:
Expand Down Expand Up @@ -815,10 +830,10 @@ public void readGroup(
final ExtensionRegistryLite extensionRegistry)
throws IOException {
checkRecursionLimit();
++recursionDepth;
++groupDepth;
builder.mergeFrom(this, extensionRegistry);
checkLastTagWas(WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP));
--recursionDepth;
--groupDepth;
}

@Override
Expand All @@ -828,10 +843,10 @@ public <T extends MessageLite> T readGroup(
final ExtensionRegistryLite extensionRegistry)
throws IOException {
checkRecursionLimit();
++recursionDepth;
++groupDepth;
T result = parser.parsePartialFrom(this, extensionRegistry);
checkLastTagWas(WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP));
--recursionDepth;
--groupDepth;
return result;
}

Expand All @@ -849,10 +864,10 @@ public void readMessage(
final int length = readRawVarint32();
checkRecursionLimit();
final int oldLimit = pushLimit(length);
++recursionDepth;
++messageDepth;
builder.mergeFrom(this, extensionRegistry);
checkLastTagWas(0);
--recursionDepth;
--messageDepth;
if (getBytesUntilLimit() != 0) {
throw InvalidProtocolBufferException.truncatedMessage();
}
Expand All @@ -865,10 +880,10 @@ public <T extends MessageLite> T readMessage(
int length = readRawVarint32();
checkRecursionLimit();
final int oldLimit = pushLimit(length);
++recursionDepth;
++messageDepth;
T result = parser.parsePartialFrom(this, extensionRegistry);
checkLastTagWas(0);
--recursionDepth;
--messageDepth;
if (getBytesUntilLimit() != 0) {
throw InvalidProtocolBufferException.truncatedMessage();
}
Expand Down Expand Up @@ -1361,6 +1376,7 @@ public boolean skipField(final int tag) throws IOException {
WireFormat.makeTag(WireFormat.getTagFieldNumber(tag), WireFormat.WIRETYPE_END_GROUP));
return true;
case WireFormat.WIRETYPE_END_GROUP:
checkValidEndTag();
return false;
case WireFormat.WIRETYPE_FIXED32:
skipRawBytes(FIXED32_SIZE);
Expand Down Expand Up @@ -1407,6 +1423,7 @@ public boolean skipField(final int tag, final CodedOutputStream output) throws I
}
case WireFormat.WIRETYPE_END_GROUP:
{
checkValidEndTag();
return false;
}
case WireFormat.WIRETYPE_FIXED32:
Expand Down Expand Up @@ -1513,10 +1530,10 @@ public void readGroup(
final ExtensionRegistryLite extensionRegistry)
throws IOException {
checkRecursionLimit();
++recursionDepth;
++groupDepth;
builder.mergeFrom(this, extensionRegistry);
checkLastTagWas(WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP));
--recursionDepth;
--groupDepth;
}

@Override
Expand All @@ -1526,10 +1543,10 @@ public <T extends MessageLite> T readGroup(
final ExtensionRegistryLite extensionRegistry)
throws IOException {
checkRecursionLimit();
++recursionDepth;
++groupDepth;
T result = parser.parsePartialFrom(this, extensionRegistry);
checkLastTagWas(WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP));
--recursionDepth;
--groupDepth;
return result;
}

Expand All @@ -1547,10 +1564,10 @@ public void readMessage(
final int length = readRawVarint32();
checkRecursionLimit();
final int oldLimit = pushLimit(length);
++recursionDepth;
++messageDepth;
builder.mergeFrom(this, extensionRegistry);
checkLastTagWas(0);
--recursionDepth;
--messageDepth;
if (getBytesUntilLimit() != 0) {
throw InvalidProtocolBufferException.truncatedMessage();
}
Expand All @@ -1563,10 +1580,10 @@ public <T extends MessageLite> T readMessage(
int length = readRawVarint32();
checkRecursionLimit();
final int oldLimit = pushLimit(length);
++recursionDepth;
++messageDepth;
T result = parser.parsePartialFrom(this, extensionRegistry);
checkLastTagWas(0);
--recursionDepth;
--messageDepth;
if (getBytesUntilLimit() != 0) {
throw InvalidProtocolBufferException.truncatedMessage();
}
Expand Down Expand Up @@ -2107,6 +2124,7 @@ public boolean skipField(final int tag) throws IOException {
WireFormat.makeTag(WireFormat.getTagFieldNumber(tag), WireFormat.WIRETYPE_END_GROUP));
return true;
case WireFormat.WIRETYPE_END_GROUP:
checkValidEndTag();
return false;
case WireFormat.WIRETYPE_FIXED32:
skipRawBytes(FIXED32_SIZE);
Expand Down Expand Up @@ -2153,6 +2171,7 @@ public boolean skipField(final int tag, final CodedOutputStream output) throws I
}
case WireFormat.WIRETYPE_END_GROUP:
{
checkValidEndTag();
return false;
}
case WireFormat.WIRETYPE_FIXED32:
Expand Down Expand Up @@ -2296,10 +2315,10 @@ public void readGroup(
final ExtensionRegistryLite extensionRegistry)
throws IOException {
checkRecursionLimit();
++recursionDepth;
++groupDepth;
builder.mergeFrom(this, extensionRegistry);
checkLastTagWas(WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP));
--recursionDepth;
--groupDepth;
}

@Override
Expand All @@ -2309,10 +2328,10 @@ public <T extends MessageLite> T readGroup(
final ExtensionRegistryLite extensionRegistry)
throws IOException {
checkRecursionLimit();
++recursionDepth;
++groupDepth;
T result = parser.parsePartialFrom(this, extensionRegistry);
checkLastTagWas(WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP));
--recursionDepth;
--groupDepth;
return result;
}

Expand All @@ -2330,10 +2349,10 @@ public void readMessage(
final int length = readRawVarint32();
checkRecursionLimit();
final int oldLimit = pushLimit(length);
++recursionDepth;
++messageDepth;
builder.mergeFrom(this, extensionRegistry);
checkLastTagWas(0);
--recursionDepth;
--messageDepth;
if (getBytesUntilLimit() != 0) {
throw InvalidProtocolBufferException.truncatedMessage();
}
Expand All @@ -2346,10 +2365,10 @@ public <T extends MessageLite> T readMessage(
int length = readRawVarint32();
checkRecursionLimit();
final int oldLimit = pushLimit(length);
++recursionDepth;
++messageDepth;
T result = parser.parsePartialFrom(this, extensionRegistry);
checkLastTagWas(0);
--recursionDepth;
--messageDepth;
if (getBytesUntilLimit() != 0) {
throw InvalidProtocolBufferException.truncatedMessage();
}
Expand Down Expand Up @@ -3234,6 +3253,7 @@ public boolean skipField(final int tag) throws IOException {
WireFormat.makeTag(WireFormat.getTagFieldNumber(tag), WireFormat.WIRETYPE_END_GROUP));
return true;
case WireFormat.WIRETYPE_END_GROUP:
checkValidEndTag();
return false;
case WireFormat.WIRETYPE_FIXED32:
skipRawBytes(FIXED32_SIZE);
Expand Down Expand Up @@ -3280,6 +3300,7 @@ public boolean skipField(final int tag, final CodedOutputStream output) throws I
}
case WireFormat.WIRETYPE_END_GROUP:
{
checkValidEndTag();
return false;
}
case WireFormat.WIRETYPE_FIXED32:
Expand Down Expand Up @@ -3393,10 +3414,10 @@ public void readGroup(
final ExtensionRegistryLite extensionRegistry)
throws IOException {
checkRecursionLimit();
++recursionDepth;
++groupDepth;
builder.mergeFrom(this, extensionRegistry);
checkLastTagWas(WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP));
--recursionDepth;
--groupDepth;
}

@Override
Expand All @@ -3406,10 +3427,10 @@ public <T extends MessageLite> T readGroup(
final ExtensionRegistryLite extensionRegistry)
throws IOException {
checkRecursionLimit();
++recursionDepth;
++groupDepth;
T result = parser.parsePartialFrom(this, extensionRegistry);
checkLastTagWas(WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP));
--recursionDepth;
--groupDepth;
return result;
}

Expand All @@ -3427,10 +3448,10 @@ public void readMessage(
final int length = readRawVarint32();
checkRecursionLimit();
final int oldLimit = pushLimit(length);
++recursionDepth;
++messageDepth;
builder.mergeFrom(this, extensionRegistry);
checkLastTagWas(0);
--recursionDepth;
--messageDepth;
if (getBytesUntilLimit() != 0) {
throw InvalidProtocolBufferException.truncatedMessage();
}
Expand All @@ -3443,10 +3464,10 @@ public <T extends MessageLite> T readMessage(
int length = readRawVarint32();
checkRecursionLimit();
final int oldLimit = pushLimit(length);
++recursionDepth;
++messageDepth;
T result = parser.parsePartialFrom(this, extensionRegistry);
checkLastTagWas(0);
--recursionDepth;
--messageDepth;
if (getBytesUntilLimit() != 0) {
throw InvalidProtocolBufferException.truncatedMessage();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,16 +184,14 @@ public <T> void mergeMessageField(
private <T> void mergeMessageFieldInternal(
T target, Schema<T> schema, ExtensionRegistryLite extensionRegistry) throws IOException {
int size = input.readUInt32();
if (input.recursionDepth >= input.recursionLimit) {
throw InvalidProtocolBufferException.recursionLimitExceeded();
}
input.checkRecursionLimit();

// Push the new limit.
final int prevLimit = input.pushLimit(size);
++input.recursionDepth;
++input.messageDepth;
schema.mergeFrom(target, this, extensionRegistry);
input.checkLastTagWas(0);
--input.recursionDepth;
--input.messageDepth;
// Restore the previous limit.
input.popLimit(prevLimit);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1334,6 +1334,8 @@ private static void mergeMessageSetExtensionFromCodedStream(
// We haven't seen a type ID yet or we want parse message lazily.
rawBytes = input.readBytes();

} else if (tag == WireFormat.MESSAGE_SET_ITEM_END_TAG) {
break;
} else { // Unknown tag. Skip it.
if (!input.skipField(tag)) {
break; // End of group
Expand Down
Loading

0 comments on commit a4d4bfe

Please sign in to comment.