Skip to content

Commit

Permalink
Adds UnsafeCollectionOperations for unsafe access to RepeatedField<T> (
Browse files Browse the repository at this point in the history
…#16772)

This is a proposal to add `UnsafeCollectionOperations `  for fast access on  `RepeatedField<T>`
  #16745

Closes #16772

COPYBARA_INTEGRATE_REVIEW=#16772 from PascalSenn:pse/readonlyspan-proposal-repeated-field dcb862a
PiperOrigin-RevId: 702356972
  • Loading branch information
PascalSenn authored and copybara-github committed Dec 3, 2024
1 parent b2acbd3 commit a1b0088
Show file tree
Hide file tree
Showing 3 changed files with 350 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
#region Copyright notice and license

// Protocol Buffers - Google's data interchange format
// Copyright 2015 Google Inc. All rights reserved.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file or at
// https://developers.google.com/open-source/licenses/bsd

#endregion

using System;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using NUnit.Framework;

namespace Google.Protobuf.Collections;

public class UnsafeCollectionOperationsTest
{
[Test]
public void NullFieldAsSpanValueType()
{
RepeatedField<int> field = null;
Assert.Throws<ArgumentNullException>(() => UnsafeCollectionOperations.AsSpan(field));
}

[Test]
public void NullFieldAsSpanClass()
{
RepeatedField<object> field = null;
Assert.Throws<ArgumentNullException>(() => UnsafeCollectionOperations.AsSpan(field));
}

[Test]
public void FieldAsSpanValueType()
{
var field = new RepeatedField<int>();
foreach (var length in Enumerable.Range(0, 36))
{
field.Clear();
ValidateContentEquality(field, UnsafeCollectionOperations.AsSpan(field));

for (var i = 0; i < length; i++)
{
field.Add(i);
}

ValidateContentEquality(field, UnsafeCollectionOperations.AsSpan(field));

field.Add(length + 1);
ValidateContentEquality(field, UnsafeCollectionOperations.AsSpan(field));
}

static void ValidateContentEquality(RepeatedField<int> field, Span<int> span)
{
Assert.AreEqual(field.Count, span.Length);

for (var i = 0; i < span.Length; i++)
{
Assert.AreEqual(field[i], span[i]);
}
}
}

[Test]
public void FieldAsSpanClass()
{
var field = new RepeatedField<IntAsObject>();
foreach (var length in Enumerable.Range(0, 36))
{
field.Clear();
ValidateContentEquality(field, UnsafeCollectionOperations.AsSpan(field));

for (var i = 0; i < length; i++)
{
field.Add(new IntAsObject { Value = i });
}

ValidateContentEquality(field, UnsafeCollectionOperations.AsSpan(field));

field.Add(new IntAsObject { Value = length + 1 });
ValidateContentEquality(field, UnsafeCollectionOperations.AsSpan(field));
}

static void ValidateContentEquality(
RepeatedField<IntAsObject> field,
Span<IntAsObject> span)
{
Assert.AreEqual(field.Count, span.Length);

for (var i = 0; i < span.Length; i++)
{
Assert.AreEqual(field[i].Value, span[i].Value);
}
}
}

[Test]
public void FieldAsSpanLinkBreaksOnResize()
{
var field = new RepeatedField<int>();

for (var i = 0; i < 8; i++)
{
field.Add(i);
}

var span = UnsafeCollectionOperations.AsSpan(field);

var startCapacity = field.Capacity;
var startCount = field.Count;
Assert.AreEqual(startCount, startCapacity);
Assert.AreEqual(startCount, span.Length);

for (var i = 0; i < span.Length; i++)
{
span[i]++;
Assert.AreEqual(field[i], span[i]);

field[i]++;
Assert.AreEqual(field[i], span[i]);
}

// Resize to break link between Span and RepeatedField
field.Add(11);

Assert.AreNotEqual(startCapacity, field.Capacity);
Assert.AreNotEqual(startCount, field.Count);
Assert.AreEqual(startCount, span.Length);

for (var i = 0; i < span.Length; i++)
{
span[i] += 2;
Assert.AreNotEqual(field[i], span[i]);

field[i] += 3;
Assert.AreNotEqual(field[i], span[i]);
}
}

[Test]
public void FieldSetCount()
{
RepeatedField<int> field = null;
Assert.Throws<ArgumentNullException>(() => UnsafeCollectionOperations.SetCount(field, 3));

field = new RepeatedField<int>();
Assert.Throws<ArgumentOutOfRangeException>(()
=> UnsafeCollectionOperations.SetCount(field, -1));

UnsafeCollectionOperations.SetCount(field, 5);
Assert.AreEqual(5, field.Count);

field = new RepeatedField<int> { 1, 2, 3, 4, 5 };
ref var intRef = ref MemoryMarshal.GetReference(UnsafeCollectionOperations.AsSpan(field));

// make sure that size decrease preserves content
UnsafeCollectionOperations.SetCount(field, 3);
Assert.AreEqual(3, field.Count);
Assert.Throws<ArgumentOutOfRangeException>(() => field[3] = 42);
var span = UnsafeCollectionOperations.AsSpan(field);
SequenceEqual(span, new[] { 1, 2, 3 });
Assert.True(Unsafe.AreSame(ref intRef, ref MemoryMarshal.GetReference(span)));

// make sure that size increase preserves content and doesn't clear
UnsafeCollectionOperations.SetCount(field, 5);
span = UnsafeCollectionOperations.AsSpan(field);
// .NET Framework always clears values. .NET 6+ only clears references.
var expected =
#if NET5_0_OR_GREATER
new[] { 1, 2, 3, 4, 5 };
#else
new[] { 1, 2, 3, 0, 0 };
#endif
SequenceEqual(span, expected);
Assert.True(Unsafe.AreSame(ref intRef, ref MemoryMarshal.GetReference(span)));

// make sure that reallocations preserve content
var newCount = field.Capacity * 2;
UnsafeCollectionOperations.SetCount(field, newCount);
Assert.AreEqual(newCount, field.Count);
span = UnsafeCollectionOperations.AsSpan(field);
SequenceEqual(span.Slice(0, 3), new[] { 1, 2, 3 });
Assert.True(!Unsafe.AreSame(ref intRef, ref MemoryMarshal.GetReference(span)));

RepeatedField<string> listReference = new() { "a", "b", "c", "d", "e" };
var listSpan = UnsafeCollectionOperations.AsSpan(listReference);
ref var stringRef = ref MemoryMarshal.GetReference(listSpan);
UnsafeCollectionOperations.SetCount(listReference, 3);

// verify that reference types aren't cleared
listSpan = UnsafeCollectionOperations.AsSpan(listReference);
SequenceEqual(listSpan, new[] { "a", "b", "c" });
Assert.True(Unsafe.AreSame(ref stringRef, ref MemoryMarshal.GetReference(listSpan)));
UnsafeCollectionOperations.SetCount(listReference, 5);

// verify that removed reference types are cleared
listSpan = UnsafeCollectionOperations.AsSpan(listReference);
SequenceEqual(listSpan, new[] { "a", "b", "c", null, null });
Assert.True(Unsafe.AreSame(ref stringRef, ref MemoryMarshal.GetReference(listSpan)));
}

private static void SequenceEqual<T>(Span<T> span, Span<T> expected)
{
Assert.AreEqual(expected.Length, span.Length);
for (var i = 0; i < expected.Length; i++)
{
Assert.AreEqual(expected[i], span[i]);
}
}

private class IntAsObject
{
public int Value;
}
}
36 changes: 36 additions & 0 deletions csharp/src/Google.Protobuf/Collections/RepeatedField.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
using System.IO;
using System.Linq;
using System.Security;
#if NET5_0_OR_GREATER
using System.Runtime.CompilerServices;
#endif

namespace Google.Protobuf.Collections
{
Expand Down Expand Up @@ -643,6 +646,39 @@ public T this[int index]
}
}

[SecuritySafeCritical]
internal Span<T> AsSpan() => array.AsSpan(0, count);

internal void SetCount(int targetCount)
{
if (targetCount < 0)
{
throw new ArgumentOutOfRangeException(
nameof(targetCount),
targetCount,
"Non-negative number required.");
}

if (targetCount > Capacity)
{
EnsureSize(targetCount);
}
#if NET5_0_OR_GREATER
else if (targetCount < count && RuntimeHelpers.IsReferenceOrContainsReferences<T>())
{
// Only reference types need to be cleared to allow GC to collect them.
Array.Clear(array, targetCount, count - targetCount);
}
#else
else if (targetCount < count)
{
Array.Clear(array, targetCount, count - targetCount);
}
#endif

count = targetCount;
}

#region Explicit interface implementation for IList and ICollection.
bool IList.IsFixedSize => false;

Expand Down
96 changes: 96 additions & 0 deletions csharp/src/Google.Protobuf/UnsafeCollectionOperations.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#region Copyright notice and license

// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file or at
// https://developers.google.com/open-source/licenses/bsd

#endregion

using System;
using System.Security;
using Google.Protobuf.Collections;

namespace Google.Protobuf;

/// <summary>
/// An unsafe class that provides a set of methods to access the underlying data representations of
/// collections.
/// </summary>
[SecuritySafeCritical]
public static class UnsafeCollectionOperations
{
/// <summary>
/// <para>
/// Returns a <see cref="Span{T}"/> that wraps the current backing array of the given
/// <see cref="RepeatedField{T}"/>.
/// </para>
/// <para>
/// Values in the <see cref="Span{T}"/> should not be set to null. Use
/// <see cref="RepeatedField{T}.Remove(T)"/> or <see cref="RepeatedField{T}.RemoveAt(int)"/> to
/// remove items instead.
/// </para>
/// <para>
/// The returned <see cref="Span{T}"/> is only valid until the size of the
/// <see cref="RepeatedField{T}"/> is modified, after which its state becomes undefined.
/// Modifying existing elements without changing the size is safe as long as the modifications
/// do not set null values.
/// </para>
/// </summary>
/// <typeparam name="T">
/// The type of elements in the <see cref="RepeatedField{T}"/>.
/// </typeparam>
/// <param name="field">
/// The <see cref="RepeatedField{T}"/> for which to wrap the current backing array. Must not be
/// null.
/// </param>
/// <returns>
/// A <see cref="Span{T}"/> that wraps the current backing array of the
/// <see cref="RepeatedField{T}"/>.
/// </returns>
/// <exception cref="ArgumentNullException">
/// Thrown if <paramref name="field"/> is <see langword="null"/>.
/// </exception>
public static Span<T> AsSpan<T>(RepeatedField<T> field)
{
ProtoPreconditions.CheckNotNull(field, nameof(field));
return field.AsSpan();
}

/// <summary>
/// <para>
/// Sets the count of the specified <see cref="RepeatedField{T}"/> to the given value.
/// </para>
/// <para>
/// This method should only be called if the subsequent code guarantees to populate
/// the field with the specified number of items.
/// </para>
/// <para>
/// If count is less than <see cref="RepeatedField{T}.Count"/>, the collection is effectively
/// trimmed down to the first count elements. <see cref="RepeatedField{T}.Capacity"/>
/// is unchanged, meaning the underlying array remains allocated.
/// </para>
/// </summary>
/// <typeparam name="T">
/// The type of elements in the <see cref="RepeatedField{T}"/>.
/// </typeparam>
/// <param name="field">
/// The field to set the count of. Must not be null.
/// </param>
/// <param name="count">
/// The value to set the field's count to. Must be non-negative.
/// </param>
/// <exception cref="ArgumentNullException">
/// Thrown if <paramref name="field"/> is <see langword="null"/>.
/// </exception>
/// <exception cref="ArgumentOutOfRangeException">
/// Thrown if <paramref name="count"/> is negative.
/// </exception>
public static void SetCount<T>(RepeatedField<T> field, int count)
{
ProtoPreconditions.CheckNotNull(field, nameof(field));
field.SetCount(count);
}
}

0 comments on commit a1b0088

Please sign in to comment.