From eb7b4686555aa65a6adcb9d68f209a61d5ca37ec Mon Sep 17 00:00:00 2001 From: Michael Stapelberg Date: Wed, 4 Dec 2024 14:33:39 +0100 Subject: [PATCH] all: Release the Opaque API For golang/protobuf#1657 Change-Id: I7b2b0c30506706015ce278e6054439c9ad9ef727 Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/634815 TryBot-Bypass: Michael Stapelberg Reviewed-by: Joseph Tsai Reviewed-by: Damien Neil --- .../builder_test/builder_test.go | 744 ++++++++++ .../descriptor_test/descriptor_test.go | 43 + cmd/protoc-gen-go/internal_gengo/init.go | 7 +- .../internal_gengo/init_opaque.go | 33 + cmd/protoc-gen-go/internal_gengo/main.go | 5 +- cmd/protoc-gen-go/internal_gengo/opaque.go | 1306 +++++++++++++++++ cmd/protoc-gen-go/internal_gengo/reflect.go | 2 +- .../name_clash_test/name_clash_proto3_test.go | 810 ++++++++++ .../name_clash_test/name_clash_test.go | 898 ++++++++++++ .../opaque_default_test.go | 21 + .../opaque_map_test/opaque_map_test.go | 35 + .../testdata/nameclash/nameclash.go | 16 + .../nameclash/test_name_clash_hybrid.proto | 333 +++++ .../nameclash/test_name_clash_hybrid3.proto | 338 +++++ .../nameclash/test_name_clash_opaque.proto | 333 +++++ .../nameclash/test_name_clash_opaque3.proto | 338 +++++ .../nameclash/test_name_clash_open.proto | 151 ++ .../nameclash/test_name_clash_open3.proto | 164 +++ compiler/protogen/protogen.go | 59 +- compiler/protogen/protogen_apilevel.go | 192 +++ compiler/protogen/protogen_opaque.go | 79 + .../prototext/testmessages_opaque_test.go | 34 + encoding/prototext/testmessages_test.go | 7 + integration_test.go | 1 + internal/cmd/generate-protos/main.go | 184 ++- internal/cmd/generate-types/impl.go | 124 ++ internal/cmd/generate-types/impl_opaque.go | 77 + internal/cmd/generate-types/main.go | 2 + internal/cmd/generate-types/proto.go | 37 + internal/filedesc/build_test.go | 4 +- internal/filedesc/desc.go | 3 + internal/filedesc/editions.go | 4 + internal/genid/descriptor_gen.go | 27 +- internal/genid/go_features_gen.go | 17 + internal/genid/name.go | 12 + internal/impl/api_export_opaque.go | 128 ++ internal/impl/bitmap.go | 34 + internal/impl/bitmap_race.go | 126 ++ internal/impl/checkinit.go | 33 + internal/impl/codec_field_opaque.go | 264 ++++ internal/impl/codec_message.go | 13 + internal/impl/codec_message_opaque.go | 156 ++ internal/impl/decode.go | 56 +- internal/impl/encode.go | 78 + internal/impl/lazy.go | 433 ++++++ internal/impl/lazy_buffersharing_test.go | 151 ++ internal/impl/lazy_field_normalized_test.go | 156 ++ internal/impl/merge.go | 27 + internal/impl/message.go | 12 + internal/impl/message_opaque.go | 614 ++++++++ internal/impl/message_opaque_gen.go | 132 ++ internal/impl/message_reflect.go | 5 + internal/impl/message_reflect_field.go | 32 +- internal/impl/message_reflect_field_gen.go | 273 ++++ internal/impl/pointer_unsafe.go | 9 + internal/impl/pointer_unsafe_opaque.go | 42 + internal/impl/presence.go | 142 ++ internal/impl/validate.go | 16 + internal/protolazy/bufferreader.go | 364 +++++ internal/protolazy/lazy.go | 359 +++++ internal/protolazy/pointer_unsafe.go | 17 + internal/race_test/lazy/lazy_race_test.go | 494 +++++++ .../reflection_test/reflection_hybrid_test.go | 1003 +++++++++++++ .../reflection_large_opaque_test.go | 893 +++++++++++ .../reflection_test/reflection_opaque_test.go | 1045 +++++++++++++ .../reflection_test/reflection_open_test.go | 985 +++++++++++++ .../reflection_repeated_test.go | 43 + internal/reflection_test/reflection_test.go | 768 ++++++++++ internal/testprotos/enums/enums.proto | 4 +- internal/testprotos/irregular/test.proto | 20 +- .../lazy/lazy_normalized_wire_test.proto | 20 + internal/testprotos/lazy/lazy_tree.proto | 29 + .../messageset/messagesetpb/message_set.proto | 12 +- .../messageset/msetextpb/msetextpb.proto | 18 +- internal/testprotos/mixed/mixed.proto | 81 + internal/testprotos/news/news.proto | 4 +- internal/testprotos/required/required.proto | 39 +- internal/testprotos/test/test.proto | 4 + internal/testprotos/testeditions/test.proto | 130 ++ .../testprotos/testeditions/test_import.proto | 15 + .../testprotos/textpbeditions/test2.proto | 2 +- proto/decode.go | 16 + proto/encode.go | 3 +- proto/lazy_bench_test.go | 92 ++ proto/lazy_roundtrip_test.go | 125 ++ proto/messageset_test.go | 2 + proto/oneof_get_test.go | 273 ++++ proto/oneof_set_test.go | 312 ++++ proto/oneof_which_test.go | 189 +++ proto/repeated_test.go | 560 +++++++ proto/testmessages_opaque_test.go | 97 ++ proto/testmessages_test.go | 14 + proto/wrapperopaque.go | 80 + proto/wrapperopaque_test.go | 173 +++ reflect/protodesc/editions.go | 3 + runtime/protoiface/methods.go | 16 + runtime/protoimpl/impl.go | 4 + runtime/protolazy/protolazy.go | 31 + src/google/protobuf/go_features.proto | 22 + testing/prototest/message.go | 12 +- types/gofeaturespb/go_features.pb.go | 182 ++- 101 files changed, 17785 insertions(+), 142 deletions(-) create mode 100644 cmd/protoc-gen-go/builder_test/builder_test.go create mode 100644 cmd/protoc-gen-go/descriptor_test/descriptor_test.go create mode 100644 cmd/protoc-gen-go/internal_gengo/init_opaque.go create mode 100644 cmd/protoc-gen-go/internal_gengo/opaque.go create mode 100644 cmd/protoc-gen-go/name_clash_test/name_clash_proto3_test.go create mode 100644 cmd/protoc-gen-go/name_clash_test/name_clash_test.go create mode 100644 cmd/protoc-gen-go/opaque_default_test/opaque_default_test.go create mode 100644 cmd/protoc-gen-go/opaque_map_test/opaque_map_test.go create mode 100644 cmd/protoc-gen-go/testdata/nameclash/nameclash.go create mode 100644 cmd/protoc-gen-go/testdata/nameclash/test_name_clash_hybrid.proto create mode 100644 cmd/protoc-gen-go/testdata/nameclash/test_name_clash_hybrid3.proto create mode 100644 cmd/protoc-gen-go/testdata/nameclash/test_name_clash_opaque.proto create mode 100644 cmd/protoc-gen-go/testdata/nameclash/test_name_clash_opaque3.proto create mode 100644 cmd/protoc-gen-go/testdata/nameclash/test_name_clash_open.proto create mode 100644 cmd/protoc-gen-go/testdata/nameclash/test_name_clash_open3.proto create mode 100644 compiler/protogen/protogen_apilevel.go create mode 100644 compiler/protogen/protogen_opaque.go create mode 100644 encoding/prototext/testmessages_opaque_test.go create mode 100644 internal/cmd/generate-types/impl_opaque.go create mode 100644 internal/genid/name.go create mode 100644 internal/impl/api_export_opaque.go create mode 100644 internal/impl/bitmap.go create mode 100644 internal/impl/bitmap_race.go create mode 100644 internal/impl/codec_field_opaque.go create mode 100644 internal/impl/codec_message_opaque.go create mode 100644 internal/impl/lazy.go create mode 100644 internal/impl/lazy_buffersharing_test.go create mode 100644 internal/impl/lazy_field_normalized_test.go create mode 100644 internal/impl/message_opaque.go create mode 100644 internal/impl/message_opaque_gen.go create mode 100644 internal/impl/message_reflect_field_gen.go create mode 100644 internal/impl/pointer_unsafe_opaque.go create mode 100644 internal/impl/presence.go create mode 100644 internal/protolazy/bufferreader.go create mode 100644 internal/protolazy/lazy.go create mode 100644 internal/protolazy/pointer_unsafe.go create mode 100644 internal/race_test/lazy/lazy_race_test.go create mode 100644 internal/reflection_test/reflection_hybrid_test.go create mode 100644 internal/reflection_test/reflection_large_opaque_test.go create mode 100644 internal/reflection_test/reflection_opaque_test.go create mode 100644 internal/reflection_test/reflection_open_test.go create mode 100644 internal/reflection_test/reflection_repeated_test.go create mode 100644 internal/reflection_test/reflection_test.go create mode 100644 internal/testprotos/lazy/lazy_normalized_wire_test.proto create mode 100644 internal/testprotos/lazy/lazy_tree.proto create mode 100644 internal/testprotos/mixed/mixed.proto create mode 100644 internal/testprotos/testeditions/test_import.proto create mode 100644 proto/lazy_bench_test.go create mode 100644 proto/lazy_roundtrip_test.go create mode 100644 proto/oneof_get_test.go create mode 100644 proto/oneof_set_test.go create mode 100644 proto/oneof_which_test.go create mode 100644 proto/repeated_test.go create mode 100644 proto/testmessages_opaque_test.go create mode 100644 proto/wrapperopaque.go create mode 100644 proto/wrapperopaque_test.go create mode 100644 runtime/protolazy/protolazy.go diff --git a/cmd/protoc-gen-go/builder_test/builder_test.go b/cmd/protoc-gen-go/builder_test/builder_test.go new file mode 100644 index 000000000..029bfa4af --- /dev/null +++ b/cmd/protoc-gen-go/builder_test/builder_test.go @@ -0,0 +1,744 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Tests the opaque builders. +package builder_test + +import ( + "testing" + + testhybridpb "google.golang.org/protobuf/internal/testprotos/testeditions/testeditions_hybrid" + testopaquepb "google.golang.org/protobuf/internal/testprotos/testeditions/testeditions_opaque" + "google.golang.org/protobuf/proto" +) + +var enableLazy = proto.UnmarshalOptions{} +var disableLazy = proto.UnmarshalOptions{ + NoLazyDecoding: true, +} + +func roundtrip(t *testing.T, m proto.Message, unmarshalOpts proto.UnmarshalOptions) { + b, err := proto.Marshal(m) + if err != nil { + t.Fatalf("unable to Marshal proto: %v", err) + } + if err := unmarshalOpts.Unmarshal(b, m); err != nil { + t.Fatalf("roundtrip: unable to unmarshal proto: %v", err) + } +} + +func TestOpaqueBuilderLazy(t *testing.T) { + testLazyOptionalBuilder(t, enableLazy) +} + +func TestOpaqueBuilderEager(t *testing.T) { + testLazyOptionalBuilder(t, disableLazy) +} + +// testLazyOptionalBuilder exercises all optional fields in the testall_opaque_optional3_go_proto builder +func testLazyOptionalBuilder(t *testing.T, unmarshalOpts proto.UnmarshalOptions) { + // Create empty proto from builder + m := testopaquepb.TestAllTypes_builder{}.Build() + + roundtrip(t, m, unmarshalOpts) + + // Check lazy message field + m = testopaquepb.TestAllTypes_builder{ + OptionalLazyNestedMessage: testopaquepb.TestAllTypes_NestedMessage_builder{ + A: proto.Int32(1147), + }.Build(), + RepeatedNestedMessage: []*testopaquepb.TestAllTypes_NestedMessage{ + testopaquepb.TestAllTypes_NestedMessage_builder{ + A: proto.Int32(1247), + }.Build(), + }, + OneofNestedMessage: testopaquepb.TestAllTypes_NestedMessage_builder{ + A: proto.Int32(1347), + }.Build(), + MapStringNestedMessage: map[string]*testopaquepb.TestAllTypes_NestedMessage{ + "a": testopaquepb.TestAllTypes_NestedMessage_builder{ + A: proto.Int32(5), + }.Build(), + }, + }.Build() + + roundtrip(t, m, unmarshalOpts) + + if got, want := m.HasOptionalLazyNestedMessage(), true; got != want { + t.Errorf("Builder for field NestedMessage did not set presence, got %v, wanted %v", got, want) + } + if got, want := m.GetOptionalLazyNestedMessage().GetA(), int32(1147); got != want { + t.Errorf("Builder for field NestedMessage did not work, got %v, wanted %v", got, want) + } + if got, want := len(m.GetRepeatedNestedMessage()), 1; got != want { + t.Errorf("Builder for field RepeatedNestedMessage did not set a field of expected length, got %v, wanted %v", got, want) + } + if got, want := m.GetRepeatedNestedMessage()[0].GetA(), int32(1247); got != want { + t.Errorf("Builder for field RepetedNestedMessage did not work, got %v, wanted %v", got, want) + } + if got, want := m.HasOneofNestedMessage(), true; got != want { + t.Errorf("Builder for field OneofNestedMessage did not set presence, got %v, wanted %v", got, want) + } + if got, want := m.GetOneofNestedMessage().GetA(), int32(1347); got != want { + t.Errorf("Builder for field OneofNestedMessage did not work, got %v, wanted %v", got, want) + } + // Check map field + { + if got, want := len(m.GetMapStringNestedMessage()), 1; got != want { + t.Errorf("Builder for field MapStringNestedMessage did not work, got len %v, wanted len %v", got, want) + } + if got, want := m.GetMapStringNestedMessage()["a"].GetA(), int32(5); got != want { + t.Errorf("Builder for field MapStringNestedMessage did not work, got %v, wanted %v", got, want) + } + } +} + +// TestHybridOptionalBuilder exercises all optional fields in the testall_opaque_optional3_go_proto builder +func TestHybridOptionalBuilder(t *testing.T) { + // Create empty proto from builder + m := testhybridpb.TestAllTypes_builder{}.Build() + + // Check that no optional fields are present + // Check presence of each field + if got, want := m.HasOptionalInt32(), false; got != want { + t.Errorf("Builder for field OptionalInt32 did not set presence, got %v, wanted %v", got, want) + } + if got, want := m.HasOptionalInt64(), false; got != want { + t.Errorf("Builder for field OptionalInt64 did not set presence, got %v, wanted %v", got, want) + } + if got, want := m.HasOptionalUint32(), false; got != want { + t.Errorf("Builder for field OptionalUint32 did not set presence, got %v, wanted %v", got, want) + } + if got, want := m.HasOptionalUint64(), false; got != want { + t.Errorf("Builder for field OptionalUint64 did not set presence, got %v, wanted %v", got, want) + } + + if got, want := m.HasOptionalSint32(), false; got != want { + t.Errorf("Builder for field OptionalSint32 did not set presence, got %v, wanted %v", got, want) + } + if got, want := m.HasOptionalSint64(), false; got != want { + t.Errorf("Builder for field OptionalSint64 did not set presence, got %v, wanted %v", got, want) + } + if got, want := m.HasOptionalFixed32(), false; got != want { + t.Errorf("Builder for field OptionalFixed32 did not set presence, got %v, wanted %v", got, want) + } + if got, want := m.HasOptionalFixed64(), false; got != want { + t.Errorf("Builder for field OptionalFixed64 did not set presence, got %v, wanted %v", got, want) + } + if got, want := m.HasOptionalSfixed32(), false; got != want { + t.Errorf("Builder for field OptionalSfixed32 did not set presence, got %v, wanted %v", got, want) + } + if got, want := m.HasOptionalSfixed64(), false; got != want { + t.Errorf("Builder for field OptionalSfixed64 did not set presence, got %v, wanted %v", got, want) + } + if got, want := m.HasOptionalFloat(), false; got != want { + t.Errorf("Builder for field OptionalFloat did not set presence, got %v, wanted %v", got, want) + } + if got, want := m.HasOptionalDouble(), false; got != want { + t.Errorf("Builder for field OptionalDouble did not set presence, got %v, wanted %v", got, want) + } + if got, want := m.HasOptionalBool(), false; got != want { + t.Errorf("Builder for field OptionalBool did not set presence, got %v, wanted %v", got, want) + } + if got, want := m.HasOptionalString(), false; got != want { + t.Errorf("Builder for field OptionalString did not set presence, got %v, wanted %v", got, want) + } + if got, want := m.HasOptionalBytes(), false; got != want { + t.Errorf("Builder for field OptionalBytes did not set presence, got %v, wanted %v", got, want) + } + if got, want := m.HasOptionalNestedEnum(), false; got != want { + t.Errorf("Builder for field OptionalNestedEnum did not set presence, got %v, wanted %v", got, want) + } + if got, want := m.HasOptionalNestedMessage(), false; got != want { + t.Errorf("Builder for field OptionalNestedMessage did not set presence, got %v, wanted %v", got, want) + } + + // Create builder with every optional field filled in + b := testhybridpb.TestAllTypes_builder{ + // Scalar fields (including bytes) + OptionalInt32: proto.Int32(3), + OptionalInt64: proto.Int64(64), + OptionalUint32: proto.Uint32(32), + OptionalUint64: proto.Uint64(4711), + OptionalSint32: proto.Int32(-23), + OptionalSint64: proto.Int64(-123132), + OptionalFixed32: proto.Uint32(6798421), + OptionalFixed64: proto.Uint64(876555776), + OptionalSfixed32: proto.Int32(-909038), + OptionalSfixed64: proto.Int64(-63728193629), + OptionalFloat: proto.Float32(781.0), + OptionalDouble: proto.Float64(-3456.3), + OptionalBool: proto.Bool(true), + OptionalString: proto.String("hello"), + OptionalBytes: []byte("goodbye"), + OptionalNestedEnum: testhybridpb.TestAllTypes_FOO.Enum(), + OptionalNestedMessage: testhybridpb.TestAllTypes_NestedMessage_builder{ + A: proto.Int32(1147), + }.Build(), + } + + m = b.Build() + + // Check presence of each optional field + if got, want := m.HasOptionalInt32(), true; got != want { + t.Errorf("Builder for field OptionalInt32 did not set presence, got %v, wanted %v", got, want) + } + if got, want := m.HasOptionalInt64(), true; got != want { + t.Errorf("Builder for field OptionalInt64 did not set presence, got %v, wanted %v", got, want) + } + if got, want := m.HasOptionalUint32(), true; got != want { + t.Errorf("Builder for field OptionalUint32 did not set presence, got %v, wanted %v", got, want) + } + if got, want := m.HasOptionalUint64(), true; got != want { + t.Errorf("Builder for field OptionalUint64 did not set presence, got %v, wanted %v", got, want) + } + + if got, want := m.HasOptionalSint32(), true; got != want { + t.Errorf("Builder for field OptionalSint32 did not set presence, got %v, wanted %v", got, want) + } + if got, want := m.HasOptionalSint64(), true; got != want { + t.Errorf("Builder for field OptionalSint64 did not set presence, got %v, wanted %v", got, want) + } + if got, want := m.HasOptionalFixed32(), true; got != want { + t.Errorf("Builder for field OptionalFixed32 did not set presence, got %v, wanted %v", got, want) + } + if got, want := m.HasOptionalFixed64(), true; got != want { + t.Errorf("Builder for field OptionalFixed64 did not set presence, got %v, wanted %v", got, want) + } + if got, want := m.HasOptionalSfixed32(), true; got != want { + t.Errorf("Builder for field OptionalSfixed32 did not set presence, got %v, wanted %v", got, want) + } + if got, want := m.HasOptionalSfixed64(), true; got != want { + t.Errorf("Builder for field OptionalSfixed64 did not set presence, got %v, wanted %v", got, want) + } + if got, want := m.HasOptionalFloat(), true; got != want { + t.Errorf("Builder for field OptionalFloat did not set presence, got %v, wanted %v", got, want) + } + if got, want := m.HasOptionalDouble(), true; got != want { + t.Errorf("Builder for field OptionalDouble did not set presence, got %v, wanted %v", got, want) + } + if got, want := m.HasOptionalBool(), true; got != want { + t.Errorf("Builder for field OptionalBool did not set presence, got %v, wanted %v", got, want) + } + if got, want := m.HasOptionalString(), true; got != want { + t.Errorf("Builder for field OptionalString did not set presence, got %v, wanted %v", got, want) + } + if got, want := m.HasOptionalBytes(), true; got != want { + t.Errorf("Builder for field OptionalBytes did not set presence, got %v, wanted %v", got, want) + } + if got, want := m.HasOptionalNestedEnum(), true; got != want { + t.Errorf("Builder for field OptionalNestedEnum did not set presence, got %v, wanted %v", got, want) + } + if got, want := m.HasOptionalNestedMessage(), true; got != want { + t.Errorf("Builder for field OptionalNestedMessage did not set presence, got %v, wanted %v", got, want) + } + + // Check each optional field against the corresponding field in the builder + if got, want := m.GetOptionalInt32(), *b.OptionalInt32; got != want { + t.Errorf("Builder for field OptionalInt32 did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetOptionalInt64(), *b.OptionalInt64; got != want { + t.Errorf("Builder for field OptionalInt64 did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetOptionalUint32(), *b.OptionalUint32; got != want { + t.Errorf("Builder for field OptionalUint32 did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetOptionalUint64(), *b.OptionalUint64; got != want { + t.Errorf("Builder for field OptionalUint64 did not work, got %v, wanted %v", got, want) + } + + if got, want := m.GetOptionalSint32(), *b.OptionalSint32; got != want { + t.Errorf("Builder for field OptionalSint32 did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetOptionalSint64(), *b.OptionalSint64; got != want { + t.Errorf("Builder for field OptionalSint64 did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetOptionalFixed32(), *b.OptionalFixed32; got != want { + t.Errorf("Builder for field OptionalFixed32 did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetOptionalFixed64(), *b.OptionalFixed64; got != want { + t.Errorf("Builder for field OptionalFixed64 did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetOptionalSfixed32(), *b.OptionalSfixed32; got != want { + t.Errorf("Builder for field OptionalSfixed32 did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetOptionalSfixed64(), *b.OptionalSfixed64; got != want { + t.Errorf("Builder for field OptionalSfixed64 did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetOptionalFloat(), *b.OptionalFloat; got != want { + t.Errorf("Builder for field OptionalFloat did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetOptionalDouble(), *b.OptionalDouble; got != want { + t.Errorf("Builder for field OptionalDouble did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetOptionalBool(), *b.OptionalBool; got != want { + t.Errorf("Builder for field OptionalBool did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetOptionalString(), *b.OptionalString; got != want { + t.Errorf("Builder for field OptionalString did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetOptionalBytes(), b.OptionalBytes; string(got) != string(want) { + t.Errorf("Builder for field OptionalBytes did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetOptionalNestedEnum(), *b.OptionalNestedEnum; got != want { + t.Errorf("Builder for field OptionalNestedEnum did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetOptionalNestedMessage().GetA(), int32(1147); got != want { + t.Errorf("Builder for field OptionalNestedMessage did not work, got %v, wanted %v", got, want) + } + +} + +// TestOpaqueBuilder exercises all non-oneof fields in the testall_opaque3_go_proto builder +func TestOpaqueBuilder(t *testing.T) { + // Create builder with every possible field filled in + b := testopaquepb.TestAllTypes_builder{ + // Scalar fields (including bytes) + SingularInt32: 3, + SingularInt64: 64, + SingularUint32: 32, + SingularUint64: 4711, + SingularSint32: -23, + SingularSint64: -123132, + SingularFixed32: 6798421, + SingularFixed64: 876555776, + SingularSfixed32: -909038, + SingularSfixed64: -63728193629, + SingularFloat: 781.0, + SingularDouble: -3456.3, + SingularBool: true, + SingularString: "hello", + SingularBytes: []byte("goodbye"), + OptionalNestedEnum: testopaquepb.TestAllTypes_FOO.Enum(), + OptionalNestedMessage: testopaquepb.TestAllTypes_NestedMessage_builder{ + A: proto.Int32(1147), + }.Build(), + RepeatedInt32: []int32{4}, + RepeatedInt64: []int64{65}, + RepeatedUint32: []uint32{33}, + RepeatedUint64: []uint64{4712}, + RepeatedSint32: []int32{-24}, + RepeatedSint64: []int64{-123133}, + RepeatedFixed32: []uint32{6798422}, + RepeatedFixed64: []uint64{876555777}, + RepeatedSfixed32: []int32{-909039}, + RepeatedSfixed64: []int64{-63728193630}, + RepeatedFloat: []float32{782.0}, + RepeatedDouble: []float64{-3457.3}, + RepeatedBool: []bool{false}, + RepeatedString: []string{"hello!"}, + RepeatedBytes: [][]byte{[]byte("goodbye!")}, + RepeatedNestedEnum: []testopaquepb.TestAllTypes_NestedEnum{testopaquepb.TestAllTypes_BAZ}, + RepeatedNestedMessage: []*testopaquepb.TestAllTypes_NestedMessage{testopaquepb.TestAllTypes_NestedMessage_builder{ + A: proto.Int32(1148), + }.Build()}, + MapInt32Int32: map[int32]int32{ + 89: 87, + 87: 89, + }, + MapInt64Int64: map[int64]int64{ + 345: 678, + 2121: 5432, + }, + MapUint32Uint32: map[uint32]uint32{ + 765476: 87658, + 4324: 6543, + }, + MapUint64Uint64: map[uint64]uint64{ + 2324: 543534, + 7657654: 675, + }, + MapSint32Sint32: map[int32]int32{ + -45243: -543353, + -54343: -33, + }, + MapSint64Sint64: map[int64]int64{ + -6754389: 34, + 467382: -676743, + }, + MapFixed32Fixed32: map[uint32]uint32{ + 43432: 4444, + 5555555: 666666, + }, + MapFixed64Fixed64: map[uint64]uint64{ + 777777: 888888, + 999999: 111111, + }, + MapSfixed32Sfixed32: map[int32]int32{ + -778989: -543, + -9999: 98765, + }, + MapSfixed64Sfixed64: map[int64]int64{ + 65486723: 89, + -76843592: -33, + }, + MapInt32Float: map[int32]float32{ + 543433: 7.5, + 3434333: 3.14, + }, + MapInt32Double: map[int32]float64{ + 876876: 34.34, + 987650: 35.35, + }, + MapBoolBool: map[bool]bool{ + true: true, + false: true, + }, + MapStringString: map[string]string{ + "hello?": "goodbye?", + "hi": "bye", + }, + MapStringBytes: map[string][]byte{ + "hi?": []byte("bye!"), + "bye?": []byte("hi!"), + }, + MapStringNestedMessage: map[string]*testopaquepb.TestAllTypes_NestedMessage{ + "nest": testopaquepb.TestAllTypes_NestedMessage_builder{ + A: proto.Int32(99), + }.Build(), + "mess": testopaquepb.TestAllTypes_NestedMessage_builder{ + A: proto.Int32(100), + }.Build(), + }, + MapStringNestedEnum: map[string]testopaquepb.TestAllTypes_NestedEnum{ + "bar": testopaquepb.TestAllTypes_BAR, + "baz": testopaquepb.TestAllTypes_BAZ, + }, + OneofUint32: proto.Uint32(77665544), + } + m := b.Build() + + // Check each field against the corresponding field in the builder + if got, want := m.GetSingularInt32(), b.SingularInt32; got != want { + t.Errorf("Builder for field FInt32 did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetSingularInt64(), b.SingularInt64; got != want { + t.Errorf("Builder for field FInt64 did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetSingularUint32(), b.SingularUint32; got != want { + t.Errorf("Builder for field FUint32 did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetSingularUint64(), b.SingularUint64; got != want { + t.Errorf("Builder for field FUint64 did not work, got %v, wanted %v", got, want) + } + + if got, want := m.GetSingularSint32(), b.SingularSint32; got != want { + t.Errorf("Builder for field FSint32 did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetSingularSint64(), b.SingularSint64; got != want { + t.Errorf("Builder for field FSint64 did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetSingularFixed32(), b.SingularFixed32; got != want { + t.Errorf("Builder for field FFixed32 did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetSingularFixed64(), b.SingularFixed64; got != want { + t.Errorf("Builder for field FFixed64 did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetSingularSfixed32(), b.SingularSfixed32; got != want { + t.Errorf("Builder for field FSfixed32 did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetSingularSfixed64(), b.SingularSfixed64; got != want { + t.Errorf("Builder for field FSfixed64 did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetSingularFloat(), b.SingularFloat; got != want { + t.Errorf("Builder for field FFloat did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetSingularDouble(), b.SingularDouble; got != want { + t.Errorf("Builder for field FDouble did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetSingularBool(), b.SingularBool; got != want { + t.Errorf("Builder for field FBool did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetSingularString(), b.SingularString; got != want { + t.Errorf("Builder for field FString did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetSingularBytes(), b.SingularBytes; string(got) != string(want) { + t.Errorf("Builder for field FBytes did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetOptionalNestedEnum(), *b.OptionalNestedEnum; got != want { + t.Errorf("Builder for field FNestedEnum did not work, got %v, wanted %v", got, want) + } + if got, want := m.HasOptionalNestedMessage(), true; got != want { + t.Errorf("Builder for field FNestedMessage did not set presence, got %v, wanted %v", got, want) + } + if got, want := m.GetOptionalNestedMessage().GetA(), int32(1147); got != want { + t.Errorf("Builder for field FNestedMessage did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetRepeatedInt32()[0], b.RepeatedInt32[0]; got != want { + t.Errorf("Builder for repeated field RepeatedInt32 did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetRepeatedInt64()[0], b.RepeatedInt64[0]; got != want { + t.Errorf("Builder for repeated field RepeatedInt64 did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetRepeatedUint32()[0], b.RepeatedUint32[0]; got != want { + t.Errorf("Builder for repeated field RepeatedUint32 did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetRepeatedUint64()[0], b.RepeatedUint64[0]; got != want { + t.Errorf("Builder for repeated field RepeatedUint64 did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetRepeatedSint32()[0], b.RepeatedSint32[0]; got != want { + t.Errorf("Builder for repeated field RepeatedSint32 did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetRepeatedSint64()[0], b.RepeatedSint64[0]; got != want { + t.Errorf("Builder for repeated field RepeatedSint64 did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetRepeatedFixed32()[0], b.RepeatedFixed32[0]; got != want { + t.Errorf("Builder for repeated field RepeatedFixed32 did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetRepeatedFixed64()[0], b.RepeatedFixed64[0]; got != want { + t.Errorf("Builder for repeated field RepeatedFixed64 did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetRepeatedSfixed32()[0], b.RepeatedSfixed32[0]; got != want { + t.Errorf("Builder for repeated field RepeatedSfixed32 did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetRepeatedSfixed64()[0], b.RepeatedSfixed64[0]; got != want { + t.Errorf("Builder for repeated field RepeatedSfixed64 did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetRepeatedFloat()[0], b.RepeatedFloat[0]; got != want { + t.Errorf("Builder for repeated field RepeatedFloat did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetRepeatedDouble()[0], b.RepeatedDouble[0]; got != want { + t.Errorf("Builder for repeated field RepeatedDouble did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetRepeatedBool()[0], b.RepeatedBool[0]; got != want { + t.Errorf("Builder for repeated field RepeatedBool did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetRepeatedString()[0], b.RepeatedString[0]; got != want { + t.Errorf("Builder for repeated field RepeatedString did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetRepeatedBytes()[0], b.RepeatedBytes[0]; string(got) != string(want) { + t.Errorf("Builder for repeated field RepeatedBytes did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetRepeatedNestedEnum()[0], b.RepeatedNestedEnum[0]; got != want { + t.Errorf("Builder for repeated field RepeatedNestedEnum did not work, got %v, wanted %v", got, want) + } + if got, want := m.GetRepeatedNestedMessage()[0].GetA(), int32(1148); got != want { + t.Errorf("Builder for repeated field RepeatedNestedMessage did not work, got %v, wanted %v", got, want) + } + + for key, want := range b.MapInt32Int32 { + if got := m.GetMapInt32Int32()[key]; got != want { + t.Errorf("Builder for map field MapInt32Int32[%v] did not work, got %v, wanted %v", key, got, want) + } + } + + for key, want := range b.MapInt64Int64 { + if got := m.GetMapInt64Int64()[key]; got != want { + t.Errorf("Builder for map field MapInt64Int64[%v] did not work, got %v, wanted %v", key, got, want) + } + } + for key, want := range b.MapUint32Uint32 { + if got := m.GetMapUint32Uint32()[key]; got != want { + t.Errorf("Builder for map field MapUint32Uint32[%v] did not work, got %v, wanted %v", key, got, want) + } + } + for key, want := range b.MapUint64Uint64 { + if got := m.GetMapUint64Uint64()[key]; got != want { + t.Errorf("Builder for map field MapUint64Uint64[%v] did not work, got %v, wanted %v", key, got, want) + } + } + for key, want := range b.MapSint32Sint32 { + if got := m.GetMapSint32Sint32()[key]; got != want { + t.Errorf("Builder for map field MapSint32Sint32[%v] did not work, got %v, wanted %v", key, got, want) + } + } + for key, want := range b.MapSint64Sint64 { + if got := m.GetMapSint64Sint64()[key]; got != want { + t.Errorf("Builder for map field MapSint64Sint64[%v] did not work, got %v, wanted %v", key, got, want) + } + } + for key, want := range b.MapFixed32Fixed32 { + if got := m.GetMapFixed32Fixed32()[key]; got != want { + t.Errorf("Builder for map field MapFixed32Fixed32[%v] did not work, got %v, wanted %v", key, got, want) + } + } + for key, want := range b.MapFixed64Fixed64 { + if got := m.GetMapFixed64Fixed64()[key]; got != want { + t.Errorf("Builder for map field MapFixed64Fixed64[%v] did not work, got %v, wanted %v", key, got, want) + } + } + for key, want := range b.MapSfixed32Sfixed32 { + if got := m.GetMapSfixed32Sfixed32()[key]; got != want { + t.Errorf("Builder for map field MapSfixed32Sfixed32[%v] did not work, got %v, wanted %v", key, got, want) + } + } + for key, want := range b.MapSfixed64Sfixed64 { + if got := m.GetMapSfixed64Sfixed64()[key]; got != want { + t.Errorf("Builder for map field MapSfixed64Sfixed64[%v] did not work, got %v, wanted %v", key, got, want) + } + } + for key, want := range b.MapInt32Float { + if got := m.GetMapInt32Float()[key]; got != want { + t.Errorf("Builder for map field MapInt32Float[%v] did not work, got %v, wanted %v", key, got, want) + } + } + for key, want := range b.MapInt32Double { + if got := m.GetMapInt32Double()[key]; got != want { + t.Errorf("Builder for map field MapInt32Double[%v] did not work, got %v, wanted %v", key, got, want) + } + } + for key, want := range b.MapBoolBool { + if got := m.GetMapBoolBool()[key]; got != want { + t.Errorf("Builder for map field MapBoolBool[%v] did not work, got %v, wanted %v", key, got, want) + } + } + for key, want := range b.MapStringString { + if got := m.GetMapStringString()[key]; got != want { + t.Errorf("Builder for map field MapStringString[%v] did not work, got %v, wanted %v", key, got, want) + } + } + for key, want := range b.MapStringBytes { + if got := m.GetMapStringBytes()[key]; string(got) != string(want) { + t.Errorf("Builder for map field MapStringBytes[%v] did not work, got %v, wanted %v", key, got, want) + } + } + for key, want := range b.MapStringNestedMessage { + if got := m.GetMapStringNestedMessage()[key]; got.GetA() != want.GetA() { + t.Errorf("Builder for map field MapStringNestedMessage[%v] did not work, got %v, wanted %v", key, got, want) + } + } + for key, want := range b.MapStringNestedEnum { + if got := m.GetMapStringNestedEnum()[key]; got != want { + t.Errorf("Builder for map field MapStringNestedEnum[%v] did not work, got %v, wanted %v", key, got, want) + } + } + if got, want := m.GetOneofUint32(), *b.OneofUint32; got != want { + t.Errorf("Builder for field OneofUint32 did not work, got %v, wanted %v", got, want) + } +} + +func TestOpaqueBuilderOneofsLazy(t *testing.T) { + testOpaqueBuilderOneofs(t, enableLazy) +} + +func TestOpaqueBuilderOneofsEager(t *testing.T) { + testOpaqueBuilderOneofs(t, disableLazy) +} + +// TestOpaqueBuilderOneofs test each oneof option in the builder separately +func testOpaqueBuilderOneofs(t *testing.T, unmarshalOpts proto.UnmarshalOptions) { + for _, task := range []struct { + set func() (any, int, *testopaquepb.TestAllTypes) + check func(any, *testopaquepb.TestAllTypes) (bool, any) + }{ + { + // uint32 + set: func() (any, int, *testopaquepb.TestAllTypes) { + val := uint32(6754) + return val, int(testopaquepb.TestAllTypes_OneofUint32_case), testopaquepb.TestAllTypes_builder{OneofUint32: &val}.Build() + }, + check: func(x any, m *testopaquepb.TestAllTypes) (bool, any) { + want := x.(uint32) + got := m.GetOneofUint32() + return want == got, got + }, + }, + { + // message + set: func() (any, int, *testopaquepb.TestAllTypes) { + val := testopaquepb.TestAllTypes_NestedMessage_builder{A: proto.Int32(5432678)}.Build() + return val, int(testopaquepb.TestAllTypes_OneofNestedMessage_case), testopaquepb.TestAllTypes_builder{OneofNestedMessage: val}.Build() + }, + check: func(x any, m *testopaquepb.TestAllTypes) (bool, any) { + want := x.(*testopaquepb.TestAllTypes_NestedMessage) + got := m.GetOneofNestedMessage() + return want.GetA() == got.GetA(), got + }, + }, + { + // string + set: func() (any, int, *testopaquepb.TestAllTypes) { + val := "random" + return val, int(testopaquepb.TestAllTypes_OneofString_case), testopaquepb.TestAllTypes_builder{OneofString: &val}.Build() + }, + check: func(x any, m *testopaquepb.TestAllTypes) (bool, any) { + want := x.(string) + got := m.GetOneofString() + return want == got, got + }, + }, + { + // bytes + set: func() (any, int, *testopaquepb.TestAllTypes) { + val := []byte("randombytes") + return val, int(testopaquepb.TestAllTypes_OneofBytes_case), testopaquepb.TestAllTypes_builder{OneofBytes: val}.Build() + }, + check: func(x any, m *testopaquepb.TestAllTypes) (bool, any) { + want := x.([]byte) + got := m.GetOneofBytes() + return string(want) == string(got), got + }, + }, + { + // uint64 + set: func() (any, int, *testopaquepb.TestAllTypes) { + val := uint64(65934287653) + return val, int(testopaquepb.TestAllTypes_OneofUint64_case), testopaquepb.TestAllTypes_builder{OneofUint64: &val}.Build() + }, + check: func(x any, m *testopaquepb.TestAllTypes) (bool, any) { + want := x.(uint64) + got := m.GetOneofUint64() + return want == got, got + }, + }, + { + // bool + set: func() (any, int, *testopaquepb.TestAllTypes) { + val := true + return val, int(testopaquepb.TestAllTypes_OneofBool_case), testopaquepb.TestAllTypes_builder{OneofBool: &val}.Build() + }, + check: func(x any, m *testopaquepb.TestAllTypes) (bool, any) { + want := x.(bool) + got := m.GetOneofBool() + return want == got, got + }, + }, + { + // float + set: func() (any, int, *testopaquepb.TestAllTypes) { + val := float32(-54.45) + return val, int(testopaquepb.TestAllTypes_OneofFloat_case), testopaquepb.TestAllTypes_builder{OneofFloat: &val}.Build() + }, + check: func(x any, m *testopaquepb.TestAllTypes) (bool, any) { + want := x.(float32) + got := m.GetOneofFloat() + return want == got, got + }, + }, + { + // double + set: func() (any, int, *testopaquepb.TestAllTypes) { + val := float64(-45.54) + return val, int(testopaquepb.TestAllTypes_OneofDouble_case), testopaquepb.TestAllTypes_builder{OneofDouble: &val}.Build() + }, + check: func(x any, m *testopaquepb.TestAllTypes) (bool, any) { + want := x.(float64) + got := m.GetOneofDouble() + return want == got, got + }, + }, + { + // enum + set: func() (any, int, *testopaquepb.TestAllTypes) { + val := testopaquepb.TestAllTypes_BAR + return val, int(testopaquepb.TestAllTypes_OneofEnum_case), testopaquepb.TestAllTypes_builder{OneofEnum: &val}.Build() + }, + check: func(x any, m *testopaquepb.TestAllTypes) (bool, any) { + want := x.(testopaquepb.TestAllTypes_NestedEnum) + got := m.GetOneofEnum() + return want == got, got + }, + }, + } { + want, cas, m := task.set() + gotCase := int(m.WhichOneofField()) + if gotCase != cas { + t.Errorf("Builder did not make which function return correct value, got %v, wanted %v for type %T", gotCase, cas, want) + } + ok, got := task.check(want, m) + if !ok { + t.Errorf("Builder did not set oneof field correctly, got %v, wanted %v for type %T", got, want, want) + } + } +} diff --git a/cmd/protoc-gen-go/descriptor_test/descriptor_test.go b/cmd/protoc-gen-go/descriptor_test/descriptor_test.go new file mode 100644 index 000000000..70c1bb43f --- /dev/null +++ b/cmd/protoc-gen-go/descriptor_test/descriptor_test.go @@ -0,0 +1,43 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package descriptor_test + +import ( + "testing" + + testopenpb "google.golang.org/protobuf/internal/testprotos/test" + testhybridpb "google.golang.org/protobuf/internal/testprotos/testeditions/testeditions_hybrid" + testopaquepb "google.golang.org/protobuf/internal/testprotos/testeditions/testeditions_opaque" +) + +func TestFileModeEnum(t *testing.T) { + var e any = testopenpb.ForeignEnum_FOREIGN_FOO + if _, ok := e.(interface{ EnumDescriptor() ([]byte, []int) }); !ok { + t.Errorf("Open V1 proto did not have deprecated method EnumDescriptor") + } + var oe any = testopaquepb.ForeignEnum_FOREIGN_FOO + if _, ok := oe.(interface{ EnumDescriptor() ([]byte, []int) }); ok { + t.Errorf("Opaque V0 proto did have deprecated method EnumDescriptor") + } + var he any = testhybridpb.ForeignEnum_FOREIGN_FOO + if _, ok := he.(interface{ EnumDescriptor() ([]byte, []int) }); ok { + t.Errorf("Hybrid proto did have deprecated method EnumDescriptor") + } +} + +func TestFileModeMessage(t *testing.T) { + var p any = &testopenpb.TestAllTypes{} + if _, ok := p.(interface{ Descriptor() ([]byte, []int) }); !ok { + t.Errorf("Open V1 proto did not have deprecated method Descriptor") + } + var op any = &testopaquepb.TestAllTypes{} + if _, ok := op.(interface{ Descriptor() ([]byte, []int) }); ok { + t.Errorf("Opaque V0 mode proto unexpectedly has deprecated Descriptor() method") + } + var hp any = &testhybridpb.TestAllTypes{} + if _, ok := hp.(interface{ EnumDescriptor() ([]byte, []int) }); ok { + t.Errorf("Hybrid proto did have deprecated method EnumDescriptor") + } +} diff --git a/cmd/protoc-gen-go/internal_gengo/init.go b/cmd/protoc-gen-go/internal_gengo/init.go index 369df13da..62de8bb1b 100644 --- a/cmd/protoc-gen-go/internal_gengo/init.go +++ b/cmd/protoc-gen-go/internal_gengo/init.go @@ -114,6 +114,7 @@ func newEnumInfo(f *fileInfo, enum *protogen.Enum) *enumInfo { e := &enumInfo{Enum: enum} e.genJSONMethod = true e.genRawDescMethod = true + opaqueNewEnumInfoHook(f, e) return e } @@ -123,8 +124,9 @@ type messageInfo struct { genRawDescMethod bool genExtRangeMethod bool - isTracked bool - hasWeak bool + isTracked bool + noInterface bool + hasWeak bool } func newMessageInfo(f *fileInfo, message *protogen.Message) *messageInfo { @@ -135,6 +137,7 @@ func newMessageInfo(f *fileInfo, message *protogen.Message) *messageInfo { for _, field := range m.Fields { m.hasWeak = m.hasWeak || field.Desc.IsWeak() } + opaqueNewMessageInfoHook(f, m) return m } diff --git a/cmd/protoc-gen-go/internal_gengo/init_opaque.go b/cmd/protoc-gen-go/internal_gengo/init_opaque.go new file mode 100644 index 000000000..221176a22 --- /dev/null +++ b/cmd/protoc-gen-go/internal_gengo/init_opaque.go @@ -0,0 +1,33 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package internal_gengo + +import "google.golang.org/protobuf/types/gofeaturespb" + +func (m *messageInfo) isOpen() bool { + return m.Message.APILevel == gofeaturespb.GoFeatures_API_OPEN +} + +func (m *messageInfo) isHybrid() bool { + return m.Message.APILevel == gofeaturespb.GoFeatures_API_HYBRID +} + +func (m *messageInfo) isOpaque() bool { + return m.Message.APILevel == gofeaturespb.GoFeatures_API_OPAQUE +} + +func opaqueNewEnumInfoHook(f *fileInfo, e *enumInfo) { + if f.File.APILevel != gofeaturespb.GoFeatures_API_OPEN { + e.genJSONMethod = false + e.genRawDescMethod = false + } +} + +func opaqueNewMessageInfoHook(f *fileInfo, m *messageInfo) { + if !m.isOpen() { + m.genRawDescMethod = false + m.genExtRangeMethod = false + } +} diff --git a/cmd/protoc-gen-go/internal_gengo/main.go b/cmd/protoc-gen-go/internal_gengo/main.go index a4c4595ec..e4933086d 100644 --- a/cmd/protoc-gen-go/internal_gengo/main.go +++ b/cmd/protoc-gen-go/internal_gengo/main.go @@ -367,6 +367,9 @@ func genMessage(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo) { if m.Desc.IsMapEntry() { return } + if opaqueGenMessageHook(g, f, m) { + return + } // Message type declaration. g.AnnotateSymbol(m.GoIdent.GoName, protogen.Annotation{Location: m.Location}) @@ -657,7 +660,7 @@ func genMessageSetterMethods(g *protogen.GeneratedFile, f *fileInfo, m *messageI continue } - genNoInterfacePragma(g, m.isTracked) + genNoInterfacePragma(g, m.noInterface) g.AnnotateSymbol(m.GoIdent.GoName+".Set"+field.GoName, protogen.Annotation{ Location: field.Location, diff --git a/cmd/protoc-gen-go/internal_gengo/opaque.go b/cmd/protoc-gen-go/internal_gengo/opaque.go new file mode 100644 index 000000000..dafa095f0 --- /dev/null +++ b/cmd/protoc-gen-go/internal_gengo/opaque.go @@ -0,0 +1,1306 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package internal_gengo + +import ( + "fmt" + "strconv" + "strings" + "unicode" + "unicode/utf8" + + "google.golang.org/protobuf/compiler/protogen" + "google.golang.org/protobuf/internal/genid" + "google.golang.org/protobuf/reflect/protoreflect" + + "google.golang.org/protobuf/types/descriptorpb" +) + +func opaqueGenMessageHook(g *protogen.GeneratedFile, f *fileInfo, message *messageInfo) bool { + opaqueGenMessage(g, f, message) + return true +} + +func opaqueGenMessage(g *protogen.GeneratedFile, f *fileInfo, message *messageInfo) { + // Message type declaration. + g.AnnotateSymbol(message.GoIdent.GoName, protogen.Annotation{Location: message.Location}) + leadingComments := appendDeprecationSuffix(message.Comments.Leading, + message.Desc.ParentFile(), + message.Desc.Options().(*descriptorpb.MessageOptions).GetDeprecated()) + g.P(leadingComments, + "type ", message.GoIdent, " struct {") + + sf := f.allMessageFieldsByPtr[message] + if sf == nil { + sf = new(structFields) + f.allMessageFieldsByPtr[message] = sf + } + + var tags structTags + switch { + case message.isOpen(): + tags = structTags{{"protogen", "open.v1"}} + case message.isHybrid(): + tags = structTags{{"protogen", "hybrid.v1"}} + case message.isOpaque(): + tags = structTags{{"protogen", "opaque.v1"}} + } + + g.P(genid.State_goname, " ", protoimplPackage.Ident("MessageState"), tags) + sf.append(genid.State_goname) + fields := message.Fields + for _, field := range fields { + opaqueGenMessageField(g, f, message, field, sf) + } + opaqueGenMessageInternalFields(g, f, message, sf) + g.P("}") + g.P() + + genMessageKnownFunctions(g, f, message) + genMessageDefaultDecls(g, f, message) + opaqueGenMessageMethods(g, f, message) + opaqueGenMessageBuilder(g, f, message) + opaqueGenOneofWrapperTypes(g, f, message) +} + +// opaqueGenMessageField generates a struct field. +func opaqueGenMessageField(g *protogen.GeneratedFile, f *fileInfo, message *messageInfo, field *protogen.Field, sf *structFields) { + if oneof := field.Oneof; oneof != nil && !oneof.Desc.IsSynthetic() { + // It would be a bit simpler to iterate over the oneofs below, + // but generating the field here keeps the contents of the Go + // struct in the same order as the contents of the source + // .proto file. + if field != oneof.Fields[0] { + return + } + opaqueGenOneofFields(g, f, message, oneof, sf) + return + } + + goType, pointer := opaqueFieldGoType(g, f, message, field) + if pointer { + goType = "*" + goType + } + protobufTagValue := fieldProtobufTagValue(field) + jsonTagValue := fieldJSONTagValue(field) + if g.InternalStripForEditionsDiff() { + if field.Desc.ContainingOneof() != nil && field.Desc.ContainingOneof().IsSynthetic() { + protobufTagValue = strings.ReplaceAll(protobufTagValue, ",oneof", "") + } + protobufTagValue = strings.ReplaceAll(protobufTagValue, ",proto3", "") + } + tags := structTags{ + {"protobuf", protobufTagValue}, + {"json", jsonTagValue}, + } + if field.Desc.IsMap() { + keyTagValue := fieldProtobufTagValue(field.Message.Fields[0]) + valTagValue := fieldProtobufTagValue(field.Message.Fields[1]) + keyTagValue = strings.ReplaceAll(keyTagValue, ",proto3", "") + valTagValue = strings.ReplaceAll(valTagValue, ",proto3", "") + tags = append(tags, structTags{ + {"protobuf_key", keyTagValue}, + {"protobuf_val", valTagValue}, + }...) + } + + name := field.GoName + if field.Desc.IsWeak() { + g.P("// Deprecated: Do not use. This will be deleted in the near future.") + name = genid.WeakFieldPrefix_goname + name + } else if message.isOpaque() { + name = "xxx_hidden_" + name + } + + if message.isOpaque() && !field.Desc.IsWeak() { + g.P(name, " ", goType, tags) + sf.append(name) + if message.isTracked { + g.P("// Deprecated: Do not use. This will be deleted in the near future.") + g.P("XXX_ft_", field.GoName, " struct{} `go:\"track\"`") + sf.append("XXX_ft_" + field.GoName) + } + } else { + if message.isTracked { + tags = append(tags, structTags{ + {"go", "track"}, + }...) + } + g.AnnotateSymbol(field.Parent.GoIdent.GoName+"."+name, protogen.Annotation{Location: field.Location}) + leadingComments := appendDeprecationSuffix(field.Comments.Leading, + field.Desc.ParentFile(), + field.Desc.Options().(*descriptorpb.FieldOptions).GetDeprecated()) + g.P(leadingComments, + name, " ", goType, tags, + trailingComment(field.Comments.Trailing)) + sf.append(name) + } +} + +// opaqueGenOneofFields generates the message fields for a oneof. +func opaqueGenOneofFields(g *protogen.GeneratedFile, f *fileInfo, message *messageInfo, oneof *protogen.Oneof, sf *structFields) { + tags := structTags{ + {"protobuf_oneof", string(oneof.Desc.Name())}, + } + if message.isTracked { + tags = append(tags, structTags{ + {"go", "track"}, + }...) + } + + oneofName := opaqueOneofFieldName(oneof, message.isOpaque()) + goType := opaqueOneofInterfaceName(oneof) + + if message.isOpaque() { + g.P(oneofName, " ", goType, tags) + sf.append(oneofName) + if message.isTracked { + g.P("// Deprecated: Do not use. This will be deleted in the near future.") + g.P("XXX_ft_", oneof.GoName, " struct{} `go:\"track\"`") + sf.append("XXX_ft_" + oneof.GoName) + } + return + } + + leadingComments := oneof.Comments.Leading + if leadingComments != "" { + leadingComments += "\n" + } + // NOTE(rsc): The extra \n here is working around #52605, + // making the comment be in Go 1.19 doc comment format + // even though it's not really a doc comment. + ss := []string{" Types that are valid to be assigned to ", oneofName, ":\n\n"} + for _, field := range oneof.Fields { + ss = append(ss, "\t*"+opaqueFieldOneofType(field, message.isOpaque()).GoName+"\n") + } + leadingComments += protogen.Comments(strings.Join(ss, "")) + g.P(leadingComments, oneofName, " ", goType, tags) + sf.append(oneofName) +} + +// opaqueGenMessageInternalFields adds additional XXX_ fields to a message struct. +func opaqueGenMessageInternalFields(g *protogen.GeneratedFile, f *fileInfo, message *messageInfo, sf *structFields) { + if opaqueNeedsPresenceArray(message) { + if opaqueNeedsLazyStruct(message) { + g.P("// Deprecated: Do not use. This will be deleted in the near future.") + g.P("XXX_lazyUnmarshalInfo ", protoimplPackage.Ident("LazyUnmarshalInfo")) + sf.append("XXX_lazyUnmarshalInfo") + } + g.P("XXX_raceDetectHookData ", protoimplPackage.Ident("RaceDetectHookData")) + sf.append("XXX_raceDetectHookData") + + // Presence must be stored in a data type no larger than 32 bit: + // + // Presence used to be a uint64, accessed with atomic.LoadUint64, but it + // turns out that on 32-bit platforms like GOARCH=arm, the struct field + // was 32-bit aligned (not 64-bit aligned) and hence atomic accesses + // failed. + // + // The easiest solution was to switch to a uint32 on all platforms, + // which did not come with a performance penalty. + g.P("XXX_presence [", (opaqueNumPresenceFields(message)+31)/32, "]uint32") + sf.append("XXX_presence") + } + if message.hasWeak { + g.P(genid.WeakFields_goname, " ", protoimplPackage.Ident("WeakFields")) + sf.append(genid.WeakFields_goname) + } + if message.Desc.ExtensionRanges().Len() > 0 { + g.P(genid.ExtensionFields_goname, " ", protoimplPackage.Ident("ExtensionFields")) + sf.append(genid.ExtensionFields_goname) + } + g.P(genid.UnknownFields_goname, " ", protoimplPackage.Ident("UnknownFields")) + sf.append(genid.UnknownFields_goname) + g.P(genid.SizeCache_goname, " ", protoimplPackage.Ident("SizeCache")) + sf.append(genid.SizeCache_goname) +} + +func opaqueGenMessageMethods(g *protogen.GeneratedFile, f *fileInfo, message *messageInfo) { + genMessageBaseMethods(g, f, message) + + isRepeated := func(field *protogen.Field) bool { + return field.Desc.Cardinality() == protoreflect.Repeated + } + + for _, field := range message.Fields { + if isFirstOneofField(field) && !message.isOpaque() { + opaqueGenGetOneof(g, f, message, field.Oneof) + } + opaqueGenGet(g, f, message, field) + } + for _, field := range message.Fields { + // For the plain open mode, we only have set methods for weak fields. + if message.isOpen() && !field.Desc.IsWeak() { + continue + } + opaqueGenSet(g, f, message, field) + } + for _, field := range message.Fields { + // Open API does not have Has method. + // Repeated (includes map) fields do not have Has method. + if message.isOpen() || isRepeated(field) { + continue + } + + if !field.Desc.HasPresence() { + continue + } + + if isFirstOneofField(field) { + opaqueGenHasOneof(g, f, message, field.Oneof) + } + opaqueGenHas(g, f, message, field) + } + for _, field := range message.Fields { + // Open API does not have Clear method. + // Repeated (includes map) fields do not have Clear method. + if message.isOpen() || isRepeated(field) { + continue + } + if !field.Desc.HasPresence() { + continue + } + + if isFirstOneofField(field) { + opaqueGenClearOneof(g, f, message, field.Oneof) + } + opaqueGenClear(g, f, message, field) + } + // Plain open protos do not have which methods. + if !message.isOpen() { + opaqueGenWhichOneof(g, f, message) + } + + if g.InternalStripForEditionsDiff() { + return + } +} + +func isLazy(field *protogen.Field) bool { + // Prerequisite: field is of kind message + if field.Message == nil { + return false + } + + // Was the field marked as [lazy = true] in the .proto file? + fopts := field.Desc.Options().(*descriptorpb.FieldOptions) + return fopts.GetLazy() +} + +// opaqueGenGet generates a Get method for a field. +func opaqueGenGet(g *protogen.GeneratedFile, f *fileInfo, message *messageInfo, field *protogen.Field) { + goType, pointer := opaqueFieldGoType(g, f, message, field) + getterName, bcName := field.MethodName("Get") + + // If we need a backwards compatible getter name, we add it now. + if bcName != "" { + defer func() { + g.P("// Deprecated: Use ", getterName, " instead.") + g.P("func (x *", message.GoIdent, ") ", bcName, "() ", goType, " {") + g.P("return x.", getterName, "()") + g.P("}") + g.P() + }() + } + + leadingComments := appendDeprecationSuffix("", + field.Desc.ParentFile(), + field.Desc.Options().(*descriptorpb.FieldOptions).GetDeprecated()) + fieldtrackNoInterface(g, message.isTracked) + g.AnnotateSymbol(message.GoIdent.GoName+"."+getterName, protogen.Annotation{Location: field.Location}) + + // Weak field. + if field.Desc.IsWeak() { + g.P(leadingComments, "func (x *", message.GoIdent, ") ", getterName, "() ", protoPackage.Ident("Message"), "{") + g.P("var w ", protoimplPackage.Ident("WeakFields")) + g.P("if x != nil {") + g.P("w = x.", genid.WeakFields_goname) + if message.isTracked { + g.P("_ = x.", genid.WeakFieldPrefix_goname+field.GoName) + } + g.P("}") + g.P("return ", protoimplPackage.Ident("X"), ".GetWeak(w, ", field.Desc.Number(), ", ", strconv.Quote(string(field.Message.Desc.FullName())), ")") + g.P("}") + g.P() + return + } + + defaultValue := fieldDefaultValue(g, f, message, field) + + // Oneof field. + if oneof := field.Oneof; oneof != nil && !oneof.Desc.IsSynthetic() { + structPtr := "x" + g.P(leadingComments, "func (x *", message.GoIdent, ") ", getterName, "() ", goType, " {") + g.P("if x != nil {") + if message.isOpaque() && message.isTracked { + g.P("_ = ", structPtr, ".XXX_ft_", field.Oneof.GoName) + } + g.P("if x, ok := ", structPtr, ".", opaqueOneofFieldName(oneof, message.isOpaque()), ".(*", opaqueFieldOneofType(field, message.isOpaque()), "); ok {") + g.P("return x.", field.GoName) + g.P("}") + // End if m != nil {. + g.P("}") + g.P("return ", defaultValue) + g.P("}") + g.P() + return + } + + // Non-oneof field for open type message. + if !message.isOpaque() { + g.P(leadingComments, "func (x *", message.GoIdent, ") ", getterName, "() ", goType, " {") + if !field.Desc.HasPresence() || defaultValue == "nil" { + g.P("if x != nil {") + } else { + g.P("if x != nil && x.", field.GoName, " != nil {") + } + star := "" + if pointer { + star = "*" + } + g.P("return ", star, " x.", field.GoName) + g.P("}") + g.P("return ", defaultValue) + g.P("}") + g.P() + return + } + + // Non-oneof field for opaque type message. + g.P(leadingComments, "func (x *", message.GoIdent, ") ", getterName, "() ", goType, "{") + structPtr := "x" + g.P("if x != nil {") + if message.isTracked { + g.P("_ = ", structPtr, ".XXX_ft_", field.GoName) + } + if usePresence(message, field) { + pi := opaqueFieldPresenceIndex(field) + ai := pi / 32 + // For + // + // 1. Message fields of lazy messages (unmarshalled lazily), + // 2. Fields with a default value, + // 3. Closed enums + // + // ...we check presence, but for other fields using presence, we can return + // whatever is there and it should be correct regardless of presence, which + // saves us an atomic operation. + isEnum := field.Desc.Kind() == protoreflect.EnumKind + usePresenceForRead := (isLazy(field)) || + field.Desc.HasDefault() || isEnum + + if usePresenceForRead { + g.P("if ", protoimplPackage.Ident("X"), ".Present(&(", structPtr, ".XXX_presence[", ai, "]),", pi, ") {") + } + // For lazy, check if pointer is nil and optionally unmarshal + if isLazy(field) { + // Since pointer to lazily unmarshaled sub-message can be written during a conceptual + // "read" operation, all read/write accesses to the pointer must be atomic. This + // function gets inlined on x86 as just a simple get and compare. Still need to make the + // slice accesses be atomic. + g.P("if ", protoimplPackage.Ident("X"), ".AtomicCheckPointerIsNil(&", structPtr, ".xxx_hidden_", field.GoName, ") {") + g.P(protoimplPackage.Ident("X"), ".UnmarshalField(", structPtr, ", ", field.Desc.Number(), ")") + g.P("}") + } + if field.Message == nil || field.Desc.IsMap() { + star := "" + if pointer { + star = "*" + } + if pointer { + g.P("if ", structPtr, ".xxx_hidden_", field.GoName, "!= nil {") + } + + g.P("return ", star, structPtr, ".xxx_hidden_", field.GoName) + if pointer { + g.P("}") + g.P("return ", defaultValue) + } + } else { + // We need to do an atomic load of the msg pointer field, but cannot explicitly use + // unsafe pointers here. We load the value and store into rv, via protoimpl.Pointer, + // which is aliased to unsafe.Pointer in pointer_unsafe.go, but is aliased to + // interface{} in pointer_reflect.go + star := "" + if pointer { + star = "*" + } + if isLazy(field) { + g.P("var rv ", star, goType) + g.P(protoimplPackage.Ident("X"), ".AtomicLoadPointer(", protoimplPackage.Ident("Pointer"), "(&", structPtr, ".xxx_hidden_", field.GoName, "), ", protoimplPackage.Ident("Pointer"), "(&rv))") + g.P("return ", star, "rv") + } else { + if pointer { + g.P("if ", structPtr, ".xxx_hidden_", field.GoName, "!= nil {") + } + g.P("return ", star, structPtr, ".xxx_hidden_", field.GoName) + if pointer { + g.P("}") + } + } + } + if usePresenceForRead { + g.P("}") + } + } else if pointer { + g.P("if ", structPtr, ".xxx_hidden_", field.GoName, " != nil {") + g.P("return *", structPtr, ".xxx_hidden_", field.GoName) + g.P("}") + } else { + g.P("return ", structPtr, ".xxx_hidden_", field.GoName) + } + // End if m != nil {. + g.P("}") + g.P("return ", defaultValue) + g.P("}") + g.P() +} + +// opaqueGenSet generates a Set method for a field. +func opaqueGenSet(g *protogen.GeneratedFile, f *fileInfo, message *messageInfo, field *protogen.Field) { + goType, pointer := opaqueFieldGoType(g, f, message, field) + setterName, bcName := field.MethodName("Set") + + // If we need a backwards compatible setter name, we add it now. + if bcName != "" { + defer func() { + g.P("// Deprecated: Use ", setterName, " instead.") + g.P("func (x *", message.GoIdent, ") ", bcName, "(v ", goType, ") {") + g.P("x.", setterName, "(v)") + g.P("}") + g.P() + }() + } + + leadingComments := appendDeprecationSuffix("", + field.Desc.ParentFile(), + field.Desc.Options().(*descriptorpb.FieldOptions).GetDeprecated()) + g.AnnotateSymbol(message.GoIdent.GoName+"."+setterName, protogen.Annotation{ + Location: field.Location, + Semantic: descriptorpb.GeneratedCodeInfo_Annotation_SET.Enum(), + }) + fieldtrackNoInterface(g, message.noInterface) + + // Weak field. + if field.Desc.IsWeak() { + g.P(leadingComments, "func (x *", message.GoIdent, ") ", setterName, "(v ", protoPackage.Ident("Message"), ") {") + g.P("var w *", protoimplPackage.Ident("WeakFields")) + g.P("if x != nil {") + g.P("w = &x.", genid.WeakFields_goname) + if message.isTracked { + g.P("_ = x.", genid.WeakFieldPrefix_goname+field.GoName) + } + g.P("}") + g.P(protoimplPackage.Ident("X"), ".SetWeak(w, ", field.Desc.Number(), ", ", strconv.Quote(string(field.Message.Desc.FullName())), ", v)") + g.P("}") + g.P() + return + } + + // Oneof field. + if oneof := field.Oneof; oneof != nil && !oneof.Desc.IsSynthetic() { + g.P(leadingComments, "func (x *", message.GoIdent, ") ", setterName, "(v ", goType, ") {") + structPtr := "x" + if message.isOpaque() && message.isTracked { + // Add access to zero field for tracking + g.P(structPtr, ".XXX_ft_", oneof.GoName, " = struct{}{}") + } + if field.Desc.Kind() == protoreflect.BytesKind { + g.P("if v == nil { v = []byte{} }") + } else if field.Message != nil { + g.P("if v == nil {") + g.P(structPtr, ".", opaqueOneofFieldName(oneof, message.isOpaque()), "= nil") + g.P("return") + g.P("}") + } + g.P(structPtr, ".", opaqueOneofFieldName(oneof, message.isOpaque()), "= &", opaqueFieldOneofType(field, message.isOpaque()), "{v}") + g.P("}") + g.P() + return + } + + // Non-oneof field for open type message. + if !message.isOpaque() { + g.P(leadingComments, "func (x *", message.GoIdent, ") ", setterName, "(v ", goType, ") {") + if field.Desc.Cardinality() != protoreflect.Repeated && field.Desc.Kind() == protoreflect.BytesKind { + g.P("if v == nil { v = []byte{} }") + } + amp := "" + if pointer { + amp = "&" + } + + v := "v" + g.P("x.", field.GoName, " = ", amp, v) + g.P("}") + g.P() + return + } + + // Non-oneof field for opaque type message. + g.P(leadingComments, "func (x *", message.GoIdent, ") ", setterName, "(v ", goType, ") {") + structPtr := "x" + if message.isTracked { + // Add access to zero field for tracking + g.P(structPtr, ".XXX_ft_", field.GoName, " = struct{}{}") + } + if field.Desc.Cardinality() != protoreflect.Repeated && field.Desc.Kind() == protoreflect.BytesKind { + g.P("if v == nil { v = []byte{} }") + } + amp := "" + if pointer { + amp = "&" + } + if usePresence(message, field) { + pi := opaqueFieldPresenceIndex(field) + ai := pi / 32 + + if field.Message != nil && field.Desc.IsList() { + g.P("var sv *", goType) + g.P(protoimplPackage.Ident("X"), ".AtomicLoadPointer(", protoimplPackage.Ident("Pointer"), "(&", structPtr, ".xxx_hidden_", field.GoName, "), ", protoimplPackage.Ident("Pointer"), "(&sv))") + g.P("if sv == nil {") + g.P("sv = &", goType, "{}") + g.P(protoimplPackage.Ident("X"), ".AtomicInitializePointer(", protoimplPackage.Ident("Pointer"), "(&", structPtr, ".xxx_hidden_", field.GoName, "), ", protoimplPackage.Ident("Pointer"), "(&sv))") + g.P("}") + g.P("*sv = v") + g.P(protoimplPackage.Ident("X"), ".SetPresent(&(", structPtr, ".XXX_presence[", ai, "]),", pi, ",", opaqueNumPresenceFields(message), ")") + } else if field.Message != nil && !field.Desc.IsMap() { + // Only for lazy messages do we need to set pointers atomically + if isLazy(field) { + g.P(protoimplPackage.Ident("X"), ".AtomicSetPointer(&", structPtr, ".xxx_hidden_", field.GoName, ", ", amp, "v)") + } else { + g.P(structPtr, ".xxx_hidden_", field.GoName, " = ", amp, "v") + } + // When setting a message or slice of messages to a nil + // value, we must clear the presence bit, else we will + // later think that this field still needs to be lazily decoded. + g.P("if v == nil {") + g.P(protoimplPackage.Ident("X"), ".ClearPresent(&(", structPtr, ".XXX_presence[", ai, "]),", pi, ")") + g.P("} else {") + g.P(protoimplPackage.Ident("X"), ".SetPresent(&(", structPtr, ".XXX_presence[", ai, "]),", pi, ",", opaqueNumPresenceFields(message), ")") + g.P("}") + } else { + // Any map or non-message, possibly repeated, field that uses presence (proto2 only) + g.P(structPtr, ".xxx_hidden_", field.GoName, " = ", amp, "v") + // For consistent behaviour with lazy fields, non-map repeated fields should be cleared when + // the last object is removed. Maps are cleared when set to a nil map. + if field.Desc.Cardinality() == protoreflect.Repeated { // Includes maps. + g.P("if v == nil {") + g.P(protoimplPackage.Ident("X"), ".ClearPresent(&(", structPtr, ".XXX_presence[", ai, "]),", pi, ")") + g.P("} else {") + } + g.P(protoimplPackage.Ident("X"), ".SetPresent(&(", structPtr, ".XXX_presence[", ai, "]),", pi, ",", opaqueNumPresenceFields(message), ")") + if field.Desc.Cardinality() == protoreflect.Repeated { + g.P("}") + } + } + } else { + // proto3 non-lazy fields + g.P(structPtr, ".xxx_hidden_", field.GoName, " = ", amp, "v") + } + g.P("}") + g.P() +} + +// usePresence returns true if the presence map should be used for a field. It +// is always true for lazy message types. It is also true for all scalar fields. +// repeated, map or message fields are not using the presence map. +func usePresence(message *messageInfo, field *protogen.Field) bool { + if !message.isOpaque() || field.Desc.IsWeak() { + return false + } + return opaqueFieldNeedsPresenceArray(message, field) +} + +func opaqueFieldNeedsPresenceArray(message *messageInfo, field *protogen.Field) bool { + // Non optional fields need presence if truly lazy field, i.e. are message fields. + if isLazy(field) { + return true + } + isNotOneof := field.Desc.ContainingOneof() == nil || field.Desc.ContainingOneof().IsSynthetic() + return field.Desc.HasPresence() && field.Message == nil && isNotOneof +} + +// opaqueGenHas generates a Has method for a field. +func opaqueGenHas(g *protogen.GeneratedFile, f *fileInfo, message *messageInfo, field *protogen.Field) { + hasserName, _ := field.MethodName("Has") + + leadingComments := appendDeprecationSuffix("", + field.Desc.ParentFile(), + field.Desc.Options().(*descriptorpb.FieldOptions).GetDeprecated()) + g.AnnotateSymbol(message.GoIdent.GoName+"."+hasserName, protogen.Annotation{Location: field.Location}) + fieldtrackNoInterface(g, message.noInterface) + + // Weak field. + if field.Desc.IsWeak() { + g.P(leadingComments, "func (x *", message.GoIdent, ") ", hasserName, "() bool {") + g.P("var w ", protoimplPackage.Ident("WeakFields")) + g.P("if x != nil {") + g.P("w = x.", genid.WeakFields_goname) + if message.isTracked { + g.P("_ = x.", genid.WeakFieldPrefix_goname+field.GoName) + } + g.P("}") + g.P("return ", protoimplPackage.Ident("X"), ".HasWeak(w, ", field.Desc.Number(), ")") + g.P("}") + g.P() + return + } + + // Oneof field. + if oneof := field.Oneof; oneof != nil && !oneof.Desc.IsSynthetic() { + g.P(leadingComments, "func (x *", message.GoIdent, ") ", hasserName, "() bool {") + structPtr := "x" + g.P("if ", structPtr, " == nil {") + g.P("return false") + g.P("}") + if message.isOpaque() && message.isTracked { + // Add access to zero field for tracking + g.P("_ = ", structPtr, ".", "XXX_ft_", oneof.GoName) + } + g.P("_, ok := ", structPtr, ".", opaqueOneofFieldName(oneof, message.isOpaque()), ".(*", opaqueFieldOneofType(field, message.isOpaque()), ")") + g.P("return ok") + g.P("}") + g.P() + return + } + + // Non-oneof field in open message. + if !message.isOpaque() { + g.P(leadingComments, "func (x *", message.GoIdent, ") ", hasserName, "() bool {") + g.P("if x == nil {") + g.P("return false") + g.P("}") + g.P("return ", "x.", field.GoName, " != nil") + g.P("}") + g.P() + return + } + + // Non-oneof field in opaque message. + g.P(leadingComments, "func (x *", message.GoIdent, ") ", hasserName, "() bool {") + g.P("if x == nil {") + g.P("return false") + g.P("}") + structPtr := "x" + if message.isTracked { + // Add access to zero field for tracking + g.P("_ = ", structPtr, ".", "XXX_ft_"+field.GoName) + } + if usePresence(message, field) { + pi := opaqueFieldPresenceIndex(field) + ai := pi / 32 + g.P("return ", protoimplPackage.Ident("X"), ".Present(&(", structPtr, ".XXX_presence[", ai, "]),", pi, ")") + } else { + // Has for proto3 message without presence + g.P("return ", structPtr, ".xxx_hidden_", field.GoName, " != nil") + } + + g.P("}") + g.P() +} + +// opaqueGenClear generates a Clear method for a field. +func opaqueGenClear(g *protogen.GeneratedFile, f *fileInfo, message *messageInfo, field *protogen.Field) { + clearerName, _ := field.MethodName("Clear") + pi := opaqueFieldPresenceIndex(field) + ai := pi / 32 + + leadingComments := appendDeprecationSuffix("", + field.Desc.ParentFile(), + field.Desc.Options().(*descriptorpb.FieldOptions).GetDeprecated()) + g.AnnotateSymbol(message.GoIdent.GoName+"."+clearerName, protogen.Annotation{ + Location: field.Location, + Semantic: descriptorpb.GeneratedCodeInfo_Annotation_SET.Enum(), + }) + fieldtrackNoInterface(g, message.noInterface) + + // Weak field. + if field.Desc.IsWeak() { + g.P(leadingComments, "func (x *", message.GoIdent, ") ", clearerName, "() {") + g.P("var w *", protoimplPackage.Ident("WeakFields")) + g.P("if x != nil {") + g.P("w = &x.", genid.WeakFields_goname) + if message.isTracked { + g.P("_ = x.", genid.WeakFieldPrefix_goname+field.GoName) + } + g.P("}") + g.P(protoimplPackage.Ident("X"), ".ClearWeak(w, ", field.Desc.Number(), ")") + g.P("}") + g.P() + return + } + + // Oneof field. + if oneof := field.Oneof; oneof != nil && !oneof.Desc.IsSynthetic() { + g.P(leadingComments, "func (x *", message.GoIdent, ") ", clearerName, "() {") + structPtr := "x" + if message.isOpaque() && message.isTracked { + // Add access to zero field for tracking + g.P(structPtr, ".", "XXX_ft_", oneof.GoName, " = struct{}{}") + } + g.P("if _, ok := ", structPtr, ".", opaqueOneofFieldName(oneof, message.isOpaque()), ".(*", opaqueFieldOneofType(field, message.isOpaque()), "); ok {") + g.P(structPtr, ".", opaqueOneofFieldName(oneof, message.isOpaque()), " = nil") + g.P("}") + g.P("}") + g.P() + return + } + + // Non-oneof field in open message. + if !message.isOpaque() { + g.P(leadingComments, "func (x *", message.GoIdent, ") ", clearerName, "() {") + g.P("x.", field.GoName, " = nil") + g.P("}") + g.P() + return + } + + // Non-oneof field in opaque message. + g.P(leadingComments, "func (x *", message.GoIdent, ") ", clearerName, "() {") + structPtr := "x" + if message.isTracked { + // Add access to zero field for tracking + g.P(structPtr, ".", "XXX_ft_", field.GoName, " = struct{}{}") + } + + if usePresence(message, field) { + g.P(protoimplPackage.Ident("X"), ".ClearPresent(&(", structPtr, ".XXX_presence[", ai, "]),", pi, ")") + } + + // Avoid needing to read the presence value in Get by ensuring that we set the + // right zero value (unless we have an explicit default, in which case we + // revert to presence checking in Get). Rationale: Get is called far more + // frequently than Clear, it should be as lean as possible. + zv := opaqueZeroValueForField(g, field) + // For lazy, (repeated) message fields are unmarshalled lazily. Hence they are + // assigned atomically in Getters (which are allowed to be called + // concurrently). Due to this, historically, the code generator would use + // atomic operations everywhere. + // + // TODO(b/291588964): Stop using atomic operations for non-presence fields in + // write calls (Set/Clear). Concurrent reads are allowed, + // but concurrent read/write or write/write are not, we + // shouldn't cater to it. + if isLazy(field) { + goType, _ := opaqueFieldGoType(g, f, message, field) + g.P(protoimplPackage.Ident("X"), ".AtomicSetPointer(&", structPtr, ".xxx_hidden_", field.GoName, ",(", goType, ")(", zv, "))") + } else if !field.Desc.HasDefault() { + g.P(structPtr, ".xxx_hidden_", field.GoName, " = ", zv) + } + g.P("}") + g.P() +} + +// Determine what value to set a cleared field to. +func opaqueZeroValueForField(g *protogen.GeneratedFile, field *protogen.Field) string { + if field.Desc.Cardinality() == protoreflect.Repeated { + return "nil" + } + switch field.Desc.Kind() { + case protoreflect.StringKind: + return "nil" + case protoreflect.MessageKind, protoreflect.GroupKind, protoreflect.BytesKind: + return "nil" + case protoreflect.BoolKind: + return "false" + case protoreflect.EnumKind: + return g.QualifiedGoIdent(field.Enum.Values[0].GoIdent) + default: + return "0" + } +} + +// opaqueGenGetOneof generates a Get function for a oneof union. +func opaqueGenGetOneof(g *protogen.GeneratedFile, f *fileInfo, message *messageInfo, oneof *protogen.Oneof) { + ifName := opaqueOneofInterfaceName(oneof) + g.AnnotateSymbol(message.GoIdent.GoName+".Get"+oneof.GoName, protogen.Annotation{Location: oneof.Location}) + fieldtrackNoInterface(g, message.isTracked) + g.P("func (x *", message.GoIdent.GoName, ") Get", oneof.GoName, "() ", ifName, " {") + g.P("if x != nil {") + g.P("return x.", opaqueOneofFieldName(oneof, message.isOpaque())) + g.P("}") + g.P("return nil") + g.P("}") + g.P() +} + +// opaqueGenHasOneof generates a Has function for a oneof union. +func opaqueGenHasOneof(g *protogen.GeneratedFile, f *fileInfo, message *messageInfo, oneof *protogen.Oneof) { + fieldtrackNoInterface(g, message.noInterface) + hasserName := oneof.MethodName("Has") + g.P("func (x *", message.GoIdent, ") ", hasserName, "() bool {") + g.P("if x == nil {") + g.P("return false") + g.P("}") + structPtr := "x" + if message.isOpaque() && message.isTracked { + // Add access to zero field for tracking + g.P("_ = ", structPtr, ".XXX_ft_", oneof.GoName) + } + g.P("return ", structPtr, ".", opaqueOneofFieldName(oneof, message.isOpaque()), " != nil") + g.P("}") + g.P() +} + +// opaqueGenClearOneof generates a Clear function for a oneof union. +func opaqueGenClearOneof(g *protogen.GeneratedFile, f *fileInfo, message *messageInfo, oneof *protogen.Oneof) { + fieldtrackNoInterface(g, message.noInterface) + clearerName := oneof.MethodName("Clear") + g.P("func (x *", message.GoIdent, ") ", clearerName, "() {") + structPtr := "x" + if message.isOpaque() && message.isTracked { + // Add access to zero field for tracking + g.P(structPtr, ".", "XXX_ft_", oneof.GoName, " = struct{}{}") + } + g.P(structPtr, ".", opaqueOneofFieldName(oneof, message.isOpaque()), " = nil") + g.P("}") + g.P() +} + +// opaqueGenWhichOneof generates the Which method for each oneof union, as well as the case values for each member +// of that union. +func opaqueGenWhichOneof(g *protogen.GeneratedFile, f *fileInfo, message *messageInfo) { + // Go through the message, and for each field that is the first of a oneof field, dig down + // and generate constants + the actual which method. + oneofIndex := 0 + for _, field := range message.Fields { + if oneof := field.Oneof; oneof != nil { + if !isFirstOneofField(field) { + continue + } + caseType := opaqueOneofCaseTypeName(oneof) + g.P("const ", message.GoIdent.GoName, "_", oneof.GoName, "_not_set_case ", caseType, " = ", 0) + for _, f := range oneof.Fields { + g.P("const ", message.GoIdent.GoName, "_", f.GoName, "_case ", caseType, " = ", f.Desc.Number()) + } + fieldtrackNoInterface(g, message.noInterface) + whicherName := oneof.MethodName("Which") + g.P("func (x *", message.GoIdent, ") ", whicherName, "() ", caseType, " {") + g.P("if x == nil {") + g.P("return ", message.GoIdent.GoName, "_", oneof.GoName, "_not_set_case ") + g.P("}") + g.P("switch x.", opaqueOneofFieldName(oneof, message.isOpaque()), ".(type) {") + for _, f := range oneof.Fields { + g.P("case *", opaqueFieldOneofType(f, message.isOpaque()), ":") + g.P("return ", message.GoIdent.GoName, "_", f.GoName, "_case") + } + g.P("default", ":") + g.P("return ", message.GoIdent.GoName, "_", oneof.GoName, "_not_set_case ") + g.P("}") + g.P("}") + g.P() + oneofIndex++ + } + } +} + +func opaqueNeedsPresenceArray(message *messageInfo) bool { + if !message.isOpaque() { + return false + } + for _, field := range message.Fields { + if opaqueFieldNeedsPresenceArray(message, field) { + return true + } + } + return false +} + +func opaqueNeedsLazyStruct(message *messageInfo) bool { + for _, field := range message.Fields { + if isLazy(field) { + return true + } + } + return false +} + +// opaqueGenMessageBuilder generates a Builder type for a message. +func opaqueGenMessageBuilder(g *protogen.GeneratedFile, f *fileInfo, message *messageInfo) { + if message.isOpen() { + return + } + // Builder type. + bName := g.QualifiedGoIdent(message.GoIdent) + genid.BuilderSuffix_goname + g.AnnotateSymbol(message.GoIdent.GoName+genid.BuilderSuffix_goname, protogen.Annotation{Location: message.Location}) + + leadingComments := appendDeprecationSuffix("", + message.Desc.ParentFile(), + message.Desc.Options().(*descriptorpb.MessageOptions).GetDeprecated()) + g.P(leadingComments, "type ", bName, " struct {") + g.P("_ [0]func() // Prevents comparability and use of unkeyed literals for the builder.") + g.P() + for _, field := range message.Fields { + oneof := field.Oneof + if oneof == nil && field.Desc.IsWeak() { + continue + } + + goType, pointer := opaqueBuilderFieldGoType(g, f, message, field) + if pointer { + goType = "*" + goType + } else if oneof != nil && fieldDefaultValue(g, f, message, field) != "nil" { + goType = "*" + goType + } + // Track all non-oneof fields. Note: synthetic oneofs are an + // implementation detail of proto3 optional fields: + // go/proto-proposals/proto3-presence.md, which should be tracked. + tag := "" + if (oneof == nil || oneof.Desc.IsSynthetic()) && message.isTracked { + tag = "`go:\"track\"`" + } + if oneof != nil && oneof.Fields[0] == field && !oneof.Desc.IsSynthetic() { + if oneof.Comments.Leading != "" { + g.P(oneof.Comments.Leading) + g.P() + } + g.P("// Fields of oneof ", opaqueOneofFieldName(oneof, message.isOpaque()), ":") + } + g.AnnotateSymbol(field.Parent.GoIdent.GoName+genid.BuilderSuffix_goname+"."+field.BuilderFieldName(), protogen.Annotation{Location: field.Location}) + leadingComments := appendDeprecationSuffix(field.Comments.Leading, + field.Desc.ParentFile(), + field.Desc.Options().(*descriptorpb.FieldOptions).GetDeprecated()) + g.P(leadingComments, + field.BuilderFieldName(), " ", goType, " ", tag) + if oneof != nil && oneof.Fields[len(oneof.Fields)-1] == field && !oneof.Desc.IsSynthetic() { + g.P("// -- end of ", opaqueOneofFieldName(oneof, message.isOpaque())) + } + } + g.P("}") + g.P() + + opaqueGenBuildMethod(g, f, message, bName) +} + +// opaqueGenBuildMethod generates the actual Build method for the builder +func opaqueGenBuildMethod(g *protogen.GeneratedFile, f *fileInfo, message *messageInfo, bName string) { + // Build method on the builder type. + fieldtrackNoInterface(g, message.noInterface) + g.P("func (b0 ", bName, ") Build() *", message.GoIdent, " {") + g.P("m0 := &", message.GoIdent, "{}") + + if message.isTracked { + // Redeclare the builder and message types as local + // defined types, so that field tracking records the + // field uses against these types instead of the + // original struct types. + // + // TODO: Actually redeclare the struct types + // without `go:"track"` tags? + g.P("type (notrackB ", bName, "; notrackM ", message.GoIdent, ")") + g.P("b, x := (*notrackB)(&b0), (*notrackM)(m0)") + } else { + g.P("b, x := &b0, m0") + } + g.P("_, _ = b, x") + + for _, field := range message.Fields { + oneof := field.Oneof + if oneof == nil && field.Desc.IsWeak() { + continue + } + if oneof != nil && !oneof.Desc.IsSynthetic() { + qual := "" + if fieldDefaultValue(g, f, message, field) != "nil" { + qual = "*" + } + + g.P("if b.", field.BuilderFieldName(), " != nil {") + oneofName := opaqueOneofFieldName(oneof, message.isOpaque()) + oneofType := opaqueFieldOneofType(field, message.isOpaque()) + g.P("x.", oneofName, " = &", oneofType, "{", qual, "b.", field.BuilderFieldName(), "}") + g.P("}") + } else { // proto3 optional ends up here (synthetic oneof) + qual := "" + _, pointer := opaqueBuilderFieldGoType(g, f, message, field) + if pointer && message.isOpaque() && !field.Desc.IsList() && field.Desc.Kind() != protoreflect.StringKind { + qual = "*" + } else if message.isOpaque() && field.Desc.IsList() && field.Desc.Message() != nil { + qual = "&" + } + presence := usePresence(message, field) + if presence { + g.P("if b.", field.BuilderFieldName(), " != nil {") + } + if presence { + pi := opaqueFieldPresenceIndex(field) + g.P(protoimplPackage.Ident("X"), ".SetPresentNonAtomic(&(x.XXX_presence[", pi/32, "]),", pi, ",", opaqueNumPresenceFields(message), ")") + } + goName := field.GoName + if message.isOpaque() { + goName = "xxx_hidden_" + goName + } + g.P("x.", goName, " = ", qual, "b.", field.BuilderFieldName()) + if presence { + g.P("}") + } + } + } + + g.P("return m0") + g.P("}") + g.P() +} + +// opaqueBuilderFieldGoType does the same as opaqueFieldGoType, but corrects for +// types that are different in a builder +func opaqueBuilderFieldGoType(g *protogen.GeneratedFile, f *fileInfo, message *messageInfo, field *protogen.Field) (goType string, pointer bool) { + goType, pointer = opaqueFieldGoType(g, f, message, field) + kind := field.Desc.Kind() + + // Use []T instead of *[]T for opaque repeated lists. + if message.isOpaque() && field.Desc.IsList() { + pointer = false + } + + // Use *T for optional fields. + optional := field.Desc.HasPresence() + if optional && + kind != protoreflect.GroupKind && + kind != protoreflect.MessageKind && + kind != protoreflect.BytesKind && + field.Desc.Cardinality() != protoreflect.Repeated { + pointer = true + } + + return goType, pointer +} + +func opaqueGenOneofWrapperTypes(g *protogen.GeneratedFile, f *fileInfo, message *messageInfo) { + // TODO: We should avoid generating these wrapper types in pure-opaque mode. + if !message.isOpen() { + for _, oneof := range message.Oneofs { + if oneof.Desc.IsSynthetic() { + continue + } + caseTypeName := opaqueOneofCaseTypeName(oneof) + g.P("type ", caseTypeName, " ", protoreflectPackage.Ident("FieldNumber")) + g.P("") + + idx := f.allMessagesByPtr[message] + typesVar := messageTypesVarName(f) + g.P("func (x ", caseTypeName, ") String() string {") + g.P("md := ", typesVar, "[", idx, "].Descriptor()") + g.P("if x == 0 {") + g.P(`return "not set"`) + g.P("}") + g.P("return ", protoimplPackage.Ident("X"), ".MessageFieldStringOf(md, ", protoreflectPackage.Ident("FieldNumber"), "(x))") + g.P("}") + g.P() + } + } + for _, oneof := range message.Oneofs { + if oneof.Desc.IsSynthetic() { + continue + } + ifName := opaqueOneofInterfaceName(oneof) + g.P("type ", ifName, " interface {") + g.P(ifName, "()") + g.P("}") + g.P() + for _, field := range oneof.Fields { + name := opaqueFieldOneofType(field, message.isOpaque()) + g.AnnotateSymbol(name.GoName, protogen.Annotation{Location: field.Location}) + g.AnnotateSymbol(name.GoName+"."+field.GoName, protogen.Annotation{Location: field.Location}) + g.P("type ", name, " struct {") + goType, _ := opaqueFieldGoType(g, f, message, field) + protobufTagValue := fieldProtobufTagValue(field) + if g.InternalStripForEditionsDiff() { + protobufTagValue = strings.ReplaceAll(protobufTagValue, ",proto3", "") + } + tags := structTags{ + {"protobuf", protobufTagValue}, + } + leadingComments := appendDeprecationSuffix(field.Comments.Leading, + field.Desc.ParentFile(), + field.Desc.Options().(*descriptorpb.FieldOptions).GetDeprecated()) + g.P(leadingComments, + field.GoName, " ", goType, tags, + trailingComment(field.Comments.Trailing)) + g.P("}") + g.P() + } + for _, field := range oneof.Fields { + g.P("func (*", opaqueFieldOneofType(field, message.isOpaque()), ") ", ifName, "() {}") + g.P() + } + } +} + +// opaqueFieldGoType returns the Go type used for a field. +// +// If it returns pointer=true, the struct field is a pointer to the type. +func opaqueFieldGoType(g *protogen.GeneratedFile, f *fileInfo, message *messageInfo, field *protogen.Field) (goType string, pointer bool) { + if field.Desc.IsWeak() { + return "struct{}", false + } + + pointer = true + switch field.Desc.Kind() { + case protoreflect.BoolKind: + goType = "bool" + case protoreflect.EnumKind: + goType = g.QualifiedGoIdent(field.Enum.GoIdent) + case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: + goType = "int32" + case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: + goType = "uint32" + case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: + goType = "int64" + case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: + goType = "uint64" + case protoreflect.FloatKind: + goType = "float32" + case protoreflect.DoubleKind: + goType = "float64" + case protoreflect.StringKind: + goType = "string" + case protoreflect.BytesKind: + goType = "[]byte" + pointer = false + case protoreflect.MessageKind, protoreflect.GroupKind: + goType = opaqueMessageFieldGoType(g, f, field, message.isOpaque()) + pointer = false + } + switch { + case field.Desc.IsList(): + goType = "[]" + goType + pointer = false + case field.Desc.IsMap(): + keyType, _ := opaqueFieldGoType(g, f, message, field.Message.Fields[0]) + valType, _ := opaqueFieldGoType(g, f, message, field.Message.Fields[1]) + return fmt.Sprintf("map[%v]%v", keyType, valType), false + } + + // Extension fields always have pointer type, even when defined in a proto3 file. + if !field.Desc.IsExtension() && !field.Desc.HasPresence() { + pointer = false + } + + if message.isOpaque() { + switch { + case field.Desc.IsList() && field.Desc.Message() != nil: + pointer = true + case !field.Desc.IsList() && field.Desc.Kind() == protoreflect.StringKind: + switch { + case field.Desc.HasPresence(): + pointer = true + default: + pointer = false + } + default: + pointer = false + } + } + + return goType, pointer +} + +func opaqueMessageFieldGoType(g *protogen.GeneratedFile, f *fileInfo, field *protogen.Field, isOpaque bool) string { + return "*" + g.QualifiedGoIdent(field.Message.GoIdent) +} + +// opaqueFieldPresenceIndex returns the index to pass to presence functions. +// +// TODO: field.Desc.Index() would be simpler, and would give space to record the presence of oneof fields. +func opaqueFieldPresenceIndex(field *protogen.Field) int { + structFieldIndex := 0 + for _, f := range field.Parent.Fields { + if field == f { + break + } + if f.Oneof == nil || isLastOneofField(f) { + structFieldIndex++ + } + } + return structFieldIndex +} + +// opaqueNumPresenceFields returns the number of fields that may be passed to presence functions. +// +// Since all fields in a oneof currently share a single entry in the presence bitmap, +// this is not just len(message.Fields). +func opaqueNumPresenceFields(message *messageInfo) int { + if len(message.Fields) == 0 { + return 0 + } + return opaqueFieldPresenceIndex(message.Fields[len(message.Fields)-1]) + 1 +} + +func fieldtrackNoInterface(g *protogen.GeneratedFile, isTracked bool) { + if isTracked { + g.P("//go:nointerface") + } +} + +// opaqueOneofFieldName returns the name of the struct field that holds +// the value of a oneof. +func opaqueOneofFieldName(oneof *protogen.Oneof, isOpaque bool) string { + if isOpaque { + return "xxx_hidden_" + oneof.GoName + } + return oneof.GoName +} + +func opaqueFieldOneofType(field *protogen.Field, isOpaque bool) protogen.GoIdent { + ident := protogen.GoIdent{ + GoImportPath: field.Parent.GoIdent.GoImportPath, + GoName: field.Parent.GoIdent.GoName + "_" + field.GoName, + } + // Check for collisions with nested messages or enums. + // + // This conflict resolution is incomplete: Among other things, it + // does not consider collisions with other oneof field types. +Loop: + for { + for _, message := range field.Parent.Messages { + if message.GoIdent == ident { + ident.GoName += "_" + continue Loop + } + } + for _, enum := range field.Parent.Enums { + if enum.GoIdent == ident { + ident.GoName += "_" + continue Loop + } + } + return unexportIdent(ident, isOpaque) + } +} + +// unexportIdent turns id into its unexported version (by lower-casing), but +// only if isOpaque is set. This function is used for oneof wrapper types, +// which remain exported in the non-opaque API for now. +func unexportIdent(id protogen.GoIdent, isOpaque bool) protogen.GoIdent { + if !isOpaque { + return id + } + r, sz := utf8.DecodeRuneInString(id.GoName) + if r == utf8.RuneError { + panic(fmt.Sprintf("Go identifier %q contains invalid UTF8?!", id.GoName)) + } + r = unicode.ToLower(r) + id.GoName = string(r) + id.GoName[sz:] + return id +} + +func opaqueOneofInterfaceName(oneof *protogen.Oneof) string { + return fmt.Sprintf("is%s_%s", oneof.Parent.GoIdent.GoName, oneof.GoName) +} +func opaqueOneofCaseTypeName(oneof *protogen.Oneof) string { + return fmt.Sprintf("case_%s_%s", oneof.Parent.GoIdent.GoName, oneof.GoName) +} + +// isFirstOneofField reports whether this is the first field in a oneof. +func isFirstOneofField(field *protogen.Field) bool { + return field.Oneof != nil && field == field.Oneof.Fields[0] && !field.Oneof.Desc.IsSynthetic() +} + +// isLastOneofField returns true if this is the last field in a oneof. +func isLastOneofField(field *protogen.Field) bool { + return field.Oneof != nil && field == field.Oneof.Fields[len(field.Oneof.Fields)-1] +} diff --git a/cmd/protoc-gen-go/internal_gengo/reflect.go b/cmd/protoc-gen-go/internal_gengo/reflect.go index 75939d96f..a3f91a85c 100644 --- a/cmd/protoc-gen-go/internal_gengo/reflect.go +++ b/cmd/protoc-gen-go/internal_gengo/reflect.go @@ -174,7 +174,7 @@ func genReflectFileDescriptor(gen *protogen.Plugin, g *protogen.GeneratedFile, f for _, oneof := range message.Oneofs { if !oneof.Desc.IsSynthetic() { for _, field := range oneof.Fields { - g.P("(*", field.GoIdent, ")(nil),") + g.P("(*", unexportIdent(field.GoIdent, message.isOpaque()), ")(nil),") } } } diff --git a/cmd/protoc-gen-go/name_clash_test/name_clash_proto3_test.go b/cmd/protoc-gen-go/name_clash_test/name_clash_proto3_test.go new file mode 100644 index 000000000..bd85f093e --- /dev/null +++ b/cmd/protoc-gen-go/name_clash_test/name_clash_proto3_test.go @@ -0,0 +1,810 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package name_clash_test + +import ( + "testing" + + "google.golang.org/protobuf/proto" + + hpb "google.golang.org/protobuf/cmd/protoc-gen-go/testdata/nameclash/test_name_clash_hybrid3" + opb "google.golang.org/protobuf/cmd/protoc-gen-go/testdata/nameclash/test_name_clash_opaque3" + pb "google.golang.org/protobuf/cmd/protoc-gen-go/testdata/nameclash/test_name_clash_open3" +) + +// TestOpenMangling3 tests the backwards compatible mangling of fields +// who clashes with the getters. The expected behavior, which is +// somewhat surprising, is documented in the proto +// test_name_clash_open.proto itself. +func TestOpenMangling3(t *testing.T) { + m1 := &pb.M1{ + Foo: makeOpenM0(1), + GetFoo_: makeOpenM0(2), + GetGetFoo: makeOpenM0(3), + } + if m1.GetFoo().GetI1() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m1.GetFoo().GetI1(), m1) + } + if m1.GetGetFoo_().GetI1() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m1.GetGetFoo_().GetI1(), m1) + } + if m1.GetGetGetFoo().GetI1() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m1.GetGetGetFoo().GetI1(), m1) + } + m2 := &pb.M2{ + Foo: makeOpenM0(1), + GetFoo_: makeOpenM0(2), + GetGetFoo: makeOpenM0(3), + } + if m2.GetFoo().GetI1() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m2.GetFoo().GetI1(), m2) + } + if m2.GetGetFoo_().GetI1() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m2.GetGetFoo_().GetI1(), m2) + } + if m2.GetGetGetFoo().GetI1() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m2.GetGetGetFoo().GetI1(), m2) + } + m3 := &pb.M3{ + Foo_: makeOpenM0(1), + GetFoo: makeOpenM0(2), + GetGetFoo_: makeOpenM0(3), + } + if m3.GetFoo_().GetI1() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m3.GetFoo_().GetI1(), m3) + } + if m3.GetGetFoo().GetI1() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m3.GetGetFoo().GetI1(), m3) + } + if m3.GetGetGetFoo_().GetI1() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m3.GetGetGetFoo_().GetI1(), m3) + } + + m4 := &pb.M4{ + GetFoo: makeOpenM0(2), + GetGetFoo_: &pb.M4_GetGetGetFoo{GetGetGetFoo: 3}, + Foo_: makeOpenM0(1), + } + if m4.GetFoo_().GetI1() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m4.GetFoo_().GetI1(), m4) + } + if m4.GetGetFoo().GetI1() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m4.GetGetFoo().GetI1(), m4) + } + if m4.GetGetGetGetFoo() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m4.GetGetGetGetFoo(), m4) + } + + m5 := &pb.M5{ + GetFoo: makeOpenM0(2), + GetGetGetFoo: &pb.M5_GetGetFoo_{GetGetFoo_: 3}, + Foo_: makeOpenM0(1), + } + if m5.GetFoo_().GetI1() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m5.GetFoo_().GetI1(), m5) + } + if m5.GetGetFoo().GetI1() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m5.GetGetFoo().GetI1(), m5) + } + if m5.GetGetGetFoo_() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m5.GetGetGetFoo_(), m5) + } + + m6 := &pb.M6{ + GetGetFoo: &pb.M6_GetGetGetFoo{GetGetGetFoo: 3}, + GetFoo_: makeOpenM0(2), + Foo: makeOpenM0(1), + } + if m6.GetFoo().GetI1() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m6.GetFoo().GetI1(), m6) + } + if m6.GetGetFoo_().GetI1() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m6.GetGetFoo_().GetI1(), m6) + } + if m6.GetGetGetGetFoo() != 3 { + t.Errorf("Proto field 'get_get_get_foo' has unexpected value %v for %T (expected 3)", m6.GetGetGetGetFoo(), m6) + } + + m7 := &pb.M7{ + GetGetFoo: &pb.M7_GetFoo_{GetFoo_: 3}, + Foo: makeOpenM0(1), + } + if m7.GetFoo().GetI1() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m7.GetFoo().GetI1(), m7) + } + if m7.GetGetFoo_() != 3 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m7.GetGetFoo_(), m7) + } + m7.GetGetFoo = &pb.M7_Bar{Bar: true} + if !m7.GetBar() { + t.Errorf("Proto field 'bar' has unexpected value %v for %T (expected 3)", m7.GetBar(), m7) + } + + m8 := &pb.M8{ + GetGetGetFoo_: &pb.M8_GetGetFoo{GetGetFoo: 3}, + GetFoo_: makeOpenM0(2), + Foo: makeOpenM0(1), + } + if m8.GetFoo().GetI1() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m8.GetFoo().GetI1(), m8) + } + if m8.GetGetFoo_().GetI1() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m8.GetGetFoo_().GetI1(), m8) + } + if m8.GetGetGetFoo() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m8.GetGetGetFoo(), m8) + } + + m9 := &pb.M9{ + GetGetGetFoo_: &pb.M9_GetGetFoo{GetGetFoo: 3}, + Foo: makeOpenM0(1), + } + if m9.GetFoo().GetI1() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m9.GetFoo(), m9) + } + if m9.GetGetGetFoo() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m9.GetGetGetFoo(), m9) + } + m9.GetGetGetFoo_ = &pb.M9_GetFoo_{GetFoo_: 2} + if m9.GetGetFoo_() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m9.GetGetFoo_(), m9) + } + +} + +// TestHybridMangling3 tests the backwards compatible mangling as well +// as new style mangling of fields who clashes with the getters. The +// expected behavior, which is somewhat surprising, is documented in +// the proto test_name_clash_hybrid.proto itself. +func TestHybridMangling3(t *testing.T) { + m1 := hpb.M1_builder{ + Foo: makeHybridM0(1), + GetFoo: makeHybridM0(2), + GetGetFoo: makeHybridM0(3), + }.Build() + if m1.GetFoo().GetI1() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m1.GetFoo().GetI1(), m1) + } + if m1.Get_Foo().GetI1() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m1.GetFoo().GetI1(), m1) + } + if m1.GetGetFoo_().GetI1() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m1.GetGetFoo_().GetI1(), m1) + } + if m1.Get_GetFoo().GetI1() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m1.GetGetFoo_().GetI1(), m1) + } + if m1.GetGetGetFoo().GetI1() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m1.GetGetGetFoo().GetI1(), m1) + } + checkNameConsistency(t, m1) + m2 := hpb.M2_builder{ + Foo: makeHybridM0(1), + GetFoo: makeHybridM0(2), + GetGetFoo: makeHybridM0(3), + }.Build() + if m2.GetFoo().GetI1() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m2.GetFoo().GetI1(), m2) + } + if m2.Get_Foo().GetI1() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m2.GetFoo().GetI1(), m2) + } + if m2.GetGetFoo_().GetI1() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m2.GetGetFoo_().GetI1(), m2) + } + if m2.Get_GetFoo().GetI1() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m2.GetGetFoo_().GetI1(), m2) + } + if m2.GetGetGetFoo().GetI1() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m2.GetGetGetFoo().GetI1(), m2) + } + checkNameConsistency(t, m2) + m3 := hpb.M3_builder{ + Foo: makeHybridM0(1), + GetFoo: makeHybridM0(2), + GetGetFoo: makeHybridM0(3), + }.Build() + if m3.GetFoo_().GetI1() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m3.GetFoo_().GetI1(), m3) + } + if m3.Get_Foo().GetI1() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m3.GetFoo_().GetI1(), m3) + } + if m3.GetGetFoo().GetI1() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m3.GetGetFoo().GetI1(), m3) + } + if m3.Get_GetFoo().GetI1() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m3.GetGetFoo().GetI1(), m3) + } + if m3.GetGetGetFoo_().GetI1() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m3.GetGetGetFoo_().GetI1(), m3) + } + checkNameConsistency(t, m3) + + m4 := hpb.M4_builder{ + GetFoo: makeHybridM0(2), + GetGetGetFoo: proto.Int32(3), + Foo: makeHybridM0(1), + }.Build() + if m4.GetFoo_().GetI1() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m4.GetFoo_().GetI1(), m4) + } + if m4.Get_Foo().GetI1() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m4.Get_Foo().GetI1(), m4) + } + if m4.GetGetFoo().GetI1() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m4.GetGetFoo().GetI1(), m4) + } + if m4.Get_GetFoo().GetI1() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m4.Get_GetFoo().GetI1(), m4) + } + if m4.GetGetGetFoo_().(*hpb.M4_GetGetGetFoo).GetGetGetFoo != 3 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 3)", m4.GetGetGetFoo_(), m4) + } + if !m4.HasGetGetFoo() { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected true)", m4.HasGetGetFoo(), m4) + } + if m4.GetGetGetGetFoo() != 3 { + t.Errorf("Proto field 'get_get_get_foo' has unexpected value %v for %T (expected 3)", m4.GetGetGetGetFoo(), m4) + } + checkNameConsistency(t, m4) + + m5 := hpb.M5_builder{ + GetFoo: makeHybridM0(2), + GetGetFoo: proto.Int32(3), + Foo: makeHybridM0(1), + }.Build() + if m5.GetFoo_().GetI1() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m5.GetFoo_().GetI1(), m5) + } + if m5.Get_Foo().GetI1() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m5.Get_Foo().GetI1(), m4) + } + if m5.GetGetFoo().GetI1() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m5.GetGetFoo().GetI1(), m5) + } + if m5.Get_GetFoo().GetI1() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m5.Get_GetFoo().GetI1(), m4) + } + if m5.GetGetGetFoo_() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m5.GetGetGetFoo_(), m5) + } + if m5.Get_GetGetFoo() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m5.Get_GetGetFoo(), m5) + } + checkNameConsistency(t, m5) + + m6 := hpb.M6_builder{ + GetGetGetFoo: proto.Int32(3), + GetFoo: makeHybridM0(2), + Foo: makeHybridM0(1), + }.Build() + if m6.GetFoo().GetI1() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m6.GetFoo().GetI1(), m6) + } + if m6.Get_Foo().GetI1() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m6.Get_Foo().GetI1(), m6) + } + if m6.GetGetFoo_().GetI1() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m6.GetGetFoo_().GetI1(), m6) + } + if m6.Get_GetFoo().GetI1() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m6.Get_GetFoo().GetI1(), m6) + } + if m6.GetGetGetGetFoo() != 3 { + t.Errorf("Proto field 'get_get_get_foo' has unexpected value %v for %T (expected 3)", m6.GetGetGetGetFoo(), m6) + } + checkNameConsistency(t, m6) + + m7 := hpb.M7_builder{ + GetFoo: proto.Int32(3), + Foo: makeHybridM0(1), + }.Build() + if m7.GetFoo().GetI1() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m7.GetFoo().GetI1(), m7) + } + if m7.Get_Foo().GetI1() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m7.Get_Foo().GetI1(), m7) + } + if m7.GetGetFoo_() != 3 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 3)", m7.GetGetFoo_(), m7) + } + if m7.Get_GetFoo() != 3 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 3)", m7.Get_GetFoo(), m7) + } + m7.SetBar(true) + if !m7.GetBar() { + t.Errorf("Proto field 'bar' has unexpected value %v for %T (expected 3)", m7.GetBar(), m7) + } + checkNameConsistency(t, m7) + + m8 := hpb.M8_builder{ + GetGetFoo: proto.Int32(3), + GetFoo: makeHybridM0(2), + Foo: makeHybridM0(1), + }.Build() + if m8.GetFoo().GetI1() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m8.GetFoo().GetI1(), m8) + } + if m8.Get_Foo().GetI1() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m8.Get_Foo().GetI1(), m8) + } + if m8.GetGetFoo_().GetI1() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m8.GetGetFoo_().GetI1(), m8) + } + if m8.Get_GetFoo().GetI1() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m8.Get_GetFoo().GetI1(), m8) + } + if m8.GetGetGetFoo() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m8.GetGetGetFoo(), m8) + } + checkNameConsistency(t, m8) + + m9 := hpb.M9_builder{ + GetGetFoo: proto.Int32(3), + Foo: makeHybridM0(1), + }.Build() + if m9.GetFoo().GetI1() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m9.GetFoo().GetI1(), m9) + } + if m9.Get_Foo().GetI1() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m9.Get_Foo().GetI1(), m9) + } + if m9.GetGetGetFoo() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m9.GetGetGetFoo(), m9) + } + if m9.Get_GetGetFoo() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m9.Get_GetGetFoo(), m9) + } + m9.Set_GetFoo(2) + if m9.GetGetFoo_() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m9.GetGetFoo_(), m9) + } + if m9.Get_GetFoo() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m9.Get_GetFoo(), m9) + } + checkNameConsistency(t, m9) + m10 := hpb.M10_builder{ + Foo: makeHybridM0(1), + SetFoo: makeHybridM0(2), + }.Build() + m10.Set_Foo(makeHybridM0(47)) + if m10.Get_Foo().GetI1() != 47 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 47)", m10.Get_Foo().GetI1(), m10) + } + m10.SetSetFoo(makeHybridM0(11)) + if m10.GetSetFoo().GetI1() != 11 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 11)", m10.GetSetFoo().GetI1(), m10) + } + checkNameConsistency(t, m10) + m11 := hpb.M11_builder{ + Foo: makeHybridM0(1), + SetSetFoo: proto.Int32(2), + }.Build() + m11.Set_Foo(makeHybridM0(47)) + if m11.Get_Foo().GetI1() != 47 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 47)", m11.Get_Foo().GetI1(), m11) + } + m11.SetSetSetFoo(11) + if m11.GetSetSetFoo() != 11 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 11)", m11.GetSetSetFoo(), m11) + } + checkNameConsistency(t, m11) + m12 := hpb.M12_builder{ + Foo: makeHybridM0(1), + SetFoo: proto.Int32(2), + }.Build() + m12.Set_Foo(makeHybridM0(47)) + if m12.Get_Foo().GetI1() != 47 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 47)", m12.Get_Foo().GetI1(), m12) + } + m12.Set_SetFoo(11) + if m12.Get_SetFoo() != 11 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 11)", m12.Get_SetFoo(), m12) + } + checkNameConsistency(t, m12) + m13 := hpb.M13_builder{ + Foo: makeHybridM0(1), + HasFoo: makeHybridM0(2), + }.Build() + if !m13.Has_Foo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected true)", m13.Has_Foo(), m13) + } + if !m13.HasHasFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected true)", m13.HasHasFoo(), m13) + } + checkNameConsistency(t, m13) + m14 := hpb.M14_builder{ + Foo: makeHybridM0(1), + HasHasFoo: proto.Int32(2), + }.Build() + if !m14.Has_Foo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected true)", m14.Has_Foo(), m14) + } + if !m14.Has_HasFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected true)", m14.Has_HasFoo(), m14) + } + if !m14.HasHasHasFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected true)", m14.HasHasHasFoo(), m14) + } + checkNameConsistency(t, m14) + m15 := hpb.M15_builder{ + Foo: makeHybridM0(1), + HasFoo: proto.Int32(2), + }.Build() + if !m15.Has_Foo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected true)", m15.Has_Foo(), m15) + } + if !m15.Has_HasFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected true)", m15.Has_HasFoo(), m15) + } + if !m15.HasHasHasFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected true)", m15.HasHasHasFoo(), m15) + } + checkNameConsistency(t, m15) + m16 := hpb.M16_builder{ + Foo: makeHybridM0(1), + ClearFoo: makeHybridM0(2), + }.Build() + m16.Clear_Foo() + if m16.Has_Foo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected false)", m16.Has_Foo(), m16) + } + m16.ClearClearFoo() + if m16.HasClearFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected false)", m16.HasClearFoo(), m16) + } + checkNameConsistency(t, m16) + m17 := hpb.M17_builder{ + Foo: makeHybridM0(1), + ClearClearFoo: proto.Int32(2), + }.Build() + m17.Clear_Foo() + if m17.Has_Foo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected false)", m17.Has_Foo(), m17) + } + m17.ClearClearClearFoo() + if m17.HasClearClearFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected false)", m17.HasClearClearFoo(), m17) + } + checkNameConsistency(t, m17) + m18 := hpb.M18_builder{ + Foo: makeHybridM0(1), + ClearFoo: proto.Int32(2), + }.Build() + m18.Clear_Foo() + if m18.Has_Foo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected false)", m18.Has_Foo(), m18) + } + m18.Clear_ClearFoo() + if m18.Has_ClearFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected false)", m18.Has_ClearFoo(), m18) + } + checkNameConsistency(t, m18) + m19 := hpb.M19_builder{ + Foo: makeHybridM0(1), + WhichFoo: proto.Int32(2), + }.Build() + if m19.WhichWhichWhichFoo() != hpb.M19_WhichFoo_case { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected M19_ClearFoo_case)", m19.WhichWhichWhichFoo(), m19) + } + checkNameConsistency(t, m19) + m20 := hpb.M20_builder{ + Foo: makeHybridM0(1), + WhichWhichFoo: proto.Int32(2), + }.Build() + if m20.Which_WhichFoo() != hpb.M20_WhichWhichFoo_case { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected M20_WhichWhichFoo_case)", m20.Which_WhichFoo(), m20) + } + checkNameConsistency(t, m20) + +} + +// TestOpaqueMangling3 tests the backwards compatible mangling as well +// as new style mangling of fields who clashes with the getters. The +// expected behavior, which is somewhat surprising, is documented in +// the proto test_name_clash_opaque.proto itself. +func TestOpaqueMangling3(t *testing.T) { + m1 := opb.M1_builder{ + Foo: makeOpaqueM0(1), + GetFoo: makeOpaqueM0(2), + GetGetFoo: makeOpaqueM0(3), + }.Build() + if m1.GetFoo().GetI1() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m1.GetFoo().GetI1(), m1) + } + if m1.GetGetFoo().GetI1() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m1.GetGetFoo().GetI1(), m1) + } + if m1.GetGetGetFoo().GetI1() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m1.GetGetGetFoo().GetI1(), m1) + } + checkNameConsistency(t, m1) + m2 := opb.M2_builder{ + Foo: makeOpaqueM0(1), + GetFoo: makeOpaqueM0(2), + GetGetFoo: makeOpaqueM0(3), + }.Build() + if m2.GetFoo().GetI1() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m2.GetFoo().GetI1(), m2) + } + if m2.GetGetFoo().GetI1() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m2.GetGetFoo().GetI1(), m2) + } + if m2.GetGetGetFoo().GetI1() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m2.GetGetGetFoo().GetI1(), m2) + } + checkNameConsistency(t, m2) + m3 := opb.M3_builder{ + Foo: makeOpaqueM0(1), + GetFoo: makeOpaqueM0(2), + GetGetFoo: makeOpaqueM0(3), + }.Build() + if m3.GetFoo().GetI1() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m3.GetFoo().GetI1(), m3) + } + if m3.GetGetFoo().GetI1() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m3.GetGetFoo().GetI1(), m3) + } + if m3.GetGetGetFoo().GetI1() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m3.GetGetGetFoo().GetI1(), m3) + } + checkNameConsistency(t, m3) + + m4 := opb.M4_builder{ + GetFoo: makeOpaqueM0(2), + GetGetGetFoo: proto.Int32(3), + Foo: makeOpaqueM0(1), + }.Build() + if m4.GetFoo().GetI1() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m4.GetFoo().GetI1(), m4) + } + if m4.GetGetFoo().GetI1() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m4.GetGetFoo().GetI1(), m4) + } + if !m4.HasGetGetFoo() { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected true)", m4.HasGetGetFoo(), m4) + } + if m4.GetGetGetGetFoo() != 3 { + t.Errorf("Proto field 'get_get_get_foo' has unexpected value %v for %T (expected 3)", m4.GetGetGetGetFoo(), m4) + } + checkNameConsistency(t, m4) + + m5 := opb.M5_builder{ + GetFoo: makeOpaqueM0(2), + GetGetFoo: proto.Int32(3), + Foo: makeOpaqueM0(1), + }.Build() + if m5.GetFoo().GetI1() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m5.GetFoo().GetI1(), m5) + } + if m5.GetGetFoo().GetI1() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m5.GetGetFoo().GetI1(), m5) + } + if m5.GetGetGetFoo() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m5.GetGetGetFoo(), m5) + } + checkNameConsistency(t, m5) + + m6 := opb.M6_builder{ + GetGetGetFoo: proto.Int32(3), + GetFoo: makeOpaqueM0(2), + Foo: makeOpaqueM0(1), + }.Build() + if m6.GetFoo().GetI1() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m6.GetFoo().GetI1(), m6) + } + if m6.GetGetFoo().GetI1() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m6.GetGetFoo().GetI1(), m6) + } + if m6.GetGetGetGetFoo() != 3 { + t.Errorf("Proto field 'get_get_get_foo' has unexpected value %v for %T (expected 3)", m6.GetGetGetGetFoo(), m6) + } + checkNameConsistency(t, m6) + + m7 := opb.M7_builder{ + GetFoo: proto.Int32(3), + Foo: makeOpaqueM0(1), + }.Build() + if m7.GetFoo().GetI1() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m7.GetFoo().GetI1(), m7) + } + if m7.GetGetFoo() != 3 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 3)", m7.GetGetFoo(), m7) + } + m7.SetBar(true) + if !m7.GetBar() { + t.Errorf("Proto field 'bar' has unexpected value %v for %T (expected true)", m7.GetBar(), m7) + } + checkNameConsistency(t, m7) + + m8 := opb.M8_builder{ + GetGetFoo: proto.Int32(3), + GetFoo: makeOpaqueM0(2), + Foo: makeOpaqueM0(1), + }.Build() + if m8.GetFoo().GetI1() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m8.GetFoo().GetI1(), m8) + } + if m8.GetGetFoo().GetI1() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m8.GetGetFoo().GetI1(), m8) + } + if m8.GetGetGetFoo() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m8.GetGetGetFoo(), m8) + } + checkNameConsistency(t, m8) + + m9 := opb.M9_builder{ + GetGetFoo: proto.Int32(3), + Foo: makeOpaqueM0(1), + }.Build() + if m9.GetFoo().GetI1() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m9.GetFoo().GetI1(), m9) + } + if m9.GetGetGetFoo() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m9.GetGetGetFoo(), m9) + } + m9.SetGetFoo(2) + if m9.GetGetFoo() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m9.GetGetFoo(), m9) + } + checkNameConsistency(t, m9) + m10 := opb.M10_builder{ + Foo: makeOpaqueM0(1), + SetFoo: makeOpaqueM0(2), + }.Build() + m10.SetFoo(makeOpaqueM0(48)) + if m10.GetFoo().GetI1() != 48 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 48)", m10.GetFoo().GetI1(), m10) + } + m10.SetSetFoo(makeOpaqueM0(11)) + if m10.GetSetFoo().GetI1() != 11 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 11)", m10.GetSetFoo().GetI1(), m10) + } + checkNameConsistency(t, m10) + m11 := opb.M11_builder{ + Foo: makeOpaqueM0(1), + SetSetFoo: proto.Int32(2), + }.Build() + m11.SetFoo(makeOpaqueM0(48)) + if m11.GetFoo().GetI1() != 48 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 48)", m11.GetFoo().GetI1(), m11) + } + m11.SetSetSetFoo(11) + if m11.GetSetSetFoo() != 11 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 11)", m11.GetSetSetFoo(), m11) + } + checkNameConsistency(t, m11) + m12 := opb.M12_builder{ + Foo: makeOpaqueM0(1), + SetFoo: proto.Int32(2), + }.Build() + m12.SetFoo(makeOpaqueM0(48)) + if m12.GetFoo().GetI1() != 48 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 48)", m12.GetFoo().GetI1(), m12) + } + m12.SetSetFoo(12) + if m12.GetSetFoo() != 12 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 12)", m12.GetSetFoo(), m12) + } + checkNameConsistency(t, m12) + m13 := opb.M13_builder{ + Foo: makeOpaqueM0(1), + HasFoo: makeOpaqueM0(2), + }.Build() + if !m13.HasFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected true)", m13.HasFoo(), m13) + } + if !m13.HasHasFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected true)", m13.HasHasFoo(), m13) + } + checkNameConsistency(t, m13) + m14 := opb.M14_builder{ + Foo: makeOpaqueM0(1), + HasHasFoo: proto.Int32(2), + }.Build() + if !m14.HasFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected true)", m14.HasFoo(), m14) + } + if !m14.HasHasFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected true)", m14.HasHasFoo(), m14) + } + if !m14.HasHasHasFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected true)", m14.HasHasHasFoo(), m14) + } + checkNameConsistency(t, m14) + m15 := opb.M15_builder{ + Foo: makeOpaqueM0(1), + HasFoo: proto.Int32(2), + }.Build() + if !m15.HasFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected true)", m15.HasFoo(), m15) + } + if !m15.HasHasFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected true)", m15.HasHasFoo(), m15) + } + if !m15.HasHasHasFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected true)", m15.HasHasHasFoo(), m15) + } + checkNameConsistency(t, m15) + m16 := opb.M16_builder{ + Foo: makeOpaqueM0(1), + ClearFoo: makeOpaqueM0(2), + }.Build() + m16.SetFoo(makeOpaqueM0(4711)) + m16.ClearFoo() + if m16.HasFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected false)", m16.HasFoo(), m16) + } + m16.ClearClearFoo() + if m16.HasClearFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected false)", m16.HasClearFoo(), m16) + } + checkNameConsistency(t, m16) + m17 := opb.M17_builder{ + Foo: makeOpaqueM0(1), + ClearClearFoo: proto.Int32(2), + }.Build() + m17.SetFoo(makeOpaqueM0(4711)) + m17.ClearFoo() + if m17.HasFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected false)", m17.HasFoo(), m17) + } + m17.ClearClearClearFoo() + if m17.HasClearClearFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected false)", m17.HasClearClearFoo(), m17) + } + checkNameConsistency(t, m17) + m18 := opb.M18_builder{ + Foo: makeOpaqueM0(1), + ClearFoo: proto.Int32(2), + }.Build() + m18.SetFoo(makeOpaqueM0(4711)) + m18.ClearFoo() + if m18.HasFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected false)", m18.HasFoo(), m18) + } + m18.SetClearFoo(13) + m18.ClearClearFoo() + if m18.HasClearFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected false)", m18.HasClearFoo(), m18) + } + checkNameConsistency(t, m18) + m19 := opb.M19_builder{ + Foo: makeOpaqueM0(1), + WhichFoo: proto.Int32(2), + }.Build() + if m19.WhichWhichWhichFoo() != opb.M19_WhichFoo_case { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected M19_ClearFoo_case)", m19.WhichWhichWhichFoo(), m19) + } + checkNameConsistency(t, m19) + m20 := opb.M20_builder{ + Foo: makeOpaqueM0(1), + WhichWhichFoo: proto.Int32(2), + }.Build() + if m20.WhichWhichFoo() != opb.M20_WhichWhichFoo_case { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected M20_WhichWhichFoo_case)", m20.WhichWhichFoo(), m20) + } + checkNameConsistency(t, m20) + +} + +func makeOpenM0(x int32) *pb.M0 { + return &pb.M0{ + I1: x, + } +} + +func makeHybridM0(x int32) *hpb.M0 { + return hpb.M0_builder{ + I1: x, + }.Build() +} + +func makeOpaqueM0(x int32) *opb.M0 { + return opb.M0_builder{ + I1: x, + }.Build() +} diff --git a/cmd/protoc-gen-go/name_clash_test/name_clash_test.go b/cmd/protoc-gen-go/name_clash_test/name_clash_test.go new file mode 100644 index 000000000..6062c6343 --- /dev/null +++ b/cmd/protoc-gen-go/name_clash_test/name_clash_test.go @@ -0,0 +1,898 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package name_clash_test + +import ( + "reflect" + "testing" + + "google.golang.org/protobuf/compiler/protogen" + "google.golang.org/protobuf/internal/genid" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protodesc" + descpb "google.golang.org/protobuf/types/descriptorpb" + "google.golang.org/protobuf/types/gofeaturespb" + "google.golang.org/protobuf/types/pluginpb" + + hpb "google.golang.org/protobuf/cmd/protoc-gen-go/testdata/nameclash/test_name_clash_hybrid" + opb "google.golang.org/protobuf/cmd/protoc-gen-go/testdata/nameclash/test_name_clash_opaque" + pb "google.golang.org/protobuf/cmd/protoc-gen-go/testdata/nameclash/test_name_clash_open" +) + +// TestOpenMangling tests the backwards compatible mangling of fields +// who clashes with the getters. The expected behavior, which is +// somewhat surprising, is documented in the proto +// test_name_clash_open.proto itself. +func TestOpenMangling(t *testing.T) { + m1 := &pb.M1{ + Foo: proto.Int32(1), + GetFoo_: proto.Int32(2), + GetGetFoo: proto.Int32(3), + } + if m1.GetFoo() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m1.GetFoo(), m1) + } + if m1.GetGetFoo_() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m1.GetGetFoo_(), m1) + } + if m1.GetGetGetFoo() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m1.GetGetGetFoo(), m1) + } + m2 := &pb.M2{ + Foo: proto.Int32(1), + GetFoo_: proto.Int32(2), + GetGetFoo: proto.Int32(3), + } + if m2.GetFoo() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m2.GetFoo(), m2) + } + if m2.GetGetFoo_() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m2.GetGetFoo_(), m2) + } + if m2.GetGetGetFoo() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m2.GetGetGetFoo(), m2) + } + m3 := &pb.M3{ + Foo_: proto.Int32(1), + GetFoo: proto.Int32(2), + GetGetFoo_: proto.Int32(3), + } + if m3.GetFoo_() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m3.GetFoo_(), m3) + } + if m3.GetGetFoo() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m3.GetGetFoo(), m3) + } + if m3.GetGetGetFoo_() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m3.GetGetGetFoo_(), m3) + } + + m4 := &pb.M4{ + GetFoo: proto.Int32(2), + GetGetFoo_: &pb.M4_GetGetGetFoo{GetGetGetFoo: 3}, + Foo_: proto.Int32(1), + } + if m4.GetFoo_() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m4.GetFoo_(), m4) + } + if m4.GetGetFoo() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m4.GetGetFoo(), m4) + } + if m4.GetGetGetGetFoo() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m4.GetGetGetGetFoo(), m4) + } + + m5 := &pb.M5{ + GetFoo: proto.Int32(2), + GetGetGetFoo: &pb.M5_GetGetFoo_{GetGetFoo_: 3}, + Foo_: proto.Int32(1), + } + if m5.GetFoo_() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m5.GetFoo_(), m5) + } + if m5.GetGetFoo() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m5.GetGetFoo(), m5) + } + if m5.GetGetGetFoo_() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m5.GetGetGetFoo_(), m5) + } + + m6 := &pb.M6{ + GetGetFoo: &pb.M6_GetGetGetFoo{GetGetGetFoo: 3}, + GetFoo_: proto.Int32(2), + Foo: proto.Int32(1), + } + if m6.GetFoo() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m6.GetFoo(), m6) + } + if m6.GetGetFoo_() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m6.GetGetFoo_(), m6) + } + if m6.GetGetGetGetFoo() != 3 { + t.Errorf("Proto field 'get_get_get_foo' has unexpected value %v for %T (expected 3)", m6.GetGetGetGetFoo(), m6) + } + + m7 := &pb.M7{ + GetGetFoo: &pb.M7_GetFoo_{GetFoo_: 3}, + Foo: proto.Int32(1), + } + if m7.GetFoo() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m7.GetFoo(), m7) + } + if m7.GetGetFoo_() != 3 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m7.GetGetFoo_(), m7) + } + m7.GetGetFoo = &pb.M7_Bar{Bar: true} + if !m7.GetBar() { + t.Errorf("Proto field 'bar' has unexpected value %v for %T (expected 3)", m7.GetBar(), m7) + } + + m8 := &pb.M8{ + GetGetGetFoo_: &pb.M8_GetGetFoo{GetGetFoo: 3}, + GetFoo_: proto.Int32(2), + Foo: proto.Int32(1), + } + if m8.GetFoo() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m8.GetFoo(), m8) + } + if m8.GetGetFoo_() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m8.GetGetFoo_(), m8) + } + if m8.GetGetGetFoo() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m8.GetGetGetFoo(), m8) + } + + m9 := &pb.M9{ + GetGetGetFoo_: &pb.M9_GetGetFoo{GetGetFoo: 3}, + Foo: proto.Int32(1), + } + if m9.GetFoo() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m9.GetFoo(), m9) + } + if m9.GetGetGetFoo() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m9.GetGetGetFoo(), m9) + } + m9.GetGetGetFoo_ = &pb.M9_GetFoo_{GetFoo_: 2} + if m9.GetGetFoo_() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m9.GetGetFoo_(), m9) + } + +} + +// TestHybridMangling tests the backwards compatible mangling as well +// as new style mangling of fields who clashes with the getters. The +// expected behavior, which is somewhat surprising, is documented in +// the proto test_name_clash_hybrid.proto itself. +func TestHybridMangling(t *testing.T) { + m1 := hpb.M1_builder{ + Foo: proto.Int32(1), + GetFoo: proto.Int32(2), + GetGetFoo: proto.Int32(3), + }.Build() + if m1.GetFoo() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m1.GetFoo(), m1) + } + if m1.Get_Foo() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m1.GetFoo(), m1) + } + if m1.GetGetFoo_() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m1.GetGetFoo_(), m1) + } + if m1.Get_GetFoo() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m1.GetGetFoo_(), m1) + } + if m1.GetGetGetFoo() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m1.GetGetGetFoo(), m1) + } + checkNameConsistency(t, m1) + m2 := hpb.M2_builder{ + Foo: proto.Int32(1), + GetFoo: proto.Int32(2), + GetGetFoo: proto.Int32(3), + }.Build() + if m2.GetFoo() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m2.GetFoo(), m2) + } + if m2.Get_Foo() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m2.GetFoo(), m2) + } + if m2.GetGetFoo_() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m2.GetGetFoo_(), m2) + } + if m2.Get_GetFoo() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m2.GetGetFoo_(), m2) + } + if m2.GetGetGetFoo() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m2.GetGetGetFoo(), m2) + } + checkNameConsistency(t, m2) + m3 := hpb.M3_builder{ + Foo: proto.Int32(1), + GetFoo: proto.Int32(2), + GetGetFoo: proto.Int32(3), + }.Build() + if m3.GetFoo_() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m3.GetFoo_(), m3) + } + if m3.Get_Foo() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m3.GetFoo_(), m3) + } + if m3.GetGetFoo() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m3.GetGetFoo(), m3) + } + if m3.Get_GetFoo() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m3.GetGetFoo(), m3) + } + if m3.GetGetGetFoo_() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m3.GetGetGetFoo_(), m3) + } + checkNameConsistency(t, m3) + + m4 := hpb.M4_builder{ + GetFoo: proto.Int32(2), + GetGetGetFoo: proto.Int32(3), + Foo: proto.Int32(1), + }.Build() + if m4.GetFoo_() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m4.GetFoo_(), m4) + } + if m4.Get_Foo() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m4.Get_Foo(), m4) + } + if m4.GetGetFoo() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m4.GetGetFoo(), m4) + } + if m4.Get_GetFoo() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m4.Get_GetFoo(), m4) + } + if m4.GetGetGetFoo_().(*hpb.M4_GetGetGetFoo).GetGetGetFoo != 3 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 3)", m4.GetGetGetFoo_(), m4) + } + if !m4.HasGetGetFoo() { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected true)", m4.HasGetGetFoo(), m4) + } + if m4.GetGetGetGetFoo() != 3 { + t.Errorf("Proto field 'get_get_get_foo' has unexpected value %v for %T (expected 3)", m4.GetGetGetGetFoo(), m4) + } + checkNameConsistency(t, m4) + + m5 := hpb.M5_builder{ + GetFoo: proto.Int32(2), + GetGetFoo: proto.Int32(3), + Foo: proto.Int32(1), + }.Build() + if m5.GetFoo_() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m5.GetFoo_(), m5) + } + if m5.Get_Foo() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m5.Get_Foo(), m4) + } + if m5.GetGetFoo() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m5.GetGetFoo(), m5) + } + if m5.Get_GetFoo() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m5.Get_GetFoo(), m4) + } + if m5.GetGetGetFoo_() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m5.GetGetGetFoo_(), m5) + } + if m5.Get_GetGetFoo() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m5.Get_GetGetFoo(), m5) + } + checkNameConsistency(t, m5) + + m6 := hpb.M6_builder{ + GetGetGetFoo: proto.Int32(3), + GetFoo: proto.Int32(2), + Foo: proto.Int32(1), + }.Build() + if m6.GetFoo() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m6.GetFoo(), m6) + } + if m6.Get_Foo() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m6.Get_Foo(), m6) + } + if m6.GetGetFoo_() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m6.GetGetFoo_(), m6) + } + if m6.Get_GetFoo() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m6.Get_GetFoo(), m6) + } + if m6.GetGetGetGetFoo() != 3 { + t.Errorf("Proto field 'get_get_get_foo' has unexpected value %v for %T (expected 3)", m6.GetGetGetGetFoo(), m6) + } + checkNameConsistency(t, m6) + + m7 := hpb.M7_builder{ + GetFoo: proto.Int32(3), + Foo: proto.Int32(1), + }.Build() + if m7.GetFoo() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m7.GetFoo(), m7) + } + if m7.Get_Foo() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m7.Get_Foo(), m7) + } + if m7.GetGetFoo_() != 3 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 3)", m7.GetGetFoo_(), m7) + } + if m7.Get_GetFoo() != 3 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 3)", m7.Get_GetFoo(), m7) + } + m7.SetBar(true) + if !m7.GetBar() { + t.Errorf("Proto field 'bar' has unexpected value %v for %T (expected 3)", m7.GetBar(), m7) + } + checkNameConsistency(t, m7) + + m8 := hpb.M8_builder{ + GetGetFoo: proto.Int32(3), + GetFoo: proto.Int32(2), + Foo: proto.Int32(1), + }.Build() + if m8.GetFoo() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m8.GetFoo(), m8) + } + if m8.Get_Foo() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m8.Get_Foo(), m8) + } + if m8.GetGetFoo_() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m8.GetGetFoo_(), m8) + } + if m8.Get_GetFoo() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m8.Get_GetFoo(), m8) + } + if m8.GetGetGetFoo() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m8.GetGetGetFoo(), m8) + } + checkNameConsistency(t, m8) + + m9 := hpb.M9_builder{ + GetGetFoo: proto.Int32(3), + Foo: proto.Int32(1), + }.Build() + if m9.GetFoo() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m9.GetFoo(), m9) + } + if m9.Get_Foo() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m9.Get_Foo(), m9) + } + if m9.GetGetGetFoo() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m9.GetGetGetFoo(), m9) + } + if m9.Get_GetGetFoo() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m9.Get_GetGetFoo(), m9) + } + m9.Set_GetFoo(2) + if m9.GetGetFoo_() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m9.GetGetFoo_(), m9) + } + if m9.Get_GetFoo() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m9.Get_GetFoo(), m9) + } + checkNameConsistency(t, m9) + m10 := hpb.M10_builder{ + Foo: proto.Int32(1), + SetFoo: proto.Int32(2), + }.Build() + m10.Set_Foo(47) + if m10.Get_Foo() != 47 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 47)", m10.Get_Foo(), m10) + } + m10.SetSetFoo(11) + if m10.GetSetFoo() != 11 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 11)", m10.GetSetFoo(), m10) + } + checkNameConsistency(t, m10) + m11 := hpb.M11_builder{ + Foo: proto.Int32(1), + SetSetFoo: proto.Int32(2), + }.Build() + m11.Set_Foo(47) + if m11.Get_Foo() != 47 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 47)", m11.Get_Foo(), m11) + } + m11.SetSetSetFoo(11) + if m11.GetSetSetFoo() != 11 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 11)", m11.GetSetSetFoo(), m11) + } + checkNameConsistency(t, m11) + m12 := hpb.M12_builder{ + Foo: proto.Int32(1), + SetFoo: proto.Int32(2), + }.Build() + m12.Set_Foo(47) + if m12.Get_Foo() != 47 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 47)", m12.Get_Foo(), m12) + } + m12.Set_SetFoo(11) + if m12.Get_SetFoo() != 11 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 11)", m12.Get_SetFoo(), m12) + } + checkNameConsistency(t, m12) + m13 := hpb.M13_builder{ + Foo: proto.Int32(1), + HasFoo: proto.Int32(2), + }.Build() + if !m13.Has_Foo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected true)", m13.Has_Foo(), m13) + } + if !m13.HasHasFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected true)", m13.HasHasFoo(), m13) + } + checkNameConsistency(t, m13) + m14 := hpb.M14_builder{ + Foo: proto.Int32(1), + HasHasFoo: proto.Int32(2), + }.Build() + if !m14.Has_Foo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected true)", m14.Has_Foo(), m14) + } + if !m14.Has_HasFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected true)", m14.Has_HasFoo(), m14) + } + if !m14.HasHasHasFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected true)", m14.HasHasHasFoo(), m14) + } + checkNameConsistency(t, m14) + m15 := hpb.M15_builder{ + Foo: proto.Int32(1), + HasFoo: proto.Int32(2), + }.Build() + if !m15.Has_Foo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected true)", m15.Has_Foo(), m15) + } + if !m15.Has_HasFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected true)", m15.Has_HasFoo(), m15) + } + if !m15.HasHasHasFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected true)", m15.HasHasHasFoo(), m15) + } + checkNameConsistency(t, m15) + m16 := hpb.M16_builder{ + Foo: proto.Int32(1), + ClearFoo: proto.Int32(2), + }.Build() + m16.Clear_Foo() + if m16.Has_Foo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected false)", m16.Has_Foo(), m16) + } + m16.ClearClearFoo() + if m16.HasClearFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected false)", m16.HasClearFoo(), m16) + } + checkNameConsistency(t, m16) + m17 := hpb.M17_builder{ + Foo: proto.Int32(1), + ClearClearFoo: proto.Int32(2), + }.Build() + m17.Clear_Foo() + if m17.Has_Foo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected false)", m17.Has_Foo(), m17) + } + m17.ClearClearClearFoo() + if m17.HasClearClearFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected false)", m17.HasClearClearFoo(), m17) + } + checkNameConsistency(t, m17) + m18 := hpb.M18_builder{ + Foo: proto.Int32(1), + ClearFoo: proto.Int32(2), + }.Build() + m18.Clear_Foo() + if m18.Has_Foo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected false)", m18.Has_Foo(), m18) + } + m18.Clear_ClearFoo() + if m18.Has_ClearFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected false)", m18.Has_ClearFoo(), m18) + } + checkNameConsistency(t, m18) + m19 := hpb.M19_builder{ + Foo: proto.Int32(1), + WhichFoo: proto.Int32(2), + }.Build() + if m19.WhichWhichWhichFoo() != hpb.M19_WhichFoo_case { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected M19_ClearFoo_case)", m19.WhichWhichWhichFoo(), m19) + } + checkNameConsistency(t, m19) + m20 := hpb.M20_builder{ + Foo: proto.Int32(1), + WhichWhichFoo: proto.Int32(2), + }.Build() + if m20.Which_WhichFoo() != hpb.M20_WhichWhichFoo_case { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected M20_WhichWhichFoo_case)", m20.Which_WhichFoo(), m20) + } + checkNameConsistency(t, m20) + +} + +// TestOpaqueMangling tests the backwards compatible mangling as well +// as new style mangling of fields who clashes with the getters. The +// expected behavior, which is somewhat surprising, is documented in +// the proto test_name_clash_opaque.proto itself. +func TestOpaqueMangling(t *testing.T) { + m1 := opb.M1_builder{ + Foo: proto.Int32(1), + GetFoo: proto.Int32(2), + GetGetFoo: proto.Int32(3), + }.Build() + if m1.GetFoo() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m1.GetFoo(), m1) + } + if m1.GetGetFoo() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m1.GetGetFoo(), m1) + } + if m1.GetGetGetFoo() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m1.GetGetGetFoo(), m1) + } + checkNameConsistency(t, m1) + m2 := opb.M2_builder{ + Foo: proto.Int32(1), + GetFoo: proto.Int32(2), + GetGetFoo: proto.Int32(3), + }.Build() + if m2.GetFoo() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m2.GetFoo(), m2) + } + if m2.GetGetFoo() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m2.GetGetFoo(), m2) + } + if m2.GetGetGetFoo() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m2.GetGetGetFoo(), m2) + } + checkNameConsistency(t, m2) + m3 := opb.M3_builder{ + Foo: proto.Int32(1), + GetFoo: proto.Int32(2), + GetGetFoo: proto.Int32(3), + }.Build() + if m3.GetFoo() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m3.GetFoo(), m3) + } + if m3.GetGetFoo() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m3.GetGetFoo(), m3) + } + if m3.GetGetGetFoo() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m3.GetGetGetFoo(), m3) + } + checkNameConsistency(t, m3) + + m4 := opb.M4_builder{ + GetFoo: proto.Int32(2), + GetGetGetFoo: proto.Int32(3), + Foo: proto.Int32(1), + }.Build() + if m4.GetFoo() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m4.GetFoo(), m4) + } + if m4.GetGetFoo() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m4.GetGetFoo(), m4) + } + if !m4.HasGetGetFoo() { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected true)", m4.HasGetGetFoo(), m4) + } + if m4.GetGetGetGetFoo() != 3 { + t.Errorf("Proto field 'get_get_get_foo' has unexpected value %v for %T (expected 3)", m4.GetGetGetGetFoo(), m4) + } + checkNameConsistency(t, m4) + + m5 := opb.M5_builder{ + GetFoo: proto.Int32(2), + GetGetFoo: proto.Int32(3), + Foo: proto.Int32(1), + }.Build() + if m5.GetFoo() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m5.GetFoo(), m5) + } + if m5.GetGetFoo() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m5.GetGetFoo(), m5) + } + if m5.GetGetGetFoo() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m5.GetGetGetFoo(), m5) + } + checkNameConsistency(t, m5) + + m6 := opb.M6_builder{ + GetGetGetFoo: proto.Int32(3), + GetFoo: proto.Int32(2), + Foo: proto.Int32(1), + }.Build() + if m6.GetFoo() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m6.GetFoo(), m6) + } + if m6.GetGetFoo() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m6.GetGetFoo(), m6) + } + if m6.GetGetGetGetFoo() != 3 { + t.Errorf("Proto field 'get_get_get_foo' has unexpected value %v for %T (expected 3)", m6.GetGetGetGetFoo(), m6) + } + checkNameConsistency(t, m6) + + m7 := opb.M7_builder{ + GetFoo: proto.Int32(3), + Foo: proto.Int32(1), + }.Build() + if m7.GetFoo() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m7.GetFoo(), m7) + } + if m7.GetGetFoo() != 3 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 3)", m7.GetGetFoo(), m7) + } + m7.SetBar(true) + if !m7.GetBar() { + t.Errorf("Proto field 'bar' has unexpected value %v for %T (expected true)", m7.GetBar(), m7) + } + checkNameConsistency(t, m7) + + m8 := opb.M8_builder{ + GetGetFoo: proto.Int32(3), + GetFoo: proto.Int32(2), + Foo: proto.Int32(1), + }.Build() + if m8.GetFoo() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m8.GetFoo(), m8) + } + if m8.GetGetFoo() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m8.GetGetFoo(), m8) + } + if m8.GetGetGetFoo() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m8.GetGetGetFoo(), m8) + } + checkNameConsistency(t, m8) + + m9 := opb.M9_builder{ + GetGetFoo: proto.Int32(3), + Foo: proto.Int32(1), + }.Build() + if m9.GetFoo() != 1 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 1)", m9.GetFoo(), m9) + } + if m9.GetGetGetFoo() != 3 { + t.Errorf("Proto field 'get_get_foo' has unexpected value %v for %T (expected 3)", m9.GetGetGetFoo(), m9) + } + m9.SetGetFoo(2) + if m9.GetGetFoo() != 2 { + t.Errorf("Proto field 'get_foo' has unexpected value %v for %T (expected 2)", m9.GetGetFoo(), m9) + } + checkNameConsistency(t, m9) + m10 := opb.M10_builder{ + Foo: proto.Int32(1), + SetFoo: proto.Int32(2), + }.Build() + m10.SetFoo(48) + if m10.GetFoo() != 48 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 48)", m10.GetFoo(), m10) + } + m10.SetSetFoo(11) + if m10.GetSetFoo() != 11 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 11)", m10.GetSetFoo(), m10) + } + checkNameConsistency(t, m10) + m11 := opb.M11_builder{ + Foo: proto.Int32(1), + SetSetFoo: proto.Int32(2), + }.Build() + m11.SetFoo(48) + if m11.GetFoo() != 48 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 48)", m11.GetFoo(), m11) + } + m11.SetSetSetFoo(11) + if m11.GetSetSetFoo() != 11 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 11)", m11.GetSetSetFoo(), m11) + } + checkNameConsistency(t, m11) + m12 := opb.M12_builder{ + Foo: proto.Int32(1), + SetFoo: proto.Int32(2), + }.Build() + m12.SetFoo(48) + if m12.GetFoo() != 48 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 48)", m12.GetFoo(), m12) + } + m12.SetSetFoo(12) + if m12.GetSetFoo() != 12 { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected 12)", m12.GetSetFoo(), m12) + } + checkNameConsistency(t, m12) + m13 := opb.M13_builder{ + Foo: proto.Int32(1), + HasFoo: proto.Int32(2), + }.Build() + if !m13.HasFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected true)", m13.HasFoo(), m13) + } + if !m13.HasHasFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected true)", m13.HasHasFoo(), m13) + } + checkNameConsistency(t, m13) + m14 := opb.M14_builder{ + Foo: proto.Int32(1), + HasHasFoo: proto.Int32(2), + }.Build() + if !m14.HasFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected true)", m14.HasFoo(), m14) + } + if !m14.HasHasFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected true)", m14.HasHasFoo(), m14) + } + if !m14.HasHasHasFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected true)", m14.HasHasHasFoo(), m14) + } + checkNameConsistency(t, m14) + m15 := opb.M15_builder{ + Foo: proto.Int32(1), + HasFoo: proto.Int32(2), + }.Build() + if !m15.HasFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected true)", m15.HasFoo(), m15) + } + if !m15.HasHasFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected true)", m15.HasHasFoo(), m15) + } + if !m15.HasHasHasFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected true)", m15.HasHasHasFoo(), m15) + } + checkNameConsistency(t, m15) + m16 := opb.M16_builder{ + Foo: proto.Int32(1), + ClearFoo: proto.Int32(2), + }.Build() + m16.SetFoo(4711) + m16.ClearFoo() + if m16.HasFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected false)", m16.HasFoo(), m16) + } + m16.ClearClearFoo() + if m16.HasClearFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected false)", m16.HasClearFoo(), m16) + } + checkNameConsistency(t, m16) + m17 := opb.M17_builder{ + Foo: proto.Int32(1), + ClearClearFoo: proto.Int32(2), + }.Build() + m17.SetFoo(4711) + m17.ClearFoo() + if m17.HasFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected false)", m17.HasFoo(), m17) + } + m17.ClearClearClearFoo() + if m17.HasClearClearFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected false)", m17.HasClearClearFoo(), m17) + } + checkNameConsistency(t, m17) + m18 := opb.M18_builder{ + Foo: proto.Int32(1), + ClearFoo: proto.Int32(2), + }.Build() + m18.SetFoo(4711) + m18.ClearFoo() + if m18.HasFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected false)", m18.HasFoo(), m18) + } + m18.SetClearFoo(13) + m18.ClearClearFoo() + if m18.HasClearFoo() { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected false)", m18.HasClearFoo(), m18) + } + checkNameConsistency(t, m18) + m19 := opb.M19_builder{ + Foo: proto.Int32(1), + WhichFoo: proto.Int32(2), + }.Build() + if m19.WhichWhichWhichFoo() != opb.M19_WhichFoo_case { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected M19_ClearFoo_case)", m19.WhichWhichWhichFoo(), m19) + } + checkNameConsistency(t, m19) + m20 := opb.M20_builder{ + Foo: proto.Int32(1), + WhichWhichFoo: proto.Int32(2), + }.Build() + if m20.WhichWhichFoo() != opb.M20_WhichWhichFoo_case { + t.Errorf("Proto field 'foo' has unexpected value %v for %T (expected M20_WhichWhichFoo_case)", m20.WhichWhichFoo(), m20) + } + checkNameConsistency(t, m20) + +} + +func protogenFor(t *testing.T, m proto.Message) *protogen.Message { + t.Helper() + + md := m.ProtoReflect().Descriptor() + + // Construct a Protobuf plugin code generation request based on the + // transitive closure of dependencies of message m. + req := &pluginpb.CodeGeneratorRequest{ + ProtoFile: []*descpb.FileDescriptorProto{ + protodesc.ToFileDescriptorProto(descpb.File_google_protobuf_descriptor_proto), + protodesc.ToFileDescriptorProto(gofeaturespb.File_google_protobuf_go_features_proto), + protodesc.ToFileDescriptorProto(md.ParentFile()), + }, + } + plugin, err := protogen.Options{}.New(req) + if err != nil { + t.Fatalf("protogen.Options.New: %v", err) + } + if got, want := len(plugin.Files), len(req.ProtoFile); got != want { + t.Fatalf("protogen returned %d plugin.Files entries, expected %d", got, want) + } + file := plugin.Files[len(plugin.Files)-1] + for _, msg := range file.Messages { + if msg.Desc.FullName() != md.FullName() { + continue + } + return msg + } + t.Fatalf("BUG: message %q not found in protogen response", md.FullName()) + return nil +} + +// checkNameConsistency will go through the fields (deliberately avoiding +// protoreflect; querying protogen instead), and for each field check that one +// of the two naming schemes is used consistently, so that if you find +// e.g. Has_Foo, the setter will be Set_Foo and not SetFoo. It also checks that +// at least one of the naming schemes apply. +func checkNameConsistency(t *testing.T, m proto.Message) { + t.Helper() + // It's wrong to use Go reflection on a Message, but in this + // case, we're specifically looking at the implementation + typ := reflect.TypeOf(m) + // The info we need for one field + type fi struct { + name string + prefixes []string + } + // The method prefixes for different kinds of fields + repeatedPrefixes := []string{"Get", "Set"} + oneofPrefixes := []string{"Has", "Clear", "Which"} + optionalPrefixes := []string{"Get", "Set", "Has", "Clear"} + + fields := []fi{} + msg := protogenFor(t, m) + for _, f := range msg.Fields { + prefixes := optionalPrefixes + if f.Desc.Cardinality() == genid.Field_CARDINALITY_REPEATED_enum_value { + prefixes = repeatedPrefixes + } + fields = append(fields, fi{name: f.GoName, prefixes: prefixes}) + } + for _, o := range msg.Oneofs { + fields = append(fields, fi{name: o.GoName, prefixes: oneofPrefixes}) + } + if len(fields) == 0 { + t.Errorf("Message %v checked for consistency has no fields to check", reflect.TypeOf(m)) + } + // Check method names for all fields + for _, f := range fields { + // Remove trailing underscores added by old name mangling algorithm + for f.name[len(f.name)-1] == '_' { + f.name = f.name[:len(f.name)-1] + } + // Check consistency of either "underscored" methods or "non underscored" + found := "" + for _, infix := range []string{"_", ""} { + for _, prefix := range f.prefixes { + if m, ok := typ.MethodByName(prefix + infix + f.name); ok { + found = m.Name + break + } + } + if found != "" { + for _, prefix := range f.prefixes { + if _, ok := typ.MethodByName(prefix + infix + f.name); !ok { + t.Errorf("Field %s has inconsistent method names - found %s, but not %s", f.name, found, prefix+infix+f.name) + } + } + break + } + } + // If we found neither, something is wrong + if found == "" { + t.Errorf("Field %s has neither plain nor underscored methods.", f.name) + } + + } +} diff --git a/cmd/protoc-gen-go/opaque_default_test/opaque_default_test.go b/cmd/protoc-gen-go/opaque_default_test/opaque_default_test.go new file mode 100644 index 000000000..fb9349bb3 --- /dev/null +++ b/cmd/protoc-gen-go/opaque_default_test/opaque_default_test.go @@ -0,0 +1,21 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package opaque_default_test + +import ( + "testing" + + enumopaquepb "google.golang.org/protobuf/internal/testprotos/enums/enums_opaque" + testopaquepb "google.golang.org/protobuf/internal/testprotos/testeditions/testeditions_opaque" +) + +// From the spec: "Proto2 enums use the first syntactic entry in the enum +// declaration as the default value where it is otherwise unspecified." +func TestOpaqueEnumDefaults(t *testing.T) { + m := &testopaquepb.RemoteDefault{} + if got, want := m.GetDefault(), enumopaquepb.Enum_DEFAULT; got != want { + t.Errorf("default enum value: got %v, expected %v", got, want) + } +} diff --git a/cmd/protoc-gen-go/opaque_map_test/opaque_map_test.go b/cmd/protoc-gen-go/opaque_map_test/opaque_map_test.go new file mode 100644 index 000000000..a856df2ee --- /dev/null +++ b/cmd/protoc-gen-go/opaque_map_test/opaque_map_test.go @@ -0,0 +1,35 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package opaque_map_test + +import ( + "testing" + + testopaquepb "google.golang.org/protobuf/internal/testprotos/testeditions/testeditions_opaque" +) + +func TestOpaqueMap(t *testing.T) { + m := &testopaquepb.TestAllTypes{} + + m.SetMapStringString(map[string]string{"one": "eins"}) + if got, want := len(m.GetMapStringString()), 1; got != want { + t.Errorf("after setting map_string_string to a non-empty map: len(m.GetMapStringString()) = %v, want %v", got, want) + } + delete(m.GetMapStringString(), "one") + if got, want := len(m.GetMapStringString()), 0; got != want { + t.Errorf("after removing all elements from m_one: len(m.GetMapStringString()) = %v, want %v", got, want) + } + if got := m.GetMapStringString(); got == nil { + t.Errorf("after removing all elements from m_one: m.GetMapStringString() = nil, want non-nil map") + } + m.GetMapStringString()["two"] = "zwei" + if got, want := len(m.GetMapStringString()), 1; got != want { + t.Errorf("after adding new element to m_one: len(m.GetMapStringString()) = %v, want %v", got, want) + } + m.SetMapStringString(map[string]string{}) + if got := m.GetMapStringString(); got == nil { + t.Errorf("after setting m_one to an empty map: m.GetMapStringString() = nil, want non-nil map") + } +} diff --git a/cmd/protoc-gen-go/testdata/nameclash/nameclash.go b/cmd/protoc-gen-go/testdata/nameclash/nameclash.go new file mode 100644 index 000000000..dc63532aa --- /dev/null +++ b/cmd/protoc-gen-go/testdata/nameclash/nameclash.go @@ -0,0 +1,16 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package nameclash is imported by gen_test.go and itself imports the various +// nameclash proto variants. +package nameclash + +import ( + _ "google.golang.org/protobuf/cmd/protoc-gen-go/testdata/nameclash/test_name_clash_hybrid" + _ "google.golang.org/protobuf/cmd/protoc-gen-go/testdata/nameclash/test_name_clash_hybrid3" + _ "google.golang.org/protobuf/cmd/protoc-gen-go/testdata/nameclash/test_name_clash_opaque" + _ "google.golang.org/protobuf/cmd/protoc-gen-go/testdata/nameclash/test_name_clash_opaque3" + _ "google.golang.org/protobuf/cmd/protoc-gen-go/testdata/nameclash/test_name_clash_open" + _ "google.golang.org/protobuf/cmd/protoc-gen-go/testdata/nameclash/test_name_clash_open3" +) diff --git a/cmd/protoc-gen-go/testdata/nameclash/test_name_clash_hybrid.proto b/cmd/protoc-gen-go/testdata/nameclash/test_name_clash_hybrid.proto new file mode 100644 index 000000000..d15394c22 --- /dev/null +++ b/cmd/protoc-gen-go/testdata/nameclash/test_name_clash_hybrid.proto @@ -0,0 +1,333 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This proto verifies that we keep the name mangling algorithm (which is +// position dependent) intact in the protoc_gen_go generator. The field names +// and the getter names have to be kept intact over time, both in the OPEN and +// in the HYBRID API. How fields are "mangled" is described in a comment per +// field. + +// The order of "evaluation" of fields is important. Fields are evaluated in +// order of appearance, except the oneof union names, that are evaluated after +// their first member. For each field, check if there is a previous field name +// or getter name that clashes with this field or it's getter. In case there is +// a clash, add an _ to the field name and repeat. In the case of oneof's, the +// union will be renamed if it clashes with it's first member, but not if it +// clashes with it's second. + +// This scheme is here for backwards compatibility. +// The type of clashes that can be are the following: +// 1 - My field name clashes with their getter name +// 2 - My getter name clashes with their field name + +edition = "2023"; + +package net.proto2.go.testdata.nameclashhybrid; + +import "google/protobuf/go_features.proto"; + +option go_package = "google.golang.org/protobuf/cmd/protoc-gen-go/testdata/nameclash/test_name_clash_hybrid"; + +option features.(pb.go).api_level = API_HYBRID; + +message M1 { + // Old Scheme: + // initial name in Go | Clashes with field | type | final name + // Foo | - | - | Foo + // GetFoo | foo | 1 | GetFoo_ + // GetGetFoo | - | - | GetGetFoo + // New Scheme: + // initial name in Go | Clashes with field | type | Getter name + // Foo | get_foo | G | Get_Foo + // GetFoo | get_get_foo | G | Get_GetFoo + // GetGetFoo | - | - | GetGetGetFoo + int32 foo = 1; + int32 get_foo = 2; + int32 get_get_foo = 3; +} + +message M2 { + // Old Scheme: + // initial name in Go | Clashes with field | type | final name + // GetGetFoo | - | - | GetGetFoo + // GetFoo | get_get_foo | 2 | GetFoo_ + // Foo | - | - | Foo + // New Scheme: + // initial name in Go | Clashes with field | type | Getter name + // Foo | get_foo | G | Get_Foo + // GetFoo | get_get_foo | G | Get_GetFoo + // GetGetFoo | - | - | GetGetGetFoo + int32 get_get_foo = 3; + int32 get_foo = 2; + int32 foo = 1; +} + +message M3 { + // Old Scheme: + // initial name in Go | Clashes with field | type | final name + // GetFoo | - | - | GetFoo + // GetGetFoo | get_foo | 1 | GetGetFoo_ + // Foo | get_foo | 2 | Foo_ + // New Scheme: + // initial name in Go | Clashes with field | type | Getter name + // Foo | get_foo | G | Get_Foo + // GetFoo | get_get_foo | G | Get_GetFoo + // GetGetFoo | - | - | GetGetGetFoo + int32 get_foo = 2; + int32 get_get_foo = 3; + int32 foo = 1; +} + +message M4 { + // Old Scheme: + // initial name in Go | Clashes with field | type | final name + // GetFoo | - | - | GetFoo + // GetGetFoo | get_foo | 1 | GetGetFoo_ + // GetGetGetFoo | - | - | GetGetGetFoo + // | | | + // Foo | get_foo | 2 | Foo_ + // New Scheme: + // initial name in Go | Clashes with field | type | Getter name + // Foo | get_foo | G | Get_Foo + // GetFoo | get_get_foo | G | Get_GetFoo + // GetGetGetFoo | - | - | GetGetGetGetFoo + int32 get_foo = 2; + oneof get_get_foo { + int32 get_get_get_foo = 3; + } + int32 foo = 1; +} + +message M5 { + // Old Scheme: + // Note evaluation order - get_get_foo before get_get_get_foo + // initial name in Go | Clashes with field | type | final name + // GetFoo | - | - | GetFoo + // GetGetGetFoo | - | - | GetGetGetFoo + // GetGetFoo | get_foo | 1 | GetGetFoo_ + // | | | + // Foo | get_foo | 2 | Foo_ + // New Scheme: + // initial name in Go | Clashes with field | type | Getter name + // Foo | get_foo | G | Get_Foo + // GetFoo | get_get_foo | G | Get_GetFoo + // GetGetFoo | get_get_get_foo | G | Get_GetGetFoo + int32 get_foo = 2; + oneof get_get_get_foo { + int32 get_get_foo = 3; + } + int32 foo = 1; +} + +message M6 { + // Note evaluation order - get_get_get_foo before get_get_foo + // initial name in Go | Clashes with field | type | final name + // GetGetFoo | - | - | GetGetFoo + // GetGetGetFoo | - | - | GetGetGetFoo + // | | | + // GetFoo | get_get_foo | 2 | GetFoo_ + // Foo | - | - | Foo + // New Scheme: + // initial name in Go | Clashes with field | type | Getter name + // Foo | get_foo | G | Get_Foo + // GetFoo | get_get_foo | G | Get_GetFoo + // GetGetGetFoo | - | - | GetGetGetGetFoo + oneof get_get_foo { + int32 get_get_get_foo = 3; + } + int32 get_foo = 2; + int32 foo = 1; +} + +message M7 { + // Note evaluation order - bar before get_get_foo, then get_get_get_foo + // initial name in Go | Clashes with field | type | final name + // GetGetFoo | - | - | GetGetFoo + // Bar | - | - | Bar + // GetFoo | foo | 1 | GetFoo_ + // | | | + // Foo | - | - | Foo + // New Scheme: + // initial name in Go | Clashes with field | type | Getter name + // Foo | get_foo | G | Get_Foo + // GetFoo | get_get_foo | G | Get_GetFoo + // Bar | - | - | GetBar + oneof get_get_foo { + bool bar = 4; + int32 get_foo = 3; + } + int32 foo = 1; +} + +message M8 { + // Note evaluation order - get_get_foo before get_get_get_foo + // initial name in Go | Clashes with field | type | final name + // GetGetGetFoo | get_get_foo | 1 | GetGetGetFoo_ + // GetGetFoo | - | - | GetGetFoo + // | | | + // GetFoo | get_get_foo | 2 | GetFoo_ + // Foo | - | - | Foo + // New Scheme: + // initial name in Go | Clashes with field | type | Getter name + // Foo | get_foo | G | Get_Foo + // GetFoo | get_get_foo | G | Get_GetFoo + // GetGetFoo | get_get_get_foo | G | Get_GetGetFoo + oneof get_get_get_foo { + int32 get_get_foo = 3; + } + int32 get_foo = 2; + int32 foo = 1; +} + +message M9 { + // Note evaluation order - get_get_foo before get_get_get_foo, then get_foo + // initial name in Go | Clashes with field | type | final name + // GetGetGetFoo | get_get_foo | 1 | GetGetGetFoo_ + // GetGetFoo | - | - | GetGetFoo + // GetFoo | get_get_foo | 2 | GetFoo_ + // | | | + // Foo | - | - | Foo + // New Scheme: + // initial name in Go | Clashes with field | type | Getter name + // Foo | get_foo | G | Get_Foo + // GetFoo | get_get_foo | G | Get_GetFoo + // GetGetFoo | get_get_get_foo | G | Get_GetGetFoo + oneof get_get_get_foo { + int32 get_get_foo = 3; + int32 get_foo = 2; + } + int32 foo = 1; +} + +message M10 { + // Set Clashes - no concerns with get-mangling as legacy open struct + // does not have setters except for weak fields: + // initial name in Go | Clashes with field | Setter name + // Foo | set_foo | Set_Foo + // SetFoo | - | SetSetFoo + int32 foo = 1; + int32 set_foo = 2; +} + +message M11 { + // Set Clashes - no concerns with get-mangling as legacy open struct + // does not have setters except for weak fields: + // initial name in Go | Clashes with field | Setter name + // Foo | set_foo | Set_Foo + // SetSetFoo | - | SetSetSetFoo + int32 foo = 1; + oneof set_foo { + int32 set_set_foo = 2; + } +} + +message M12 { + // Set Clashes - no concerns with get-mangling as legacy open struct + // does not have setters except for weak fields: + // initial name in Go | Clashes with field | Setter name + // Foo | set_foo | Set_Foo + // SetFoo | set_set_foo | Set_SetFoo + int32 foo = 1; + oneof set_set_foo { + int32 set_foo = 2; + } +} + +message M13 { + // Has Clashes - no concerns with get-mangling as legacy open struct + // does not have hassers except for weak fields: + // initial name in Go | Clashes with field | Hasser name + // Foo | has_foo | Has_Foo + // HasFoo | - | HasHasFoo + int32 foo = 1; + int32 has_foo = 2; +} + +message M14 { + // Has Clashes - no concerns with get-mangling as legacy open struct + // does not have hassers except for weak fields: + // initial name in Go | Clashes with field | Hasser name + // Foo | has_foo | Has_Foo + // HasFoo | has_has_foo | Has_HasFoo + // HasHasFoo | - | HasHasHasFoo + int32 foo = 1; + oneof has_foo { + int32 has_has_foo = 2; + } +} + +message M15 { + // Has Clashes - no concerns with get-mangling as legacy open struct + // does not have hassers except for weak fields: + // initial name in Go | Clashes with field | Hasser name + // Foo | has_foo | Has_Foo + // HasFoo | has_has_foo | Has_HasFoo + // HasHasFoo | - | HasHasHasFoo + int32 foo = 1; + oneof has_has_foo { + int32 has_foo = 2; + } +} + +message M16 { + // Clear Clashes - no concerns with get-mangling as legacy open + // struct does not have clearers except for weak fields: + // initial name in Go | Clashes with field | Clearer name + // Foo | clear_foo | Clear_Foo + // ClearFoo | - | ClearClearFoo + int32 foo = 1; + int32 clear_foo = 2; +} + +message M17 { + // Clear Clashes - no concerns with get-mangling as legacy open + // struct does not have clearers except for weak fields: + // initial name in Go | Clashes with field | Clearer name + // Foo | clear_foo | Clear_Foo + // ClearFoo | clear_clear_foo | Clear_ClearFoo + // ClearClearFoo | - | ClearClearClearFoo + int32 foo = 1; + oneof clear_foo { + int32 clear_clear_foo = 2; + } +} + +message M18 { + // Clear Clashes - no concerns with get-mangling as legacy open + // struct does not have clearers except for weak fields: + // initial name in Go | Clashes with field | Clearer name + // Foo | clear_foo | Clear_Foo + // ClearFoo | clear_clear_foo | Clear_ClearFoo + // ClearClearFoo | - | ClearClearClearFoo + int32 foo = 1; + oneof clear_clear_foo { + int32 clear_foo = 2; + } +} + +message M19 { + // Which Clashes - no concerns with get-mangling as legacy open + // struct does not have whichers except for weak fields: + // initial name in Go | Clashes with field | Whicher name + // Foo | - | - + // WhichFoo | - | - + // WhichWhichFoo | - | WhichWhichWhichFoo + int32 foo = 1; + oneof which_which_foo { + int32 which_foo = 2; + } +} + +message M20 { + // Which Clashes - no concerns with get-mangling as legacy open + // struct does not have whichers except for weak fields: + // initial name in Go | Clashes with field | Whicher name + // Foo | - | - + // WhichFoo | which_which_foo | Which_WhichFoo + // WhichWhichFoo | - | - + int32 foo = 1; + oneof which_foo { + int32 which_which_foo = 2; + } +} diff --git a/cmd/protoc-gen-go/testdata/nameclash/test_name_clash_hybrid3.proto b/cmd/protoc-gen-go/testdata/nameclash/test_name_clash_hybrid3.proto new file mode 100644 index 000000000..8be92f611 --- /dev/null +++ b/cmd/protoc-gen-go/testdata/nameclash/test_name_clash_hybrid3.proto @@ -0,0 +1,338 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This proto verifies that we keep the name mangling algorithm (which is +// position dependent) intact in the protoc_gen_go generator. The field names +// and the getter names have to be kept intact over time, both in the OPEN and +// in the HYBRID API. How fields are "mangled" is described in a comment per +// field. + +// The order of "evaluation" of fields is important. Fields are evaluated in +// order of appearance, except the oneof union names, that are evaluated after +// their first member. For each field, check if there is a previous field name +// or getter name that clashes with this field or it's getter. In case there is +// a clash, add an _ to the field name and repeat. In the case of oneof's, the +// union will be renamed if it clashes with it's first member, but not if it +// clashes with it's second. + +// This scheme is here for backwards compatibility. +// The type of clashes that can be are the following: +// 1 - My field name clashes with their getter name +// 2 - My getter name clashes with their field name + +edition = "2023"; + +package net.proto2.go.testdata.nameclashhybrid3; + +import "google/protobuf/go_features.proto"; + +option go_package = "google.golang.org/protobuf/cmd/protoc-gen-go/testdata/nameclash/test_name_clash_hybrid3"; + +option features.field_presence = IMPLICIT; +option features.(pb.go).api_level = API_HYBRID; + +message M0 { + int32 i1 = 1; +} + +message M1 { + // Old Scheme: + // initial name in Go | Clashes with field | type | final name + // Foo | - | - | Foo + // GetFoo | foo | 1 | GetFoo_ + // GetGetFoo | - | - | GetGetFoo + // New Scheme: + // initial name in Go | Clashes with field | type | Getter name + // Foo | get_foo | G | Get_Foo + // GetFoo | get_get_foo | G | Get_GetFoo + // GetGetFoo | - | - | GetGetGetFoo + M0 foo = 1; + M0 get_foo = 2; + M0 get_get_foo = 3; +} + +message M2 { + // Old Scheme: + // initial name in Go | Clashes with field | type | final name + // GetGetFoo | - | - | GetGetFoo + // GetFoo | get_get_foo | 2 | GetFoo_ + // Foo | - | - | Foo + // New Scheme: + // initial name in Go | Clashes with field | type | Getter name + // Foo | get_foo | G | Get_Foo + // GetFoo | get_get_foo | G | Get_GetFoo + // GetGetFoo | - | - | GetGetGetFoo + M0 get_get_foo = 3; + M0 get_foo = 2; + M0 foo = 1; +} + +message M3 { + // Old Scheme: + // initial name in Go | Clashes with field | type | final name + // GetFoo | - | - | GetFoo + // GetGetFoo | get_foo | 1 | GetGetFoo_ + // Foo | get_foo | 2 | Foo_ + // New Scheme: + // initial name in Go | Clashes with field | type | Getter name + // Foo | get_foo | G | Get_Foo + // GetFoo | get_get_foo | G | Get_GetFoo + // GetGetFoo | - | - | GetGetGetFoo + M0 get_foo = 2; + M0 get_get_foo = 3; + M0 foo = 1; +} + +message M4 { + // Old Scheme: + // initial name in Go | Clashes with field | type | final name + // GetFoo | - | - | GetFoo + // GetGetFoo | get_foo | 1 | GetGetFoo_ + // GetGetGetFoo | - | - | GetGetGetFoo + // | | | + // Foo | get_foo | 2 | Foo_ + // New Scheme: + // initial name in Go | Clashes with field | type | Getter name + // Foo | get_foo | G | Get_Foo + // GetFoo | get_get_foo | G | Get_GetFoo + // GetGetGetFoo | - | - | GetGetGetGetFoo + M0 get_foo = 2; + oneof get_get_foo { + int32 get_get_get_foo = 3; + } + M0 foo = 1; +} + +message M5 { + // Old Scheme: + // Note evaluation order - get_get_foo before get_get_get_foo + // initial name in Go | Clashes with field | type | final name + // GetFoo | - | - | GetFoo + // GetGetGetFoo | - | - | GetGetGetFoo + // GetGetFoo | get_foo | 1 | GetGetFoo_ + // | | | + // Foo | get_foo | 2 | Foo_ + // New Scheme: + // initial name in Go | Clashes with field | type | Getter name + // Foo | get_foo | G | Get_Foo + // GetFoo | get_get_foo | G | Get_GetFoo + // GetGetFoo | get_get_get_foo | G | Get_GetGetFoo + M0 get_foo = 2; + oneof get_get_get_foo { + int32 get_get_foo = 3; + } + M0 foo = 1; +} + +message M6 { + // Note evaluation order - get_get_get_foo before get_get_foo + // initial name in Go | Clashes with field | type | final name + // GetGetFoo | - | - | GetGetFoo + // GetGetGetFoo | - | - | GetGetGetFoo + // | | | + // GetFoo | get_get_foo | 2 | GetFoo_ + // Foo | - | - | Foo + // New Scheme: + // initial name in Go | Clashes with field | type | Getter name + // Foo | get_foo | G | Get_Foo + // GetFoo | get_get_foo | G | Get_GetFoo + // GetGetGetFoo | - | - | GetGetGetGetFoo + oneof get_get_foo { + int32 get_get_get_foo = 3; + } + M0 get_foo = 2; + M0 foo = 1; +} + +message M7 { + // Note evaluation order - bar before get_get_foo, then get_get_get_foo + // initial name in Go | Clashes with field | type | final name + // GetGetFoo | - | - | GetGetFoo + // Bar | - | - | Bar + // GetFoo | foo | 1 | GetFoo_ + // | | | + // Foo | - | - | Foo + // New Scheme: + // initial name in Go | Clashes with field | type | Getter name + // Foo | get_foo | G | Get_Foo + // GetFoo | get_get_foo | G | Get_GetFoo + // Bar | - | - | GetBar + oneof get_get_foo { + bool bar = 4; + int32 get_foo = 3; + } + M0 foo = 1; +} + +message M8 { + // Note evaluation order - get_get_foo before get_get_get_foo + // initial name in Go | Clashes with field | type | final name + // GetGetGetFoo | get_get_foo | 1 | GetGetGetFoo_ + // GetGetFoo | - | - | GetGetFoo + // | | | + // GetFoo | get_get_foo | 2 | GetFoo_ + // Foo | - | - | Foo + // New Scheme: + // initial name in Go | Clashes with field | type | Getter name + // Foo | get_foo | G | Get_Foo + // GetFoo | get_get_foo | G | Get_GetFoo + // GetGetFoo | get_get_get_foo | G | Get_GetGetFoo + oneof get_get_get_foo { + int32 get_get_foo = 3; + } + M0 get_foo = 2; + M0 foo = 1; +} + +message M9 { + // Note evaluation order - get_get_foo before get_get_get_foo, then get_foo + // initial name in Go | Clashes with field | type | final name + // GetGetGetFoo | get_get_foo | 1 | GetGetGetFoo_ + // GetGetFoo | - | - | GetGetFoo + // GetFoo | get_get_foo | 2 | GetFoo_ + // | | | + // Foo | - | - | Foo + // New Scheme: + // initial name in Go | Clashes with field | type | Getter name + // Foo | get_foo | G | Get_Foo + // GetFoo | get_get_foo | G | Get_GetFoo + // GetGetFoo | get_get_get_foo | G | Get_GetGetFoo + oneof get_get_get_foo { + int32 get_get_foo = 3; + int32 get_foo = 2; + } + M0 foo = 1; +} + +message M10 { + // Set Clashes - no concerns with get-mangling as legacy open struct + // does not have setters except for weak fields: + // initial name in Go | Clashes with field | Setter name + // Foo | set_foo | Set_Foo + // SetFoo | - | SetSetFoo + M0 foo = 1; + M0 set_foo = 2; +} + +message M11 { + // Set Clashes - no concerns with get-mangling as legacy open struct + // does not have setters except for weak fields: + // initial name in Go | Clashes with field | Setter name + // Foo | set_foo | Set_Foo + // SetSetFoo | - | SetSetSetFoo + M0 foo = 1; + oneof set_foo { + int32 set_set_foo = 2; + } +} + +message M12 { + // Set Clashes - no concerns with get-mangling as legacy open struct + // does not have setters except for weak fields: + // initial name in Go | Clashes with field | Setter name + // Foo | set_foo | Set_Foo + // SetFoo | set_set_foo | Set_SetFoo + M0 foo = 1; + oneof set_set_foo { + int32 set_foo = 2; + } +} + +message M13 { + // Has Clashes - no concerns with get-mangling as legacy open struct + // does not have hassers except for weak fields: + // initial name in Go | Clashes with field | Hasser name + // Foo | has_foo | Has_Foo + // HasFoo | - | HasHasFoo + M0 foo = 1; + M0 has_foo = 2; +} + +message M14 { + // Has Clashes - no concerns with get-mangling as legacy open struct + // does not have hassers except for weak fields: + // initial name in Go | Clashes with field | Hasser name + // Foo | has_foo | Has_Foo + // HasFoo | has_has_foo | Has_HasFoo + // HasHasFoo | - | HasHasHasFoo + M0 foo = 1; + oneof has_foo { + int32 has_has_foo = 2; + } +} + +message M15 { + // Has Clashes - no concerns with get-mangling as legacy open struct + // does not have hassers except for weak fields: + // initial name in Go | Clashes with field | Hasser name + // Foo | has_foo | Has_Foo + // HasFoo | has_has_foo | Has_HasFoo + // HasHasFoo | - | HasHasHasFoo + M0 foo = 1; + oneof has_has_foo { + int32 has_foo = 2; + } +} + +message M16 { + // Clear Clashes - no concerns with get-mangling as legacy open + // struct does not have clearers except for weak fields: + // initial name in Go | Clashes with field | Clearer name + // Foo | clear_foo | Clear_Foo + // ClearFoo | - | ClearClearFoo + M0 foo = 1; + M0 clear_foo = 2; +} + +message M17 { + // Clear Clashes - no concerns with get-mangling as legacy open + // struct does not have clearers except for weak fields: + // initial name in Go | Clashes with field | Clearer name + // Foo | clear_foo | Clear_Foo + // ClearFoo | clear_clear_foo | Clear_ClearFoo + // ClearClearFoo | - | ClearClearClearFoo + M0 foo = 1; + oneof clear_foo { + int32 clear_clear_foo = 2; + } +} + +message M18 { + // Clear Clashes - no concerns with get-mangling as legacy open + // struct does not have clearers except for weak fields: + // initial name in Go | Clashes with field | Clearer name + // Foo | clear_foo | Clear_Foo + // ClearFoo | clear_clear_foo | Clear_ClearFoo + // ClearClearFoo | - | ClearClearClearFoo + M0 foo = 1; + oneof clear_clear_foo { + int32 clear_foo = 2; + } +} + +message M19 { + // Which Clashes - no concerns with get-mangling as legacy open + // struct does not have whichers except for weak fields: + // initial name in Go | Clashes with field | Whicher name + // Foo | - | - + // WhichFoo | - | - + // WhichWhichFoo | - | WhichWhichWhichFoo + M0 foo = 1; + oneof which_which_foo { + int32 which_foo = 2; + } +} + +message M20 { + // Which Clashes - no concerns with get-mangling as legacy open + // struct does not have whichers except for weak fields: + // initial name in Go | Clashes with field | Whicher name + // Foo | - | - + // WhichFoo | which_which_foo | Which_WhichFoo + // WhichWhichFoo | - | - + M0 foo = 1; + oneof which_foo { + int32 which_which_foo = 2; + } +} diff --git a/cmd/protoc-gen-go/testdata/nameclash/test_name_clash_opaque.proto b/cmd/protoc-gen-go/testdata/nameclash/test_name_clash_opaque.proto new file mode 100644 index 000000000..e23674c54 --- /dev/null +++ b/cmd/protoc-gen-go/testdata/nameclash/test_name_clash_opaque.proto @@ -0,0 +1,333 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This proto verifies that we keep the name mangling algorithm (which is +// position dependent) intact in the protoc_gen_go generator. The field names +// and the getter names have to be kept intact over time, both in the OPEN and +// in the HYBRID API. How fields are "mangled" is described in a comment per +// field. + +// The order of "evaluation" of fields is important. Fields are evaluated in +// order of appearance, except the oneof union names, that are evaluated after +// their first member. For each field, check if there is a previous field name +// or getter name that clashes with this field or it's getter. In case there is +// a clash, add an _ to the field name and repeat. In the case of oneof's, the +// union will be renamed if it clashes with it's first member, but not if it +// clashes with it's second. + +// This scheme is here for backwards compatibility. +// The type of clashes that can be are the following: +// 1 - My field name clashes with their getter name +// 2 - My getter name clashes with their field name + +edition = "2023"; + +package net.proto2.go.testdata.nameclashopaque; + +import "google/protobuf/go_features.proto"; + +option go_package = "google.golang.org/protobuf/cmd/protoc-gen-go/testdata/nameclash/test_name_clash_opaque"; + +option features.(pb.go).api_level = API_OPAQUE; + +message M1 { + // Old Scheme: + // initial name in Go | Clashes with field | type | final name + // Foo | - | - | Foo + // GetFoo | foo | 1 | GetFoo_ + // GetGetFoo | - | - | GetGetFoo + // New Scheme: + // initial name in Go | Clashes with field | type | Getter name + // Foo | get_foo | G | Get_Foo + // GetFoo | get_get_foo | G | Get_GetFoo + // GetGetFoo | - | - | GetGetGetFoo + int32 foo = 1; + int32 get_foo = 2; + int32 get_get_foo = 3; +} + +message M2 { + // Old Scheme: + // initial name in Go | Clashes with field | type | final name + // GetGetFoo | - | - | GetGetFoo + // GetFoo | get_get_foo | 2 | GetFoo_ + // Foo | - | - | Foo + // New Scheme: + // initial name in Go | Clashes with field | type | Getter name + // Foo | get_foo | G | Get_Foo + // GetFoo | get_get_foo | G | Get_GetFoo + // GetGetFoo | - | - | GetGetGetFoo + int32 get_get_foo = 3; + int32 get_foo = 2; + int32 foo = 1; +} + +message M3 { + // Old Scheme: + // initial name in Go | Clashes with field | type | final name + // GetFoo | - | - | GetFoo + // GetGetFoo | get_foo | 1 | GetGetFoo_ + // Foo | get_foo | 2 | Foo_ + // New Scheme: + // initial name in Go | Clashes with field | type | Getter name + // Foo | get_foo | G | Get_Foo + // GetFoo | get_get_foo | G | Get_GetFoo + // GetGetFoo | - | - | GetGetGetFoo + int32 get_foo = 2; + int32 get_get_foo = 3; + int32 foo = 1; +} + +message M4 { + // Old Scheme: + // initial name in Go | Clashes with field | type | final name + // GetFoo | - | - | GetFoo + // GetGetFoo | get_foo | 1 | GetGetFoo_ + // GetGetGetFoo | - | - | GetGetGetFoo + // | | | + // Foo | get_foo | 2 | Foo_ + // New Scheme: + // initial name in Go | Clashes with field | type | Getter name + // Foo | get_foo | G | Get_Foo + // GetFoo | get_get_foo | G | Get_GetFoo + // GetGetGetFoo | - | - | GetGetGetGetFoo + int32 get_foo = 2; + oneof get_get_foo { + int32 get_get_get_foo = 3; + } + int32 foo = 1; +} + +message M5 { + // Old Scheme: + // Note evaluation order - get_get_foo before get_get_get_foo + // initial name in Go | Clashes with field | type | final name + // GetFoo | - | - | GetFoo + // GetGetGetFoo | - | - | GetGetGetFoo + // GetGetFoo | get_foo | 1 | GetGetFoo_ + // | | | + // Foo | get_foo | 2 | Foo_ + // New Scheme: + // initial name in Go | Clashes with field | type | Getter name + // Foo | get_foo | G | Get_Foo + // GetFoo | get_get_foo | G | Get_GetFoo + // GetGetFoo | get_get_get_foo | G | Get_GetGetFoo + int32 get_foo = 2; + oneof get_get_get_foo { + int32 get_get_foo = 3; + } + int32 foo = 1; +} + +message M6 { + // Note evaluation order - get_get_get_foo before get_get_foo + // initial name in Go | Clashes with field | type | final name + // GetGetFoo | - | - | GetGetFoo + // GetGetGetFoo | - | - | GetGetGetFoo + // | | | + // GetFoo | get_get_foo | 2 | GetFoo_ + // Foo | - | - | Foo + // New Scheme: + // initial name in Go | Clashes with field | type | Getter name + // Foo | get_foo | G | Get_Foo + // GetFoo | get_get_foo | G | Get_GetFoo + // GetGetGetFoo | - | - | GetGetGetGetFoo + oneof get_get_foo { + int32 get_get_get_foo = 3; + } + int32 get_foo = 2; + int32 foo = 1; +} + +message M7 { + // Note evaluation order - bar before get_get_foo, then get_get_get_foo + // initial name in Go | Clashes with field | type | final name + // GetGetFoo | - | - | GetGetFoo + // Bar | - | - | Bar + // GetFoo | foo | 1 | GetFoo_ + // | | | + // Foo | - | - | Foo + // New Scheme: + // initial name in Go | Clashes with field | type | Getter name + // Foo | get_foo | G | Get_Foo + // GetFoo | get_get_foo | G | Get_GetFoo + // Bar | - | - | GetBar + oneof get_get_foo { + bool bar = 4; + int32 get_foo = 3; + } + int32 foo = 1; +} + +message M8 { + // Note evaluation order - get_get_foo before get_get_get_foo + // initial name in Go | Clashes with field | type | final name + // GetGetGetFoo | get_get_foo | 1 | GetGetGetFoo_ + // GetGetFoo | - | - | GetGetFoo + // | | | + // GetFoo | get_get_foo | 2 | GetFoo_ + // Foo | - | - | Foo + // New Scheme: + // initial name in Go | Clashes with field | type | Getter name + // Foo | get_foo | G | Get_Foo + // GetFoo | get_get_foo | G | Get_GetFoo + // GetGetFoo | get_get_get_foo | G | Get_GetGetFoo + oneof get_get_get_foo { + int32 get_get_foo = 3; + } + int32 get_foo = 2; + int32 foo = 1; +} + +message M9 { + // Note evaluation order - get_get_foo before get_get_get_foo, then get_foo + // initial name in Go | Clashes with field | type | final name + // GetGetGetFoo | get_get_foo | 1 | GetGetGetFoo_ + // GetGetFoo | - | - | GetGetFoo + // GetFoo | get_get_foo | 2 | GetFoo_ + // | | | + // Foo | - | - | Foo + // New Scheme: + // initial name in Go | Clashes with field | type | Getter name + // Foo | get_foo | G | Get_Foo + // GetFoo | get_get_foo | G | Get_GetFoo + // GetGetFoo | get_get_get_foo | G | Get_GetGetFoo + oneof get_get_get_foo { + int32 get_get_foo = 3; + int32 get_foo = 2; + } + int32 foo = 1; +} + +message M10 { + // Set Clashes - no concerns with get-mangling as legacy open struct + // does not have setters except for weak fields: + // initial name in Go | Clashes with field | Setter name + // Foo | set_foo | Set_Foo + // SetFoo | - | SetSetFoo + int32 foo = 1; + int32 set_foo = 2; +} + +message M11 { + // Set Clashes - no concerns with get-mangling as legacy open struct + // does not have setters except for weak fields: + // initial name in Go | Clashes with field | Setter name + // Foo | set_foo | Set_Foo + // SetSetFoo | - | SetSetSetFoo + int32 foo = 1; + oneof set_foo { + int32 set_set_foo = 2; + } +} + +message M12 { + // Set Clashes - no concerns with get-mangling as legacy open struct + // does not have setters except for weak fields: + // initial name in Go | Clashes with field | Setter name + // Foo | set_foo | Set_Foo + // SetFoo | set_set_foo | Set_SetFoo + int32 foo = 1; + oneof set_set_foo { + int32 set_foo = 2; + } +} + +message M13 { + // Has Clashes - no concerns with get-mangling as legacy open struct + // does not have hassers except for weak fields: + // initial name in Go | Clashes with field | Hasser name + // Foo | has_foo | Has_Foo + // HasFoo | - | HasHasFoo + int32 foo = 1; + int32 has_foo = 2; +} + +message M14 { + // Has Clashes - no concerns with get-mangling as legacy open struct + // does not have hassers except for weak fields: + // initial name in Go | Clashes with field | Hasser name + // Foo | has_foo | Has_Foo + // HasFoo | has_has_foo | Has_HasFoo + // HasHasFoo | - | HasHasHasFoo + int32 foo = 1; + oneof has_foo { + int32 has_has_foo = 2; + } +} + +message M15 { + // Has Clashes - no concerns with get-mangling as legacy open struct + // does not have hassers except for weak fields: + // initial name in Go | Clashes with field | Hasser name + // Foo | has_foo | Has_Foo + // HasFoo | has_has_foo | Has_HasFoo + // HasHasFoo | - | HasHasHasFoo + int32 foo = 1; + oneof has_has_foo { + int32 has_foo = 2; + } +} + +message M16 { + // Clear Clashes - no concerns with get-mangling as legacy open + // struct does not have clearers except for weak fields: + // initial name in Go | Clashes with field | Clearer name + // Foo | clear_foo | Clear_Foo + // ClearFoo | - | ClearClearFoo + int32 foo = 1; + int32 clear_foo = 2; +} + +message M17 { + // Clear Clashes - no concerns with get-mangling as legacy open + // struct does not have clearers except for weak fields: + // initial name in Go | Clashes with field | Clearer name + // Foo | clear_foo | Clear_Foo + // ClearFoo | clear_clear_foo | Clear_ClearFoo + // ClearClearFoo | - | ClearClearClearFoo + int32 foo = 1; + oneof clear_foo { + int32 clear_clear_foo = 2; + } +} + +message M18 { + // Clear Clashes - no concerns with get-mangling as legacy open + // struct does not have clearers except for weak fields: + // initial name in Go | Clashes with field | Clearer name + // Foo | clear_foo | Clear_Foo + // ClearFoo | clear_clear_foo | Clear_ClearFoo + // ClearClearFoo | - | ClearClearClearFoo + int32 foo = 1; + oneof clear_clear_foo { + int32 clear_foo = 2; + } +} + +message M19 { + // Which Clashes - no concerns with get-mangling as legacy open + // struct does not have whichers except for weak fields: + // initial name in Go | Clashes with field | Whicher name + // Foo | - | - + // WhichFoo | - | - + // WhichWhichFoo | - | WhichWhichWhichFoo + int32 foo = 1; + oneof which_which_foo { + int32 which_foo = 2; + } +} + +message M20 { + // Which Clashes - no concerns with get-mangling as legacy open + // struct does not have whichers except for weak fields: + // initial name in Go | Clashes with field | Whicher name + // Foo | - | - + // WhichFoo | which_which_foo | Which_WhichFoo + // WhichWhichFoo | - | - + int32 foo = 1; + oneof which_foo { + int32 which_which_foo = 2; + } +} diff --git a/cmd/protoc-gen-go/testdata/nameclash/test_name_clash_opaque3.proto b/cmd/protoc-gen-go/testdata/nameclash/test_name_clash_opaque3.proto new file mode 100644 index 000000000..fdafeee88 --- /dev/null +++ b/cmd/protoc-gen-go/testdata/nameclash/test_name_clash_opaque3.proto @@ -0,0 +1,338 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This proto verifies that we keep the name mangling algorithm (which is +// position dependent) intact in the protoc_gen_go generator. The field names +// and the getter names have to be kept intact over time, both in the OPEN and +// in the HYBRID API. How fields are "mangled" is described in a comment per +// field. + +// The order of "evaluation" of fields is important. Fields are evaluated in +// order of appearance, except the oneof union names, that are evaluated after +// their first member. For each field, check if there is a previous field name +// or getter name that clashes with this field or it's getter. In case there is +// a clash, add an _ to the field name and repeat. In the case of oneof's, the +// union will be renamed if it clashes with it's first member, but not if it +// clashes with it's second. + +// This scheme is here for backwards compatibility. +// The type of clashes that can be are the following: +// 1 - My field name clashes with their getter name +// 2 - My getter name clashes with their field name + +edition = "2023"; + +package net.proto2.go.testdata.nameclashopaque3; + +import "google/protobuf/go_features.proto"; + +option go_package = "google.golang.org/protobuf/cmd/protoc-gen-go/testdata/nameclash/test_name_clash_opaque3"; + +option features.field_presence = IMPLICIT; +option features.(pb.go).api_level = API_OPAQUE; + +message M0 { + int32 i1 = 1; +} + +message M1 { + // Old Scheme: + // initial name in Go | Clashes with field | type | final name + // Foo | - | - | Foo + // GetFoo | foo | 1 | GetFoo_ + // GetGetFoo | - | - | GetGetFoo + // New Scheme: + // initial name in Go | Clashes with field | type | Getter name + // Foo | get_foo | G | Get_Foo + // GetFoo | get_get_foo | G | Get_GetFoo + // GetGetFoo | - | - | GetGetGetFoo + M0 foo = 1; + M0 get_foo = 2; + M0 get_get_foo = 3; +} + +message M2 { + // Old Scheme: + // initial name in Go | Clashes with field | type | final name + // GetGetFoo | - | - | GetGetFoo + // GetFoo | get_get_foo | 2 | GetFoo_ + // Foo | - | - | Foo + // New Scheme: + // initial name in Go | Clashes with field | type | Getter name + // Foo | get_foo | G | Get_Foo + // GetFoo | get_get_foo | G | Get_GetFoo + // GetGetFoo | - | - | GetGetGetFoo + M0 get_get_foo = 3; + M0 get_foo = 2; + M0 foo = 1; +} + +message M3 { + // Old Scheme: + // initial name in Go | Clashes with field | type | final name + // GetFoo | - | - | GetFoo + // GetGetFoo | get_foo | 1 | GetGetFoo_ + // Foo | get_foo | 2 | Foo_ + // New Scheme: + // initial name in Go | Clashes with field | type | Getter name + // Foo | get_foo | G | Get_Foo + // GetFoo | get_get_foo | G | Get_GetFoo + // GetGetFoo | - | - | GetGetGetFoo + M0 get_foo = 2; + M0 get_get_foo = 3; + M0 foo = 1; +} + +message M4 { + // Old Scheme: + // initial name in Go | Clashes with field | type | final name + // GetFoo | - | - | GetFoo + // GetGetFoo | get_foo | 1 | GetGetFoo_ + // GetGetGetFoo | - | - | GetGetGetFoo + // | | | + // Foo | get_foo | 2 | Foo_ + // New Scheme: + // initial name in Go | Clashes with field | type | Getter name + // Foo | get_foo | G | Get_Foo + // GetFoo | get_get_foo | G | Get_GetFoo + // GetGetGetFoo | - | - | GetGetGetGetFoo + M0 get_foo = 2; + oneof get_get_foo { + int32 get_get_get_foo = 3; + } + M0 foo = 1; +} + +message M5 { + // Old Scheme: + // Note evaluation order - get_get_foo before get_get_get_foo + // initial name in Go | Clashes with field | type | final name + // GetFoo | - | - | GetFoo + // GetGetGetFoo | - | - | GetGetGetFoo + // GetGetFoo | get_foo | 1 | GetGetFoo_ + // | | | + // Foo | get_foo | 2 | Foo_ + // New Scheme: + // initial name in Go | Clashes with field | type | Getter name + // Foo | get_foo | G | Get_Foo + // GetFoo | get_get_foo | G | Get_GetFoo + // GetGetFoo | get_get_get_foo | G | Get_GetGetFoo + M0 get_foo = 2; + oneof get_get_get_foo { + int32 get_get_foo = 3; + } + M0 foo = 1; +} + +message M6 { + // Note evaluation order - get_get_get_foo before get_get_foo + // initial name in Go | Clashes with field | type | final name + // GetGetFoo | - | - | GetGetFoo + // GetGetGetFoo | - | - | GetGetGetFoo + // | | | + // GetFoo | get_get_foo | 2 | GetFoo_ + // Foo | - | - | Foo + // New Scheme: + // initial name in Go | Clashes with field | type | Getter name + // Foo | get_foo | G | Get_Foo + // GetFoo | get_get_foo | G | Get_GetFoo + // GetGetGetFoo | - | - | GetGetGetGetFoo + oneof get_get_foo { + int32 get_get_get_foo = 3; + } + M0 get_foo = 2; + M0 foo = 1; +} + +message M7 { + // Note evaluation order - bar before get_get_foo, then get_get_get_foo + // initial name in Go | Clashes with field | type | final name + // GetGetFoo | - | - | GetGetFoo + // Bar | - | - | Bar + // GetFoo | foo | 1 | GetFoo_ + // | | | + // Foo | - | - | Foo + // New Scheme: + // initial name in Go | Clashes with field | type | Getter name + // Foo | get_foo | G | Get_Foo + // GetFoo | get_get_foo | G | Get_GetFoo + // Bar | - | - | GetBar + oneof get_get_foo { + bool bar = 4; + int32 get_foo = 3; + } + M0 foo = 1; +} + +message M8 { + // Note evaluation order - get_get_foo before get_get_get_foo + // initial name in Go | Clashes with field | type | final name + // GetGetGetFoo | get_get_foo | 1 | GetGetGetFoo_ + // GetGetFoo | - | - | GetGetFoo + // | | | + // GetFoo | get_get_foo | 2 | GetFoo_ + // Foo | - | - | Foo + // New Scheme: + // initial name in Go | Clashes with field | type | Getter name + // Foo | get_foo | G | Get_Foo + // GetFoo | get_get_foo | G | Get_GetFoo + // GetGetFoo | get_get_get_foo | G | Get_GetGetFoo + oneof get_get_get_foo { + int32 get_get_foo = 3; + } + M0 get_foo = 2; + M0 foo = 1; +} + +message M9 { + // Note evaluation order - get_get_foo before get_get_get_foo, then get_foo + // initial name in Go | Clashes with field | type | final name + // GetGetGetFoo | get_get_foo | 1 | GetGetGetFoo_ + // GetGetFoo | - | - | GetGetFoo + // GetFoo | get_get_foo | 2 | GetFoo_ + // | | | + // Foo | - | - | Foo + // New Scheme: + // initial name in Go | Clashes with field | type | Getter name + // Foo | get_foo | G | Get_Foo + // GetFoo | get_get_foo | G | Get_GetFoo + // GetGetFoo | get_get_get_foo | G | Get_GetGetFoo + oneof get_get_get_foo { + int32 get_get_foo = 3; + int32 get_foo = 2; + } + M0 foo = 1; +} + +message M10 { + // Set Clashes - no concerns with get-mangling as legacy open struct + // does not have setters except for weak fields: + // initial name in Go | Clashes with field | Setter name + // Foo | set_foo | Set_Foo + // SetFoo | - | SetSetFoo + M0 foo = 1; + M0 set_foo = 2; +} + +message M11 { + // Set Clashes - no concerns with get-mangling as legacy open struct + // does not have setters except for weak fields: + // initial name in Go | Clashes with field | Setter name + // Foo | set_foo | Set_Foo + // SetSetFoo | - | SetSetSetFoo + M0 foo = 1; + oneof set_foo { + int32 set_set_foo = 2; + } +} + +message M12 { + // Set Clashes - no concerns with get-mangling as legacy open struct + // does not have setters except for weak fields: + // initial name in Go | Clashes with field | Setter name + // Foo | set_foo | Set_Foo + // SetFoo | set_set_foo | Set_SetFoo + M0 foo = 1; + oneof set_set_foo { + int32 set_foo = 2; + } +} + +message M13 { + // Has Clashes - no concerns with get-mangling as legacy open struct + // does not have hassers except for weak fields: + // initial name in Go | Clashes with field | Hasser name + // Foo | has_foo | Has_Foo + // HasFoo | - | HasHasFoo + M0 foo = 1; + M0 has_foo = 2; +} + +message M14 { + // Has Clashes - no concerns with get-mangling as legacy open struct + // does not have hassers except for weak fields: + // initial name in Go | Clashes with field | Hasser name + // Foo | has_foo | Has_Foo + // HasFoo | has_has_foo | Has_HasFoo + // HasHasFoo | - | HasHasHasFoo + M0 foo = 1; + oneof has_foo { + int32 has_has_foo = 2; + } +} + +message M15 { + // Has Clashes - no concerns with get-mangling as legacy open struct + // does not have hassers except for weak fields: + // initial name in Go | Clashes with field | Hasser name + // Foo | has_foo | Has_Foo + // HasFoo | has_has_foo | Has_HasFoo + // HasHasFoo | - | HasHasHasFoo + M0 foo = 1; + oneof has_has_foo { + int32 has_foo = 2; + } +} + +message M16 { + // Clear Clashes - no concerns with get-mangling as legacy open + // struct does not have clearers except for weak fields: + // initial name in Go | Clashes with field | Clearer name + // Foo | clear_foo | Clear_Foo + // ClearFoo | - | ClearClearFoo + M0 foo = 1; + M0 clear_foo = 2; +} + +message M17 { + // Clear Clashes - no concerns with get-mangling as legacy open + // struct does not have clearers except for weak fields: + // initial name in Go | Clashes with field | Clearer name + // Foo | clear_foo | Clear_Foo + // ClearFoo | clear_clear_foo | Clear_ClearFoo + // ClearClearFoo | - | ClearClearClearFoo + M0 foo = 1; + oneof clear_foo { + int32 clear_clear_foo = 2; + } +} + +message M18 { + // Clear Clashes - no concerns with get-mangling as legacy open + // struct does not have clearers except for weak fields: + // initial name in Go | Clashes with field | Clearer name + // Foo | clear_foo | Clear_Foo + // ClearFoo | clear_clear_foo | Clear_ClearFoo + // ClearClearFoo | - | ClearClearClearFoo + M0 foo = 1; + oneof clear_clear_foo { + int32 clear_foo = 2; + } +} + +message M19 { + // Which Clashes - no concerns with get-mangling as legacy open + // struct does not have whichers except for weak fields: + // initial name in Go | Clashes with field | Whicher name + // Foo | - | - + // WhichFoo | - | - + // WhichWhichFoo | - | WhichWhichWhichFoo + M0 foo = 1; + oneof which_which_foo { + int32 which_foo = 2; + } +} + +message M20 { + // Which Clashes - no concerns with get-mangling as legacy open + // struct does not have whichers except for weak fields: + // initial name in Go | Clashes with field | Whicher name + // Foo | - | - + // WhichFoo | which_which_foo | Which_WhichFoo + // WhichWhichFoo | - | - + M0 foo = 1; + oneof which_foo { + int32 which_which_foo = 2; + } +} diff --git a/cmd/protoc-gen-go/testdata/nameclash/test_name_clash_open.proto b/cmd/protoc-gen-go/testdata/nameclash/test_name_clash_open.proto new file mode 100644 index 000000000..10775a979 --- /dev/null +++ b/cmd/protoc-gen-go/testdata/nameclash/test_name_clash_open.proto @@ -0,0 +1,151 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This proto verifies that we keep the name mangling algorithm (which is +// position dependent) intact in the protoc_gen_go generator. The field names +// and the getter names have to be kept intact over time, both in the OPEN and +// in the HYBRID API. How fields are "mangled" is described in a comment per +// field. + +// The order of "evaluation" of fields is important. Fields are evaluated in +// order of appearance, except the oneof union names, that are evaluated after +// their first member. For each field, check if there is a previous field name +// or getter name that clashes with this field or it's getter. In case there is +// a clash, add an _ to the field name and repeat. In the case of oneof's, the +// union will be renamed if it clashes with it's first member, but not if it +// clashes with it's second. + +// This scheme is here for backwards compatibility. +// The type of clashes that can be are the following: +// 1 - My field name clashes with their getter name +// 2 - My getter name clashes with their field name + +edition = "2023"; + +package net.proto2.go.testdata.nameclashopen; + +import "google/protobuf/go_features.proto"; + +option go_package = "google.golang.org/protobuf/cmd/protoc-gen-go/testdata/nameclash/test_name_clash_open"; + +option features.(pb.go).api_level = API_OPEN; + +message M1 { + // initial name in Go | Clashes with field | type | final name + // Foo | - | - | Foo + // GetFoo | foo | 1 | GetFoo_ + // GetGetFoo | - | - | GetGetFoo + int32 foo = 1; + int32 get_foo = 2; + int32 get_get_foo = 3; +} + +message M2 { + // initial name in Go | Clashes with field | type | final name + // GetGetFoo | - | - | GetGetFoo + // GetFoo | get_get_foo | 2 | GetFoo_ + // Foo | - | - | Foo + int32 get_get_foo = 3; + int32 get_foo = 2; + int32 foo = 1; +} + +message M3 { + // initial name in Go | Clashes with field | type | final name + // GetFoo | - | - | GetFoo + // GetGetFoo | get_foo | 1 | GetGetFoo_ + // Foo | get_foo | 2 | Foo_ + int32 get_foo = 2; + int32 get_get_foo = 3; + int32 foo = 1; +} + +message M4 { + // initial name in Go | Clashes with field | type | final name + // GetFoo | - | - | GetFoo + // GetGetFoo | get_foo | 1 | GetGetFoo_ + // GetGetGetFoo | - | - | GetGetGetFoo + // | | | + // Foo | get_foo | 2 | Foo_ + int32 get_foo = 2; + oneof get_get_foo { + int32 get_get_get_foo = 3; + } + int32 foo = 1; +} + +message M5 { + // Note evaluation order - get_get_foo before get_get_get_foo + // initial name in Go | Clashes with field | type | final name + // GetFoo | - | - | GetFoo + // GetGetGetFoo | - | - | GetGetGetFoo + // GetGetFoo | get_foo | 1 | GetGetFoo_ + // | | | + // Foo | get_foo | 2 | Foo_ + int32 get_foo = 2; + oneof get_get_get_foo { + int32 get_get_foo = 3; + } + int32 foo = 1; +} + +message M6 { + // Note evaluation order - get_get_get_foo before get_get_foo + // initial name in Go | Clashes with field | type | final name + // GetGetFoo | - | - | GetGetFoo + // GetGetGetFoo | - | - | GetGetGetFoo + // | | | + // GetFoo | get_get_foo | 2 | GetFoo_ + // Foo | - | - | Foo + oneof get_get_foo { + int32 get_get_get_foo = 3; + } + int32 get_foo = 2; + int32 foo = 1; +} + +message M7 { + // Note evaluation order - bar before get_get_foo, then get_get_get_foo + // initial name in Go | Clashes with field | type | final name + // GetGetFoo | - | - | GetGetFoo + // Bar | - | - | Bar + // GetFoo | foo | 1 | GetFoo_ + // | | | + // Foo | - | - | Foo + oneof get_get_foo { + bool bar = 4; + int32 get_foo = 3; + } + int32 foo = 1; +} + +message M8 { + // Note evaluation order - get_get_foo before get_get_get_foo + // initial name in Go | Clashes with field | type | final name + // GetGetGetFoo | get_get_foo | 1 | GetGetGetFoo_ + // GetGetFoo | - | - | GetGetFoo + // | | | + // GetFoo | get_get_foo | 2 | GetFoo_ + // Foo | - | - | Foo + oneof get_get_get_foo { + int32 get_get_foo = 3; + } + int32 get_foo = 2; + int32 foo = 1; +} + +message M9 { + // Note evaluation order - get_get_foo before get_get_get_foo, then get_foo + // initial name in Go | Clashes with field | type | final name + // GetGetGetFoo | get_get_foo | 1 | GetGetGetFoo_ + // GetGetFoo | - | - | GetGetFoo + // GetFoo | get_get_foo | 2 | GetFoo_ + // | | | + // Foo | - | - | Foo + oneof get_get_get_foo { + int32 get_get_foo = 3; + int32 get_foo = 2; + } + int32 foo = 1; +} diff --git a/cmd/protoc-gen-go/testdata/nameclash/test_name_clash_open3.proto b/cmd/protoc-gen-go/testdata/nameclash/test_name_clash_open3.proto new file mode 100644 index 000000000..442c046f4 --- /dev/null +++ b/cmd/protoc-gen-go/testdata/nameclash/test_name_clash_open3.proto @@ -0,0 +1,164 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This proto verifies that we keep the name mangling algorithm (which is +// position dependent) intact in the protoc_gen_go generator. The field names +// and the getter names have to be kept intact over time, both in the OPEN and +// in the HYBRID API. How fields are "mangled" is described in a comment per +// field. + +// The order of "evaluation" of fields is important. Fields are evaluated in +// order of appearance, except the oneof union names, that are evaluated after +// their first member. For each field, check if there is a previous field name +// or getter name that clashes with this field or it's getter. In case there is +// a clash, add an _ to the field name and repeat. In the case of oneof's, the +// union will be renamed if it clashes with it's first member, but not if it +// clashes with it's second. + +// This scheme is here for backwards compatibility. +// The type of clashes that can be are the following: +// 1 - My field name clashes with their getter name +// 2 - My getter name clashes with their field name + +edition = "2023"; + +package net.proto2.go.testdata.nameclashopen3; + +import "google/protobuf/go_features.proto"; + +option go_package = "google.golang.org/protobuf/cmd/protoc-gen-go/testdata/nameclash/test_name_clash_open3"; + +option features.field_presence = IMPLICIT; +option features.(pb.go).api_level = API_OPEN; + +message M0 { + int32 i1 = 1; +} + +message M1 { + // initial name in Go | Clashes with field | type | final name + // Foo | - | - | Foo + // GetFoo | foo | 1 | GetFoo_ + // GetGetFoo | - | - | GetGetFoo + M0 foo = 1; + M0 get_foo = 2; + M0 get_get_foo = 3; +} + +message M2 { + // initial name in Go | Clashes with field | type | final name + // GetGetFoo | - | - | GetGetFoo + // GetFoo | get_get_foo | 2 | GetFoo_ + // Foo | - | - | Foo + M0 get_get_foo = 3; + M0 get_foo = 2; + M0 foo = 1; +} + +message M3 { + // initial name in Go | Clashes with field | type | final name + // GetFoo | - | - | GetFoo + // GetGetFoo | get_foo | 1 | GetGetFoo_ + // Foo | get_foo | 2 | Foo_ + M0 get_foo = 2; + M0 get_get_foo = 3; + M0 foo = 1; +} + +message M4 { + // initial name in Go | Clashes with field | type | final name + // GetFoo | - | - | GetFoo + // GetGetFoo | get_foo | 1 | GetGetFoo_ + // GetGetGetFoo | - | - | GetGetGetFoo + // | | | + // Foo | get_foo | 2 | Foo_ + M0 get_foo = 2; + + oneof get_get_foo { + int32 get_get_get_foo = 3; + } + + M0 foo = 1; +} + +message M5 { + // Note evaluation order - get_get_foo before get_get_get_foo + // initial name in Go | Clashes with field | type | final name + // GetFoo | - | - | GetFoo + // GetGetGetFoo | - | - | GetGetGetFoo + // GetGetFoo | get_foo | 1 | GetGetFoo_ + // | | | + // Foo | get_foo | 2 | Foo_ + M0 get_foo = 2; + + oneof get_get_get_foo { + int32 get_get_foo = 3; + } + + M0 foo = 1; +} + +message M6 { + // Note evaluation order - get_get_get_foo before get_get_foo + // initial name in Go | Clashes with field | type | final name + // GetGetFoo | - | - | GetGetFoo + // GetGetGetFoo | - | - | GetGetGetFoo + // | | | + // GetFoo | get_get_foo | 2 | GetFoo_ + // Foo | - | - | Foo + oneof get_get_foo { + int32 get_get_get_foo = 3; + } + + M0 get_foo = 2; + M0 foo = 1; +} + +message M7 { + // Note evaluation order - bar before get_get_foo, then get_get_get_foo + // initial name in Go | Clashes with field | type | final name + // GetGetFoo | - | - | GetGetFoo + // Bar | - | - | Bar + // GetFoo | foo | 1 | GetFoo_ + // | | | + // Foo | - | - | Foo + oneof get_get_foo { + bool bar = 4; + int32 get_foo = 3; + } + + M0 foo = 1; +} + +message M8 { + // Note evaluation order - get_get_foo before get_get_get_foo + // initial name in Go | Clashes with field | type | final name + // GetGetGetFoo | get_get_foo | 1 | GetGetGetFoo_ + // GetGetFoo | - | - | GetGetFoo + // | | | + // GetFoo | get_get_foo | 2 | GetFoo_ + // Foo | - | - | Foo + oneof get_get_get_foo { + int32 get_get_foo = 3; + } + + M0 get_foo = 2; + M0 foo = 1; +} + +message M9 { + // Note evaluation order - get_get_foo before get_get_get_foo, then get_foo + // initial name in Go | Clashes with field | type | final name + // GetGetGetFoo | get_get_foo | 1 | GetGetGetFoo_ + // GetGetFoo | - | - | GetGetFoo + // GetFoo | get_get_foo | 2 | GetFoo_ + // | | | + // Foo | - | - | Foo + oneof get_get_get_foo { + int32 get_get_foo = 3; + int32 get_foo = 2; + } + + M0 foo = 1; +} diff --git a/compiler/protogen/protogen.go b/compiler/protogen/protogen.go index c2c9e9da7..0bff637c6 100644 --- a/compiler/protogen/protogen.go +++ b/compiler/protogen/protogen.go @@ -35,9 +35,9 @@ import ( "google.golang.org/protobuf/reflect/protodesc" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" + "google.golang.org/protobuf/types/dynamicpb" "google.golang.org/protobuf/types/descriptorpb" - "google.golang.org/protobuf/types/dynamicpb" "google.golang.org/protobuf/types/gofeaturespb" "google.golang.org/protobuf/types/pluginpb" ) @@ -166,6 +166,10 @@ type Options struct { // This struct field is for internal use by Go Protobuf only. Do not use it, // we might remove it at any time. InternalStripForEditionsDiff *bool + + // DefaultAPILevel overrides which API to generate by default (despite what + // the editions feature default specifies). One of OPEN, HYBRID or OPAQUE. + DefaultAPILevel gofeaturespb.GoFeatures_APILevel } // New returns a new Plugin. @@ -250,9 +254,9 @@ func (opts Options) New(req *pluginpb.CodeGeneratorRequest) (*Plugin, error) { // Alternatively, build systems which want to exert full control over // import paths may specify M= flags. for _, fdesc := range gen.Request.ProtoFile { + filename := fdesc.GetName() // The "M" command-line flags take precedence over // the "go_package" option in the .proto source file. - filename := fdesc.GetName() impPath, pkgName := splitImportPathAndPackageName(fdesc.GetOptions().GetGoPackage()) if importPaths[filename] == "" && impPath != "" { importPaths[filename] = impPath @@ -460,6 +464,9 @@ type File struct { GeneratedFilenamePrefix string location Location + + // APILevel specifies which API to generate. One of OPEN, HYBRID or OPAQUE. + APILevel gofeaturespb.GoFeatures_APILevel } func newFile(gen *Plugin, p *descriptorpb.FileDescriptorProto, packageName GoPackageName, importPath GoImportPath) (*File, error) { @@ -476,6 +483,8 @@ func newFile(gen *Plugin, p *descriptorpb.FileDescriptorProto, packageName GoPac GoPackageName: packageName, GoImportPath: importPath, location: Location{SourceFile: desc.Path()}, + + APILevel: fileAPILevel(desc, gen.defaultAPILevel()), } // Determine the prefix for generated Go files. @@ -655,6 +664,9 @@ type Message struct { Location Location // location of this message Comments CommentSet // comments associated with this message + + // APILevel specifies which API to generate. One of OPEN, HYBRID or OPAQUE. + APILevel gofeaturespb.GoFeatures_APILevel } func newMessage(gen *Plugin, f *File, parent *Message, desc protoreflect.MessageDescriptor) *Message { @@ -664,11 +676,20 @@ func newMessage(gen *Plugin, f *File, parent *Message, desc protoreflect.Message } else { loc = f.location.appendPath(genid.FileDescriptorProto_MessageType_field_number, desc.Index()) } + + def := f.APILevel + if parent != nil { + // editions feature semantics: applies to nested messages. + def = parent.APILevel + } + message := &Message{ Desc: desc, GoIdent: newGoIdent(f, desc), Location: loc, Comments: makeCommentSet(gen, f.Desc.SourceLocations().ByDescriptor(desc)), + + APILevel: messageAPILevel(desc, def), } gen.messagesByName[desc.FullName()] = message for i, eds := 0, desc.Enums(); i < eds.Len(); i++ { @@ -766,6 +787,8 @@ func newMessage(gen *Plugin, f *File, parent *Message, desc protoreflect.Message } } + opaqueNewMessageHook(message) + return message } @@ -812,6 +835,18 @@ type Field struct { Location Location // location of this field Comments CommentSet // comments associated with this field + + // camelCase is the same as GoName, but without the name + // mangling. This is used in builders, where only the single + // name "Build" needs to be mangled. + camelCase string + + // hasConflictHybrid tells us if we are to insert an '_' into + // the method names, (e.g. SetFoo becomes Set_Foo). This will + // be set even if we generate opaque protos, as we will want + // to potentially generate these method names anyway + // (opaque-v0). + hasConflictHybrid bool } func newField(gen *Plugin, f *File, message *Message, desc protoreflect.FieldDescriptor) *Field { @@ -840,6 +875,9 @@ func newField(gen *Plugin, f *File, message *Message, desc protoreflect.FieldDes Location: loc, Comments: makeCommentSet(gen, f.Desc.SourceLocations().ByDescriptor(desc)), } + + opaqueNewFieldHook(desc, field) + return field } @@ -890,13 +928,24 @@ type Oneof struct { Location Location // location of this oneof Comments CommentSet // comments associated with this oneof + + // camelCase is the same as GoName, but without the name mangling. + // This is used in builders, which never have their names mangled + camelCase string + + // hasConflictHybrid tells us if we are to insert an '_' into + // the method names, (e.g. SetFoo becomes Set_Foo). This will + // be set even if we generate opaque protos, as we will want + // to potentially generate these method names anyway + // (opaque-v0). + hasConflictHybrid bool } func newOneof(gen *Plugin, f *File, message *Message, desc protoreflect.OneofDescriptor) *Oneof { loc := message.Location.appendPath(genid.DescriptorProto_OneofDecl_field_number, desc.Index()) camelCased := strs.GoCamelCase(string(desc.Name())) parentPrefix := message.GoIdent.GoName + "_" - return &Oneof{ + oneof := &Oneof{ Desc: desc, Parent: message, GoName: camelCased, @@ -907,6 +956,10 @@ func newOneof(gen *Plugin, f *File, message *Message, desc protoreflect.OneofDes Location: loc, Comments: makeCommentSet(gen, f.Desc.SourceLocations().ByDescriptor(desc)), } + + opaqueNewOneofHook(desc, oneof) + + return oneof } // Extension is an alias of [Field] for documentation. diff --git a/compiler/protogen/protogen_apilevel.go b/compiler/protogen/protogen_apilevel.go new file mode 100644 index 000000000..27276fa42 --- /dev/null +++ b/compiler/protogen/protogen_apilevel.go @@ -0,0 +1,192 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package protogen + +import ( + "fmt" + + "google.golang.org/protobuf/internal/filedesc" + "google.golang.org/protobuf/internal/genid" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/gofeaturespb" +) + +func fileAPILevel(fd protoreflect.FileDescriptor, def gofeaturespb.GoFeatures_APILevel) gofeaturespb.GoFeatures_APILevel { + level := gofeaturespb.GoFeatures_API_OPEN + level = def + if fd, ok := fd.(*filedesc.File); ok { + al := fd.L1.EditionFeatures.APILevel + if al != genid.GoFeatures_API_LEVEL_UNSPECIFIED_enum_value { + level = gofeaturespb.GoFeatures_APILevel(al) + } + } + + return level +} + +func messageAPILevel(md protoreflect.MessageDescriptor, def gofeaturespb.GoFeatures_APILevel) gofeaturespb.GoFeatures_APILevel { + level := def + if md, ok := md.(*filedesc.Message); ok { + al := md.L1.EditionFeatures.APILevel + if al != genid.GoFeatures_API_LEVEL_UNSPECIFIED_enum_value { + level = gofeaturespb.GoFeatures_APILevel(al) + } + } + + return level +} + +func (p *Plugin) defaultAPILevel() gofeaturespb.GoFeatures_APILevel { + if p.opts.DefaultAPILevel != gofeaturespb.GoFeatures_API_LEVEL_UNSPECIFIED { + return p.opts.DefaultAPILevel + } + + return gofeaturespb.GoFeatures_API_OPEN +} + +// MethodName returns the (possibly mangled) name of the generated accessor +// method, along with the backwards-compatible name (if needed). +// +// method must be one of Get, Set, Has, Clear. MethodName panics otherwise. +func (field *Field) MethodName(method string) (name, compat string) { + switch method { + case "Get": + return field.getterName() + + case "Set": + return field.setterName() + + case "Has", "Clear": + return field.methodName(method), "" + + default: + panic(fmt.Sprintf("Field.MethodName called for unknown method %q", method)) + } +} + +// methodName returns the (possibly mangled) name of the generated method with +// the given prefix. +// +// For the Open API, the return value is "". +func (field *Field) methodName(prefix string) string { + switch field.Parent.APILevel { + case gofeaturespb.GoFeatures_API_OPEN: + // In the Open API, only generate getters (no Has or Clear methods). + return "" + + case gofeaturespb.GoFeatures_API_HYBRID: + var infix string + if field.hasConflictHybrid { + infix = "_" + } + return prefix + infix + field.camelCase + + case gofeaturespb.GoFeatures_API_OPAQUE: + return prefix + field.camelCase + + default: + panic("BUG: message is neither open, nor hybrid, nor opaque?!") + } +} + +// getterName returns the (possibly mangled) name of the generated Get method, +// along with the backwards-compatible name (if needed). +func (field *Field) getterName() (getter, compat string) { + switch field.Parent.APILevel { + case gofeaturespb.GoFeatures_API_OPEN: + // In the Open API, only generate a getter with the old style mangled name. + return "Get" + field.GoName, "" + + case gofeaturespb.GoFeatures_API_HYBRID: + // In the Hybrid API, return the mangled getter name and the old style + // name if needed, for backwards compatibility with the Open API. + var infix string + if field.hasConflictHybrid { + infix = "_" + } + orig := "Get" + infix + field.camelCase + mangled := "Get" + field.GoName + if mangled == orig { + mangled = "" + } + return orig, mangled + + case gofeaturespb.GoFeatures_API_OPAQUE: + return field.methodName("Get"), "" + + default: + panic("BUG: message is neither open, nor hybrid, nor opaque?!") + } +} + +// setterName returns the (possibly mangled) name of the generated Set method, +// along with the backwards-compatible name (if needed). +func (field *Field) setterName() (setter, compat string) { + // TODO(b/359846588): remove weak field support? + if field.Desc.IsWeak() && field.Parent.APILevel != gofeaturespb.GoFeatures_API_OPAQUE { + switch field.Parent.APILevel { + case gofeaturespb.GoFeatures_API_OPEN: + return "Set" + field.GoName, "" + + default: + var infix string + if field.hasConflictHybrid { + infix = "_" + } + orig := "Set" + infix + field.camelCase + mangled := "Set" + field.GoName + if mangled == orig { + mangled = "" + } + return orig, mangled + } + } + return field.methodName("Set"), "" +} + +// BuilderFieldName returns the name of this field in the corresponding _builder +// struct. +func (field *Field) BuilderFieldName() string { + return field.camelCase +} + +// MethodName returns the (possibly mangled) name of the generated accessor +// method. +// +// method must be one of Has, Clear, Which. MethodName panics otherwise. +func (oneof *Oneof) MethodName(method string) string { + switch method { + case "Has", "Clear", "Which": + return oneof.methodName(method) + + default: + panic(fmt.Sprintf("Oneof.MethodName called for unknown method %q", method)) + } +} + +// methodName returns the (possibly mangled) name of the generated method with +// the given prefix. +// +// For the Open API, the return value is "". +func (oneof *Oneof) methodName(prefix string) string { + switch oneof.Parent.APILevel { + case gofeaturespb.GoFeatures_API_OPEN: + // In the Open API, only generate getters. + return "" + + case gofeaturespb.GoFeatures_API_HYBRID: + var infix string + if oneof.hasConflictHybrid { + infix = "_" + } + return prefix + infix + oneof.camelCase + + case gofeaturespb.GoFeatures_API_OPAQUE: + return prefix + oneof.camelCase + + default: + panic("BUG: message is neither open, nor hybrid, nor opaque?!") + } +} diff --git a/compiler/protogen/protogen_opaque.go b/compiler/protogen/protogen_opaque.go new file mode 100644 index 000000000..8b11cdbd5 --- /dev/null +++ b/compiler/protogen/protogen_opaque.go @@ -0,0 +1,79 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package protogen + +import ( + "google.golang.org/protobuf/internal/strs" + "google.golang.org/protobuf/reflect/protoreflect" +) + +func opaqueNewFieldHook(desc protoreflect.FieldDescriptor, field *Field) { + field.camelCase = strs.GoCamelCase(string(desc.Name())) +} + +func opaqueNewOneofHook(desc protoreflect.OneofDescriptor, oneof *Oneof) { + oneof.camelCase = strs.GoCamelCase(string(desc.Name())) +} + +func opaqueNewMessageHook(message *Message) { + // New name mangling scheme: Add a '_' between method base + // name (Get, Set, Clear etc) and original field name if + // needed. As a special case, there is one globally reserved + // name, e.g. "Build" thet still results in actual renaming of + // the builder field like in the old scheme. We begin by + // taking care of this special case. + for _, field := range message.Fields { + if field.camelCase == "Build" { + field.camelCase += "_" + } + } + + // Then find all names of the original field names, we do not want the old scheme to affect + // how we name things. + + camelCases := map[string]bool{} + for _, field := range message.Fields { + if field.Oneof != nil { + // We add the name of the union here (potentially many times). + camelCases[field.Oneof.camelCase] = true + } + // The member fields of the oneof are considered fields in the struct although + // they are not technically there. This is to allow changing a proto2 optional + // to a oneof with source code compatibility. + camelCases[field.camelCase] = true + + } + // For each field, check if any of it's methods would clash with an original field name + for _, field := range message.Fields { + // Every field (except the union fields, that are taken care of separately) has + // a Get and a Set method. + methods := []string{"Set", "Get"} + // For explicit presence fields, we also have Has and Clear. + if field.Desc.HasPresence() { + methods = append(methods, "Has", "Clear") + } + for _, method := range methods { + // If any method name clashes with a field name, all methods get a + // "_" inserted between the operation and the field name. + if camelCases[method+field.camelCase] { + field.hasConflictHybrid = true + } + } + } + // The union names for oneofs need only have a methods prefix if there is a clash with Has, Clear or Which in + // hybrid and opaque-v0. + for _, field := range message.Fields { + if field.Oneof == nil { + continue + } + for _, method := range []string{"Has", "Clear", "Which"} { + // Same logic as for regular fields - all methods get the "_" if one needs it. + if camelCases[method+field.Oneof.camelCase] { + field.Oneof.hasConflictHybrid = true + } + } + } + +} diff --git a/encoding/prototext/testmessages_opaque_test.go b/encoding/prototext/testmessages_opaque_test.go new file mode 100644 index 000000000..9f2a7a756 --- /dev/null +++ b/encoding/prototext/testmessages_opaque_test.go @@ -0,0 +1,34 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package prototext_test + +import ( + "fmt" + "strings" + + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" + + _ "google.golang.org/protobuf/internal/testprotos/textpbeditions" + _ "google.golang.org/protobuf/internal/testprotos/textpbeditions/textpbeditions_opaque" +) + +var relatedMessages = func() map[protoreflect.MessageType][]protoreflect.MessageType { + related := map[protoreflect.MessageType][]protoreflect.MessageType{} + const opaqueNamePrefix = "opaque." + protoregistry.GlobalTypes.RangeMessages(func(mt protoreflect.MessageType) bool { + name := mt.Descriptor().FullName() + if !strings.HasPrefix(string(name), opaqueNamePrefix) { + return true + } + mt1, err := protoregistry.GlobalTypes.FindMessageByName(name[len(opaqueNamePrefix):]) + if err != nil { + panic(fmt.Sprintf("%v: can't find related message", name)) + } + related[mt1] = append(related[mt1], mt) + return true + }) + return related +}() diff --git a/encoding/prototext/testmessages_test.go b/encoding/prototext/testmessages_test.go index 4681db7f0..50da60f48 100644 --- a/encoding/prototext/testmessages_test.go +++ b/encoding/prototext/testmessages_test.go @@ -10,6 +10,13 @@ import ( ) func makeMessages(in protobuild.Message, messages ...proto.Message) []proto.Message { + + for _, m := range messages { + for _, mt := range relatedMessages[m.ProtoReflect().Type()] { + messages = append(messages, mt.New().Interface()) + } + } + for _, m := range messages { in.Build(m.ProtoReflect()) } diff --git a/integration_test.go b/integration_test.go index b05cdab27..0ec126fca 100644 --- a/integration_test.go +++ b/integration_test.go @@ -140,6 +140,7 @@ func TestIntegration(t *testing.T) { } runGo("Normal", command{}, "go", "test", "-race", "./...") + runGo("LazyDecoding", command{}, "go", "test", "./proto", "-test_lazy_unmarshal") runGo("Reflect", command{}, "go", "test", "-race", "-tags", "protoreflect", "./...") if goVersion == golangLatest { runGo("ProtoLegacyRace", command{}, "go", "test", "-race", "-tags", "protolegacy", "./...") diff --git a/internal/cmd/generate-protos/main.go b/internal/cmd/generate-protos/main.go index 5d352cfae..5d95201ef 100644 --- a/internal/cmd/generate-protos/main.go +++ b/internal/cmd/generate-protos/main.go @@ -101,10 +101,190 @@ func main() { // editions_default.binpb was not yet updated. generateEditionsDefaults() + // Generate versions of each testproto .proto file which use the Hybrid and + // Opaque API. This step needs to come first so that the next step will + // generate the .pb.go files for these extra .proto files. + generateOpaqueTestprotos() + generateLocalProtos() generateRemoteProtos() } +// gsed works roughly like sed(1), in that it processes a file with a list of +// replacement functions that are applied in order to each line. +func gsed(outFn, inFn string, repls ...func(line string) string) error { + if err := os.MkdirAll(filepath.Dir(outFn), 0755); err != nil { + return err + } + out, err := os.Create(outFn) + if err != nil { + return err + } + defer out.Close() + b, err := os.ReadFile(inFn) + if err != nil { + return err + } + lines := strings.Split(strings.TrimSpace(string(b)), "\n") + for idx, line := range lines { + for _, repl := range repls { + line = repl(line) + } + lines[idx] = line + } + if _, err := out.Write([]byte(strings.Join(lines, "\n"))); err != nil { + return err + } + return out.Close() +} + +// variantFn turns a relative path like +// internal/testprotos/annotation/annotation.proto into its corresponding +// Hybrid/Opaque API variant file name, +// e.g. internal/testprotos/annotation/annotation_hybrid/annotation.hybrid.proto +func variantFn(relPath, variant string) string { + base := strings.TrimSuffix(filepath.Base(relPath), ".proto") + dir := filepath.Dir(relPath) + return filepath.Join(dir, filepath.Base(dir)+"_"+variant, base) + "." + variant + ".proto" +} + +var ( + testProtoRe = regexp.MustCompile(`(internal/testprotos/.*[.]proto)`) + goPackageRe = regexp.MustCompile(`option go_package = "([^"]+)";`) + extRe = regexp.MustCompile(`_ext = ([0-9]+);`) +) + +func generateOpaqueDotProto(repoRoot, tmpDir, relPath string) { + // relPath is e.g. internal/testprotos/annotation/annotation.proto + ignored := func(p string) bool { + return strings.HasPrefix(p, "internal/testprotos/irregular") + } + inFn := filepath.Join(repoRoot, relPath) + + // create .hybrid.proto variant + hybridFn := variantFn(relPath, "hybrid") + outFn := filepath.Join(tmpDir, hybridFn) + check(gsed(outFn, inFn, []func(line string) string{ + func(line string) string { + if strings.HasPrefix(line, "package ") { + return strings.ReplaceAll(line, "package ", "package hybrid.") + } + return line + }, + func(line string) string { + if testProtoPath := testProtoRe.FindString(line); testProtoPath != "" && !ignored(testProtoPath) { + hybridFn := variantFn(testProtoPath, "hybrid") + return strings.ReplaceAll(line, testProtoPath, hybridFn) + } + return line + }, + func(line string) string { + if matches := goPackageRe.FindStringSubmatch(line); matches != nil { + goPkg := matches[1] + hybridGoPkg := strings.TrimSuffix(goPkg, "/") + "/" + filepath.Base(goPkg) + "_hybrid" + return `option go_package = "` + hybridGoPkg + `";` + "\n" + + `import "google/protobuf/go_features.proto";` + "\n" + + `option features.(pb.go).api_level = API_HYBRID;` + } + return line + }, + }...)) + + // create .opaque.proto variant + opaqueFn := variantFn(relPath, "opaque") + outFn = filepath.Join(tmpDir, opaqueFn) + check(gsed(outFn, inFn, []func(line string) string{ + func(line string) string { + if strings.HasPrefix(line, "package ") { + return strings.ReplaceAll(line, "package ", "package opaque.") + } + return line + }, + func(line string) string { + if testProtoPath := testProtoRe.FindString(line); testProtoPath != "" && !ignored(testProtoPath) { + hybridFn := variantFn(testProtoPath, "opaque") + return strings.ReplaceAll(line, testProtoPath, hybridFn) + } + return line + }, + func(line string) string { + if matches := goPackageRe.FindStringSubmatch(line); matches != nil { + goPkg := matches[1] + opaqueGoPkg := strings.TrimSuffix(goPkg, "/") + "/" + filepath.Base(goPkg) + "_opaque" + return `option go_package = "` + opaqueGoPkg + `";` + "\n" + + `import "google/protobuf/go_features.proto";` + "\n" + + `option features.(pb.go).api_level = API_OPAQUE;` + } + return line + }, + func(line string) string { + return strings.ReplaceAll(line, `full_name: ".goproto`, `full_name: ".opaque.goproto`) + }, + func(line string) string { + return strings.ReplaceAll(line, `type: ".goproto`, `type: ".opaque.goproto`) + }, + func(line string) string { + if matches := extRe.FindStringSubmatch(line); matches != nil { + trimmed := strings.TrimSuffix(matches[0], ";") + return strings.ReplaceAll(line, trimmed, trimmed+"0") + } + return line + }, + }...)) +} + +func generateOpaqueTestprotos() { + tmpDir, err := os.MkdirTemp(repoRoot, "tmp") + check(err) + defer os.RemoveAll(tmpDir) + + // Generate variants using the Hybrid and Opaque API for all local proto + // files (except version-locked files). + dirs := []struct { + path string + pkgPaths map[string]string // mapping of .proto path to Go package path + annotate map[string]bool // .proto files to annotate + exclude map[string]bool // .proto files to exclude from generation + }{ + {path: "internal/testprotos/required"}, + {path: "internal/testprotos/testeditions"}, + {path: "internal/testprotos/enums"}, + {path: "internal/testprotos/textpbeditions"}, + {path: "internal/testprotos/messageset"}, + { + path: "internal/testprotos/lazy", + exclude: map[string]bool{ + "internal/testprotos/lazy/lazy_extension_normalized_wire_test.proto": true, + "internal/testprotos/lazy/lazy_normalized_wire_test.proto": true, + "internal/testprotos/lazy/lazy_extension_test.proto": true, + }, + }, + } + excludeRx := regexp.MustCompile(`legacy/.*/`) + for _, d := range dirs { + srcDir := filepath.Join(repoRoot, filepath.FromSlash(d.path)) + filepath.Walk(srcDir, func(srcPath string, _ os.FileInfo, _ error) error { + if !strings.HasSuffix(srcPath, ".proto") || excludeRx.MatchString(srcPath) { + return nil + } + if strings.HasSuffix(srcPath, ".opaque.proto") || strings.HasSuffix(srcPath, ".hybrid.proto") { + return nil + } + relPath, err := filepath.Rel(repoRoot, srcPath) + check(err) + + if d.exclude[filepath.ToSlash(relPath)] { + return nil + } + + generateOpaqueDotProto(repoRoot, tmpDir, relPath) + return nil + }) + } + + syncOutput(repoRoot, tmpDir) +} + func generateEditionsDefaults() { dest := filepath.Join(repoRoot, "internal", "editiondefaults", "editions_defaults.binpb") srcDescriptorProto := filepath.Join(protoRoot, "src", "google", "protobuf", "descriptor.proto") @@ -472,7 +652,9 @@ func generateSourceContextStringer(gen *protogen.Plugin, file *protogen.File) { func syncOutput(dstDir, srcDir string) { filepath.Walk(srcDir, func(srcPath string, _ os.FileInfo, _ error) error { - if !strings.HasSuffix(srcPath, ".go") && !strings.HasSuffix(srcPath, ".meta") { + if !strings.HasSuffix(srcPath, ".go") && + !strings.HasSuffix(srcPath, ".meta") && + !strings.HasSuffix(srcPath, ".proto") { return nil } relPath, err := filepath.Rel(srcDir, srcPath) diff --git a/internal/cmd/generate-types/impl.go b/internal/cmd/generate-types/impl.go index 97bd25197..66aa69ba3 100644 --- a/internal/cmd/generate-types/impl.go +++ b/internal/cmd/generate-types/impl.go @@ -894,3 +894,127 @@ func merge{{.PointerMethod}}Slice(dst, src pointer, _ *coderFieldInfo, _ mergeOp {{end}} {{end}} `)) + +func generateImplField() string { + return mustExecute(implFieldTemplate, GoTypes) +} + +var implFieldTemplate = template.Must(template.New("").Parse(` +func getterForNullableScalar(fd protoreflect.FieldDescriptor, fs reflect.StructField, conv Converter, fieldOffset offset) func(p pointer) protoreflect.Value { + ft := fs.Type + if ft.Kind() == reflect.Ptr { + ft = ft.Elem() + } + if fd.Kind() == protoreflect.EnumKind { + elemType := fs.Type.Elem() + // Enums for nullable types. + return func(p pointer) protoreflect.Value { + if p.IsNil() { + return conv.Zero() + } + rv := p.Apply(fieldOffset).Elem().AsValueOf(elemType) + if rv.IsNil() { + return conv.Zero() + } + return conv.PBValueOf(rv.Elem()) + } + } + switch ft.Kind() { +{{range . }} +{{- if eq . "string"}} case reflect.String: +{{- /* Handle string GoType -> bytes proto type specially */}} + if fd.Kind() == protoreflect.BytesKind { + return func(p pointer) protoreflect.Value { + if p.IsNil() { + return conv.Zero() + } + x := p.Apply(fieldOffset).StringPtr() + if *x == nil { + return conv.Zero() + } + if len(**x) == 0 { + return protoreflect.ValueOfBytes(nil) + } + return protoreflect.ValueOfBytes([]byte(**x)) + } + } +{{else if eq . "[]byte" }} case reflect.Slice: +{{- /* Handle []byte GoType -> string proto type specially */}} + if fd.Kind() == protoreflect.StringKind { + return func(p pointer) protoreflect.Value { + if p.IsNil() { + return conv.Zero() + } + x := p.Apply(fieldOffset).Bytes() + if len(*x) == 0 { + return conv.Zero() + } + return protoreflect.ValueOfString(string(*x)) + } + } +{{else}} case {{.Kind }}: +{{end}} return func(p pointer) protoreflect.Value { + if p.IsNil() { + return conv.Zero() + } + x := p.Apply(fieldOffset).{{.NullablePointerMethod}}() + if *x == nil { + return conv.Zero() + } + return protoreflect.ValueOf{{.PointerMethod}}({{.NullableStar}}*x) + } +{{end}} } + panic("unexpected protobuf kind: "+ft.Kind().String()) +} + +func getterForDirectScalar(fd protoreflect.FieldDescriptor, fs reflect.StructField, conv Converter, fieldOffset offset) func(p pointer) protoreflect.Value { + ft := fs.Type + if fd.Kind() == protoreflect.EnumKind { + // Enums for non nullable types. + return func(p pointer) protoreflect.Value { + if p.IsNil() { + return conv.Zero() + } + rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem() + return conv.PBValueOf(rv) + } + } + switch ft.Kind() { +{{range . }} +{{- if eq . "string"}} case reflect.String: +{{- /* Handle string GoType -> bytes proto type specially */}} + if fd.Kind() == protoreflect.BytesKind { + return func(p pointer) protoreflect.Value { + if p.IsNil() { + return conv.Zero() + } + x := p.Apply(fieldOffset).String() + if len(*x) == 0 { + return protoreflect.ValueOfBytes(nil) + } + return protoreflect.ValueOfBytes([]byte(*x)) + } + } +{{else if eq . "[]byte" }} case reflect.Slice: +{{- /* Handle []byte GoType -> string proto type specially */}} + if fd.Kind() == protoreflect.StringKind { + return func(p pointer) protoreflect.Value { + if p.IsNil() { + return conv.Zero() + } + x := p.Apply(fieldOffset).Bytes() + return protoreflect.ValueOfString(string(*x)) + } + } +{{else}} case {{.Kind}}: +{{end}} return func(p pointer) protoreflect.Value { + if p.IsNil() { + return conv.Zero() + } + x := p.Apply(fieldOffset).{{.PointerMethod}}() + return protoreflect.ValueOf{{.PointerMethod}}(*x) + } +{{end}} } + panic("unexpected protobuf kind: "+ft.Kind().String()) +} +`)) diff --git a/internal/cmd/generate-types/impl_opaque.go b/internal/cmd/generate-types/impl_opaque.go new file mode 100644 index 000000000..cb405de45 --- /dev/null +++ b/internal/cmd/generate-types/impl_opaque.go @@ -0,0 +1,77 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "text/template" +) + +func generateImplMessageOpaque() string { + return mustExecute(messageOpaqueTemplate, GoTypes) +} + +var messageOpaqueTemplate = template.Must(template.New("").Parse(` +func getterForOpaqueNullableScalar(mi *MessageInfo, index uint32, fd protoreflect.FieldDescriptor, fs reflect.StructField, conv Converter, fieldOffset offset) func(p pointer) protoreflect.Value { + ft := fs.Type + if ft.Kind() == reflect.Ptr { + ft = ft.Elem() + } + if fd.Kind() == protoreflect.EnumKind { + // Enums for nullable opaque types. + return func(p pointer) protoreflect.Value { + if p.IsNil() || !mi.present(p, index) { + return conv.Zero() + } + rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem() + return conv.PBValueOf(rv) + } + } + switch ft.Kind() { +{{range . }} +{{- if eq . "string"}} case reflect.String: +{{- /* Handle string GoType -> bytes proto type specially */}} + if fd.Kind() == protoreflect.BytesKind { + return func(p pointer) protoreflect.Value { + if p.IsNil() || !mi.present(p, index) { + return conv.Zero() + } + x := p.Apply(fieldOffset).StringPtr() + if *x == nil { + return conv.Zero() + } + if len(**x) == 0 { + return protoreflect.ValueOfBytes(nil) + } + return protoreflect.ValueOfBytes([]byte(**x)) + } + } +{{else if eq . "[]byte" }} case reflect.Slice: +{{- /* Handle []byte GoType -> string proto type specially */}} + if fd.Kind() == protoreflect.StringKind { + return func(p pointer) protoreflect.Value { + if p.IsNil() || !mi.present(p, index) { + return conv.Zero() + } + x := p.Apply(fieldOffset).Bytes() + return protoreflect.ValueOfString(string(*x)) + } + } +{{else}} case {{.Kind}}: +{{end}} return func(p pointer) protoreflect.Value { + if p.IsNil() || !mi.present(p, index) { + return conv.Zero() + } + x := p.Apply(fieldOffset).{{.OpaqueNullablePointerMethod}}() +{{- if eq . "string"}} + if *x == nil { + return conv.Zero() + } +{{- end}} + return protoreflect.ValueOf{{.PointerMethod}}({{.OpaqueNullableStar}}*x) + } +{{end}} } + panic("unexpected protobuf kind: "+ft.Kind().String()) +} +`)) diff --git a/internal/cmd/generate-types/main.go b/internal/cmd/generate-types/main.go index de829f2fa..597bd642a 100644 --- a/internal/cmd/generate-types/main.go +++ b/internal/cmd/generate-types/main.go @@ -44,6 +44,8 @@ func main() { writeSource("internal/impl/codec_gen.go", generateImplCodec()) writeSource("internal/impl/message_reflect_gen.go", generateImplMessage()) writeSource("internal/impl/merge_gen.go", generateImplMerge()) + writeSource("internal/impl/message_reflect_field_gen.go", generateImplField()) + writeSource("internal/impl/message_opaque_gen.go", generateImplMessageOpaque()) writeSource("proto/decode_gen.go", generateProtoDecode()) writeSource("proto/encode_gen.go", generateProtoEncode()) writeSource("proto/size_gen.go", generateProtoSize()) diff --git a/internal/cmd/generate-types/proto.go b/internal/cmd/generate-types/proto.go index c70cc21f5..94b5d97af 100644 --- a/internal/cmd/generate-types/proto.go +++ b/internal/cmd/generate-types/proto.go @@ -88,6 +88,43 @@ func (g GoType) PointerMethod() Expr { return Expr(strings.ToUpper(string(g[:1])) + string(g[1:])) } +// NullablePointerMethod is the "internal/impl".pointer method used to access a nullable pointer to this type. +func (g GoType) NullablePointerMethod() Expr { + if g == GoBytes { + return "Bytes" // Bytes are already nullable + } + return Expr(strings.ToUpper(string(g[:1])) + string(g[1:]) + "Ptr") +} + +// NullableStar is the prefix for dereferencing a nullable value of this type "*" or "". +func (g GoType) NullableStar() Expr { + if g == GoBytes { + return "" // bytes are stored as a slice even when nullable + } + return "*" +} + +// OpaqueNullablePointerMethod is the "internal/impl".pointer method used to access a opaque nullable pointer to this type. +func (g GoType) OpaqueNullablePointerMethod() Expr { + switch g { + case GoString: + return "StringPtr" // Strings have indirection even in opaque + case GoBytes: + return "Bytes" + default: + return Expr(strings.ToUpper(string(g[:1])) + string(g[1:])) + } + +} + +// OpaqueNullableStar is the prefix for dereferencing a opaque nullable value of this type. +func (g GoType) OpaqueNullableStar() Expr { + if g == GoString { + return "*" // Strings have indirection even in opaque + } + return "" +} + type ProtoKind struct { Name string WireType WireType diff --git a/internal/filedesc/build_test.go b/internal/filedesc/build_test.go index 8dd7323bf..60e7152d5 100644 --- a/internal/filedesc/build_test.go +++ b/internal/filedesc/build_test.go @@ -54,7 +54,9 @@ func TestInit(t *testing.T) { descPkg.Append("FileDescriptorProto.source_code_info"): true, descPkg.Append("FileDescriptorProto.syntax"): true, // Nothing is using edition yet. - descPkg.Append("FileDescriptorProto.edition"): true, + descPkg.Append("FileDescriptorProto.edition"): true, + descPkg.Append("FileDescriptorProto.edition_enum"): true, + descPkg.Append("FileDescriptorProto.edition_deprecated"): true, // Impossible to test proto3 optional in a proto2 file. descPkg.Append("FieldDescriptorProto.proto3_optional"): true, diff --git a/internal/filedesc/desc.go b/internal/filedesc/desc.go index f32529856..378b826fa 100644 --- a/internal/filedesc/desc.go +++ b/internal/filedesc/desc.go @@ -117,6 +117,9 @@ type ( // GenerateLegacyUnmarshalJSON determines if the plugin generates the // UnmarshalJSON([]byte) error method for enums. GenerateLegacyUnmarshalJSON bool + // APILevel controls which API (Open, Hybrid or Opaque) should be used + // for generated code (.pb.go files). + APILevel int } ) diff --git a/internal/filedesc/editions.go b/internal/filedesc/editions.go index 7611796e8..10132c9b3 100644 --- a/internal/filedesc/editions.go +++ b/internal/filedesc/editions.go @@ -32,6 +32,10 @@ func unmarshalGoFeature(b []byte, parent EditionFeatures) EditionFeatures { v, m := protowire.ConsumeVarint(b) b = b[m:] parent.GenerateLegacyUnmarshalJSON = protowire.DecodeBool(v) + case genid.GoFeatures_ApiLevel_field_number: + v, m := protowire.ConsumeVarint(b) + b = b[m:] + parent.APILevel = int(v) case genid.GoFeatures_StripEnumPrefix_field_number: v, m := protowire.ConsumeVarint(b) b = b[m:] diff --git a/internal/genid/descriptor_gen.go b/internal/genid/descriptor_gen.go index f30ab6b58..30a2fa6d9 100644 --- a/internal/genid/descriptor_gen.go +++ b/internal/genid/descriptor_gen.go @@ -1120,20 +1120,26 @@ const ( // Field names for google.protobuf.FeatureSetDefaults. const ( - FeatureSetDefaults_Defaults_field_name protoreflect.Name = "defaults" - FeatureSetDefaults_MinimumEdition_field_name protoreflect.Name = "minimum_edition" - FeatureSetDefaults_MaximumEdition_field_name protoreflect.Name = "maximum_edition" + FeatureSetDefaults_Defaults_field_name protoreflect.Name = "defaults" + FeatureSetDefaults_MinimumEditionDeprecated_field_name protoreflect.Name = "minimum_edition" + FeatureSetDefaults_MaximumEditionDeprecated_field_name protoreflect.Name = "maximum_edition" + FeatureSetDefaults_MinimumEdition_field_name protoreflect.Name = "minimum_edition" + FeatureSetDefaults_MaximumEdition_field_name protoreflect.Name = "maximum_edition" - FeatureSetDefaults_Defaults_field_fullname protoreflect.FullName = "google.protobuf.FeatureSetDefaults.defaults" - FeatureSetDefaults_MinimumEdition_field_fullname protoreflect.FullName = "google.protobuf.FeatureSetDefaults.minimum_edition" - FeatureSetDefaults_MaximumEdition_field_fullname protoreflect.FullName = "google.protobuf.FeatureSetDefaults.maximum_edition" + FeatureSetDefaults_Defaults_field_fullname protoreflect.FullName = "google.protobuf.FeatureSetDefaults.defaults" + FeatureSetDefaults_MinimumEditionDeprecated_field_fullname protoreflect.FullName = "google.protobuf.FeatureSetDefaults.minimum_edition" + FeatureSetDefaults_MaximumEditionDeprecated_field_fullname protoreflect.FullName = "google.protobuf.FeatureSetDefaults.maximum_edition" + FeatureSetDefaults_MinimumEdition_field_fullname protoreflect.FullName = "google.protobuf.FeatureSetDefaults.minimum_edition" + FeatureSetDefaults_MaximumEdition_field_fullname protoreflect.FullName = "google.protobuf.FeatureSetDefaults.maximum_edition" ) // Field numbers for google.protobuf.FeatureSetDefaults. const ( - FeatureSetDefaults_Defaults_field_number protoreflect.FieldNumber = 1 - FeatureSetDefaults_MinimumEdition_field_number protoreflect.FieldNumber = 4 - FeatureSetDefaults_MaximumEdition_field_number protoreflect.FieldNumber = 5 + FeatureSetDefaults_Defaults_field_number protoreflect.FieldNumber = 1 + FeatureSetDefaults_MinimumEditionDeprecated_field_number protoreflect.FieldNumber = 2 + FeatureSetDefaults_MaximumEditionDeprecated_field_number protoreflect.FieldNumber = 3 + FeatureSetDefaults_MinimumEdition_field_number protoreflect.FieldNumber = 4 + FeatureSetDefaults_MaximumEdition_field_number protoreflect.FieldNumber = 5 ) // Names for google.protobuf.FeatureSetDefaults.FeatureSetEditionDefault. @@ -1147,10 +1153,12 @@ const ( FeatureSetDefaults_FeatureSetEditionDefault_Edition_field_name protoreflect.Name = "edition" FeatureSetDefaults_FeatureSetEditionDefault_OverridableFeatures_field_name protoreflect.Name = "overridable_features" FeatureSetDefaults_FeatureSetEditionDefault_FixedFeatures_field_name protoreflect.Name = "fixed_features" + FeatureSetDefaults_FeatureSetEditionDefault_Features_field_name protoreflect.Name = "features" FeatureSetDefaults_FeatureSetEditionDefault_Edition_field_fullname protoreflect.FullName = "google.protobuf.FeatureSetDefaults.FeatureSetEditionDefault.edition" FeatureSetDefaults_FeatureSetEditionDefault_OverridableFeatures_field_fullname protoreflect.FullName = "google.protobuf.FeatureSetDefaults.FeatureSetEditionDefault.overridable_features" FeatureSetDefaults_FeatureSetEditionDefault_FixedFeatures_field_fullname protoreflect.FullName = "google.protobuf.FeatureSetDefaults.FeatureSetEditionDefault.fixed_features" + FeatureSetDefaults_FeatureSetEditionDefault_Features_field_fullname protoreflect.FullName = "google.protobuf.FeatureSetDefaults.FeatureSetEditionDefault.features" ) // Field numbers for google.protobuf.FeatureSetDefaults.FeatureSetEditionDefault. @@ -1158,6 +1166,7 @@ const ( FeatureSetDefaults_FeatureSetEditionDefault_Edition_field_number protoreflect.FieldNumber = 3 FeatureSetDefaults_FeatureSetEditionDefault_OverridableFeatures_field_number protoreflect.FieldNumber = 4 FeatureSetDefaults_FeatureSetEditionDefault_FixedFeatures_field_number protoreflect.FieldNumber = 5 + FeatureSetDefaults_FeatureSetEditionDefault_Features_field_number protoreflect.FieldNumber = 2 ) // Names for google.protobuf.SourceCodeInfo. diff --git a/internal/genid/go_features_gen.go b/internal/genid/go_features_gen.go index 09792d96f..f5ee7f5c2 100644 --- a/internal/genid/go_features_gen.go +++ b/internal/genid/go_features_gen.go @@ -21,18 +21,35 @@ const ( // Field names for pb.GoFeatures. const ( GoFeatures_LegacyUnmarshalJsonEnum_field_name protoreflect.Name = "legacy_unmarshal_json_enum" + GoFeatures_ApiLevel_field_name protoreflect.Name = "api_level" GoFeatures_StripEnumPrefix_field_name protoreflect.Name = "strip_enum_prefix" GoFeatures_LegacyUnmarshalJsonEnum_field_fullname protoreflect.FullName = "pb.GoFeatures.legacy_unmarshal_json_enum" + GoFeatures_ApiLevel_field_fullname protoreflect.FullName = "pb.GoFeatures.api_level" GoFeatures_StripEnumPrefix_field_fullname protoreflect.FullName = "pb.GoFeatures.strip_enum_prefix" ) // Field numbers for pb.GoFeatures. const ( GoFeatures_LegacyUnmarshalJsonEnum_field_number protoreflect.FieldNumber = 1 + GoFeatures_ApiLevel_field_number protoreflect.FieldNumber = 2 GoFeatures_StripEnumPrefix_field_number protoreflect.FieldNumber = 3 ) +// Full and short names for pb.GoFeatures.APILevel. +const ( + GoFeatures_APILevel_enum_fullname = "pb.GoFeatures.APILevel" + GoFeatures_APILevel_enum_name = "APILevel" +) + +// Enum values for pb.GoFeatures.APILevel. +const ( + GoFeatures_API_LEVEL_UNSPECIFIED_enum_value = 0 + GoFeatures_API_OPEN_enum_value = 1 + GoFeatures_API_HYBRID_enum_value = 2 + GoFeatures_API_OPAQUE_enum_value = 3 +) + // Full and short names for pb.GoFeatures.StripEnumPrefix. const ( GoFeatures_StripEnumPrefix_enum_fullname = "pb.GoFeatures.StripEnumPrefix" diff --git a/internal/genid/name.go b/internal/genid/name.go new file mode 100644 index 000000000..224f33930 --- /dev/null +++ b/internal/genid/name.go @@ -0,0 +1,12 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package genid + +const ( + NoUnkeyedLiteral_goname = "noUnkeyedLiteral" + NoUnkeyedLiteralA_goname = "XXX_NoUnkeyedLiteral" + + BuilderSuffix_goname = "_builder" +) diff --git a/internal/impl/api_export_opaque.go b/internal/impl/api_export_opaque.go new file mode 100644 index 000000000..6075d6f69 --- /dev/null +++ b/internal/impl/api_export_opaque.go @@ -0,0 +1,128 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package impl + +import ( + "strconv" + "sync/atomic" + "unsafe" + + "google.golang.org/protobuf/reflect/protoreflect" +) + +func (Export) UnmarshalField(msg any, fieldNum int32) { + UnmarshalField(msg.(protoreflect.ProtoMessage).ProtoReflect(), protoreflect.FieldNumber(fieldNum)) +} + +// Present checks the presence set for a certain field number (zero +// based, ordered by appearance in original proto file). part is +// a pointer to the correct element in the bitmask array, num is the +// field number unaltered. Example (field number 70 -> part = +// &m.XXX_presence[1], num = 70) +func (Export) Present(part *uint32, num uint32) bool { + // This hook will read an unprotected shadow presence set if + // we're unning under the race detector + raceDetectHookPresent(part, num) + return atomic.LoadUint32(part)&(1<<(num%32)) > 0 +} + +// SetPresent adds a field to the presence set. part is a pointer to +// the relevant element in the array and num is the field number +// unaltered. size is the number of fields in the protocol +// buffer. +func (Export) SetPresent(part *uint32, num uint32, size uint32) { + // This hook will mutate an unprotected shadow presence set if + // we're running under the race detector + raceDetectHookSetPresent(part, num, presenceSize(size)) + for { + old := atomic.LoadUint32(part) + if atomic.CompareAndSwapUint32(part, old, old|(1<<(num%32))) { + return + } + } +} + +// SetPresentNonAtomic is like SetPresent, but operates non-atomically. +// It is meant for use by builder methods, where the message is known not +// to be accessible yet by other goroutines. +func (Export) SetPresentNonAtomic(part *uint32, num uint32, size uint32) { + // This hook will mutate an unprotected shadow presence set if + // we're running under the race detector + raceDetectHookSetPresent(part, num, presenceSize(size)) + *part |= 1 << (num % 32) +} + +// ClearPresence removes a field from the presence set. part is a +// pointer to the relevant element in the presence array and num is +// the field number unaltered. +func (Export) ClearPresent(part *uint32, num uint32) { + // This hook will mutate an unprotected shadow presence set if + // we're running under the race detector + raceDetectHookClearPresent(part, num) + for { + old := atomic.LoadUint32(part) + if atomic.CompareAndSwapUint32(part, old, old&^(1<<(num%32))) { + return + } + } +} + +// interfaceToPointer takes a pointer to an empty interface whose value is a +// pointer type, and converts it into a "pointer" that points to the same +// target +func interfaceToPointer(i *any) pointer { + return pointer{p: (*[2]unsafe.Pointer)(unsafe.Pointer(i))[1]} +} + +func (p pointer) atomicGetPointer() pointer { + return pointer{p: atomic.LoadPointer((*unsafe.Pointer)(p.p))} +} + +func (p pointer) atomicSetPointer(q pointer) { + atomic.StorePointer((*unsafe.Pointer)(p.p), q.p) +} + +// AtomicCheckPointerIsNil takes an interface (which is a pointer to a +// pointer) and returns true if the pointed-to pointer is nil (using an +// atomic load). This function is inlineable and, on x86, just becomes a +// simple load and compare. +func (Export) AtomicCheckPointerIsNil(ptr any) bool { + return interfaceToPointer(&ptr).atomicGetPointer().IsNil() +} + +// AtomicSetPointer takes two interfaces (first is a pointer to a pointer, +// second is a pointer) and atomically sets the second pointer into location +// referenced by first pointer. Unfortunately, atomicSetPointer() does not inline +// (even on x86), so this does not become a simple store on x86. +func (Export) AtomicSetPointer(dstPtr, valPtr any) { + interfaceToPointer(&dstPtr).atomicSetPointer(interfaceToPointer(&valPtr)) +} + +// AtomicLoadPointer loads the pointer at the location pointed at by src, +// and stores that pointer value into the location pointed at by dst. +func (Export) AtomicLoadPointer(ptr Pointer, dst Pointer) { + *(*unsafe.Pointer)(unsafe.Pointer(dst)) = atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(ptr))) +} + +// AtomicInitializePointer makes ptr and dst point to the same value. +// +// If *ptr is a nil pointer, it sets *ptr = *dst. +// +// If *ptr is a non-nil pointer, it sets *dst = *ptr. +func (Export) AtomicInitializePointer(ptr Pointer, dst Pointer) { + if !atomic.CompareAndSwapPointer((*unsafe.Pointer)(ptr), unsafe.Pointer(nil), *(*unsafe.Pointer)(dst)) { + *(*unsafe.Pointer)(unsafe.Pointer(dst)) = atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(ptr))) + } +} + +// MessageFieldStringOf returns the field formatted as a string, +// either as the field name if resolvable otherwise as a decimal string. +func (Export) MessageFieldStringOf(md protoreflect.MessageDescriptor, n protoreflect.FieldNumber) string { + fd := md.Fields().ByNumber(n) + if fd != nil { + return string(fd.Name()) + } + return strconv.Itoa(int(n)) +} diff --git a/internal/impl/bitmap.go b/internal/impl/bitmap.go new file mode 100644 index 000000000..ea276547c --- /dev/null +++ b/internal/impl/bitmap.go @@ -0,0 +1,34 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !race + +package impl + +// There is no additional data as we're not running under race detector. +type RaceDetectHookData struct{} + +// Empty stubs for when not using the race detector. Calls to these from index.go should be optimized away. +func (presence) raceDetectHookPresent(num uint32) {} +func (presence) raceDetectHookSetPresent(num uint32, size presenceSize) {} +func (presence) raceDetectHookClearPresent(num uint32) {} +func (presence) raceDetectHookAllocAndCopy(src presence) {} + +// raceDetectHookPresent is called by the generated file interface +// (*proto.internalFuncs) Present to optionally read an unprotected +// shadow bitmap when race detection is enabled. In regular code it is +// a noop. +func raceDetectHookPresent(field *uint32, num uint32) {} + +// raceDetectHookSetPresent is called by the generated file interface +// (*proto.internalFuncs) SetPresent to optionally write an unprotected +// shadow bitmap when race detection is enabled. In regular code it is +// a noop. +func raceDetectHookSetPresent(field *uint32, num uint32, size presenceSize) {} + +// raceDetectHookClearPresent is called by the generated file interface +// (*proto.internalFuncs) ClearPresent to optionally write an unprotected +// shadow bitmap when race detection is enabled. In regular code it is +// a noop. +func raceDetectHookClearPresent(field *uint32, num uint32) {} diff --git a/internal/impl/bitmap_race.go b/internal/impl/bitmap_race.go new file mode 100644 index 000000000..e9a27583a --- /dev/null +++ b/internal/impl/bitmap_race.go @@ -0,0 +1,126 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build race + +package impl + +// When running under race detector, we add a presence map of bytes, that we can access +// in the hook functions so that we trigger the race detection whenever we have concurrent +// Read-Writes or Write-Writes. The race detector does not otherwise detect invalid concurrent +// access to lazy fields as all updates of bitmaps and pointers are done using atomic operations. +type RaceDetectHookData struct { + shadowPresence *[]byte +} + +// Hooks for presence bitmap operations that allocate, read and write the shadowPresence +// using non-atomic operations. +func (data *RaceDetectHookData) raceDetectHookAlloc(size presenceSize) { + sp := make([]byte, size) + atomicStoreShadowPresence(&data.shadowPresence, &sp) +} + +func (p presence) raceDetectHookPresent(num uint32) { + data := p.toRaceDetectData() + if data == nil { + return + } + sp := atomicLoadShadowPresence(&data.shadowPresence) + if sp != nil { + _ = (*sp)[num] + } +} + +func (p presence) raceDetectHookSetPresent(num uint32, size presenceSize) { + data := p.toRaceDetectData() + if data == nil { + return + } + sp := atomicLoadShadowPresence(&data.shadowPresence) + if sp == nil { + data.raceDetectHookAlloc(size) + sp = atomicLoadShadowPresence(&data.shadowPresence) + } + (*sp)[num] = 1 +} + +func (p presence) raceDetectHookClearPresent(num uint32) { + data := p.toRaceDetectData() + if data == nil { + return + } + sp := atomicLoadShadowPresence(&data.shadowPresence) + if sp != nil { + (*sp)[num] = 0 + + } +} + +// raceDetectHookAllocAndCopy allocates a new shadowPresence slice at lazy and copies +// shadowPresence bytes from src to lazy. +func (p presence) raceDetectHookAllocAndCopy(q presence) { + sData := q.toRaceDetectData() + dData := p.toRaceDetectData() + if sData == nil { + return + } + srcSp := atomicLoadShadowPresence(&sData.shadowPresence) + if srcSp == nil { + atomicStoreShadowPresence(&dData.shadowPresence, nil) + return + } + n := len(*srcSp) + dSlice := make([]byte, n) + atomicStoreShadowPresence(&dData.shadowPresence, &dSlice) + for i := 0; i < n; i++ { + dSlice[i] = (*srcSp)[i] + } +} + +// raceDetectHookPresent is called by the generated file interface +// (*proto.internalFuncs) Present to optionally read an unprotected +// shadow bitmap when race detection is enabled. In regular code it is +// a noop. +func raceDetectHookPresent(field *uint32, num uint32) { + data := findPointerToRaceDetectData(field, num) + if data == nil { + return + } + sp := atomicLoadShadowPresence(&data.shadowPresence) + if sp != nil { + _ = (*sp)[num] + } +} + +// raceDetectHookSetPresent is called by the generated file interface +// (*proto.internalFuncs) SetPresent to optionally write an unprotected +// shadow bitmap when race detection is enabled. In regular code it is +// a noop. +func raceDetectHookSetPresent(field *uint32, num uint32, size presenceSize) { + data := findPointerToRaceDetectData(field, num) + if data == nil { + return + } + sp := atomicLoadShadowPresence(&data.shadowPresence) + if sp == nil { + data.raceDetectHookAlloc(size) + sp = atomicLoadShadowPresence(&data.shadowPresence) + } + (*sp)[num] = 1 +} + +// raceDetectHookClearPresent is called by the generated file interface +// (*proto.internalFuncs) ClearPresent to optionally write an unprotected +// shadow bitmap when race detection is enabled. In regular code it is +// a noop. +func raceDetectHookClearPresent(field *uint32, num uint32) { + data := findPointerToRaceDetectData(field, num) + if data == nil { + return + } + sp := atomicLoadShadowPresence(&data.shadowPresence) + if sp != nil { + (*sp)[num] = 0 + } +} diff --git a/internal/impl/checkinit.go b/internal/impl/checkinit.go index f29e6a8fa..fe2c719ce 100644 --- a/internal/impl/checkinit.go +++ b/internal/impl/checkinit.go @@ -35,6 +35,12 @@ func (mi *MessageInfo) checkInitializedPointer(p pointer) error { } return nil } + + var presence presence + if mi.presenceOffset.IsValid() { + presence = p.Apply(mi.presenceOffset).PresenceInfo() + } + if mi.extensionOffset.IsValid() { e := p.Apply(mi.extensionOffset).Extensions() if err := mi.isInitExtensions(e); err != nil { @@ -45,6 +51,33 @@ func (mi *MessageInfo) checkInitializedPointer(p pointer) error { if !f.isRequired && f.funcs.isInit == nil { continue } + + if f.presenceIndex != noPresence { + if !presence.Present(f.presenceIndex) { + if f.isRequired { + return errors.RequiredNotSet(string(mi.Desc.Fields().ByNumber(f.num).FullName())) + } + continue + } + if f.funcs.isInit != nil { + f.mi.init() + if f.mi.needsInitCheck { + if f.isLazy && p.Apply(f.offset).AtomicGetPointer().IsNil() { + lazy := *p.Apply(mi.lazyOffset).LazyInfoPtr() + if !lazy.AllowedPartial() { + // Nothing to see here, it was checked on unmarshal + continue + } + mi.lazyUnmarshal(p, f.num) + } + if err := f.funcs.isInit(p.Apply(f.offset), f); err != nil { + return err + } + } + } + continue + } + fptr := p.Apply(f.offset) if f.isPointer && fptr.Elem().IsNil() { if f.isRequired { diff --git a/internal/impl/codec_field_opaque.go b/internal/impl/codec_field_opaque.go new file mode 100644 index 000000000..76818ea25 --- /dev/null +++ b/internal/impl/codec_field_opaque.go @@ -0,0 +1,264 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package impl + +import ( + "fmt" + "reflect" + + "google.golang.org/protobuf/encoding/protowire" + "google.golang.org/protobuf/internal/errors" + "google.golang.org/protobuf/reflect/protoreflect" +) + +func makeOpaqueMessageFieldCoder(fd protoreflect.FieldDescriptor, ft reflect.Type) (*MessageInfo, pointerCoderFuncs) { + mi := getMessageInfo(ft) + if mi == nil { + panic(fmt.Sprintf("invalid field: %v: unsupported message type %v", fd.FullName(), ft)) + } + switch fd.Kind() { + case protoreflect.MessageKind: + return mi, pointerCoderFuncs{ + size: sizeOpaqueMessage, + marshal: appendOpaqueMessage, + unmarshal: consumeOpaqueMessage, + isInit: isInitOpaqueMessage, + merge: mergeOpaqueMessage, + } + case protoreflect.GroupKind: + return mi, pointerCoderFuncs{ + size: sizeOpaqueGroup, + marshal: appendOpaqueGroup, + unmarshal: consumeOpaqueGroup, + isInit: isInitOpaqueMessage, + merge: mergeOpaqueMessage, + } + } + panic("unexpected field kind") +} + +func sizeOpaqueMessage(p pointer, f *coderFieldInfo, opts marshalOptions) (size int) { + return protowire.SizeBytes(f.mi.sizePointer(p.AtomicGetPointer(), opts)) + f.tagsize +} + +func appendOpaqueMessage(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) { + mp := p.AtomicGetPointer() + calculatedSize := f.mi.sizePointer(mp, opts) + b = protowire.AppendVarint(b, f.wiretag) + b = protowire.AppendVarint(b, uint64(calculatedSize)) + before := len(b) + b, err := f.mi.marshalAppendPointer(b, mp, opts) + if measuredSize := len(b) - before; calculatedSize != measuredSize && err == nil { + return nil, errors.MismatchedSizeCalculation(calculatedSize, measuredSize) + } + return b, err +} + +func consumeOpaqueMessage(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) { + if wtyp != protowire.BytesType { + return out, errUnknown + } + v, n := protowire.ConsumeBytes(b) + if n < 0 { + return out, errDecode + } + mp := p.AtomicGetPointer() + if mp.IsNil() { + mp = p.AtomicSetPointerIfNil(pointerOfValue(reflect.New(f.mi.GoReflectType.Elem()))) + } + o, err := f.mi.unmarshalPointer(v, mp, 0, opts) + if err != nil { + return out, err + } + out.n = n + out.initialized = o.initialized + return out, nil +} + +func isInitOpaqueMessage(p pointer, f *coderFieldInfo) error { + mp := p.AtomicGetPointer() + if mp.IsNil() { + return nil + } + return f.mi.checkInitializedPointer(mp) +} + +func mergeOpaqueMessage(dst, src pointer, f *coderFieldInfo, opts mergeOptions) { + dstmp := dst.AtomicGetPointer() + if dstmp.IsNil() { + dstmp = dst.AtomicSetPointerIfNil(pointerOfValue(reflect.New(f.mi.GoReflectType.Elem()))) + } + f.mi.mergePointer(dstmp, src.AtomicGetPointer(), opts) +} + +func sizeOpaqueGroup(p pointer, f *coderFieldInfo, opts marshalOptions) (size int) { + return 2*f.tagsize + f.mi.sizePointer(p.AtomicGetPointer(), opts) +} + +func appendOpaqueGroup(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) { + b = protowire.AppendVarint(b, f.wiretag) // start group + b, err := f.mi.marshalAppendPointer(b, p.AtomicGetPointer(), opts) + b = protowire.AppendVarint(b, f.wiretag+1) // end group + return b, err +} + +func consumeOpaqueGroup(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) { + if wtyp != protowire.StartGroupType { + return out, errUnknown + } + mp := p.AtomicGetPointer() + if mp.IsNil() { + mp = p.AtomicSetPointerIfNil(pointerOfValue(reflect.New(f.mi.GoReflectType.Elem()))) + } + o, e := f.mi.unmarshalPointer(b, mp, f.num, opts) + return o, e +} + +func makeOpaqueRepeatedMessageFieldCoder(fd protoreflect.FieldDescriptor, ft reflect.Type) (*MessageInfo, pointerCoderFuncs) { + if ft.Kind() != reflect.Ptr || ft.Elem().Kind() != reflect.Slice { + panic(fmt.Sprintf("invalid field: %v: unsupported type for opaque repeated message: %v", fd.FullName(), ft)) + } + mt := ft.Elem().Elem() // *[]*T -> *T + mi := getMessageInfo(mt) + if mi == nil { + panic(fmt.Sprintf("invalid field: %v: unsupported message type %v", fd.FullName(), mt)) + } + switch fd.Kind() { + case protoreflect.MessageKind: + return mi, pointerCoderFuncs{ + size: sizeOpaqueMessageSlice, + marshal: appendOpaqueMessageSlice, + unmarshal: consumeOpaqueMessageSlice, + isInit: isInitOpaqueMessageSlice, + merge: mergeOpaqueMessageSlice, + } + case protoreflect.GroupKind: + return mi, pointerCoderFuncs{ + size: sizeOpaqueGroupSlice, + marshal: appendOpaqueGroupSlice, + unmarshal: consumeOpaqueGroupSlice, + isInit: isInitOpaqueMessageSlice, + merge: mergeOpaqueMessageSlice, + } + } + panic("unexpected field kind") +} + +func sizeOpaqueMessageSlice(p pointer, f *coderFieldInfo, opts marshalOptions) (size int) { + s := p.AtomicGetPointer().PointerSlice() + n := 0 + for _, v := range s { + n += protowire.SizeBytes(f.mi.sizePointer(v, opts)) + f.tagsize + } + return n +} + +func appendOpaqueMessageSlice(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) { + s := p.AtomicGetPointer().PointerSlice() + var err error + for _, v := range s { + b = protowire.AppendVarint(b, f.wiretag) + siz := f.mi.sizePointer(v, opts) + b = protowire.AppendVarint(b, uint64(siz)) + before := len(b) + b, err = f.mi.marshalAppendPointer(b, v, opts) + if err != nil { + return b, err + } + if measuredSize := len(b) - before; siz != measuredSize { + return nil, errors.MismatchedSizeCalculation(siz, measuredSize) + } + } + return b, nil +} + +func consumeOpaqueMessageSlice(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) { + if wtyp != protowire.BytesType { + return out, errUnknown + } + v, n := protowire.ConsumeBytes(b) + if n < 0 { + return out, errDecode + } + mp := pointerOfValue(reflect.New(f.mi.GoReflectType.Elem())) + o, err := f.mi.unmarshalPointer(v, mp, 0, opts) + if err != nil { + return out, err + } + sp := p.AtomicGetPointer() + if sp.IsNil() { + sp = p.AtomicSetPointerIfNil(pointerOfValue(reflect.New(f.ft.Elem()))) + } + sp.AppendPointerSlice(mp) + out.n = n + out.initialized = o.initialized + return out, nil +} + +func isInitOpaqueMessageSlice(p pointer, f *coderFieldInfo) error { + sp := p.AtomicGetPointer() + if sp.IsNil() { + return nil + } + s := sp.PointerSlice() + for _, v := range s { + if err := f.mi.checkInitializedPointer(v); err != nil { + return err + } + } + return nil +} + +func mergeOpaqueMessageSlice(dst, src pointer, f *coderFieldInfo, opts mergeOptions) { + ds := dst.AtomicGetPointer() + if ds.IsNil() { + ds = dst.AtomicSetPointerIfNil(pointerOfValue(reflect.New(f.ft.Elem()))) + } + for _, sp := range src.AtomicGetPointer().PointerSlice() { + dm := pointerOfValue(reflect.New(f.mi.GoReflectType.Elem())) + f.mi.mergePointer(dm, sp, opts) + ds.AppendPointerSlice(dm) + } +} + +func sizeOpaqueGroupSlice(p pointer, f *coderFieldInfo, opts marshalOptions) (size int) { + s := p.AtomicGetPointer().PointerSlice() + n := 0 + for _, v := range s { + n += 2*f.tagsize + f.mi.sizePointer(v, opts) + } + return n +} + +func appendOpaqueGroupSlice(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) { + s := p.AtomicGetPointer().PointerSlice() + var err error + for _, v := range s { + b = protowire.AppendVarint(b, f.wiretag) // start group + b, err = f.mi.marshalAppendPointer(b, v, opts) + if err != nil { + return b, err + } + b = protowire.AppendVarint(b, f.wiretag+1) // end group + } + return b, nil +} + +func consumeOpaqueGroupSlice(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) { + if wtyp != protowire.StartGroupType { + return out, errUnknown + } + mp := pointerOfValue(reflect.New(f.mi.GoReflectType.Elem())) + out, err = f.mi.unmarshalPointer(b, mp, f.num, opts) + if err != nil { + return out, err + } + sp := p.AtomicGetPointer() + if sp.IsNil() { + sp = p.AtomicSetPointerIfNil(pointerOfValue(reflect.New(f.ft.Elem()))) + } + sp.AppendPointerSlice(mp) + return out, err +} diff --git a/internal/impl/codec_message.go b/internal/impl/codec_message.go index 78be9df34..2f7b363ec 100644 --- a/internal/impl/codec_message.go +++ b/internal/impl/codec_message.go @@ -32,6 +32,10 @@ type coderMessageInfo struct { needsInitCheck bool isMessageSet bool numRequiredFields uint8 + + lazyOffset offset + presenceOffset offset + presenceSize presenceSize } type coderFieldInfo struct { @@ -45,12 +49,19 @@ type coderFieldInfo struct { tagsize int // size of the varint-encoded tag isPointer bool // true if IsNil may be called on the struct field isRequired bool // true if field is required + + isLazy bool + presenceIndex uint32 } +const noPresence = 0xffffffff + func (mi *MessageInfo) makeCoderMethods(t reflect.Type, si structInfo) { mi.sizecacheOffset = invalidOffset mi.unknownOffset = invalidOffset mi.extensionOffset = invalidOffset + mi.lazyOffset = invalidOffset + mi.presenceOffset = si.presenceOffset if si.sizecacheOffset.IsValid() && si.sizecacheType == sizecacheType { mi.sizecacheOffset = si.sizecacheOffset @@ -127,6 +138,8 @@ func (mi *MessageInfo) makeCoderMethods(t reflect.Type, si structInfo) { validation: newFieldValidationInfo(mi, si, fd, ft), isPointer: fd.Cardinality() == protoreflect.Repeated || fd.HasPresence(), isRequired: fd.Cardinality() == protoreflect.Required, + + presenceIndex: noPresence, } mi.orderedCoderFields = append(mi.orderedCoderFields, cf) mi.coderFields[cf.num] = cf diff --git a/internal/impl/codec_message_opaque.go b/internal/impl/codec_message_opaque.go new file mode 100644 index 000000000..88c16ae5b --- /dev/null +++ b/internal/impl/codec_message_opaque.go @@ -0,0 +1,156 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package impl + +import ( + "fmt" + "reflect" + "sort" + + "google.golang.org/protobuf/encoding/protowire" + "google.golang.org/protobuf/internal/encoding/messageset" + "google.golang.org/protobuf/internal/order" + "google.golang.org/protobuf/reflect/protoreflect" + piface "google.golang.org/protobuf/runtime/protoiface" +) + +func (mi *MessageInfo) makeOpaqueCoderMethods(t reflect.Type, si opaqueStructInfo) { + mi.sizecacheOffset = si.sizecacheOffset + mi.unknownOffset = si.unknownOffset + mi.unknownPtrKind = si.unknownType.Kind() == reflect.Ptr + mi.extensionOffset = si.extensionOffset + mi.lazyOffset = si.lazyOffset + mi.presenceOffset = si.presenceOffset + + mi.coderFields = make(map[protowire.Number]*coderFieldInfo) + fields := mi.Desc.Fields() + for i := 0; i < fields.Len(); i++ { + fd := fields.Get(i) + + fs := si.fieldsByNumber[fd.Number()] + if fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic() { + fs = si.oneofsByName[fd.ContainingOneof().Name()] + } + ft := fs.Type + var wiretag uint64 + if !fd.IsPacked() { + wiretag = protowire.EncodeTag(fd.Number(), wireTypes[fd.Kind()]) + } else { + wiretag = protowire.EncodeTag(fd.Number(), protowire.BytesType) + } + var fieldOffset offset + var funcs pointerCoderFuncs + var childMessage *MessageInfo + switch { + case fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic(): + fieldOffset = offsetOf(fs, mi.Exporter) + case fd.IsWeak(): + fieldOffset = si.weakOffset + funcs = makeWeakMessageFieldCoder(fd) + case fd.Message() != nil && !fd.IsMap(): + fieldOffset = offsetOf(fs, mi.Exporter) + if fd.IsList() { + childMessage, funcs = makeOpaqueRepeatedMessageFieldCoder(fd, ft) + } else { + childMessage, funcs = makeOpaqueMessageFieldCoder(fd, ft) + } + default: + fieldOffset = offsetOf(fs, mi.Exporter) + childMessage, funcs = fieldCoder(fd, ft) + } + cf := &coderFieldInfo{ + num: fd.Number(), + offset: fieldOffset, + wiretag: wiretag, + ft: ft, + tagsize: protowire.SizeVarint(wiretag), + funcs: funcs, + mi: childMessage, + validation: newFieldValidationInfo(mi, si.structInfo, fd, ft), + isPointer: (fd.Cardinality() == protoreflect.Repeated || + fd.Kind() == protoreflect.MessageKind || + fd.Kind() == protoreflect.GroupKind), + isRequired: fd.Cardinality() == protoreflect.Required, + presenceIndex: noPresence, + } + + // TODO: Use presence for all fields. + // + // In some cases, such as maps, presence means only "might be set" rather + // than "is definitely set", but every field should have a presence bit to + // permit us to skip over definitely-unset fields at marshal time. + + var hasPresence bool + hasPresence, cf.isLazy = usePresenceForField(si, fd) + + if hasPresence { + cf.presenceIndex, mi.presenceSize = presenceIndex(mi.Desc, fd) + } + + mi.orderedCoderFields = append(mi.orderedCoderFields, cf) + mi.coderFields[cf.num] = cf + } + for i, oneofs := 0, mi.Desc.Oneofs(); i < oneofs.Len(); i++ { + if od := oneofs.Get(i); !od.IsSynthetic() { + mi.initOneofFieldCoders(od, si.structInfo) + } + } + if messageset.IsMessageSet(mi.Desc) { + if !mi.extensionOffset.IsValid() { + panic(fmt.Sprintf("%v: MessageSet with no extensions field", mi.Desc.FullName())) + } + if !mi.unknownOffset.IsValid() { + panic(fmt.Sprintf("%v: MessageSet with no unknown field", mi.Desc.FullName())) + } + mi.isMessageSet = true + } + sort.Slice(mi.orderedCoderFields, func(i, j int) bool { + return mi.orderedCoderFields[i].num < mi.orderedCoderFields[j].num + }) + + var maxDense protoreflect.FieldNumber + for _, cf := range mi.orderedCoderFields { + if cf.num >= 16 && cf.num >= 2*maxDense { + break + } + maxDense = cf.num + } + mi.denseCoderFields = make([]*coderFieldInfo, maxDense+1) + for _, cf := range mi.orderedCoderFields { + if int(cf.num) > len(mi.denseCoderFields) { + break + } + mi.denseCoderFields[cf.num] = cf + } + + // To preserve compatibility with historic wire output, marshal oneofs last. + if mi.Desc.Oneofs().Len() > 0 { + sort.Slice(mi.orderedCoderFields, func(i, j int) bool { + fi := fields.ByNumber(mi.orderedCoderFields[i].num) + fj := fields.ByNumber(mi.orderedCoderFields[j].num) + return order.LegacyFieldOrder(fi, fj) + }) + } + + mi.needsInitCheck = needsInitCheck(mi.Desc) + if mi.methods.Marshal == nil && mi.methods.Size == nil { + mi.methods.Flags |= piface.SupportMarshalDeterministic + mi.methods.Marshal = mi.marshal + mi.methods.Size = mi.size + } + if mi.methods.Unmarshal == nil { + mi.methods.Flags |= piface.SupportUnmarshalDiscardUnknown + mi.methods.Unmarshal = mi.unmarshal + } + if mi.methods.CheckInitialized == nil { + mi.methods.CheckInitialized = mi.checkInitialized + } + if mi.methods.Merge == nil { + mi.methods.Merge = mi.merge + } + if mi.methods.Equal == nil { + mi.methods.Equal = equal + } +} diff --git a/internal/impl/decode.go b/internal/impl/decode.go index cda0520c2..e0dd21fa5 100644 --- a/internal/impl/decode.go +++ b/internal/impl/decode.go @@ -34,6 +34,8 @@ func (o unmarshalOptions) Options() proto.UnmarshalOptions { AllowPartial: true, DiscardUnknown: o.DiscardUnknown(), Resolver: o.resolver, + + NoLazyDecoding: o.NoLazyDecoding(), } } @@ -41,13 +43,26 @@ func (o unmarshalOptions) DiscardUnknown() bool { return o.flags&protoiface.UnmarshalDiscardUnknown != 0 } -func (o unmarshalOptions) IsDefault() bool { - return o.flags == 0 && o.resolver == protoregistry.GlobalTypes +func (o unmarshalOptions) AliasBuffer() bool { return o.flags&protoiface.UnmarshalAliasBuffer != 0 } +func (o unmarshalOptions) Validated() bool { return o.flags&protoiface.UnmarshalValidated != 0 } +func (o unmarshalOptions) NoLazyDecoding() bool { + return o.flags&protoiface.UnmarshalNoLazyDecoding != 0 +} + +func (o unmarshalOptions) CanBeLazy() bool { + if o.resolver != protoregistry.GlobalTypes { + return false + } + // We ignore the UnmarshalInvalidateSizeCache even though it's not in the default set + return (o.flags & ^(protoiface.UnmarshalAliasBuffer | protoiface.UnmarshalValidated | protoiface.UnmarshalCheckRequired)) == 0 } var lazyUnmarshalOptions = unmarshalOptions{ resolver: protoregistry.GlobalTypes, - depth: protowire.DefaultRecursionLimit, + + flags: protoiface.UnmarshalAliasBuffer | protoiface.UnmarshalValidated, + + depth: protowire.DefaultRecursionLimit, } type unmarshalOutput struct { @@ -94,9 +109,30 @@ func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag protowire. if flags.ProtoLegacy && mi.isMessageSet { return unmarshalMessageSet(mi, b, p, opts) } + + lazyDecoding := LazyEnabled() // default + if opts.NoLazyDecoding() { + lazyDecoding = false // explicitly disabled + } + if mi.lazyOffset.IsValid() && lazyDecoding { + return mi.unmarshalPointerLazy(b, p, groupTag, opts) + } + return mi.unmarshalPointerEager(b, p, groupTag, opts) +} + +// unmarshalPointerEager is the message unmarshalling function for all messages that are not lazy. +// The corresponding function for Lazy is in google_lazy.go. +func (mi *MessageInfo) unmarshalPointerEager(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) { + initialized := true var requiredMask uint64 var exts *map[int32]ExtensionField + + var presence presence + if mi.presenceOffset.IsValid() { + presence = p.Apply(mi.presenceOffset).PresenceInfo() + } + start := len(b) for len(b) > 0 { // Parse the tag (field number and wire type). @@ -154,6 +190,11 @@ func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag protowire. if f.funcs.isInit != nil && !o.initialized { initialized = false } + + if f.presenceIndex != noPresence { + presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize) + } + default: // Possible extension. if exts == nil && mi.extensionOffset.IsValid() { @@ -222,7 +263,7 @@ func (mi *MessageInfo) unmarshalExtension(b []byte, num protowire.Number, wtyp p return out, errUnknown } if flags.LazyUnmarshalExtensions { - if opts.IsDefault() && x.canLazy(xt) { + if opts.CanBeLazy() && x.canLazy(xt) { out, valid := skipExtension(b, xi, num, wtyp, opts) switch valid { case ValidationValid: @@ -270,6 +311,13 @@ func skipExtension(b []byte, xi *extensionFieldInfo, num protowire.Number, wtyp if n < 0 { return out, ValidationUnknown } + + if opts.Validated() { + out.initialized = true + out.n = n + return out, ValidationValid + } + out, st := xi.validation.mi.validate(v, 0, opts) out.n = n return out, st diff --git a/internal/impl/encode.go b/internal/impl/encode.go index 6254f5de4..b2e212291 100644 --- a/internal/impl/encode.go +++ b/internal/impl/encode.go @@ -10,6 +10,7 @@ import ( "sync/atomic" "google.golang.org/protobuf/internal/flags" + "google.golang.org/protobuf/internal/protolazy" "google.golang.org/protobuf/proto" piface "google.golang.org/protobuf/runtime/protoiface" ) @@ -71,11 +72,39 @@ func (mi *MessageInfo) sizePointerSlow(p pointer, opts marshalOptions) (size int e := p.Apply(mi.extensionOffset).Extensions() size += mi.sizeExtensions(e, opts) } + + var lazy **protolazy.XXX_lazyUnmarshalInfo + var presence presence + if mi.presenceOffset.IsValid() { + presence = p.Apply(mi.presenceOffset).PresenceInfo() + if mi.lazyOffset.IsValid() { + lazy = p.Apply(mi.lazyOffset).LazyInfoPtr() + } + } + for _, f := range mi.orderedCoderFields { if f.funcs.size == nil { continue } fptr := p.Apply(f.offset) + + if f.presenceIndex != noPresence { + if !presence.Present(f.presenceIndex) { + continue + } + + if f.isLazy && fptr.AtomicGetPointer().IsNil() { + if lazyFields(opts) { + size += (*lazy).SizeField(uint32(f.num)) + continue + } else { + mi.lazyUnmarshal(p, f.num) + } + } + size += f.funcs.size(fptr, f, opts) + continue + } + if f.isPointer && fptr.Elem().IsNil() { continue } @@ -134,11 +163,52 @@ func (mi *MessageInfo) marshalAppendPointer(b []byte, p pointer, opts marshalOpt return b, err } } + + var lazy **protolazy.XXX_lazyUnmarshalInfo + var presence presence + if mi.presenceOffset.IsValid() { + presence = p.Apply(mi.presenceOffset).PresenceInfo() + if mi.lazyOffset.IsValid() { + lazy = p.Apply(mi.lazyOffset).LazyInfoPtr() + } + } + for _, f := range mi.orderedCoderFields { if f.funcs.marshal == nil { continue } fptr := p.Apply(f.offset) + + if f.presenceIndex != noPresence { + if !presence.Present(f.presenceIndex) { + continue + } + if f.isLazy { + // Be careful, this field needs to be read atomically, like for a get + if f.isPointer && fptr.AtomicGetPointer().IsNil() { + if lazyFields(opts) { + b, _ = (*lazy).AppendField(b, uint32(f.num)) + continue + } else { + mi.lazyUnmarshal(p, f.num) + } + } + + b, err = f.funcs.marshal(b, fptr, f, opts) + if err != nil { + return b, err + } + continue + } else if f.isPointer && fptr.Elem().IsNil() { + continue + } + b, err = f.funcs.marshal(b, fptr, f, opts) + if err != nil { + return b, err + } + continue + } + if f.isPointer && fptr.Elem().IsNil() { continue } @@ -163,6 +233,14 @@ func fullyLazyExtensions(opts marshalOptions) bool { return opts.flags&piface.MarshalDeterministic == 0 } +// lazyFields returns true if we should attempt to keep fields lazy over size and marshal. +func lazyFields(opts marshalOptions) bool { + // When deterministic marshaling is requested, force an unmarshal for lazy + // fields to produce a deterministic result, instead of passing through + // bytes lazily that may or may not match what Go Protobuf would produce. + return opts.flags&piface.MarshalDeterministic == 0 +} + func (mi *MessageInfo) sizeExtensions(ext *map[int32]ExtensionField, opts marshalOptions) (n int) { if ext == nil { return 0 diff --git a/internal/impl/lazy.go b/internal/impl/lazy.go new file mode 100644 index 000000000..e8fb6c35b --- /dev/null +++ b/internal/impl/lazy.go @@ -0,0 +1,433 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package impl + +import ( + "fmt" + "math/bits" + "os" + "reflect" + "sort" + "sync/atomic" + + "google.golang.org/protobuf/encoding/protowire" + "google.golang.org/protobuf/internal/errors" + "google.golang.org/protobuf/internal/protolazy" + "google.golang.org/protobuf/reflect/protoreflect" + preg "google.golang.org/protobuf/reflect/protoregistry" + piface "google.golang.org/protobuf/runtime/protoiface" +) + +var enableLazy int32 = func() int32 { + if os.Getenv("GOPROTODEBUG") == "nolazy" { + return 0 + } + return 1 +}() + +// EnableLazyUnmarshal enables lazy unmarshaling. +func EnableLazyUnmarshal(enable bool) { + if enable { + atomic.StoreInt32(&enableLazy, 1) + return + } + atomic.StoreInt32(&enableLazy, 0) +} + +// LazyEnabled reports whether lazy unmarshalling is currently enabled. +func LazyEnabled() bool { + return atomic.LoadInt32(&enableLazy) != 0 +} + +// UnmarshalField unmarshals a field in a message. +func UnmarshalField(m interface{}, num protowire.Number) { + switch m := m.(type) { + case *messageState: + m.messageInfo().lazyUnmarshal(m.pointer(), num) + case *messageReflectWrapper: + m.messageInfo().lazyUnmarshal(m.pointer(), num) + default: + panic(fmt.Sprintf("unsupported wrapper type %T", m)) + } +} + +func (mi *MessageInfo) lazyUnmarshal(p pointer, num protoreflect.FieldNumber) { + var f *coderFieldInfo + if int(num) < len(mi.denseCoderFields) { + f = mi.denseCoderFields[num] + } else { + f = mi.coderFields[num] + } + if f == nil { + panic(fmt.Sprintf("lazyUnmarshal: field info for %v.%v", mi.Desc.FullName(), num)) + } + lazy := *p.Apply(mi.lazyOffset).LazyInfoPtr() + start, end, found, _, multipleEntries := lazy.FindFieldInProto(uint32(num)) + if !found && multipleEntries == nil { + panic(fmt.Sprintf("lazyUnmarshal: can't find field data for %v.%v", mi.Desc.FullName(), num)) + } + // The actual pointer in the message can not be set until the whole struct is filled in, otherwise we will have races. + // Create another pointer and set it atomically, if we won the race and the pointer in the original message is still nil. + fp := pointerOfValue(reflect.New(f.ft)) + if multipleEntries != nil { + for _, entry := range multipleEntries { + mi.unmarshalField(lazy.Buffer()[entry.Start:entry.End], fp, f, lazy, lazy.UnmarshalFlags()) + } + } else { + mi.unmarshalField(lazy.Buffer()[start:end], fp, f, lazy, lazy.UnmarshalFlags()) + } + p.Apply(f.offset).AtomicSetPointerIfNil(fp.Elem()) +} + +func (mi *MessageInfo) unmarshalField(b []byte, p pointer, f *coderFieldInfo, lazyInfo *protolazy.XXX_lazyUnmarshalInfo, flags piface.UnmarshalInputFlags) error { + opts := lazyUnmarshalOptions + opts.flags |= flags + for len(b) > 0 { + // Parse the tag (field number and wire type). + var tag uint64 + if b[0] < 0x80 { + tag = uint64(b[0]) + b = b[1:] + } else if len(b) >= 2 && b[1] < 128 { + tag = uint64(b[0]&0x7f) + uint64(b[1])<<7 + b = b[2:] + } else { + var n int + tag, n = protowire.ConsumeVarint(b) + if n < 0 { + return errors.New("invalid wire data") + } + b = b[n:] + } + var num protowire.Number + if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) { + return errors.New("invalid wire data") + } else { + num = protowire.Number(n) + } + wtyp := protowire.Type(tag & 7) + if num == f.num { + o, err := f.funcs.unmarshal(b, p, wtyp, f, opts) + if err == nil { + b = b[o.n:] + continue + } + if err != errUnknown { + return err + } + } + n := protowire.ConsumeFieldValue(num, wtyp, b) + if n < 0 { + return errors.New("invalid wire data") + } + b = b[n:] + } + return nil +} + +func (mi *MessageInfo) skipField(b []byte, f *coderFieldInfo, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, _ ValidationStatus) { + fmi := f.validation.mi + if fmi == nil { + fd := mi.Desc.Fields().ByNumber(f.num) + if fd == nil || !fd.IsWeak() { + return out, ValidationUnknown + } + messageName := fd.Message().FullName() + messageType, err := preg.GlobalTypes.FindMessageByName(messageName) + if err != nil { + return out, ValidationUnknown + } + var ok bool + fmi, ok = messageType.(*MessageInfo) + if !ok { + return out, ValidationUnknown + } + } + fmi.init() + switch f.validation.typ { + case validationTypeMessage: + if wtyp != protowire.BytesType { + return out, ValidationWrongWireType + } + v, n := protowire.ConsumeBytes(b) + if n < 0 { + return out, ValidationInvalid + } + out, st := fmi.validate(v, 0, opts) + out.n = n + return out, st + case validationTypeGroup: + if wtyp != protowire.StartGroupType { + return out, ValidationWrongWireType + } + out, st := fmi.validate(b, f.num, opts) + return out, st + default: + return out, ValidationUnknown + } +} + +// unmarshalPointerLazy is similar to unmarshalPointerEager, but it +// specifically handles lazy unmarshalling. it expects lazyOffset and +// presenceOffset to both be valid. +func (mi *MessageInfo) unmarshalPointerLazy(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) { + initialized := true + var requiredMask uint64 + var lazy **protolazy.XXX_lazyUnmarshalInfo + var presence presence + var lazyIndex []protolazy.IndexEntry + var lastNum protowire.Number + outOfOrder := false + lazyDecode := false + presence = p.Apply(mi.presenceOffset).PresenceInfo() + lazy = p.Apply(mi.lazyOffset).LazyInfoPtr() + if !presence.AnyPresent(mi.presenceSize) { + if opts.CanBeLazy() { + // If the message contains existing data, we need to merge into it. + // Lazy unmarshaling doesn't merge, so only enable it when the + // message is empty (has no presence bitmap). + lazyDecode = true + if *lazy == nil { + *lazy = &protolazy.XXX_lazyUnmarshalInfo{} + } + (*lazy).SetUnmarshalFlags(opts.flags) + if !opts.AliasBuffer() { + // Make a copy of the buffer for lazy unmarshaling. + // Set the AliasBuffer flag so recursive unmarshal + // operations reuse the copy. + b = append([]byte{}, b...) + opts.flags |= piface.UnmarshalAliasBuffer + } + (*lazy).SetBuffer(b) + } + } + // Track special handling of lazy fields. + // + // In the common case, all fields are lazyValidateOnly (and lazyFields remains nil). + // In the event that validation for a field fails, this map tracks handling of the field. + type lazyAction uint8 + const ( + lazyValidateOnly lazyAction = iota // validate the field only + lazyUnmarshalNow // eagerly unmarshal the field + lazyUnmarshalLater // unmarshal the field after the message is fully processed + ) + var lazyFields map[*coderFieldInfo]lazyAction + var exts *map[int32]ExtensionField + start := len(b) + pos := 0 + for len(b) > 0 { + // Parse the tag (field number and wire type). + var tag uint64 + if b[0] < 0x80 { + tag = uint64(b[0]) + b = b[1:] + } else if len(b) >= 2 && b[1] < 128 { + tag = uint64(b[0]&0x7f) + uint64(b[1])<<7 + b = b[2:] + } else { + var n int + tag, n = protowire.ConsumeVarint(b) + if n < 0 { + return out, errDecode + } + b = b[n:] + } + var num protowire.Number + if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) { + return out, errors.New("invalid field number") + } else { + num = protowire.Number(n) + } + wtyp := protowire.Type(tag & 7) + + if wtyp == protowire.EndGroupType { + if num != groupTag { + return out, errors.New("mismatching end group marker") + } + groupTag = 0 + break + } + + var f *coderFieldInfo + if int(num) < len(mi.denseCoderFields) { + f = mi.denseCoderFields[num] + } else { + f = mi.coderFields[num] + } + var n int + err := errUnknown + discardUnknown := false + Field: + switch { + case f != nil: + if f.funcs.unmarshal == nil { + break + } + if f.isLazy && lazyDecode { + switch { + case lazyFields == nil || lazyFields[f] == lazyValidateOnly: + // Attempt to validate this field and leave it for later lazy unmarshaling. + o, valid := mi.skipField(b, f, wtyp, opts) + switch valid { + case ValidationValid: + // Skip over the valid field and continue. + err = nil + presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize) + requiredMask |= f.validation.requiredBit + if !o.initialized { + initialized = false + } + n = o.n + break Field + case ValidationInvalid: + return out, errors.New("invalid proto wire format") + case ValidationWrongWireType: + break Field + case ValidationUnknown: + if lazyFields == nil { + lazyFields = make(map[*coderFieldInfo]lazyAction) + } + if presence.Present(f.presenceIndex) { + // We were unable to determine if the field is valid or not, + // and we've already skipped over at least one instance of this + // field. Clear the presence bit (so if we stop decoding early, + // we don't leave a partially-initialized field around) and flag + // the field for unmarshaling before we return. + presence.ClearPresent(f.presenceIndex) + lazyFields[f] = lazyUnmarshalLater + discardUnknown = true + break Field + } else { + // We were unable to determine if the field is valid or not, + // but this is the first time we've seen it. Flag it as needing + // eager unmarshaling and fall through to the eager unmarshal case below. + lazyFields[f] = lazyUnmarshalNow + } + } + case lazyFields[f] == lazyUnmarshalLater: + // This field will be unmarshaled in a separate pass below. + // Skip over it here. + discardUnknown = true + break Field + default: + // Eagerly unmarshal the field. + } + } + if f.isLazy && !lazyDecode && presence.Present(f.presenceIndex) { + if p.Apply(f.offset).AtomicGetPointer().IsNil() { + mi.lazyUnmarshal(p, f.num) + } + } + var o unmarshalOutput + o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, f, opts) + n = o.n + if err != nil { + break + } + requiredMask |= f.validation.requiredBit + if f.funcs.isInit != nil && !o.initialized { + initialized = false + } + if f.presenceIndex != noPresence { + presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize) + } + default: + // Possible extension. + if exts == nil && mi.extensionOffset.IsValid() { + exts = p.Apply(mi.extensionOffset).Extensions() + if *exts == nil { + *exts = make(map[int32]ExtensionField) + } + } + if exts == nil { + break + } + var o unmarshalOutput + o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts) + if err != nil { + break + } + n = o.n + if !o.initialized { + initialized = false + } + } + if err != nil { + if err != errUnknown { + return out, err + } + n = protowire.ConsumeFieldValue(num, wtyp, b) + if n < 0 { + return out, errDecode + } + if !discardUnknown && !opts.DiscardUnknown() && mi.unknownOffset.IsValid() { + u := mi.mutableUnknownBytes(p) + *u = protowire.AppendTag(*u, num, wtyp) + *u = append(*u, b[:n]...) + } + } + b = b[n:] + end := start - len(b) + if lazyDecode && f != nil && f.isLazy { + if num != lastNum { + lazyIndex = append(lazyIndex, protolazy.IndexEntry{ + FieldNum: uint32(num), + Start: uint32(pos), + End: uint32(end), + }) + } else { + i := len(lazyIndex) - 1 + lazyIndex[i].End = uint32(end) + lazyIndex[i].MultipleContiguous = true + } + } + if num < lastNum { + outOfOrder = true + } + pos = end + lastNum = num + } + if groupTag != 0 { + return out, errors.New("missing end group marker") + } + if lazyFields != nil { + // Some fields failed validation, and now need to be unmarshaled. + for f, action := range lazyFields { + if action != lazyUnmarshalLater { + continue + } + initialized = false + if *lazy == nil { + *lazy = &protolazy.XXX_lazyUnmarshalInfo{} + } + if err := mi.unmarshalField((*lazy).Buffer(), p.Apply(f.offset), f, *lazy, opts.flags); err != nil { + return out, err + } + presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize) + } + } + if lazyDecode { + if outOfOrder { + sort.Slice(lazyIndex, func(i, j int) bool { + return lazyIndex[i].FieldNum < lazyIndex[j].FieldNum || + (lazyIndex[i].FieldNum == lazyIndex[j].FieldNum && + lazyIndex[i].Start < lazyIndex[j].Start) + }) + } + if *lazy == nil { + *lazy = &protolazy.XXX_lazyUnmarshalInfo{} + } + + (*lazy).SetIndex(lazyIndex) + } + if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) { + initialized = false + } + if initialized { + out.initialized = true + } + out.n = start - len(b) + return out, nil +} diff --git a/internal/impl/lazy_buffersharing_test.go b/internal/impl/lazy_buffersharing_test.go new file mode 100644 index 000000000..15a7b049c --- /dev/null +++ b/internal/impl/lazy_buffersharing_test.go @@ -0,0 +1,151 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package impl_test + +import ( + "testing" + + mixedpb "google.golang.org/protobuf/internal/testprotos/mixed" + "google.golang.org/protobuf/proto" +) + +var enableLazy = proto.UnmarshalOptions{} +var disableLazy = proto.UnmarshalOptions{ + NoLazyDecoding: true, +} + +func TestCopyTopLevelLazy(t *testing.T) { + testCopyTopLevel(t, enableLazy) +} + +func TestCopyTopLevelEager(t *testing.T) { + testCopyTopLevel(t, disableLazy) +} + +// testCopyTopLevel tests that the buffer is copied to a safe location +// when the opaque proto is the top level proto +func testCopyTopLevel(t *testing.T, unmarshalOpts proto.UnmarshalOptions) { + m := mixedpb.OpaqueLazy_builder{ + Opaque: mixedpb.OpaqueLazy_builder{ + OptionalInt32: proto.Int32(23), + }.Build(), + }.Build() + if got, want := m.GetOpaque().GetOptionalInt32(), int32(23); got != want { + t.Errorf("Build(): unexpected optional_int32: got %v, want %v", got, want) + } + b, err := proto.Marshal(m) + if err != nil { + t.Fatalf("Could not marshal healthy proto %v.", m) + } + m2 := &mixedpb.OpaqueLazy{} + if err := unmarshalOpts.Unmarshal(b, m2); err != nil { + t.Fatalf("Could not unmarshal healthy proto buffer: %v.", b) + } + for i := 0; i < len(b); i++ { + b[i] = byte(0xFF) + } + if got, want := m2.GetOpaque().GetOptionalInt32(), int32(23); got != want { + t.Errorf("Mixed proto referred to shared buffer: got %v, want %v", got, want) + } +} + +func TestCopyWhenContainedInOpenLazy(t *testing.T) { + testCopyWhenContainedInOpen(t, enableLazy) +} + +func TestCopyWhenContainedInOpenEager(t *testing.T) { + testCopyWhenContainedInOpen(t, disableLazy) +} + +// testCopyWhenContainedInOpen tests that the buffer is copied +// for opaque messages that are not on the top level +func testCopyWhenContainedInOpen(t *testing.T, unmarshalOpts proto.UnmarshalOptions) { + m := &mixedpb.OpenLazy{ + Opaque: mixedpb.OpaqueLazy_builder{ + Opaque: mixedpb.OpaqueLazy_builder{ + OptionalInt32: proto.Int32(23), + }.Build(), + }.Build(), + } + if got, want := m.GetOpaque().GetOpaque().GetOptionalInt32(), int32(23); got != want { + t.Errorf("Build(): unexpected optional_int32: got %v, want %v", got, want) + } + b, err := proto.Marshal(m) + if err != nil { + t.Fatalf("Could not marshal healthy proto %v.", m) + } + m2 := &mixedpb.OpenLazy{} + if err := unmarshalOpts.Unmarshal(b, m2); err != nil { + t.Fatalf("Could not unmarshal healthy proto buffer: %v.", b) + } + for i := 0; i < len(b); i++ { + b[i] = byte(0xFF) + } + if got, want := m2.GetOpaque().GetOpaque().GetOptionalInt32(), int32(23); got != want { + t.Errorf("Build(): unexpected optional_int32: got %v, want %v", got, want) + } +} + +func TestNoExcessiveCopyLazy(t *testing.T) { + testNoExcessiveCopy(t, enableLazy) +} + +func TestNoExcessiveCopyEager(t *testing.T) { + testNoExcessiveCopy(t, disableLazy) +} + +// testNoExcessiveCopy tests that an opaque submessage does share the buffer +// if the message above already got it copied +func testNoExcessiveCopy(t *testing.T, unmarshalOpts proto.UnmarshalOptions) { + m := &mixedpb.OpenLazy{ + Opaque: mixedpb.OpaqueLazy_builder{ + Opaque: mixedpb.OpaqueLazy_builder{ + OptionalInt32: proto.Int32(23), + }.Build(), + }.Build(), + } + if got, want := m.GetOpaque().GetOpaque().GetOptionalInt32(), int32(23); got != want { + t.Errorf("Build(): unexpected optional_int32: got %v, want %v", got, want) + } + b, err := proto.Marshal(m) + if err != nil { + t.Fatalf("Could not marshal healthy proto %v.", m) + } + mm := &mixedpb.OpenLazy{} + if err := unmarshalOpts.Unmarshal(b, mm); err != nil { + t.Fatalf("Could not unmarshal healthy proto buffer: %v.", b) + } + m2 := mm.GetOpaque() + m3 := mm.GetOpaque().GetOpaque() + // Now, if we deliberately destroy the OpaqueM2 buffer, the OpaqueM3 buffer should + // be destroyed as well + if m2.XXX_lazyUnmarshalInfo == nil { + if m3.XXX_lazyUnmarshalInfo != nil { + t.Errorf("Inconsistent lazyUnmarshalInfo for subprotos") + } + // nothing to check, we don't have backing store + return + } + + if m3.XXX_lazyUnmarshalInfo == nil { + t.Errorf("Inconsistent lazyUnmarshalInfo for subprotos (2)") + return + } + b = (*m2.XXX_lazyUnmarshalInfo).Protobuf + m2len := len(b) + for i := 0; i < len(b); i++ { + b[i] = byte(0xFF) + } + b = (*m3.XXX_lazyUnmarshalInfo).Protobuf + if m2len != 0 && len(b) == 0 { + t.Errorf("The lazy backing store for submessage is empty when it is not for the surronding message: %v.", m2len) + } + for i, x := range b { + if x != byte(0xFF) { + t.Errorf("Backing store for protocol buffer is not shared (index = %d, x = 0x%x)", i, x) + } + } + +} diff --git a/internal/impl/lazy_field_normalized_test.go b/internal/impl/lazy_field_normalized_test.go new file mode 100644 index 000000000..02c9ebe9f --- /dev/null +++ b/internal/impl/lazy_field_normalized_test.go @@ -0,0 +1,156 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package impl_test + +import ( + "bytes" + "fmt" + "strings" + "testing" + + "google.golang.org/protobuf/encoding/protowire" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/testing/protopack" + + lnwtpb "google.golang.org/protobuf/internal/testprotos/lazy" +) + +func unmarshalsTheSame(b []byte, expected *lnwtpb.FTop) error { + unmarshaledTop := &lnwtpb.FTop{} + if err := proto.Unmarshal(b, unmarshaledTop); err != nil { + return err + } + if !proto.Equal(unmarshaledTop, expected) { + return fmt.Errorf("!proto.Equal") + } + return nil +} + +func bytesTag(num protowire.Number) protopack.Tag { + return protopack.Tag{ + Number: num, + Type: protopack.BytesType, + } +} + +func varintTag(num protowire.Number) protopack.Tag { + return protopack.Tag{ + Number: num, + Type: protopack.VarintType, + } +} + +// Constructs a message encoded in denormalized (non-minimal) wire format, but +// using two levels of nesting: A top-level message with a child message which +// in turn has a grandchild message. +func denormalizedTwoLevelField() ([]byte, *lnwtpb.FTop, error) { + expectedMessage := &lnwtpb.FTop{ + A: proto.Uint32(2342), + Child: &lnwtpb.FSub{ + Grandchild: &lnwtpb.FSub{ + B: proto.Uint32(1337), + }, + }, + } + + fullMessage := protopack.Message{ + varintTag(1), protopack.Varint(2342), + // Child + bytesTag(2), protopack.LengthPrefix(protopack.Message{ + // Grandchild + bytesTag(4), protopack.LengthPrefix(protopack.Message{ + // The first occurrence of B matches expectedMessage: + varintTag(2), protopack.Varint(1337), + // This second duplicative occurrence of B is spec'd in Protobuf: + // https://github.com/protocolbuffers/protobuf/issues/9257 + varintTag(2), protopack.Varint(1337), + }), + }), + }.Marshal() + + return fullMessage, expectedMessage, nil +} + +func TestInvalidWireFormat(t *testing.T) { + fullMessage, expectedMessage, err := denormalizedTwoLevelField() + if err != nil { + t.Fatal(err) + } + + top := &lnwtpb.FTop{} + if err := proto.Unmarshal(fullMessage, top); err != nil { + t.Fatal(err) + } + + // Access the top-level submessage, but not the grandchild. + // This populates the size cache in the top-level message. + top.GetChild() + + marshal1, err := proto.MarshalOptions{ + UseCachedSize: true, + }.Marshal(top) + if err != nil { + t.Fatal(err) + } + if err := unmarshalsTheSame(marshal1, expectedMessage); err != nil { + t.Error(err) + } + + // Call top.GetChild().GetGrandchild() to unmarshal the lazy message, + // which will normalize it: the size cache shrinks from 6 bytes to 3. + // Notably, top.GetChild()’s size cache is not updated! + top.GetChild().GetGrandchild() + marshal2, err := proto.MarshalOptions{ + // GetGrandchild+UseCachedSize is one way to trigger this bug. + // The other way is to call GetGrandchild in another goroutine, + // after proto.Marshal has called proto.Size but + // before proto.Marshal started encoding. + UseCachedSize: true, + }.Marshal(top) + if err != nil { + if strings.Contains(err.Error(), "size mismatch") { + // This is the expected failure mode: proto.Marshal() detects the + // combination of non-minimal wire format and lazy decoding and + // returns an error, prompting the user to disable lazy decoding. + return + } + t.Fatal(err) + } + if err := unmarshalsTheSame(marshal2, expectedMessage); err != nil { + t.Error(err) + } +} + +func TestIdenticalOverAccessWhenDeterministic(t *testing.T) { + fullMessage, _, err := denormalizedTwoLevelField() + if err != nil { + t.Fatal(err) + } + + top := &lnwtpb.FTop{} + if err := proto.Unmarshal(fullMessage, top); err != nil { + t.Fatal(err) + } + + deterministic := proto.MarshalOptions{ + Deterministic: true, + } + marshal1, err := deterministic.Marshal(top) + if err != nil { + t.Fatal(err) + } + + // Call top.GetChild().GetGrandchild() to unmarshal the lazy message, + // which will normalize it. + top.GetChild().GetGrandchild() + + marshal2, err := deterministic.Marshal(top) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(marshal1, marshal2) { + t.Errorf("MarshalOptions{Deterministic: true}.Marshal() not identical over accessing a non-minimal wire format lazy message:\nbefore:\n%x\nafter:\n%x", marshal1, marshal2) + } +} diff --git a/internal/impl/merge.go b/internal/impl/merge.go index 7e65f64f2..8ffdce67d 100644 --- a/internal/impl/merge.go +++ b/internal/impl/merge.go @@ -41,11 +41,38 @@ func (mi *MessageInfo) mergePointer(dst, src pointer, opts mergeOptions) { if src.IsNil() { return } + + var presenceSrc presence + var presenceDst presence + if mi.presenceOffset.IsValid() { + presenceSrc = src.Apply(mi.presenceOffset).PresenceInfo() + presenceDst = dst.Apply(mi.presenceOffset).PresenceInfo() + } + for _, f := range mi.orderedCoderFields { if f.funcs.merge == nil { continue } sfptr := src.Apply(f.offset) + + if f.presenceIndex != noPresence { + if !presenceSrc.Present(f.presenceIndex) { + continue + } + dfptr := dst.Apply(f.offset) + if f.isLazy { + if sfptr.AtomicGetPointer().IsNil() { + mi.lazyUnmarshal(src, f.num) + } + if presenceDst.Present(f.presenceIndex) && dfptr.AtomicGetPointer().IsNil() { + mi.lazyUnmarshal(dst, f.num) + } + } + f.funcs.merge(dst.Apply(f.offset), sfptr, f, opts) + presenceDst.SetPresentUnatomic(f.presenceIndex, mi.presenceSize) + continue + } + if f.isPointer && sfptr.Elem().IsNil() { continue } diff --git a/internal/impl/message.go b/internal/impl/message.go index 741b5ed29..fa10a0f5c 100644 --- a/internal/impl/message.go +++ b/internal/impl/message.go @@ -79,6 +79,9 @@ func (mi *MessageInfo) initOnce() { if mi.initDone == 1 { return } + if opaqueInitHook(mi) { + return + } t := mi.GoReflectType if t.Kind() != reflect.Ptr && t.Elem().Kind() != reflect.Struct { @@ -133,6 +136,9 @@ type structInfo struct { extensionOffset offset extensionType reflect.Type + lazyOffset offset + presenceOffset offset + fieldsByNumber map[protoreflect.FieldNumber]reflect.StructField oneofsByName map[protoreflect.Name]reflect.StructField oneofWrappersByType map[reflect.Type]protoreflect.FieldNumber @@ -145,6 +151,8 @@ func (mi *MessageInfo) makeStructInfo(t reflect.Type) structInfo { weakOffset: invalidOffset, unknownOffset: invalidOffset, extensionOffset: invalidOffset, + lazyOffset: invalidOffset, + presenceOffset: invalidOffset, fieldsByNumber: map[protoreflect.FieldNumber]reflect.StructField{}, oneofsByName: map[protoreflect.Name]reflect.StructField{}, @@ -175,6 +183,10 @@ fieldLoop: si.extensionOffset = offsetOf(f, mi.Exporter) si.extensionType = f.Type } + case "lazyFields", "XXX_lazyUnmarshalInfo": + si.lazyOffset = offsetOf(f, mi.Exporter) + case "XXX_presence": + si.presenceOffset = offsetOf(f, mi.Exporter) default: for _, s := range strings.Split(f.Tag.Get("protobuf"), ",") { if len(s) > 0 && strings.Trim(s, "0123456789") == "" { diff --git a/internal/impl/message_opaque.go b/internal/impl/message_opaque.go new file mode 100644 index 000000000..d407dd791 --- /dev/null +++ b/internal/impl/message_opaque.go @@ -0,0 +1,614 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package impl + +import ( + "fmt" + "math" + "reflect" + "strings" + "sync/atomic" + + "google.golang.org/protobuf/reflect/protoreflect" +) + +type opaqueStructInfo struct { + structInfo +} + +// isOpaque determines whether a protobuf message type is on the Opaque API. It +// checks whether the type is a Go struct that protoc-gen-go would generate. +// +// This function only detects newly generated messages from the v2 +// implementation of protoc-gen-go. It is unable to classify generated messages +// that are too old or those that are generated by a different generator +// such as protoc-gen-gogo. +func isOpaque(t reflect.Type) bool { + // The current detection mechanism is to simply check the first field + // for a struct tag with the "protogen" key. + if t.Kind() == reflect.Struct && t.NumField() > 0 { + pgt := t.Field(0).Tag.Get("protogen") + return strings.HasPrefix(pgt, "opaque.") + } + return false +} + +func opaqueInitHook(mi *MessageInfo) bool { + mt := mi.GoReflectType.Elem() + si := opaqueStructInfo{ + structInfo: mi.makeStructInfo(mt), + } + + if !isOpaque(mt) { + return false + } + + defer atomic.StoreUint32(&mi.initDone, 1) + + mi.fields = map[protoreflect.FieldNumber]*fieldInfo{} + fds := mi.Desc.Fields() + for i := 0; i < fds.Len(); i++ { + fd := fds.Get(i) + fs := si.fieldsByNumber[fd.Number()] + var fi fieldInfo + usePresence, _ := usePresenceForField(si, fd) + + switch { + case fd.IsWeak(): + // Weak fields are no different for opaque. + fi = fieldInfoForWeakMessage(fd, si.weakOffset) + case fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic(): + // Oneofs are no different for opaque. + fi = fieldInfoForOneof(fd, si.oneofsByName[fd.ContainingOneof().Name()], mi.Exporter, si.oneofWrappersByNumber[fd.Number()]) + case fd.IsMap(): + fi = mi.fieldInfoForMapOpaque(si, fd, fs) + case fd.IsList() && fd.Message() == nil && usePresence: + fi = mi.fieldInfoForScalarListOpaque(si, fd, fs) + case fd.IsList() && fd.Message() == nil: + // Proto3 lists without presence can use same access methods as open + fi = fieldInfoForList(fd, fs, mi.Exporter) + case fd.IsList() && usePresence: + fi = mi.fieldInfoForMessageListOpaque(si, fd, fs) + case fd.IsList(): + // Proto3 opaque messages that does not need presence bitmap. + // Different representation than open struct, but same logic + fi = mi.fieldInfoForMessageListOpaqueNoPresence(si, fd, fs) + case fd.Message() != nil && usePresence: + fi = mi.fieldInfoForMessageOpaque(si, fd, fs) + case fd.Message() != nil: + // Proto3 messages without presence can use same access methods as open + fi = fieldInfoForMessage(fd, fs, mi.Exporter) + default: + fi = mi.fieldInfoForScalarOpaque(si, fd, fs) + } + mi.fields[fd.Number()] = &fi + } + mi.oneofs = map[protoreflect.Name]*oneofInfo{} + for i := 0; i < mi.Desc.Oneofs().Len(); i++ { + od := mi.Desc.Oneofs().Get(i) + if !od.IsSynthetic() { + mi.oneofs[od.Name()] = makeOneofInfo(od, si.structInfo, mi.Exporter) + } + } + + mi.denseFields = make([]*fieldInfo, fds.Len()*2) + for i := 0; i < fds.Len(); i++ { + if fd := fds.Get(i); int(fd.Number()) < len(mi.denseFields) { + mi.denseFields[fd.Number()] = mi.fields[fd.Number()] + } + } + + for i := 0; i < fds.Len(); { + fd := fds.Get(i) + if od := fd.ContainingOneof(); od != nil && !fd.ContainingOneof().IsSynthetic() { + mi.rangeInfos = append(mi.rangeInfos, mi.oneofs[od.Name()]) + i += od.Fields().Len() + } else { + mi.rangeInfos = append(mi.rangeInfos, mi.fields[fd.Number()]) + i++ + } + } + + mi.makeExtensionFieldsFunc(mt, si.structInfo) + mi.makeUnknownFieldsFunc(mt, si.structInfo) + mi.makeOpaqueCoderMethods(mt, si) + mi.makeFieldTypes(si.structInfo) + + return true +} + +func (mi *MessageInfo) fieldInfoForMapOpaque(si opaqueStructInfo, fd protoreflect.FieldDescriptor, fs reflect.StructField) fieldInfo { + ft := fs.Type + if ft.Kind() != reflect.Map { + panic(fmt.Sprintf("invalid type: got %v, want map kind", ft)) + } + fieldOffset := offsetOf(fs, mi.Exporter) + conv := NewConverter(ft, fd) + return fieldInfo{ + fieldDesc: fd, + has: func(p pointer) bool { + if p.IsNil() { + return false + } + // Don't bother checking presence bits, since we need to + // look at the map length even if the presence bit is set. + rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem() + return rv.Len() > 0 + }, + clear: func(p pointer) { + rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem() + rv.Set(reflect.Zero(rv.Type())) + }, + get: func(p pointer) protoreflect.Value { + if p.IsNil() { + return conv.Zero() + } + rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem() + if rv.Len() == 0 { + return conv.Zero() + } + return conv.PBValueOf(rv) + }, + set: func(p pointer, v protoreflect.Value) { + pv := conv.GoValueOf(v) + if pv.IsNil() { + panic(fmt.Sprintf("invalid value: setting map field to read-only value")) + } + rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem() + rv.Set(pv) + }, + mutable: func(p pointer) protoreflect.Value { + v := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem() + if v.IsNil() { + v.Set(reflect.MakeMap(fs.Type)) + } + return conv.PBValueOf(v) + }, + newField: func() protoreflect.Value { + return conv.New() + }, + } +} + +func (mi *MessageInfo) fieldInfoForScalarListOpaque(si opaqueStructInfo, fd protoreflect.FieldDescriptor, fs reflect.StructField) fieldInfo { + ft := fs.Type + if ft.Kind() != reflect.Slice { + panic(fmt.Sprintf("invalid type: got %v, want slice kind", ft)) + } + conv := NewConverter(reflect.PtrTo(ft), fd) + fieldOffset := offsetOf(fs, mi.Exporter) + index, _ := presenceIndex(mi.Desc, fd) + return fieldInfo{ + fieldDesc: fd, + has: func(p pointer) bool { + if p.IsNil() { + return false + } + rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem() + return rv.Len() > 0 + }, + clear: func(p pointer) { + rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem() + rv.Set(reflect.Zero(rv.Type())) + }, + get: func(p pointer) protoreflect.Value { + if p.IsNil() { + return conv.Zero() + } + rv := p.Apply(fieldOffset).AsValueOf(fs.Type) + if rv.Elem().Len() == 0 { + return conv.Zero() + } + return conv.PBValueOf(rv) + }, + set: func(p pointer, v protoreflect.Value) { + pv := conv.GoValueOf(v) + if pv.IsNil() { + panic(fmt.Sprintf("invalid value: setting repeated field to read-only value")) + } + mi.setPresent(p, index) + rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem() + rv.Set(pv.Elem()) + }, + mutable: func(p pointer) protoreflect.Value { + mi.setPresent(p, index) + return conv.PBValueOf(p.Apply(fieldOffset).AsValueOf(fs.Type)) + }, + newField: func() protoreflect.Value { + return conv.New() + }, + } +} + +func (mi *MessageInfo) fieldInfoForMessageListOpaque(si opaqueStructInfo, fd protoreflect.FieldDescriptor, fs reflect.StructField) fieldInfo { + ft := fs.Type + if ft.Kind() != reflect.Ptr || ft.Elem().Kind() != reflect.Slice { + panic(fmt.Sprintf("invalid type: got %v, want slice kind", ft)) + } + conv := NewConverter(ft, fd) + fieldOffset := offsetOf(fs, mi.Exporter) + index, _ := presenceIndex(mi.Desc, fd) + fieldNumber := fd.Number() + return fieldInfo{ + fieldDesc: fd, + has: func(p pointer) bool { + if p.IsNil() { + return false + } + if !mi.present(p, index) { + return false + } + sp := p.Apply(fieldOffset).AtomicGetPointer() + if sp.IsNil() { + // Lazily unmarshal this field. + mi.lazyUnmarshal(p, fieldNumber) + sp = p.Apply(fieldOffset).AtomicGetPointer() + } + rv := sp.AsValueOf(fs.Type.Elem()) + return rv.Elem().Len() > 0 + }, + clear: func(p pointer) { + fp := p.Apply(fieldOffset) + sp := fp.AtomicGetPointer() + if sp.IsNil() { + sp = fp.AtomicSetPointerIfNil(pointerOfValue(reflect.New(fs.Type.Elem()))) + mi.setPresent(p, index) + } + rv := sp.AsValueOf(fs.Type.Elem()) + rv.Elem().Set(reflect.Zero(rv.Type().Elem())) + }, + get: func(p pointer) protoreflect.Value { + if p.IsNil() { + return conv.Zero() + } + if !mi.present(p, index) { + return conv.Zero() + } + sp := p.Apply(fieldOffset).AtomicGetPointer() + if sp.IsNil() { + // Lazily unmarshal this field. + mi.lazyUnmarshal(p, fieldNumber) + sp = p.Apply(fieldOffset).AtomicGetPointer() + } + rv := sp.AsValueOf(fs.Type.Elem()) + if rv.Elem().Len() == 0 { + return conv.Zero() + } + return conv.PBValueOf(rv) + }, + set: func(p pointer, v protoreflect.Value) { + fp := p.Apply(fieldOffset) + sp := fp.AtomicGetPointer() + if sp.IsNil() { + sp = fp.AtomicSetPointerIfNil(pointerOfValue(reflect.New(fs.Type.Elem()))) + mi.setPresent(p, index) + } + rv := sp.AsValueOf(fs.Type.Elem()) + val := conv.GoValueOf(v) + if val.IsNil() { + panic(fmt.Sprintf("invalid value: setting repeated field to read-only value")) + } else { + rv.Elem().Set(val.Elem()) + } + }, + mutable: func(p pointer) protoreflect.Value { + fp := p.Apply(fieldOffset) + sp := fp.AtomicGetPointer() + if sp.IsNil() { + if mi.present(p, index) { + // Lazily unmarshal this field. + mi.lazyUnmarshal(p, fieldNumber) + sp = p.Apply(fieldOffset).AtomicGetPointer() + } else { + sp = fp.AtomicSetPointerIfNil(pointerOfValue(reflect.New(fs.Type.Elem()))) + mi.setPresent(p, index) + } + } + rv := sp.AsValueOf(fs.Type.Elem()) + return conv.PBValueOf(rv) + }, + newField: func() protoreflect.Value { + return conv.New() + }, + } +} + +func (mi *MessageInfo) fieldInfoForMessageListOpaqueNoPresence(si opaqueStructInfo, fd protoreflect.FieldDescriptor, fs reflect.StructField) fieldInfo { + ft := fs.Type + if ft.Kind() != reflect.Ptr || ft.Elem().Kind() != reflect.Slice { + panic(fmt.Sprintf("invalid type: got %v, want slice kind", ft)) + } + conv := NewConverter(ft, fd) + fieldOffset := offsetOf(fs, mi.Exporter) + return fieldInfo{ + fieldDesc: fd, + has: func(p pointer) bool { + if p.IsNil() { + return false + } + sp := p.Apply(fieldOffset).AtomicGetPointer() + if sp.IsNil() { + return false + } + rv := sp.AsValueOf(fs.Type.Elem()) + return rv.Elem().Len() > 0 + }, + clear: func(p pointer) { + sp := p.Apply(fieldOffset).AtomicGetPointer() + if !sp.IsNil() { + rv := sp.AsValueOf(fs.Type.Elem()) + rv.Elem().Set(reflect.Zero(rv.Type().Elem())) + } + }, + get: func(p pointer) protoreflect.Value { + if p.IsNil() { + return conv.Zero() + } + sp := p.Apply(fieldOffset).AtomicGetPointer() + if sp.IsNil() { + return conv.Zero() + } + rv := sp.AsValueOf(fs.Type.Elem()) + if rv.Elem().Len() == 0 { + return conv.Zero() + } + return conv.PBValueOf(rv) + }, + set: func(p pointer, v protoreflect.Value) { + rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem() + if rv.IsNil() { + rv.Set(reflect.New(fs.Type.Elem())) + } + val := conv.GoValueOf(v) + if val.IsNil() { + panic(fmt.Sprintf("invalid value: setting repeated field to read-only value")) + } else { + rv.Elem().Set(val.Elem()) + } + }, + mutable: func(p pointer) protoreflect.Value { + rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem() + if rv.IsNil() { + rv.Set(reflect.New(fs.Type.Elem())) + } + return conv.PBValueOf(rv) + }, + newField: func() protoreflect.Value { + return conv.New() + }, + } +} + +func (mi *MessageInfo) fieldInfoForScalarOpaque(si opaqueStructInfo, fd protoreflect.FieldDescriptor, fs reflect.StructField) fieldInfo { + ft := fs.Type + nullable := fd.HasPresence() + if oneof := fd.ContainingOneof(); oneof != nil && oneof.IsSynthetic() { + nullable = true + } + deref := false + if nullable && ft.Kind() == reflect.Ptr { + ft = ft.Elem() + deref = true + } + conv := NewConverter(ft, fd) + fieldOffset := offsetOf(fs, mi.Exporter) + index, _ := presenceIndex(mi.Desc, fd) + var getter func(p pointer) protoreflect.Value + if !nullable { + getter = getterForDirectScalar(fd, fs, conv, fieldOffset) + } else { + getter = getterForOpaqueNullableScalar(mi, index, fd, fs, conv, fieldOffset) + } + return fieldInfo{ + fieldDesc: fd, + has: func(p pointer) bool { + if p.IsNil() { + return false + } + if nullable { + return mi.present(p, index) + } + rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem() + switch rv.Kind() { + case reflect.Bool: + return rv.Bool() + case reflect.Int32, reflect.Int64: + return rv.Int() != 0 + case reflect.Uint32, reflect.Uint64: + return rv.Uint() != 0 + case reflect.Float32, reflect.Float64: + return rv.Float() != 0 || math.Signbit(rv.Float()) + case reflect.String, reflect.Slice: + return rv.Len() > 0 + default: + panic(fmt.Sprintf("invalid type: %v", rv.Type())) // should never happen + } + }, + clear: func(p pointer) { + if nullable { + mi.clearPresent(p, index) + } + // This is only valuable for bytes and strings, but we do it unconditionally. + rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem() + rv.Set(reflect.Zero(rv.Type())) + }, + get: getter, + // TODO: Implement unsafe fast path for set? + set: func(p pointer, v protoreflect.Value) { + rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem() + if deref { + if rv.IsNil() { + rv.Set(reflect.New(ft)) + } + rv = rv.Elem() + } + + rv.Set(conv.GoValueOf(v)) + if nullable && rv.Kind() == reflect.Slice && rv.IsNil() { + rv.Set(emptyBytes) + } + if nullable { + mi.setPresent(p, index) + } + }, + newField: func() protoreflect.Value { + return conv.New() + }, + } +} + +func (mi *MessageInfo) fieldInfoForMessageOpaque(si opaqueStructInfo, fd protoreflect.FieldDescriptor, fs reflect.StructField) fieldInfo { + ft := fs.Type + conv := NewConverter(ft, fd) + fieldOffset := offsetOf(fs, mi.Exporter) + index, _ := presenceIndex(mi.Desc, fd) + fieldNumber := fd.Number() + elemType := fs.Type.Elem() + return fieldInfo{ + fieldDesc: fd, + has: func(p pointer) bool { + if p.IsNil() { + return false + } + return mi.present(p, index) + }, + clear: func(p pointer) { + mi.clearPresent(p, index) + p.Apply(fieldOffset).AtomicSetNilPointer() + }, + get: func(p pointer) protoreflect.Value { + if p.IsNil() || !mi.present(p, index) { + return conv.Zero() + } + fp := p.Apply(fieldOffset) + mp := fp.AtomicGetPointer() + if mp.IsNil() { + // Lazily unmarshal this field. + mi.lazyUnmarshal(p, fieldNumber) + mp = fp.AtomicGetPointer() + } + rv := mp.AsValueOf(elemType) + return conv.PBValueOf(rv) + }, + set: func(p pointer, v protoreflect.Value) { + val := pointerOfValue(conv.GoValueOf(v)) + if val.IsNil() { + panic("invalid nil pointer") + } + p.Apply(fieldOffset).AtomicSetPointer(val) + mi.setPresent(p, index) + }, + mutable: func(p pointer) protoreflect.Value { + fp := p.Apply(fieldOffset) + mp := fp.AtomicGetPointer() + if mp.IsNil() { + if mi.present(p, index) { + // Lazily unmarshal this field. + mi.lazyUnmarshal(p, fieldNumber) + mp = fp.AtomicGetPointer() + } else { + mp = pointerOfValue(conv.GoValueOf(conv.New())) + fp.AtomicSetPointer(mp) + mi.setPresent(p, index) + } + } + return conv.PBValueOf(mp.AsValueOf(fs.Type.Elem())) + }, + newMessage: func() protoreflect.Message { + return conv.New().Message() + }, + newField: func() protoreflect.Value { + return conv.New() + }, + } +} + +// A presenceList wraps a List, updating presence bits as necessary when the +// list contents change. +type presenceList struct { + pvalueList + setPresence func(bool) +} +type pvalueList interface { + protoreflect.List + //Unwrapper +} + +func (list presenceList) Append(v protoreflect.Value) { + list.pvalueList.Append(v) + list.setPresence(true) +} +func (list presenceList) Truncate(i int) { + list.pvalueList.Truncate(i) + list.setPresence(i > 0) +} + +// presenceIndex returns the index to pass to presence functions. +// +// TODO: field.Desc.Index() would be simpler, and would give space to record the presence of oneof fields. +func presenceIndex(md protoreflect.MessageDescriptor, fd protoreflect.FieldDescriptor) (uint32, presenceSize) { + found := false + var index, numIndices uint32 + for i := 0; i < md.Fields().Len(); i++ { + f := md.Fields().Get(i) + if f == fd { + found = true + index = numIndices + } + if f.ContainingOneof() == nil || isLastOneofField(f) { + numIndices++ + } + } + if !found { + panic(fmt.Sprintf("BUG: %v not in %v", fd.Name(), md.FullName())) + } + return index, presenceSize(numIndices) +} + +func isLastOneofField(fd protoreflect.FieldDescriptor) bool { + fields := fd.ContainingOneof().Fields() + return fields.Get(fields.Len()-1) == fd +} + +func (mi *MessageInfo) setPresent(p pointer, index uint32) { + p.Apply(mi.presenceOffset).PresenceInfo().SetPresent(index, mi.presenceSize) +} + +func (mi *MessageInfo) clearPresent(p pointer, index uint32) { + p.Apply(mi.presenceOffset).PresenceInfo().ClearPresent(index) +} + +func (mi *MessageInfo) present(p pointer, index uint32) bool { + return p.Apply(mi.presenceOffset).PresenceInfo().Present(index) +} + +// usePresenceForField implements the somewhat intricate logic of when +// the presence bitmap is used for a field. The main logic is that a +// field that is optional or that can be lazy will use the presence +// bit, but for proto2, also maps have a presence bit. It also records +// if the field can ever be lazy, which is true if we have a +// lazyOffset and the field is a message or a slice of messages. A +// field that is lazy will always need a presence bit. Oneofs are not +// lazy and do not use presence, unless they are a synthetic oneof, +// which is a proto3 optional field. For proto3 optionals, we use the +// presence and they can also be lazy when applicable (a message). +func usePresenceForField(si opaqueStructInfo, fd protoreflect.FieldDescriptor) (usePresence, canBeLazy bool) { + hasLazyField := fd.(interface{ IsLazy() bool }).IsLazy() + + // Non-oneof scalar fields with explicit field presence use the presence array. + usesPresenceArray := fd.HasPresence() && fd.Message() == nil && (fd.ContainingOneof() == nil || fd.ContainingOneof().IsSynthetic()) + switch { + case fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic(): + return false, false + case fd.IsWeak(): + return false, false + case fd.IsMap(): + return false, false + case fd.Kind() == protoreflect.MessageKind || fd.Kind() == protoreflect.GroupKind: + return hasLazyField, hasLazyField + default: + return usesPresenceArray || (hasLazyField && fd.HasPresence()), false + } +} diff --git a/internal/impl/message_opaque_gen.go b/internal/impl/message_opaque_gen.go new file mode 100644 index 000000000..a69825699 --- /dev/null +++ b/internal/impl/message_opaque_gen.go @@ -0,0 +1,132 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Code generated by generate-types. DO NOT EDIT. + +package impl + +import ( + "reflect" + + "google.golang.org/protobuf/reflect/protoreflect" +) + +func getterForOpaqueNullableScalar(mi *MessageInfo, index uint32, fd protoreflect.FieldDescriptor, fs reflect.StructField, conv Converter, fieldOffset offset) func(p pointer) protoreflect.Value { + ft := fs.Type + if ft.Kind() == reflect.Ptr { + ft = ft.Elem() + } + if fd.Kind() == protoreflect.EnumKind { + // Enums for nullable opaque types. + return func(p pointer) protoreflect.Value { + if p.IsNil() || !mi.present(p, index) { + return conv.Zero() + } + rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem() + return conv.PBValueOf(rv) + } + } + switch ft.Kind() { + case reflect.Bool: + return func(p pointer) protoreflect.Value { + if p.IsNil() || !mi.present(p, index) { + return conv.Zero() + } + x := p.Apply(fieldOffset).Bool() + return protoreflect.ValueOfBool(*x) + } + case reflect.Int32: + return func(p pointer) protoreflect.Value { + if p.IsNil() || !mi.present(p, index) { + return conv.Zero() + } + x := p.Apply(fieldOffset).Int32() + return protoreflect.ValueOfInt32(*x) + } + case reflect.Uint32: + return func(p pointer) protoreflect.Value { + if p.IsNil() || !mi.present(p, index) { + return conv.Zero() + } + x := p.Apply(fieldOffset).Uint32() + return protoreflect.ValueOfUint32(*x) + } + case reflect.Int64: + return func(p pointer) protoreflect.Value { + if p.IsNil() || !mi.present(p, index) { + return conv.Zero() + } + x := p.Apply(fieldOffset).Int64() + return protoreflect.ValueOfInt64(*x) + } + case reflect.Uint64: + return func(p pointer) protoreflect.Value { + if p.IsNil() || !mi.present(p, index) { + return conv.Zero() + } + x := p.Apply(fieldOffset).Uint64() + return protoreflect.ValueOfUint64(*x) + } + case reflect.Float32: + return func(p pointer) protoreflect.Value { + if p.IsNil() || !mi.present(p, index) { + return conv.Zero() + } + x := p.Apply(fieldOffset).Float32() + return protoreflect.ValueOfFloat32(*x) + } + case reflect.Float64: + return func(p pointer) protoreflect.Value { + if p.IsNil() || !mi.present(p, index) { + return conv.Zero() + } + x := p.Apply(fieldOffset).Float64() + return protoreflect.ValueOfFloat64(*x) + } + case reflect.String: + if fd.Kind() == protoreflect.BytesKind { + return func(p pointer) protoreflect.Value { + if p.IsNil() || !mi.present(p, index) { + return conv.Zero() + } + x := p.Apply(fieldOffset).StringPtr() + if *x == nil { + return conv.Zero() + } + if len(**x) == 0 { + return protoreflect.ValueOfBytes(nil) + } + return protoreflect.ValueOfBytes([]byte(**x)) + } + } + return func(p pointer) protoreflect.Value { + if p.IsNil() || !mi.present(p, index) { + return conv.Zero() + } + x := p.Apply(fieldOffset).StringPtr() + if *x == nil { + return conv.Zero() + } + return protoreflect.ValueOfString(**x) + } + case reflect.Slice: + if fd.Kind() == protoreflect.StringKind { + return func(p pointer) protoreflect.Value { + if p.IsNil() || !mi.present(p, index) { + return conv.Zero() + } + x := p.Apply(fieldOffset).Bytes() + return protoreflect.ValueOfString(string(*x)) + } + } + return func(p pointer) protoreflect.Value { + if p.IsNil() || !mi.present(p, index) { + return conv.Zero() + } + x := p.Apply(fieldOffset).Bytes() + return protoreflect.ValueOfBytes(*x) + } + } + panic("unexpected protobuf kind: " + ft.Kind().String()) +} diff --git a/internal/impl/message_reflect.go b/internal/impl/message_reflect.go index 98ab94aed..1b9b16a40 100644 --- a/internal/impl/message_reflect.go +++ b/internal/impl/message_reflect.go @@ -207,6 +207,11 @@ func (mi *MessageInfo) makeFieldTypes(si structInfo) { case fd.IsList(): if fd.Enum() != nil || fd.Message() != nil { ft = fs.Type.Elem() + + if ft.Kind() == reflect.Slice { + ft = ft.Elem() + } + } isMessage = fd.Message() != nil case fd.Enum() != nil: diff --git a/internal/impl/message_reflect_field.go b/internal/impl/message_reflect_field.go index 986322b19..a74064620 100644 --- a/internal/impl/message_reflect_field.go +++ b/internal/impl/message_reflect_field.go @@ -256,6 +256,7 @@ func fieldInfoForScalar(fd protoreflect.FieldDescriptor, fs reflect.StructField, ft := fs.Type nullable := fd.HasPresence() isBytes := ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 + var getter func(p pointer) protoreflect.Value if nullable { if ft.Kind() != reflect.Ptr && ft.Kind() != reflect.Slice { // This never occurs for generated message types. @@ -268,19 +269,25 @@ func fieldInfoForScalar(fd protoreflect.FieldDescriptor, fs reflect.StructField, } } conv := NewConverter(ft, fd) - - // TODO: Implement unsafe fast path? fieldOffset := offsetOf(fs, x) + + // Generate specialized getter functions to avoid going through reflect.Value + if nullable { + getter = getterForNullableScalar(fd, fs, conv, fieldOffset) + } else { + getter = getterForDirectScalar(fd, fs, conv, fieldOffset) + } + return fieldInfo{ fieldDesc: fd, has: func(p pointer) bool { if p.IsNil() { return false } - rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem() if nullable { - return !rv.IsNil() + return !p.Apply(fieldOffset).Elem().IsNil() } + rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem() switch rv.Kind() { case reflect.Bool: return rv.Bool() @@ -300,21 +307,8 @@ func fieldInfoForScalar(fd protoreflect.FieldDescriptor, fs reflect.StructField, rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem() rv.Set(reflect.Zero(rv.Type())) }, - get: func(p pointer) protoreflect.Value { - if p.IsNil() { - return conv.Zero() - } - rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem() - if nullable { - if rv.IsNil() { - return conv.Zero() - } - if rv.Kind() == reflect.Ptr { - rv = rv.Elem() - } - } - return conv.PBValueOf(rv) - }, + get: getter, + // TODO: Implement unsafe fast path for set? set: func(p pointer, v protoreflect.Value) { rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem() if nullable && rv.Kind() == reflect.Ptr { diff --git a/internal/impl/message_reflect_field_gen.go b/internal/impl/message_reflect_field_gen.go new file mode 100644 index 000000000..af5e063a1 --- /dev/null +++ b/internal/impl/message_reflect_field_gen.go @@ -0,0 +1,273 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Code generated by generate-types. DO NOT EDIT. + +package impl + +import ( + "reflect" + + "google.golang.org/protobuf/reflect/protoreflect" +) + +func getterForNullableScalar(fd protoreflect.FieldDescriptor, fs reflect.StructField, conv Converter, fieldOffset offset) func(p pointer) protoreflect.Value { + ft := fs.Type + if ft.Kind() == reflect.Ptr { + ft = ft.Elem() + } + if fd.Kind() == protoreflect.EnumKind { + elemType := fs.Type.Elem() + // Enums for nullable types. + return func(p pointer) protoreflect.Value { + if p.IsNil() { + return conv.Zero() + } + rv := p.Apply(fieldOffset).Elem().AsValueOf(elemType) + if rv.IsNil() { + return conv.Zero() + } + return conv.PBValueOf(rv.Elem()) + } + } + switch ft.Kind() { + case reflect.Bool: + return func(p pointer) protoreflect.Value { + if p.IsNil() { + return conv.Zero() + } + x := p.Apply(fieldOffset).BoolPtr() + if *x == nil { + return conv.Zero() + } + return protoreflect.ValueOfBool(**x) + } + case reflect.Int32: + return func(p pointer) protoreflect.Value { + if p.IsNil() { + return conv.Zero() + } + x := p.Apply(fieldOffset).Int32Ptr() + if *x == nil { + return conv.Zero() + } + return protoreflect.ValueOfInt32(**x) + } + case reflect.Uint32: + return func(p pointer) protoreflect.Value { + if p.IsNil() { + return conv.Zero() + } + x := p.Apply(fieldOffset).Uint32Ptr() + if *x == nil { + return conv.Zero() + } + return protoreflect.ValueOfUint32(**x) + } + case reflect.Int64: + return func(p pointer) protoreflect.Value { + if p.IsNil() { + return conv.Zero() + } + x := p.Apply(fieldOffset).Int64Ptr() + if *x == nil { + return conv.Zero() + } + return protoreflect.ValueOfInt64(**x) + } + case reflect.Uint64: + return func(p pointer) protoreflect.Value { + if p.IsNil() { + return conv.Zero() + } + x := p.Apply(fieldOffset).Uint64Ptr() + if *x == nil { + return conv.Zero() + } + return protoreflect.ValueOfUint64(**x) + } + case reflect.Float32: + return func(p pointer) protoreflect.Value { + if p.IsNil() { + return conv.Zero() + } + x := p.Apply(fieldOffset).Float32Ptr() + if *x == nil { + return conv.Zero() + } + return protoreflect.ValueOfFloat32(**x) + } + case reflect.Float64: + return func(p pointer) protoreflect.Value { + if p.IsNil() { + return conv.Zero() + } + x := p.Apply(fieldOffset).Float64Ptr() + if *x == nil { + return conv.Zero() + } + return protoreflect.ValueOfFloat64(**x) + } + case reflect.String: + if fd.Kind() == protoreflect.BytesKind { + return func(p pointer) protoreflect.Value { + if p.IsNil() { + return conv.Zero() + } + x := p.Apply(fieldOffset).StringPtr() + if *x == nil { + return conv.Zero() + } + if len(**x) == 0 { + return protoreflect.ValueOfBytes(nil) + } + return protoreflect.ValueOfBytes([]byte(**x)) + } + } + return func(p pointer) protoreflect.Value { + if p.IsNil() { + return conv.Zero() + } + x := p.Apply(fieldOffset).StringPtr() + if *x == nil { + return conv.Zero() + } + return protoreflect.ValueOfString(**x) + } + case reflect.Slice: + if fd.Kind() == protoreflect.StringKind { + return func(p pointer) protoreflect.Value { + if p.IsNil() { + return conv.Zero() + } + x := p.Apply(fieldOffset).Bytes() + if len(*x) == 0 { + return conv.Zero() + } + return protoreflect.ValueOfString(string(*x)) + } + } + return func(p pointer) protoreflect.Value { + if p.IsNil() { + return conv.Zero() + } + x := p.Apply(fieldOffset).Bytes() + if *x == nil { + return conv.Zero() + } + return protoreflect.ValueOfBytes(*x) + } + } + panic("unexpected protobuf kind: " + ft.Kind().String()) +} + +func getterForDirectScalar(fd protoreflect.FieldDescriptor, fs reflect.StructField, conv Converter, fieldOffset offset) func(p pointer) protoreflect.Value { + ft := fs.Type + if fd.Kind() == protoreflect.EnumKind { + // Enums for non nullable types. + return func(p pointer) protoreflect.Value { + if p.IsNil() { + return conv.Zero() + } + rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem() + return conv.PBValueOf(rv) + } + } + switch ft.Kind() { + case reflect.Bool: + return func(p pointer) protoreflect.Value { + if p.IsNil() { + return conv.Zero() + } + x := p.Apply(fieldOffset).Bool() + return protoreflect.ValueOfBool(*x) + } + case reflect.Int32: + return func(p pointer) protoreflect.Value { + if p.IsNil() { + return conv.Zero() + } + x := p.Apply(fieldOffset).Int32() + return protoreflect.ValueOfInt32(*x) + } + case reflect.Uint32: + return func(p pointer) protoreflect.Value { + if p.IsNil() { + return conv.Zero() + } + x := p.Apply(fieldOffset).Uint32() + return protoreflect.ValueOfUint32(*x) + } + case reflect.Int64: + return func(p pointer) protoreflect.Value { + if p.IsNil() { + return conv.Zero() + } + x := p.Apply(fieldOffset).Int64() + return protoreflect.ValueOfInt64(*x) + } + case reflect.Uint64: + return func(p pointer) protoreflect.Value { + if p.IsNil() { + return conv.Zero() + } + x := p.Apply(fieldOffset).Uint64() + return protoreflect.ValueOfUint64(*x) + } + case reflect.Float32: + return func(p pointer) protoreflect.Value { + if p.IsNil() { + return conv.Zero() + } + x := p.Apply(fieldOffset).Float32() + return protoreflect.ValueOfFloat32(*x) + } + case reflect.Float64: + return func(p pointer) protoreflect.Value { + if p.IsNil() { + return conv.Zero() + } + x := p.Apply(fieldOffset).Float64() + return protoreflect.ValueOfFloat64(*x) + } + case reflect.String: + if fd.Kind() == protoreflect.BytesKind { + return func(p pointer) protoreflect.Value { + if p.IsNil() { + return conv.Zero() + } + x := p.Apply(fieldOffset).String() + if len(*x) == 0 { + return protoreflect.ValueOfBytes(nil) + } + return protoreflect.ValueOfBytes([]byte(*x)) + } + } + return func(p pointer) protoreflect.Value { + if p.IsNil() { + return conv.Zero() + } + x := p.Apply(fieldOffset).String() + return protoreflect.ValueOfString(*x) + } + case reflect.Slice: + if fd.Kind() == protoreflect.StringKind { + return func(p pointer) protoreflect.Value { + if p.IsNil() { + return conv.Zero() + } + x := p.Apply(fieldOffset).Bytes() + return protoreflect.ValueOfString(string(*x)) + } + } + return func(p pointer) protoreflect.Value { + if p.IsNil() { + return conv.Zero() + } + x := p.Apply(fieldOffset).Bytes() + return protoreflect.ValueOfBytes(*x) + } + } + panic("unexpected protobuf kind: " + ft.Kind().String()) +} diff --git a/internal/impl/pointer_unsafe.go b/internal/impl/pointer_unsafe.go index 79e186667..041ebde2d 100644 --- a/internal/impl/pointer_unsafe.go +++ b/internal/impl/pointer_unsafe.go @@ -8,6 +8,8 @@ import ( "reflect" "sync/atomic" "unsafe" + + "google.golang.org/protobuf/internal/protolazy" ) const UnsafeEnabled = true @@ -111,6 +113,13 @@ func (p pointer) BytesPtr() **[]byte { return (**[]byte)(p.p) func (p pointer) BytesSlice() *[][]byte { return (*[][]byte)(p.p) } func (p pointer) WeakFields() *weakFields { return (*weakFields)(p.p) } func (p pointer) Extensions() *map[int32]ExtensionField { return (*map[int32]ExtensionField)(p.p) } +func (p pointer) LazyInfoPtr() **protolazy.XXX_lazyUnmarshalInfo { + return (**protolazy.XXX_lazyUnmarshalInfo)(p.p) +} + +func (p pointer) PresenceInfo() presence { + return presence{P: p.p} +} func (p pointer) Elem() pointer { return pointer{p: *(*unsafe.Pointer)(p.p)} diff --git a/internal/impl/pointer_unsafe_opaque.go b/internal/impl/pointer_unsafe_opaque.go new file mode 100644 index 000000000..38aa7b7dc --- /dev/null +++ b/internal/impl/pointer_unsafe_opaque.go @@ -0,0 +1,42 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package impl + +import ( + "sync/atomic" + "unsafe" +) + +func (p pointer) AtomicGetPointer() pointer { + return pointer{p: atomic.LoadPointer((*unsafe.Pointer)(p.p))} +} + +func (p pointer) AtomicSetPointer(v pointer) { + atomic.StorePointer((*unsafe.Pointer)(p.p), v.p) +} + +func (p pointer) AtomicSetNilPointer() { + atomic.StorePointer((*unsafe.Pointer)(p.p), unsafe.Pointer(nil)) +} + +func (p pointer) AtomicSetPointerIfNil(v pointer) pointer { + if atomic.CompareAndSwapPointer((*unsafe.Pointer)(p.p), unsafe.Pointer(nil), v.p) { + return v + } + return pointer{p: atomic.LoadPointer((*unsafe.Pointer)(p.p))} +} + +type atomicV1MessageInfo struct{ p Pointer } + +func (mi *atomicV1MessageInfo) Get() Pointer { + return Pointer(atomic.LoadPointer((*unsafe.Pointer)(&mi.p))) +} + +func (mi *atomicV1MessageInfo) SetIfNil(p Pointer) Pointer { + if atomic.CompareAndSwapPointer((*unsafe.Pointer)(&mi.p), nil, unsafe.Pointer(p)) { + return p + } + return mi.Get() +} diff --git a/internal/impl/presence.go b/internal/impl/presence.go new file mode 100644 index 000000000..914cb1ded --- /dev/null +++ b/internal/impl/presence.go @@ -0,0 +1,142 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package impl + +import ( + "sync/atomic" + "unsafe" +) + +// presenceSize represents the size of a presence set, which should be the largest index of the set+1 +type presenceSize uint32 + +// presence is the internal representation of the bitmap array in a generated protobuf +type presence struct { + // This is a pointer to the beginning of an array of uint32 + P unsafe.Pointer +} + +func (p presence) toElem(num uint32) (ret *uint32) { + const ( + bitsPerByte = 8 + siz = unsafe.Sizeof(*ret) + ) + // p.P points to an array of uint32, num is the bit in this array that the + // caller wants to check/manipulate. Calculate the index in the array that + // contains this specific bit. E.g.: 76 / 32 = 2 (integer division). + offset := uintptr(num) / (siz * bitsPerByte) * siz + return (*uint32)(unsafe.Pointer(uintptr(p.P) + offset)) +} + +// Present checks for the presence of a specific field number in a presence set. +func (p presence) Present(num uint32) bool { + if p.P == nil { + return false + } + return Export{}.Present(p.toElem(num), num) +} + +// SetPresent adds presence for a specific field number in a presence set. +func (p presence) SetPresent(num uint32, size presenceSize) { + Export{}.SetPresent(p.toElem(num), num, uint32(size)) +} + +// SetPresentUnatomic adds presence for a specific field number in a presence set without using +// atomic operations. Only to be called during unmarshaling. +func (p presence) SetPresentUnatomic(num uint32, size presenceSize) { + Export{}.SetPresentNonAtomic(p.toElem(num), num, uint32(size)) +} + +// ClearPresent removes presence for a specific field number in a presence set. +func (p presence) ClearPresent(num uint32) { + Export{}.ClearPresent(p.toElem(num), num) +} + +// LoadPresenceCache (together with PresentInCache) allows for a +// cached version of checking for presence without re-reading the word +// for every field. It is optimized for efficiency and assumes no +// simltaneous mutation of the presence set (or at least does not have +// a problem with simultaneous mutation giving inconsistent results). +func (p presence) LoadPresenceCache() (current uint32) { + if p.P == nil { + return 0 + } + return atomic.LoadUint32((*uint32)(p.P)) +} + +// PresentInCache reads presence from a cached word in the presence +// bitmap. It caches up a new word if the bit is outside the +// word. This is for really fast iteration through bitmaps in cases +// where we either know that the bitmap will not be altered, or we +// don't care about inconsistencies caused by simultaneous writes. +func (p presence) PresentInCache(num uint32, cachedElement *uint32, current *uint32) bool { + if num/32 != *cachedElement { + o := uintptr(num/32) * unsafe.Sizeof(uint32(0)) + q := (*uint32)(unsafe.Pointer(uintptr(p.P) + o)) + *current = atomic.LoadUint32(q) + *cachedElement = num / 32 + } + return (*current & (1 << (num % 32))) > 0 +} + +// AnyPresent checks if any field is marked as present in the bitmap. +func (p presence) AnyPresent(size presenceSize) bool { + n := uintptr((size + 31) / 32) + for j := uintptr(0); j < n; j++ { + o := j * unsafe.Sizeof(uint32(0)) + q := (*uint32)(unsafe.Pointer(uintptr(p.P) + o)) + b := atomic.LoadUint32(q) + if b > 0 { + return true + } + } + return false +} + +// toRaceDetectData finds the preceding RaceDetectHookData in a +// message by using pointer arithmetic. As the type of the presence +// set (bitmap) varies with the number of fields in the protobuf, we +// can not have a struct type containing the array and the +// RaceDetectHookData. instead the RaceDetectHookData is placed +// immediately before the bitmap array, and we find it by walking +// backwards in the struct. +// +// This method is only called from the race-detect version of the code, +// so RaceDetectHookData is never an empty struct. +func (p presence) toRaceDetectData() *RaceDetectHookData { + var template struct { + d RaceDetectHookData + a [1]uint32 + } + o := (uintptr(unsafe.Pointer(&template.a)) - uintptr(unsafe.Pointer(&template.d))) + return (*RaceDetectHookData)(unsafe.Pointer(uintptr(p.P) - o)) +} + +func atomicLoadShadowPresence(p **[]byte) *[]byte { + return (*[]byte)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(p)))) +} +func atomicStoreShadowPresence(p **[]byte, v *[]byte) { + atomic.CompareAndSwapPointer((*unsafe.Pointer)(unsafe.Pointer(p)), nil, unsafe.Pointer(v)) +} + +// findPointerToRaceDetectData finds the preceding RaceDetectHookData +// in a message by using pointer arithmetic. For the methods called +// directy from generated code, we don't have a pointer to the +// beginning of the presence set, but a pointer inside the array. As +// we know the index of the bit we're manipulating (num), we can +// calculate which element of the array ptr is pointing to. With that +// information we find the preceding RaceDetectHookData and can +// manipulate the shadow bitmap. +// +// This method is only called from the race-detect version of the +// code, so RaceDetectHookData is never an empty struct. +func findPointerToRaceDetectData(ptr *uint32, num uint32) *RaceDetectHookData { + var template struct { + d RaceDetectHookData + a [1]uint32 + } + o := (uintptr(unsafe.Pointer(&template.a)) - uintptr(unsafe.Pointer(&template.d))) + uintptr(num/32)*unsafe.Sizeof(uint32(0)) + return (*RaceDetectHookData)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) - o)) +} diff --git a/internal/impl/validate.go b/internal/impl/validate.go index a24e6bbd7..b534a3d6d 100644 --- a/internal/impl/validate.go +++ b/internal/impl/validate.go @@ -37,6 +37,10 @@ const ( // ValidationValid indicates that unmarshaling the message will succeed. ValidationValid + + // ValidationWrongWireType indicates that a validated field does not have + // the expected wire type. + ValidationWrongWireType ) func (v ValidationStatus) String() string { @@ -149,11 +153,23 @@ func newValidationInfo(fd protoreflect.FieldDescriptor, ft reflect.Type) validat switch fd.Kind() { case protoreflect.MessageKind: vi.typ = validationTypeMessage + + if ft.Kind() == reflect.Ptr { + // Repeated opaque message fields are *[]*T. + ft = ft.Elem() + } + if ft.Kind() == reflect.Slice { vi.mi = getMessageInfo(ft.Elem()) } case protoreflect.GroupKind: vi.typ = validationTypeGroup + + if ft.Kind() == reflect.Ptr { + // Repeated opaque message fields are *[]*T. + ft = ft.Elem() + } + if ft.Kind() == reflect.Slice { vi.mi = getMessageInfo(ft.Elem()) } diff --git a/internal/protolazy/bufferreader.go b/internal/protolazy/bufferreader.go new file mode 100644 index 000000000..82e5cab4a --- /dev/null +++ b/internal/protolazy/bufferreader.go @@ -0,0 +1,364 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Helper code for parsing a protocol buffer + +package protolazy + +import ( + "errors" + "fmt" + "io" + + "google.golang.org/protobuf/encoding/protowire" +) + +// BufferReader is a structure encapsulating a protobuf and a current position +type BufferReader struct { + Buf []byte + Pos int +} + +// NewBufferReader creates a new BufferRead from a protobuf +func NewBufferReader(buf []byte) BufferReader { + return BufferReader{Buf: buf, Pos: 0} +} + +var errOutOfBounds = errors.New("protobuf decoding: out of bounds") +var errOverflow = errors.New("proto: integer overflow") + +func (b *BufferReader) DecodeVarintSlow() (x uint64, err error) { + i := b.Pos + l := len(b.Buf) + + for shift := uint(0); shift < 64; shift += 7 { + if i >= l { + err = io.ErrUnexpectedEOF + return + } + v := b.Buf[i] + i++ + x |= (uint64(v) & 0x7F) << shift + if v < 0x80 { + b.Pos = i + return + } + } + + // The number is too large to represent in a 64-bit value. + err = errOverflow + return +} + +// decodeVarint decodes a varint at the current position +func (b *BufferReader) DecodeVarint() (x uint64, err error) { + i := b.Pos + buf := b.Buf + + if i >= len(buf) { + return 0, io.ErrUnexpectedEOF + } else if buf[i] < 0x80 { + b.Pos++ + return uint64(buf[i]), nil + } else if len(buf)-i < 10 { + return b.DecodeVarintSlow() + } + + var v uint64 + // we already checked the first byte + x = uint64(buf[i]) & 127 + i++ + + v = uint64(buf[i]) + i++ + x |= (v & 127) << 7 + if v < 128 { + goto done + } + + v = uint64(buf[i]) + i++ + x |= (v & 127) << 14 + if v < 128 { + goto done + } + + v = uint64(buf[i]) + i++ + x |= (v & 127) << 21 + if v < 128 { + goto done + } + + v = uint64(buf[i]) + i++ + x |= (v & 127) << 28 + if v < 128 { + goto done + } + + v = uint64(buf[i]) + i++ + x |= (v & 127) << 35 + if v < 128 { + goto done + } + + v = uint64(buf[i]) + i++ + x |= (v & 127) << 42 + if v < 128 { + goto done + } + + v = uint64(buf[i]) + i++ + x |= (v & 127) << 49 + if v < 128 { + goto done + } + + v = uint64(buf[i]) + i++ + x |= (v & 127) << 56 + if v < 128 { + goto done + } + + v = uint64(buf[i]) + i++ + x |= (v & 127) << 63 + if v < 128 { + goto done + } + + return 0, errOverflow + +done: + b.Pos = i + return +} + +// decodeVarint32 decodes a varint32 at the current position +func (b *BufferReader) DecodeVarint32() (x uint32, err error) { + i := b.Pos + buf := b.Buf + + if i >= len(buf) { + return 0, io.ErrUnexpectedEOF + } else if buf[i] < 0x80 { + b.Pos++ + return uint32(buf[i]), nil + } else if len(buf)-i < 5 { + v, err := b.DecodeVarintSlow() + return uint32(v), err + } + + var v uint32 + // we already checked the first byte + x = uint32(buf[i]) & 127 + i++ + + v = uint32(buf[i]) + i++ + x |= (v & 127) << 7 + if v < 128 { + goto done + } + + v = uint32(buf[i]) + i++ + x |= (v & 127) << 14 + if v < 128 { + goto done + } + + v = uint32(buf[i]) + i++ + x |= (v & 127) << 21 + if v < 128 { + goto done + } + + v = uint32(buf[i]) + i++ + x |= (v & 127) << 28 + if v < 128 { + goto done + } + + return 0, errOverflow + +done: + b.Pos = i + return +} + +// skipValue skips a value in the protobuf, based on the specified tag +func (b *BufferReader) SkipValue(tag uint32) (err error) { + wireType := tag & 0x7 + switch protowire.Type(wireType) { + case protowire.VarintType: + err = b.SkipVarint() + case protowire.Fixed64Type: + err = b.SkipFixed64() + case protowire.BytesType: + var n uint32 + n, err = b.DecodeVarint32() + if err == nil { + err = b.Skip(int(n)) + } + case protowire.StartGroupType: + err = b.SkipGroup(tag) + case protowire.Fixed32Type: + err = b.SkipFixed32() + default: + err = fmt.Errorf("Unexpected wire type (%d)", wireType) + } + return +} + +// skipGroup skips a group with the specified tag. It executes efficiently using a tag stack +func (b *BufferReader) SkipGroup(tag uint32) (err error) { + tagStack := make([]uint32, 0, 16) + tagStack = append(tagStack, tag) + var n uint32 + for len(tagStack) > 0 { + tag, err = b.DecodeVarint32() + if err != nil { + return err + } + switch protowire.Type(tag & 0x7) { + case protowire.VarintType: + err = b.SkipVarint() + case protowire.Fixed64Type: + err = b.Skip(8) + case protowire.BytesType: + n, err = b.DecodeVarint32() + if err == nil { + err = b.Skip(int(n)) + } + case protowire.StartGroupType: + tagStack = append(tagStack, tag) + case protowire.Fixed32Type: + err = b.SkipFixed32() + case protowire.EndGroupType: + if protoFieldNumber(tagStack[len(tagStack)-1]) == protoFieldNumber(tag) { + tagStack = tagStack[:len(tagStack)-1] + } else { + err = fmt.Errorf("end group tag %d does not match begin group tag %d at pos %d", + protoFieldNumber(tag), protoFieldNumber(tagStack[len(tagStack)-1]), b.Pos) + } + } + if err != nil { + return err + } + } + return nil +} + +// skipVarint effiently skips a varint +func (b *BufferReader) SkipVarint() (err error) { + i := b.Pos + + if len(b.Buf)-i < 10 { + // Use DecodeVarintSlow() to check for buffer overflow, but ignore result + if _, err := b.DecodeVarintSlow(); err != nil { + return err + } + return nil + } + + if b.Buf[i] < 0x80 { + goto out + } + i++ + + if b.Buf[i] < 0x80 { + goto out + } + i++ + + if b.Buf[i] < 0x80 { + goto out + } + i++ + + if b.Buf[i] < 0x80 { + goto out + } + i++ + + if b.Buf[i] < 0x80 { + goto out + } + i++ + + if b.Buf[i] < 0x80 { + goto out + } + i++ + + if b.Buf[i] < 0x80 { + goto out + } + i++ + + if b.Buf[i] < 0x80 { + goto out + } + i++ + + if b.Buf[i] < 0x80 { + goto out + } + i++ + + if b.Buf[i] < 0x80 { + goto out + } + return errOverflow + +out: + b.Pos = i + 1 + return nil +} + +// skip skips the specified number of bytes +func (b *BufferReader) Skip(n int) (err error) { + if len(b.Buf) < b.Pos+n { + return io.ErrUnexpectedEOF + } + b.Pos += n + return +} + +// skipFixed64 skips a fixed64 +func (b *BufferReader) SkipFixed64() (err error) { + return b.Skip(8) +} + +// skipFixed32 skips a fixed32 +func (b *BufferReader) SkipFixed32() (err error) { + return b.Skip(4) +} + +// skipBytes skips a set of bytes +func (b *BufferReader) SkipBytes() (err error) { + n, err := b.DecodeVarint32() + if err != nil { + return err + } + return b.Skip(int(n)) +} + +// Done returns whether we are at the end of the protobuf +func (b *BufferReader) Done() bool { + return b.Pos == len(b.Buf) +} + +// Remaining returns how many bytes remain +func (b *BufferReader) Remaining() int { + return len(b.Buf) - b.Pos +} diff --git a/internal/protolazy/lazy.go b/internal/protolazy/lazy.go new file mode 100644 index 000000000..ff4d4834b --- /dev/null +++ b/internal/protolazy/lazy.go @@ -0,0 +1,359 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package protolazy contains internal data structures for lazy message decoding. +package protolazy + +import ( + "fmt" + "sort" + + "google.golang.org/protobuf/encoding/protowire" + piface "google.golang.org/protobuf/runtime/protoiface" +) + +// IndexEntry is the structure for an index of the fields in a message of a +// proto (not descending to sub-messages) +type IndexEntry struct { + FieldNum uint32 + // first byte of this tag/field + Start uint32 + // first byte after a contiguous sequence of bytes for this tag/field, which could + // include a single encoding of the field, or multiple encodings for the field + End uint32 + // True if this protobuf segment includes multiple encodings of the field + MultipleContiguous bool +} + +// XXX_lazyUnmarshalInfo has information about a particular lazily decoded message +// +// Deprecated: Do not use. This will be deleted in the near future. +type XXX_lazyUnmarshalInfo struct { + // Index of fields and their positions in the protobuf for this + // message. Make index be a pointer to a slice so it can be updated + // atomically. The index pointer is only set once (lazily when/if + // the index is first needed), and must always be SET and LOADED + // ATOMICALLY. + index *[]IndexEntry + // The protobuf associated with this lazily decoded message. It is + // only set during proto.Unmarshal(). It doesn't need to be set and + // loaded atomically, since any simultaneous set (Unmarshal) and read + // (during a get) would already be a race in the app code. + Protobuf []byte + // The flags present when Unmarshal was originally called for this particular message + unmarshalFlags piface.UnmarshalInputFlags +} + +// The Buffer and SetBuffer methods let v2/internal/impl interact with +// XXX_lazyUnmarshalInfo via an interface, to avoid an import cycle. + +// Buffer returns the lazy unmarshal buffer. +// +// Deprecated: Do not use. This will be deleted in the near future. +func (lazy *XXX_lazyUnmarshalInfo) Buffer() []byte { + return lazy.Protobuf +} + +// SetBuffer sets the lazy unmarshal buffer. +// +// Deprecated: Do not use. This will be deleted in the near future. +func (lazy *XXX_lazyUnmarshalInfo) SetBuffer(b []byte) { + lazy.Protobuf = b +} + +// SetUnmarshalFlags is called to set a copy of the original unmarshalInputFlags. +// The flags should reflect how Unmarshal was called. +func (lazy *XXX_lazyUnmarshalInfo) SetUnmarshalFlags(f piface.UnmarshalInputFlags) { + lazy.unmarshalFlags = f +} + +// UnmarshalFlags returns the original unmarshalInputFlags. +func (lazy *XXX_lazyUnmarshalInfo) UnmarshalFlags() piface.UnmarshalInputFlags { + return lazy.unmarshalFlags +} + +// AllowedPartial returns true if the user originally unmarshalled this message with +// AllowPartial set to true +func (lazy *XXX_lazyUnmarshalInfo) AllowedPartial() bool { + return (lazy.unmarshalFlags & piface.UnmarshalCheckRequired) == 0 +} + +func protoFieldNumber(tag uint32) uint32 { + return tag >> 3 +} + +// buildIndex builds an index of the specified protobuf, return the index +// array and an error. +func buildIndex(buf []byte) ([]IndexEntry, error) { + index := make([]IndexEntry, 0, 16) + var lastProtoFieldNum uint32 + var outOfOrder bool + + var r BufferReader = NewBufferReader(buf) + + for !r.Done() { + var tag uint32 + var err error + var curPos = r.Pos + // INLINED: tag, err = r.DecodeVarint32() + { + i := r.Pos + buf := r.Buf + + if i >= len(buf) { + return nil, errOutOfBounds + } else if buf[i] < 0x80 { + r.Pos++ + tag = uint32(buf[i]) + } else if r.Remaining() < 5 { + var v uint64 + v, err = r.DecodeVarintSlow() + tag = uint32(v) + } else { + var v uint32 + // we already checked the first byte + tag = uint32(buf[i]) & 127 + i++ + + v = uint32(buf[i]) + i++ + tag |= (v & 127) << 7 + if v < 128 { + goto done + } + + v = uint32(buf[i]) + i++ + tag |= (v & 127) << 14 + if v < 128 { + goto done + } + + v = uint32(buf[i]) + i++ + tag |= (v & 127) << 21 + if v < 128 { + goto done + } + + v = uint32(buf[i]) + i++ + tag |= (v & 127) << 28 + if v < 128 { + goto done + } + + return nil, errOutOfBounds + + done: + r.Pos = i + } + } + // DONE: tag, err = r.DecodeVarint32() + + fieldNum := protoFieldNumber(tag) + if fieldNum < lastProtoFieldNum { + outOfOrder = true + } + + // Skip the current value -- will skip over an entire group as well. + // INLINED: err = r.SkipValue(tag) + wireType := tag & 0x7 + switch protowire.Type(wireType) { + case protowire.VarintType: + // INLINED: err = r.SkipVarint() + i := r.Pos + + if len(r.Buf)-i < 10 { + // Use DecodeVarintSlow() to skip while + // checking for buffer overflow, but ignore result + _, err = r.DecodeVarintSlow() + goto out2 + } + if r.Buf[i] < 0x80 { + goto out + } + i++ + + if r.Buf[i] < 0x80 { + goto out + } + i++ + + if r.Buf[i] < 0x80 { + goto out + } + i++ + + if r.Buf[i] < 0x80 { + goto out + } + i++ + + if r.Buf[i] < 0x80 { + goto out + } + i++ + + if r.Buf[i] < 0x80 { + goto out + } + i++ + + if r.Buf[i] < 0x80 { + goto out + } + i++ + + if r.Buf[i] < 0x80 { + goto out + } + i++ + + if r.Buf[i] < 0x80 { + goto out + } + i++ + + if r.Buf[i] < 0x80 { + goto out + } + return nil, errOverflow + out: + r.Pos = i + 1 + // DONE: err = r.SkipVarint() + case protowire.Fixed64Type: + err = r.SkipFixed64() + case protowire.BytesType: + var n uint32 + n, err = r.DecodeVarint32() + if err == nil { + err = r.Skip(int(n)) + } + case protowire.StartGroupType: + err = r.SkipGroup(tag) + case protowire.Fixed32Type: + err = r.SkipFixed32() + default: + err = fmt.Errorf("Unexpected wire type (%d)", wireType) + } + // DONE: err = r.SkipValue(tag) + + out2: + if err != nil { + return nil, err + } + if fieldNum != lastProtoFieldNum { + index = append(index, IndexEntry{FieldNum: fieldNum, + Start: uint32(curPos), + End: uint32(r.Pos)}, + ) + } else { + index[len(index)-1].End = uint32(r.Pos) + index[len(index)-1].MultipleContiguous = true + } + lastProtoFieldNum = fieldNum + } + if outOfOrder { + sort.Slice(index, func(i, j int) bool { + return index[i].FieldNum < index[j].FieldNum || + (index[i].FieldNum == index[j].FieldNum && + index[i].Start < index[j].Start) + }) + } + return index, nil +} + +func (lazy *XXX_lazyUnmarshalInfo) SizeField(num uint32) (size int) { + start, end, found, _, multipleEntries := lazy.FindFieldInProto(num) + if multipleEntries != nil { + for _, entry := range multipleEntries { + size += int(entry.End - entry.Start) + } + return size + } + if !found { + return 0 + } + return int(end - start) +} + +func (lazy *XXX_lazyUnmarshalInfo) AppendField(b []byte, num uint32) ([]byte, bool) { + start, end, found, _, multipleEntries := lazy.FindFieldInProto(num) + if multipleEntries != nil { + for _, entry := range multipleEntries { + b = append(b, lazy.Protobuf[entry.Start:entry.End]...) + } + return b, true + } + if !found { + return nil, false + } + b = append(b, lazy.Protobuf[start:end]...) + return b, true +} + +func (lazy *XXX_lazyUnmarshalInfo) SetIndex(index []IndexEntry) { + atomicStoreIndex(&lazy.index, &index) +} + +// FindFieldInProto looks for field fieldNum in lazyUnmarshalInfo information +// (including protobuf), returns startOffset/endOffset/found. +func (lazy *XXX_lazyUnmarshalInfo) FindFieldInProto(fieldNum uint32) (start, end uint32, found, multipleContiguous bool, multipleEntries []IndexEntry) { + if lazy.Protobuf == nil { + // There is no backing protobuf for this message -- it was made from a builder + return 0, 0, false, false, nil + } + index := atomicLoadIndex(&lazy.index) + if index == nil { + r, err := buildIndex(lazy.Protobuf) + if err != nil { + panic(fmt.Sprintf("findFieldInfo: error building index when looking for field %d: %v", fieldNum, err)) + } + // lazy.index is a pointer to the slice returned by BuildIndex + index = &r + atomicStoreIndex(&lazy.index, index) + } + return lookupField(index, fieldNum) +} + +// lookupField returns the offset at which the indicated field starts using +// the index, offset immediately after field ends (including all instances of +// a repeated field), and bools indicating if field was found and if there +// are multiple encodings of the field in the byte range. +// +// To hande the uncommon case where there are repeated encodings for the same +// field which are not consecutive in the protobuf (so we need to returns +// multiple start/end offsets), we also return a slice multipleEntries. If +// multipleEntries is non-nil, then multiple entries were found, and the +// values in the slice should be used, rather than start/end/found. +func lookupField(indexp *[]IndexEntry, fieldNum uint32) (start, end uint32, found bool, multipleContiguous bool, multipleEntries []IndexEntry) { + // The pointer indexp to the index was already loaded atomically. + // The slice is uniquely associated with the pointer, so it doesn't + // need to be loaded atomically. + index := *indexp + for i, entry := range index { + if fieldNum == entry.FieldNum { + if i < len(index)-1 && entry.FieldNum == index[i+1].FieldNum { + // Handle the uncommon case where there are + // repeated entries for the same field which + // are not contiguous in the protobuf. + multiple := make([]IndexEntry, 1, 2) + multiple[0] = IndexEntry{fieldNum, entry.Start, entry.End, entry.MultipleContiguous} + i++ + for i < len(index) && index[i].FieldNum == fieldNum { + multiple = append(multiple, IndexEntry{fieldNum, index[i].Start, index[i].End, index[i].MultipleContiguous}) + i++ + } + return 0, 0, false, false, multiple + + } + return entry.Start, entry.End, true, entry.MultipleContiguous, nil + } + if fieldNum < entry.FieldNum { + return 0, 0, false, false, nil + } + } + return 0, 0, false, false, nil +} diff --git a/internal/protolazy/pointer_unsafe.go b/internal/protolazy/pointer_unsafe.go new file mode 100644 index 000000000..dc2a64ca6 --- /dev/null +++ b/internal/protolazy/pointer_unsafe.go @@ -0,0 +1,17 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package protolazy + +import ( + "sync/atomic" + "unsafe" +) + +func atomicLoadIndex(p **[]IndexEntry) *[]IndexEntry { + return (*[]IndexEntry)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(p)))) +} +func atomicStoreIndex(p **[]IndexEntry, v *[]IndexEntry) { + atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(p)), unsafe.Pointer(v)) +} diff --git a/internal/race_test/lazy/lazy_race_test.go b/internal/race_test/lazy/lazy_race_test.go new file mode 100644 index 000000000..7d3868345 --- /dev/null +++ b/internal/race_test/lazy/lazy_race_test.go @@ -0,0 +1,494 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This test tests that races on lazy fields in opaque protos are detected by the race detector, +// even though the plain code uses atomic variables in a manner that would hide data races. +// This is essential, as concurrent writes or read-writes on a lazy field can cause undefined +// behaviours. +// +// Using exectest with the race detector to check that the code fails did not work, +// as the race error got propagated from the subprocess and failed the test case in the parent process. +// Instead we create the subprocess where the test is supposed to fail by ourselves. + +// Lazy decoding is only available in the fast path, which the protoreflect tag disables. +//go:build !protoreflect + +package lazy_race_test + +import ( + "fmt" + "os" + "os/exec" + "reflect" + "sync" + "testing" + "unsafe" + + "google.golang.org/protobuf/internal/test/race" + mixedpb "google.golang.org/protobuf/internal/testprotos/mixed" + testopaquepb "google.golang.org/protobuf/internal/testprotos/testeditions/testeditions_opaque" + "google.golang.org/protobuf/proto" +) + +// To get some output from the subprocess, set this to true +const debug = false + +func makeM2() *testopaquepb.TestAllTypes { + return testopaquepb.TestAllTypes_builder{ + OptionalLazyNestedMessage: testopaquepb.TestAllTypes_NestedMessage_builder{ + A: proto.Int32(1), + Corecursive: testopaquepb.TestAllTypes_builder{ + OptionalBool: proto.Bool(true), + }.Build(), + }.Build(), + RepeatedNestedMessage: []*testopaquepb.TestAllTypes_NestedMessage{ + testopaquepb.TestAllTypes_NestedMessage_builder{ + A: proto.Int32(2), + Corecursive: testopaquepb.TestAllTypes_builder{ + OptionalInt32: proto.Int32(32), + }.Build(), + }.Build(), + }, + }.Build() +} + +type testC struct { + name string + l1 func() + l2 func() +} + +const envVar = "GO_TESTING_IN_SUBPROCESS" + +// TestRaceDetectionOnWrite tests that any combination involving concurrent +// read-write or write-write will trigger the race detector. +func TestRaceDetectionOnWrite(t *testing.T) { + var x *testopaquepb.TestAllTypes + var y *testopaquepb.TestAllTypes_NestedMessage + var z int32 + // A table of test cases to expose to the race detector. + // The name will be set in an environment variable, so don't use special characters or spaces. + // Each entry in the table will be spawned into a sub process, where the actual execution will happen. + cases := []testC{ + { + name: "TestSetSet", + l1: func() { x.SetOptionalLazyNestedMessage(y) }, + l2: func() { x.SetOptionalLazyNestedMessage(y) }, + }, + { + name: "TestClearClear", + l1: func() { x.ClearOptionalLazyNestedMessage() }, + l2: func() { x.ClearOptionalLazyNestedMessage() }, + }, + { + name: "TestSetClear", + l1: func() { x.SetOptionalLazyNestedMessage(y) }, + l2: func() { x.ClearOptionalLazyNestedMessage() }, + }, + { + name: "TestSetGet", + l1: func() { x.SetOptionalLazyNestedMessage(y) }, + l2: func() { + if x.GetOptionalLazyNestedMessage().GetCorecursive().GetOptionalBool() { + z++ + } + }, + }, + { + name: "TestSetHas", + l1: func() { x.SetOptionalLazyNestedMessage(y) }, + l2: func() { + if x.HasOptionalLazyNestedMessage() { + z++ + } + }, + }, + { + name: "TestClearGet", + l1: func() { x.ClearOptionalLazyNestedMessage() }, + l2: func() { + if x.GetOptionalLazyNestedMessage().GetCorecursive().GetOptionalBool() { + z++ + } + }, + }, + { + name: "TestClearHas", + l1: func() { x.ClearOptionalLazyNestedMessage() }, + l2: func() { + if x.HasOptionalLazyNestedMessage() { + z++ + } + }, + }, + } + e := os.Getenv(envVar) + if e != "" { + // We're in the subprocess. As spawnCase will add filter for the subtest, + // we will actually only execute one test in this subprocess even though + // we call t.Run for all cases. + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + x = makeM2() + y = x.GetOptionalLazyNestedMessage() + z = 0 + execCase(t, tc) + return + }) + } + return + } + // If we're not in a subprocess, spawn and check one for each entry in the table + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + spawnCase(t) + }) + } +} + +// execCase actually executes the testcase when we're in a subprocess, it executes +// the two operations of tc in parallel and make sure tsan sees this as parallel +// execution. +func execCase(t *testing.T, tc testC) { + t.Helper() + c1 := make(chan struct{}) + wg := sync.WaitGroup{} + wg.Add(2) + // This is a very complicated but stable way of telling tsan that the + // two operations are executed in parallel. I can only guess why this + // works so I'll leave my speculations out of the comment but + // experiments suggest that it works reliably. + go func() { + c1 <- struct{}{} + tc.l1() + <-c1 + tc.l1() + wg.Done() + }() + go func() { + <-c1 + tc.l2() + c1 <- struct{}{} + tc.l2() + wg.Done() + }() + wg.Wait() +} + +// spawnCase reruns this executable to execute t.Name() with the sub-case tn in the environment variable +func spawnCase(t *testing.T) { + // If we get here, we are in the parent process and should execute ourselves, but filter on the test that called us. + ep, err := os.Executable() + if err != nil { + t.Fatalf("Failed to find my own executable: %v", err) + } + c := exec.Command(ep, "--test.run="+t.Name()) + // Set the environment variable so that we know we're in a subproceess when re-executed + c.Env = append(c.Env, envVar+"=true") + out, err := c.CombinedOutput() + // If we do not get an error, we fail in the parent process, otherwise we're good + if race.Enabled && err == nil { + t.Errorf("Got success, want error under race detector:\n-----------\n%s\n-------------\n", string(out)) + } + if !race.Enabled && err != nil { + t.Errorf("Got error, want success without race detector:\n-----------\n%s\n-------------\n", string(out)) + } + if debug { + fmt.Fprintf(os.Stderr, "Subprocess output:\n-----------\n%s\n-------------\n", string(out)) + } +} + +// TestNoRaceDetection should not fail under race detector (or otherwise) +func TestNoRaceDetection(t *testing.T) { + x := makeM2() + var y int32 + var z int32 + c := make(chan struct{}) + go func() { + for i := 0; i < 10000; i++ { + y += x.GetRepeatedNestedMessage()[0].GetA() + } + close(c) + }() + for i := 0; i < 10000; i++ { + z += x.GetRepeatedNestedMessage()[0].GetA() + } + <-c + if z != y { + t.Errorf("The two go-routines did not calculate the same: %d != %d", z, y) + } +} + +func TestNoRaceOnGetsOfSlices(t *testing.T) { + x := makeM2() + b, err := proto.Marshal(x) + if err != nil { + t.Fatalf("Error while marshaling: %v", err) + } + + var y int32 + var z int32 + d := make(chan int) + + // Check that there are no races when we do concurrent lazy gets of a field + // containing a slice of message pointers. + for i := 0; i < 10000; i++ { + err := proto.Unmarshal(b, x) + if err != nil { + t.Fatalf("Error while unmarshaling: %v", err) + } + go func() { + y += x.GetRepeatedNestedMessage()[0].GetA() + d <- 1 + }() + go func() { + z += x.GetRepeatedNestedMessage()[0].GetA() + d <- 1 + }() + <-d + <-d + } + if z != y { + t.Errorf("The two go-routines did not calculate the same: %d != %d", z, y) + } + close(d) +} + +func TestNoRaceOnGetsOfMessages(t *testing.T) { + x := makeM2() + b, err := proto.Marshal(x) + if err != nil { + t.Fatalf("Error while marshaling: %v", err) + } + + var y int32 + var z int32 + d := make(chan int) + + // Check that there is no race when we do concurrent lazy gets of a field + // pointing to a sub-message. + for i := 0; i < 10000; i++ { + err := proto.Unmarshal(b, x) + if err != nil { + t.Fatalf("Error while unmarshaling: %v", err) + } + go func() { + if x.GetOptionalLazyNestedMessage().GetA() > 0 { + y++ + } + d <- 1 + }() + go func() { + if x.GetOptionalLazyNestedMessage().GetA() > 0 { + z++ + } + d <- 1 + }() + <-d + <-d + } + if z != y { + t.Errorf("The two go-routines did not calculate the same: %d != %d", z, y) + } + + close(d) +} + +func fillRequiredLazy() *testopaquepb.TestRequiredLazy { + return testopaquepb.TestRequiredLazy_builder{ + OptionalLazyMessage: testopaquepb.TestRequired_builder{ + RequiredField: proto.Int32(23), + }.Build(), + }.Build() +} + +func expandedLazy(m *testopaquepb.TestRequiredLazy) bool { + v := reflect.ValueOf(m).Elem() + rf := v.FieldByName("xxx_hidden_OptionalLazyMessage") + rf = reflect.NewAt(rf.Type(), unsafe.Pointer(rf.UnsafeAddr())).Elem() + return rf.Pointer() != 0 +} + +// This test verifies all assumptions of TestParallellMarshalWithRequired +// are (still) valid, to prevent the test from becoming a no-op (again). +func TestParallellMarshalWithRequiredAssumptions(t *testing.T) { + b, err := proto.Marshal(fillRequiredLazy()) + if err != nil { + t.Fatal(err) + } + + ml := &testopaquepb.TestRequiredLazy{} + // Specifying AllowPartial: true at unmarshal time is required, otherwise + // the Marshal call will skip the required field check. + if err := (proto.UnmarshalOptions{AllowPartial: true}).Unmarshal(b, ml); err != nil { + t.Fatal(err) + } + if expandedLazy(ml) { + t.Fatalf("lazy message unexpectedly decoded") + } + + // Marshaling with AllowPartial: true means the no decoding is needed, + // because no required field checks are done. + if _, err := (proto.MarshalOptions{AllowPartial: true}).Marshal(ml); err != nil { + t.Fatal(err) + } + if expandedLazy(ml) { + t.Fatalf("lazy message unexpectedly decoded") + } + + // Whereas marshaling with AllowPartial: false (default) means the message + // will be decoded to check if any required fields are not set. + if _, err := (proto.MarshalOptions{AllowPartial: false}).Marshal(ml); err != nil { + t.Fatal(err) + } + if !expandedLazy(ml) { + t.Fatalf("lazy message unexpectedly not decoded") + } +} + +// TestParallellMarshalWithRequired runs two goroutines that marshal the same +// message. Marshaling a message can result in lazily decoding said message, +// provided the message contains any required fields. This test ensures that +// said lazy decoding can happen without causing races in the other goroutine +// that marshals the same message. +func TestParallellMarshalWithRequired(t *testing.T) { + m := fillRequiredLazy() + b, err := proto.MarshalOptions{}.Marshal(m) + if err != nil { + t.Fatal(err) + } + partial := false + for i := 0; i < 1000; i++ { + partial = !partial + ml := &testopaquepb.TestRequiredLazy{} + d := make(chan bool) + err := proto.UnmarshalOptions{AllowPartial: true}.Unmarshal(b, ml) + if err != nil { + t.Fatalf("Error while unmarshaling: %v", err) + } + + go func() { + b2, err := proto.MarshalOptions{AllowPartial: partial}.Marshal(ml) + if err != nil { + t.Errorf("Marshal error: %v", err) + d <- false + return + } + m := &testopaquepb.TestRequiredLazy{} + if err := (proto.UnmarshalOptions{}).Unmarshal(b2, m); err != nil { + t.Errorf("Unmarshal error: %v", err) + d <- false + return + } + if !proto.Equal(ml, m) { + t.Errorf("Unmarshal roundtrip - protos not equal") + d <- false + return + } + d <- true + }() + go func() { + b2, err := proto.MarshalOptions{AllowPartial: partial}.Marshal(ml) + if err != nil { + t.Errorf("Marshal error: %v", err) + d <- false + return + } + m := &testopaquepb.TestRequiredLazy{} + if err := (proto.UnmarshalOptions{}).Unmarshal(b2, m); err != nil { + if !proto.Equal(ml, m) { + t.Errorf("Unmarshal roundtrip - protos not equal") + d <- false + return + } + if !proto.Equal(ml, m) { + t.Errorf("Unmarshal roundtrip - protos not equal") + d <- false + return + } + } + d <- true + }() + x := <-d + y := <-d + if !x || !y { + t.Fatalf("Worker reported error") + } + } +} + +func fillMixedOpaqueLazy() *mixedpb.OpaqueLazy { + return mixedpb.OpaqueLazy_builder{ + Opaque: mixedpb.OpaqueLazy_builder{ + OptionalInt32: proto.Int32(23), + Hybrid: mixedpb.HybridLazy_builder{ + OptionalInt32: proto.Int32(42), + }.Build(), + }.Build(), + Hybrid: mixedpb.HybridLazy_builder{ + OptionalInt32: proto.Int32(5), + }.Build(), + }.Build() +} + +func TestParallellMarshalMixed(t *testing.T) { + m := fillMixedOpaqueLazy() + b, err := proto.Marshal(m) + if err != nil { + t.Fatal(err) + } + for i := 0; i < 10000; i++ { + ml := &mixedpb.OpaqueLazy{} + d := make(chan bool) + if err := proto.Unmarshal(b, ml); err != nil { + t.Fatalf("Error while unmarshaling: %v", err) + } + + go func() { + b2, err := proto.Marshal(ml) + if err != nil { + t.Errorf("Marshal error: %v", err) + d <- false + return + } + m := &mixedpb.OpaqueLazy{} + if err := proto.Unmarshal(b2, m); err != nil { + t.Errorf("Unmarshal error: %v", err) + d <- false + return + } + if !proto.Equal(ml, m) { // This is what expands all fields of ml + t.Errorf("Unmarshal roundtrip - protos not equal") + d <- false + return + } + d <- true + }() + go func() { + b2, err := proto.Marshal(ml) + if err != nil { + t.Errorf("Marshal error: %v", err) + d <- false + return + } + m := &mixedpb.OpaqueLazy{} + if err := proto.Unmarshal(b2, m); err != nil { + t.Errorf("Unmarshal error: %v", err) + d <- false + return + } + if !proto.Equal(ml, m) { // This is what expands all fields of ml + t.Errorf("Unmarshal roundtrip - protos not equal") + d <- false + return + } + d <- true + }() + x := <-d + y := <-d + if !x || !y { + t.Fatalf("Worker reported error") + } + } +} diff --git a/internal/reflection_test/reflection_hybrid_test.go b/internal/reflection_test/reflection_hybrid_test.go new file mode 100644 index 000000000..4230ab847 --- /dev/null +++ b/internal/reflection_test/reflection_hybrid_test.go @@ -0,0 +1,1003 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package reflection_test + +import ( + "fmt" + "math" + "testing" + + testpb "google.golang.org/protobuf/internal/testprotos/testeditions/testeditions_hybrid" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/testing/prototest" +) + +func TestOpen3Concrete(t *testing.T) { + + prototest.Message{}.Test(t, newTestMessageOpen3(nil).ProtoReflect().Type()) +} + +func TestOpen3Reflection(t *testing.T) { + prototest.Message{}.Test(t, (*testpb.TestAllTypes)(nil).ProtoReflect().Type()) +} + +func TestOpen3Shadow_GetConcrete_SetReflection(t *testing.T) { + prototest.Message{}.Test(t, newShadow(func() (get, set protoreflect.ProtoMessage) { + m := &testpb.TestAllTypes{} + return newTestMessageOpen3(m), m + }).ProtoReflect().Type()) +} + +func TestOpen3Shadow_GetReflection_SetConcrete(t *testing.T) { + prototest.Message{}.Test(t, newShadow(func() (get, set protoreflect.ProtoMessage) { + m := &testpb.TestAllTypes{} + return m, newTestMessageOpen3(m) + }).ProtoReflect().Type()) +} + +func newTestMessageOpen3(m *testpb.TestAllTypes) protoreflect.ProtoMessage { + return &testProtoMessage{ + m: m, + md: m.ProtoReflect().Descriptor(), + new: func() protoreflect.Message { + return newTestMessageOpen3(&testpb.TestAllTypes{}).ProtoReflect() + }, + has: func(num protoreflect.FieldNumber) bool { + switch num { + case fieldSingularInt32: + return m.GetSingularInt32() != 0 + case fieldSingularInt64: + return m.GetSingularInt64() != 0 + case fieldSingularUint32: + return m.GetSingularUint32() != 0 + case fieldSingularUint64: + return m.GetSingularUint64() != 0 + case fieldSingularSint32: + return m.GetSingularSint32() != 0 + case fieldSingularSint64: + return m.GetSingularSint64() != 0 + case fieldSingularFixed32: + return m.GetSingularFixed32() != 0 + case fieldSingularFixed64: + return m.GetSingularFixed64() != 0 + case fieldSingularSfixed32: + return m.GetSingularSfixed32() != 0 + case fieldSingularSfixed64: + return m.GetSingularSfixed64() != 0 + case fieldSingularFloat: + return m.GetSingularFloat() != 0 || math.Signbit(float64(m.GetSingularFloat())) + case fieldSingularDouble: + return m.GetSingularDouble() != 0 || math.Signbit(m.GetSingularDouble()) + case fieldSingularBool: + return m.GetSingularBool() != false + case fieldSingularString: + return m.GetSingularString() != "" + case fieldSingularBytes: + return len(m.GetSingularBytes()) != 0 + case fieldSingularNestedEnum: + return m.GetSingularNestedEnum() != testpb.TestAllTypes_FOO + case fieldSingularForeignEnum: + return m.GetSingularForeignEnum() != testpb.ForeignEnum_FOREIGN_ZERO + case fieldSingularImportEnum: + return m.GetSingularImportEnum() != testpb.ImportEnum_IMPORT_ZERO + + case fieldOptionalInt32: + return m.HasOptionalInt32() + case fieldOptionalInt64: + return m.HasOptionalInt64() + case fieldOptionalUint32: + return m.HasOptionalUint32() + case fieldOptionalUint64: + return m.HasOptionalUint64() + case fieldOptionalSint32: + return m.HasOptionalSint32() + case fieldOptionalSint64: + return m.HasOptionalSint64() + case fieldOptionalFixed32: + return m.HasOptionalFixed32() + case fieldOptionalFixed64: + return m.HasOptionalFixed64() + case fieldOptionalSfixed32: + return m.HasOptionalSfixed32() + case fieldOptionalSfixed64: + return m.HasOptionalSfixed64() + case fieldOptionalFloat: + return m.HasOptionalFloat() + case fieldOptionalDouble: + return m.HasOptionalDouble() + case fieldOptionalBool: + return m.HasOptionalBool() + case fieldOptionalString: + return m.HasOptionalString() + case fieldOptionalBytes: + return m.HasOptionalBytes() + case fieldOptionalGroup: + return m.HasOptionalgroup() + case fieldNotGroupLikeDelimited: + return m.HasNotGroupLikeDelimited() + case fieldOptionalNestedMessage: + return m.HasOptionalNestedMessage() + case fieldOptionalForeignMessage: + return m.HasOptionalForeignMessage() + case fieldOptionalImportMessage: + return m.HasOptionalImportMessage() + case fieldOptionalNestedEnum: + return m.HasOptionalNestedEnum() + case fieldOptionalForeignEnum: + return m.HasOptionalForeignEnum() + case fieldOptionalImportEnum: + return m.HasOptionalImportEnum() + case fieldOptionalLazyNestedMessage: + return m.HasOptionalLazyNestedMessage() + + case fieldRepeatedInt32: + return len(m.GetRepeatedInt32()) > 0 + case fieldRepeatedInt64: + return len(m.GetRepeatedInt64()) > 0 + case fieldRepeatedUint32: + return len(m.GetRepeatedUint32()) > 0 + case fieldRepeatedUint64: + return len(m.GetRepeatedUint64()) > 0 + case fieldRepeatedSint32: + return len(m.GetRepeatedSint32()) > 0 + case fieldRepeatedSint64: + return len(m.GetRepeatedSint64()) > 0 + case fieldRepeatedFixed32: + return len(m.GetRepeatedFixed32()) > 0 + case fieldRepeatedFixed64: + return len(m.GetRepeatedFixed64()) > 0 + case fieldRepeatedSfixed32: + return len(m.GetRepeatedSfixed32()) > 0 + case fieldRepeatedSfixed64: + return len(m.GetRepeatedSfixed64()) > 0 + case fieldRepeatedFloat: + return len(m.GetRepeatedFloat()) > 0 + case fieldRepeatedDouble: + return len(m.GetRepeatedDouble()) > 0 + case fieldRepeatedBool: + return len(m.GetRepeatedBool()) > 0 + case fieldRepeatedString: + return len(m.GetRepeatedString()) > 0 + case fieldRepeatedBytes: + return len(m.GetRepeatedBytes()) > 0 + case fieldRepeatedGroup: + return len(m.GetRepeatedgroup()) > 0 + case fieldRepeatedNestedMessage: + return len(m.GetRepeatedNestedMessage()) > 0 + case fieldRepeatedForeignMessage: + return len(m.GetRepeatedForeignMessage()) > 0 + case fieldRepeatedImportMessage: + return len(m.GetRepeatedImportmessage()) > 0 + case fieldRepeatedNestedEnum: + return len(m.GetRepeatedNestedEnum()) > 0 + case fieldRepeatedForeignEnum: + return len(m.GetRepeatedForeignEnum()) > 0 + case fieldRepeatedImportEnum: + return len(m.GetRepeatedImportenum()) > 0 + + case fieldMapInt32Int32: + return len(m.GetMapInt32Int32()) > 0 + case fieldMapInt64Int64: + return len(m.GetMapInt64Int64()) > 0 + case fieldMapUint32Uint32: + return len(m.GetMapUint32Uint32()) > 0 + case fieldMapUint64Uint64: + return len(m.GetMapUint64Uint64()) > 0 + case fieldMapSint32Sint32: + return len(m.GetMapSint32Sint32()) > 0 + case fieldMapSint64Sint64: + return len(m.GetMapSint64Sint64()) > 0 + case fieldMapFixed32Fixed32: + return len(m.GetMapFixed32Fixed32()) > 0 + case fieldMapFixed64Fixed64: + return len(m.GetMapFixed64Fixed64()) > 0 + case fieldMapSfixed32Sfixed32: + return len(m.GetMapSfixed32Sfixed32()) > 0 + case fieldMapSfixed64Sfixed64: + return len(m.GetMapSfixed64Sfixed64()) > 0 + case fieldMapInt32Float: + return len(m.GetMapInt32Float()) > 0 + case fieldMapInt32Double: + return len(m.GetMapInt32Double()) > 0 + case fieldMapBoolBool: + return len(m.GetMapBoolBool()) > 0 + case fieldMapStringString: + return len(m.GetMapStringString()) > 0 + case fieldMapStringBytes: + return len(m.GetMapStringBytes()) > 0 + case fieldMapStringNestedMessage: + return len(m.GetMapStringNestedMessage()) > 0 + case fieldMapStringNestedEnum: + return len(m.GetMapStringNestedEnum()) > 0 + + case fieldDefaultInt32: + return m.HasDefaultInt32() + case fieldDefaultInt64: + return m.HasDefaultInt64() + case fieldDefaultUint32: + return m.HasDefaultUint32() + case fieldDefaultUint64: + return m.HasDefaultUint64() + case fieldDefaultSint32: + return m.HasDefaultSint32() + case fieldDefaultSint64: + return m.HasDefaultSint64() + case fieldDefaultFixed32: + return m.HasDefaultFixed32() + case fieldDefaultFixed64: + return m.HasDefaultFixed64() + case fieldDefaultSfixed32: + return m.HasDefaultSfixed32() + case fieldDefaultSfixed64: + return m.HasDefaultSfixed64() + case fieldDefaultFloat: + return m.HasDefaultFloat() + case fieldDefaultDouble: + return m.HasDefaultDouble() + case fieldDefaultBool: + return m.HasDefaultBool() + case fieldDefaultString: + return m.HasDefaultString() + case fieldDefaultBytes: + return m.HasDefaultBytes() + case fieldDefaultNestedEnum: + return m.HasDefaultNestedEnum() + case fieldDefaultForeignEnum: + return m.HasDefaultForeignEnum() + + case fieldOneofUint32: + return m.HasOneofUint32() + case fieldOneofNestedMessage: + return m.HasOneofNestedMessage() + case fieldOneofString: + return m.HasOneofString() + case fieldOneofBytes: + return m.HasOneofBytes() + case fieldOneofBool: + return m.HasOneofBool() + case fieldOneofUint64: + return m.HasOneofUint64() + case fieldOneofFloat: + return m.HasOneofFloat() + case fieldOneofDouble: + return m.HasOneofDouble() + case fieldOneofEnum: + return m.HasOneofEnum() + case fieldOneofGroup: + return m.HasOneofgroup() + case fieldOneofOptionalUint32: + return m.HasOneofOptionalUint32() + + default: + panic(fmt.Sprintf("has: unknown field %d", num)) + } + }, + get: func(num protoreflect.FieldNumber) any { + switch num { + case fieldSingularInt32: + return m.GetSingularInt32() + case fieldSingularInt64: + return m.GetSingularInt64() + case fieldSingularUint32: + return m.GetSingularUint32() + case fieldSingularUint64: + return m.GetSingularUint64() + case fieldSingularSint32: + return m.GetSingularSint32() + case fieldSingularSint64: + return m.GetSingularSint64() + case fieldSingularFixed32: + return m.GetSingularFixed32() + case fieldSingularFixed64: + return m.GetSingularFixed64() + case fieldSingularSfixed32: + return m.GetSingularSfixed32() + case fieldSingularSfixed64: + return m.GetSingularSfixed64() + case fieldSingularFloat: + return m.GetSingularFloat() + case fieldSingularDouble: + return m.GetSingularDouble() + case fieldSingularBool: + return m.GetSingularBool() + case fieldSingularString: + return m.GetSingularString() + case fieldSingularBytes: + return m.GetSingularBytes() + case fieldSingularNestedEnum: + return m.GetSingularNestedEnum() + case fieldSingularForeignEnum: + return m.GetSingularForeignEnum() + case fieldSingularImportEnum: + return m.GetSingularImportEnum() + + case fieldOptionalInt32: + return m.GetOptionalInt32() + case fieldOptionalInt64: + return m.GetOptionalInt64() + case fieldOptionalUint32: + return m.GetOptionalUint32() + case fieldOptionalUint64: + return m.GetOptionalUint64() + case fieldOptionalSint32: + return m.GetOptionalSint32() + case fieldOptionalSint64: + return m.GetOptionalSint64() + case fieldOptionalFixed32: + return m.GetOptionalFixed32() + case fieldOptionalFixed64: + return m.GetOptionalFixed64() + case fieldOptionalSfixed32: + return m.GetOptionalSfixed32() + case fieldOptionalSfixed64: + return m.GetOptionalSfixed64() + case fieldOptionalFloat: + return m.GetOptionalFloat() + case fieldOptionalDouble: + return m.GetOptionalDouble() + case fieldOptionalBool: + return m.GetOptionalBool() + case fieldOptionalString: + return m.GetOptionalString() + case fieldOptionalBytes: + return m.GetOptionalBytes() + case fieldOptionalGroup: + return m.GetOptionalgroup() + case fieldNotGroupLikeDelimited: + return m.GetNotGroupLikeDelimited() + case fieldOptionalNestedMessage: + return m.GetOptionalNestedMessage() + case fieldOptionalForeignMessage: + return m.GetOptionalForeignMessage() + case fieldOptionalImportMessage: + return m.GetOptionalImportMessage() + case fieldOptionalNestedEnum: + return m.GetOptionalNestedEnum() + case fieldOptionalForeignEnum: + return m.GetOptionalForeignEnum() + case fieldOptionalImportEnum: + return m.GetOptionalImportEnum() + case fieldOptionalLazyNestedMessage: + return m.GetOptionalLazyNestedMessage() + + case fieldRepeatedInt32: + return m.GetRepeatedInt32() + case fieldRepeatedInt64: + return m.GetRepeatedInt64() + case fieldRepeatedUint32: + return m.GetRepeatedUint32() + case fieldRepeatedUint64: + return m.GetRepeatedUint64() + case fieldRepeatedSint32: + return m.GetRepeatedSint32() + case fieldRepeatedSint64: + return m.GetRepeatedSint64() + case fieldRepeatedFixed32: + return m.GetRepeatedFixed32() + case fieldRepeatedFixed64: + return m.GetRepeatedFixed64() + case fieldRepeatedSfixed32: + return m.GetRepeatedSfixed32() + case fieldRepeatedSfixed64: + return m.GetRepeatedSfixed64() + case fieldRepeatedFloat: + return m.GetRepeatedFloat() + case fieldRepeatedDouble: + return m.GetRepeatedDouble() + case fieldRepeatedBool: + return m.GetRepeatedBool() + case fieldRepeatedString: + return m.GetRepeatedString() + case fieldRepeatedBytes: + return m.GetRepeatedBytes() + case fieldRepeatedGroup: + return m.GetRepeatedgroup() + case fieldRepeatedNestedMessage: + return m.GetRepeatedNestedMessage() + case fieldRepeatedForeignMessage: + return m.GetRepeatedForeignMessage() + case fieldRepeatedImportMessage: + return m.GetRepeatedImportmessage() + case fieldRepeatedNestedEnum: + return m.GetRepeatedNestedEnum() + case fieldRepeatedForeignEnum: + return m.GetRepeatedForeignEnum() + case fieldRepeatedImportEnum: + return m.GetRepeatedImportenum() + + case fieldMapInt32Int32: + return m.GetMapInt32Int32() + case fieldMapInt64Int64: + return m.GetMapInt64Int64() + case fieldMapUint32Uint32: + return m.GetMapUint32Uint32() + case fieldMapUint64Uint64: + return m.GetMapUint64Uint64() + case fieldMapSint32Sint32: + return m.GetMapSint32Sint32() + case fieldMapSint64Sint64: + return m.GetMapSint64Sint64() + case fieldMapFixed32Fixed32: + return m.GetMapFixed32Fixed32() + case fieldMapFixed64Fixed64: + return m.GetMapFixed64Fixed64() + case fieldMapSfixed32Sfixed32: + return m.GetMapSfixed32Sfixed32() + case fieldMapSfixed64Sfixed64: + return m.GetMapSfixed64Sfixed64() + case fieldMapInt32Float: + return m.GetMapInt32Float() + case fieldMapInt32Double: + return m.GetMapInt32Double() + case fieldMapBoolBool: + return m.GetMapBoolBool() + case fieldMapStringString: + return m.GetMapStringString() + case fieldMapStringBytes: + return m.GetMapStringBytes() + case fieldMapStringNestedMessage: + return m.GetMapStringNestedMessage() + case fieldMapStringNestedEnum: + return m.GetMapStringNestedEnum() + + case fieldDefaultInt32: + return m.GetDefaultInt32() + case fieldDefaultInt64: + return m.GetDefaultInt64() + case fieldDefaultUint32: + return m.GetDefaultUint32() + case fieldDefaultUint64: + return m.GetDefaultUint64() + case fieldDefaultSint32: + return m.GetDefaultSint32() + case fieldDefaultSint64: + return m.GetDefaultSint64() + case fieldDefaultFixed32: + return m.GetDefaultFixed32() + case fieldDefaultFixed64: + return m.GetDefaultFixed64() + case fieldDefaultSfixed32: + return m.GetDefaultSfixed32() + case fieldDefaultSfixed64: + return m.GetDefaultSfixed64() + case fieldDefaultFloat: + return m.GetDefaultFloat() + case fieldDefaultDouble: + return m.GetDefaultDouble() + case fieldDefaultBool: + return m.GetDefaultBool() + case fieldDefaultString: + return m.GetDefaultString() + case fieldDefaultBytes: + return m.GetDefaultBytes() + case fieldDefaultNestedEnum: + return m.GetDefaultNestedEnum() + case fieldDefaultForeignEnum: + return m.GetDefaultForeignEnum() + + case fieldOneofUint32: + return m.GetOneofUint32() + case fieldOneofNestedMessage: + return m.GetOneofNestedMessage() + case fieldOneofString: + return m.GetOneofString() + case fieldOneofBytes: + return m.GetOneofBytes() + case fieldOneofBool: + return m.GetOneofBool() + case fieldOneofUint64: + return m.GetOneofUint64() + case fieldOneofFloat: + return m.GetOneofFloat() + case fieldOneofDouble: + return m.GetOneofDouble() + case fieldOneofEnum: + return m.GetOneofEnum() + case fieldOneofGroup: + return m.GetOneofgroup() + case fieldOneofOptionalUint32: + return m.GetOneofOptionalUint32() + + default: + panic(fmt.Sprintf("get: unknown field %d", num)) + } + }, + set: func(num protoreflect.FieldNumber, v any) { + switch num { + case fieldSingularInt32: + m.SetSingularInt32(v.(int32)) + case fieldSingularInt64: + m.SetSingularInt64(v.(int64)) + case fieldSingularUint32: + m.SetSingularUint32(v.(uint32)) + case fieldSingularUint64: + m.SetSingularUint64(v.(uint64)) + case fieldSingularSint32: + m.SetSingularSint32(v.(int32)) + case fieldSingularSint64: + m.SetSingularSint64(v.(int64)) + case fieldSingularFixed32: + m.SetSingularFixed32(v.(uint32)) + case fieldSingularFixed64: + m.SetSingularFixed64(v.(uint64)) + case fieldSingularSfixed32: + m.SetSingularSfixed32(v.(int32)) + case fieldSingularSfixed64: + m.SetSingularSfixed64(v.(int64)) + case fieldSingularFloat: + m.SetSingularFloat(v.(float32)) + case fieldSingularDouble: + m.SetSingularDouble(v.(float64)) + case fieldSingularBool: + m.SetSingularBool(v.(bool)) + case fieldSingularString: + m.SetSingularString(v.(string)) + case fieldSingularBytes: + m.SetSingularBytes(v.([]byte)) + case fieldSingularNestedEnum: + m.SetSingularNestedEnum(testpb.TestAllTypes_NestedEnum(v.(protoreflect.EnumNumber))) + case fieldSingularForeignEnum: + m.SetSingularForeignEnum(testpb.ForeignEnum(v.(protoreflect.EnumNumber))) + case fieldSingularImportEnum: + m.SetSingularImportEnum(testpb.ImportEnum(v.(protoreflect.EnumNumber))) + + case fieldOptionalInt32: + m.SetOptionalInt32(v.(int32)) + case fieldOptionalInt64: + m.SetOptionalInt64(v.(int64)) + case fieldOptionalUint32: + m.SetOptionalUint32(v.(uint32)) + case fieldOptionalUint64: + m.SetOptionalUint64(v.(uint64)) + case fieldOptionalSint32: + m.SetOptionalSint32(v.(int32)) + case fieldOptionalSint64: + m.SetOptionalSint64(v.(int64)) + case fieldOptionalFixed32: + m.SetOptionalFixed32(v.(uint32)) + case fieldOptionalFixed64: + m.SetOptionalFixed64(v.(uint64)) + case fieldOptionalSfixed32: + m.SetOptionalSfixed32(v.(int32)) + case fieldOptionalSfixed64: + m.SetOptionalSfixed64(v.(int64)) + case fieldOptionalFloat: + m.SetOptionalFloat(v.(float32)) + case fieldOptionalDouble: + m.SetOptionalDouble(v.(float64)) + case fieldOptionalBool: + m.SetOptionalBool(v.(bool)) + case fieldOptionalString: + m.SetOptionalString(v.(string)) + case fieldOptionalBytes: + m.SetOptionalBytes(v.([]byte)) + case fieldOptionalGroup: + m.SetOptionalgroup(v.(*testpb.TestAllTypes_OptionalGroup)) + case fieldNotGroupLikeDelimited: + m.SetNotGroupLikeDelimited(v.(*testpb.TestAllTypes_OptionalGroup)) + case fieldOptionalNestedMessage: + m.SetOptionalNestedMessage(v.(*testpb.TestAllTypes_NestedMessage)) + case fieldOptionalForeignMessage: + m.SetOptionalForeignMessage(v.(*testpb.ForeignMessage)) + case fieldOptionalImportMessage: + m.SetOptionalImportMessage(v.(*testpb.ImportMessage)) + case fieldOptionalNestedEnum: + m.SetOptionalNestedEnum(testpb.TestAllTypes_NestedEnum(v.(protoreflect.EnumNumber))) + case fieldOptionalForeignEnum: + m.SetOptionalForeignEnum(testpb.ForeignEnum(v.(protoreflect.EnumNumber))) + case fieldOptionalImportEnum: + m.SetOptionalImportEnum(testpb.ImportEnum(v.(protoreflect.EnumNumber))) + case fieldOptionalLazyNestedMessage: + m.SetOptionalLazyNestedMessage(v.(*testpb.TestAllTypes_NestedMessage)) + + case fieldRepeatedInt32: + m.SetRepeatedInt32(v.([]int32)) + case fieldRepeatedInt64: + m.SetRepeatedInt64(v.([]int64)) + case fieldRepeatedUint32: + m.SetRepeatedUint32(v.([]uint32)) + case fieldRepeatedUint64: + m.SetRepeatedUint64(v.([]uint64)) + case fieldRepeatedSint32: + m.SetRepeatedSint32(v.([]int32)) + case fieldRepeatedSint64: + m.SetRepeatedSint64(v.([]int64)) + case fieldRepeatedFixed32: + m.SetRepeatedFixed32(v.([]uint32)) + case fieldRepeatedFixed64: + m.SetRepeatedFixed64(v.([]uint64)) + case fieldRepeatedSfixed32: + m.SetRepeatedSfixed32(v.([]int32)) + case fieldRepeatedSfixed64: + m.SetRepeatedSfixed64(v.([]int64)) + case fieldRepeatedFloat: + m.SetRepeatedFloat(v.([]float32)) + case fieldRepeatedDouble: + m.SetRepeatedDouble(v.([]float64)) + case fieldRepeatedBool: + m.SetRepeatedBool(v.([]bool)) + case fieldRepeatedString: + m.SetRepeatedString(v.([]string)) + case fieldRepeatedBytes: + m.SetRepeatedBytes(v.([][]byte)) + case fieldRepeatedGroup: + m.SetRepeatedgroup(v.([]*testpb.TestAllTypes_RepeatedGroup)) + case fieldRepeatedNestedMessage: + m.SetRepeatedNestedMessage(v.([]*testpb.TestAllTypes_NestedMessage)) + case fieldRepeatedForeignMessage: + m.SetRepeatedForeignMessage(v.([]*testpb.ForeignMessage)) + case fieldRepeatedImportMessage: + m.SetRepeatedImportmessage(v.([]*testpb.ImportMessage)) + case fieldRepeatedNestedEnum: + m.SetRepeatedNestedEnum(v.([]testpb.TestAllTypes_NestedEnum)) + case fieldRepeatedForeignEnum: + m.SetRepeatedForeignEnum(v.([]testpb.ForeignEnum)) + case fieldRepeatedImportEnum: + m.SetRepeatedImportenum(v.([]testpb.ImportEnum)) + + case fieldMapInt32Int32: + m.SetMapInt32Int32(v.(map[int32]int32)) + case fieldMapInt64Int64: + m.SetMapInt64Int64(v.(map[int64]int64)) + case fieldMapUint32Uint32: + m.SetMapUint32Uint32(v.(map[uint32]uint32)) + case fieldMapUint64Uint64: + m.SetMapUint64Uint64(v.(map[uint64]uint64)) + case fieldMapSint32Sint32: + m.SetMapSint32Sint32(v.(map[int32]int32)) + case fieldMapSint64Sint64: + m.SetMapSint64Sint64(v.(map[int64]int64)) + case fieldMapFixed32Fixed32: + m.SetMapFixed32Fixed32(v.(map[uint32]uint32)) + case fieldMapFixed64Fixed64: + m.SetMapFixed64Fixed64(v.(map[uint64]uint64)) + case fieldMapSfixed32Sfixed32: + m.SetMapSfixed32Sfixed32(v.(map[int32]int32)) + case fieldMapSfixed64Sfixed64: + m.SetMapSfixed64Sfixed64(v.(map[int64]int64)) + case fieldMapInt32Float: + m.SetMapInt32Float(v.(map[int32]float32)) + case fieldMapInt32Double: + m.SetMapInt32Double(v.(map[int32]float64)) + case fieldMapBoolBool: + m.SetMapBoolBool(v.(map[bool]bool)) + case fieldMapStringString: + m.SetMapStringString(v.(map[string]string)) + case fieldMapStringBytes: + m.SetMapStringBytes(v.(map[string][]byte)) + case fieldMapStringNestedMessage: + m.SetMapStringNestedMessage(v.(map[string]*testpb.TestAllTypes_NestedMessage)) + case fieldMapStringNestedEnum: + m.SetMapStringNestedEnum(v.(map[string]testpb.TestAllTypes_NestedEnum)) + + case fieldDefaultInt32: + m.SetDefaultInt32(v.(int32)) + case fieldDefaultInt64: + m.SetDefaultInt64(v.(int64)) + case fieldDefaultUint32: + m.SetDefaultUint32(v.(uint32)) + case fieldDefaultUint64: + m.SetDefaultUint64(v.(uint64)) + case fieldDefaultSint32: + m.SetDefaultSint32(v.(int32)) + case fieldDefaultSint64: + m.SetDefaultSint64(v.(int64)) + case fieldDefaultFixed32: + m.SetDefaultFixed32(v.(uint32)) + case fieldDefaultFixed64: + m.SetDefaultFixed64(v.(uint64)) + case fieldDefaultSfixed32: + m.SetDefaultSfixed32(v.(int32)) + case fieldDefaultSfixed64: + m.SetDefaultSfixed64(v.(int64)) + case fieldDefaultFloat: + m.SetDefaultFloat(v.(float32)) + case fieldDefaultDouble: + m.SetDefaultDouble(v.(float64)) + case fieldDefaultBool: + m.SetDefaultBool(v.(bool)) + case fieldDefaultString: + m.SetDefaultString(v.(string)) + case fieldDefaultBytes: + m.SetDefaultBytes(v.([]byte)) + case fieldDefaultNestedEnum: + m.SetDefaultNestedEnum(testpb.TestAllTypes_NestedEnum(v.(protoreflect.EnumNumber))) + case fieldDefaultForeignEnum: + m.SetDefaultForeignEnum(testpb.ForeignEnum(v.(protoreflect.EnumNumber))) + + case fieldDefaultInt32: + m.SetDefaultInt32(v.(int32)) + case fieldDefaultInt64: + m.SetDefaultInt64(v.(int64)) + case fieldDefaultUint32: + m.SetDefaultUint32(v.(uint32)) + case fieldDefaultUint64: + m.SetDefaultUint64(v.(uint64)) + case fieldDefaultSint32: + m.SetDefaultSint32(v.(int32)) + case fieldDefaultSint64: + m.SetDefaultSint64(v.(int64)) + case fieldDefaultFixed32: + m.SetDefaultFixed32(v.(uint32)) + case fieldDefaultFixed64: + m.SetDefaultFixed64(v.(uint64)) + case fieldDefaultSfixed32: + m.SetDefaultSfixed32(v.(int32)) + case fieldDefaultSfixed64: + m.SetDefaultSfixed64(v.(int64)) + case fieldDefaultFloat: + m.SetDefaultFloat(v.(float32)) + case fieldDefaultDouble: + m.SetDefaultDouble(v.(float64)) + case fieldDefaultBool: + m.SetDefaultBool(v.(bool)) + case fieldDefaultString: + m.SetDefaultString(v.(string)) + case fieldDefaultBytes: + m.SetDefaultBytes(v.([]byte)) + case fieldDefaultNestedEnum: + m.SetDefaultNestedEnum(testpb.TestAllTypes_NestedEnum(v.(protoreflect.EnumNumber))) + case fieldDefaultForeignEnum: + m.SetDefaultForeignEnum(testpb.ForeignEnum(v.(protoreflect.EnumNumber))) + + case fieldOneofUint32: + m.SetOneofUint32(v.(uint32)) + case fieldOneofNestedMessage: + m.SetOneofNestedMessage(v.(*testpb.TestAllTypes_NestedMessage)) + case fieldOneofString: + m.SetOneofString(v.(string)) + case fieldOneofBytes: + m.SetOneofBytes(v.([]byte)) + case fieldOneofBool: + m.SetOneofBool(v.(bool)) + case fieldOneofUint64: + m.SetOneofUint64(v.(uint64)) + case fieldOneofFloat: + m.SetOneofFloat(v.(float32)) + case fieldOneofDouble: + m.SetOneofDouble(v.(float64)) + case fieldOneofEnum: + m.SetOneofEnum(testpb.TestAllTypes_NestedEnum(v.(protoreflect.EnumNumber))) + case fieldOneofGroup: + m.SetOneofgroup(v.(*testpb.TestAllTypes_OneofGroup)) + case fieldOneofOptionalUint32: + m.SetOneofOptionalUint32(v.(uint32)) + + default: + panic(fmt.Sprintf("set: unknown field %d", num)) + } + }, + clear: func(num protoreflect.FieldNumber) { + switch num { + case fieldSingularInt32: + m.SetSingularInt32(0) + case fieldSingularInt64: + m.SetSingularInt64(0) + case fieldSingularUint32: + m.SetSingularUint32(0) + case fieldSingularUint64: + m.SetSingularUint64(0) + case fieldSingularSint32: + m.SetSingularSint32(0) + case fieldSingularSint64: + m.SetSingularSint64(0) + case fieldSingularFixed32: + m.SetSingularFixed32(0) + case fieldSingularFixed64: + m.SetSingularFixed64(0) + case fieldSingularSfixed32: + m.SetSingularSfixed32(0) + case fieldSingularSfixed64: + m.SetSingularSfixed64(0) + case fieldSingularFloat: + m.SetSingularFloat(0) + case fieldSingularDouble: + m.SetSingularDouble(0) + case fieldSingularBool: + m.SetSingularBool(false) + case fieldSingularString: + m.SetSingularString("") + case fieldSingularBytes: + m.SetSingularBytes(nil) + case fieldSingularNestedEnum: + m.SetSingularNestedEnum(testpb.TestAllTypes_FOO) + case fieldSingularForeignEnum: + m.SetSingularForeignEnum(testpb.ForeignEnum_FOREIGN_ZERO) + case fieldSingularImportEnum: + m.SetSingularImportEnum(testpb.ImportEnum_IMPORT_ZERO) + + case fieldOptionalInt32: + m.ClearOptionalInt32() + case fieldOptionalInt64: + m.ClearOptionalInt64() + case fieldOptionalUint32: + m.ClearOptionalUint32() + case fieldOptionalUint64: + m.ClearOptionalUint64() + case fieldOptionalSint32: + m.ClearOptionalSint32() + case fieldOptionalSint64: + m.ClearOptionalSint64() + case fieldOptionalFixed32: + m.ClearOptionalFixed32() + case fieldOptionalFixed64: + m.ClearOptionalFixed64() + case fieldOptionalSfixed32: + m.ClearOptionalSfixed32() + case fieldOptionalSfixed64: + m.ClearOptionalSfixed64() + case fieldOptionalFloat: + m.ClearOptionalFloat() + case fieldOptionalDouble: + m.ClearOptionalDouble() + case fieldOptionalBool: + m.ClearOptionalBool() + case fieldOptionalString: + m.ClearOptionalString() + case fieldOptionalBytes: + m.ClearOptionalBytes() + case fieldOptionalGroup: + m.ClearOptionalgroup() + case fieldNotGroupLikeDelimited: + m.ClearNotGroupLikeDelimited() + case fieldOptionalNestedMessage: + m.ClearOptionalNestedMessage() + case fieldOptionalForeignMessage: + m.ClearOptionalForeignMessage() + case fieldOptionalImportMessage: + m.ClearOptionalImportMessage() + case fieldOptionalNestedEnum: + m.ClearOptionalNestedEnum() + case fieldOptionalForeignEnum: + m.ClearOptionalForeignEnum() + case fieldOptionalImportEnum: + m.ClearOptionalImportEnum() + case fieldOptionalLazyNestedMessage: + m.ClearOptionalLazyNestedMessage() + + case fieldRepeatedInt32: + m.SetRepeatedInt32(nil) + case fieldRepeatedInt64: + m.SetRepeatedInt64(nil) + case fieldRepeatedUint32: + m.SetRepeatedUint32(nil) + case fieldRepeatedUint64: + m.SetRepeatedUint64(nil) + case fieldRepeatedSint32: + m.SetRepeatedSint32(nil) + case fieldRepeatedSint64: + m.SetRepeatedSint64(nil) + case fieldRepeatedFixed32: + m.SetRepeatedFixed32(nil) + case fieldRepeatedFixed64: + m.SetRepeatedFixed64(nil) + case fieldRepeatedSfixed32: + m.SetRepeatedSfixed32(nil) + case fieldRepeatedSfixed64: + m.SetRepeatedSfixed64(nil) + case fieldRepeatedFloat: + m.SetRepeatedFloat(nil) + case fieldRepeatedDouble: + m.SetRepeatedDouble(nil) + case fieldRepeatedBool: + m.SetRepeatedBool(nil) + case fieldRepeatedString: + m.SetRepeatedString(nil) + case fieldRepeatedBytes: + m.SetRepeatedBytes(nil) + case fieldRepeatedGroup: + m.SetRepeatedgroup(nil) + case fieldRepeatedNestedMessage: + m.SetRepeatedNestedMessage(nil) + case fieldRepeatedForeignMessage: + m.SetRepeatedForeignMessage(nil) + case fieldRepeatedImportMessage: + m.SetRepeatedImportmessage(nil) + case fieldRepeatedNestedEnum: + m.SetRepeatedNestedEnum(nil) + case fieldRepeatedForeignEnum: + m.SetRepeatedForeignEnum(nil) + case fieldRepeatedImportEnum: + m.SetRepeatedImportenum(nil) + + case fieldMapInt32Int32: + m.SetMapInt32Int32(nil) + case fieldMapInt64Int64: + m.SetMapInt64Int64(nil) + case fieldMapUint32Uint32: + m.SetMapUint32Uint32(nil) + case fieldMapUint64Uint64: + m.SetMapUint64Uint64(nil) + case fieldMapSint32Sint32: + m.SetMapSint32Sint32(nil) + case fieldMapSint64Sint64: + m.SetMapSint64Sint64(nil) + case fieldMapFixed32Fixed32: + m.SetMapFixed32Fixed32(nil) + case fieldMapFixed64Fixed64: + m.SetMapFixed64Fixed64(nil) + case fieldMapSfixed32Sfixed32: + m.SetMapSfixed32Sfixed32(nil) + case fieldMapSfixed64Sfixed64: + m.SetMapSfixed64Sfixed64(nil) + case fieldMapInt32Float: + m.SetMapInt32Float(nil) + case fieldMapInt32Double: + m.SetMapInt32Double(nil) + case fieldMapBoolBool: + m.SetMapBoolBool(nil) + case fieldMapStringString: + m.SetMapStringString(nil) + case fieldMapStringBytes: + m.SetMapStringBytes(nil) + case fieldMapStringNestedMessage: + m.SetMapStringNestedMessage(nil) + case fieldMapStringNestedEnum: + m.SetMapStringNestedEnum(nil) + + case fieldDefaultInt32: + m.ClearDefaultInt32() + case fieldDefaultInt64: + m.ClearDefaultInt64() + case fieldDefaultUint32: + m.ClearDefaultUint32() + case fieldDefaultUint64: + m.ClearDefaultUint64() + case fieldDefaultSint32: + m.ClearDefaultSint32() + case fieldDefaultSint64: + m.ClearDefaultSint64() + case fieldDefaultFixed32: + m.ClearDefaultFixed32() + case fieldDefaultFixed64: + m.ClearDefaultFixed64() + case fieldDefaultSfixed32: + m.ClearDefaultSfixed32() + case fieldDefaultSfixed64: + m.ClearDefaultSfixed64() + case fieldDefaultFloat: + m.ClearDefaultFloat() + case fieldDefaultDouble: + m.ClearDefaultDouble() + case fieldDefaultBool: + m.ClearDefaultBool() + case fieldDefaultString: + m.ClearDefaultString() + case fieldDefaultBytes: + m.ClearDefaultBytes() + case fieldDefaultNestedEnum: + m.ClearDefaultNestedEnum() + case fieldDefaultForeignEnum: + m.ClearDefaultForeignEnum() + + case fieldOneofUint32: + m.ClearOneofUint32() + case fieldOneofNestedMessage: + m.ClearOneofNestedMessage() + case fieldOneofString: + m.ClearOneofString() + case fieldOneofBytes: + m.ClearOneofBytes() + case fieldOneofBool: + m.ClearOneofBool() + case fieldOneofUint64: + m.ClearOneofUint64() + case fieldOneofFloat: + m.ClearOneofFloat() + case fieldOneofDouble: + m.ClearOneofDouble() + case fieldOneofEnum: + m.ClearOneofEnum() + case fieldOneofGroup: + m.ClearOneofgroup() + case fieldOneofOptionalUint32: + m.ClearOneofOptionalUint32() + + default: + panic(fmt.Sprintf("clear: unknown field %d", num)) + } + }, + } +} diff --git a/internal/reflection_test/reflection_large_opaque_test.go b/internal/reflection_test/reflection_large_opaque_test.go new file mode 100644 index 000000000..fc88620e7 --- /dev/null +++ b/internal/reflection_test/reflection_large_opaque_test.go @@ -0,0 +1,893 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package reflection_test + +import ( + "fmt" + "testing" + + testpb "google.golang.org/protobuf/internal/testprotos/testeditions/testeditions_opaque" + "google.golang.org/protobuf/reflect/protoreflect" +) + +func TestLargeOpaqueConcrete(t *testing.T) { + for _, tt := range lazyCombinations { + t.Run(tt.desc, func(t *testing.T) { + tt.ptm.Test(t, newTestMessageLargeOpaque(nil).ProtoReflect().Type()) + }) + } +} + +func TestLargeOpaqueReflection(t *testing.T) { + for _, tt := range lazyCombinations { + t.Run(tt.desc, func(t *testing.T) { + tt.ptm.Test(t, (*testpb.TestManyMessageFieldsMessage)(nil).ProtoReflect().Type()) + }) + } +} + +func TestLargeOpaqueShadow_GetConcrete_SetReflection(t *testing.T) { + for _, tt := range lazyCombinations { + t.Run(tt.desc, func(t *testing.T) { + tt.ptm.Test(t, newShadow(func() (get, set protoreflect.ProtoMessage) { + m := &testpb.TestManyMessageFieldsMessage{} + return newTestMessageLargeOpaque(m), m + }).ProtoReflect().Type()) + }) + } +} + +func TestLargeOpaqueShadow_GetReflection_SetConcrete(t *testing.T) { + for _, tt := range lazyCombinations { + t.Run(tt.desc, func(t *testing.T) { + tt.ptm.Test(t, newShadow(func() (get, set protoreflect.ProtoMessage) { + m := &testpb.TestManyMessageFieldsMessage{} + return m, newTestMessageLargeOpaque(m) + }).ProtoReflect().Type()) + }) + } +} + +func newTestMessageLargeOpaque(m *testpb.TestManyMessageFieldsMessage) protoreflect.ProtoMessage { + return &testProtoMessage{ + m: m, + md: m.ProtoReflect().Descriptor(), + new: func() protoreflect.Message { + return newTestMessageLargeOpaque(&testpb.TestManyMessageFieldsMessage{}).ProtoReflect() + }, + has: func(num protoreflect.FieldNumber) bool { + switch num { + case largeFieldF1: + return m.HasF1() + case largeFieldF2: + return m.HasF2() + case largeFieldF3: + return m.HasF3() + case largeFieldF4: + return m.HasF4() + case largeFieldF5: + return m.HasF5() + case largeFieldF6: + return m.HasF6() + case largeFieldF7: + return m.HasF7() + case largeFieldF8: + return m.HasF8() + case largeFieldF9: + return m.HasF9() + case largeFieldF10: + return m.HasF10() + case largeFieldF11: + return m.HasF11() + case largeFieldF12: + return m.HasF12() + case largeFieldF13: + return m.HasF13() + case largeFieldF14: + return m.HasF14() + case largeFieldF15: + return m.HasF15() + case largeFieldF16: + return m.HasF16() + case largeFieldF17: + return m.HasF17() + case largeFieldF18: + return m.HasF18() + case largeFieldF19: + return m.HasF19() + case largeFieldF20: + return m.HasF20() + case largeFieldF21: + return m.HasF21() + case largeFieldF22: + return m.HasF22() + case largeFieldF23: + return m.HasF23() + case largeFieldF24: + return m.HasF24() + case largeFieldF25: + return m.HasF25() + case largeFieldF26: + return m.HasF26() + case largeFieldF27: + return m.HasF27() + case largeFieldF28: + return m.HasF28() + case largeFieldF29: + return m.HasF29() + case largeFieldF30: + return m.HasF30() + case largeFieldF31: + return m.HasF31() + case largeFieldF32: + return m.HasF32() + case largeFieldF33: + return m.HasF33() + case largeFieldF34: + return m.HasF34() + case largeFieldF35: + return m.HasF35() + case largeFieldF36: + return m.HasF36() + case largeFieldF37: + return m.HasF37() + case largeFieldF38: + return m.HasF38() + case largeFieldF39: + return m.HasF39() + case largeFieldF40: + return m.HasF40() + case largeFieldF41: + return m.HasF41() + case largeFieldF42: + return m.HasF42() + case largeFieldF43: + return m.HasF43() + case largeFieldF44: + return m.HasF44() + case largeFieldF45: + return m.HasF45() + case largeFieldF46: + return m.HasF46() + case largeFieldF47: + return m.HasF47() + case largeFieldF48: + return m.HasF48() + case largeFieldF49: + return m.HasF49() + case largeFieldF50: + return m.HasF50() + case largeFieldF51: + return m.HasF51() + case largeFieldF52: + return m.HasF52() + case largeFieldF53: + return m.HasF53() + case largeFieldF54: + return m.HasF54() + case largeFieldF55: + return m.HasF55() + case largeFieldF56: + return m.HasF56() + case largeFieldF57: + return m.HasF57() + case largeFieldF58: + return m.HasF58() + case largeFieldF59: + return m.HasF59() + case largeFieldF60: + return m.HasF60() + case largeFieldF60: + return m.HasF60() + case largeFieldF61: + return m.HasF61() + case largeFieldF62: + return m.HasF62() + case largeFieldF63: + return m.HasF63() + case largeFieldF64: + return m.HasF64() + case largeFieldF65: + return m.HasF65() + case largeFieldF66: + return m.HasF66() + case largeFieldF67: + return m.HasF67() + case largeFieldF68: + return m.HasF68() + case largeFieldF69: + return m.HasF69() + case largeFieldF70: + return m.HasF70() + case largeFieldF71: + return m.HasF71() + case largeFieldF72: + return m.HasF72() + case largeFieldF73: + return m.HasF73() + case largeFieldF74: + return m.HasF74() + case largeFieldF75: + return m.HasF75() + case largeFieldF76: + return m.HasF76() + case largeFieldF77: + return m.HasF77() + case largeFieldF78: + return m.HasF78() + case largeFieldF79: + return m.HasF79() + case largeFieldF80: + return m.HasF80() + case largeFieldF81: + return m.HasF81() + case largeFieldF82: + return m.HasF82() + case largeFieldF83: + return m.HasF83() + case largeFieldF84: + return m.HasF84() + case largeFieldF85: + return m.HasF85() + case largeFieldF86: + return m.HasF86() + case largeFieldF87: + return m.HasF87() + case largeFieldF88: + return m.HasF88() + case largeFieldF89: + return m.HasF89() + case largeFieldF90: + return m.HasF90() + case largeFieldF91: + return m.HasF91() + case largeFieldF92: + return m.HasF92() + case largeFieldF93: + return m.HasF93() + case largeFieldF94: + return m.HasF94() + case largeFieldF95: + return m.HasF95() + case largeFieldF96: + return m.HasF96() + case largeFieldF97: + return m.HasF97() + case largeFieldF98: + return m.HasF98() + case largeFieldF99: + return m.HasF99() + case largeFieldF100: + return m.HasF100() + + default: + panic(fmt.Sprintf("has: unknown field %d", num)) + } + }, + get: func(num protoreflect.FieldNumber) any { + switch num { + case largeFieldF1: + return m.GetF1() + case largeFieldF2: + return m.GetF2() + case largeFieldF3: + return m.GetF3() + case largeFieldF4: + return m.GetF4() + case largeFieldF5: + return m.GetF5() + case largeFieldF6: + return m.GetF6() + case largeFieldF7: + return m.GetF7() + case largeFieldF8: + return m.GetF8() + case largeFieldF9: + return m.GetF9() + case largeFieldF10: + return m.GetF10() + case largeFieldF11: + return m.GetF11() + case largeFieldF12: + return m.GetF12() + case largeFieldF13: + return m.GetF13() + case largeFieldF14: + return m.GetF14() + case largeFieldF15: + return m.GetF15() + case largeFieldF16: + return m.GetF16() + case largeFieldF17: + return m.GetF17() + case largeFieldF18: + return m.GetF18() + case largeFieldF19: + return m.GetF19() + case largeFieldF20: + return m.GetF20() + case largeFieldF21: + return m.GetF21() + case largeFieldF22: + return m.GetF22() + case largeFieldF23: + return m.GetF23() + case largeFieldF24: + return m.GetF24() + case largeFieldF25: + return m.GetF25() + case largeFieldF26: + return m.GetF26() + case largeFieldF27: + return m.GetF27() + case largeFieldF28: + return m.GetF28() + case largeFieldF29: + return m.GetF29() + case largeFieldF30: + return m.GetF30() + case largeFieldF31: + return m.GetF31() + case largeFieldF32: + return m.GetF32() + case largeFieldF33: + return m.GetF33() + case largeFieldF34: + return m.GetF34() + case largeFieldF35: + return m.GetF35() + case largeFieldF36: + return m.GetF36() + case largeFieldF37: + return m.GetF37() + case largeFieldF38: + return m.GetF38() + case largeFieldF39: + return m.GetF39() + case largeFieldF40: + return m.GetF40() + case largeFieldF41: + return m.GetF41() + case largeFieldF42: + return m.GetF42() + case largeFieldF43: + return m.GetF43() + case largeFieldF44: + return m.GetF44() + case largeFieldF45: + return m.GetF45() + case largeFieldF46: + return m.GetF46() + case largeFieldF47: + return m.GetF47() + case largeFieldF48: + return m.GetF48() + case largeFieldF49: + return m.GetF49() + case largeFieldF50: + return m.GetF50() + case largeFieldF51: + return m.GetF51() + case largeFieldF52: + return m.GetF52() + case largeFieldF53: + return m.GetF53() + case largeFieldF54: + return m.GetF54() + case largeFieldF55: + return m.GetF55() + case largeFieldF56: + return m.GetF56() + case largeFieldF57: + return m.GetF57() + case largeFieldF58: + return m.GetF58() + case largeFieldF59: + return m.GetF59() + case largeFieldF60: + return m.GetF60() + case largeFieldF61: + return m.GetF61() + case largeFieldF62: + return m.GetF62() + case largeFieldF63: + return m.GetF63() + case largeFieldF64: + return m.GetF64() + case largeFieldF65: + return m.GetF65() + case largeFieldF66: + return m.GetF66() + case largeFieldF67: + return m.GetF67() + case largeFieldF68: + return m.GetF68() + case largeFieldF69: + return m.GetF69() + case largeFieldF70: + return m.GetF70() + case largeFieldF71: + return m.GetF71() + case largeFieldF72: + return m.GetF72() + case largeFieldF73: + return m.GetF73() + case largeFieldF74: + return m.GetF74() + case largeFieldF75: + return m.GetF75() + case largeFieldF76: + return m.GetF76() + case largeFieldF77: + return m.GetF77() + case largeFieldF78: + return m.GetF78() + case largeFieldF79: + return m.GetF79() + case largeFieldF80: + return m.GetF80() + case largeFieldF81: + return m.GetF81() + case largeFieldF82: + return m.GetF82() + case largeFieldF83: + return m.GetF83() + case largeFieldF84: + return m.GetF84() + case largeFieldF85: + return m.GetF85() + case largeFieldF86: + return m.GetF86() + case largeFieldF87: + return m.GetF87() + case largeFieldF88: + return m.GetF88() + case largeFieldF89: + return m.GetF89() + case largeFieldF90: + return m.GetF90() + case largeFieldF91: + return m.GetF91() + case largeFieldF92: + return m.GetF92() + case largeFieldF93: + return m.GetF93() + case largeFieldF94: + return m.GetF94() + case largeFieldF95: + return m.GetF95() + case largeFieldF96: + return m.GetF96() + case largeFieldF97: + return m.GetF97() + case largeFieldF98: + return m.GetF98() + case largeFieldF99: + return m.GetF99() + case largeFieldF100: + return m.GetF100() + + default: + panic(fmt.Sprintf("get: unknown field %d", num)) + } + }, + set: func(num protoreflect.FieldNumber, v any) { + switch num { + case largeFieldF1: + m.SetF1(v.(*testpb.TestAllTypes)) + case largeFieldF2: + m.SetF2(v.(*testpb.TestAllTypes)) + case largeFieldF3: + m.SetF3(v.(*testpb.TestAllTypes)) + case largeFieldF4: + m.SetF4(v.(*testpb.TestAllTypes)) + case largeFieldF5: + m.SetF5(v.(*testpb.TestAllTypes)) + case largeFieldF6: + m.SetF6(v.(*testpb.TestAllTypes)) + case largeFieldF7: + m.SetF7(v.(*testpb.TestAllTypes)) + case largeFieldF8: + m.SetF8(v.(*testpb.TestAllTypes)) + case largeFieldF9: + m.SetF9(v.(*testpb.TestAllTypes)) + case largeFieldF10: + m.SetF10(v.(*testpb.TestAllTypes)) + case largeFieldF11: + m.SetF11(v.(*testpb.TestAllTypes)) + case largeFieldF12: + m.SetF12(v.(*testpb.TestAllTypes)) + case largeFieldF13: + m.SetF13(v.(*testpb.TestAllTypes)) + case largeFieldF14: + m.SetF14(v.(*testpb.TestAllTypes)) + case largeFieldF15: + m.SetF15(v.(*testpb.TestAllTypes)) + case largeFieldF16: + m.SetF16(v.(*testpb.TestAllTypes)) + case largeFieldF17: + m.SetF17(v.(*testpb.TestAllTypes)) + case largeFieldF18: + m.SetF18(v.(*testpb.TestAllTypes)) + case largeFieldF19: + m.SetF19(v.(*testpb.TestAllTypes)) + case largeFieldF20: + m.SetF20(v.(*testpb.TestAllTypes)) + case largeFieldF21: + m.SetF21(v.(*testpb.TestAllTypes)) + case largeFieldF22: + m.SetF22(v.(*testpb.TestAllTypes)) + case largeFieldF23: + m.SetF23(v.(*testpb.TestAllTypes)) + case largeFieldF24: + m.SetF24(v.(*testpb.TestAllTypes)) + case largeFieldF25: + m.SetF25(v.(*testpb.TestAllTypes)) + case largeFieldF26: + m.SetF26(v.(*testpb.TestAllTypes)) + case largeFieldF27: + m.SetF27(v.(*testpb.TestAllTypes)) + case largeFieldF28: + m.SetF28(v.(*testpb.TestAllTypes)) + case largeFieldF29: + m.SetF29(v.(*testpb.TestAllTypes)) + case largeFieldF30: + m.SetF30(v.(*testpb.TestAllTypes)) + case largeFieldF31: + m.SetF31(v.(*testpb.TestAllTypes)) + case largeFieldF32: + m.SetF32(v.(*testpb.TestAllTypes)) + case largeFieldF33: + m.SetF33(v.(*testpb.TestAllTypes)) + case largeFieldF34: + m.SetF34(v.(*testpb.TestAllTypes)) + case largeFieldF35: + m.SetF35(v.(*testpb.TestAllTypes)) + case largeFieldF36: + m.SetF36(v.(*testpb.TestAllTypes)) + case largeFieldF37: + m.SetF37(v.(*testpb.TestAllTypes)) + case largeFieldF38: + m.SetF38(v.(*testpb.TestAllTypes)) + case largeFieldF39: + m.SetF39(v.(*testpb.TestAllTypes)) + case largeFieldF40: + m.SetF40(v.(*testpb.TestAllTypes)) + case largeFieldF41: + m.SetF41(v.(*testpb.TestAllTypes)) + case largeFieldF42: + m.SetF42(v.(*testpb.TestAllTypes)) + case largeFieldF43: + m.SetF43(v.(*testpb.TestAllTypes)) + case largeFieldF44: + m.SetF44(v.(*testpb.TestAllTypes)) + case largeFieldF45: + m.SetF45(v.(*testpb.TestAllTypes)) + case largeFieldF46: + m.SetF46(v.(*testpb.TestAllTypes)) + case largeFieldF47: + m.SetF47(v.(*testpb.TestAllTypes)) + case largeFieldF48: + m.SetF48(v.(*testpb.TestAllTypes)) + case largeFieldF49: + m.SetF49(v.(*testpb.TestAllTypes)) + case largeFieldF50: + m.SetF50(v.(*testpb.TestAllTypes)) + case largeFieldF51: + m.SetF51(v.(*testpb.TestAllTypes)) + case largeFieldF52: + m.SetF52(v.(*testpb.TestAllTypes)) + case largeFieldF53: + m.SetF53(v.(*testpb.TestAllTypes)) + case largeFieldF54: + m.SetF54(v.(*testpb.TestAllTypes)) + case largeFieldF55: + m.SetF55(v.(*testpb.TestAllTypes)) + case largeFieldF56: + m.SetF56(v.(*testpb.TestAllTypes)) + case largeFieldF57: + m.SetF57(v.(*testpb.TestAllTypes)) + case largeFieldF58: + m.SetF58(v.(*testpb.TestAllTypes)) + case largeFieldF59: + m.SetF59(v.(*testpb.TestAllTypes)) + case largeFieldF60: + m.SetF60(v.(*testpb.TestAllTypes)) + case largeFieldF61: + m.SetF61(v.(*testpb.TestAllTypes)) + case largeFieldF62: + m.SetF62(v.(*testpb.TestAllTypes)) + case largeFieldF63: + m.SetF63(v.(*testpb.TestAllTypes)) + case largeFieldF64: + m.SetF64(v.(*testpb.TestAllTypes)) + case largeFieldF65: + m.SetF65(v.(*testpb.TestAllTypes)) + case largeFieldF66: + m.SetF66(v.(*testpb.TestAllTypes)) + case largeFieldF67: + m.SetF67(v.(*testpb.TestAllTypes)) + case largeFieldF68: + m.SetF68(v.(*testpb.TestAllTypes)) + case largeFieldF69: + m.SetF69(v.(*testpb.TestAllTypes)) + case largeFieldF70: + m.SetF70(v.(*testpb.TestAllTypes)) + case largeFieldF71: + m.SetF71(v.(*testpb.TestAllTypes)) + case largeFieldF72: + m.SetF72(v.(*testpb.TestAllTypes)) + case largeFieldF73: + m.SetF73(v.(*testpb.TestAllTypes)) + case largeFieldF74: + m.SetF74(v.(*testpb.TestAllTypes)) + case largeFieldF75: + m.SetF75(v.(*testpb.TestAllTypes)) + case largeFieldF76: + m.SetF76(v.(*testpb.TestAllTypes)) + case largeFieldF77: + m.SetF77(v.(*testpb.TestAllTypes)) + case largeFieldF78: + m.SetF78(v.(*testpb.TestAllTypes)) + case largeFieldF79: + m.SetF79(v.(*testpb.TestAllTypes)) + case largeFieldF80: + m.SetF80(v.(*testpb.TestAllTypes)) + case largeFieldF81: + m.SetF81(v.(*testpb.TestAllTypes)) + case largeFieldF82: + m.SetF82(v.(*testpb.TestAllTypes)) + case largeFieldF83: + m.SetF83(v.(*testpb.TestAllTypes)) + case largeFieldF84: + m.SetF84(v.(*testpb.TestAllTypes)) + case largeFieldF85: + m.SetF85(v.(*testpb.TestAllTypes)) + case largeFieldF86: + m.SetF86(v.(*testpb.TestAllTypes)) + case largeFieldF87: + m.SetF87(v.(*testpb.TestAllTypes)) + case largeFieldF88: + m.SetF88(v.(*testpb.TestAllTypes)) + case largeFieldF89: + m.SetF89(v.(*testpb.TestAllTypes)) + case largeFieldF90: + m.SetF90(v.(*testpb.TestAllTypes)) + case largeFieldF91: + m.SetF91(v.(*testpb.TestAllTypes)) + case largeFieldF92: + m.SetF92(v.(*testpb.TestAllTypes)) + case largeFieldF93: + m.SetF93(v.(*testpb.TestAllTypes)) + case largeFieldF94: + m.SetF94(v.(*testpb.TestAllTypes)) + case largeFieldF95: + m.SetF95(v.(*testpb.TestAllTypes)) + case largeFieldF96: + m.SetF96(v.(*testpb.TestAllTypes)) + case largeFieldF97: + m.SetF97(v.(*testpb.TestAllTypes)) + case largeFieldF98: + m.SetF98(v.(*testpb.TestAllTypes)) + case largeFieldF99: + m.SetF99(v.(*testpb.TestAllTypes)) + case largeFieldF100: + m.SetF100(v.(*testpb.TestAllTypes)) + + default: + panic(fmt.Sprintf("set: unknown field %d", num)) + } + }, + clear: func(num protoreflect.FieldNumber) { + switch num { + case largeFieldF1: + m.ClearF1() + case largeFieldF2: + m.ClearF2() + case largeFieldF3: + m.ClearF3() + case largeFieldF4: + m.ClearF4() + case largeFieldF5: + m.ClearF5() + case largeFieldF6: + m.ClearF6() + case largeFieldF7: + m.ClearF7() + case largeFieldF8: + m.ClearF8() + case largeFieldF9: + m.ClearF9() + case largeFieldF10: + m.ClearF10() + case largeFieldF11: + m.ClearF11() + case largeFieldF12: + m.ClearF12() + case largeFieldF13: + m.ClearF13() + case largeFieldF14: + m.ClearF14() + case largeFieldF15: + m.ClearF15() + case largeFieldF16: + m.ClearF16() + case largeFieldF17: + m.ClearF17() + case largeFieldF18: + m.ClearF18() + case largeFieldF19: + m.ClearF19() + case largeFieldF20: + m.ClearF20() + case largeFieldF21: + m.ClearF21() + case largeFieldF22: + m.ClearF22() + case largeFieldF23: + m.ClearF23() + case largeFieldF24: + m.ClearF24() + case largeFieldF25: + m.ClearF25() + case largeFieldF26: + m.ClearF26() + case largeFieldF27: + m.ClearF27() + case largeFieldF28: + m.ClearF28() + case largeFieldF29: + m.ClearF29() + case largeFieldF30: + m.ClearF30() + case largeFieldF31: + m.ClearF31() + case largeFieldF32: + m.ClearF32() + case largeFieldF33: + m.ClearF33() + case largeFieldF34: + m.ClearF34() + case largeFieldF35: + m.ClearF35() + case largeFieldF36: + m.ClearF36() + case largeFieldF37: + m.ClearF37() + case largeFieldF38: + m.ClearF38() + case largeFieldF39: + m.ClearF39() + case largeFieldF40: + m.ClearF40() + case largeFieldF41: + m.ClearF41() + case largeFieldF42: + m.ClearF42() + case largeFieldF43: + m.ClearF43() + case largeFieldF44: + m.ClearF44() + case largeFieldF45: + m.ClearF45() + case largeFieldF46: + m.ClearF46() + case largeFieldF47: + m.ClearF47() + case largeFieldF48: + m.ClearF48() + case largeFieldF49: + m.ClearF49() + case largeFieldF50: + m.ClearF50() + case largeFieldF51: + m.ClearF51() + case largeFieldF52: + m.ClearF52() + case largeFieldF53: + m.ClearF53() + case largeFieldF54: + m.ClearF54() + case largeFieldF55: + m.ClearF55() + case largeFieldF56: + m.ClearF56() + case largeFieldF57: + m.ClearF57() + case largeFieldF58: + m.ClearF58() + case largeFieldF59: + m.ClearF59() + case largeFieldF60: + m.ClearF60() + case largeFieldF60: + m.ClearF60() + case largeFieldF61: + m.ClearF61() + case largeFieldF62: + m.ClearF62() + case largeFieldF63: + m.ClearF63() + case largeFieldF64: + m.ClearF64() + case largeFieldF65: + m.ClearF65() + case largeFieldF66: + m.ClearF66() + case largeFieldF67: + m.ClearF67() + case largeFieldF68: + m.ClearF68() + case largeFieldF69: + m.ClearF69() + case largeFieldF70: + m.ClearF70() + case largeFieldF71: + m.ClearF71() + case largeFieldF72: + m.ClearF72() + case largeFieldF73: + m.ClearF73() + case largeFieldF74: + m.ClearF74() + case largeFieldF75: + m.ClearF75() + case largeFieldF76: + m.ClearF76() + case largeFieldF77: + m.ClearF77() + case largeFieldF78: + m.ClearF78() + case largeFieldF79: + m.ClearF79() + case largeFieldF80: + m.ClearF80() + case largeFieldF81: + m.ClearF81() + case largeFieldF82: + m.ClearF82() + case largeFieldF83: + m.ClearF83() + case largeFieldF84: + m.ClearF84() + case largeFieldF85: + m.ClearF85() + case largeFieldF86: + m.ClearF86() + case largeFieldF87: + m.ClearF87() + case largeFieldF88: + m.ClearF88() + case largeFieldF89: + m.ClearF89() + case largeFieldF90: + m.ClearF90() + case largeFieldF91: + m.ClearF91() + case largeFieldF92: + m.ClearF92() + case largeFieldF93: + m.ClearF93() + case largeFieldF94: + m.ClearF94() + case largeFieldF95: + m.ClearF95() + case largeFieldF96: + m.ClearF96() + case largeFieldF97: + m.ClearF97() + case largeFieldF98: + m.ClearF98() + case largeFieldF99: + m.ClearF99() + case largeFieldF100: + m.ClearF100() + + default: + panic(fmt.Sprintf("clear: unknown field %d", num)) + } + }, + } +} diff --git a/internal/reflection_test/reflection_opaque_test.go b/internal/reflection_test/reflection_opaque_test.go new file mode 100644 index 000000000..7bfaa0d9d --- /dev/null +++ b/internal/reflection_test/reflection_opaque_test.go @@ -0,0 +1,1045 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package reflection_test + +import ( + "fmt" + "math" + "testing" + + testpb "google.golang.org/protobuf/internal/testprotos/testeditions/testeditions_opaque" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/testing/prototest" +) + +var enableLazy = proto.UnmarshalOptions{} +var disableLazy = proto.UnmarshalOptions{ + NoLazyDecoding: true, +} + +var lazyCombinations = []struct { + desc string + ptm prototest.Message +}{ + { + desc: "lazy decoding", + ptm: prototest.Message{ + UnmarshalOptions: enableLazy, + }, + }, + + { + desc: "no lazy decoding", + ptm: prototest.Message{ + UnmarshalOptions: disableLazy, + }, + }, +} + +func TestOpaqueConcrete(t *testing.T) { + for _, tt := range lazyCombinations { + t.Run(tt.desc, func(t *testing.T) { + tt.ptm.Test(t, newTestMessageOpaque(nil).ProtoReflect().Type()) + }) + } +} + +func TestOpaqueReflection(t *testing.T) { + for _, tt := range lazyCombinations { + t.Run(tt.desc, func(t *testing.T) { + tt.ptm.Test(t, (*testpb.TestAllTypes)(nil).ProtoReflect().Type()) + }) + } +} + +func TestOpaqueShadow_GetConcrete_SetReflection(t *testing.T) { + for _, tt := range lazyCombinations { + t.Run(tt.desc, func(t *testing.T) { + tt.ptm.Test(t, newShadow(func() (get, set protoreflect.ProtoMessage) { + m := &testpb.TestAllTypes{} + return newTestMessageOpaque(m), m + }).ProtoReflect().Type()) + }) + } +} + +func TestOpaqueShadow_GetReflection_SetConcrete(t *testing.T) { + for _, tt := range lazyCombinations { + t.Run(tt.desc, func(t *testing.T) { + tt.ptm.Test(t, newShadow(func() (get, set protoreflect.ProtoMessage) { + m := &testpb.TestAllTypes{} + return m, newTestMessageOpaque(m) + }).ProtoReflect().Type()) + }) + } +} + +func newTestMessageOpaque(m *testpb.TestAllTypes) protoreflect.ProtoMessage { + return &testProtoMessage{ + m: m, + md: m.ProtoReflect().Descriptor(), + new: func() protoreflect.Message { + return newTestMessageOpaque(&testpb.TestAllTypes{}).ProtoReflect() + }, + has: func(num protoreflect.FieldNumber) bool { + switch num { + case fieldSingularInt32: + return m.GetSingularInt32() != 0 + case fieldSingularInt64: + return m.GetSingularInt64() != 0 + case fieldSingularUint32: + return m.GetSingularUint32() != 0 + case fieldSingularUint64: + return m.GetSingularUint64() != 0 + case fieldSingularSint32: + return m.GetSingularSint32() != 0 + case fieldSingularSint64: + return m.GetSingularSint64() != 0 + case fieldSingularFixed32: + return m.GetSingularFixed32() != 0 + case fieldSingularFixed64: + return m.GetSingularFixed64() != 0 + case fieldSingularSfixed32: + return m.GetSingularSfixed32() != 0 + case fieldSingularSfixed64: + return m.GetSingularSfixed64() != 0 + case fieldSingularFloat: + return m.GetSingularFloat() != 0 || math.Signbit(float64(m.GetSingularFloat())) + case fieldSingularDouble: + return m.GetSingularDouble() != 0 || math.Signbit(m.GetSingularDouble()) + case fieldSingularBool: + return m.GetSingularBool() != false + case fieldSingularString: + return m.GetSingularString() != "" + case fieldSingularBytes: + return len(m.GetSingularBytes()) != 0 + case fieldSingularNestedEnum: + return m.GetSingularNestedEnum() != testpb.TestAllTypes_FOO + case fieldSingularForeignEnum: + return m.GetSingularForeignEnum() != testpb.ForeignEnum_FOREIGN_ZERO + case fieldSingularImportEnum: + return m.GetSingularImportEnum() != testpb.ImportEnum_IMPORT_ZERO + + case fieldOptionalInt32: + return m.HasOptionalInt32() + case fieldOptionalInt64: + return m.HasOptionalInt64() + case fieldOptionalUint32: + return m.HasOptionalUint32() + case fieldOptionalUint64: + return m.HasOptionalUint64() + case fieldOptionalSint32: + return m.HasOptionalSint32() + case fieldOptionalSint64: + return m.HasOptionalSint64() + case fieldOptionalFixed32: + return m.HasOptionalFixed32() + case fieldOptionalFixed64: + return m.HasOptionalFixed64() + case fieldOptionalSfixed32: + return m.HasOptionalSfixed32() + case fieldOptionalSfixed64: + return m.HasOptionalSfixed64() + case fieldOptionalFloat: + return m.HasOptionalFloat() + case fieldOptionalDouble: + return m.HasOptionalDouble() + case fieldOptionalBool: + return m.HasOptionalBool() + case fieldOptionalString: + return m.HasOptionalString() + case fieldOptionalBytes: + return m.HasOptionalBytes() + case fieldOptionalGroup: + return m.HasOptionalgroup() + case fieldNotGroupLikeDelimited: + return m.HasNotGroupLikeDelimited() + case fieldOptionalGroup: + return m.HasOptionalgroup() + case fieldOptionalNestedMessage: + return m.HasOptionalNestedMessage() + case fieldOptionalForeignMessage: + return m.HasOptionalForeignMessage() + case fieldOptionalImportMessage: + return m.HasOptionalImportMessage() + case fieldOptionalNestedEnum: + return m.HasOptionalNestedEnum() + case fieldOptionalForeignEnum: + return m.HasOptionalForeignEnum() + case fieldOptionalImportEnum: + return m.HasOptionalImportEnum() + case fieldOptionalLazyNestedMessage: + return m.HasOptionalLazyNestedMessage() + + case fieldRepeatedInt32: + return len(m.GetRepeatedInt32()) > 0 + case fieldRepeatedInt64: + return len(m.GetRepeatedInt64()) > 0 + case fieldRepeatedUint32: + return len(m.GetRepeatedUint32()) > 0 + case fieldRepeatedUint64: + return len(m.GetRepeatedUint64()) > 0 + case fieldRepeatedSint32: + return len(m.GetRepeatedSint32()) > 0 + case fieldRepeatedSint64: + return len(m.GetRepeatedSint64()) > 0 + case fieldRepeatedFixed32: + return len(m.GetRepeatedFixed32()) > 0 + case fieldRepeatedFixed64: + return len(m.GetRepeatedFixed64()) > 0 + case fieldRepeatedSfixed32: + return len(m.GetRepeatedSfixed32()) > 0 + case fieldRepeatedSfixed64: + return len(m.GetRepeatedSfixed64()) > 0 + case fieldRepeatedFloat: + return len(m.GetRepeatedFloat()) > 0 + case fieldRepeatedDouble: + return len(m.GetRepeatedDouble()) > 0 + case fieldRepeatedBool: + return len(m.GetRepeatedBool()) > 0 + case fieldRepeatedString: + return len(m.GetRepeatedString()) > 0 + case fieldRepeatedBytes: + return len(m.GetRepeatedBytes()) > 0 + case fieldRepeatedGroup: + return len(m.GetRepeatedgroup()) > 0 + case fieldRepeatedNestedMessage: + return len(m.GetRepeatedNestedMessage()) > 0 + case fieldRepeatedForeignMessage: + return len(m.GetRepeatedForeignMessage()) > 0 + case fieldRepeatedImportMessage: + return len(m.GetRepeatedImportmessage()) > 0 + case fieldRepeatedNestedEnum: + return len(m.GetRepeatedNestedEnum()) > 0 + case fieldRepeatedForeignEnum: + return len(m.GetRepeatedForeignEnum()) > 0 + case fieldRepeatedImportEnum: + return len(m.GetRepeatedImportenum()) > 0 + + case fieldMapInt32Int32: + return len(m.GetMapInt32Int32()) > 0 + case fieldMapInt64Int64: + return len(m.GetMapInt64Int64()) > 0 + case fieldMapUint32Uint32: + return len(m.GetMapUint32Uint32()) > 0 + case fieldMapUint64Uint64: + return len(m.GetMapUint64Uint64()) > 0 + case fieldMapSint32Sint32: + return len(m.GetMapSint32Sint32()) > 0 + case fieldMapSint64Sint64: + return len(m.GetMapSint64Sint64()) > 0 + case fieldMapFixed32Fixed32: + return len(m.GetMapFixed32Fixed32()) > 0 + case fieldMapFixed64Fixed64: + return len(m.GetMapFixed64Fixed64()) > 0 + case fieldMapSfixed32Sfixed32: + return len(m.GetMapSfixed32Sfixed32()) > 0 + case fieldMapSfixed64Sfixed64: + return len(m.GetMapSfixed64Sfixed64()) > 0 + case fieldMapInt32Float: + return len(m.GetMapInt32Float()) > 0 + case fieldMapInt32Double: + return len(m.GetMapInt32Double()) > 0 + case fieldMapBoolBool: + return len(m.GetMapBoolBool()) > 0 + case fieldMapStringString: + return len(m.GetMapStringString()) > 0 + case fieldMapStringBytes: + return len(m.GetMapStringBytes()) > 0 + case fieldMapStringNestedMessage: + return len(m.GetMapStringNestedMessage()) > 0 + case fieldMapStringNestedEnum: + return len(m.GetMapStringNestedEnum()) > 0 + + case fieldDefaultInt32: + return m.HasDefaultInt32() + case fieldDefaultInt64: + return m.HasDefaultInt64() + case fieldDefaultUint32: + return m.HasDefaultUint32() + case fieldDefaultUint64: + return m.HasDefaultUint64() + case fieldDefaultSint32: + return m.HasDefaultSint32() + case fieldDefaultSint64: + return m.HasDefaultSint64() + case fieldDefaultFixed32: + return m.HasDefaultFixed32() + case fieldDefaultFixed64: + return m.HasDefaultFixed64() + case fieldDefaultSfixed32: + return m.HasDefaultSfixed32() + case fieldDefaultSfixed64: + return m.HasDefaultSfixed64() + case fieldDefaultFloat: + return m.HasDefaultFloat() + case fieldDefaultDouble: + return m.HasDefaultDouble() + case fieldDefaultBool: + return m.HasDefaultBool() + case fieldDefaultString: + return m.HasDefaultString() + case fieldDefaultBytes: + return m.HasDefaultBytes() + case fieldDefaultNestedEnum: + return m.HasDefaultNestedEnum() + case fieldDefaultForeignEnum: + return m.HasDefaultForeignEnum() + + case fieldDefaultInt32: + return m.HasDefaultInt32() + case fieldDefaultInt64: + return m.HasDefaultInt64() + case fieldDefaultUint32: + return m.HasDefaultUint32() + case fieldDefaultUint64: + return m.HasDefaultUint64() + case fieldDefaultSint32: + return m.HasDefaultSint32() + case fieldDefaultSint64: + return m.HasDefaultSint64() + case fieldDefaultFixed32: + return m.HasDefaultFixed32() + case fieldDefaultFixed64: + return m.HasDefaultFixed64() + case fieldDefaultSfixed32: + return m.HasDefaultSfixed32() + case fieldDefaultSfixed64: + return m.HasDefaultSfixed64() + case fieldDefaultFloat: + return m.HasDefaultFloat() + case fieldDefaultDouble: + return m.HasDefaultDouble() + case fieldDefaultBool: + return m.HasDefaultBool() + case fieldDefaultString: + return m.HasDefaultString() + case fieldDefaultBytes: + return m.HasDefaultBytes() + case fieldDefaultNestedEnum: + return m.HasDefaultNestedEnum() + case fieldDefaultForeignEnum: + return m.HasDefaultForeignEnum() + + case fieldOneofUint32: + return m.HasOneofUint32() + case fieldOneofNestedMessage: + return m.HasOneofNestedMessage() + case fieldOneofString: + return m.HasOneofString() + case fieldOneofBytes: + return m.HasOneofBytes() + case fieldOneofBool: + return m.HasOneofBool() + case fieldOneofUint64: + return m.HasOneofUint64() + case fieldOneofFloat: + return m.HasOneofFloat() + case fieldOneofDouble: + return m.HasOneofDouble() + case fieldOneofEnum: + return m.HasOneofEnum() + case fieldOneofGroup: + return m.HasOneofgroup() + case fieldOneofOptionalUint32: + return m.HasOneofOptionalUint32() + + default: + panic(fmt.Sprintf("has: unknown field %d", num)) + } + }, + get: func(num protoreflect.FieldNumber) any { + switch num { + case fieldSingularInt32: + return m.GetSingularInt32() + case fieldSingularInt64: + return m.GetSingularInt64() + case fieldSingularUint32: + return m.GetSingularUint32() + case fieldSingularUint64: + return m.GetSingularUint64() + case fieldSingularSint32: + return m.GetSingularSint32() + case fieldSingularSint64: + return m.GetSingularSint64() + case fieldSingularFixed32: + return m.GetSingularFixed32() + case fieldSingularFixed64: + return m.GetSingularFixed64() + case fieldSingularSfixed32: + return m.GetSingularSfixed32() + case fieldSingularSfixed64: + return m.GetSingularSfixed64() + case fieldSingularFloat: + return m.GetSingularFloat() + case fieldSingularDouble: + return m.GetSingularDouble() + case fieldSingularBool: + return m.GetSingularBool() + case fieldSingularString: + return m.GetSingularString() + case fieldSingularBytes: + return m.GetSingularBytes() + case fieldSingularNestedEnum: + return m.GetSingularNestedEnum() + case fieldSingularForeignEnum: + return m.GetSingularForeignEnum() + case fieldSingularImportEnum: + return m.GetSingularImportEnum() + + case fieldOptionalInt32: + return m.GetOptionalInt32() + case fieldOptionalInt64: + return m.GetOptionalInt64() + case fieldOptionalUint32: + return m.GetOptionalUint32() + case fieldOptionalUint64: + return m.GetOptionalUint64() + case fieldOptionalSint32: + return m.GetOptionalSint32() + case fieldOptionalSint64: + return m.GetOptionalSint64() + case fieldOptionalFixed32: + return m.GetOptionalFixed32() + case fieldOptionalFixed64: + return m.GetOptionalFixed64() + case fieldOptionalSfixed32: + return m.GetOptionalSfixed32() + case fieldOptionalSfixed64: + return m.GetOptionalSfixed64() + case fieldOptionalFloat: + return m.GetOptionalFloat() + case fieldOptionalDouble: + return m.GetOptionalDouble() + case fieldOptionalBool: + return m.GetOptionalBool() + case fieldOptionalString: + return m.GetOptionalString() + case fieldOptionalBytes: + return m.GetOptionalBytes() + case fieldOptionalGroup: + return m.GetOptionalgroup() + case fieldNotGroupLikeDelimited: + return m.GetNotGroupLikeDelimited() + case fieldOptionalNestedMessage: + return m.GetOptionalNestedMessage() + case fieldOptionalForeignMessage: + return m.GetOptionalForeignMessage() + case fieldOptionalImportMessage: + return m.GetOptionalImportMessage() + case fieldOptionalNestedEnum: + return m.GetOptionalNestedEnum() + case fieldOptionalForeignEnum: + return m.GetOptionalForeignEnum() + case fieldOptionalImportEnum: + return m.GetOptionalImportEnum() + case fieldOptionalLazyNestedMessage: + return m.GetOptionalLazyNestedMessage() + + case fieldRepeatedInt32: + return m.GetRepeatedInt32() + case fieldRepeatedInt64: + return m.GetRepeatedInt64() + case fieldRepeatedUint32: + return m.GetRepeatedUint32() + case fieldRepeatedUint64: + return m.GetRepeatedUint64() + case fieldRepeatedSint32: + return m.GetRepeatedSint32() + case fieldRepeatedSint64: + return m.GetRepeatedSint64() + case fieldRepeatedFixed32: + return m.GetRepeatedFixed32() + case fieldRepeatedFixed64: + return m.GetRepeatedFixed64() + case fieldRepeatedSfixed32: + return m.GetRepeatedSfixed32() + case fieldRepeatedSfixed64: + return m.GetRepeatedSfixed64() + case fieldRepeatedFloat: + return m.GetRepeatedFloat() + case fieldRepeatedDouble: + return m.GetRepeatedDouble() + case fieldRepeatedBool: + return m.GetRepeatedBool() + case fieldRepeatedString: + return m.GetRepeatedString() + case fieldRepeatedBytes: + return m.GetRepeatedBytes() + case fieldRepeatedGroup: + return m.GetRepeatedgroup() + case fieldRepeatedNestedMessage: + return m.GetRepeatedNestedMessage() + case fieldRepeatedForeignMessage: + return m.GetRepeatedForeignMessage() + case fieldRepeatedImportMessage: + return m.GetRepeatedImportmessage() + case fieldRepeatedNestedEnum: + return m.GetRepeatedNestedEnum() + case fieldRepeatedForeignEnum: + return m.GetRepeatedForeignEnum() + case fieldRepeatedImportEnum: + return m.GetRepeatedImportenum() + + case fieldMapInt32Int32: + return m.GetMapInt32Int32() + case fieldMapInt64Int64: + return m.GetMapInt64Int64() + case fieldMapUint32Uint32: + return m.GetMapUint32Uint32() + case fieldMapUint64Uint64: + return m.GetMapUint64Uint64() + case fieldMapSint32Sint32: + return m.GetMapSint32Sint32() + case fieldMapSint64Sint64: + return m.GetMapSint64Sint64() + case fieldMapFixed32Fixed32: + return m.GetMapFixed32Fixed32() + case fieldMapFixed64Fixed64: + return m.GetMapFixed64Fixed64() + case fieldMapSfixed32Sfixed32: + return m.GetMapSfixed32Sfixed32() + case fieldMapSfixed64Sfixed64: + return m.GetMapSfixed64Sfixed64() + case fieldMapInt32Float: + return m.GetMapInt32Float() + case fieldMapInt32Double: + return m.GetMapInt32Double() + case fieldMapBoolBool: + return m.GetMapBoolBool() + case fieldMapStringString: + return m.GetMapStringString() + case fieldMapStringBytes: + return m.GetMapStringBytes() + case fieldMapStringNestedMessage: + return m.GetMapStringNestedMessage() + case fieldMapStringNestedEnum: + return m.GetMapStringNestedEnum() + + case fieldDefaultInt32: + return m.GetDefaultInt32() + case fieldDefaultInt64: + return m.GetDefaultInt64() + case fieldDefaultUint32: + return m.GetDefaultUint32() + case fieldDefaultUint64: + return m.GetDefaultUint64() + case fieldDefaultSint32: + return m.GetDefaultSint32() + case fieldDefaultSint64: + return m.GetDefaultSint64() + case fieldDefaultFixed32: + return m.GetDefaultFixed32() + case fieldDefaultFixed64: + return m.GetDefaultFixed64() + case fieldDefaultSfixed32: + return m.GetDefaultSfixed32() + case fieldDefaultSfixed64: + return m.GetDefaultSfixed64() + case fieldDefaultFloat: + return m.GetDefaultFloat() + case fieldDefaultDouble: + return m.GetDefaultDouble() + case fieldDefaultBool: + return m.GetDefaultBool() + case fieldDefaultString: + return m.GetDefaultString() + case fieldDefaultBytes: + return m.GetDefaultBytes() + case fieldDefaultNestedEnum: + return m.GetDefaultNestedEnum() + case fieldDefaultForeignEnum: + return m.GetDefaultForeignEnum() + + case fieldOneofUint32: + return m.GetOneofUint32() + case fieldOneofNestedMessage: + return m.GetOneofNestedMessage() + case fieldOneofString: + return m.GetOneofString() + case fieldOneofBytes: + return m.GetOneofBytes() + case fieldOneofBool: + return m.GetOneofBool() + case fieldOneofUint64: + return m.GetOneofUint64() + case fieldOneofFloat: + return m.GetOneofFloat() + case fieldOneofDouble: + return m.GetOneofDouble() + case fieldOneofEnum: + return m.GetOneofEnum() + case fieldOneofGroup: + return m.GetOneofgroup() + case fieldOneofOptionalUint32: + return m.GetOneofOptionalUint32() + + default: + panic(fmt.Sprintf("get: unknown field %d", num)) + } + }, + set: func(num protoreflect.FieldNumber, v any) { + switch num { + case fieldSingularInt32: + m.SetSingularInt32(v.(int32)) + case fieldSingularInt64: + m.SetSingularInt64(v.(int64)) + case fieldSingularUint32: + m.SetSingularUint32(v.(uint32)) + case fieldSingularUint64: + m.SetSingularUint64(v.(uint64)) + case fieldSingularSint32: + m.SetSingularSint32(v.(int32)) + case fieldSingularSint64: + m.SetSingularSint64(v.(int64)) + case fieldSingularFixed32: + m.SetSingularFixed32(v.(uint32)) + case fieldSingularFixed64: + m.SetSingularFixed64(v.(uint64)) + case fieldSingularSfixed32: + m.SetSingularSfixed32(v.(int32)) + case fieldSingularSfixed64: + m.SetSingularSfixed64(v.(int64)) + case fieldSingularFloat: + m.SetSingularFloat(v.(float32)) + case fieldSingularDouble: + m.SetSingularDouble(v.(float64)) + case fieldSingularBool: + m.SetSingularBool(v.(bool)) + case fieldSingularString: + m.SetSingularString(v.(string)) + case fieldSingularBytes: + m.SetSingularBytes(v.([]byte)) + case fieldSingularNestedEnum: + m.SetSingularNestedEnum(testpb.TestAllTypes_NestedEnum(v.(protoreflect.EnumNumber))) + case fieldSingularForeignEnum: + m.SetSingularForeignEnum(testpb.ForeignEnum(v.(protoreflect.EnumNumber))) + case fieldSingularImportEnum: + m.SetSingularImportEnum(testpb.ImportEnum(v.(protoreflect.EnumNumber))) + + case fieldOptionalInt32: + m.SetOptionalInt32(v.(int32)) + case fieldOptionalInt64: + m.SetOptionalInt64(v.(int64)) + case fieldOptionalUint32: + m.SetOptionalUint32(v.(uint32)) + case fieldOptionalUint64: + m.SetOptionalUint64(v.(uint64)) + case fieldOptionalSint32: + m.SetOptionalSint32(v.(int32)) + case fieldOptionalSint64: + m.SetOptionalSint64(v.(int64)) + case fieldOptionalFixed32: + m.SetOptionalFixed32(v.(uint32)) + case fieldOptionalFixed64: + m.SetOptionalFixed64(v.(uint64)) + case fieldOptionalSfixed32: + m.SetOptionalSfixed32(v.(int32)) + case fieldOptionalSfixed64: + m.SetOptionalSfixed64(v.(int64)) + case fieldOptionalFloat: + m.SetOptionalFloat(v.(float32)) + case fieldOptionalDouble: + m.SetOptionalDouble(v.(float64)) + case fieldOptionalBool: + m.SetOptionalBool(v.(bool)) + case fieldOptionalString: + m.SetOptionalString(v.(string)) + case fieldOptionalBytes: + m.SetOptionalBytes(v.([]byte)) + case fieldOptionalGroup: + m.SetOptionalgroup(v.(*testpb.TestAllTypes_OptionalGroup)) + case fieldNotGroupLikeDelimited: + m.SetNotGroupLikeDelimited(v.(*testpb.TestAllTypes_OptionalGroup)) + case fieldOptionalNestedMessage: + m.SetOptionalNestedMessage(v.(*testpb.TestAllTypes_NestedMessage)) + case fieldOptionalForeignMessage: + m.SetOptionalForeignMessage(v.(*testpb.ForeignMessage)) + case fieldOptionalImportMessage: + m.SetOptionalImportMessage(v.(*testpb.ImportMessage)) + case fieldOptionalNestedEnum: + m.SetOptionalNestedEnum(testpb.TestAllTypes_NestedEnum(v.(protoreflect.EnumNumber))) + case fieldOptionalForeignEnum: + m.SetOptionalForeignEnum(testpb.ForeignEnum(v.(protoreflect.EnumNumber))) + case fieldOptionalImportEnum: + m.SetOptionalImportEnum(testpb.ImportEnum(v.(protoreflect.EnumNumber))) + case fieldOptionalLazyNestedMessage: + m.SetOptionalLazyNestedMessage(v.(*testpb.TestAllTypes_NestedMessage)) + + case fieldRepeatedInt32: + m.SetRepeatedInt32(v.([]int32)) + case fieldRepeatedInt64: + m.SetRepeatedInt64(v.([]int64)) + case fieldRepeatedUint32: + m.SetRepeatedUint32(v.([]uint32)) + case fieldRepeatedUint64: + m.SetRepeatedUint64(v.([]uint64)) + case fieldRepeatedSint32: + m.SetRepeatedSint32(v.([]int32)) + case fieldRepeatedSint64: + m.SetRepeatedSint64(v.([]int64)) + case fieldRepeatedFixed32: + m.SetRepeatedFixed32(v.([]uint32)) + case fieldRepeatedFixed64: + m.SetRepeatedFixed64(v.([]uint64)) + case fieldRepeatedSfixed32: + m.SetRepeatedSfixed32(v.([]int32)) + case fieldRepeatedSfixed64: + m.SetRepeatedSfixed64(v.([]int64)) + case fieldRepeatedFloat: + m.SetRepeatedFloat(v.([]float32)) + case fieldRepeatedDouble: + m.SetRepeatedDouble(v.([]float64)) + case fieldRepeatedBool: + m.SetRepeatedBool(v.([]bool)) + case fieldRepeatedString: + m.SetRepeatedString(v.([]string)) + case fieldRepeatedBytes: + m.SetRepeatedBytes(v.([][]byte)) + case fieldRepeatedGroup: + m.SetRepeatedgroup(v.([]*testpb.TestAllTypes_RepeatedGroup)) + case fieldRepeatedNestedMessage: + m.SetRepeatedNestedMessage(v.([]*testpb.TestAllTypes_NestedMessage)) + case fieldRepeatedForeignMessage: + m.SetRepeatedForeignMessage(v.([]*testpb.ForeignMessage)) + case fieldRepeatedImportMessage: + m.SetRepeatedImportmessage(v.([]*testpb.ImportMessage)) + case fieldRepeatedNestedEnum: + m.SetRepeatedNestedEnum(v.([]testpb.TestAllTypes_NestedEnum)) + case fieldRepeatedForeignEnum: + m.SetRepeatedForeignEnum(v.([]testpb.ForeignEnum)) + case fieldRepeatedImportEnum: + m.SetRepeatedImportenum(v.([]testpb.ImportEnum)) + + case fieldMapInt32Int32: + m.SetMapInt32Int32(v.(map[int32]int32)) + case fieldMapInt64Int64: + m.SetMapInt64Int64(v.(map[int64]int64)) + case fieldMapUint32Uint32: + m.SetMapUint32Uint32(v.(map[uint32]uint32)) + case fieldMapUint64Uint64: + m.SetMapUint64Uint64(v.(map[uint64]uint64)) + case fieldMapSint32Sint32: + m.SetMapSint32Sint32(v.(map[int32]int32)) + case fieldMapSint64Sint64: + m.SetMapSint64Sint64(v.(map[int64]int64)) + case fieldMapFixed32Fixed32: + m.SetMapFixed32Fixed32(v.(map[uint32]uint32)) + case fieldMapFixed64Fixed64: + m.SetMapFixed64Fixed64(v.(map[uint64]uint64)) + case fieldMapSfixed32Sfixed32: + m.SetMapSfixed32Sfixed32(v.(map[int32]int32)) + case fieldMapSfixed64Sfixed64: + m.SetMapSfixed64Sfixed64(v.(map[int64]int64)) + case fieldMapInt32Float: + m.SetMapInt32Float(v.(map[int32]float32)) + case fieldMapInt32Double: + m.SetMapInt32Double(v.(map[int32]float64)) + case fieldMapBoolBool: + m.SetMapBoolBool(v.(map[bool]bool)) + case fieldMapStringString: + m.SetMapStringString(v.(map[string]string)) + case fieldMapStringBytes: + m.SetMapStringBytes(v.(map[string][]byte)) + case fieldMapStringNestedMessage: + m.SetMapStringNestedMessage(v.(map[string]*testpb.TestAllTypes_NestedMessage)) + case fieldMapStringNestedEnum: + m.SetMapStringNestedEnum(v.(map[string]testpb.TestAllTypes_NestedEnum)) + + case fieldDefaultInt32: + m.SetDefaultInt32(v.(int32)) + case fieldDefaultInt64: + m.SetDefaultInt64(v.(int64)) + case fieldDefaultUint32: + m.SetDefaultUint32(v.(uint32)) + case fieldDefaultUint64: + m.SetDefaultUint64(v.(uint64)) + case fieldDefaultSint32: + m.SetDefaultSint32(v.(int32)) + case fieldDefaultSint64: + m.SetDefaultSint64(v.(int64)) + case fieldDefaultFixed32: + m.SetDefaultFixed32(v.(uint32)) + case fieldDefaultFixed64: + m.SetDefaultFixed64(v.(uint64)) + case fieldDefaultSfixed32: + m.SetDefaultSfixed32(v.(int32)) + case fieldDefaultSfixed64: + m.SetDefaultSfixed64(v.(int64)) + case fieldDefaultFloat: + m.SetDefaultFloat(v.(float32)) + case fieldDefaultDouble: + m.SetDefaultDouble(v.(float64)) + case fieldDefaultBool: + m.SetDefaultBool(v.(bool)) + case fieldDefaultString: + m.SetDefaultString(v.(string)) + case fieldDefaultBytes: + m.SetDefaultBytes(v.([]byte)) + case fieldDefaultNestedEnum: + m.SetDefaultNestedEnum(testpb.TestAllTypes_NestedEnum(v.(protoreflect.EnumNumber))) + case fieldDefaultForeignEnum: + m.SetDefaultForeignEnum(testpb.ForeignEnum(v.(protoreflect.EnumNumber))) + + case fieldOneofUint32: + m.SetOneofUint32(v.(uint32)) + case fieldOneofNestedMessage: + m.SetOneofNestedMessage(v.(*testpb.TestAllTypes_NestedMessage)) + case fieldOneofString: + m.SetOneofString(v.(string)) + case fieldOneofBytes: + m.SetOneofBytes(v.([]byte)) + case fieldOneofBool: + m.SetOneofBool(v.(bool)) + case fieldOneofUint64: + m.SetOneofUint64(v.(uint64)) + case fieldOneofFloat: + m.SetOneofFloat(v.(float32)) + case fieldOneofDouble: + m.SetOneofDouble(v.(float64)) + case fieldOneofEnum: + m.SetOneofEnum(testpb.TestAllTypes_NestedEnum(v.(protoreflect.EnumNumber))) + case fieldOneofGroup: + m.SetOneofgroup(v.(*testpb.TestAllTypes_OneofGroup)) + case fieldOneofOptionalUint32: + m.SetOneofOptionalUint32(v.(uint32)) + + default: + panic(fmt.Sprintf("set: unknown field %d", num)) + } + }, + clear: func(num protoreflect.FieldNumber) { + switch num { + case fieldSingularInt32: + m.SetSingularInt32(0) + case fieldSingularInt64: + m.SetSingularInt64(0) + case fieldSingularUint32: + m.SetSingularUint32(0) + case fieldSingularUint64: + m.SetSingularUint64(0) + case fieldSingularSint32: + m.SetSingularSint32(0) + case fieldSingularSint64: + m.SetSingularSint64(0) + case fieldSingularFixed32: + m.SetSingularFixed32(0) + case fieldSingularFixed64: + m.SetSingularFixed64(0) + case fieldSingularSfixed32: + m.SetSingularSfixed32(0) + case fieldSingularSfixed64: + m.SetSingularSfixed64(0) + case fieldSingularFloat: + m.SetSingularFloat(0) + case fieldSingularDouble: + m.SetSingularDouble(0) + case fieldSingularBool: + m.SetSingularBool(false) + case fieldSingularString: + m.SetSingularString("") + case fieldSingularBytes: + m.SetSingularBytes(nil) + case fieldSingularNestedEnum: + m.SetSingularNestedEnum(testpb.TestAllTypes_FOO) + case fieldSingularForeignEnum: + m.SetSingularForeignEnum(testpb.ForeignEnum_FOREIGN_ZERO) + case fieldSingularImportEnum: + m.SetSingularImportEnum(testpb.ImportEnum_IMPORT_ZERO) + + case fieldOptionalInt32: + m.ClearOptionalInt32() + case fieldOptionalInt64: + m.ClearOptionalInt64() + case fieldOptionalUint32: + m.ClearOptionalUint32() + case fieldOptionalUint64: + m.ClearOptionalUint64() + case fieldOptionalSint32: + m.ClearOptionalSint32() + case fieldOptionalSint64: + m.ClearOptionalSint64() + case fieldOptionalFixed32: + m.ClearOptionalFixed32() + case fieldOptionalFixed64: + m.ClearOptionalFixed64() + case fieldOptionalSfixed32: + m.ClearOptionalSfixed32() + case fieldOptionalSfixed64: + m.ClearOptionalSfixed64() + case fieldOptionalFloat: + m.ClearOptionalFloat() + case fieldOptionalDouble: + m.ClearOptionalDouble() + case fieldOptionalBool: + m.ClearOptionalBool() + case fieldOptionalString: + m.ClearOptionalString() + case fieldOptionalBytes: + m.ClearOptionalBytes() + case fieldOptionalGroup: + m.ClearOptionalgroup() + case fieldNotGroupLikeDelimited: + m.ClearNotGroupLikeDelimited() + case fieldOptionalNestedMessage: + m.ClearOptionalNestedMessage() + case fieldOptionalForeignMessage: + m.ClearOptionalForeignMessage() + case fieldOptionalImportMessage: + m.ClearOptionalImportMessage() + case fieldOptionalNestedEnum: + m.ClearOptionalNestedEnum() + case fieldOptionalForeignEnum: + m.ClearOptionalForeignEnum() + case fieldOptionalImportEnum: + m.ClearOptionalImportEnum() + case fieldOptionalLazyNestedMessage: + m.ClearOptionalLazyNestedMessage() + + case fieldRepeatedInt32: + m.SetRepeatedInt32(nil) + case fieldRepeatedInt64: + m.SetRepeatedInt64(nil) + case fieldRepeatedUint32: + m.SetRepeatedUint32(nil) + case fieldRepeatedUint64: + m.SetRepeatedUint64(nil) + case fieldRepeatedSint32: + m.SetRepeatedSint32(nil) + case fieldRepeatedSint64: + m.SetRepeatedSint64(nil) + case fieldRepeatedFixed32: + m.SetRepeatedFixed32(nil) + case fieldRepeatedFixed64: + m.SetRepeatedFixed64(nil) + case fieldRepeatedSfixed32: + m.SetRepeatedSfixed32(nil) + case fieldRepeatedSfixed64: + m.SetRepeatedSfixed64(nil) + case fieldRepeatedFloat: + m.SetRepeatedFloat(nil) + case fieldRepeatedDouble: + m.SetRepeatedDouble(nil) + case fieldRepeatedBool: + m.SetRepeatedBool(nil) + case fieldRepeatedString: + m.SetRepeatedString(nil) + case fieldRepeatedBytes: + m.SetRepeatedBytes(nil) + case fieldRepeatedGroup: + m.SetRepeatedgroup(nil) + case fieldRepeatedNestedMessage: + m.SetRepeatedNestedMessage(nil) + case fieldRepeatedForeignMessage: + m.SetRepeatedForeignMessage(nil) + case fieldRepeatedImportMessage: + m.SetRepeatedImportmessage(nil) + case fieldRepeatedNestedEnum: + m.SetRepeatedNestedEnum(nil) + case fieldRepeatedForeignEnum: + m.SetRepeatedForeignEnum(nil) + case fieldRepeatedImportEnum: + m.SetRepeatedImportenum(nil) + + case fieldMapInt32Int32: + m.SetMapInt32Int32(nil) + case fieldMapInt64Int64: + m.SetMapInt64Int64(nil) + case fieldMapUint32Uint32: + m.SetMapUint32Uint32(nil) + case fieldMapUint64Uint64: + m.SetMapUint64Uint64(nil) + case fieldMapSint32Sint32: + m.SetMapSint32Sint32(nil) + case fieldMapSint64Sint64: + m.SetMapSint64Sint64(nil) + case fieldMapFixed32Fixed32: + m.SetMapFixed32Fixed32(nil) + case fieldMapFixed64Fixed64: + m.SetMapFixed64Fixed64(nil) + case fieldMapSfixed32Sfixed32: + m.SetMapSfixed32Sfixed32(nil) + case fieldMapSfixed64Sfixed64: + m.SetMapSfixed64Sfixed64(nil) + case fieldMapInt32Float: + m.SetMapInt32Float(nil) + case fieldMapInt32Double: + m.SetMapInt32Double(nil) + case fieldMapBoolBool: + m.SetMapBoolBool(nil) + case fieldMapStringString: + m.SetMapStringString(nil) + case fieldMapStringBytes: + m.SetMapStringBytes(nil) + case fieldMapStringNestedMessage: + m.SetMapStringNestedMessage(nil) + case fieldMapStringNestedEnum: + m.SetMapStringNestedEnum(nil) + + case fieldDefaultInt32: + m.ClearDefaultInt32() + case fieldDefaultInt64: + m.ClearDefaultInt64() + case fieldDefaultUint32: + m.ClearDefaultUint32() + case fieldDefaultUint64: + m.ClearDefaultUint64() + case fieldDefaultSint32: + m.ClearDefaultSint32() + case fieldDefaultSint64: + m.ClearDefaultSint64() + case fieldDefaultFixed32: + m.ClearDefaultFixed32() + case fieldDefaultFixed64: + m.ClearDefaultFixed64() + case fieldDefaultSfixed32: + m.ClearDefaultSfixed32() + case fieldDefaultSfixed64: + m.ClearDefaultSfixed64() + case fieldDefaultFloat: + m.ClearDefaultFloat() + case fieldDefaultDouble: + m.ClearDefaultDouble() + case fieldDefaultBool: + m.ClearDefaultBool() + case fieldDefaultString: + m.ClearDefaultString() + case fieldDefaultBytes: + m.ClearDefaultBytes() + case fieldDefaultNestedEnum: + m.ClearDefaultNestedEnum() + case fieldDefaultForeignEnum: + m.ClearDefaultForeignEnum() + + case fieldOneofUint32: + m.ClearOneofUint32() + case fieldOneofNestedMessage: + m.ClearOneofNestedMessage() + case fieldOneofString: + m.ClearOneofString() + case fieldOneofBytes: + m.ClearOneofBytes() + case fieldOneofBool: + m.ClearOneofBool() + case fieldOneofUint64: + m.ClearOneofUint64() + case fieldOneofFloat: + m.ClearOneofFloat() + case fieldOneofDouble: + m.ClearOneofDouble() + case fieldOneofEnum: + m.ClearOneofEnum() + case fieldOneofGroup: + m.ClearOneofgroup() + case fieldOneofOptionalUint32: + m.ClearOneofOptionalUint32() + + default: + panic(fmt.Sprintf("clear: unknown field %d", num)) + } + }, + } +} diff --git a/internal/reflection_test/reflection_open_test.go b/internal/reflection_test/reflection_open_test.go new file mode 100644 index 000000000..228dd2015 --- /dev/null +++ b/internal/reflection_test/reflection_open_test.go @@ -0,0 +1,985 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package reflection_test + +import ( + "fmt" + "math" + "testing" + + testpb "google.golang.org/protobuf/internal/testprotos/testeditions" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/testing/prototest" +) + +func TestOpenConcrete(t *testing.T) { + prototest.Message{}.Test(t, newTestMessageOpen(nil).ProtoReflect().Type()) +} + +func TestOpenReflection(t *testing.T) { + prototest.Message{}.Test(t, (*testpb.TestAllTypes)(nil).ProtoReflect().Type()) +} + +func TestOpenShadow_GetConcrete_SetReflection(t *testing.T) { + prototest.Message{}.Test(t, newShadow(func() (get, set protoreflect.ProtoMessage) { + m := &testpb.TestAllTypes{} + return newTestMessageOpen(m), m + }).ProtoReflect().Type()) +} + +func TestOpenShadow_GetReflection_SetConcrete(t *testing.T) { + prototest.Message{}.Test(t, newShadow(func() (get, set protoreflect.ProtoMessage) { + m := &testpb.TestAllTypes{} + return m, newTestMessageOpen(m) + }).ProtoReflect().Type()) +} + +func newTestMessageOpen(m *testpb.TestAllTypes) protoreflect.ProtoMessage { + return &testProtoMessage{ + m: m, + md: m.ProtoReflect().Descriptor(), + new: func() protoreflect.Message { + return newTestMessageOpen(&testpb.TestAllTypes{}).ProtoReflect() + }, + has: func(num protoreflect.FieldNumber) bool { + switch num { + case fieldSingularInt32: + return m.GetSingularInt32() != 0 + case fieldSingularInt64: + return m.GetSingularInt64() != 0 + case fieldSingularUint32: + return m.GetSingularUint32() != 0 + case fieldSingularUint64: + return m.GetSingularUint64() != 0 + case fieldSingularSint32: + return m.GetSingularSint32() != 0 + case fieldSingularSint64: + return m.GetSingularSint64() != 0 + case fieldSingularFixed32: + return m.GetSingularFixed32() != 0 + case fieldSingularFixed64: + return m.GetSingularFixed64() != 0 + case fieldSingularSfixed32: + return m.GetSingularSfixed32() != 0 + case fieldSingularSfixed64: + return m.GetSingularSfixed64() != 0 + case fieldSingularFloat: + return m.GetSingularFloat() != 0 || math.Signbit(float64(m.GetSingularFloat())) + case fieldSingularDouble: + return m.GetSingularDouble() != 0 || math.Signbit(m.GetSingularDouble()) + case fieldSingularBool: + return m.GetSingularBool() != false + case fieldSingularString: + return m.GetSingularString() != "" + case fieldSingularBytes: + return m.SingularBytes != nil + case fieldSingularNestedEnum: + return m.GetSingularNestedEnum() != testpb.TestAllTypes_FOO + case fieldSingularForeignEnum: + return m.GetSingularForeignEnum() != testpb.ForeignEnum_FOREIGN_ZERO + case fieldSingularImportEnum: + return m.GetSingularImportEnum() != testpb.ImportEnum_IMPORT_ZERO + + case fieldOptionalInt32: + return m.OptionalInt32 != nil + case fieldOptionalInt64: + return m.OptionalInt64 != nil + case fieldOptionalUint32: + return m.OptionalUint32 != nil + case fieldOptionalUint64: + return m.OptionalUint64 != nil + case fieldOptionalSint32: + return m.OptionalSint32 != nil + case fieldOptionalSint64: + return m.OptionalSint64 != nil + case fieldOptionalFixed32: + return m.OptionalFixed32 != nil + case fieldOptionalFixed64: + return m.OptionalFixed64 != nil + case fieldOptionalSfixed32: + return m.OptionalSfixed32 != nil + case fieldOptionalSfixed64: + return m.OptionalSfixed64 != nil + case fieldOptionalFloat: + return m.OptionalFloat != nil + case fieldOptionalDouble: + return m.OptionalDouble != nil + case fieldOptionalBool: + return m.OptionalBool != nil + case fieldOptionalString: + return m.OptionalString != nil + case fieldOptionalBytes: + return m.OptionalBytes != nil + case fieldOptionalGroup: + return m.Optionalgroup != nil + case fieldNotGroupLikeDelimited: + return m.NotGroupLikeDelimited != nil + case fieldOptionalNestedMessage: + return m.OptionalNestedMessage != nil + case fieldOptionalForeignMessage: + return m.OptionalForeignMessage != nil + case fieldOptionalImportMessage: + return m.OptionalImportMessage != nil + case fieldOptionalNestedEnum: + return m.OptionalNestedEnum != nil + case fieldOptionalForeignEnum: + return m.OptionalForeignEnum != nil + case fieldOptionalImportEnum: + return m.OptionalImportEnum != nil + case fieldOptionalLazyNestedMessage: + return m.OptionalLazyNestedMessage != nil + + case fieldRepeatedInt32: + return len(m.GetRepeatedInt32()) > 0 + case fieldRepeatedInt64: + return len(m.GetRepeatedInt64()) > 0 + case fieldRepeatedUint32: + return len(m.GetRepeatedUint32()) > 0 + case fieldRepeatedUint64: + return len(m.GetRepeatedUint64()) > 0 + case fieldRepeatedSint32: + return len(m.GetRepeatedSint32()) > 0 + case fieldRepeatedSint64: + return len(m.GetRepeatedSint64()) > 0 + case fieldRepeatedFixed32: + return len(m.GetRepeatedFixed32()) > 0 + case fieldRepeatedFixed64: + return len(m.GetRepeatedFixed64()) > 0 + case fieldRepeatedSfixed32: + return len(m.GetRepeatedSfixed32()) > 0 + case fieldRepeatedSfixed64: + return len(m.GetRepeatedSfixed64()) > 0 + case fieldRepeatedFloat: + return len(m.GetRepeatedFloat()) > 0 + case fieldRepeatedDouble: + return len(m.GetRepeatedDouble()) > 0 + case fieldRepeatedBool: + return len(m.GetRepeatedBool()) > 0 + case fieldRepeatedString: + return len(m.GetRepeatedString()) > 0 + case fieldRepeatedBytes: + return len(m.GetRepeatedBytes()) > 0 + case fieldRepeatedGroup: + return len(m.GetRepeatedgroup()) > 0 + case fieldRepeatedNestedMessage: + return len(m.GetRepeatedNestedMessage()) > 0 + case fieldRepeatedForeignMessage: + return len(m.GetRepeatedForeignMessage()) > 0 + case fieldRepeatedImportMessage: + return len(m.GetRepeatedImportmessage()) > 0 + case fieldRepeatedNestedEnum: + return len(m.GetRepeatedNestedEnum()) > 0 + case fieldRepeatedForeignEnum: + return len(m.GetRepeatedForeignEnum()) > 0 + case fieldRepeatedImportEnum: + return len(m.GetRepeatedImportenum()) > 0 + + case fieldMapInt32Int32: + return len(m.GetMapInt32Int32()) > 0 + case fieldMapInt64Int64: + return len(m.GetMapInt64Int64()) > 0 + case fieldMapUint32Uint32: + return len(m.GetMapUint32Uint32()) > 0 + case fieldMapUint64Uint64: + return len(m.GetMapUint64Uint64()) > 0 + case fieldMapSint32Sint32: + return len(m.GetMapSint32Sint32()) > 0 + case fieldMapSint64Sint64: + return len(m.GetMapSint64Sint64()) > 0 + case fieldMapFixed32Fixed32: + return len(m.GetMapFixed32Fixed32()) > 0 + case fieldMapFixed64Fixed64: + return len(m.GetMapFixed64Fixed64()) > 0 + case fieldMapSfixed32Sfixed32: + return len(m.GetMapSfixed32Sfixed32()) > 0 + case fieldMapSfixed64Sfixed64: + return len(m.GetMapSfixed64Sfixed64()) > 0 + case fieldMapInt32Float: + return len(m.GetMapInt32Float()) > 0 + case fieldMapInt32Double: + return len(m.GetMapInt32Double()) > 0 + case fieldMapBoolBool: + return len(m.GetMapBoolBool()) > 0 + case fieldMapStringString: + return len(m.GetMapStringString()) > 0 + case fieldMapStringBytes: + return len(m.GetMapStringBytes()) > 0 + case fieldMapStringNestedMessage: + return len(m.GetMapStringNestedMessage()) > 0 + case fieldMapStringNestedEnum: + return len(m.GetMapStringNestedEnum()) > 0 + + case fieldDefaultInt32: + return m.DefaultInt32 != nil + case fieldDefaultInt64: + return m.DefaultInt64 != nil + case fieldDefaultUint32: + return m.DefaultUint32 != nil + case fieldDefaultUint64: + return m.DefaultUint64 != nil + case fieldDefaultSint32: + return m.DefaultSint32 != nil + case fieldDefaultSint64: + return m.DefaultSint64 != nil + case fieldDefaultFixed32: + return m.DefaultFixed32 != nil + case fieldDefaultFixed64: + return m.DefaultFixed64 != nil + case fieldDefaultSfixed32: + return m.DefaultSfixed32 != nil + case fieldDefaultSfixed64: + return m.DefaultSfixed64 != nil + case fieldDefaultFloat: + return m.DefaultFloat != nil + case fieldDefaultDouble: + return m.DefaultDouble != nil + case fieldDefaultBool: + return m.DefaultBool != nil + case fieldDefaultString: + return m.DefaultString != nil + case fieldDefaultBytes: + return m.DefaultBytes != nil + case fieldDefaultNestedEnum: + return m.DefaultNestedEnum != nil + case fieldDefaultForeignEnum: + return m.DefaultForeignEnum != nil + + case fieldOneofUint32: + _, ok := m.OneofField.(*testpb.TestAllTypes_OneofUint32) + return ok + case fieldOneofNestedMessage: + _, ok := m.OneofField.(*testpb.TestAllTypes_OneofNestedMessage) + return ok + case fieldOneofString: + _, ok := m.OneofField.(*testpb.TestAllTypes_OneofString) + return ok + case fieldOneofBytes: + _, ok := m.OneofField.(*testpb.TestAllTypes_OneofBytes) + return ok + case fieldOneofBool: + _, ok := m.OneofField.(*testpb.TestAllTypes_OneofBool) + return ok + case fieldOneofUint64: + _, ok := m.OneofField.(*testpb.TestAllTypes_OneofUint64) + return ok + case fieldOneofFloat: + _, ok := m.OneofField.(*testpb.TestAllTypes_OneofFloat) + return ok + case fieldOneofDouble: + _, ok := m.OneofField.(*testpb.TestAllTypes_OneofDouble) + return ok + case fieldOneofEnum: + _, ok := m.OneofField.(*testpb.TestAllTypes_OneofEnum) + return ok + case fieldOneofGroup: + _, ok := m.OneofField.(*testpb.TestAllTypes_Oneofgroup) + return ok + case fieldOneofOptionalUint32: + _, ok := m.OneofOptional.(*testpb.TestAllTypes_OneofOptionalUint32) + return ok + + default: + panic(fmt.Sprintf("has: unknown field %d", num)) + } + }, + get: func(num protoreflect.FieldNumber) any { + switch num { + case fieldSingularInt32: + return m.GetSingularInt32() + case fieldSingularInt64: + return m.GetSingularInt64() + case fieldSingularUint32: + return m.GetSingularUint32() + case fieldSingularUint64: + return m.GetSingularUint64() + case fieldSingularSint32: + return m.GetSingularSint32() + case fieldSingularSint64: + return m.GetSingularSint64() + case fieldSingularFixed32: + return m.GetSingularFixed32() + case fieldSingularFixed64: + return m.GetSingularFixed64() + case fieldSingularSfixed32: + return m.GetSingularSfixed32() + case fieldSingularSfixed64: + return m.GetSingularSfixed64() + case fieldSingularFloat: + return m.GetSingularFloat() + case fieldSingularDouble: + return m.GetSingularDouble() + case fieldSingularBool: + return m.GetSingularBool() + case fieldSingularString: + return m.GetSingularString() + case fieldSingularBytes: + return m.GetSingularBytes() + case fieldSingularNestedEnum: + return m.GetSingularNestedEnum() + case fieldSingularForeignEnum: + return m.GetSingularForeignEnum() + case fieldSingularImportEnum: + return m.GetSingularImportEnum() + + case fieldOptionalInt32: + return m.GetOptionalInt32() + case fieldOptionalInt64: + return m.GetOptionalInt64() + case fieldOptionalUint32: + return m.GetOptionalUint32() + case fieldOptionalUint64: + return m.GetOptionalUint64() + case fieldOptionalSint32: + return m.GetOptionalSint32() + case fieldOptionalSint64: + return m.GetOptionalSint64() + case fieldOptionalFixed32: + return m.GetOptionalFixed32() + case fieldOptionalFixed64: + return m.GetOptionalFixed64() + case fieldOptionalSfixed32: + return m.GetOptionalSfixed32() + case fieldOptionalSfixed64: + return m.GetOptionalSfixed64() + case fieldOptionalFloat: + return m.GetOptionalFloat() + case fieldOptionalDouble: + return m.GetOptionalDouble() + case fieldOptionalBool: + return m.GetOptionalBool() + case fieldOptionalString: + return m.GetOptionalString() + case fieldOptionalBytes: + return m.GetOptionalBytes() + case fieldOptionalGroup: + return m.GetOptionalgroup() + case fieldNotGroupLikeDelimited: + return m.GetNotGroupLikeDelimited() + case fieldOptionalNestedMessage: + return m.GetOptionalNestedMessage() + case fieldOptionalForeignMessage: + return m.GetOptionalForeignMessage() + case fieldOptionalImportMessage: + return m.GetOptionalImportMessage() + case fieldOptionalNestedEnum: + return m.GetOptionalNestedEnum() + case fieldOptionalForeignEnum: + return m.GetOptionalForeignEnum() + case fieldOptionalImportEnum: + return m.GetOptionalImportEnum() + case fieldOptionalLazyNestedMessage: + return m.GetOptionalLazyNestedMessage() + + case fieldRepeatedInt32: + return m.GetRepeatedInt32() + case fieldRepeatedInt64: + return m.GetRepeatedInt64() + case fieldRepeatedUint32: + return m.GetRepeatedUint32() + case fieldRepeatedUint64: + return m.GetRepeatedUint64() + case fieldRepeatedSint32: + return m.GetRepeatedSint32() + case fieldRepeatedSint64: + return m.GetRepeatedSint64() + case fieldRepeatedFixed32: + return m.GetRepeatedFixed32() + case fieldRepeatedFixed64: + return m.GetRepeatedFixed64() + case fieldRepeatedSfixed32: + return m.GetRepeatedSfixed32() + case fieldRepeatedSfixed64: + return m.GetRepeatedSfixed64() + case fieldRepeatedFloat: + return m.GetRepeatedFloat() + case fieldRepeatedDouble: + return m.GetRepeatedDouble() + case fieldRepeatedBool: + return m.GetRepeatedBool() + case fieldRepeatedString: + return m.GetRepeatedString() + case fieldRepeatedBytes: + return m.GetRepeatedBytes() + case fieldRepeatedGroup: + return m.GetRepeatedgroup() + case fieldRepeatedNestedMessage: + return m.GetRepeatedNestedMessage() + case fieldRepeatedForeignMessage: + return m.GetRepeatedForeignMessage() + case fieldRepeatedImportMessage: + return m.GetRepeatedImportmessage() + case fieldRepeatedNestedEnum: + return m.GetRepeatedNestedEnum() + case fieldRepeatedForeignEnum: + return m.GetRepeatedForeignEnum() + case fieldRepeatedImportEnum: + return m.GetRepeatedImportenum() + + case fieldMapInt32Int32: + return m.GetMapInt32Int32() + case fieldMapInt64Int64: + return m.GetMapInt64Int64() + case fieldMapUint32Uint32: + return m.GetMapUint32Uint32() + case fieldMapUint64Uint64: + return m.GetMapUint64Uint64() + case fieldMapSint32Sint32: + return m.GetMapSint32Sint32() + case fieldMapSint64Sint64: + return m.GetMapSint64Sint64() + case fieldMapFixed32Fixed32: + return m.GetMapFixed32Fixed32() + case fieldMapFixed64Fixed64: + return m.GetMapFixed64Fixed64() + case fieldMapSfixed32Sfixed32: + return m.GetMapSfixed32Sfixed32() + case fieldMapSfixed64Sfixed64: + return m.GetMapSfixed64Sfixed64() + case fieldMapInt32Float: + return m.GetMapInt32Float() + case fieldMapInt32Double: + return m.GetMapInt32Double() + case fieldMapBoolBool: + return m.GetMapBoolBool() + case fieldMapStringString: + return m.GetMapStringString() + case fieldMapStringBytes: + return m.GetMapStringBytes() + case fieldMapStringNestedMessage: + return m.GetMapStringNestedMessage() + case fieldMapStringNestedEnum: + return m.GetMapStringNestedEnum() + + case fieldDefaultInt32: + return m.GetDefaultInt32() + case fieldDefaultInt64: + return m.GetDefaultInt64() + case fieldDefaultUint32: + return m.GetDefaultUint32() + case fieldDefaultUint64: + return m.GetDefaultUint64() + case fieldDefaultSint32: + return m.GetDefaultSint32() + case fieldDefaultSint64: + return m.GetDefaultSint64() + case fieldDefaultFixed32: + return m.GetDefaultFixed32() + case fieldDefaultFixed64: + return m.GetDefaultFixed64() + case fieldDefaultSfixed32: + return m.GetDefaultSfixed32() + case fieldDefaultSfixed64: + return m.GetDefaultSfixed64() + case fieldDefaultFloat: + return m.GetDefaultFloat() + case fieldDefaultDouble: + return m.GetDefaultDouble() + case fieldDefaultBool: + return m.GetDefaultBool() + case fieldDefaultString: + return m.GetDefaultString() + case fieldDefaultBytes: + return m.GetDefaultBytes() + case fieldDefaultNestedEnum: + return m.GetDefaultNestedEnum() + case fieldDefaultForeignEnum: + return m.GetDefaultForeignEnum() + + case fieldOneofUint32: + return m.GetOneofUint32() + case fieldOneofNestedMessage: + return m.GetOneofNestedMessage() + case fieldOneofString: + return m.GetOneofString() + case fieldOneofBytes: + return m.GetOneofBytes() + case fieldOneofBool: + return m.GetOneofBool() + case fieldOneofUint64: + return m.GetOneofUint64() + case fieldOneofFloat: + return m.GetOneofFloat() + case fieldOneofDouble: + return m.GetOneofDouble() + case fieldOneofEnum: + return protoreflect.EnumNumber(m.GetOneofEnum()) + case fieldOneofGroup: + return m.GetOneofgroup() + case fieldOneofOptionalUint32: + return m.GetOneofOptionalUint32() + + default: + panic(fmt.Sprintf("get: unknown field %d", num)) + } + }, + set: func(num protoreflect.FieldNumber, v any) { + switch num { + case fieldSingularInt32: + m.SingularInt32 = v.(int32) + case fieldSingularInt64: + m.SingularInt64 = v.(int64) + case fieldSingularUint32: + m.SingularUint32 = v.(uint32) + case fieldSingularUint64: + m.SingularUint64 = v.(uint64) + case fieldSingularSint32: + m.SingularSint32 = v.(int32) + case fieldSingularSint64: + m.SingularSint64 = v.(int64) + case fieldSingularFixed32: + m.SingularFixed32 = v.(uint32) + case fieldSingularFixed64: + m.SingularFixed64 = v.(uint64) + case fieldSingularSfixed32: + m.SingularSfixed32 = v.(int32) + case fieldSingularSfixed64: + m.SingularSfixed64 = v.(int64) + case fieldSingularFloat: + m.SingularFloat = v.(float32) + case fieldSingularDouble: + m.SingularDouble = v.(float64) + case fieldSingularBool: + m.SingularBool = v.(bool) + case fieldSingularString: + m.SingularString = v.(string) + case fieldSingularBytes: + m.SingularBytes = v.([]byte) + case fieldSingularNestedEnum: + m.SingularNestedEnum = testpb.TestAllTypes_NestedEnum(v.(protoreflect.EnumNumber)) + case fieldSingularForeignEnum: + m.SingularForeignEnum = testpb.ForeignEnum(v.(protoreflect.EnumNumber)) + case fieldSingularImportEnum: + m.SingularImportEnum = testpb.ImportEnum(v.(protoreflect.EnumNumber)) + + case fieldOptionalInt32: + m.OptionalInt32 = proto.Int32(v.(int32)) + case fieldOptionalInt64: + m.OptionalInt64 = proto.Int64(v.(int64)) + case fieldOptionalUint32: + m.OptionalUint32 = proto.Uint32(v.(uint32)) + case fieldOptionalUint64: + m.OptionalUint64 = proto.Uint64(v.(uint64)) + case fieldOptionalSint32: + m.OptionalSint32 = proto.Int32(v.(int32)) + case fieldOptionalSint64: + m.OptionalSint64 = proto.Int64(v.(int64)) + case fieldOptionalFixed32: + m.OptionalFixed32 = proto.Uint32(v.(uint32)) + case fieldOptionalFixed64: + m.OptionalFixed64 = proto.Uint64(v.(uint64)) + case fieldOptionalSfixed32: + m.OptionalSfixed32 = proto.Int32(v.(int32)) + case fieldOptionalSfixed64: + m.OptionalSfixed64 = proto.Int64(v.(int64)) + case fieldOptionalFloat: + m.OptionalFloat = proto.Float32(v.(float32)) + case fieldOptionalDouble: + m.OptionalDouble = proto.Float64(v.(float64)) + case fieldOptionalBool: + m.OptionalBool = proto.Bool(v.(bool)) + case fieldOptionalString: + m.OptionalString = proto.String(v.(string)) + case fieldOptionalBytes: + if v.([]byte) == nil { + v = []byte{} + } + m.OptionalBytes = v.([]byte) + case fieldNotGroupLikeDelimited: + m.NotGroupLikeDelimited = v.(*testpb.TestAllTypes_OptionalGroup) + case fieldOptionalGroup: + m.Optionalgroup = v.(*testpb.TestAllTypes_OptionalGroup) + case fieldOptionalNestedMessage: + m.OptionalNestedMessage = v.(*testpb.TestAllTypes_NestedMessage) + case fieldOptionalForeignMessage: + m.OptionalForeignMessage = v.(*testpb.ForeignMessage) + case fieldOptionalImportMessage: + m.OptionalImportMessage = v.(*testpb.ImportMessage) + case fieldOptionalNestedEnum: + m.OptionalNestedEnum = testpb.TestAllTypes_NestedEnum(v.(protoreflect.EnumNumber)).Enum() + case fieldOptionalForeignEnum: + m.OptionalForeignEnum = testpb.ForeignEnum(v.(protoreflect.EnumNumber)).Enum() + case fieldOptionalImportEnum: + m.OptionalImportEnum = testpb.ImportEnum(v.(protoreflect.EnumNumber)).Enum() + case fieldOptionalLazyNestedMessage: + m.OptionalLazyNestedMessage = v.(*testpb.TestAllTypes_NestedMessage) + + case fieldRepeatedInt32: + m.RepeatedInt32 = v.([]int32) + case fieldRepeatedInt64: + m.RepeatedInt64 = v.([]int64) + case fieldRepeatedUint32: + m.RepeatedUint32 = v.([]uint32) + case fieldRepeatedUint64: + m.RepeatedUint64 = v.([]uint64) + case fieldRepeatedSint32: + m.RepeatedSint32 = v.([]int32) + case fieldRepeatedSint64: + m.RepeatedSint64 = v.([]int64) + case fieldRepeatedFixed32: + m.RepeatedFixed32 = v.([]uint32) + case fieldRepeatedFixed64: + m.RepeatedFixed64 = v.([]uint64) + case fieldRepeatedSfixed32: + m.RepeatedSfixed32 = v.([]int32) + case fieldRepeatedSfixed64: + m.RepeatedSfixed64 = v.([]int64) + case fieldRepeatedFloat: + m.RepeatedFloat = v.([]float32) + case fieldRepeatedDouble: + m.RepeatedDouble = v.([]float64) + case fieldRepeatedBool: + m.RepeatedBool = v.([]bool) + case fieldRepeatedString: + m.RepeatedString = v.([]string) + case fieldRepeatedBytes: + m.RepeatedBytes = v.([][]byte) + case fieldRepeatedGroup: + m.Repeatedgroup = v.([]*testpb.TestAllTypes_RepeatedGroup) + case fieldRepeatedNestedMessage: + m.RepeatedNestedMessage = v.([]*testpb.TestAllTypes_NestedMessage) + case fieldRepeatedForeignMessage: + m.RepeatedForeignMessage = v.([]*testpb.ForeignMessage) + case fieldRepeatedImportMessage: + m.RepeatedImportmessage = v.([]*testpb.ImportMessage) + case fieldRepeatedNestedEnum: + m.RepeatedNestedEnum = v.([]testpb.TestAllTypes_NestedEnum) + case fieldRepeatedForeignEnum: + m.RepeatedForeignEnum = v.([]testpb.ForeignEnum) + case fieldRepeatedImportEnum: + m.RepeatedImportenum = v.([]testpb.ImportEnum) + + case fieldMapInt32Int32: + m.MapInt32Int32 = v.(map[int32]int32) + case fieldMapInt64Int64: + m.MapInt64Int64 = v.(map[int64]int64) + case fieldMapUint32Uint32: + m.MapUint32Uint32 = v.(map[uint32]uint32) + case fieldMapUint64Uint64: + m.MapUint64Uint64 = v.(map[uint64]uint64) + case fieldMapSint32Sint32: + m.MapSint32Sint32 = v.(map[int32]int32) + case fieldMapSint64Sint64: + m.MapSint64Sint64 = v.(map[int64]int64) + case fieldMapFixed32Fixed32: + m.MapFixed32Fixed32 = v.(map[uint32]uint32) + case fieldMapFixed64Fixed64: + m.MapFixed64Fixed64 = v.(map[uint64]uint64) + case fieldMapSfixed32Sfixed32: + m.MapSfixed32Sfixed32 = v.(map[int32]int32) + case fieldMapSfixed64Sfixed64: + m.MapSfixed64Sfixed64 = v.(map[int64]int64) + case fieldMapInt32Float: + m.MapInt32Float = v.(map[int32]float32) + case fieldMapInt32Double: + m.MapInt32Double = v.(map[int32]float64) + case fieldMapBoolBool: + m.MapBoolBool = v.(map[bool]bool) + case fieldMapStringString: + m.MapStringString = v.(map[string]string) + case fieldMapStringBytes: + m.MapStringBytes = v.(map[string][]byte) + case fieldMapStringNestedMessage: + m.MapStringNestedMessage = v.(map[string]*testpb.TestAllTypes_NestedMessage) + case fieldMapStringNestedEnum: + m.MapStringNestedEnum = v.(map[string]testpb.TestAllTypes_NestedEnum) + + case fieldDefaultInt32: + m.DefaultInt32 = proto.Int32(v.(int32)) + case fieldDefaultInt64: + m.DefaultInt64 = proto.Int64(v.(int64)) + case fieldDefaultUint32: + m.DefaultUint32 = proto.Uint32(v.(uint32)) + case fieldDefaultUint64: + m.DefaultUint64 = proto.Uint64(v.(uint64)) + case fieldDefaultSint32: + m.DefaultSint32 = proto.Int32(v.(int32)) + case fieldDefaultSint64: + m.DefaultSint64 = proto.Int64(v.(int64)) + case fieldDefaultFixed32: + m.DefaultFixed32 = proto.Uint32(v.(uint32)) + case fieldDefaultFixed64: + m.DefaultFixed64 = proto.Uint64(v.(uint64)) + case fieldDefaultSfixed32: + m.DefaultSfixed32 = proto.Int32(v.(int32)) + case fieldDefaultSfixed64: + m.DefaultSfixed64 = proto.Int64(v.(int64)) + case fieldDefaultFloat: + m.DefaultFloat = proto.Float32(v.(float32)) + case fieldDefaultDouble: + m.DefaultDouble = proto.Float64(v.(float64)) + case fieldDefaultBool: + m.DefaultBool = proto.Bool(v.(bool)) + case fieldDefaultString: + m.DefaultString = proto.String(v.(string)) + case fieldDefaultBytes: + if v.([]byte) == nil { + v = []byte{} + } + m.DefaultBytes = v.([]byte) + case fieldDefaultNestedEnum: + m.DefaultNestedEnum = testpb.TestAllTypes_NestedEnum(v.(protoreflect.EnumNumber)).Enum() + case fieldDefaultForeignEnum: + m.DefaultForeignEnum = testpb.ForeignEnum(v.(protoreflect.EnumNumber)).Enum() + + case fieldOneofUint32: + m.OneofField = &testpb.TestAllTypes_OneofUint32{v.(uint32)} + case fieldOneofNestedMessage: + m.OneofField = &testpb.TestAllTypes_OneofNestedMessage{v.(*testpb.TestAllTypes_NestedMessage)} + case fieldOneofString: + m.OneofField = &testpb.TestAllTypes_OneofString{v.(string)} + case fieldOneofBytes: + m.OneofField = &testpb.TestAllTypes_OneofBytes{v.([]byte)} + case fieldOneofBool: + m.OneofField = &testpb.TestAllTypes_OneofBool{v.(bool)} + case fieldOneofUint64: + m.OneofField = &testpb.TestAllTypes_OneofUint64{v.(uint64)} + case fieldOneofFloat: + m.OneofField = &testpb.TestAllTypes_OneofFloat{v.(float32)} + case fieldOneofDouble: + m.OneofField = &testpb.TestAllTypes_OneofDouble{v.(float64)} + case fieldOneofEnum: + m.OneofField = &testpb.TestAllTypes_OneofEnum{testpb.TestAllTypes_NestedEnum(v.(protoreflect.EnumNumber))} + case fieldOneofGroup: + m.OneofField = &testpb.TestAllTypes_Oneofgroup{v.(*testpb.TestAllTypes_OneofGroup)} + case fieldOneofOptionalUint32: + m.OneofOptional = &testpb.TestAllTypes_OneofOptionalUint32{v.(uint32)} + + default: + panic(fmt.Sprintf("set: unknown field %d", num)) + } + }, + clear: func(num protoreflect.FieldNumber) { + switch num { + case fieldSingularInt32: + m.SingularInt32 = 0 + case fieldSingularInt64: + m.SingularInt64 = 0 + case fieldSingularUint32: + m.SingularUint32 = 0 + case fieldSingularUint64: + m.SingularUint64 = 0 + case fieldSingularSint32: + m.SingularSint32 = 0 + case fieldSingularSint64: + m.SingularSint64 = 0 + case fieldSingularFixed32: + m.SingularFixed32 = 0 + case fieldSingularFixed64: + m.SingularFixed64 = 0 + case fieldSingularSfixed32: + m.SingularSfixed32 = 0 + case fieldSingularSfixed64: + m.SingularSfixed64 = 0 + case fieldSingularFloat: + m.SingularFloat = 0 + case fieldSingularDouble: + m.SingularDouble = 0 + case fieldSingularBool: + m.SingularBool = false + case fieldSingularString: + m.SingularString = "" + case fieldSingularBytes: + m.SingularBytes = nil + case fieldSingularNestedEnum: + m.SingularNestedEnum = testpb.TestAllTypes_FOO + case fieldSingularForeignEnum: + m.SingularForeignEnum = testpb.ForeignEnum_FOREIGN_ZERO + case fieldSingularImportEnum: + m.SingularImportEnum = testpb.ImportEnum_IMPORT_ZERO + + case fieldOptionalInt32: + m.OptionalInt32 = nil + case fieldOptionalInt64: + m.OptionalInt64 = nil + case fieldOptionalUint32: + m.OptionalUint32 = nil + case fieldOptionalUint64: + m.OptionalUint64 = nil + case fieldOptionalSint32: + m.OptionalSint32 = nil + case fieldOptionalSint64: + m.OptionalSint64 = nil + case fieldOptionalFixed32: + m.OptionalFixed32 = nil + case fieldOptionalFixed64: + m.OptionalFixed64 = nil + case fieldOptionalSfixed32: + m.OptionalSfixed32 = nil + case fieldOptionalSfixed64: + m.OptionalSfixed64 = nil + case fieldOptionalFloat: + m.OptionalFloat = nil + case fieldOptionalDouble: + m.OptionalDouble = nil + case fieldOptionalBool: + m.OptionalBool = nil + case fieldOptionalString: + m.OptionalString = nil + case fieldOptionalBytes: + m.OptionalBytes = nil + case fieldOptionalGroup: + m.Optionalgroup = nil + case fieldNotGroupLikeDelimited: + m.NotGroupLikeDelimited = nil + case fieldOptionalNestedMessage: + m.OptionalNestedMessage = nil + case fieldOptionalForeignMessage: + m.OptionalForeignMessage = nil + case fieldOptionalImportMessage: + m.OptionalImportMessage = nil + case fieldOptionalNestedEnum: + m.OptionalNestedEnum = nil + case fieldOptionalForeignEnum: + m.OptionalForeignEnum = nil + case fieldOptionalImportEnum: + m.OptionalImportEnum = nil + case fieldOptionalLazyNestedMessage: + m.OptionalLazyNestedMessage = nil + + case fieldRepeatedInt32: + m.RepeatedInt32 = nil + case fieldRepeatedInt64: + m.RepeatedInt64 = nil + case fieldRepeatedUint32: + m.RepeatedUint32 = nil + case fieldRepeatedUint64: + m.RepeatedUint64 = nil + case fieldRepeatedSint32: + m.RepeatedSint32 = nil + case fieldRepeatedSint64: + m.RepeatedSint64 = nil + case fieldRepeatedFixed32: + m.RepeatedFixed32 = nil + case fieldRepeatedFixed64: + m.RepeatedFixed64 = nil + case fieldRepeatedSfixed32: + m.RepeatedSfixed32 = nil + case fieldRepeatedSfixed64: + m.RepeatedSfixed64 = nil + case fieldRepeatedFloat: + m.RepeatedFloat = nil + case fieldRepeatedDouble: + m.RepeatedDouble = nil + case fieldRepeatedBool: + m.RepeatedBool = nil + case fieldRepeatedString: + m.RepeatedString = nil + case fieldRepeatedBytes: + m.RepeatedBytes = nil + case fieldRepeatedGroup: + m.Repeatedgroup = nil + case fieldRepeatedNestedMessage: + m.RepeatedNestedMessage = nil + case fieldRepeatedForeignMessage: + m.RepeatedForeignMessage = nil + case fieldRepeatedImportMessage: + m.RepeatedImportmessage = nil + case fieldRepeatedNestedEnum: + m.RepeatedNestedEnum = nil + case fieldRepeatedForeignEnum: + m.RepeatedForeignEnum = nil + case fieldRepeatedImportEnum: + m.RepeatedImportenum = nil + + case fieldMapInt32Int32: + m.MapInt32Int32 = nil + case fieldMapInt64Int64: + m.MapInt64Int64 = nil + case fieldMapUint32Uint32: + m.MapUint32Uint32 = nil + case fieldMapUint64Uint64: + m.MapUint64Uint64 = nil + case fieldMapSint32Sint32: + m.MapSint32Sint32 = nil + case fieldMapSint64Sint64: + m.MapSint64Sint64 = nil + case fieldMapFixed32Fixed32: + m.MapFixed32Fixed32 = nil + case fieldMapFixed64Fixed64: + m.MapFixed64Fixed64 = nil + case fieldMapSfixed32Sfixed32: + m.MapSfixed32Sfixed32 = nil + case fieldMapSfixed64Sfixed64: + m.MapSfixed64Sfixed64 = nil + case fieldMapInt32Float: + m.MapInt32Float = nil + case fieldMapInt32Double: + m.MapInt32Double = nil + case fieldMapBoolBool: + m.MapBoolBool = nil + case fieldMapStringString: + m.MapStringString = nil + case fieldMapStringBytes: + m.MapStringBytes = nil + case fieldMapStringNestedMessage: + m.MapStringNestedMessage = nil + case fieldMapStringNestedEnum: + m.MapStringNestedEnum = nil + + case fieldDefaultInt32: + m.DefaultInt32 = nil + case fieldDefaultInt64: + m.DefaultInt64 = nil + case fieldDefaultUint32: + m.DefaultUint32 = nil + case fieldDefaultUint64: + m.DefaultUint64 = nil + case fieldDefaultSint32: + m.DefaultSint32 = nil + case fieldDefaultSint64: + m.DefaultSint64 = nil + case fieldDefaultFixed32: + m.DefaultFixed32 = nil + case fieldDefaultFixed64: + m.DefaultFixed64 = nil + case fieldDefaultSfixed32: + m.DefaultSfixed32 = nil + case fieldDefaultSfixed64: + m.DefaultSfixed64 = nil + case fieldDefaultFloat: + m.DefaultFloat = nil + case fieldDefaultDouble: + m.DefaultDouble = nil + case fieldDefaultBool: + m.DefaultBool = nil + case fieldDefaultString: + m.DefaultString = nil + case fieldDefaultBytes: + m.DefaultBytes = nil + case fieldDefaultNestedEnum: + m.DefaultNestedEnum = nil + case fieldDefaultForeignEnum: + m.DefaultForeignEnum = nil + + case fieldOneofUint32: + m.OneofField = nil + case fieldOneofNestedMessage: + m.OneofField = nil + case fieldOneofString: + m.OneofField = nil + case fieldOneofBytes: + m.OneofField = nil + case fieldOneofBool: + m.OneofField = nil + case fieldOneofUint64: + m.OneofField = nil + case fieldOneofFloat: + m.OneofField = nil + case fieldOneofDouble: + m.OneofField = nil + case fieldOneofEnum: + m.OneofField = nil + case fieldOneofGroup: + m.OneofField = nil + case fieldOneofOptionalUint32: + m.OneofOptional = nil + + default: + panic(fmt.Sprintf("clear: unknown field %d", num)) + } + }, + } +} diff --git a/internal/reflection_test/reflection_repeated_test.go b/internal/reflection_test/reflection_repeated_test.go new file mode 100644 index 000000000..9f4f32082 --- /dev/null +++ b/internal/reflection_test/reflection_repeated_test.go @@ -0,0 +1,43 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package reflection_test + +import ( + "testing" + + testopaquepb "google.golang.org/protobuf/internal/testprotos/testeditions/testeditions_opaque" + "google.golang.org/protobuf/proto" +) + +func TestOpaqueRepeated(t *testing.T) { + m := testopaquepb.TestAllTypes_builder{ + RepeatedNestedMessage: []*testopaquepb.TestAllTypes_NestedMessage{ + testopaquepb.TestAllTypes_NestedMessage_builder{ + A: proto.Int32(42), + }.Build(), + }, + }.Build() + + // Clear the repeated_nested_message field. This should not clear the presence bit. + mr := m.ProtoReflect() + fd := mr.Descriptor().Fields().ByNumber(48) + mr.Clear(fd) + if len(m.GetRepeatedNestedMessage()) != 0 { + t.Errorf("protoreflect Clear did not empty the repeated field: got %v, expected []", m.GetRepeatedNestedMessage()) + } + + // Append a new submessage to the input field and set its A field to 23. + dst := mr.Mutable(fd).List() + v := dst.NewElement() + dst.Append(v) + if len(m.GetRepeatedNestedMessage()) != 1 { + t.Fatalf("unexpected number of elements in repeated field: got %v, expected 1", len(m.GetRepeatedNestedMessage())) + } + m.GetRepeatedNestedMessage()[0].SetA(23) + + if mr.Get(fd).List().Len() != 1 { + t.Fatalf("presence bit (incorrectly) cleared") + } +} diff --git a/internal/reflection_test/reflection_test.go b/internal/reflection_test/reflection_test.go new file mode 100644 index 000000000..53f6491d9 --- /dev/null +++ b/internal/reflection_test/reflection_test.go @@ -0,0 +1,768 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package reflection_test + +import ( + "fmt" + "reflect" + "testing" + + testopenpb "google.golang.org/protobuf/internal/testprotos/testeditions" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/runtime/protoiface" + "google.golang.org/protobuf/testing/prototest" +) + +func Test(t *testing.T) { + t.Skip() + for _, m := range []protoreflect.ProtoMessage{ + &testopenpb.TestAllTypes{}, + } { + t.Run(fmt.Sprintf("%T", m), func(t *testing.T) { + prototest.Message{}.Test(t, m.ProtoReflect().Type()) + }) + } +} + +// What follows is infrastructure for a complicated but useful set of tests +// of different views of a message. +// +// Every Protobuf message can be accessed in at least two ways: +// +// - m: a concrete open, hybrid or opaque message +// - m.ProtoReflect(): reflective view of the message +// +// A mutation to one representation must be reflected in the others. +// +// To test the various views of a message, we construct an implementations of +// the protoreflect.Message interface for each. The simplest is the canonical +// reflective view provided by the ProtoReflect method. In addition, for each +// concrete representation we create another view backed by that concrete API. +// (i.e., m.ProtoReflect().KnownFields().Get(1) directly translates to a call +// to m.GetFieldOne().) +// +// Finally, we construct a "shadow" view in which read operations are backed +// by one implementation and write operations by another. +// +// Each of these various views may then be passed to the prototest package +// for validation. +// +// This approach separates the decision of what behaviors to test from the +// implementations being tested; new validation tests added to prototest +// apply to all the various views without additional effort. The disadvantage +// is that there is quite a bit of per-message boilerplate required. +// +// We could attempt to reduce that boilerplate by use of reflection or code +// generation, but both approaches replace simple-but-repetitive code with +// something quite complex. Since the purpose of all this is to test the +// complex, general-purpose canonical implementation, the simple approach +// is safer. + +// Field numbers for the test messages. +var ( + largeMessageDesc protoreflect.MessageDescriptor = (&testopenpb.TestManyMessageFieldsMessage{}).ProtoReflect().Descriptor() + + largeFieldF1 = largeMessageDesc.Fields().ByName("f1").Number() + largeFieldF2 = largeMessageDesc.Fields().ByName("f2").Number() + largeFieldF3 = largeMessageDesc.Fields().ByName("f3").Number() + largeFieldF4 = largeMessageDesc.Fields().ByName("f4").Number() + largeFieldF5 = largeMessageDesc.Fields().ByName("f5").Number() + largeFieldF6 = largeMessageDesc.Fields().ByName("f6").Number() + largeFieldF7 = largeMessageDesc.Fields().ByName("f7").Number() + largeFieldF8 = largeMessageDesc.Fields().ByName("f8").Number() + largeFieldF9 = largeMessageDesc.Fields().ByName("f9").Number() + largeFieldF10 = largeMessageDesc.Fields().ByName("f10").Number() + largeFieldF11 = largeMessageDesc.Fields().ByName("f11").Number() + largeFieldF12 = largeMessageDesc.Fields().ByName("f12").Number() + largeFieldF13 = largeMessageDesc.Fields().ByName("f13").Number() + largeFieldF14 = largeMessageDesc.Fields().ByName("f14").Number() + largeFieldF15 = largeMessageDesc.Fields().ByName("f15").Number() + largeFieldF16 = largeMessageDesc.Fields().ByName("f16").Number() + largeFieldF17 = largeMessageDesc.Fields().ByName("f17").Number() + largeFieldF18 = largeMessageDesc.Fields().ByName("f18").Number() + largeFieldF19 = largeMessageDesc.Fields().ByName("f19").Number() + largeFieldF20 = largeMessageDesc.Fields().ByName("f20").Number() + largeFieldF21 = largeMessageDesc.Fields().ByName("f21").Number() + largeFieldF22 = largeMessageDesc.Fields().ByName("f22").Number() + largeFieldF23 = largeMessageDesc.Fields().ByName("f23").Number() + largeFieldF24 = largeMessageDesc.Fields().ByName("f24").Number() + largeFieldF25 = largeMessageDesc.Fields().ByName("f25").Number() + largeFieldF26 = largeMessageDesc.Fields().ByName("f26").Number() + largeFieldF27 = largeMessageDesc.Fields().ByName("f27").Number() + largeFieldF28 = largeMessageDesc.Fields().ByName("f28").Number() + largeFieldF29 = largeMessageDesc.Fields().ByName("f29").Number() + largeFieldF30 = largeMessageDesc.Fields().ByName("f30").Number() + largeFieldF31 = largeMessageDesc.Fields().ByName("f31").Number() + largeFieldF32 = largeMessageDesc.Fields().ByName("f32").Number() + largeFieldF33 = largeMessageDesc.Fields().ByName("f33").Number() + largeFieldF34 = largeMessageDesc.Fields().ByName("f34").Number() + largeFieldF35 = largeMessageDesc.Fields().ByName("f35").Number() + largeFieldF36 = largeMessageDesc.Fields().ByName("f36").Number() + largeFieldF37 = largeMessageDesc.Fields().ByName("f37").Number() + largeFieldF38 = largeMessageDesc.Fields().ByName("f38").Number() + largeFieldF39 = largeMessageDesc.Fields().ByName("f39").Number() + largeFieldF40 = largeMessageDesc.Fields().ByName("f40").Number() + largeFieldF41 = largeMessageDesc.Fields().ByName("f41").Number() + largeFieldF42 = largeMessageDesc.Fields().ByName("f42").Number() + largeFieldF43 = largeMessageDesc.Fields().ByName("f43").Number() + largeFieldF44 = largeMessageDesc.Fields().ByName("f44").Number() + largeFieldF45 = largeMessageDesc.Fields().ByName("f45").Number() + largeFieldF46 = largeMessageDesc.Fields().ByName("f46").Number() + largeFieldF47 = largeMessageDesc.Fields().ByName("f47").Number() + largeFieldF48 = largeMessageDesc.Fields().ByName("f48").Number() + largeFieldF49 = largeMessageDesc.Fields().ByName("f49").Number() + largeFieldF50 = largeMessageDesc.Fields().ByName("f50").Number() + largeFieldF51 = largeMessageDesc.Fields().ByName("f51").Number() + largeFieldF52 = largeMessageDesc.Fields().ByName("f52").Number() + largeFieldF53 = largeMessageDesc.Fields().ByName("f53").Number() + largeFieldF54 = largeMessageDesc.Fields().ByName("f54").Number() + largeFieldF55 = largeMessageDesc.Fields().ByName("f55").Number() + largeFieldF56 = largeMessageDesc.Fields().ByName("f56").Number() + largeFieldF57 = largeMessageDesc.Fields().ByName("f57").Number() + largeFieldF58 = largeMessageDesc.Fields().ByName("f58").Number() + largeFieldF59 = largeMessageDesc.Fields().ByName("f59").Number() + largeFieldF60 = largeMessageDesc.Fields().ByName("f60").Number() + largeFieldF61 = largeMessageDesc.Fields().ByName("f61").Number() + largeFieldF62 = largeMessageDesc.Fields().ByName("f62").Number() + largeFieldF63 = largeMessageDesc.Fields().ByName("f63").Number() + largeFieldF64 = largeMessageDesc.Fields().ByName("f64").Number() + largeFieldF65 = largeMessageDesc.Fields().ByName("f65").Number() + largeFieldF66 = largeMessageDesc.Fields().ByName("f66").Number() + largeFieldF67 = largeMessageDesc.Fields().ByName("f67").Number() + largeFieldF68 = largeMessageDesc.Fields().ByName("f68").Number() + largeFieldF69 = largeMessageDesc.Fields().ByName("f69").Number() + largeFieldF70 = largeMessageDesc.Fields().ByName("f70").Number() + largeFieldF71 = largeMessageDesc.Fields().ByName("f71").Number() + largeFieldF72 = largeMessageDesc.Fields().ByName("f72").Number() + largeFieldF73 = largeMessageDesc.Fields().ByName("f73").Number() + largeFieldF74 = largeMessageDesc.Fields().ByName("f74").Number() + largeFieldF75 = largeMessageDesc.Fields().ByName("f75").Number() + largeFieldF76 = largeMessageDesc.Fields().ByName("f76").Number() + largeFieldF77 = largeMessageDesc.Fields().ByName("f77").Number() + largeFieldF78 = largeMessageDesc.Fields().ByName("f78").Number() + largeFieldF79 = largeMessageDesc.Fields().ByName("f79").Number() + largeFieldF80 = largeMessageDesc.Fields().ByName("f80").Number() + largeFieldF81 = largeMessageDesc.Fields().ByName("f81").Number() + largeFieldF82 = largeMessageDesc.Fields().ByName("f82").Number() + largeFieldF83 = largeMessageDesc.Fields().ByName("f83").Number() + largeFieldF84 = largeMessageDesc.Fields().ByName("f84").Number() + largeFieldF85 = largeMessageDesc.Fields().ByName("f85").Number() + largeFieldF86 = largeMessageDesc.Fields().ByName("f86").Number() + largeFieldF87 = largeMessageDesc.Fields().ByName("f87").Number() + largeFieldF88 = largeMessageDesc.Fields().ByName("f88").Number() + largeFieldF89 = largeMessageDesc.Fields().ByName("f89").Number() + largeFieldF90 = largeMessageDesc.Fields().ByName("f90").Number() + largeFieldF91 = largeMessageDesc.Fields().ByName("f91").Number() + largeFieldF92 = largeMessageDesc.Fields().ByName("f92").Number() + largeFieldF93 = largeMessageDesc.Fields().ByName("f93").Number() + largeFieldF94 = largeMessageDesc.Fields().ByName("f94").Number() + largeFieldF95 = largeMessageDesc.Fields().ByName("f95").Number() + largeFieldF96 = largeMessageDesc.Fields().ByName("f96").Number() + largeFieldF97 = largeMessageDesc.Fields().ByName("f97").Number() + largeFieldF98 = largeMessageDesc.Fields().ByName("f98").Number() + largeFieldF99 = largeMessageDesc.Fields().ByName("f99").Number() + largeFieldF100 = largeMessageDesc.Fields().ByName("f100").Number() +) + +var ( + messageDesc protoreflect.MessageDescriptor = (&testopenpb.TestAllTypes{}).ProtoReflect().Descriptor() + + fieldSingularInt32 = messageDesc.Fields().ByName("singular_int32").Number() + fieldSingularInt64 = messageDesc.Fields().ByName("singular_int64").Number() + fieldSingularUint32 = messageDesc.Fields().ByName("singular_uint32").Number() + fieldSingularUint64 = messageDesc.Fields().ByName("singular_uint64").Number() + fieldSingularSint32 = messageDesc.Fields().ByName("singular_sint32").Number() + fieldSingularSint64 = messageDesc.Fields().ByName("singular_sint64").Number() + fieldSingularFixed32 = messageDesc.Fields().ByName("singular_fixed32").Number() + fieldSingularFixed64 = messageDesc.Fields().ByName("singular_fixed64").Number() + fieldSingularSfixed32 = messageDesc.Fields().ByName("singular_sfixed32").Number() + fieldSingularSfixed64 = messageDesc.Fields().ByName("singular_sfixed64").Number() + fieldSingularFloat = messageDesc.Fields().ByName("singular_float").Number() + fieldSingularDouble = messageDesc.Fields().ByName("singular_double").Number() + fieldSingularBool = messageDesc.Fields().ByName("singular_bool").Number() + fieldSingularString = messageDesc.Fields().ByName("singular_string").Number() + fieldSingularBytes = messageDesc.Fields().ByName("singular_bytes").Number() + fieldSingularNestedEnum = messageDesc.Fields().ByName("singular_nested_enum").Number() + fieldSingularForeignEnum = messageDesc.Fields().ByName("singular_foreign_enum").Number() + fieldSingularImportEnum = messageDesc.Fields().ByName("singular_import_enum").Number() + + fieldOptionalInt32 = messageDesc.Fields().ByName("optional_int32").Number() + fieldOptionalInt64 = messageDesc.Fields().ByName("optional_int64").Number() + fieldOptionalUint32 = messageDesc.Fields().ByName("optional_uint32").Number() + fieldOptionalUint64 = messageDesc.Fields().ByName("optional_uint64").Number() + fieldOptionalSint32 = messageDesc.Fields().ByName("optional_sint32").Number() + fieldOptionalSint64 = messageDesc.Fields().ByName("optional_sint64").Number() + fieldOptionalFixed32 = messageDesc.Fields().ByName("optional_fixed32").Number() + fieldOptionalFixed64 = messageDesc.Fields().ByName("optional_fixed64").Number() + fieldOptionalSfixed32 = messageDesc.Fields().ByName("optional_sfixed32").Number() + fieldOptionalSfixed64 = messageDesc.Fields().ByName("optional_sfixed64").Number() + fieldOptionalFloat = messageDesc.Fields().ByName("optional_float").Number() + fieldOptionalDouble = messageDesc.Fields().ByName("optional_double").Number() + fieldOptionalBool = messageDesc.Fields().ByName("optional_bool").Number() + fieldOptionalString = messageDesc.Fields().ByName("optional_string").Number() + fieldOptionalBytes = messageDesc.Fields().ByName("optional_bytes").Number() + fieldOptionalGroup = messageDesc.Fields().ByName("optionalgroup").Number() + fieldOptionalNestedMessage = messageDesc.Fields().ByName("optional_nested_message").Number() + fieldOptionalForeignMessage = messageDesc.Fields().ByName("optional_foreign_message").Number() + fieldOptionalImportMessage = messageDesc.Fields().ByName("optional_import_message").Number() + fieldOptionalNestedEnum = messageDesc.Fields().ByName("optional_nested_enum").Number() + fieldOptionalForeignEnum = messageDesc.Fields().ByName("optional_foreign_enum").Number() + fieldOptionalImportEnum = messageDesc.Fields().ByName("optional_import_enum").Number() + fieldOptionalLazyNestedMessage = messageDesc.Fields().ByName("optional_lazy_nested_message").Number() + fieldNotGroupLikeDelimited = messageDesc.Fields().ByName("not_group_like_delimited").Number() + + fieldRepeatedInt32 = messageDesc.Fields().ByName("repeated_int32").Number() + fieldRepeatedInt64 = messageDesc.Fields().ByName("repeated_int64").Number() + fieldRepeatedUint32 = messageDesc.Fields().ByName("repeated_uint32").Number() + fieldRepeatedUint64 = messageDesc.Fields().ByName("repeated_uint64").Number() + fieldRepeatedSint32 = messageDesc.Fields().ByName("repeated_sint32").Number() + fieldRepeatedSint64 = messageDesc.Fields().ByName("repeated_sint64").Number() + fieldRepeatedFixed32 = messageDesc.Fields().ByName("repeated_fixed32").Number() + fieldRepeatedFixed64 = messageDesc.Fields().ByName("repeated_fixed64").Number() + fieldRepeatedSfixed32 = messageDesc.Fields().ByName("repeated_sfixed32").Number() + fieldRepeatedSfixed64 = messageDesc.Fields().ByName("repeated_sfixed64").Number() + fieldRepeatedFloat = messageDesc.Fields().ByName("repeated_float").Number() + fieldRepeatedDouble = messageDesc.Fields().ByName("repeated_double").Number() + fieldRepeatedBool = messageDesc.Fields().ByName("repeated_bool").Number() + fieldRepeatedString = messageDesc.Fields().ByName("repeated_string").Number() + fieldRepeatedBytes = messageDesc.Fields().ByName("repeated_bytes").Number() + fieldRepeatedGroup = messageDesc.Fields().ByName("repeatedgroup").Number() + fieldRepeatedNestedMessage = messageDesc.Fields().ByName("repeated_nested_message").Number() + fieldRepeatedForeignMessage = messageDesc.Fields().ByName("repeated_foreign_message").Number() + fieldRepeatedImportMessage = messageDesc.Fields().ByName("repeated_importmessage").Number() + fieldRepeatedNestedEnum = messageDesc.Fields().ByName("repeated_nested_enum").Number() + fieldRepeatedForeignEnum = messageDesc.Fields().ByName("repeated_foreign_enum").Number() + fieldRepeatedImportEnum = messageDesc.Fields().ByName("repeated_importenum").Number() + + fieldMapInt32Int32 = messageDesc.Fields().ByName("map_int32_int32").Number() + fieldMapInt64Int64 = messageDesc.Fields().ByName("map_int64_int64").Number() + fieldMapUint32Uint32 = messageDesc.Fields().ByName("map_uint32_uint32").Number() + fieldMapUint64Uint64 = messageDesc.Fields().ByName("map_uint64_uint64").Number() + fieldMapSint32Sint32 = messageDesc.Fields().ByName("map_sint32_sint32").Number() + fieldMapSint64Sint64 = messageDesc.Fields().ByName("map_sint64_sint64").Number() + fieldMapFixed32Fixed32 = messageDesc.Fields().ByName("map_fixed32_fixed32").Number() + fieldMapFixed64Fixed64 = messageDesc.Fields().ByName("map_fixed64_fixed64").Number() + fieldMapSfixed32Sfixed32 = messageDesc.Fields().ByName("map_sfixed32_sfixed32").Number() + fieldMapSfixed64Sfixed64 = messageDesc.Fields().ByName("map_sfixed64_sfixed64").Number() + fieldMapInt32Float = messageDesc.Fields().ByName("map_int32_float").Number() + fieldMapInt32Double = messageDesc.Fields().ByName("map_int32_double").Number() + fieldMapBoolBool = messageDesc.Fields().ByName("map_bool_bool").Number() + fieldMapStringString = messageDesc.Fields().ByName("map_string_string").Number() + fieldMapStringBytes = messageDesc.Fields().ByName("map_string_bytes").Number() + fieldMapStringNestedMessage = messageDesc.Fields().ByName("map_string_nested_message").Number() + fieldMapStringNestedEnum = messageDesc.Fields().ByName("map_string_nested_enum").Number() + + fieldDefaultInt32 = messageDesc.Fields().ByName("default_int32").Number() + fieldDefaultInt64 = messageDesc.Fields().ByName("default_int64").Number() + fieldDefaultUint32 = messageDesc.Fields().ByName("default_uint32").Number() + fieldDefaultUint64 = messageDesc.Fields().ByName("default_uint64").Number() + fieldDefaultSint32 = messageDesc.Fields().ByName("default_sint32").Number() + fieldDefaultSint64 = messageDesc.Fields().ByName("default_sint64").Number() + fieldDefaultFixed32 = messageDesc.Fields().ByName("default_fixed32").Number() + fieldDefaultFixed64 = messageDesc.Fields().ByName("default_fixed64").Number() + fieldDefaultSfixed32 = messageDesc.Fields().ByName("default_sfixed32").Number() + fieldDefaultSfixed64 = messageDesc.Fields().ByName("default_sfixed64").Number() + fieldDefaultFloat = messageDesc.Fields().ByName("default_float").Number() + fieldDefaultDouble = messageDesc.Fields().ByName("default_double").Number() + fieldDefaultBool = messageDesc.Fields().ByName("default_bool").Number() + fieldDefaultString = messageDesc.Fields().ByName("default_string").Number() + fieldDefaultBytes = messageDesc.Fields().ByName("default_bytes").Number() + fieldDefaultNestedEnum = messageDesc.Fields().ByName("default_nested_enum").Number() + fieldDefaultForeignEnum = messageDesc.Fields().ByName("default_foreign_enum").Number() + + fieldOneofUint32 = messageDesc.Fields().ByName("oneof_uint32").Number() + fieldOneofNestedMessage = messageDesc.Fields().ByName("oneof_nested_message").Number() + fieldOneofString = messageDesc.Fields().ByName("oneof_string").Number() + fieldOneofBytes = messageDesc.Fields().ByName("oneof_bytes").Number() + fieldOneofBool = messageDesc.Fields().ByName("oneof_bool").Number() + fieldOneofUint64 = messageDesc.Fields().ByName("oneof_uint64").Number() + fieldOneofFloat = messageDesc.Fields().ByName("oneof_float").Number() + fieldOneofDouble = messageDesc.Fields().ByName("oneof_double").Number() + fieldOneofEnum = messageDesc.Fields().ByName("oneof_enum").Number() + fieldOneofGroup = messageDesc.Fields().ByName("oneofgroup").Number() + fieldOneofOptionalUint32 = messageDesc.Fields().ByName("oneof_optional_uint32").Number() +) + +// testMessageType is an implementation of protoreflect.MessageType. +type testMessageType struct { + protoreflect.MessageDescriptor + new func() protoreflect.Message +} + +func (m *testMessageType) New() protoreflect.Message { return m.new() } +func (m *testMessageType) Zero() protoreflect.Message { return m.new() } +func (m *testMessageType) GoType() reflect.Type { panic("unimplemented") } +func (m *testMessageType) Descriptor() protoreflect.MessageDescriptor { return m.MessageDescriptor } + +// testProtoMessage adapts the concrete API for a message to the ProtoReflect interface. +type testProtoMessage struct { + m protoreflect.ProtoMessage + md protoreflect.MessageDescriptor + new func() protoreflect.Message + has func(protoreflect.FieldNumber) bool + get func(protoreflect.FieldNumber) any + set func(protoreflect.FieldNumber, any) + clear func(protoreflect.FieldNumber) +} + +func (m *testProtoMessage) ProtoReflect() protoreflect.Message { return (*testMessage)(m) } + +// testMessage implements protoreflect.Message. +type testMessage testProtoMessage + +func (m *testMessage) Interface() protoreflect.ProtoMessage { return (*testProtoMessage)(m) } +func (m *testMessage) ProtoMethods() *protoiface.Methods { return nil } +func (m *testMessage) Descriptor() protoreflect.MessageDescriptor { return m.md } +func (m *testMessage) Type() protoreflect.MessageType { return &testMessageType{m.md, m.new} } +func (m *testMessage) New() protoreflect.Message { return m.new() } +func (m *testMessage) Range(f func(protoreflect.FieldDescriptor, protoreflect.Value) bool) { + fields := m.md.Fields() + for i := 0; i < fields.Len(); i++ { + fd := fields.Get(i) + if !m.Has(fd) { + continue + } + if !f(fd, m.Get(fd)) { + break + } + } +} +func (m *testMessage) Has(fd protoreflect.FieldDescriptor) bool { + return m.has(fd.Number()) +} +func (m *testMessage) Clear(fd protoreflect.FieldDescriptor) { + m.clear(fd.Number()) +} +func (m *testMessage) Get(fd protoreflect.FieldDescriptor) protoreflect.Value { + num := fd.Number() + switch { + case fd.IsMap(): + if !m.has(num) { + return protoreflect.ValueOfMap(&testMap{reflect.Zero(reflect.TypeOf(m.get(num))), fd}) + } + return protoreflect.ValueOfMap(&testMap{reflect.ValueOf(m.get(num)), fd}) + case fd.IsList(): + if !m.has(num) { + return protoreflect.ValueOfList(&zeroList{reflect.TypeOf(m.get(num)).Elem(), fd}) + } + return protoreflect.ValueOfList(&testList{m: (*testProtoMessage)(m), fd: fd}) + case fd.Message() != nil: + if !m.has(fd.Number()) { + return protoreflect.Value{} + } + } + return singularValueOf(m.get(num)) +} +func (m *testMessage) Set(fd protoreflect.FieldDescriptor, v protoreflect.Value) { + num := fd.Number() + switch { + case fd.IsMap(): + if !v.Map().IsValid() { + panic("set with invalid map") + } + m.set(num, v.Map().(*testMap).val.Interface()) + case fd.IsList(): + if !v.List().IsValid() { + panic("set with invalid list") + } + m.set(num, v.List().(*testList).field().Interface()) + case fd.Message() != nil: + i := v.Message().Interface() + if p, ok := i.(*testProtoMessage); ok { + i = p.m + } + m.set(num, i) + default: + m.set(num, v.Interface()) + } +} +func (m *testMessage) Mutable(fd protoreflect.FieldDescriptor) protoreflect.Value { + num := fd.Number() + if !m.Has(fd) && (fd.IsMap() || fd.IsList() || fd.Message() != nil) { + switch { + case fd.IsMap(): + typ := reflect.ValueOf(m.get(num)).Type() + m.set(num, reflect.MakeMap(typ).Interface()) + return protoreflect.ValueOfMap(&testMap{reflect.ValueOf(m.get(num)), fd}) + case fd.IsList(): + return protoreflect.ValueOfList(&testList{m: (*testProtoMessage)(m), fd: fd}) + case fd.Message() != nil: + typ := reflect.ValueOf(m.get(num)).Type() + m.set(num, reflect.New(typ.Elem()).Interface()) + } + } + return m.Get(fd) +} +func (m *testMessage) NewMessage(fd protoreflect.FieldDescriptor) protoreflect.Message { + return singularValueOf(m.NewField(fd)).Message() +} +func (m *testMessage) NewField(fd protoreflect.FieldDescriptor) protoreflect.Value { + num := fd.Number() + switch { + case fd.IsMap(): + typ := reflect.ValueOf(m.get(num)).Type() + return protoreflect.ValueOf(&testMap{reflect.MakeMap(typ), fd}) + case fd.IsList(): + typ := reflect.ValueOf(m.get(num)).Type() + return protoreflect.ValueOf(&testList{val: reflect.Zero(typ), fd: fd}) + case fd.Message() != nil: + typ := reflect.ValueOf(m.get(num)).Type() + return singularValueOf(reflect.New(typ.Elem()).Interface()) + default: + // Obtain the default value of the field by creating an empty message + // and calling the getter. + n := m.new().(*testMessage) + return singularValueOf(n.get(num)) + } +} +func (m *testMessage) WhichOneof(od protoreflect.OneofDescriptor) protoreflect.FieldDescriptor { + for i := 0; i < od.Fields().Len(); i++ { + fd := od.Fields().Get(i) + if m.has(fd.Number()) { + return fd + } + } + return nil +} +func (m *testMessage) GetUnknown() protoreflect.RawFields { + return m.m.ProtoReflect().GetUnknown() +} +func (m *testMessage) SetUnknown(raw protoreflect.RawFields) { + m.m.ProtoReflect().SetUnknown(raw) +} +func (m *testMessage) IsValid() bool { + return !reflect.ValueOf(m.m).IsNil() +} + +func singularValueOf(v any) protoreflect.Value { + switch v := v.(type) { + case protoreflect.ProtoMessage: + return protoreflect.ValueOf(v.ProtoReflect()) + case protoreflect.Enum: + return protoreflect.ValueOf(v.Number()) + default: + return protoreflect.ValueOf(v) + } +} + +// testList implements protoreflect.List over a concrete slice of values. +type testList struct { + m *testProtoMessage + val reflect.Value + fd protoreflect.FieldDescriptor +} + +func (x *testList) field() reflect.Value { + if x.m == nil { + return x.val + } + return reflect.ValueOf(x.m.get(x.fd.Number())) +} +func (x *testList) setField(v reflect.Value) { + if x.m == nil { + x.val = v + return + } + x.m.set(x.fd.Number(), v.Interface()) +} +func (x *testList) Len() int { return x.field().Len() } +func (x *testList) Get(n int) protoreflect.Value { + return singularValueOf(x.field().Index(n).Interface()) +} +func (x *testList) Set(n int, v protoreflect.Value) { + switch x.fd.Kind() { + case protoreflect.MessageKind, protoreflect.GroupKind: + x.field().Index(n).Set(reflect.ValueOf(v.Message().Interface())) + case protoreflect.EnumKind: + x.field().Index(n).SetInt(int64(v.Enum())) + default: + x.field().Index(n).Set(reflect.ValueOf(v.Interface())) + } +} +func (x *testList) Append(v protoreflect.Value) { + f := x.field() + x.setField(reflect.Append(f, reflect.Zero(f.Type().Elem()))) + x.Set(f.Len(), v) +} +func (x *testList) AppendMutable() protoreflect.Value { + if x.fd.Message() == nil { + panic("invalid AppendMutable on list with non-message value type") + } + v := x.NewElement() + x.Append(v) + return v +} +func (x *testList) Truncate(n int) { + x.setField(x.field().Slice(0, n)) +} +func (x *testList) NewMessage() protoreflect.Message { + return x.NewElement().Message() +} +func (x *testList) NewElement() protoreflect.Value { + // For enums, List.NewElement returns the first enum value. + if ee := newEnumElement(x.fd); ee.IsValid() { + return ee + } + var v reflect.Value + typ := x.field().Type().Elem() + if typ.Kind() == reflect.Ptr { + v = reflect.New(typ.Elem()) + } else { + v = reflect.Zero(typ) + } + return singularValueOf(v.Interface()) +} +func (x *testList) IsValid() bool { + return true +} + +func newEnumElement(fd protoreflect.FieldDescriptor) protoreflect.Value { + if fd.Kind() != protoreflect.EnumKind { + return protoreflect.Value{} + } + if val := fd.Enum().Values(); val.Len() > 0 { + return protoreflect.ValueOfEnum(val.Get(0).Number()) + } + return protoreflect.Value{} +} + +// testList implements protoreflect.List over a concrete slice of values. +type zeroList struct { + typ reflect.Type + fd protoreflect.FieldDescriptor +} + +func (x *zeroList) Len() int { return 0 } +func (x *zeroList) Get(n int) protoreflect.Value { panic("get on zero list") } +func (x *zeroList) Set(n int, v protoreflect.Value) { panic("set on zero list") } +func (x *zeroList) Append(v protoreflect.Value) { panic("append on zero list") } +func (x *zeroList) AppendMutable() protoreflect.Value { panic("append on zero list") } +func (x *zeroList) Truncate(n int) { panic("truncate on zero list") } +func (x *zeroList) NewMessage() protoreflect.Message { + return x.NewElement().Message() +} +func (x *zeroList) NewElement() protoreflect.Value { + // For enums, List.NewElement returns the first enum value. + if ee := newEnumElement(x.fd); ee.IsValid() { + return ee + } + var v reflect.Value + if x.typ.Kind() == reflect.Ptr { + v = reflect.New(x.typ.Elem()) + } else { + v = reflect.Zero(x.typ) + } + return singularValueOf(v.Interface()) +} +func (x *zeroList) IsValid() bool { + return false +} + +// testMap implements a protoreflect.Map over a concrete map. +type testMap struct { + val reflect.Value + fd protoreflect.FieldDescriptor +} + +func (x *testMap) key(k protoreflect.MapKey) reflect.Value { return reflect.ValueOf(k.Interface()) } +func (x *testMap) valueToProto(v reflect.Value) protoreflect.Value { + if !v.IsValid() { + return protoreflect.Value{} + } + switch x.fd.Message().Fields().ByNumber(2).Kind() { + case protoreflect.MessageKind, protoreflect.GroupKind: + return protoreflect.ValueOf(v.Interface().(protoreflect.ProtoMessage).ProtoReflect()) + case protoreflect.EnumKind: + return protoreflect.ValueOf(protoreflect.EnumNumber(v.Int())) + default: + return protoreflect.ValueOf(v.Interface()) + } +} +func (x *testMap) Len() int { return x.val.Len() } +func (x *testMap) Has(k protoreflect.MapKey) bool { return x.val.MapIndex(x.key(k)).IsValid() } +func (x *testMap) Get(k protoreflect.MapKey) protoreflect.Value { + return x.valueToProto(x.val.MapIndex(x.key(k))) +} +func (x *testMap) Set(k protoreflect.MapKey, v protoreflect.Value) { + f := x.val + switch x.fd.MapValue().Kind() { + case protoreflect.MessageKind, protoreflect.GroupKind: + f.SetMapIndex(x.key(k), reflect.ValueOf(v.Message().Interface())) + case protoreflect.EnumKind: + rv := reflect.New(f.Type().Elem()).Elem() + rv.SetInt(int64(v.Enum())) + f.SetMapIndex(x.key(k), rv) + default: + f.SetMapIndex(x.key(k), reflect.ValueOf(v.Interface())) + } +} +func (x *testMap) Mutable(k protoreflect.MapKey) protoreflect.Value { + if x.fd.MapValue().Message() == nil { + panic("invalid Mutable on map with non-message value type") + } + v := x.Get(k) + if !v.IsValid() { + v = x.NewValue() + x.Set(k, v) + } + return v +} +func (x *testMap) Clear(k protoreflect.MapKey) { x.val.SetMapIndex(x.key(k), reflect.Value{}) } +func (x *testMap) Range(f func(protoreflect.MapKey, protoreflect.Value) bool) { + iter := x.val.MapRange() + for iter.Next() { + if !f(protoreflect.ValueOf(iter.Key().Interface()).MapKey(), x.valueToProto(iter.Value())) { + return + } + } +} +func (x *testMap) NewMessage() protoreflect.Message { + return x.NewValue().Message() +} +func (x *testMap) NewValue() protoreflect.Value { + var v reflect.Value + if x.fd.MapValue().Message() != nil { + v = reflect.New(x.val.Type().Elem().Elem()) + } else { + v = reflect.Zero(x.val.Type().Elem()) + } + return singularValueOf(v.Interface()) +} +func (x *testMap) IsValid() bool { + return !x.val.IsNil() +} + +// A shadow message is a wrapper around two protoreflect.Message implementations +// presenting different views of the same underlying data. Read operations +// are directed to one implementation and write operations to the other. + +// shadowProtoMessage implements protoreflect.ProtoMessage as a shadow. +type shadowProtoMessage struct { + get, set protoreflect.Message + new func() (get, set protoreflect.ProtoMessage) +} + +func newShadow(newf func() (get, set protoreflect.ProtoMessage)) protoreflect.ProtoMessage { + get, set := newf() + return &shadowProtoMessage{ + get.ProtoReflect(), + set.ProtoReflect(), + newf, + } +} + +func (m *shadowProtoMessage) ProtoReflect() protoreflect.Message { return (*shadowMessage)(m) } + +// shadowMessage implements protoreflect.Message as a shadow. +type shadowMessage shadowProtoMessage + +func (m *shadowMessage) Interface() protoreflect.ProtoMessage { return (*shadowProtoMessage)(m) } +func (m *shadowMessage) ProtoMethods() *protoiface.Methods { return nil } +func (m *shadowMessage) Descriptor() protoreflect.MessageDescriptor { return m.get.Descriptor() } +func (m *shadowMessage) Type() protoreflect.MessageType { + return &testMessageType{m.Descriptor(), m.New} +} +func (m *shadowMessage) New() protoreflect.Message { + get, set := m.new() + return &shadowMessage{ + get: get.ProtoReflect(), + set: set.ProtoReflect(), + new: m.new, + } +} + +// TODO: Implement these. +func (m *shadowMessage) Range(f func(protoreflect.FieldDescriptor, protoreflect.Value) bool) { + m.get.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool { + return f(fd, m.Get(fd)) + }) +} +func (m *shadowMessage) Has(fd protoreflect.FieldDescriptor) bool { + return m.get.Has(fd) +} +func (m *shadowMessage) Clear(fd protoreflect.FieldDescriptor) { + m.set.Clear(fd) +} +func (m *shadowMessage) Get(fd protoreflect.FieldDescriptor) protoreflect.Value { + v := m.get.Get(fd) + switch { + case fd.IsList(): + return protoreflect.ValueOfList(&shadowList{v.List(), m.set.Get(fd).List()}) + case fd.IsMap(): + return protoreflect.ValueOfMap(&shadowMap{v.Map(), m.set.Get(fd).Map()}) + default: + return v + } +} +func (m *shadowMessage) Set(fd protoreflect.FieldDescriptor, v protoreflect.Value) { + switch x := v.Interface().(type) { + case *shadowList: + m.set.Set(fd, protoreflect.ValueOf(x.set)) + case *shadowMap: + m.set.Set(fd, protoreflect.ValueOf(x.set)) + default: + m.set.Set(fd, v) + } +} +func (m *shadowMessage) Mutable(fd protoreflect.FieldDescriptor) protoreflect.Value { + v := m.get.Mutable(fd) + switch { + case fd.IsList(): + return protoreflect.ValueOf(&shadowList{v.List(), m.set.Mutable(fd).List()}) + case fd.IsMap(): + return protoreflect.ValueOf(&shadowMap{v.Map(), m.set.Mutable(fd).Map()}) + default: + return v + } +} +func (m *shadowMessage) NewMessage(fd protoreflect.FieldDescriptor) protoreflect.Message { + return m.NewField(fd).Message() +} +func (m *shadowMessage) NewField(fd protoreflect.FieldDescriptor) protoreflect.Value { + return m.set.NewField(fd) +} +func (m *shadowMessage) WhichOneof(od protoreflect.OneofDescriptor) protoreflect.FieldDescriptor { + return m.get.WhichOneof(od) +} +func (m *shadowMessage) GetUnknown() protoreflect.RawFields { + return m.get.GetUnknown() +} +func (m *shadowMessage) SetUnknown(raw protoreflect.RawFields) { + m.set.SetUnknown(raw) +} +func (m *shadowMessage) IsValid() bool { + return m.get.IsValid() +} + +// shadowList implements protoreflect.List as a shadow. +type shadowList struct { + get, set protoreflect.List +} + +func (x *shadowList) Len() int { return x.get.Len() } +func (x *shadowList) Get(n int) protoreflect.Value { return x.get.Get(n) } +func (x *shadowList) Set(n int, v protoreflect.Value) { x.set.Set(n, v) } +func (x *shadowList) Append(v protoreflect.Value) { x.set.Append(v) } +func (x *shadowList) AppendMutable() protoreflect.Value { return x.set.AppendMutable() } +func (x *shadowList) Truncate(n int) { x.set.Truncate(n) } +func (x *shadowList) NewMessage() protoreflect.Message { return x.set.NewElement().Message() } +func (x *shadowList) NewElement() protoreflect.Value { return x.set.NewElement() } +func (x *shadowList) IsValid() bool { return x.get.IsValid() } + +// shadowMap implements protoreflect.Map as a shadow. +type shadowMap struct { + get, set protoreflect.Map +} + +func (x *shadowMap) Len() int { return x.get.Len() } +func (x *shadowMap) Has(k protoreflect.MapKey) bool { return x.get.Has(k) } +func (x *shadowMap) Get(k protoreflect.MapKey) protoreflect.Value { return x.get.Get(k) } +func (x *shadowMap) Set(k protoreflect.MapKey, v protoreflect.Value) { x.set.Set(k, v) } +func (x *shadowMap) Mutable(k protoreflect.MapKey) protoreflect.Value { return x.set.Mutable(k) } +func (x *shadowMap) Clear(k protoreflect.MapKey) { x.set.Clear(k) } +func (x *shadowMap) Range(f func(protoreflect.MapKey, protoreflect.Value) bool) { x.get.Range(f) } +func (x *shadowMap) NewMessage() protoreflect.Message { return x.set.NewValue().Message() } +func (x *shadowMap) NewValue() protoreflect.Value { return x.set.NewValue() } +func (x *shadowMap) IsValid() bool { return x.get.IsValid() } diff --git a/internal/testprotos/enums/enums.proto b/internal/testprotos/enums/enums.proto index b8452d4cb..2921ad06c 100644 --- a/internal/testprotos/enums/enums.proto +++ b/internal/testprotos/enums/enums.proto @@ -2,12 +2,14 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -syntax = "proto2"; +edition = "2023"; package goproto.proto.enums; option go_package = "google.golang.org/protobuf/internal/testprotos/enums"; +option features.enum_type = CLOSED; + enum Enum { DEFAULT = 1337; ZERO = 0; diff --git a/internal/testprotos/irregular/test.proto b/internal/testprotos/irregular/test.proto index af8c59d42..7a2929743 100644 --- a/internal/testprotos/irregular/test.proto +++ b/internal/testprotos/irregular/test.proto @@ -15,17 +15,17 @@ import "internal/testprotos/irregular/irregular.proto"; option go_package = "google.golang.org/protobuf/internal/testprotos/irregular"; message Message { - optional IrregularMessage optional_message = 1; - repeated IrregularMessage repeated_message = 2; - required IrregularMessage required_message = 3; - map map_message = 4; + optional .goproto.proto.irregular.IrregularMessage optional_message = 1; + repeated .goproto.proto.irregular.IrregularMessage repeated_message = 2; + required .goproto.proto.irregular.IrregularMessage required_message = 3; + map map_message = 4; oneof union { - IrregularMessage oneof_message = 5; - AberrantMessage oneof_aberrant_message = 6; + .goproto.proto.irregular.IrregularMessage oneof_message = 5; + .goproto.proto.irregular.AberrantMessage oneof_aberrant_message = 6; } - optional AberrantMessage optional_aberrant_message = 7; - repeated AberrantMessage repeated_aberrant_message = 8; - required AberrantMessage required_aberrant_message = 9; - map map_aberrant_message = 10; + optional .goproto.proto.irregular.AberrantMessage optional_aberrant_message = 7; + repeated .goproto.proto.irregular.AberrantMessage repeated_aberrant_message = 8; + required .goproto.proto.irregular.AberrantMessage required_aberrant_message = 9; + map map_aberrant_message = 10; } diff --git a/internal/testprotos/lazy/lazy_normalized_wire_test.proto b/internal/testprotos/lazy/lazy_normalized_wire_test.proto new file mode 100644 index 000000000..f33560169 --- /dev/null +++ b/internal/testprotos/lazy/lazy_normalized_wire_test.proto @@ -0,0 +1,20 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +syntax = "proto2"; + +package lazy_normalized_wire_test; + +option go_package = "google.golang.org/protobuf/internal/testprotos/lazy"; + +message FSub { + optional uint32 b = 2; + optional uint32 c = 3; + optional FSub grandchild = 4 [lazy = true]; +} + +message FTop { + optional uint32 a = 1; + optional FSub child = 2; +} diff --git a/internal/testprotos/lazy/lazy_tree.proto b/internal/testprotos/lazy/lazy_tree.proto new file mode 100644 index 000000000..47641dcac --- /dev/null +++ b/internal/testprotos/lazy/lazy_tree.proto @@ -0,0 +1,29 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +edition = "2023"; + +package lazy_tree; + +option go_package = "google.golang.org/protobuf/internal/testprotos/lazy"; + +message Node { + Node nested = 99 [lazy = true]; + + int32 int32 = 1; + int64 int64 = 2; + uint32 uint32 = 3; + uint64 uint64 = 4; + sint32 sint32 = 5; + sint64 sint64 = 6; + fixed32 fixed32 = 7; + fixed64 fixed64 = 8; + sfixed32 sfixed32 = 9; + sfixed64 sfixed64 = 10; + float float = 11; + double double = 12; + bool bool = 13; + string string = 14; + bytes bytes = 15; +} diff --git a/internal/testprotos/messageset/messagesetpb/message_set.proto b/internal/testprotos/messageset/messagesetpb/message_set.proto index 0e408e0ef..93cfc1455 100644 --- a/internal/testprotos/messageset/messagesetpb/message_set.proto +++ b/internal/testprotos/messageset/messagesetpb/message_set.proto @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -syntax = "proto2"; +edition = "2023"; package goproto.proto.messageset; @@ -11,15 +11,9 @@ option go_package = "google.golang.org/protobuf/internal/testprotos/messageset/m message MessageSet { option message_set_wire_format = true; - extensions 4 to 529999999; - extensions 530000000 to max - [declaration = { - number: 536870912 - full_name: ".goproto.proto.messageset.ExtLargeNumber.message_set_extlarge" - type: ".goproto.proto.messageset.ExtLargeNumber" - }]; + extensions 4 to max; } message MessageSetContainer { - optional MessageSet message_set = 1; + MessageSet message_set = 1; } diff --git a/internal/testprotos/messageset/msetextpb/msetextpb.proto b/internal/testprotos/messageset/msetextpb/msetextpb.proto index c723b3a70..ed7127b21 100644 --- a/internal/testprotos/messageset/msetextpb/msetextpb.proto +++ b/internal/testprotos/messageset/msetextpb/msetextpb.proto @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -syntax = "proto2"; +edition = "2023"; package goproto.proto.messageset; @@ -12,28 +12,28 @@ option go_package = "google.golang.org/protobuf/internal/testprotos/messageset/m message Ext1 { extend MessageSet { - optional Ext1 message_set_ext1 = 1000; + Ext1 message_set_ext1 = 1000; } - optional int32 ext1_field1 = 1; - optional int32 ext1_field2 = 2; + int32 ext1_field1 = 1; + int32 ext1_field2 = 2; } message Ext2 { extend MessageSet { - optional Ext2 message_set_ext2 = 1001; + Ext2 message_set_ext2 = 1001; } - optional int32 ext2_field1 = 1; + int32 ext2_field1 = 1; } message ExtRequired { extend MessageSet { - optional ExtRequired message_set_extrequired = 1002; + ExtRequired message_set_extrequired = 1002; } - required int32 required_field1 = 1; + int32 required_field1 = 1 [features.field_presence = LEGACY_REQUIRED]; } message ExtLargeNumber { extend MessageSet { - optional ExtLargeNumber message_set_extlarge = 536870912; // 1<<29 + ExtLargeNumber message_set_extlarge = 536870912; // 1<<29 } } diff --git a/internal/testprotos/mixed/mixed.proto b/internal/testprotos/mixed/mixed.proto new file mode 100644 index 000000000..81cb0391e --- /dev/null +++ b/internal/testprotos/mixed/mixed.proto @@ -0,0 +1,81 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This testproto explicitly configures the API level of each message. +// +// This allows creating mixed trees of proto messages on different API levels. + +edition = "2023"; + +package goproto.proto.test; + +import "google/protobuf/go_features.proto"; + +option go_package = "google.golang.org/protobuf/internal/testprotos/mixed"; + +message Open { + option features.(pb.go).api_level = API_OPEN; + + // These fields allow for arbitrary mixing. + Open open = 1; + Hybrid hybrid = 2; + Opaque opaque = 3; + + int32 optional_int32 = 4; +} + +message Hybrid { + option features.(pb.go).api_level = API_HYBRID; + + // These fields allow for arbitrary mixing. + Open open = 1; + Hybrid hybrid = 2; + Opaque opaque = 3; + + int32 optional_int32 = 4; +} + +message Opaque { + option features.(pb.go).api_level = API_OPAQUE; + + // These fields allow for arbitrary mixing. + Open open = 1; + Hybrid hybrid = 2; + Opaque opaque = 3; + + int32 optional_int32 = 4; +} + +message OpenLazy { + option features.(pb.go).api_level = API_OPEN; + + // These fields allow for arbitrary mixing. + OpenLazy open = 1 [lazy = true]; + HybridLazy hybrid = 2 [lazy = true]; + OpaqueLazy opaque = 3 [lazy = true]; + + int32 optional_int32 = 4; +} + +message HybridLazy { + option features.(pb.go).api_level = API_HYBRID; + + // These fields allow for arbitrary mixing. + OpenLazy open = 1 [lazy = true]; + HybridLazy hybrid = 2 [lazy = true]; + OpaqueLazy opaque = 3 [lazy = true]; + + int32 optional_int32 = 4; +} + +message OpaqueLazy { + option features.(pb.go).api_level = API_OPAQUE; + + // These fields allow for arbitrary mixing. + OpenLazy open = 1 [lazy = true]; + HybridLazy hybrid = 2 [lazy = true]; + OpaqueLazy opaque = 3 [lazy = true]; + + int32 optional_int32 = 4; +} diff --git a/internal/testprotos/news/news.proto b/internal/testprotos/news/news.proto index 774c94981..0f5031f9c 100644 --- a/internal/testprotos/news/news.proto +++ b/internal/testprotos/news/news.proto @@ -19,12 +19,12 @@ message Article { } string author = 1; - google.protobuf.Timestamp date = 2; + .google.protobuf.Timestamp date = 2; string title = 3; string content = 4; Status status = 8; repeated string tags = 7; - repeated google.protobuf.Any attachments = 6; + repeated .google.protobuf.Any attachments = 6; } message BinaryAttachment { diff --git a/internal/testprotos/required/required.proto b/internal/testprotos/required/required.proto index 2837008ed..d0625d3d4 100644 --- a/internal/testprotos/required/required.proto +++ b/internal/testprotos/required/required.proto @@ -2,71 +2,76 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -syntax = "proto2"; +edition = "2023"; package goproto.proto.testrequired; option go_package = "google.golang.org/protobuf/internal/testprotos/required"; message Int32 { - required int32 v = 1; + int32 v = 1 [features.field_presence = LEGACY_REQUIRED]; } message Int64 { - required int64 v = 1; + int64 v = 1 [features.field_presence = LEGACY_REQUIRED]; } message Uint32 { - required uint32 v = 1; + uint32 v = 1 [features.field_presence = LEGACY_REQUIRED]; } message Uint64 { - required uint64 v = 1; + uint64 v = 1 [features.field_presence = LEGACY_REQUIRED]; } message Sint32 { - required sint32 v = 1; + sint32 v = 1 [features.field_presence = LEGACY_REQUIRED]; } message Sint64 { - required sint64 v = 1; + sint64 v = 1 [features.field_presence = LEGACY_REQUIRED]; } message Fixed32 { - required fixed32 v = 1; + fixed32 v = 1 [features.field_presence = LEGACY_REQUIRED]; } message Fixed64 { - required fixed64 v = 1; + fixed64 v = 1 [features.field_presence = LEGACY_REQUIRED]; } message Float { - required float v = 1; + float v = 1 [features.field_presence = LEGACY_REQUIRED]; } message Double { - required double v = 1; + double v = 1 [features.field_presence = LEGACY_REQUIRED]; } message Bool { - required bool v = 1; + bool v = 1 [features.field_presence = LEGACY_REQUIRED]; } message String { - required string v = 1; + string v = 1 [features.field_presence = LEGACY_REQUIRED]; } message Bytes { - required bytes v = 1; + bytes v = 1 [features.field_presence = LEGACY_REQUIRED]; } message Message { message M {} - required M v = 1; + M v = 1 [features.field_presence = LEGACY_REQUIRED]; } message Group { - required group Group = 1 { - optional int32 v = 1; + message Group { + int32 v = 1; } + + Group group = 1 [ + features.field_presence = LEGACY_REQUIRED, + features.message_encoding = DELIMITED + ]; } diff --git a/internal/testprotos/test/test.proto b/internal/testprotos/test/test.proto index ab0557b20..7bb4a829d 100644 --- a/internal/testprotos/test/test.proto +++ b/internal/testprotos/test/test.proto @@ -398,6 +398,10 @@ message TestRequiredGroupFields { } } +message TestRequiredLazy { + optional TestRequired optional_lazy_message = 1 [lazy = true]; +} + message TestWeak { optional goproto.proto.test.weak.WeakImportMessage1 weak_message1 = 1 [weak = true]; diff --git a/internal/testprotos/testeditions/test.proto b/internal/testprotos/testeditions/test.proto index a850d4223..420279bfb 100644 --- a/internal/testprotos/testeditions/test.proto +++ b/internal/testprotos/testeditions/test.proto @@ -6,6 +6,9 @@ edition = "2023"; package goproto.proto.testeditions; +import "internal/testprotos/enums/enums.proto"; +import "internal/testprotos/testeditions/test_import.proto"; + option go_package = "google.golang.org/protobuf/internal/testprotos/testeditions"; message TestAllTypes { @@ -36,6 +39,10 @@ message TestAllTypes { bool singular_bool = 136 [features.field_presence = IMPLICIT]; string singular_string = 137 [features.field_presence = IMPLICIT]; bytes singular_bytes = 138 [features.field_presence = IMPLICIT]; + // message-typed fields elided, as they cannot specify implicit presence. + NestedEnum singular_nested_enum = 142 [features.field_presence = IMPLICIT]; + ForeignEnum singular_foreign_enum = 143 [features.field_presence = IMPLICIT]; + ImportEnum singular_import_enum = 144 [features.field_presence = IMPLICIT]; int32 optional_int32 = 1; int64 optional_int64 = 2; @@ -62,8 +69,11 @@ message TestAllTypes { [features.message_encoding = DELIMITED]; NestedMessage optional_nested_message = 18; ForeignMessage optional_foreign_message = 19; + ImportMessage optional_import_message = 20; NestedEnum optional_nested_enum = 21; ForeignEnum optional_foreign_enum = 22; + ImportEnum optional_import_enum = 23; + NestedMessage optional_lazy_nested_message = 24 [lazy = true]; repeated int32 repeated_int32 = 31; repeated int64 repeated_int64 = 32; @@ -91,8 +101,10 @@ message TestAllTypes { ]; repeated NestedMessage repeated_nested_message = 48; repeated ForeignMessage repeated_foreign_message = 49; + repeated ImportMessage repeated_importmessage = 50; repeated NestedEnum repeated_nested_enum = 51; repeated ForeignEnum repeated_foreign_enum = 52; + repeated ImportEnum repeated_importenum = 53; map map_int32_int32 = 56; map map_int64_int64 = 57; @@ -154,6 +166,109 @@ message TestAllTypes { } } +message TestManyMessageFieldsMessage { + TestAllTypes f1 = 1; + TestAllTypes f2 = 2; + TestAllTypes f3 = 3; + TestAllTypes f4 = 4; + TestAllTypes f5 = 5; + TestAllTypes f6 = 6; + TestAllTypes f7 = 7; + TestAllTypes f8 = 8; + TestAllTypes f9 = 9; + TestAllTypes f10 = 10; + TestAllTypes f11 = 11; + TestAllTypes f12 = 12; + TestAllTypes f13 = 13; + TestAllTypes f14 = 14; + TestAllTypes f15 = 15; + TestAllTypes f16 = 16; + TestAllTypes f17 = 17; + TestAllTypes f18 = 18; + TestAllTypes f19 = 19; + TestAllTypes f20 = 20; + TestAllTypes f21 = 21; + TestAllTypes f22 = 22; + TestAllTypes f23 = 23; + TestAllTypes f24 = 24; + TestAllTypes f25 = 25; + TestAllTypes f26 = 26; + TestAllTypes f27 = 27; + TestAllTypes f28 = 28; + TestAllTypes f29 = 29; + TestAllTypes f30 = 30; + TestAllTypes f31 = 31; + TestAllTypes f32 = 32; + TestAllTypes f33 = 33; + TestAllTypes f34 = 34; + TestAllTypes f35 = 35; + TestAllTypes f36 = 36; + TestAllTypes f37 = 37; + TestAllTypes f38 = 38; + TestAllTypes f39 = 39; + TestAllTypes f40 = 40; + TestAllTypes f41 = 41; + TestAllTypes f42 = 42; + TestAllTypes f43 = 43; + TestAllTypes f44 = 44; + TestAllTypes f45 = 45; + TestAllTypes f46 = 46; + TestAllTypes f47 = 47; + TestAllTypes f48 = 48; + TestAllTypes f49 = 49; + TestAllTypes f50 = 50; + TestAllTypes f51 = 51; + TestAllTypes f52 = 52; + TestAllTypes f53 = 53; + TestAllTypes f54 = 54; + TestAllTypes f55 = 55; + TestAllTypes f56 = 56; + TestAllTypes f57 = 57; + TestAllTypes f58 = 58; + TestAllTypes f59 = 59; + TestAllTypes f60 = 60; + TestAllTypes f61 = 61; + TestAllTypes f62 = 62; + TestAllTypes f63 = 63; + TestAllTypes f64 = 64; + TestAllTypes f65 = 65; + TestAllTypes f66 = 66; + TestAllTypes f67 = 67; + TestAllTypes f68 = 68; + TestAllTypes f69 = 69; + TestAllTypes f70 = 70; + TestAllTypes f71 = 71; + TestAllTypes f72 = 72; + TestAllTypes f73 = 73; + TestAllTypes f74 = 74; + TestAllTypes f75 = 75; + TestAllTypes f76 = 76; + TestAllTypes f77 = 77; + TestAllTypes f78 = 78; + TestAllTypes f79 = 79; + TestAllTypes f80 = 80; + TestAllTypes f81 = 81; + TestAllTypes f82 = 82; + TestAllTypes f83 = 83; + TestAllTypes f84 = 84; + TestAllTypes f85 = 85; + TestAllTypes f86 = 86; + TestAllTypes f87 = 87; + TestAllTypes f88 = 88; + TestAllTypes f89 = 89; + TestAllTypes f90 = 90; + TestAllTypes f91 = 91; + TestAllTypes f92 = 92; + TestAllTypes f93 = 93; + TestAllTypes f94 = 94; + TestAllTypes f95 = 95; + TestAllTypes f96 = 96; + TestAllTypes f97 = 97; + TestAllTypes f98 = 98; + TestAllTypes f99 = 99; + TestAllTypes f100 = 100; +} + message ForeignMessage { int32 c = 1; int32 d = 2; @@ -191,6 +306,10 @@ message TestRequiredGroupFields { [features.message_encoding = DELIMITED]; } +message TestRequiredLazy { + TestRequired optional_lazy_message = 1 [lazy = true]; +} + message TestPackedTypes { repeated int32 packed_int32 = 90 [features.repeated_field_encoding = PACKED]; repeated int64 packed_int64 = 91 [features.repeated_field_encoding = PACKED]; @@ -248,3 +367,14 @@ extend TestPackedExtensions { repeated ForeignEnum packed_enum = 103 [features.repeated_field_encoding = PACKED]; } + +message RemoteDefault { + goproto.proto.enums.Enum default = 1; + goproto.proto.enums.Enum zero = 2 [default = ZERO]; + goproto.proto.enums.Enum one = 3 [default = ONE]; + goproto.proto.enums.Enum elevent = 4 [default = ELEVENT]; + goproto.proto.enums.Enum seventeen = 5 [default = SEVENTEEN]; + goproto.proto.enums.Enum thirtyseven = 6 [default = THIRTYSEVEN]; + goproto.proto.enums.Enum sixtyseven = 7 [default = SIXTYSEVEN]; + goproto.proto.enums.Enum negative = 8 [default = NEGATIVE]; +} diff --git a/internal/testprotos/testeditions/test_import.proto b/internal/testprotos/testeditions/test_import.proto new file mode 100644 index 000000000..e4d3c2ebd --- /dev/null +++ b/internal/testprotos/testeditions/test_import.proto @@ -0,0 +1,15 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +edition = "2023"; + +package goproto.proto.testeditions; + +option go_package = "google.golang.org/protobuf/internal/testprotos/testeditions"; + +message ImportMessage {} + +enum ImportEnum { + IMPORT_ZERO = 0; +} diff --git a/internal/testprotos/textpbeditions/test2.proto b/internal/testprotos/textpbeditions/test2.proto index efd8caf7b..4a5f6fec5 100644 --- a/internal/testprotos/textpbeditions/test2.proto +++ b/internal/testprotos/textpbeditions/test2.proto @@ -196,7 +196,7 @@ message NestedWithRequired { } message IndirectRequired { - NestedWithRequired opt_nested = 1 [features.field_presence = LEGACY_REQUIRED]; + NestedWithRequired opt_nested = 1; repeated NestedWithRequired rpt_nested = 2; map str_to_nested = 3; diff --git a/proto/decode.go b/proto/decode.go index d75a6534c..a3b5e142d 100644 --- a/proto/decode.go +++ b/proto/decode.go @@ -47,6 +47,12 @@ type UnmarshalOptions struct { // RecursionLimit limits how deeply messages may be nested. // If zero, a default limit is applied. RecursionLimit int + + // + // NoLazyDecoding turns off lazy decoding, which otherwise is enabled by + // default. Lazy decoding only affects submessages (annotated with [lazy = + // true] in the .proto file) within messages that use the Opaque API. + NoLazyDecoding bool } // Unmarshal parses the wire-format message in b and places the result in m. @@ -104,6 +110,16 @@ func (o UnmarshalOptions) unmarshal(b []byte, m protoreflect.Message) (out proto if o.DiscardUnknown { in.Flags |= protoiface.UnmarshalDiscardUnknown } + + if !allowPartial { + // This does not affect how current unmarshal functions work, it just allows them + // to record this for lazy the decoding case. + in.Flags |= protoiface.UnmarshalCheckRequired + } + if o.NoLazyDecoding { + in.Flags |= protoiface.UnmarshalNoLazyDecoding + } + out, err = methods.Unmarshal(in) } else { o.RecursionLimit-- diff --git a/proto/encode.go b/proto/encode.go index 1f847bcc3..f0473c586 100644 --- a/proto/encode.go +++ b/proto/encode.go @@ -63,7 +63,8 @@ type MarshalOptions struct { // options (except for UseCachedSize itself). // // 2. The message and all its submessages have not changed in any - // way since the Size call. + // way since the Size call. For lazily decoded messages, accessing + // a message results in decoding the message, which is a change. // // If either of these invariants is violated, // the results are undefined and may include panics or corrupted output. diff --git a/proto/lazy_bench_test.go b/proto/lazy_bench_test.go new file mode 100644 index 000000000..e405aa1be --- /dev/null +++ b/proto/lazy_bench_test.go @@ -0,0 +1,92 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package proto_test + +import ( + "testing" + + "google.golang.org/protobuf/proto" + + lazyopaquepb "google.golang.org/protobuf/internal/testprotos/lazy/lazy_opaque" +) + +// testMessageLinked returns a test message with a few fields of various +// possible types filled in that nests more messages like a linked list. +func testMessageLinked(nesting int) *lazyopaquepb.Node { + const ( + shortVarint = 23 // encodes into 1 byte + longVarint = 562949953421312 // encodes into 8 bytes + ) + msg := lazyopaquepb.Node_builder{ + Int32: proto.Int32(shortVarint), + Int64: proto.Int64(longVarint), + Uint32: proto.Uint32(shortVarint), + Uint64: proto.Uint64(longVarint), + Sint32: proto.Int32(shortVarint), + Sint64: proto.Int64(longVarint), + Fixed32: proto.Uint32(shortVarint), + Fixed64: proto.Uint64(longVarint), + Sfixed32: proto.Int32(shortVarint), + Sfixed64: proto.Int64(longVarint), + Float: proto.Float32(23.42), + Double: proto.Float64(23.42), + Bool: proto.Bool(true), + String: proto.String("hello"), + Bytes: []byte("world"), + }.Build() + if nesting > 0 { + msg.SetNested(testMessageLinked(nesting - 1)) + } + return msg +} + +// A higher nesting level than 15 messages deep does not result in (relative) +// performance changes. In other words, the full effect of lazy decoding is +// visible with a nesting level of 15 messages deep. Lower nesting levels (like +// 5 messages deep) also show significant improvement. +const nesting = 15 + +func BenchmarkUnmarshal(b *testing.B) { + encoded, err := proto.Marshal(testMessageLinked(nesting)) + if err != nil { + b.Fatal(err) + } + + for _, tt := range []struct { + desc string + uopts proto.UnmarshalOptions + }{ + { + desc: "lazy", + uopts: proto.UnmarshalOptions{}, + }, + + // When running the benchmark directly, print lazy vs. nolazy in the + // same run. When using the benchstat tool, you can compare lazy + // vs. nolazy by running only the lazy variant and disabling lazy + // decoding with the -test_lazy_unmarshal command-line flag: + // + // benchstat \ + // nolazy=<(go test -run=^$ -bench=Unmarshal/^lazy -count=6) \ + // lazy=<(go test -run=^$ -bench=Unmarshal/^lazy -count=6 -test_lazy_unmarshal) + { + desc: "nolazy", + uopts: proto.UnmarshalOptions{ + NoLazyDecoding: true, + }, + }, + } { + b.Run(tt.desc, func(b *testing.B) { + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + out := &lazyopaquepb.Node{} + if err := tt.uopts.Unmarshal(encoded, out); err != nil { + b.Fatalf("can't unmarshal message: %v", err) + } + } + }) + } +} diff --git a/proto/lazy_roundtrip_test.go b/proto/lazy_roundtrip_test.go new file mode 100644 index 000000000..08f42c359 --- /dev/null +++ b/proto/lazy_roundtrip_test.go @@ -0,0 +1,125 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package proto_test + +import ( + "reflect" + "testing" + "unsafe" + + "github.com/google/go-cmp/cmp" + "google.golang.org/protobuf/internal/impl" + testopaquepb "google.golang.org/protobuf/internal/testprotos/testeditions/testeditions_opaque" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/testing/protocmp" +) + +func fillLazyRequiredMessage() *testopaquepb.TestRequiredLazy { + return testopaquepb.TestRequiredLazy_builder{ + OptionalLazyMessage: testopaquepb.TestRequired_builder{ + RequiredField: proto.Int32(12), + }.Build(), + }.Build() +} + +func expandedLazy(m *testopaquepb.TestRequiredLazy) bool { + v := reflect.ValueOf(m).Elem() + rf := v.FieldByName("xxx_hidden_OptionalLazyMessage") + rf = reflect.NewAt(rf.Type(), unsafe.Pointer(rf.UnsafeAddr())).Elem() + return rf.Pointer() != 0 +} + +// This test ensures that a lazy field keeps being lazy when marshalling +// even if it has required fields (as they have already been checked on +// unmarshal) +func TestLazyRequiredRoundtrip(t *testing.T) { + if !impl.LazyEnabled() { + t.Skipf("this test requires lazy decoding to be enabled") + } + m := fillLazyRequiredMessage() + b, _ := proto.MarshalOptions{}.Marshal(m) + ml := &testopaquepb.TestRequiredLazy{} + err := proto.UnmarshalOptions{}.Unmarshal(b, ml) + if err != nil { + t.Fatalf("Error while unmarshaling: %v", err) + } + // Sanity check, we should have all unexpanded fields in the proto + if expandedLazy(ml) { + t.Fatalf("Proto is not lazy: %#v", ml) + } + // Now we marshal the lazy field. It should still be unexpanded + _, _ = proto.MarshalOptions{}.Marshal(ml) + + if expandedLazy(ml) { + t.Errorf("Proto got expanded by marshal: %#v", ml) + } + + // The following tests the current behavior for cases where we + // cannot guarantee the integrity of the lazy unmarshalled buffer + // because of required fields. This would have to be updated if + // we find another way to check required fields than simply + // unmarshalling everything that has them when we're not sure. + + ml = &testopaquepb.TestRequiredLazy{} + err = proto.UnmarshalOptions{AllowPartial: true}.Unmarshal(b, ml) + if err != nil { + t.Fatalf("Error while unmarshaling: %v", err) + } + // Sanity check, we should have all unexpanded fields in the proto. + if expandedLazy(ml) { + t.Fatalf("Proto is not lazy: %#v", ml) + } + // Now we marshal the proto. The lazy fields will be expanded to + // check required fields. + _, _ = proto.MarshalOptions{}.Marshal(ml) + + if !expandedLazy(ml) { + t.Errorf("Proto did not get expanded by marshal: %#v", ml) + } + + // Finally, we test to see that the fields to not get expanded + // if we are consistently using AllowPartial both for marshal + // and unmarshal. + ml = &testopaquepb.TestRequiredLazy{} + err = proto.UnmarshalOptions{AllowPartial: true}.Unmarshal(b, ml) + if err != nil { + t.Fatalf("Error while unmarshaling: %v", err) + } + // Sanity check, we should have all unexpanded fields in the proto. + if expandedLazy(ml) { + t.Fatalf("Proto is not lazy: %#v", ml) + } + // Now we marshal the proto. The lazy fields will be expanded to + // check required fields. + _, _ = proto.MarshalOptions{AllowPartial: true}.Marshal(ml) + + if expandedLazy(ml) { + t.Errorf("Proto did not get expanded by marshal: %#v", ml) + } + +} + +func TestRoundtripMap(t *testing.T) { + m := testopaquepb.TestAllTypes_builder{ + OptionalLazyNestedMessage: testopaquepb.TestAllTypes_NestedMessage_builder{ + Corecursive: testopaquepb.TestAllTypes_builder{ + MapStringString: map[string]string{ + "a": "b", + }, + }.Build(), + }.Build(), + }.Build() + b, err := proto.Marshal(m) + if err != nil { + t.Fatalf("proto.Marshal: %v", err) + } + got := &testopaquepb.TestAllTypes{} + if err := proto.Unmarshal(b, got); err != nil { + t.Fatalf("proto.Unmarshal: %v", err) + } + if diff := cmp.Diff(m, got, protocmp.Transform()); diff != "" { + t.Errorf("not the same: diff (-want +got):\n%s", diff) + } +} diff --git a/proto/messageset_test.go b/proto/messageset_test.go index 7b7671466..d1cad2147 100644 --- a/proto/messageset_test.go +++ b/proto/messageset_test.go @@ -12,7 +12,9 @@ import ( "google.golang.org/protobuf/testing/protopack" "google.golang.org/protobuf/internal/testprotos/messageset/messagesetpb" + _ "google.golang.org/protobuf/internal/testprotos/messageset/messagesetpb/messagesetpb_opaque" _ "google.golang.org/protobuf/internal/testprotos/messageset/msetextpb" + _ "google.golang.org/protobuf/internal/testprotos/messageset/msetextpb/msetextpb_opaque" ) func init() { diff --git a/proto/oneof_get_test.go b/proto/oneof_get_test.go new file mode 100644 index 000000000..d60015c93 --- /dev/null +++ b/proto/oneof_get_test.go @@ -0,0 +1,273 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package proto_test + +import ( + "testing" + + testhybridpb "google.golang.org/protobuf/internal/testprotos/testeditions/testeditions_hybrid" + testopaquepb "google.golang.org/protobuf/internal/testprotos/testeditions/testeditions_opaque" +) + +func expectPanic(t *testing.T, f func(), fmt string, a ...any) { + t.Helper() + defer func() { + t.Helper() + if r := recover(); r == nil { + t.Errorf(fmt, a...) + } + }() + f() +} + +func TestOpenGet(t *testing.T) { + x := &testhybridpb.TestAllTypes{} + + tab := []struct { + fName string // Field name (in proto) + set func() // Set the field + clear func() // Clear the field + has func() bool // Has for the field + isZero func() bool // Get and return true if zero value + }{ + { + fName: "oneof_uint32", + set: func() { x.SetOneofUint32(47) }, + clear: func() { x.ClearOneofUint32() }, + has: func() bool { return x.HasOneofUint32() }, + isZero: func() bool { return x.GetOneofUint32() == 0 }, + }, + { + fName: "oneof_nested_message", + set: func() { x.SetOneofNestedMessage(&testhybridpb.TestAllTypes_NestedMessage{}) }, + clear: func() { x.ClearOneofNestedMessage() }, + has: func() bool { return x.HasOneofNestedMessage() }, + isZero: func() bool { return x.GetOneofNestedMessage() == nil }, + }, + { + fName: "oneof_string", + set: func() { x.SetOneofString("test") }, + clear: func() { x.ClearOneofString() }, + has: func() bool { return x.HasOneofString() }, + isZero: func() bool { return x.GetOneofString() == "" }, + }, + { + fName: "oneof_bytes", + set: func() { x.SetOneofBytes([]byte("test")) }, + clear: func() { x.ClearOneofBytes() }, + has: func() bool { return x.HasOneofBytes() }, + isZero: func() bool { return len(x.GetOneofBytes()) == 0 }, + }, + { + fName: "oneof_bool", + set: func() { x.SetOneofBool(true) }, + clear: func() { x.ClearOneofBool() }, + has: func() bool { return x.HasOneofBool() }, + isZero: func() bool { return x.GetOneofBool() == false }, + }, + { + fName: "oneof_uint64", + set: func() { x.SetOneofUint64(7438109473104) }, + clear: func() { x.ClearOneofUint64() }, + has: func() bool { return x.HasOneofUint64() }, + isZero: func() bool { return x.GetOneofUint64() == 0 }, + }, + { + fName: "oneof_float", + set: func() { x.SetOneofFloat(3.1415) }, + clear: func() { x.ClearOneofFloat() }, + has: func() bool { return x.HasOneofFloat() }, + isZero: func() bool { return x.GetOneofFloat() == 0.0 }, + }, + { + fName: "oneof_double", + set: func() { x.SetOneofDouble(3e+8) }, + clear: func() { x.ClearOneofDouble() }, + has: func() bool { return x.HasOneofDouble() }, + isZero: func() bool { return x.GetOneofDouble() == 0.0 }, + }, + { + fName: "oneof_enum", + set: func() { x.SetOneofEnum(testhybridpb.TestAllTypes_BAZ) }, + clear: func() { x.ClearOneofEnum() }, + has: func() bool { return x.HasOneofEnum() }, + isZero: func() bool { return x.GetOneofEnum() == 0 }, + }, + } + + for i, mv := range tab { + x.ClearOneofField() + if got, want := x.HasOneofField(), false; got != want { + t.Errorf("HasOneofField returned %v, expected %v", got, want) + } + if got, want := mv.isZero(), true; got != want { + t.Errorf("Get on empty oneof member did not return zero value, got %v, expected %v (%s)", got, want, mv.fName) + } + mv.set() + + if got, want := x.HasOneofField(), true; got != want { + t.Errorf("HasOneofField returned %v, expected %v (%s)", got, want, mv.fName) + } + if got, want := mv.isZero(), false; got != want { + t.Errorf("Get on non-empty oneof member did return zero value, got %v, expected %v (%s)", got, want, mv.fName) + } + + mv.clear() + if got, want := x.HasOneofField(), false; got != want { + t.Errorf("HasOneofField returned %v, expected %v (%s)", got, want, mv.fName) + } + if got, want := mv.isZero(), true; got != want { + t.Errorf("Get on empty oneof member did not return zero value, got %v, expected %v (%s)", got, want, mv.fName) + } + other := tab[(i+1)%len(tab)] + mv.set() + other.set() + + if got, want := x.HasOneofField(), true; got != want { + t.Errorf("HasOneofField returned %v, expected %v (%s)", got, want, mv.fName) + } + if got, want := mv.isZero(), true; got != want { + t.Errorf("Get on wrong oneof member did not return zero value, got %v, expected %v (%s)", got, want, mv.fName) + } + other.clear() + if got, want := x.HasOneofField(), false; got != want { + t.Errorf("HasOneofField returned %v, expected %v (%s)", got, want, mv.fName) + } + } + x = nil + for _, mv := range tab { + if got, want := mv.isZero(), true; got != want { + t.Errorf("Get on nil receiver did not return zero value, got %v, expected %v (%s)", got, want, mv.fName) + } + if got, want := mv.has(), false; got != want { + t.Errorf("Has on nil receiver failed, got %v, expected %v (%s)", got, want, mv.fName) + } + } +} + +func TestOpaqueGet(t *testing.T) { + x := &testopaquepb.TestAllTypes{} + + tab := []struct { + fName string // Field name (in proto) + set func() // Set the field + clear func() // Clear the field + has func() bool // Has for the field + isZero func() bool // Get and return true if zero value + }{ + { + fName: "oneof_uint32", + set: func() { x.SetOneofUint32(47) }, + clear: func() { x.ClearOneofUint32() }, + has: func() bool { return x.HasOneofUint32() }, + isZero: func() bool { return x.GetOneofUint32() == 0 }, + }, + { + fName: "oneof_nested_message", + set: func() { x.SetOneofNestedMessage(&testopaquepb.TestAllTypes_NestedMessage{}) }, + clear: func() { x.ClearOneofNestedMessage() }, + has: func() bool { return x.HasOneofNestedMessage() }, + isZero: func() bool { return x.GetOneofNestedMessage() == nil }, + }, + { + fName: "oneof_string", + set: func() { x.SetOneofString("test") }, + clear: func() { x.ClearOneofString() }, + has: func() bool { return x.HasOneofString() }, + isZero: func() bool { return x.GetOneofString() == "" }, + }, + { + fName: "oneof_bytes", + set: func() { x.SetOneofBytes([]byte("test")) }, + clear: func() { x.ClearOneofBytes() }, + has: func() bool { return x.HasOneofBytes() }, + isZero: func() bool { return len(x.GetOneofBytes()) == 0 }, + }, + { + fName: "oneof_bool", + set: func() { x.SetOneofBool(true) }, + clear: func() { x.ClearOneofBool() }, + has: func() bool { return x.HasOneofBool() }, + isZero: func() bool { return x.GetOneofBool() == false }, + }, + { + fName: "oneof_uint64", + set: func() { x.SetOneofUint64(7438109473104) }, + clear: func() { x.ClearOneofUint64() }, + has: func() bool { return x.HasOneofUint64() }, + isZero: func() bool { return x.GetOneofUint64() == 0 }, + }, + { + fName: "oneof_float", + set: func() { x.SetOneofFloat(3.1415) }, + clear: func() { x.ClearOneofFloat() }, + has: func() bool { return x.HasOneofFloat() }, + isZero: func() bool { return x.GetOneofFloat() == 0.0 }, + }, + { + fName: "oneof_double", + set: func() { x.SetOneofDouble(3e+8) }, + clear: func() { x.ClearOneofDouble() }, + has: func() bool { return x.HasOneofDouble() }, + isZero: func() bool { return x.GetOneofDouble() == 0.0 }, + }, + { + fName: "oneof_enum", + set: func() { x.SetOneofEnum(testopaquepb.TestAllTypes_BAZ) }, + clear: func() { x.ClearOneofEnum() }, + has: func() bool { return x.HasOneofEnum() }, + isZero: func() bool { return x.GetOneofEnum() == 0 }, + }, + } + + for i, mv := range tab { + x.ClearOneofField() + if got, want := x.HasOneofField(), false; got != want { + t.Errorf("HasOneofField returned %v, expected %v", got, want) + } + if got, want := mv.isZero(), true; got != want { + t.Errorf("Get on empty oneof member did not return zero value, got %v, expected %v (%s)", got, want, mv.fName) + } + mv.set() + + if got, want := x.HasOneofField(), true; got != want { + t.Errorf("HasOneofField returned %v, expected %v (%s)", got, want, mv.fName) + } + if got, want := mv.isZero(), false; got != want { + t.Errorf("Get on non-empty oneof member did return zero value, got %v, expected %v (%s)", got, want, mv.fName) + } + + mv.clear() + if got, want := x.HasOneofField(), false; got != want { + t.Errorf("HasOneofField returned %v, expected %v (%s)", got, want, mv.fName) + } + if got, want := mv.isZero(), true; got != want { + t.Errorf("Get on empty oneof member did not return zero value, got %v, expected %v (%s)", got, want, mv.fName) + } + other := tab[(i+1)%len(tab)] + mv.set() + other.set() + + if got, want := x.HasOneofField(), true; got != want { + t.Errorf("HasOneofField returned %v, expected %v (%s)", got, want, mv.fName) + } + if got, want := mv.isZero(), true; got != want { + t.Errorf("Get on wrong oneof member did not return zero value, got %v, expected %v (%s)", got, want, mv.fName) + } + other.clear() + if got, want := x.HasOneofField(), false; got != want { + t.Errorf("HasOneofField returned %v, expected %v (%s)", got, want, mv.fName) + } + } + x = nil + for _, mv := range tab { + if got, want := mv.isZero(), true; got != want { + t.Errorf("Get on nil receiver did not return zero value, got %v, expected %v (%s)", got, want, mv.fName) + } + if got, want := mv.has(), false; got != want { + t.Errorf("Has on nil receiver failed, got %v, expected %v (%s)", got, want, mv.fName) + } + } +} diff --git a/proto/oneof_set_test.go b/proto/oneof_set_test.go new file mode 100644 index 000000000..265063e35 --- /dev/null +++ b/proto/oneof_set_test.go @@ -0,0 +1,312 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package proto_test + +import ( + "testing" + + testhybridpb "google.golang.org/protobuf/internal/testprotos/testeditions/testeditions_hybrid" + testopaquepb "google.golang.org/protobuf/internal/testprotos/testeditions/testeditions_opaque" +) + +func TestOpenSetNilReceiver(t *testing.T) { + var x *testhybridpb.TestAllTypes + expectPanic(t, func() { + x.SetOneofUint32(24) + }, "Setting of oneof member on nil receiver did not panic.") + expectPanic(t, func() { + x.ClearOneofUint32() + }, "Clearing of oneof member on nil receiver did not panic.") + expectPanic(t, func() { + x.ClearOneofField() + }, "Clearing of oneof union on nil receiver did not panic.") +} + +func TestOpenSet(t *testing.T) { + x := &testhybridpb.TestAllTypes{} + + tab := []struct { + fName string // Field name (in proto) + set func() // Set the field + clear func() // Clear the field + has func() bool // Has for the field + }{ + { + fName: "oneof_uint32", + set: func() { x.SetOneofUint32(47) }, + clear: func() { x.ClearOneofUint32() }, + has: func() bool { return x.HasOneofUint32() }, + }, + { + fName: "oneof_nested_message", + set: func() { x.SetOneofNestedMessage(&testhybridpb.TestAllTypes_NestedMessage{}) }, + clear: func() { x.ClearOneofNestedMessage() }, + has: func() bool { return x.HasOneofNestedMessage() }, + }, + { + fName: "oneof_string", + set: func() { x.SetOneofString("test") }, + clear: func() { x.ClearOneofString() }, + has: func() bool { return x.HasOneofString() }, + }, + { + fName: "oneof_bytes", + set: func() { x.SetOneofBytes([]byte("test")) }, + clear: func() { x.ClearOneofBytes() }, + has: func() bool { return x.HasOneofBytes() }, + }, + { + fName: "oneof_bool", + set: func() { x.SetOneofBool(true) }, + clear: func() { x.ClearOneofBool() }, + has: func() bool { return x.HasOneofBool() }, + }, + { + fName: "oneof_uint64", + set: func() { x.SetOneofUint64(7438109473104) }, + clear: func() { x.ClearOneofUint64() }, + has: func() bool { return x.HasOneofUint64() }, + }, + { + fName: "oneof_float", + set: func() { x.SetOneofFloat(3.1415) }, + clear: func() { x.ClearOneofFloat() }, + has: func() bool { return x.HasOneofFloat() }, + }, + { + fName: "oneof_double", + set: func() { x.SetOneofDouble(3e+8) }, + clear: func() { x.ClearOneofDouble() }, + has: func() bool { return x.HasOneofDouble() }, + }, + { + fName: "oneof_enum", + set: func() { x.SetOneofEnum(testhybridpb.TestAllTypes_BAZ) }, + clear: func() { x.ClearOneofEnum() }, + has: func() bool { return x.HasOneofEnum() }, + }, + } + + for i, mv := range tab { + x.ClearOneofField() + if got, want := x.HasOneofField(), false; got != want { + t.Errorf("HasOneofField returned %v, expected %v", got, want) + } + mv.set() + + if got, want := x.HasOneofField(), true; got != want { + t.Errorf("HasOneofField returned %v, expected %v (%s)", got, want, mv.fName) + } + if got, want := mv.has(), true; got != want { + t.Errorf("Has on oneof member returned %v, expected %v (%s)", got, want, mv.fName) + } + + mv.clear() + if got, want := x.HasOneofField(), false; got != want { + t.Errorf("HasOneofField returned %v, expected %v (%s)", got, want, mv.fName) + } + if got, want := mv.has(), false; got != want { + t.Errorf("Has on oneof member returned %v, expected %v (%s)", got, want, mv.fName) + } + other := tab[(i+1)%len(tab)] + mv.set() + other.set() + + if got, want := x.HasOneofField(), true; got != want { + t.Errorf("HasOneofField returned %v, expected %v (%s)", got, want, mv.fName) + } + if got, want := mv.has(), false; got != want { + t.Errorf("Has on oneof member returned %v, expected %v (%s)", got, want, mv.fName) + } + other.clear() + if got, want := x.HasOneofField(), false; got != want { + t.Errorf("HasOneofField returned %v, expected %v (%s)", got, want, mv.fName) + } + } + x.SetOneofUint32(47) + if got, want := x.HasOneofField(), true; got != want { + t.Errorf("HasOneofField returned %v, expected %v", got, want) + } + if got, want := x.HasOneofUint32(), true; got != want { + t.Errorf("HasOneofField returned %v, expected %v", got, want) + } + x.SetOneofNestedMessage(nil) + if got, want := x.HasOneofField(), false; got != want { + t.Errorf("HasOneofField returned %v, expected %v", got, want) + } + if got, want := x.HasOneofUint32(), false; got != want { + t.Errorf("HasOneofUint32 returned %v, expected %v", got, want) + } + if got, want := x.HasOneofNestedMessage(), false; got != want { + t.Errorf("HasOneofNestedMessage returned %v, expected %v", got, want) + } + x.SetOneofUint32(47) + if got, want := x.HasOneofField(), true; got != want { + t.Errorf("HasOneofField returned %v, expected %v", got, want) + } + if got, want := x.HasOneofUint32(), true; got != want { + t.Errorf("HasOneofField returned %v, expected %v", got, want) + } + x.SetOneofBytes(nil) + if got, want := x.HasOneofField(), true; got != want { + t.Errorf("HasOneofField returned %v, expected %v", got, want) + } + if got, want := x.HasOneofUint32(), false; got != want { + t.Errorf("HasOneofUint32 returned %v, expected %v", got, want) + } + if got, want := x.HasOneofBytes(), true; got != want { + t.Errorf("HasOneofNestedMessage returned %v, expected %v", got, want) + } +} + +func TestOpaqueSetNilReceiver(t *testing.T) { + var x *testopaquepb.TestAllTypes + expectPanic(t, func() { + x.SetOneofUint32(24) + }, "Setting of oneof member on nil receiver did not panic.") + expectPanic(t, func() { + x.ClearOneofUint32() + }, "Clearing of oneof member on nil receiver did not panic.") + expectPanic(t, func() { + x.ClearOneofField() + }, "Clearing of oneof union on nil receiver did not panic.") +} + +func TestOpaqueSet(t *testing.T) { + x := &testopaquepb.TestAllTypes{} + + tab := []struct { + fName string // Field name (in proto) + set func() // Set the field + clear func() // Clear the field + has func() bool // Has for the field + }{ + { + fName: "oneof_uint32", + set: func() { x.SetOneofUint32(47) }, + clear: func() { x.ClearOneofUint32() }, + has: func() bool { return x.HasOneofUint32() }, + }, + { + fName: "oneof_nested_message", + set: func() { x.SetOneofNestedMessage(&testopaquepb.TestAllTypes_NestedMessage{}) }, + clear: func() { x.ClearOneofNestedMessage() }, + has: func() bool { return x.HasOneofNestedMessage() }, + }, + { + fName: "oneof_string", + set: func() { x.SetOneofString("test") }, + clear: func() { x.ClearOneofString() }, + has: func() bool { return x.HasOneofString() }, + }, + { + fName: "oneof_bytes", + set: func() { x.SetOneofBytes([]byte("test")) }, + clear: func() { x.ClearOneofBytes() }, + has: func() bool { return x.HasOneofBytes() }, + }, + { + fName: "oneof_bool", + set: func() { x.SetOneofBool(true) }, + clear: func() { x.ClearOneofBool() }, + has: func() bool { return x.HasOneofBool() }, + }, + { + fName: "oneof_uint64", + set: func() { x.SetOneofUint64(7438109473104) }, + clear: func() { x.ClearOneofUint64() }, + has: func() bool { return x.HasOneofUint64() }, + }, + { + fName: "oneof_float", + set: func() { x.SetOneofFloat(3.1415) }, + clear: func() { x.ClearOneofFloat() }, + has: func() bool { return x.HasOneofFloat() }, + }, + { + fName: "oneof_double", + set: func() { x.SetOneofDouble(3e+8) }, + clear: func() { x.ClearOneofDouble() }, + has: func() bool { return x.HasOneofDouble() }, + }, + { + fName: "oneof_enum", + set: func() { x.SetOneofEnum(testopaquepb.TestAllTypes_BAZ) }, + clear: func() { x.ClearOneofEnum() }, + has: func() bool { return x.HasOneofEnum() }, + }, + } + + for i, mv := range tab { + x.ClearOneofField() + if got, want := x.HasOneofField(), false; got != want { + t.Errorf("HasOneofField returned %v, expected %v", got, want) + } + mv.set() + + if got, want := x.HasOneofField(), true; got != want { + t.Errorf("HasOneofField returned %v, expected %v (%s)", got, want, mv.fName) + } + if got, want := mv.has(), true; got != want { + t.Errorf("Has on oneof member returned %v, expected %v (%s)", got, want, mv.fName) + } + + mv.clear() + if got, want := x.HasOneofField(), false; got != want { + t.Errorf("HasOneofField returned %v, expected %v (%s)", got, want, mv.fName) + } + if got, want := mv.has(), false; got != want { + t.Errorf("Has on oneof member returned %v, expected %v (%s)", got, want, mv.fName) + } + other := tab[(i+1)%len(tab)] + mv.set() + other.set() + + if got, want := x.HasOneofField(), true; got != want { + t.Errorf("HasOneofField returned %v, expected %v (%s)", got, want, mv.fName) + } + if got, want := mv.has(), false; got != want { + t.Errorf("Has on oneof member returned %v, expected %v (%s)", got, want, mv.fName) + } + other.clear() + if got, want := x.HasOneofField(), false; got != want { + t.Errorf("HasOneofField returned %v, expected %v (%s)", got, want, mv.fName) + } + } + x.SetOneofUint32(47) + if got, want := x.HasOneofField(), true; got != want { + t.Errorf("HasOneofField returned %v, expected %v", got, want) + } + if got, want := x.HasOneofUint32(), true; got != want { + t.Errorf("HasOneofField returned %v, expected %v", got, want) + } + x.SetOneofNestedMessage(nil) + if got, want := x.HasOneofField(), false; got != want { + t.Errorf("HasOneofField returned %v, expected %v", got, want) + } + if got, want := x.HasOneofUint32(), false; got != want { + t.Errorf("HasOneofUint32 returned %v, expected %v", got, want) + } + if got, want := x.HasOneofNestedMessage(), false; got != want { + t.Errorf("HasOneofNestedMessage returned %v, expected %v", got, want) + } + x.SetOneofUint32(47) + if got, want := x.HasOneofField(), true; got != want { + t.Errorf("HasOneofField returned %v, expected %v", got, want) + } + if got, want := x.HasOneofUint32(), true; got != want { + t.Errorf("HasOneofField returned %v, expected %v", got, want) + } + x.SetOneofBytes(nil) + if got, want := x.HasOneofField(), true; got != want { + t.Errorf("HasOneofField returned %v, expected %v", got, want) + } + if got, want := x.HasOneofUint32(), false; got != want { + t.Errorf("HasOneofUint32 returned %v, expected %v", got, want) + } + if got, want := x.HasOneofBytes(), true; got != want { + t.Errorf("HasOneofNestedMessage returned %v, expected %v", got, want) + } +} diff --git a/proto/oneof_which_test.go b/proto/oneof_which_test.go new file mode 100644 index 000000000..bdf10bdb0 --- /dev/null +++ b/proto/oneof_which_test.go @@ -0,0 +1,189 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package proto_test + +import ( + "testing" + + testhybridpb "google.golang.org/protobuf/internal/testprotos/testeditions/testeditions_hybrid" + testopaquepb "google.golang.org/protobuf/internal/testprotos/testeditions/testeditions_opaque" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" +) + +func TestOpenWhich(t *testing.T) { + var x *testhybridpb.TestAllTypes + if x.WhichOneofField() != 0 { + t.Errorf("WhichOneofField on nil returned %d, expected %d", x.WhichOneofField(), 0) + } + x = &testhybridpb.TestAllTypes{} + if x.WhichOneofField() != 0 { + t.Errorf("WhichOneofField returned %d, expected %d", x.WhichOneofField(), 0) + } + tab := []struct { + m *testhybridpb.TestAllTypes + v protoreflect.FieldNumber + }{ + { + m: testhybridpb.TestAllTypes_builder{ + OneofUint32: proto.Uint32(46), + }.Build(), + v: protoreflect.FieldNumber(testhybridpb.TestAllTypes_OneofUint32_case), + }, + { + m: testhybridpb.TestAllTypes_builder{ + OneofNestedMessage: testhybridpb.TestAllTypes_NestedMessage_builder{A: proto.Int32(46)}.Build(), + }.Build(), + v: protoreflect.FieldNumber(testhybridpb.TestAllTypes_OneofNestedMessage_case), + }, + { + m: testhybridpb.TestAllTypes_builder{ + OneofString: proto.String("foo"), + }.Build(), + v: protoreflect.FieldNumber(testhybridpb.TestAllTypes_OneofString_case), + }, + { + m: testhybridpb.TestAllTypes_builder{ + OneofBytes: []byte("foo"), + }.Build(), + v: protoreflect.FieldNumber(testhybridpb.TestAllTypes_OneofBytes_case), + }, + { + m: testhybridpb.TestAllTypes_builder{ + OneofBool: proto.Bool(true), + }.Build(), + v: protoreflect.FieldNumber(testhybridpb.TestAllTypes_OneofBool_case), + }, + { + m: testhybridpb.TestAllTypes_builder{ + OneofUint64: proto.Uint64(0), + }.Build(), + v: protoreflect.FieldNumber(testhybridpb.TestAllTypes_OneofUint64_case), + }, + { + m: testhybridpb.TestAllTypes_builder{ + OneofFloat: proto.Float32(0.0), + }.Build(), + v: protoreflect.FieldNumber(testhybridpb.TestAllTypes_OneofFloat_case), + }, + { + m: testhybridpb.TestAllTypes_builder{ + OneofDouble: proto.Float64(1.1), + }.Build(), + v: protoreflect.FieldNumber(testhybridpb.TestAllTypes_OneofDouble_case), + }, + { + m: testhybridpb.TestAllTypes_builder{ + OneofEnum: testhybridpb.TestAllTypes_BAZ.Enum(), + }.Build(), + v: protoreflect.FieldNumber(testhybridpb.TestAllTypes_OneofEnum_case), + }, + } + + for _, mv := range tab { + if protoreflect.FieldNumber(mv.m.WhichOneofField()) != mv.v { + t.Errorf("WhichOneofField returned %d, expected %d", mv.m.WhichOneofField(), mv.v) + } + if !mv.m.HasOneofField() { + t.Errorf("HasOneofField returned %t, expected true", mv.m.HasOneofField()) + + } + mv.m.ClearOneofField() + if mv.m.WhichOneofField() != 0 { + t.Errorf("WhichOneofField returned %d, expected %d", mv.m.WhichOneofField(), 0) + } + if mv.m.HasOneofField() { + t.Errorf("HasOneofField returned %t, expected false", mv.m.HasOneofField()) + } + } +} + +func TestOpaqueWhich(t *testing.T) { + var x *testopaquepb.TestAllTypes + if x.WhichOneofField() != 0 { + t.Errorf("WhichOneofField on nil returned %d, expected %d", x.WhichOneofField(), 0) + } + x = &testopaquepb.TestAllTypes{} + if x.WhichOneofField() != 0 { + t.Errorf("WhichOneofField returned %d, expected %d", x.WhichOneofField(), 0) + } + en := testopaquepb.TestAllTypes_BAZ + tab := []struct { + m *testopaquepb.TestAllTypes + v protoreflect.FieldNumber + }{ + { + m: testopaquepb.TestAllTypes_builder{ + OneofUint32: proto.Uint32(46), + }.Build(), + v: protoreflect.FieldNumber(testopaquepb.TestAllTypes_OneofUint32_case), + }, + { + m: testopaquepb.TestAllTypes_builder{ + OneofNestedMessage: testopaquepb.TestAllTypes_NestedMessage_builder{A: proto.Int32(46)}.Build(), + }.Build(), + v: protoreflect.FieldNumber(testopaquepb.TestAllTypes_OneofNestedMessage_case), + }, + { + m: testopaquepb.TestAllTypes_builder{ + OneofString: proto.String("foo"), + }.Build(), + v: protoreflect.FieldNumber(testopaquepb.TestAllTypes_OneofString_case), + }, + { + m: testopaquepb.TestAllTypes_builder{ + OneofBytes: []byte("foo"), + }.Build(), + v: protoreflect.FieldNumber(testopaquepb.TestAllTypes_OneofBytes_case), + }, + { + m: testopaquepb.TestAllTypes_builder{ + OneofBool: proto.Bool(true), + }.Build(), + v: protoreflect.FieldNumber(testopaquepb.TestAllTypes_OneofBool_case), + }, + { + m: testopaquepb.TestAllTypes_builder{ + OneofUint64: proto.Uint64(0), + }.Build(), + v: protoreflect.FieldNumber(testopaquepb.TestAllTypes_OneofUint64_case), + }, + { + m: testopaquepb.TestAllTypes_builder{ + OneofFloat: proto.Float32(0.0), + }.Build(), + v: protoreflect.FieldNumber(testopaquepb.TestAllTypes_OneofFloat_case), + }, + { + m: testopaquepb.TestAllTypes_builder{ + OneofDouble: proto.Float64(1.1), + }.Build(), + v: protoreflect.FieldNumber(testopaquepb.TestAllTypes_OneofDouble_case), + }, + { + m: testopaquepb.TestAllTypes_builder{ + OneofEnum: &en, + }.Build(), + v: protoreflect.FieldNumber(testopaquepb.TestAllTypes_OneofEnum_case), + }, + } + + for _, mv := range tab { + if protoreflect.FieldNumber(mv.m.WhichOneofField()) != mv.v { + t.Errorf("WhichOneofField returned %d, expected %d", mv.m.WhichOneofField(), mv.v) + } + if !mv.m.HasOneofField() { + t.Errorf("HasOneofField returned %t, expected true", mv.m.HasOneofField()) + + } + mv.m.ClearOneofField() + if mv.m.WhichOneofField() != 0 { + t.Errorf("WhichOneofField returned %d, expected %d", mv.m.WhichOneofField(), 0) + } + if mv.m.HasOneofField() { + t.Errorf("HasOneofField returned %t, expected false", mv.m.HasOneofField()) + } + } +} diff --git a/proto/repeated_test.go b/proto/repeated_test.go new file mode 100644 index 000000000..9edb2c654 --- /dev/null +++ b/proto/repeated_test.go @@ -0,0 +1,560 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package proto_test + +import ( + "fmt" + "reflect" + "testing" + "unsafe" + + "google.golang.org/protobuf/internal/impl" + testhybridpb "google.golang.org/protobuf/internal/testprotos/testeditions/testeditions_hybrid" + testopaquepb "google.golang.org/protobuf/internal/testprotos/testeditions/testeditions_opaque" + "google.golang.org/protobuf/proto" +) + +func TestOpenSetRepeatedNilReceiver(t *testing.T) { + var x *testhybridpb.TestAllTypes + expectPanic(t, func() { + x.SetRepeatedUint32(nil) + }, "Setting repeated field on nil receiver did not panic.") +} + +func TestOpenSetRepeated(t *testing.T) { + x := &testhybridpb.TestAllTypes{} + + tab := []struct { + fName string // Field name (in proto) + set func() // Set the field to empty slice + setNil func() // Set the field to nil + len func() int // length of field, -1 if nil + }{ + { + fName: "repeated_int32", + set: func() { x.SetRepeatedInt32([]int32{}) }, + setNil: func() { x.SetRepeatedInt32(nil) }, + len: func() int { + if x.GetRepeatedInt32() == nil { + return -1 + } + return len(x.GetRepeatedInt32()) + }, + }, + { + fName: "repeated_int64", + set: func() { x.SetRepeatedInt64([]int64{}) }, + setNil: func() { x.SetRepeatedInt64(nil) }, + len: func() int { + if x.GetRepeatedInt64() == nil { + return -1 + } + return len(x.GetRepeatedInt64()) + }, + }, + { + fName: "repeated_uint32", + set: func() { x.SetRepeatedUint32([]uint32{}) }, + setNil: func() { x.SetRepeatedUint32(nil) }, + len: func() int { + if x.GetRepeatedUint32() == nil { + return -1 + } + return len(x.GetRepeatedUint32()) + }, + }, + { + fName: "repeated_uint64", + set: func() { x.SetRepeatedUint64([]uint64{}) }, + setNil: func() { x.SetRepeatedUint64(nil) }, + len: func() int { + if x.GetRepeatedUint64() == nil { + return -1 + } + return len(x.GetRepeatedUint64()) + }, + }, + { + fName: "repeated_sint32", + set: func() { x.SetRepeatedSint32([]int32{}) }, + setNil: func() { x.SetRepeatedSint32(nil) }, + len: func() int { + if x.GetRepeatedSint32() == nil { + return -1 + } + return len(x.GetRepeatedSint32()) + }, + }, + { + fName: "repeated_sint64", + set: func() { x.SetRepeatedSint64([]int64{}) }, + setNil: func() { x.SetRepeatedSint64(nil) }, + len: func() int { + if x.GetRepeatedSint64() == nil { + return -1 + } + return len(x.GetRepeatedSint64()) + }, + }, + { + fName: "repeated_fixed32", + set: func() { x.SetRepeatedFixed32([]uint32{}) }, + setNil: func() { x.SetRepeatedFixed32(nil) }, + len: func() int { + if x.GetRepeatedFixed32() == nil { + return -1 + } + return len(x.GetRepeatedFixed32()) + }, + }, + { + fName: "repeated_fixed64", + set: func() { x.SetRepeatedFixed64([]uint64{}) }, + setNil: func() { x.SetRepeatedFixed64(nil) }, + len: func() int { + if x.GetRepeatedFixed64() == nil { + return -1 + } + return len(x.GetRepeatedFixed64()) + }, + }, + { + fName: "repeated_sfixed32", + set: func() { x.SetRepeatedSfixed32([]int32{}) }, + setNil: func() { x.SetRepeatedSfixed32(nil) }, + len: func() int { + if x.GetRepeatedSfixed32() == nil { + return -1 + } + return len(x.GetRepeatedSfixed32()) + }, + }, + { + fName: "repeated_sfixed64", + set: func() { x.SetRepeatedSfixed64([]int64{}) }, + setNil: func() { x.SetRepeatedSfixed64(nil) }, + len: func() int { + if x.GetRepeatedSfixed64() == nil { + return -1 + } + return len(x.GetRepeatedSfixed64()) + }, + }, + { + fName: "repeated_float", + set: func() { x.SetRepeatedFloat([]float32{}) }, + setNil: func() { x.SetRepeatedFloat(nil) }, + len: func() int { + if x.GetRepeatedFloat() == nil { + return -1 + } + return len(x.GetRepeatedFloat()) + }, + }, + { + fName: "repeated_double", + set: func() { x.SetRepeatedDouble([]float64{}) }, + setNil: func() { x.SetRepeatedDouble(nil) }, + len: func() int { + if x.GetRepeatedDouble() == nil { + return -1 + } + return len(x.GetRepeatedDouble()) + }, + }, + { + fName: "repeated_bool", + set: func() { x.SetRepeatedBool([]bool{}) }, + setNil: func() { x.SetRepeatedBool(nil) }, + len: func() int { + if x.GetRepeatedBool() == nil { + return -1 + } + return len(x.GetRepeatedBool()) + }, + }, + { + fName: "repeated_string", + set: func() { x.SetRepeatedString([]string{}) }, + setNil: func() { x.SetRepeatedString(nil) }, + len: func() int { + if x.GetRepeatedString() == nil { + return -1 + } + return len(x.GetRepeatedString()) + }, + }, + { + fName: "repeated_bytes", + set: func() { x.SetRepeatedBytes([][]byte{}) }, + setNil: func() { x.SetRepeatedBytes(nil) }, + len: func() int { + if x.GetRepeatedBytes() == nil { + return -1 + } + return len(x.GetRepeatedBytes()) + }, + }, + { + fName: "RepeatedGroup", + set: func() { x.SetRepeatedgroup([]*testhybridpb.TestAllTypes_RepeatedGroup{}) }, + setNil: func() { x.SetRepeatedgroup(nil) }, + len: func() int { + if x.GetRepeatedgroup() == nil { + return -1 + } + return len(x.GetRepeatedgroup()) + }, + }, + { + fName: "repeated_nested_message", + set: func() { x.SetRepeatedNestedMessage([]*testhybridpb.TestAllTypes_NestedMessage{}) }, + setNil: func() { x.SetRepeatedNestedMessage(nil) }, + len: func() int { + if x.GetRepeatedNestedMessage() == nil { + return -1 + } + return len(x.GetRepeatedNestedMessage()) + }, + }, + { + fName: "repeated_nested_enum", + set: func() { x.SetRepeatedNestedEnum([]testhybridpb.TestAllTypes_NestedEnum{}) }, + setNil: func() { x.SetRepeatedNestedEnum(nil) }, + len: func() int { + if x.GetRepeatedNestedEnum() == nil { + return -1 + } + return len(x.GetRepeatedNestedEnum()) + }, + }, + } + + for _, mv := range tab { + if mv.len() != -1 { + t.Errorf("Repeated field %s was not nil to start with ", mv.fName) + } + mv.set() + if mv.len() != 0 { + t.Errorf("Repeated field %s did not retain empty slice ", mv.fName) + } + b, err := proto.Marshal(x) + if err != nil { + t.Fatalf("Failed to marshal message, err = %v", err) + } + proto.Unmarshal(b, x) + if mv.len() != -1 { + t.Errorf("Repeated field %s was not nil to start with ", mv.fName) + } + mv.set() + mv.setNil() + if mv.len() != -1 { + t.Errorf("Repeated field %s was not nil event though we set it to ", mv.fName) + } + + } + + // Check that we actually retain the same slice + s := make([]testhybridpb.TestAllTypes_NestedEnum, 0, 455) + x.SetRepeatedNestedEnum(s) + if got, want := cap(x.GetRepeatedNestedEnum()), 455; got != want { + t.Errorf("cap(x.GetRepeatedNestedEnum()) returned %v, expected %v", got, want) + } + // Do this for a message too + s2 := make([]*testhybridpb.TestAllTypes_NestedMessage, 0, 544) + x.SetRepeatedNestedMessage(s2) + if got, want := cap(x.GetRepeatedNestedMessage()), 544; got != want { + t.Errorf("cap(x.GetRepeatedNestedMessage()) returned %v, expected %v", got, want) + } + // Check special bytes behavior + x.SetOptionalBytes(nil) + if got, want := x.HasOptionalBytes(), true; got != want { + t.Errorf("HasOptionalBytes after setting to nil returned %v, expected %v", got, want) + } + if got := x.GetOptionalBytes(); got == nil || len(got) != 0 { + t.Errorf("GetOptionalBytes after setting to nil returned %v, expected %v", got, []byte{}) + } + +} + +func TestOpaqueSetRepeatedNilReceiver(t *testing.T) { + var x *testopaquepb.TestAllTypes + expectPanic(t, func() { + x.SetRepeatedUint32(nil) + }, "Setting repeated field on nil receiver did not panic.") +} + +func TestOpaqueSetRepeated(t *testing.T) { + for _, mode := range []bool{true, false} { + impl.EnableLazyUnmarshal(mode) + t.Run(fmt.Sprintf("LazyUnmarshal_%t", mode), testOpaqueSetRepeatedSub) + } +} + +func testOpaqueSetRepeatedSub(t *testing.T) { + x := &testopaquepb.TestAllTypes{} + + tab := []struct { + fName string // Field name (in proto) + set func() // Set the field to empty slice + setNil func() // Set the field to nil + len func() int // length of field, -1 if nil + }{ + { + fName: "repeated_int32", + set: func() { x.SetRepeatedInt32([]int32{}) }, + setNil: func() { x.SetRepeatedInt32(nil) }, + len: func() int { + if x.GetRepeatedInt32() == nil { + return -1 + } + return len(x.GetRepeatedInt32()) + }, + }, + { + fName: "repeated_int64", + set: func() { x.SetRepeatedInt64([]int64{}) }, + setNil: func() { x.SetRepeatedInt64(nil) }, + len: func() int { + if x.GetRepeatedInt64() == nil { + return -1 + } + return len(x.GetRepeatedInt64()) + }, + }, + { + fName: "repeated_uint32", + set: func() { x.SetRepeatedUint32([]uint32{}) }, + setNil: func() { x.SetRepeatedUint32(nil) }, + len: func() int { + if x.GetRepeatedUint32() == nil { + return -1 + } + return len(x.GetRepeatedUint32()) + }, + }, + { + fName: "repeated_uint64", + set: func() { x.SetRepeatedUint64([]uint64{}) }, + setNil: func() { x.SetRepeatedUint64(nil) }, + len: func() int { + if x.GetRepeatedUint64() == nil { + return -1 + } + return len(x.GetRepeatedUint64()) + }, + }, + { + fName: "repeated_sint32", + set: func() { x.SetRepeatedSint32([]int32{}) }, + setNil: func() { x.SetRepeatedSint32(nil) }, + len: func() int { + if x.GetRepeatedSint32() == nil { + return -1 + } + return len(x.GetRepeatedSint32()) + }, + }, + { + fName: "repeated_sint64", + set: func() { x.SetRepeatedSint64([]int64{}) }, + setNil: func() { x.SetRepeatedSint64(nil) }, + len: func() int { + if x.GetRepeatedSint64() == nil { + return -1 + } + return len(x.GetRepeatedSint64()) + }, + }, + { + fName: "repeated_fixed32", + set: func() { x.SetRepeatedFixed32([]uint32{}) }, + setNil: func() { x.SetRepeatedFixed32(nil) }, + len: func() int { + if x.GetRepeatedFixed32() == nil { + return -1 + } + return len(x.GetRepeatedFixed32()) + }, + }, + { + fName: "repeated_fixed64", + set: func() { x.SetRepeatedFixed64([]uint64{}) }, + setNil: func() { x.SetRepeatedFixed64(nil) }, + len: func() int { + if x.GetRepeatedFixed64() == nil { + return -1 + } + return len(x.GetRepeatedFixed64()) + }, + }, + { + fName: "repeated_sfixed32", + set: func() { x.SetRepeatedSfixed32([]int32{}) }, + setNil: func() { x.SetRepeatedSfixed32(nil) }, + len: func() int { + if x.GetRepeatedSfixed32() == nil { + return -1 + } + return len(x.GetRepeatedSfixed32()) + }, + }, + { + fName: "repeated_sfixed64", + set: func() { x.SetRepeatedSfixed64([]int64{}) }, + setNil: func() { x.SetRepeatedSfixed64(nil) }, + len: func() int { + if x.GetRepeatedSfixed64() == nil { + return -1 + } + return len(x.GetRepeatedSfixed64()) + }, + }, + { + fName: "repeated_float", + set: func() { x.SetRepeatedFloat([]float32{}) }, + setNil: func() { x.SetRepeatedFloat(nil) }, + len: func() int { + if x.GetRepeatedFloat() == nil { + return -1 + } + return len(x.GetRepeatedFloat()) + }, + }, + { + fName: "repeated_double", + set: func() { x.SetRepeatedDouble([]float64{}) }, + setNil: func() { x.SetRepeatedDouble(nil) }, + len: func() int { + if x.GetRepeatedDouble() == nil { + return -1 + } + return len(x.GetRepeatedDouble()) + }, + }, + { + fName: "repeated_bool", + set: func() { x.SetRepeatedBool([]bool{}) }, + setNil: func() { x.SetRepeatedBool(nil) }, + len: func() int { + if x.GetRepeatedBool() == nil { + return -1 + } + return len(x.GetRepeatedBool()) + }, + }, + { + fName: "repeated_string", + set: func() { x.SetRepeatedString([]string{}) }, + setNil: func() { x.SetRepeatedString(nil) }, + len: func() int { + if x.GetRepeatedString() == nil { + return -1 + } + return len(x.GetRepeatedString()) + }, + }, + { + fName: "repeated_bytes", + set: func() { x.SetRepeatedBytes([][]byte{}) }, + setNil: func() { x.SetRepeatedBytes(nil) }, + len: func() int { + if x.GetRepeatedBytes() == nil { + return -1 + } + return len(x.GetRepeatedBytes()) + }, + }, + { + fName: "RepeatedGroup", + set: func() { x.SetRepeatedgroup([]*testopaquepb.TestAllTypes_RepeatedGroup{}) }, + setNil: func() { x.SetRepeatedgroup(nil) }, + len: func() int { + if x.GetRepeatedgroup() == nil { + return -1 + } + return len(x.GetRepeatedgroup()) + }, + }, + { + fName: "repeated_nested_message", + set: func() { x.SetRepeatedNestedMessage([]*testopaquepb.TestAllTypes_NestedMessage{}) }, + setNil: func() { x.SetRepeatedNestedMessage(nil) }, + len: func() int { + if x.GetRepeatedNestedMessage() == nil { + return -1 + } + return len(x.GetRepeatedNestedMessage()) + }, + }, + { + fName: "repeated_nested_enum", + set: func() { x.SetRepeatedNestedEnum([]testopaquepb.TestAllTypes_NestedEnum{}) }, + setNil: func() { x.SetRepeatedNestedEnum(nil) }, + len: func() int { + if x.GetRepeatedNestedEnum() == nil { + return -1 + } + return len(x.GetRepeatedNestedEnum()) + }, + }, + } + + for _, mv := range tab { + if mv.len() != -1 { + t.Errorf("Repeated field %s was not nil to start with ", mv.fName) + } + mv.set() + if mv.len() != 0 { + t.Errorf("Repeated field %s did not retain empty slice ", mv.fName) + } + b, err := proto.Marshal(x) + if err != nil { + t.Fatalf("Failed to marshal message, err = %v", err) + } + proto.Unmarshal(b, x) + if mv.len() != -1 { + t.Errorf("Repeated field %s was not nil to start with ", mv.fName) + } + mv.set() + mv.setNil() + if mv.len() != -1 { + t.Errorf("Repeated field %s was not nil event though we set it to ", mv.fName) + } + + } + + // Check that we actually retain the same slice + s := make([]testopaquepb.TestAllTypes_NestedEnum, 0, 455) + x.SetRepeatedNestedEnum(s) + if got, want := cap(x.GetRepeatedNestedEnum()), 455; got != want { + t.Errorf("cap(x.GetRepeatedNestedEnum()) returned %v, expected %v", got, want) + } + // Do this for a message too + s2 := make([]*testopaquepb.TestAllTypes_NestedMessage, 0, 544) + x.SetRepeatedNestedMessage(s2) + if got, want := cap(x.GetRepeatedNestedMessage()), 544; got != want { + t.Errorf("cap(x.GetRepeatedNestedMessage()) returned %v, expected %v", got, want) + t.Errorf("present: %v, isNilen: %v", checkPresent(x, 34), x.GetRepeatedNestedMessage()) + } + // Check special bytes behavior + x.SetOptionalBytes(nil) + if got, want := x.HasOptionalBytes(), true; got != want { + t.Errorf("HasOptionalBytes after setting to nil returned %v, expected %v", got, want) + } + if got := x.GetOptionalBytes(); got == nil || len(got) != 0 { + t.Errorf("GetOptionalBytes after setting to nil returned %v, expected %v", got, []byte{}) + } +} + +func checkPresent(m proto.Message, fn uint32) bool { + vv := reflect.ValueOf(m).Elem() + rf := vv.FieldByName("XXX_presence") + rf = reflect.NewAt(rf.Type(), unsafe.Pointer(rf.UnsafeAddr())).Elem() + ai := int(fn) / 32 + bit := fn % 32 + ptr := rf.Index(ai).Addr().Interface().(*uint32) + return (*ptr & (1 << bit)) > 0 +} diff --git a/proto/testmessages_opaque_test.go b/proto/testmessages_opaque_test.go new file mode 100644 index 000000000..2cc51f1da --- /dev/null +++ b/proto/testmessages_opaque_test.go @@ -0,0 +1,97 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package proto_test + +import ( + "flag" + "fmt" + "os" + "strings" + "testing" + + "google.golang.org/protobuf/internal/impl" + "google.golang.org/protobuf/internal/protobuild" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" + "google.golang.org/protobuf/testing/protopack" + + _ "google.golang.org/protobuf/internal/testprotos/lazy" + _ "google.golang.org/protobuf/internal/testprotos/lazy/lazy_opaque" + _ "google.golang.org/protobuf/internal/testprotos/required" + _ "google.golang.org/protobuf/internal/testprotos/required/required_opaque" + _ "google.golang.org/protobuf/internal/testprotos/test" + _ "google.golang.org/protobuf/internal/testprotos/test/weak1" + _ "google.golang.org/protobuf/internal/testprotos/test3" + _ "google.golang.org/protobuf/internal/testprotos/testeditions/testeditions_hybrid" + _ "google.golang.org/protobuf/internal/testprotos/testeditions/testeditions_opaque" +) + +var testLazyUnmarshal = flag.Bool("test_lazy_unmarshal", false, "test lazy proto.Unmarshal") + +func TestMain(m *testing.M) { + flag.Parse() + impl.EnableLazyUnmarshal(*testLazyUnmarshal) + os.Exit(m.Run()) +} + +var relatedMessages = func() map[protoreflect.MessageType][]protoreflect.MessageType { + related := map[protoreflect.MessageType][]protoreflect.MessageType{} + const opaqueNamePrefix = "opaque." + protoregistry.GlobalTypes.RangeMessages(func(mt protoreflect.MessageType) bool { + name := mt.Descriptor().FullName() + if !strings.HasPrefix(string(name), opaqueNamePrefix) { + return true + } + mt1, err := protoregistry.GlobalTypes.FindMessageByName(name[len(opaqueNamePrefix):]) + if err != nil { + panic(fmt.Sprintf("%v: can't find related message", name)) + } + related[mt1] = append(related[mt1], mt) + return true + }) + return related +}() + +func init() { + testValidMessages = append(testValidMessages, []testProto{ + { + desc: "lazy field contains wrong wire type", + checkFastInit: true, + decodeTo: makeMessages(protobuild.Message{ + "optional_nested_message": protobuild.Message{ + protobuild.Unknown: protopack.Message{ + protopack.Tag{2, protopack.VarintType}, protopack.Varint(3), + }.Marshal(), + }, + }), + wire: protopack.Message{ + protopack.Tag{18, protopack.BytesType}, protopack.LengthPrefix(protopack.Message{ + protopack.Tag{2, protopack.VarintType}, protopack.Varint(3), + }), + }.Marshal(), + }, { + desc: "lazy field contains right and wrong wire type", + checkFastInit: true, + decodeTo: makeMessages(protobuild.Message{ + "optional_nested_message": protobuild.Message{ + "corecursive": protobuild.Message{ + "optional_int32": 2, + }, + protobuild.Unknown: protopack.Message{ + protopack.Tag{2, protopack.VarintType}, protopack.Varint(3), + }.Marshal(), + }, + }), + wire: protopack.Message{ + protopack.Tag{18, protopack.BytesType}, protopack.LengthPrefix(protopack.Message{ + protopack.Tag{2, protopack.BytesType}, protopack.LengthPrefix(protopack.Message{ + protopack.Tag{1, protopack.VarintType}, protopack.Varint(2), + }), + protopack.Tag{2, protopack.VarintType}, protopack.Varint(3), + }), + }.Marshal(), + }, + }...) +} diff --git a/proto/testmessages_test.go b/proto/testmessages_test.go index b1fccf80e..4fb0e618b 100644 --- a/proto/testmessages_test.go +++ b/proto/testmessages_test.go @@ -41,6 +41,13 @@ func makeMessages(in protobuild.Message, messages ...proto.Message) []proto.Mess &testeditionspb.TestAllTypes{}, } } + + for _, m := range messages { + for _, mt := range relatedMessages[m.ProtoReflect().Type()] { + messages = append(messages, mt.New().Interface()) + } + } + for _, m := range messages { in.Build(m.ProtoReflect()) } @@ -56,6 +63,13 @@ func templateMessages(messages ...proto.Message) []protoreflect.MessageType { (*testeditionspb.TestAllTypes)(nil), } } + + for _, m := range messages { + for _, mt := range relatedMessages[m.ProtoReflect().Type()] { + messages = append(messages, mt.New().Interface()) + } + } + var out []protoreflect.MessageType for _, m := range messages { out = append(out, m.ProtoReflect().Type()) diff --git a/proto/wrapperopaque.go b/proto/wrapperopaque.go new file mode 100644 index 000000000..267fd0f1f --- /dev/null +++ b/proto/wrapperopaque.go @@ -0,0 +1,80 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package proto + +// ValueOrNil returns nil if has is false, or a pointer to a new variable +// containing the value returned by the specified getter. +// +// This function is similar to the wrappers (proto.Int32(), proto.String(), +// etc.), but is generic (works for any field type) and works with the hasser +// and getter of a field, as opposed to a value. +// +// This is convenient when populating builder fields. +// +// Example: +// +// hop := attr.GetDirectHop() +// injectedRoute := ripb.InjectedRoute_builder{ +// Prefixes: route.GetPrefixes(), +// NextHop: proto.ValueOrNil(hop.HasAddress(), hop.GetAddress), +// } +func ValueOrNil[T any](has bool, getter func() T) *T { + if !has { + return nil + } + v := getter() + return &v +} + +// ValueOrDefault returns the protobuf message val if val is not nil, otherwise +// it returns a pointer to an empty val message. +// +// This function allows for translating code from the old Open Struct API to the +// new Opaque API. +// +// The old Open Struct API represented oneof fields with a wrapper struct: +// +// var signedImg *accountpb.SignedImage +// profile := &accountpb.Profile{ +// // The Avatar oneof will be set, with an empty SignedImage. +// Avatar: &accountpb.Profile_SignedImage{signedImg}, +// } +// +// The new Opaque API treats oneof fields like regular fields, there are no more +// wrapper structs: +// +// var signedImg *accountpb.SignedImage +// profile := &accountpb.Profile{} +// profile.SetSignedImage(signedImg) +// +// For convenience, the Opaque API also offers Builders, which allow for a +// direct translation of struct initialization. However, because Builders use +// nilness to represent field presence (but there is no non-nil wrapper struct +// anymore), Builders cannot distinguish between an unset oneof and a set oneof +// with nil message. The above code would need to be translated with help of the +// ValueOrDefault function to retain the same behavior: +// +// var signedImg *accountpb.SignedImage +// return &accountpb.Profile_builder{ +// SignedImage: proto.ValueOrDefault(signedImg), +// }.Build() +func ValueOrDefault[T interface { + *P + Message +}, P any](val T) T { + if val == nil { + return T(new(P)) + } + return val +} + +// ValueOrDefaultBytes is like ValueOrDefault but for working with fields of +// type []byte. +func ValueOrDefaultBytes(val []byte) []byte { + if val == nil { + return []byte{} + } + return val +} diff --git a/proto/wrapperopaque_test.go b/proto/wrapperopaque_test.go new file mode 100644 index 000000000..0fdd4a0ea --- /dev/null +++ b/proto/wrapperopaque_test.go @@ -0,0 +1,173 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package proto_test + +import ( + "bytes" + "testing" + + testhybridpb "google.golang.org/protobuf/internal/testprotos/testeditions/testeditions_hybrid" + "google.golang.org/protobuf/proto" +) + +func TestOneofOrDefault(t *testing.T) { + for _, tt := range []struct { + desc string + input func() *testhybridpb.TestAllTypes + }{ + { + desc: "struct literal with nil nested message", + input: func() *testhybridpb.TestAllTypes { + return &testhybridpb.TestAllTypes{ + OneofField: &testhybridpb.TestAllTypes_OneofNestedMessage{ + OneofNestedMessage: nil, + }, + } + }, + }, + + { + desc: "struct literal with non-nil nested message", + input: func() *testhybridpb.TestAllTypes { + return &testhybridpb.TestAllTypes{ + OneofField: &testhybridpb.TestAllTypes_OneofNestedMessage{ + OneofNestedMessage: &testhybridpb.TestAllTypes_NestedMessage{}, + }, + } + }, + }, + + { + desc: "opaque setter with ValueOrDefault", + input: func() *testhybridpb.TestAllTypes { + msg := &testhybridpb.TestAllTypes{} + msg.ClearOneofString() + var val *testhybridpb.TestAllTypes_NestedMessage + msg.SetOneofNestedMessage(proto.ValueOrDefault(val)) + return msg + }, + }, + + { + desc: "opaque builder with ValueOrDefault", + input: func() *testhybridpb.TestAllTypes { + var val *testhybridpb.TestAllTypes_NestedMessage + return testhybridpb.TestAllTypes_builder{ + OneofNestedMessage: proto.ValueOrDefault(val), + }.Build() + }, + }, + } { + t.Run(tt.desc, func(t *testing.T) { + msg := tt.input() + b, err := proto.Marshal(msg) + if err != nil { + t.Fatal(err) + } + want := []byte{130, 7, 0} + if !bytes.Equal(b, want) { + t.Fatalf("Marshal: got %x, want %x", b, want) + } + if !msg.HasOneofField() { + t.Fatalf("HasOneofField was false, want true") + } + if got, want := msg.WhichOneofField(), testhybridpb.TestAllTypes_OneofNestedMessage_case; got != want { + t.Fatalf("WhichOneofField: got %v, want %v", got, want) + } + if !msg.HasOneofNestedMessage() { + t.Fatalf("HasOneofNestedMessage was false, want true") + } + if msg.HasOneofString() { + t.Fatalf("HasOneofString was true, want false") + } + }) + } +} + +func TestOneofOrDefaultBytes(t *testing.T) { + for _, tt := range []struct { + desc string + input func() *testhybridpb.TestAllTypes + wantWire []byte + }{ + { + desc: "struct literal with nil bytes", + input: func() *testhybridpb.TestAllTypes { + return &testhybridpb.TestAllTypes{ + OneofField: &testhybridpb.TestAllTypes_OneofBytes{ + OneofBytes: nil, + }, + } + }, + }, + + { + desc: "struct literal with non-nil bytes", + input: func() *testhybridpb.TestAllTypes { + return &testhybridpb.TestAllTypes{ + OneofField: &testhybridpb.TestAllTypes_OneofBytes{ + OneofBytes: []byte{}, + }, + } + }, + }, + + { + desc: "opaque setter with ValueOrDefaultBytes", + input: func() *testhybridpb.TestAllTypes { + msg := &testhybridpb.TestAllTypes{} + msg.ClearOneofString() + var val []byte + msg.SetOneofBytes(proto.ValueOrDefaultBytes(val)) + return msg + }, + }, + + { + desc: "opaque setter", + input: func() *testhybridpb.TestAllTypes { + msg := &testhybridpb.TestAllTypes{} + msg.ClearOneofString() + var val []byte + msg.SetOneofBytes(val) + return msg + }, + }, + + { + desc: "opaque builder with ValueOrDefaultBytes", + input: func() *testhybridpb.TestAllTypes { + var val []byte + return testhybridpb.TestAllTypes_builder{ + OneofBytes: proto.ValueOrDefaultBytes(val), + }.Build() + }, + }, + } { + t.Run(tt.desc, func(t *testing.T) { + msg := tt.input() + b, err := proto.Marshal(msg) + if err != nil { + t.Fatal(err) + } + want := []byte{146, 7, 0} + if !bytes.Equal(b, want) { + t.Fatalf("Marshal: got %x, want %x", b, want) + } + if !msg.HasOneofField() { + t.Fatalf("HasOneofField was false, want true") + } + if got, want := msg.WhichOneofField(), testhybridpb.TestAllTypes_OneofBytes_case; got != want { + t.Fatalf("WhichOneofField: got %v, want %v", got, want) + } + if !msg.HasOneofBytes() { + t.Fatalf("HasOneofBytes was false, want true") + } + if msg.HasOneofString() { + t.Fatalf("HasOneofString was true, want false") + } + }) + } +} diff --git a/reflect/protodesc/editions.go b/reflect/protodesc/editions.go index d0aeab958..bf0a0ccde 100644 --- a/reflect/protodesc/editions.go +++ b/reflect/protodesc/editions.go @@ -132,6 +132,9 @@ func mergeEditionFeatures(parentDesc protoreflect.Descriptor, child *descriptorp if sep := goFeatures.StripEnumPrefix; sep != nil { parentFS.StripEnumPrefix = int(*sep) } + if al := goFeatures.ApiLevel; al != nil { + parentFS.APILevel = int(*al) + } } return parentFS diff --git a/runtime/protoiface/methods.go b/runtime/protoiface/methods.go index 246156561..28e9e9f03 100644 --- a/runtime/protoiface/methods.go +++ b/runtime/protoiface/methods.go @@ -122,6 +122,22 @@ type UnmarshalInputFlags = uint8 const ( UnmarshalDiscardUnknown UnmarshalInputFlags = 1 << iota + + // UnmarshalAliasBuffer permits unmarshal operations to alias the input buffer. + // The unmarshaller must not modify the contents of the buffer. + UnmarshalAliasBuffer + + // UnmarshalValidated indicates that validation has already been + // performed on the input buffer. + UnmarshalValidated + + // UnmarshalCheckRequired is set if this unmarshal operation ultimately will care if required fields are + // initialized. + UnmarshalCheckRequired + + // UnmarshalNoLazyDecoding is set if this unmarshal operation should not use + // lazy decoding, even when otherwise available. + UnmarshalNoLazyDecoding ) // UnmarshalOutputFlags are output from the Unmarshal method. diff --git a/runtime/protoimpl/impl.go b/runtime/protoimpl/impl.go index 4a1ab7fb3..93df1b569 100644 --- a/runtime/protoimpl/impl.go +++ b/runtime/protoimpl/impl.go @@ -15,6 +15,7 @@ import ( "google.golang.org/protobuf/internal/filedesc" "google.golang.org/protobuf/internal/filetype" "google.golang.org/protobuf/internal/impl" + "google.golang.org/protobuf/internal/protolazy" ) // UnsafeEnabled specifies whether package unsafe can be used. @@ -39,6 +40,9 @@ type ( ExtensionFieldV1 = impl.ExtensionField Pointer = impl.Pointer + + LazyUnmarshalInfo = *protolazy.XXX_lazyUnmarshalInfo + RaceDetectHookData = impl.RaceDetectHookData ) var X impl.Export diff --git a/runtime/protolazy/protolazy.go b/runtime/protolazy/protolazy.go new file mode 100644 index 000000000..cae330dc5 --- /dev/null +++ b/runtime/protolazy/protolazy.go @@ -0,0 +1,31 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package protolazy controls the lazy implementation in the protobuf runtime. +// +// The following logic determines whether lazy decoding is enabled: +// 1. Lazy decoding is enabled by default, unless the environment variable +// GOPROTODEBUG=nolazy is set. +// 2. If still on, calling protolazy.Disable() turns off lazy decoding. +// 3. If still on, proto.UnmarshalOptions's NoLazyDecoding turns off +// lazy decoding for this Unmarshal operation only. +package protolazy + +import ( + "google.golang.org/protobuf/internal/impl" +) + +// Disable disables lazy unmarshaling of opaque messages. +// +// Messages which are still on the OPEN or HYBRID API level (see +// http://go/go-opaque-proto) are never lazily unmarshalled. +// +// Fields must be annotated with [lazy = true] in their .proto file to become +// eligible for lazy unmarshaling. +func Disable() (reenable func()) { + impl.EnableLazyUnmarshal(false) + return func() { + impl.EnableLazyUnmarshal(true) + } +} diff --git a/src/google/protobuf/go_features.proto b/src/google/protobuf/go_features.proto index 7ab74f5ce..a7e14f85e 100644 --- a/src/google/protobuf/go_features.proto +++ b/src/google/protobuf/go_features.proto @@ -33,6 +33,28 @@ message GoFeatures { edition_defaults = { edition: EDITION_PROTO3, value: "false" } ]; + enum APILevel { + // API_LEVEL_UNSPECIFIED results in selecting the OPEN API, + // but needs to be a separate value to distinguish between + // an explicitly set api level or a missing api level. + API_LEVEL_UNSPECIFIED = 0; + API_OPEN = 1; + API_HYBRID = 2; + API_OPAQUE = 3; + } + + // One of OPEN, HYBRID or OPAQUE. + optional APILevel api_level = 2 [ + retention = RETENTION_RUNTIME, + targets = TARGET_TYPE_MESSAGE, + targets = TARGET_TYPE_FILE, + feature_support = { + edition_introduced: EDITION_2023, + }, + edition_defaults = { edition: EDITION_LEGACY, value: "API_LEVEL_UNSPECIFIED" }, + edition_defaults = { edition: EDITION_2024, value: "API_OPAQUE" } + ]; + enum StripEnumPrefix { STRIP_ENUM_PREFIX_UNSPECIFIED = 0; STRIP_ENUM_PREFIX_KEEP = 1; diff --git a/testing/prototest/message.go b/testing/prototest/message.go index eaf53cfe4..def37bff3 100644 --- a/testing/prototest/message.go +++ b/testing/prototest/message.go @@ -33,6 +33,10 @@ type Message struct { FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) RangeExtensionsByMessage(message protoreflect.FullName, f func(protoreflect.ExtensionType) bool) } + + // UnmarshalOptions are respected for every Unmarshal call this package + // does. The Resolver and AllowPartial fields are overridden. + UnmarshalOptions proto.UnmarshalOptions } // Test performs tests on a [protoreflect.MessageType] implementation. @@ -74,10 +78,10 @@ func (test Message) Test(t testing.TB, mt protoreflect.MessageType) { t.Errorf("Marshal() = %v, want nil\n%v", err, prototext.Format(m2)) } m3 := mt.New().Interface() - if err := (proto.UnmarshalOptions{ - AllowPartial: true, - Resolver: test.Resolver, - }.Unmarshal(b, m3)); err != nil { + unmarshalOpts := test.UnmarshalOptions + unmarshalOpts.AllowPartial = true + unmarshalOpts.Resolver = test.Resolver + if err := unmarshalOpts.Unmarshal(b, m3); err != nil { t.Errorf("Unmarshal() = %v, want nil\n%v", err, prototext.Format(m2)) } if !proto.Equal(m2, m3) { diff --git a/types/gofeaturespb/go_features.pb.go b/types/gofeaturespb/go_features.pb.go index 5067b89e9..61e3f7664 100644 --- a/types/gofeaturespb/go_features.pb.go +++ b/types/gofeaturespb/go_features.pb.go @@ -18,6 +18,71 @@ import ( sync "sync" ) +type GoFeatures_APILevel int32 + +const ( + // API_LEVEL_UNSPECIFIED results in selecting the OPEN API, + // but needs to be a separate value to distinguish between + // an explicitly set api level or a missing api level. + GoFeatures_API_LEVEL_UNSPECIFIED GoFeatures_APILevel = 0 + GoFeatures_API_OPEN GoFeatures_APILevel = 1 + GoFeatures_API_HYBRID GoFeatures_APILevel = 2 + GoFeatures_API_OPAQUE GoFeatures_APILevel = 3 +) + +// Enum value maps for GoFeatures_APILevel. +var ( + GoFeatures_APILevel_name = map[int32]string{ + 0: "API_LEVEL_UNSPECIFIED", + 1: "API_OPEN", + 2: "API_HYBRID", + 3: "API_OPAQUE", + } + GoFeatures_APILevel_value = map[string]int32{ + "API_LEVEL_UNSPECIFIED": 0, + "API_OPEN": 1, + "API_HYBRID": 2, + "API_OPAQUE": 3, + } +) + +func (x GoFeatures_APILevel) Enum() *GoFeatures_APILevel { + p := new(GoFeatures_APILevel) + *p = x + return p +} + +func (x GoFeatures_APILevel) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (GoFeatures_APILevel) Descriptor() protoreflect.EnumDescriptor { + return file_google_protobuf_go_features_proto_enumTypes[0].Descriptor() +} + +func (GoFeatures_APILevel) Type() protoreflect.EnumType { + return &file_google_protobuf_go_features_proto_enumTypes[0] +} + +func (x GoFeatures_APILevel) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Do not use. +func (x *GoFeatures_APILevel) UnmarshalJSON(b []byte) error { + num, err := protoimpl.X.UnmarshalJSONEnum(x.Descriptor(), b) + if err != nil { + return err + } + *x = GoFeatures_APILevel(num) + return nil +} + +// Deprecated: Use GoFeatures_APILevel.Descriptor instead. +func (GoFeatures_APILevel) EnumDescriptor() ([]byte, []int) { + return file_google_protobuf_go_features_proto_rawDescGZIP(), []int{0, 0} +} + type GoFeatures_StripEnumPrefix int32 const ( @@ -54,11 +119,11 @@ func (x GoFeatures_StripEnumPrefix) String() string { } func (GoFeatures_StripEnumPrefix) Descriptor() protoreflect.EnumDescriptor { - return file_google_protobuf_go_features_proto_enumTypes[0].Descriptor() + return file_google_protobuf_go_features_proto_enumTypes[1].Descriptor() } func (GoFeatures_StripEnumPrefix) Type() protoreflect.EnumType { - return &file_google_protobuf_go_features_proto_enumTypes[0] + return &file_google_protobuf_go_features_proto_enumTypes[1] } func (x GoFeatures_StripEnumPrefix) Number() protoreflect.EnumNumber { @@ -77,17 +142,18 @@ func (x *GoFeatures_StripEnumPrefix) UnmarshalJSON(b []byte) error { // Deprecated: Use GoFeatures_StripEnumPrefix.Descriptor instead. func (GoFeatures_StripEnumPrefix) EnumDescriptor() ([]byte, []int) { - return file_google_protobuf_go_features_proto_rawDescGZIP(), []int{0, 0} + return file_google_protobuf_go_features_proto_rawDescGZIP(), []int{0, 1} } type GoFeatures struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - + state protoimpl.MessageState `protogen:"open.v1"` // Whether or not to generate the deprecated UnmarshalJSON method for enums. - LegacyUnmarshalJsonEnum *bool `protobuf:"varint,1,opt,name=legacy_unmarshal_json_enum,json=legacyUnmarshalJsonEnum" json:"legacy_unmarshal_json_enum,omitempty"` - StripEnumPrefix *GoFeatures_StripEnumPrefix `protobuf:"varint,3,opt,name=strip_enum_prefix,json=stripEnumPrefix,enum=pb.GoFeatures_StripEnumPrefix" json:"strip_enum_prefix,omitempty"` + LegacyUnmarshalJsonEnum *bool `protobuf:"varint,1,opt,name=legacy_unmarshal_json_enum,json=legacyUnmarshalJsonEnum" json:"legacy_unmarshal_json_enum,omitempty"` + // One of OPEN, HYBRID or OPAQUE. + ApiLevel *GoFeatures_APILevel `protobuf:"varint,2,opt,name=api_level,json=apiLevel,enum=pb.GoFeatures_APILevel" json:"api_level,omitempty"` + StripEnumPrefix *GoFeatures_StripEnumPrefix `protobuf:"varint,3,opt,name=strip_enum_prefix,json=stripEnumPrefix,enum=pb.GoFeatures_StripEnumPrefix" json:"strip_enum_prefix,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *GoFeatures) Reset() { @@ -127,6 +193,13 @@ func (x *GoFeatures) GetLegacyUnmarshalJsonEnum() bool { return false } +func (x *GoFeatures) GetApiLevel() GoFeatures_APILevel { + if x != nil && x.ApiLevel != nil { + return *x.ApiLevel + } + return GoFeatures_API_LEVEL_UNSPECIFIED +} + func (x *GoFeatures) GetStripEnumPrefix() GoFeatures_StripEnumPrefix { if x != nil && x.StripEnumPrefix != nil { return *x.StripEnumPrefix @@ -158,7 +231,7 @@ var file_google_protobuf_go_features_proto_rawDesc = []byte{ 0x66, 0x2f, 0x67, 0x6f, 0x5f, 0x66, 0x65, 0x61, 0x74, 0x75, 0x72, 0x65, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x02, 0x70, 0x62, 0x1a, 0x20, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x64, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, - 0x74, 0x6f, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xe0, 0x03, 0x0a, 0x0a, 0x47, 0x6f, + 0x74, 0x6f, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xab, 0x05, 0x0a, 0x0a, 0x47, 0x6f, 0x46, 0x65, 0x61, 0x74, 0x75, 0x72, 0x65, 0x73, 0x12, 0xbe, 0x01, 0x0a, 0x1a, 0x6c, 0x65, 0x67, 0x61, 0x63, 0x79, 0x5f, 0x75, 0x6e, 0x6d, 0x61, 0x72, 0x73, 0x68, 0x61, 0x6c, 0x5f, 0x6a, 0x73, 0x6f, 0x6e, 0x5f, 0x65, 0x6e, 0x75, 0x6d, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x42, 0x80, 0x01, @@ -171,31 +244,44 @@ var file_google_protobuf_go_features_proto_rawDesc = []byte{ 0x20, 0x62, 0x65, 0x20, 0x72, 0x65, 0x6d, 0x6f, 0x76, 0x65, 0x64, 0x20, 0x69, 0x6e, 0x20, 0x61, 0x20, 0x66, 0x75, 0x74, 0x75, 0x72, 0x65, 0x20, 0x65, 0x64, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x52, 0x17, 0x6c, 0x65, 0x67, 0x61, 0x63, 0x79, 0x55, 0x6e, 0x6d, 0x61, 0x72, 0x73, 0x68, 0x61, - 0x6c, 0x4a, 0x73, 0x6f, 0x6e, 0x45, 0x6e, 0x75, 0x6d, 0x12, 0x7c, 0x0a, 0x11, 0x73, 0x74, 0x72, - 0x69, 0x70, 0x5f, 0x65, 0x6e, 0x75, 0x6d, 0x5f, 0x70, 0x72, 0x65, 0x66, 0x69, 0x78, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1e, 0x2e, 0x70, 0x62, 0x2e, 0x47, 0x6f, 0x46, 0x65, 0x61, 0x74, - 0x75, 0x72, 0x65, 0x73, 0x2e, 0x53, 0x74, 0x72, 0x69, 0x70, 0x45, 0x6e, 0x75, 0x6d, 0x50, 0x72, - 0x65, 0x66, 0x69, 0x78, 0x42, 0x30, 0x88, 0x01, 0x01, 0x98, 0x01, 0x06, 0x98, 0x01, 0x07, 0x98, - 0x01, 0x01, 0xa2, 0x01, 0x1b, 0x12, 0x16, 0x53, 0x54, 0x52, 0x49, 0x50, 0x5f, 0x45, 0x4e, 0x55, - 0x4d, 0x5f, 0x50, 0x52, 0x45, 0x46, 0x49, 0x58, 0x5f, 0x4b, 0x45, 0x45, 0x50, 0x18, 0x84, 0x07, - 0xb2, 0x01, 0x03, 0x08, 0xe9, 0x07, 0x52, 0x0f, 0x73, 0x74, 0x72, 0x69, 0x70, 0x45, 0x6e, 0x75, - 0x6d, 0x50, 0x72, 0x65, 0x66, 0x69, 0x78, 0x22, 0x92, 0x01, 0x0a, 0x0f, 0x53, 0x74, 0x72, 0x69, - 0x70, 0x45, 0x6e, 0x75, 0x6d, 0x50, 0x72, 0x65, 0x66, 0x69, 0x78, 0x12, 0x21, 0x0a, 0x1d, 0x53, - 0x54, 0x52, 0x49, 0x50, 0x5f, 0x45, 0x4e, 0x55, 0x4d, 0x5f, 0x50, 0x52, 0x45, 0x46, 0x49, 0x58, - 0x5f, 0x55, 0x4e, 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x00, 0x12, 0x1a, - 0x0a, 0x16, 0x53, 0x54, 0x52, 0x49, 0x50, 0x5f, 0x45, 0x4e, 0x55, 0x4d, 0x5f, 0x50, 0x52, 0x45, - 0x46, 0x49, 0x58, 0x5f, 0x4b, 0x45, 0x45, 0x50, 0x10, 0x01, 0x12, 0x23, 0x0a, 0x1f, 0x53, 0x54, + 0x6c, 0x4a, 0x73, 0x6f, 0x6e, 0x45, 0x6e, 0x75, 0x6d, 0x12, 0x74, 0x0a, 0x09, 0x61, 0x70, 0x69, + 0x5f, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x17, 0x2e, 0x70, + 0x62, 0x2e, 0x47, 0x6f, 0x46, 0x65, 0x61, 0x74, 0x75, 0x72, 0x65, 0x73, 0x2e, 0x41, 0x50, 0x49, + 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x42, 0x3e, 0x88, 0x01, 0x01, 0x98, 0x01, 0x03, 0x98, 0x01, 0x01, + 0xa2, 0x01, 0x1a, 0x12, 0x15, 0x41, 0x50, 0x49, 0x5f, 0x4c, 0x45, 0x56, 0x45, 0x4c, 0x5f, 0x55, + 0x4e, 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, 0x18, 0x84, 0x07, 0xa2, 0x01, 0x0f, + 0x12, 0x0a, 0x41, 0x50, 0x49, 0x5f, 0x4f, 0x50, 0x41, 0x51, 0x55, 0x45, 0x18, 0xe9, 0x07, 0xb2, + 0x01, 0x03, 0x08, 0xe8, 0x07, 0x52, 0x08, 0x61, 0x70, 0x69, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, + 0x7c, 0x0a, 0x11, 0x73, 0x74, 0x72, 0x69, 0x70, 0x5f, 0x65, 0x6e, 0x75, 0x6d, 0x5f, 0x70, 0x72, + 0x65, 0x66, 0x69, 0x78, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1e, 0x2e, 0x70, 0x62, 0x2e, + 0x47, 0x6f, 0x46, 0x65, 0x61, 0x74, 0x75, 0x72, 0x65, 0x73, 0x2e, 0x53, 0x74, 0x72, 0x69, 0x70, + 0x45, 0x6e, 0x75, 0x6d, 0x50, 0x72, 0x65, 0x66, 0x69, 0x78, 0x42, 0x30, 0x88, 0x01, 0x01, 0x98, + 0x01, 0x06, 0x98, 0x01, 0x07, 0x98, 0x01, 0x01, 0xa2, 0x01, 0x1b, 0x12, 0x16, 0x53, 0x54, 0x52, + 0x49, 0x50, 0x5f, 0x45, 0x4e, 0x55, 0x4d, 0x5f, 0x50, 0x52, 0x45, 0x46, 0x49, 0x58, 0x5f, 0x4b, + 0x45, 0x45, 0x50, 0x18, 0x84, 0x07, 0xb2, 0x01, 0x03, 0x08, 0xe9, 0x07, 0x52, 0x0f, 0x73, 0x74, + 0x72, 0x69, 0x70, 0x45, 0x6e, 0x75, 0x6d, 0x50, 0x72, 0x65, 0x66, 0x69, 0x78, 0x22, 0x53, 0x0a, + 0x08, 0x41, 0x50, 0x49, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x19, 0x0a, 0x15, 0x41, 0x50, 0x49, + 0x5f, 0x4c, 0x45, 0x56, 0x45, 0x4c, 0x5f, 0x55, 0x4e, 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, + 0x45, 0x44, 0x10, 0x00, 0x12, 0x0c, 0x0a, 0x08, 0x41, 0x50, 0x49, 0x5f, 0x4f, 0x50, 0x45, 0x4e, + 0x10, 0x01, 0x12, 0x0e, 0x0a, 0x0a, 0x41, 0x50, 0x49, 0x5f, 0x48, 0x59, 0x42, 0x52, 0x49, 0x44, + 0x10, 0x02, 0x12, 0x0e, 0x0a, 0x0a, 0x41, 0x50, 0x49, 0x5f, 0x4f, 0x50, 0x41, 0x51, 0x55, 0x45, + 0x10, 0x03, 0x22, 0x92, 0x01, 0x0a, 0x0f, 0x53, 0x74, 0x72, 0x69, 0x70, 0x45, 0x6e, 0x75, 0x6d, + 0x50, 0x72, 0x65, 0x66, 0x69, 0x78, 0x12, 0x21, 0x0a, 0x1d, 0x53, 0x54, 0x52, 0x49, 0x50, 0x5f, + 0x45, 0x4e, 0x55, 0x4d, 0x5f, 0x50, 0x52, 0x45, 0x46, 0x49, 0x58, 0x5f, 0x55, 0x4e, 0x53, 0x50, + 0x45, 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x00, 0x12, 0x1a, 0x0a, 0x16, 0x53, 0x54, 0x52, + 0x49, 0x50, 0x5f, 0x45, 0x4e, 0x55, 0x4d, 0x5f, 0x50, 0x52, 0x45, 0x46, 0x49, 0x58, 0x5f, 0x4b, + 0x45, 0x45, 0x50, 0x10, 0x01, 0x12, 0x23, 0x0a, 0x1f, 0x53, 0x54, 0x52, 0x49, 0x50, 0x5f, 0x45, + 0x4e, 0x55, 0x4d, 0x5f, 0x50, 0x52, 0x45, 0x46, 0x49, 0x58, 0x5f, 0x47, 0x45, 0x4e, 0x45, 0x52, + 0x41, 0x54, 0x45, 0x5f, 0x42, 0x4f, 0x54, 0x48, 0x10, 0x02, 0x12, 0x1b, 0x0a, 0x17, 0x53, 0x54, 0x52, 0x49, 0x50, 0x5f, 0x45, 0x4e, 0x55, 0x4d, 0x5f, 0x50, 0x52, 0x45, 0x46, 0x49, 0x58, 0x5f, - 0x47, 0x45, 0x4e, 0x45, 0x52, 0x41, 0x54, 0x45, 0x5f, 0x42, 0x4f, 0x54, 0x48, 0x10, 0x02, 0x12, - 0x1b, 0x0a, 0x17, 0x53, 0x54, 0x52, 0x49, 0x50, 0x5f, 0x45, 0x4e, 0x55, 0x4d, 0x5f, 0x50, 0x52, - 0x45, 0x46, 0x49, 0x58, 0x5f, 0x53, 0x54, 0x52, 0x49, 0x50, 0x10, 0x03, 0x3a, 0x3c, 0x0a, 0x02, - 0x67, 0x6f, 0x12, 0x1b, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x65, 0x61, 0x74, 0x75, 0x72, 0x65, 0x53, 0x65, 0x74, 0x18, - 0xea, 0x07, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x70, 0x62, 0x2e, 0x47, 0x6f, 0x46, 0x65, - 0x61, 0x74, 0x75, 0x72, 0x65, 0x73, 0x52, 0x02, 0x67, 0x6f, 0x42, 0x2f, 0x5a, 0x2d, 0x67, 0x6f, - 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x67, 0x6f, 0x6c, 0x61, 0x6e, 0x67, 0x2e, 0x6f, 0x72, 0x67, 0x2f, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2f, 0x67, - 0x6f, 0x66, 0x65, 0x61, 0x74, 0x75, 0x72, 0x65, 0x73, 0x70, 0x62, + 0x53, 0x54, 0x52, 0x49, 0x50, 0x10, 0x03, 0x3a, 0x3c, 0x0a, 0x02, 0x67, 0x6f, 0x12, 0x1b, 0x2e, + 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, + 0x46, 0x65, 0x61, 0x74, 0x75, 0x72, 0x65, 0x53, 0x65, 0x74, 0x18, 0xea, 0x07, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x0e, 0x2e, 0x70, 0x62, 0x2e, 0x47, 0x6f, 0x46, 0x65, 0x61, 0x74, 0x75, 0x72, 0x65, + 0x73, 0x52, 0x02, 0x67, 0x6f, 0x42, 0x2f, 0x5a, 0x2d, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, + 0x67, 0x6f, 0x6c, 0x61, 0x6e, 0x67, 0x2e, 0x6f, 0x72, 0x67, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x62, 0x75, 0x66, 0x2f, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2f, 0x67, 0x6f, 0x66, 0x65, 0x61, 0x74, + 0x75, 0x72, 0x65, 0x73, 0x70, 0x62, } var ( @@ -210,22 +296,24 @@ func file_google_protobuf_go_features_proto_rawDescGZIP() []byte { return file_google_protobuf_go_features_proto_rawDescData } -var file_google_protobuf_go_features_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_google_protobuf_go_features_proto_enumTypes = make([]protoimpl.EnumInfo, 2) var file_google_protobuf_go_features_proto_msgTypes = make([]protoimpl.MessageInfo, 1) var file_google_protobuf_go_features_proto_goTypes = []any{ - (GoFeatures_StripEnumPrefix)(0), // 0: pb.GoFeatures.StripEnumPrefix - (*GoFeatures)(nil), // 1: pb.GoFeatures - (*descriptorpb.FeatureSet)(nil), // 2: google.protobuf.FeatureSet + (GoFeatures_APILevel)(0), // 0: pb.GoFeatures.APILevel + (GoFeatures_StripEnumPrefix)(0), // 1: pb.GoFeatures.StripEnumPrefix + (*GoFeatures)(nil), // 2: pb.GoFeatures + (*descriptorpb.FeatureSet)(nil), // 3: google.protobuf.FeatureSet } var file_google_protobuf_go_features_proto_depIdxs = []int32{ - 0, // 0: pb.GoFeatures.strip_enum_prefix:type_name -> pb.GoFeatures.StripEnumPrefix - 2, // 1: pb.go:extendee -> google.protobuf.FeatureSet - 1, // 2: pb.go:type_name -> pb.GoFeatures - 3, // [3:3] is the sub-list for method output_type - 3, // [3:3] is the sub-list for method input_type - 2, // [2:3] is the sub-list for extension type_name - 1, // [1:2] is the sub-list for extension extendee - 0, // [0:1] is the sub-list for field type_name + 0, // 0: pb.GoFeatures.api_level:type_name -> pb.GoFeatures.APILevel + 1, // 1: pb.GoFeatures.strip_enum_prefix:type_name -> pb.GoFeatures.StripEnumPrefix + 3, // 2: pb.go:extendee -> google.protobuf.FeatureSet + 2, // 3: pb.go:type_name -> pb.GoFeatures + 4, // [4:4] is the sub-list for method output_type + 4, // [4:4] is the sub-list for method input_type + 3, // [3:4] is the sub-list for extension type_name + 2, // [2:3] is the sub-list for extension extendee + 0, // [0:2] is the sub-list for field type_name } func init() { file_google_protobuf_go_features_proto_init() } @@ -238,7 +326,7 @@ func file_google_protobuf_go_features_proto_init() { File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_google_protobuf_go_features_proto_rawDesc, - NumEnums: 1, + NumEnums: 2, NumMessages: 1, NumExtensions: 1, NumServices: 0,