diff --git a/Build.csproj b/Build.csproj index 3e16e801c..41fb15b0c 100644 --- a/Build.csproj +++ b/Build.csproj @@ -1,5 +1,6 @@ + diff --git a/Directory.Build.props b/Directory.Build.props index 42de5875c..988cea9b6 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -1,5 +1,8 @@ + + false + 2.0.0 2014 - $([System.DateTime]::Now.Year) Stack Exchange, Inc. true @@ -26,7 +29,15 @@ true false true + true + false + true 00240000048000009400000006020000002400005253413100040000010001007791a689e9d8950b44a9a8886baad2ea180e7a8a854f158c9b98345ca5009cdd2362c84f368f1c3658c132b3c0f74e44ff16aeb2e5b353b6e0fe02f923a050470caeac2bde47a2238a9c7125ed7dab14f486a5a64558df96640933b9f2b6db188fc4a820f96dce963b662fa8864adbff38e5b4542343f162ecdc6dad16912fff + LatestMajor + + + preview + $(DefineConstants);PREVIEW_LANGVER true diff --git a/Directory.Packages.props b/Directory.Packages.props index df8c078a3..fd9d139a0 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -7,11 +7,16 @@ + + + + + - + @@ -25,7 +30,6 @@ - diff --git a/StackExchange.Redis.sln b/StackExchange.Redis.sln index 2ed4ebfb3..86f751e25 100644 --- a/StackExchange.Redis.sln +++ b/StackExchange.Redis.sln @@ -122,9 +122,21 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "docs", "docs\docs.csproj", EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "StackExchange.Redis.Benchmarks", "tests\StackExchange.Redis.Benchmarks\StackExchange.Redis.Benchmarks.csproj", "{59889284-FFEE-82E7-94CB-3B43E87DA6CF}" EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "eng", "eng", "{5FA0958E-6EBD-45F4-808E-3447A293F96F}" +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "RESP.Core", "src\RESP.Core\RESP.Core.csproj", "{E50EEB8B-6B3F-4C8C-A5C6-C37FB87C01E2}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "StackExchange.Redis.Build", "eng\StackExchange.Redis.Build\StackExchange.Redis.Build.csproj", "{190742E1-FA50-4E36-A8C4-88AE87654340}" +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "RESPite.Tests", "tests\RESPite.Tests\RESPite.Tests.csproj", "{7063E2D3-C591-4604-A5DD-32D4A1678A58}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "eng", "eng", "{C0132984-68D1-4A97-8F8C-AD4E2EECC583}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "StackExchange.Redis.Build", "eng\StackExchange.Redis.Build\StackExchange.Redis.Build.csproj", "{B0055B76-4685-4ECF-A904-88EE4E6FC8F0}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "RESPite", "src\RESPite\RESPite.csproj", "{F8762EE5-3461-4F6B-8C24-C876B6D9E637}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "RESPite.Redis", "src\RESPite.Redis\RESPite.Redis.csproj", "{3A92C2E7-3033-4FDF-8DDC-5DF43D290537}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "RESPite.StackExchange.Redis", "src\RESPite.StackExchange.Redis\RESPite.StackExchange.Redis.csproj", "{A5580114-C236-494E-851C-A21E3DB86FC8}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "RESPite.Benchmark", "src\RESPite.Benchmark\RESPite.Benchmark.csproj", "{3725A78B-B6B5-4379-9DE0-37A180ADE95A}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution @@ -184,10 +196,34 @@ Global {59889284-FFEE-82E7-94CB-3B43E87DA6CF}.Debug|Any CPU.Build.0 = Debug|Any CPU {59889284-FFEE-82E7-94CB-3B43E87DA6CF}.Release|Any CPU.ActiveCfg = Release|Any CPU {59889284-FFEE-82E7-94CB-3B43E87DA6CF}.Release|Any CPU.Build.0 = Release|Any CPU - {190742E1-FA50-4E36-A8C4-88AE87654340}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {190742E1-FA50-4E36-A8C4-88AE87654340}.Debug|Any CPU.Build.0 = Debug|Any CPU - {190742E1-FA50-4E36-A8C4-88AE87654340}.Release|Any CPU.ActiveCfg = Release|Any CPU - {190742E1-FA50-4E36-A8C4-88AE87654340}.Release|Any CPU.Build.0 = Release|Any CPU + {E50EEB8B-6B3F-4C8C-A5C6-C37FB87C01E2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {E50EEB8B-6B3F-4C8C-A5C6-C37FB87C01E2}.Debug|Any CPU.Build.0 = Debug|Any CPU + {E50EEB8B-6B3F-4C8C-A5C6-C37FB87C01E2}.Release|Any CPU.ActiveCfg = Release|Any CPU + {E50EEB8B-6B3F-4C8C-A5C6-C37FB87C01E2}.Release|Any CPU.Build.0 = Release|Any CPU + {7063E2D3-C591-4604-A5DD-32D4A1678A58}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {7063E2D3-C591-4604-A5DD-32D4A1678A58}.Debug|Any CPU.Build.0 = Debug|Any CPU + {7063E2D3-C591-4604-A5DD-32D4A1678A58}.Release|Any CPU.ActiveCfg = Release|Any CPU + {7063E2D3-C591-4604-A5DD-32D4A1678A58}.Release|Any CPU.Build.0 = Release|Any CPU + {B0055B76-4685-4ECF-A904-88EE4E6FC8F0}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {B0055B76-4685-4ECF-A904-88EE4E6FC8F0}.Debug|Any CPU.Build.0 = Debug|Any CPU + {B0055B76-4685-4ECF-A904-88EE4E6FC8F0}.Release|Any CPU.ActiveCfg = Release|Any CPU + {B0055B76-4685-4ECF-A904-88EE4E6FC8F0}.Release|Any CPU.Build.0 = Release|Any CPU + {F8762EE5-3461-4F6B-8C24-C876B6D9E637}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {F8762EE5-3461-4F6B-8C24-C876B6D9E637}.Debug|Any CPU.Build.0 = Debug|Any CPU + {F8762EE5-3461-4F6B-8C24-C876B6D9E637}.Release|Any CPU.ActiveCfg = Release|Any CPU + {F8762EE5-3461-4F6B-8C24-C876B6D9E637}.Release|Any CPU.Build.0 = Release|Any CPU + {3A92C2E7-3033-4FDF-8DDC-5DF43D290537}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {3A92C2E7-3033-4FDF-8DDC-5DF43D290537}.Debug|Any CPU.Build.0 = Debug|Any CPU + {3A92C2E7-3033-4FDF-8DDC-5DF43D290537}.Release|Any CPU.ActiveCfg = Release|Any CPU + {3A92C2E7-3033-4FDF-8DDC-5DF43D290537}.Release|Any CPU.Build.0 = Release|Any CPU + {A5580114-C236-494E-851C-A21E3DB86FC8}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {A5580114-C236-494E-851C-A21E3DB86FC8}.Debug|Any CPU.Build.0 = Debug|Any CPU + {A5580114-C236-494E-851C-A21E3DB86FC8}.Release|Any CPU.ActiveCfg = Release|Any CPU + {A5580114-C236-494E-851C-A21E3DB86FC8}.Release|Any CPU.Build.0 = Release|Any CPU + {3725A78B-B6B5-4379-9DE0-37A180ADE95A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {3725A78B-B6B5-4379-9DE0-37A180ADE95A}.Debug|Any CPU.Build.0 = Debug|Any CPU + {3725A78B-B6B5-4379-9DE0-37A180ADE95A}.Release|Any CPU.ActiveCfg = Release|Any CPU + {3725A78B-B6B5-4379-9DE0-37A180ADE95A}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -210,7 +246,13 @@ Global {A0F89B8B-32A3-4C28-8F1B-ADE343F16137} = {73A5C363-CA1F-44C4-9A9B-EF791A76BA6A} {69A0ACF2-DF1F-4F49-B554-F732DCA938A3} = {73A5C363-CA1F-44C4-9A9B-EF791A76BA6A} {59889284-FFEE-82E7-94CB-3B43E87DA6CF} = {73A5C363-CA1F-44C4-9A9B-EF791A76BA6A} - {190742E1-FA50-4E36-A8C4-88AE87654340} = {5FA0958E-6EBD-45F4-808E-3447A293F96F} + {E50EEB8B-6B3F-4C8C-A5C6-C37FB87C01E2} = {00CA0876-DA9F-44E8-B0DC-A88716BF347A} + {7063E2D3-C591-4604-A5DD-32D4A1678A58} = {73A5C363-CA1F-44C4-9A9B-EF791A76BA6A} + {B0055B76-4685-4ECF-A904-88EE4E6FC8F0} = {C0132984-68D1-4A97-8F8C-AD4E2EECC583} + {F8762EE5-3461-4F6B-8C24-C876B6D9E637} = {00CA0876-DA9F-44E8-B0DC-A88716BF347A} + {3A92C2E7-3033-4FDF-8DDC-5DF43D290537} = {00CA0876-DA9F-44E8-B0DC-A88716BF347A} + {A5580114-C236-494E-851C-A21E3DB86FC8} = {00CA0876-DA9F-44E8-B0DC-A88716BF347A} + {3725A78B-B6B5-4379-9DE0-37A180ADE95A} = {00CA0876-DA9F-44E8-B0DC-A88716BF347A} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {193AA352-6748-47C1-A5FC-C9AA6B5F000B} diff --git a/StackExchange.Redis.sln.DotSettings b/StackExchange.Redis.sln.DotSettings index b72a49d2c..0c18b97d4 100644 --- a/StackExchange.Redis.sln.DotSettings +++ b/StackExchange.Redis.sln.DotSettings @@ -1,5 +1,13 @@  OK PONG + RES + SE + True + True + True + True + True True - True \ No newline at end of file + True + diff --git a/docs/ReleaseNotes.md b/docs/ReleaseNotes.md index 185a679f4..8b24ea1b8 100644 --- a/docs/ReleaseNotes.md +++ b/docs/ReleaseNotes.md @@ -26,7 +26,7 @@ Current package versions: - Add support for XPENDING Idle time filter ([#2822 by david-brink-talogy](https://github.com/StackExchange/StackExchange.Redis/pull/2822)) - Improve `double` formatting performance on net8+ ([#2928 by mgravell](https://github.com/StackExchange/StackExchange.Redis/pull/2928)) - Add `GetServer(RedisKey, ...)` API ([#2936 by mgravell](https://github.com/StackExchange/StackExchange.Redis/pull/2936)) -- Fix error constructing `StreamAdd` message ([#2941 by mgravell](https://github.com/StackExchange/StackExchange.Redis/pull/2941)) +- Fix error constructing `StreamAdd` message ([#2941 by mgravell](https://github.com/StackExchange/StackExchange.Redis/pull/2941)) ## 2.8.58 diff --git a/eng/StackExchange.Redis.Build/EasyArray.cs b/eng/StackExchange.Redis.Build/EasyArray.cs new file mode 100644 index 000000000..400438e1c --- /dev/null +++ b/eng/StackExchange.Redis.Build/EasyArray.cs @@ -0,0 +1,55 @@ +using System.Collections; + +namespace StackExchange.Redis.Build; + +/// +/// Think ImmutableArray{T}, but with structural equality. +/// +/// The data being wrapped. +internal readonly struct EasyArray(T[]? array) : IEquatable>, IEnumerable +{ + public static readonly EasyArray Empty = new([]); + private readonly T[]? _array = array ?? []; + public int Length => _array?.Length ?? 0; + public ref readonly T this[int index] => ref _array![index]; + public ReadOnlySpan Span => _array.AsSpan(); + public bool IsEmpty => Length == 0; + + public static bool operator ==(EasyArray x, EasyArray y) + => x.Equals(y); + + public static bool operator !=(EasyArray x, EasyArray y) + => x.Equals(y); + + public bool Equals(EasyArray other) + { + T[]? tArr = this._array, oArr = other._array; + if (tArr is null) return oArr is null || oArr.Length == 0; + if (oArr is null) return tArr.Length == 0; + + if (tArr.Length != oArr.Length) return false; + for (int i = 0; i < tArr.Length; i++) + { + if (ReferenceEquals(tArr[i], oArr[i])) + return false; + } + return true; + } + + public IEnumerator GetEnumerator() => ((IEnumerable)(_array ?? [])).GetEnumerator(); + + public override bool Equals(object? obj) + => obj is EasyArray other && Equals(other); + + public override int GetHashCode() + { + var arr = _array; + if (arr is null) return 0; + // use length and first item for a quick hash + return arr.Length == 0 + ? 0 + : arr.Length ^ EqualityComparer.Default.GetHashCode(arr[0]); + } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); +} diff --git a/eng/StackExchange.Redis.Build/IsExternalInit.cs b/eng/StackExchange.Redis.Build/IsExternalInit.cs new file mode 100644 index 000000000..64f57fd4a --- /dev/null +++ b/eng/StackExchange.Redis.Build/IsExternalInit.cs @@ -0,0 +1,7 @@ +// ReSharper disable once CheckNamespace +namespace System.Runtime.CompilerServices; +#if !NET5_0_OR_GREATER +internal static class IsExternalInit +{ +} +#endif diff --git a/eng/StackExchange.Redis.Build/RespCommandGenerator.cs b/eng/StackExchange.Redis.Build/RespCommandGenerator.cs new file mode 100644 index 000000000..3ebef397a --- /dev/null +++ b/eng/StackExchange.Redis.Build/RespCommandGenerator.cs @@ -0,0 +1,1515 @@ +using System.Collections.Immutable; +using System.Diagnostics; +using System.Globalization; +using System.Reflection; +using System.Text; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace StackExchange.Redis.Build; + +[Generator(LanguageNames.CSharp)] +public class RespCommandGenerator : IIncrementalGenerator +{ + [Flags] + private enum LiteralFlags + { + None = 0, + Suffix = 1 << 0, // else prefix + // optional, etc + } + + public void Initialize(IncrementalGeneratorInitializationContext context) + { + var literals = context.SyntaxProvider + .CreateSyntaxProvider(Predicate, Transform) + .Where(pair => pair.MethodName is { Length: > 0 }) + .Collect(); + + context.RegisterSourceOutput(literals, Generate); + } + + private bool Predicate(SyntaxNode node, CancellationToken cancellationToken) + { + // looking for [FastHash] partial static class Foo { } + if (node is MethodDeclarationSyntax decl + && decl.Modifiers.Any(SyntaxKind.PartialKeyword)) + { + foreach (var attribList in decl.AttributeLists) + { + foreach (var attrib in attribList.Attributes) + { + if (attrib.Name.ToString() is "RespCommandAttribute" or "RespCommand") return true; + } + } + } + + return false; + } + + private readonly record struct LiteralTuple(string Token, LiteralFlags Flags); + + private readonly record struct ParameterTuple( + string Type, + string Name, + string Modifiers, + ParameterFlags Flags, + EasyArray Literals, + string? ElementType, + string? IgnoreExpression, + int ArgIndex) + { + // variable if collection, nullable, or an explicit ignore expression + public bool IsVariable => OptionalReasons != 0; + + public ParameterFlags OptionalReasons => + Flags & (ParameterFlags.Collection | ParameterFlags.Nullable | ParameterFlags.IgnoreExpression); + + public bool IsCollection => (Flags & ParameterFlags.Collection) != 0; + public bool IsNullable => (Flags & ParameterFlags.Nullable) != 0; + } + + private readonly record struct MethodTuple( + string Namespace, + string TypeName, + string ReturnType, + string MethodName, + string Command, + EasyArray Parameters, + string TypeModifiers, + string MethodModifiers, + string Context, + string? Formatter, + string? Parser, + MethodFlags Flags, + string DebugNotes) + { + public bool IsRespOperation => (Flags & MethodFlags.RespOperation) != 0; + } + + private static string GetFullName(ITypeSymbol type) => + type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + + private enum RESPite + { + RespContext, + RespCommandAttribute, + RespKeyAttribute, + RespPrefixAttribute, + RespSuffixAttribute, + RespOperation, + RespIgnoreAttribute, + } + + private static bool IsRESPite(ITypeSymbol? symbol, RESPite type) + { + static string NameOf(RESPite type) => type switch + { + RESPite.RespContext => nameof(RESPite.RespContext), + RESPite.RespCommandAttribute => nameof(RESPite.RespCommandAttribute), + RESPite.RespKeyAttribute => nameof(RESPite.RespKeyAttribute), + RESPite.RespPrefixAttribute => nameof(RESPite.RespPrefixAttribute), + RESPite.RespSuffixAttribute => nameof(RESPite.RespSuffixAttribute), + RESPite.RespOperation => nameof(RESPite.RespOperation), + RESPite.RespIgnoreAttribute => nameof(RESPite.RespIgnoreAttribute), + _ => type.ToString(), + }; + + if (symbol is INamedTypeSymbol named && named.Name == NameOf(type)) + { + // looking likely; check namespace + if (named.ContainingNamespace is { Name: "RESPite", ContainingNamespace.IsGlobalNamespace: true }) + { + return true; + } + + // if the type doesn't resolve: we're going to need to trust it + if (named.TypeKind == TypeKind.Error) return true; + } + + return false; + } + + private enum SERedis + { + CommandFlags, + RedisValue, + RedisKey, + } + + private static bool IsSERedis(ITypeSymbol? symbol, SERedis type) + { + static string NameOf(SERedis type) => type switch + { + SERedis.CommandFlags => nameof(SERedis.CommandFlags), + SERedis.RedisValue => nameof(SERedis.RedisValue), + SERedis.RedisKey => nameof(SERedis.RedisKey), + _ => type.ToString(), + }; + + if (symbol is INamedTypeSymbol named && named.Name == NameOf(type)) + { + // looking likely; check namespace + if (named.ContainingNamespace is + { + Name: "Redis", ContainingNamespace: + { + Name: "StackExchange", + ContainingNamespace.IsGlobalNamespace: true, + } + }) + { + return true; + } + + // if the type doesn't resolve: we're going to need to trust it + if (named.TypeKind == TypeKind.Error) return true; + } + + return false; + } + + private static string GetName(ITypeSymbol type) + { + if (type.ContainingType is null) return type.Name; + var stack = new Stack(); + while (true) + { + stack.Push(type.Name); + if (type.ContainingType is null) break; + type = type.ContainingType; + } + + var sb = new StringBuilder(stack.Pop()); + while (stack.Count != 0) + { + sb.Append('.').Append(stack.Pop()); + } + + return sb.ToString(); + } + + [Conditional("DEBUG")] + private static void AddNotes(ref string notes, string note) + { + if (string.IsNullOrWhiteSpace(notes)) + { + notes = note; + } + else + { + notes += "; " + note; + } + } + + private MethodTuple Transform( + GeneratorSyntaxContext ctx, + CancellationToken cancellationToken) + { + // extract the name and value (defaults to name, but can be overridden via attribute) and the location + if (ctx.SemanticModel.GetDeclaredSymbol(ctx.Node) is not IMethodSymbol method) return default; + if (!(method is { IsPartialDefinition: true, PartialImplementationPart: null })) return default; + + MethodFlags methodFlags = 0; + string returnType, debugNote = ""; + if (method.ReturnsVoid) + { + returnType = ""; + } + // ReSharper disable once ConditionIsAlwaysTrueOrFalseAccordingToNullableAPIContract + else if (method.ReturnType is null) + { + return default; + } + else + { + ITypeSymbol? rt = method.ReturnType; + if (IsRespOperation(ref rt)) + { + methodFlags |= MethodFlags.RespOperation; + } + returnType = rt is null ? "" : GetFullName(rt); + } + + string ns = "", parentType = ""; + if (method.ContainingType is { } containingType) + { + parentType = GetName(containingType); + ns = containingType.ContainingNamespace.ToDisplayString(SymbolDisplayFormat.CSharpErrorMessageFormat); + } + else if (method.ContainingNamespace is { } containingNamespace) + { + ns = containingNamespace.ToDisplayString(SymbolDisplayFormat.CSharpErrorMessageFormat); + } + + string value = method.Name.ToLowerInvariant(); + string? formatter = null, parser = null; + foreach (var attrib in method.GetAttributes()) + { + if (IsRESPite(attrib.AttributeClass, RESPite.RespCommandAttribute)) + { + if (attrib.ConstructorArguments.Length == 1) + { + if (attrib.ConstructorArguments[0].Value?.ToString() is { Length: > 0 } val) + { + value = val; + } + } + + foreach (var tuple in attrib.NamedArguments) + { + switch (tuple.Key) + { + case "Formatter": + formatter = tuple.Value.Value?.ToString(); + AddNotes(ref debugNote, $"custom formatter: '{formatter}'"); + break; + case "Parser": + parser = tuple.Value.Value?.ToString(); + AddNotes(ref debugNote, $"custom parser: '{parser}'"); + break; + } + } + + break; // we don't expect another [RespCommand] + } + } + + var parameters = new List(method.Parameters.Length); + + // get context from the available fields + string? context = null; + IParameterSymbol? contextParam = null; + foreach (var param in method.Parameters) + { + if (IsRESPite(param.Type, RESPite.RespContext)) + { + contextParam = param; + context = param.Name; + break; + } + } + + if (context is null) + { + AddNotes(ref debugNote, $"checking {method.ContainingType.Name} for fields"); + foreach (var member in method.ContainingType.GetMembers()) + { + if (member is IFieldSymbol { IsStatic: false } field) + { + if (IsRESPite(field.Type, RESPite.RespContext)) + { + AddNotes(ref debugNote, $"{field.Name} WAS match - {field.Type.Name}"); + context = field.Name; + break; + } + } + } + } + + if (context is null) + { + // get context from primary constructor (actually, we look at all constructors, + // and just hope that the one that matches: works!) + foreach (var ctor in method.ContainingType.Constructors) + { + if (ctor.IsStatic) continue; + foreach (var param in ctor.Parameters) + { + if (IsRESPite(param.Type, RESPite.RespContext)) + { + context = param.Name; + break; + } + } + + if (context is not null) break; + } + } + + if (context is null) + { + // look for indirect from parameter + foreach (var param in method.Parameters) + { + if (IsIndirectRespContext(param.Type, out var memberName)) + { + contextParam = param; + context = $"{param.Name}.{memberName}"; + break; + } + } + } + + if (context is null) + { + // look for indirect from field + foreach (var member in method.ContainingType.GetMembers()) + { + if (member is IFieldSymbol { IsStatic: false } field && + IsIndirectRespContext(field.Type, out var memberName)) + { + context = $"{field.Name}.{memberName}"; + break; + } + } + } + + // See whether instead of x (param, etc.) *being* a RespContext, it could be something that *provides* + // a RespContext; this is especially useful for using punned structs (that just wrap a RespContext) to + // narrow the methods into logical groups, i.e. "strings", "hashes", etc. + static bool IsIndirectRespContext(ITypeSymbol type, out string memberName) + { + foreach (var member in type.GetMembers()) + { + if (member is IFieldSymbol { IsStatic: false } field + && IsRESPite(field.Type, RESPite.RespContext)) + { + memberName = field.Name; + return true; + } + } + + foreach (var member in type.GetMembers()) + { + if (member is IPropertySymbol { IsStatic: false } prop + && IsRESPite(prop.Type, RESPite.RespContext) && prop.GetMethod is not null) + { + memberName = prop.Name; + return true; + } + } + + memberName = ""; + return false; + } + + if (context is null) + { + // last ditch, get context from properties + foreach (var member in method.ContainingType.GetMembers()) + { + if (member is IPropertySymbol { IsStatic: false } prop + && IsRESPite(prop.Type, RESPite.RespContext) && prop.GetMethod is not null) + { + context = prop.Name; + break; + } + } + } + + int nextArgIndex = 0; + foreach (var param in method.Parameters) + { + string? ignoreExpression = null; + var flags = ParameterFlags.Parameter; + if (IsKey(param)) flags |= ParameterFlags.Key; + var elementType = param.Type; + flags |= GetTypeFlags(ref elementType); + string? elementTypeName = ReferenceEquals(elementType, param.Type) ? null : GetFullName(elementType); + if (IsSERedis(param.Type, SERedis.CommandFlags)) + { + flags |= ParameterFlags.CommandFlags; + // magic pattern; we *demand* a method called Context that takes the flags; if this is an extension + // method, assume it is on the first parameter + if ((methodFlags & MethodFlags.ExtensionMethod) != 0) + { + context = $"{method.Parameters[0].Name}.Context({param.Name})"; + } + else + { + context = $"Context({param.Name})"; + } + } + else if (IsRESPite(param.Type, RESPite.RespContext)) + { + // ignore it, but no extra flag + } + else if (contextParam is not null && SymbolEqualityComparer.Default.Equals(param, contextParam)) + { + // ignore it, but no extra flag + } + else + { + flags |= ParameterFlags.Data; + } + + string modifiers = param.RefKind switch + { + RefKind.None => "", + RefKind.In => "in ", + RefKind.Out => "out ", + RefKind.Ref => "ref ", + _ => "", + }; + + if (param.Ordinal == 0 && method.IsExtensionMethod) + { + methodFlags |= MethodFlags.ExtensionMethod; + modifiers = "this " + modifiers; + } + + List? literals = null; + + void AddLiteral(string token, LiteralFlags literalFlags) + { + (literals ??= new()).Add(new(token, literalFlags)); + } + + AddNotes(ref debugNote, $"checking {param.Name} for literals"); + foreach (var attrib in param.GetAttributes()) + { + if (attrib.ConstructorArguments.Length == 1) + { + if (IsRESPite(attrib.AttributeClass, RESPite.RespPrefixAttribute)) + { + if (attrib.ConstructorArguments[0].Value?.ToString() is { } val) + { + AddNotes(ref debugNote, $"prefix {val}"); + AddLiteral(val, LiteralFlags.None); + } + } + + if (IsRESPite(attrib.AttributeClass, RESPite.RespSuffixAttribute)) + { + if (attrib.ConstructorArguments[0].Value?.ToString() is { Length: > 0 } val) + { + AddNotes(ref debugNote, $"suffix {val}"); + AddLiteral(val, LiteralFlags.Suffix); + } + } + + if (IsRESPite(attrib.AttributeClass, RESPite.RespIgnoreAttribute)) + { + var val = attrib.ConstructorArguments[0].Value; + var expr = val switch + { + null when IsSERedis(param.Type, SERedis.RedisValue) | IsSERedis(param.Type, SERedis.RedisKey) => ".IsNull is false", + string s => " != " + CodeLiteral(s), + bool b => b ? " is false" : " is true", // if we *ignore* true, then "incN = foo is false" + long l when attrib.ConstructorArguments[0].Type is INamedTypeSymbol { EnumUnderlyingType: not null } enumType + => " != " + GetEnumExpression(enumType, l), + long l => " != " + l.ToString(CultureInfo.InvariantCulture), + int i when attrib.ConstructorArguments[0].Type is INamedTypeSymbol { EnumUnderlyingType: not null } enumType + => " != " + GetEnumExpression(enumType, i), + int i => " != " + i.ToString(CultureInfo.InvariantCulture), + _ => null, + }; + + if (expr is not null) + { + flags |= ParameterFlags.IgnoreExpression; + ignoreExpression = expr; + } + + static string GetEnumExpression(INamedTypeSymbol enumType, object value) + { + foreach (var member in enumType.GetMembers()) + { + if (member is IFieldSymbol { IsStatic: true, IsConst: true } field + && Equals(field.ConstantValue, value)) + { + return $"{GetFullName(enumType)}.{field.Name}"; + } + } + + return $"({GetFullName(enumType)}){value}"; + } + } + } + } + + var literalArray = literals is null ? EasyArray.Empty : new(literals.ToArray()); + var argIndex = (flags & ParameterFlags.Data) != 0 ? nextArgIndex++ : -1; + + parameters.Add(new(GetFullName(param.Type), param.Name, modifiers, flags, literalArray, elementTypeName, ignoreExpression, argIndex)); + } + + var syntax = (MethodDeclarationSyntax)ctx.Node; + return new( + ns, + parentType, + returnType, + method.Name, + value, + new(parameters.ToArray()), + TypeModifiers(method.ContainingType), + syntax.Modifiers.ToString(), + context ?? "", + formatter, + parser, + methodFlags, + debugNote); + + static string TypeModifiers(ITypeSymbol type) + { + foreach (var symbol in type.DeclaringSyntaxReferences) + { + var syntax = symbol.GetSyntax(); + if (syntax is TypeDeclarationSyntax typeDeclaration) + { + var mods = typeDeclaration.Modifiers.ToString(); + return syntax switch + { + InterfaceDeclarationSyntax => $"{mods} interface", + StructDeclarationSyntax => $"{mods} struct", + _ => $"{mods} class", + }; + } + } + + return "class"; // wut? + } + } + + private bool IsRespOperation(ref ITypeSymbol? type) // identify RespOperation[] + { + if (type is INamedTypeSymbol named && IsRESPite(type, RESPite.RespOperation)) + { + if (named.IsGenericType) + { + if (named.TypeArguments.Length != 1) return false; // unexpected + type = named.TypeArguments[0]; + } + else + { + type = null; + } + return true; + } + return false; + } + + private static ParameterFlags GetTypeFlags(ref ITypeSymbol paramType) + { + var flags = ParameterFlags.None; + if (paramType.IsValueType) flags |= ParameterFlags.ValueType; + switch (paramType.NullableAnnotation) + { + case NullableAnnotation.Annotated: + flags |= ParameterFlags.Nullable; + break; + case NullableAnnotation.None: + if (paramType.IsReferenceType) flags |= ParameterFlags.Nullable; + break; + } + + if (paramType is IArrayTypeSymbol arr) + { + if (arr.Rank == 1 && arr.ElementType.SpecialType != SpecialType.System_Byte) + { + flags |= ParameterFlags.Collection; + paramType = arr.ElementType; + } + } + + if (paramType is INamedTypeSymbol { IsGenericType: true, Arity: 1 } gen) + { + switch (gen.ConstructedFrom.SpecialType) + { + case SpecialType.System_Collections_Generic_ICollection_T: + case SpecialType.System_Collections_Generic_IList_T: + case SpecialType.System_Collections_Generic_IReadOnlyCollection_T: + case SpecialType.System_Collections_Generic_IReadOnlyList_T: + flags |= ParameterFlags.Collection | ParameterFlags.CollectionWithCount; + paramType = gen.TypeArguments[0]; + break; + default: + if (IsSystemCollections(gen.ConstructedFrom, "List")) + { + flags |= ParameterFlags.Collection; + paramType = gen.TypeArguments[0]; + } + if (IsSystemCollections(gen.ConstructedFrom, "ImmutableArray", "Immutable")) + { + flags |= ParameterFlags.Collection | ParameterFlags.ImmutableArray; + paramType = gen.TypeArguments[0]; + } + break; + } + } + + return flags; + + static bool IsSystemCollections(INamedTypeSymbol type, string name, string ns = "Generic") + => type.Name == name && type.ContainingNamespace is { } actualNs && actualNs.Name == ns + && actualNs.ContainingNamespace is + { + Name: + "Collections", + ContainingNamespace: + { + Name: + "System", + ContainingNamespace.IsGlobalNamespace: true, + } + }; + } + + private bool IsKey(IParameterSymbol param) + { + if (param.Name.EndsWith("key", StringComparison.InvariantCultureIgnoreCase)) + { + return true; + } + + foreach (var attrib in param.GetAttributes()) + { + if (IsRESPite(attrib.AttributeClass, RESPite.RespKeyAttribute)) return true; + } + + return false; + } + + private string GetVersion() + { + var asm = GetType().Assembly; + if (asm.GetCustomAttributes(typeof(AssemblyFileVersionAttribute), false).FirstOrDefault() is + AssemblyFileVersionAttribute { Version: { Length: > 0 } } version) + { + return version.Version; + } + + return asm.GetName().Version?.ToString() ?? "??"; + } + + private static string CodeLiteral(string value) + => SyntaxFactory + .LiteralExpression(SyntaxKind.StringLiteralExpression, SyntaxFactory.Literal(value)) + .ToFullString(); + + private void Generate( + SourceProductionContext ctx, + ImmutableArray methods) + { + if (methods.IsDefaultOrEmpty) return; + + var sb = new StringBuilder("// ") + .AppendLine().Append("// ").Append(GetType().Name).Append(" v").Append(GetVersion()).AppendLine(); + + bool first; + int indent = 0; + + // find the unique param types, so we can build helpers + Dictionary, (string Name, + int ShareCount, string Command)> + formatters = + new(FormatterComparer.Default); + + foreach (var method in methods.AsSpan()) + { + if (method.Formatter is not null) continue; // using explicit formatter + var count = DataParameterCount(method.Parameters); + switch (count) + { + case 0: continue; // no parameter to consider + case 1: + var p = FirstDataParameter(method.Parameters); + if (p.Literals.IsEmpty) + { + // no literals, and basic write scenario;consumer should add their own extension method + continue; + } + + break; + } + + // add a new formatter, or mark an existing formatter as shared + var key = method.Parameters; + if (!formatters.TryGetValue(key, out var existing)) + { + formatters.Add(key, ($"__RespFormatter_{formatters.Count}", 1, method.Command)); + } + else + { + formatters[key] = (existing.Name, existing.ShareCount + 1, ""); // incr share count + } + } + + StringBuilder NewLine() => sb.AppendLine().Append(' ', Math.Max(indent * 4, 0)); + NewLine().Append("using global::RESPite;"); + foreach (var method in methods.AsSpan()) + { + if (HasAnyFlag(method.Parameters, ParameterFlags.CommandFlags)) + { + NewLine().Append("using global::RESPite.StackExchange.Redis;"); + break; + } + } + + NewLine().Append("using global::System;"); + NewLine().Append("using global::System.Threading.Tasks;"); + + foreach (var grp in methods.GroupBy(l => (l.Namespace, l.TypeName, l.TypeModifiers))) + { + NewLine(); + int braces = 0; + if (!string.IsNullOrWhiteSpace(grp.Key.Namespace)) + { + NewLine().Append("namespace ").Append(grp.Key.Namespace); + NewLine().Append("{"); + indent++; + braces++; + } + + if (!string.IsNullOrWhiteSpace(grp.Key.TypeName)) + { + if (grp.Key.TypeName.Contains('.')) // nested types + { + var tokens = grp.Key.TypeName.Split('.'); + for (var i = 0; i < tokens.Length; i++) + { + var part = tokens[i]; + if (i == tokens.Length - 1) + { + NewLine().Append(grp.Key.TypeModifiers).Append(' ').Append(part); + } + else + { + NewLine().Append("partial class ").Append(part); + } + + NewLine().Append("{"); + indent++; + braces++; + } + } + else + { + NewLine().Append(grp.Key.TypeModifiers).Append(' ').Append(grp.Key.TypeName); + NewLine().Append("{"); + indent++; + braces++; + } + } + + foreach (var method in grp) + { + if (method.DebugNotes is { Length: > 0 }) + { + NewLine().Append("/* ").Append(method.MethodName).Append(": ") + .Append(method.DebugNotes).Append(" */"); + } + + bool isSharedFormatter = false; + string? formatter = method.Formatter + ?? InbuiltFormatter(method.Parameters); + if (formatter is null && formatters.TryGetValue(method.Parameters, out var tmp)) + { + formatter = $"{tmp.Name}.Default"; + isSharedFormatter = tmp.ShareCount > 1; + } + + // perform string escaping on the generated value (this includes the quotes, note) + var csValue = CodeLiteral(method.Command); + + WriteMethod(false); + if ((method.Flags & MethodFlags.RespOperation) == 0) + { + WriteMethod(true); // also write async half + } + + void WriteMethod(bool asAsync) + { + sb = NewLine().Append(asAsync ? RemovePartial(method.MethodModifiers) : method.MethodModifiers) + .Append(' '); + if (asAsync) + { + sb.Append(HasAnyFlag(method.Parameters, ParameterFlags.CommandFlags) ? "Task" : "ValueTask"); + if (!string.IsNullOrWhiteSpace(method.ReturnType)) + { + sb.Append('<').Append(method.ReturnType).Append('>'); + } + } + else if (method.IsRespOperation) + { + sb.Append("global::RESPite.RespOperation"); + if (!string.IsNullOrWhiteSpace(method.ReturnType)) + { + sb.Append('<').Append(method.ReturnType).Append('>'); + } + } + else + { + sb.Append(string.IsNullOrEmpty(method.ReturnType) ? "void" : method.ReturnType); + } + + sb.Append(' ').Append(method.MethodName).Append(asAsync ? "Async" : "").Append("("); + first = true; + foreach (var param in method.Parameters) + { + if ((param.Flags & ParameterFlags.Parameter) == 0) continue; + if (!first) sb.Append(", "); + first = false; + + sb.Append(param.Modifiers).Append(param.Type).Append(' ').Append(param.Name); + } + + var dataParameters = DataParameterCount(method.Parameters); + sb.Append(")"); + indent++; + + var parser = method.Parser ?? InbuiltParser(method.ReturnType, explicitSuccess: true); + bool useDirectCall = method.Context is { Length: > 0 } & formatter is { Length: > 0 } & + parser is { Length: > 0 }; + + if (string.IsNullOrWhiteSpace(method.Context)) + { + NewLine().Append("=> throw new NotSupportedException(\"No RespContext available\");"); + useDirectCall = false; + } + else if (!(useDirectCall & asAsync)) + { + sb = NewLine(); + if (useDirectCall) sb.Append("// "); + sb.Append("=> ").Append(method.Context).Append(".Command(").Append(csValue).Append("u8"); + if (dataParameters != 0) + { + sb.Append(", "); + WriteTuple(method.Parameters, sb, TupleMode.Values); + + if (!string.IsNullOrWhiteSpace(formatter)) + { + sb.Append(", ").Append(formatter); + } + } + sb.Append(asAsync | method.IsRespOperation ? ").Send" : ").Wait"); + if (!string.IsNullOrWhiteSpace(method.ReturnType)) + { + sb.Append('<').Append(method.ReturnType).Append('>'); + } + + sb.Append("(").Append(parser).Append(")"); + if (asAsync && HasAnyFlag(method.Parameters, ParameterFlags.CommandFlags)) + { + sb.Append(".AsTask()"); + } + + sb.Append(';'); + } + + if (useDirectCall) // avoid the intermediate step when possible + { + sb = NewLine().Append("=> ").Append(method.Context).Append(".Send") + .Append('<'); + WriteTuple( + method.Parameters, + sb, + isSharedFormatter ? TupleMode.SyntheticNames : TupleMode.NamedTuple); + if (!string.IsNullOrWhiteSpace(method.ReturnType)) + { + sb.Append(", ").Append(method.ReturnType); + } + + sb.Append(">(").Append(csValue).Append("u8").Append(", "); + WriteTuple(method.Parameters, sb, TupleMode.Values); + sb.Append(", ").Append(formatter).Append(", ").Append(parser).Append(")"); + if (asAsync) + { + sb.Append(HasAnyFlag(method.Parameters, ParameterFlags.CommandFlags) + ? ".AsTask()" + : ".AsValueTask()"); + } + else if (method.IsRespOperation) + { + // nothing to do + } + else + { + sb.Append(".Wait("); + if (HasAnyFlag(method.Parameters, ParameterFlags.CommandFlags)) + { + // to avoid calling Context(flags) twice, we assume that this member will exist + sb.Append("SyncTimeout"); + } + else + { + sb.Append(method.Context).Append(".SyncTimeout"); + } + + sb.Append(")"); + } + + sb.Append(";"); + } + + indent--; + NewLine(); + } + } + + // handle any closing braces + while (braces-- > 0) + { + indent--; + NewLine().Append("}"); + } + + NewLine(); + } + + foreach (var tuple in formatters) + { + var parameters = tuple.Key; + var name = tuple.Value.Name; + var names = tuple.Value.ShareCount > 1 ? TupleMode.SyntheticNames : TupleMode.NamedTuple; + + NewLine(); + if (tuple.Value.ShareCount > 1) + { + NewLine().Append("// shared by ").Append(tuple.Value.ShareCount).Append(" methods"); + } + else if (tuple.Value.Command is { Length: > 0 }) + { + NewLine().Append("// for command: ").Append(tuple.Value.Command); + } + + sb = NewLine().Append("sealed file class ").Append(name) + .Append(" : global::RESPite.Messages.IRespFormatter<"); + WriteTuple(parameters, sb, names); + sb.Append('>'); + NewLine().Append("{"); + indent++; + NewLine().Append("public static readonly ").Append(name).Append(" Default = new();"); + NewLine(); + + sb = NewLine() + .Append( + "public void Format(scoped ReadOnlySpan command, ref global::RESPite.Messages.RespWriter writer, in "); + WriteTuple(parameters, sb, names); + sb.Append(" request)"); + NewLine().Append("{"); + indent++; + var argCount = DataParameterCount(parameters, out int constantCount, out bool isVariable); + + void WriteParameterName(in ParameterTuple p, StringBuilder? target = null) + { + target ??= sb; + if (argCount == 1) + { + target.Append("request"); + } + else + { + target.Append("request."); + if (names == TupleMode.SyntheticNames) + { + target.Append("Arg").Append(p.ArgIndex); + } + else + { + target.Append(p.Name); + } + } + } + + int index; + if (isVariable) + { + foreach (var parameter in parameters.Span) + { + if (parameter.IsVariable) + { + sb = NewLine().Append("bool __inc").Append(parameter.ArgIndex).Append(" = "); + WriteParameterName(parameter); + switch (parameter.OptionalReasons) + { + case ParameterFlags.Nullable: + sb.Append(" is not null"); + break; + case ParameterFlags.Nullable | ParameterFlags.IgnoreExpression: + sb.Append(" is { } __val").Append(parameter.ArgIndex) + .Append(" && __val").Append(parameter.ArgIndex) + .Append(parameter.IgnoreExpression); + break; + case ParameterFlags.IgnoreExpression: + sb.Append(parameter.IgnoreExpression); + break; + case ParameterFlags.Collection: + // non-nullable collection; literals already handled + switch (parameter.Flags & (ParameterFlags.CollectionWithCount | ParameterFlags.ImmutableArray)) + { + case ParameterFlags.CollectionWithCount: + sb.Append(".Count != 0"); + break; + case ParameterFlags.ImmutableArray: // needs special care because of default (breaks .Length) + sb.Append(".IsDefaultOrEmpty == false"); + break; + default: + sb.Append(".Length != 0"); + break; + } + break; + case ParameterFlags.Collection | ParameterFlags.Nullable: + sb.Append(" is { "); + switch (parameter.Flags & (ParameterFlags.CollectionWithCount | ParameterFlags.ImmutableArray)) + { + case ParameterFlags.CollectionWithCount: + sb.Append("Count: > 0"); + break; + case ParameterFlags.ImmutableArray: // needs special care because of default (breaks .Length) + sb.Append("IsDefaultOrEmpty: false"); + break; + default: + sb.Append("Length: > 0"); + break; + } + sb.Append("}"); + break; + default: + sb.Append($" false /* unhandled combination! */"); + break; + } + sb.Append("; // ").Append(parameter.OptionalReasons); + } + } + + sb = NewLine().Append("writer.WriteCommand(command,"); + bool firstVariableItem = true; + if (constantCount != 0) + { + sb.Append(" ").Append(constantCount).Append(" // constant args"); + firstVariableItem = false; + } + indent++; + index = 0; + foreach (var parameter in parameters.Span) + { + if (parameter.IsVariable) + { + sb = NewLine(); + if (firstVariableItem) + { + firstVariableItem = false; + } + else + { + sb.Append("+ "); + } + sb.Append("(__inc").Append(parameter.ArgIndex).Append(" ? "); + var literalCount = parameter.Literals.Length; + if (!parameter.IsCollection) + { + sb.Append(1 + literalCount); + } + else + { + if (literalCount != 0) sb.Append("("); + WriteParameterName(parameter); + if (parameter.IsNullable) sb.Append("!"); + sb.Append((parameter.Flags & ParameterFlags.CollectionWithCount) == 0 ? ".Length" : ".Count"); + if (literalCount != 0) sb.Append(" + ").Append(literalCount).Append(")"); + } + + sb.Append(" : 0)"); + if (!parameter.IsCollection) + { + // help identify what this is (not needed for collections, since foo.Count etc) + sb.Append(" // "); + WriteParameterName(parameter); + if (tuple.Value.ShareCount != 1) sb.Append(" (").Append(parameter.Name).Append(")"); // give an example + } + + if (literalCount != 0) + { + if (parameter.IsCollection) sb.Append(" //"); + sb.Append(" with"); + foreach (var literal in parameter.Literals.Span) + { + sb.Append(" ").Append(string.IsNullOrEmpty(literal.Token) ? "(count)" : literal.Token); + } + } + } + index++; + } + NewLine().Append(");"); + indent--; + } + else if (tuple.Value.Command is { Length: > 0 } cmd + && Encoding.UTF8.GetByteCount(cmd) == cmd.Length) // check pure ASCII + { + // only used by one command; allow optimization + NewLine().Append("if(writer.CommandMap is null)"); + NewLine().Append("{"); + indent++; + string raw = $"*{constantCount + 1}\r\n${cmd.Length}\r\n{tuple.Value.Command}\r\n"; + sb = NewLine().Append("writer.WriteRaw(").Append(CodeLiteral(raw)).Append("u8); // ") + .Append(cmd).Append(" with ").Append(constantCount).Append(" args"); + indent--; + NewLine().Append("}"); + NewLine().Append("else"); + NewLine().Append("{"); + indent++; + NewLine().Append("writer.WriteCommand(command, ").Append(constantCount).Append(");"); + indent--; + NewLine().Append("}"); + } + else + { + NewLine().Append("writer.WriteCommand(command, ").Append(constantCount).Append(");"); + } + + void WritePrefix(in ParameterTuple p) => WriteLiteral(p, false); + void WriteSuffix(in ParameterTuple p) => WriteLiteral(p, true); + + void WriteLiteral(in ParameterTuple p, bool suffix) + { + LiteralFlags match = suffix ? LiteralFlags.Suffix : LiteralFlags.None; + foreach (var literal in p.Literals.Span) + { + if ((literal.Flags & LiteralFlags.Suffix) == match) + { + if (string.IsNullOrEmpty(literal.Token)) + { + if (p.IsCollection) + { + sb = NewLine().Append("writer.WriteBulkString("); + WriteParameterName(p); + if (p.IsNullable) sb.Append("!"); + sb.Append((p.Flags & ParameterFlags.CollectionWithCount) == 0 ? ".Length" : ".Count") + .Append(");"); + } + else + { + NewLine().Append("#error empty literal for ").Append(p.Name).AppendLine(); + } + } + else + { + var len = Encoding.UTF8.GetByteCount(literal.Token); + var resp = $"${len}\r\n{literal.Token}\r\n"; + NewLine().Append("writer.WriteRaw(").Append(CodeLiteral(resp)).Append("u8); // ") + .Append(literal.Token); + } + } + } + } + + index = 0; + foreach (var parameter in parameters.Span) + { + if ((parameter.Flags & ParameterFlags.DataParameter) == ParameterFlags.DataParameter) + { + if (parameter.IsVariable) + { + sb = NewLine().Append("if (__inc").Append(parameter.ArgIndex).Append(")"); + NewLine().Append("{"); + indent++; + } + + WritePrefix(parameter); + var elementType = parameter.ElementType ?? parameter.Type; + if (parameter.IsCollection) + { + sb = NewLine().Append("foreach (").Append(elementType).Append(" val in "); + WriteParameterName(parameter); + if (parameter.IsNullable) sb.Append("!"); + sb.Append(")"); + NewLine().Append("{"); + indent++; + } + + sb = NewLine().Append("writer."); + if (elementType is "global::StackExchange.Redis.RedisValue" + or "global::StackExchange.Redis.RedisKey") + { + sb.Append("Write"); + } + else + { + sb.Append((parameter.Flags & ParameterFlags.Key) == 0 ? "WriteBulkString" : "WriteKey"); + } + + sb.Append("("); + if (parameter.IsCollection) + { + sb.Append("val"); + } + else + { + WriteParameterName(parameter); + } + sb.Append(");"); + + if (parameter.IsCollection) + { + indent--; + NewLine().Append("}"); + } + + WriteSuffix(parameter); + if (parameter.IsVariable) + { + indent--; + NewLine().Append("}"); + } + index++; + } + } + + Debug.Assert(index == argCount, "wrote all parameters"); + + indent--; + NewLine().Append("}"); + indent--; + NewLine().Append("}"); + } + + NewLine(); + ctx.AddSource(GetType().Name + ".generated.cs", sb.ToString()); + + static void WriteTuple( + EasyArray parameters, + StringBuilder sb, + TupleMode mode) + { + var count = DataParameterCount(parameters); + if (count == 0) return; + if (count < 2) + { + var p = FirstDataParameter(parameters); + sb.Append(mode == TupleMode.Values ? p.Name : p.Type); + return; + } + + sb.Append('('); + int index = 0; + foreach (var param in parameters.Span) + { + if ((param.Flags & ParameterFlags.DataParameter) != ParameterFlags.DataParameter) + { + continue; // note don't increase index + } + + if (index != 0) sb.Append(", "); + + switch (mode) + { + case TupleMode.Values: + sb.Append(param.Name); + break; + case TupleMode.AnonTuple: + sb.Append(param.Type); + break; + case TupleMode.NamedTuple: + sb.Append(param.Type).Append(' ').Append(param.Name); + break; + case TupleMode.SyntheticNames: + sb.Append(param.Type).Append(" Arg").Append(index); + break; + } + + index++; + } + + sb.Append(')'); + } + } + + private static bool HasAnyFlag( + EasyArray parameters, + ParameterFlags any) + { + foreach (var p in parameters.Span) + { + if ((p.Flags & any) != 0) return true; + } + + return false; + } + + private static string? InbuiltFormatter( + EasyArray parameters) + { + if (DataParameterCount(parameters) == 1) + { + var p = FirstDataParameter(parameters); + if (p.Literals.IsEmpty) + { + // can only use the inbuilt formatter if there are no literals + return InbuiltFormatter(p.Type, (p.Flags & ParameterFlags.Key) != 0); + } + } + + return null; + } + + private static ParameterTuple FirstDataParameter( + EasyArray parameters) + { + if (!parameters.IsEmpty) + { + foreach (var parameter in parameters.Span) + { + if ((parameter.Flags & ParameterFlags.DataParameter) == ParameterFlags.DataParameter) + { + return parameter; + } + } + } + + return Array.Empty().First(); + } + + private static int DataParameterCount( + EasyArray parameters) + => DataParameterCount(parameters, out _, out _); + + private static int DataParameterCount( + EasyArray parameters, out int constantCount, out bool isVariable) + { + // note: constantCount includes literals + constantCount = 0; + isVariable = false; + if (parameters.IsEmpty) return 0; + int count = 0; + foreach (var parameter in parameters.Span) + { + if ((parameter.Flags & ParameterFlags.DataParameter) == ParameterFlags.DataParameter) + { + bool thisParamIsVariable = false; + count++; + if (parameter.IsVariable) + { + isVariable = thisParamIsVariable = true; + } + else + { + constantCount++; + } + + if (!(thisParamIsVariable | parameter.Literals.IsEmpty)) + { + constantCount += parameter.Literals.Length; // we include literals if not variable + } + } + } + + return count; + } + + private const string RespFormattersPrefix = "global::RESPite.RespFormatters."; + + private static string? InbuiltFormatter(string type, bool isKey) => type switch + { + "string" => isKey ? (RespFormattersPrefix + "Key.String") : (RespFormattersPrefix + "Value.String"), + "byte[]" => isKey ? (RespFormattersPrefix + "Key.ByteArray") : (RespFormattersPrefix + "Value.ByteArray"), + "int" => RespFormattersPrefix + "Int32", + "long" => RespFormattersPrefix + "Int64", + "float" => RespFormattersPrefix + "Single", + "double" => RespFormattersPrefix + "Double", + "" => RespFormattersPrefix + "Empty", + "global::StackExchange.Redis.RedisKey" => "global::RESPite.StackExchange.Redis.RespFormatters.RedisKey", + "global::StackExchange.Redis.RedisKey[]" => "global::RESPite.StackExchange.Redis.RespFormatters.RedisKeyArray", + "global::StackExchange.Redis.RedisValue" => "global::RESPite.StackExchange.Redis.RespFormatters.RedisValue", + _ => null, + }; + + private const string RespParsersPrefix = "global::RESPite.RespParsers."; + + private static string? InbuiltParser(string type, bool explicitSuccess = false) => type switch + { + "" when explicitSuccess => RespParsersPrefix + "Success", + "bool" => RespParsersPrefix + "Success", + "string" => RespParsersPrefix + "String", + "int" => RespParsersPrefix + "Int32", + "long" => RespParsersPrefix + "Int64", + "float" => RespParsersPrefix + "Single", + "double" => RespParsersPrefix + "Double", + "int?" => RespParsersPrefix + "NullableInt32", + "long?" => RespParsersPrefix + "NullableInt64", + "float?" => RespParsersPrefix + "NullableSingle", + "double?" => RespParsersPrefix + "NullableDouble", + "global::RESPite.RespParsers.ResponseSummary" => RespParsersPrefix + "ResponseSummary.Parser", + "global::StackExchange.Redis.RedisKey" => "global::RESPite.StackExchange.Redis.RespParsers.RedisKey", + "global::StackExchange.Redis.RedisValue" => "global::RESPite.StackExchange.Redis.RespParsers.RedisValue", + "global::StackExchange.Redis.RedisValue[]" => "global::RESPite.StackExchange.Redis.RespParsers.RedisValueArray", + "global::StackExchange.Redis.HashEntry[]" => "global::RESPite.StackExchange.Redis.RespParsers.HashEntryArray", + "global::StackExchange.Redis.SortedSetEntry[]" => "global::RESPite.StackExchange.Redis.RespParsers.SortedSetEntryArray", + "global::StackExchange.Redis.SortedSetEntry?" => "global::RESPite.StackExchange.Redis.RespParsers.SortedSetEntry", + "global::StackExchange.Redis.Lease" => "global::RESPite.StackExchange.Redis.RespParsers.BytesLease", + _ => null, + }; + + private enum TupleMode + { + AnonTuple, + NamedTuple, + Values, + SyntheticNames, + } + + private static string RemovePartial(string modifiers) + { + if (string.IsNullOrWhiteSpace(modifiers) || !modifiers.Contains("partial")) return modifiers; + if (modifiers == "partial") return ""; + if (modifiers.StartsWith("partial ")) return modifiers.Substring(8); + if (modifiers.EndsWith(" partial")) return modifiers.Substring(0, modifiers.Length - 8); + return modifiers.Replace(" partial ", " "); + } + + [Flags] + private enum MethodFlags + { + None = 0, + RespOperation = 1 << 0, + ExtensionMethod = 1 << 1, + } + + [Flags] + private enum ParameterFlags + { + // ReSharper disable once UnusedMember.Local + None = 0, + Parameter = 1 << 0, + Data = 1 << 1, + DataParameter = Data | Parameter, + Key = 1 << 2, + CommandFlags = 1 << 3, + ValueType = 1 << 4, + Nullable = 1 << 5, + Collection = 1 << 6, + CollectionWithCount = 1 << 7, // has .Count, otherwise assumed to have .Length + ImmutableArray = 1 << 8, + IgnoreExpression = 1 << 9, + } + + // compares whether a formatter can be shared, which depends on the key index and types (not names) + private sealed class + FormatterComparer + : IEqualityComparer> + { + private FormatterComparer() { } + public static readonly FormatterComparer Default = new(); + + public bool Equals( + EasyArray x, + EasyArray y) + { + if (x.Length != y.Length) return false; + for (int i = 0; i < x.Length; i++) + { + var px = x[i]; + var py = y[i]; + if (px.Type != py.Type || px.Flags != py.Flags) return false; + // literals need to match by name too + if (!px.Literals.SequenceEqual(py.Literals)) return false; + } + + return true; + } + + public int GetHashCode( + EasyArray obj) + { + var hash = obj.Length; + foreach (var p in obj.Span) + { + hash ^= p.Type.GetHashCode() ^ (int)p.Flags ^ p.Literals.Length; + } + + return hash; + } + } +} diff --git a/eng/StackExchange.Redis.Build/RespCommandGenerator.md b/eng/StackExchange.Redis.Build/RespCommandGenerator.md new file mode 100644 index 000000000..1a3ec5fbc --- /dev/null +++ b/eng/StackExchange.Redis.Build/RespCommandGenerator.md @@ -0,0 +1,18 @@ +# RespCommandGenerator + +Emit basic RESP command bodies. + +The purpose of this generator is to interpret inputs like: + +``` c# +[RespCommand] // optional: include explicit command text +public int void Foo(string key, int delta, double x); +``` + +and implement the relevant sync and async core logic, including +implementing a custom `IRespFormatter<(string, int, double)>`. Note that +the formatter can be reused between commands, so the names are not used internally. + +Note that parameters named `key` are detected automatically for sharding purposes; +when this is not suitable,`[Key]` can be used instead to denote a parameter to use +for sharding - for example `partial void Rename([Key] string fromKey, string toKey)`. \ No newline at end of file diff --git a/eng/StackExchange.Redis.Build/StackExchange.Redis.Build.csproj b/eng/StackExchange.Redis.Build/StackExchange.Redis.Build.csproj index f875133ba..f57005154 100644 --- a/eng/StackExchange.Redis.Build/StackExchange.Redis.Build.csproj +++ b/eng/StackExchange.Redis.Build/StackExchange.Redis.Build.csproj @@ -5,6 +5,7 @@ enable enable true + true @@ -16,5 +17,4 @@ FastHash.cs - diff --git a/global.json b/global.json index 35e954767..f00fd8fcc 100644 --- a/global.json +++ b/global.json @@ -1,6 +1,6 @@ { "sdk": { - "allowPrerelease": false + "allowPrerelease": true } } \ No newline at end of file diff --git a/src/Directory.Build.props b/src/Directory.Build.props index 06e403ebb..3d2acbaaf 100644 --- a/src/Directory.Build.props +++ b/src/Directory.Build.props @@ -6,11 +6,16 @@ false - + + + + + $(MSBuildWarningsAsMessages);MSB3277 + diff --git a/src/RESP.Core/AmbientBufferWriter.cs b/src/RESP.Core/AmbientBufferWriter.cs new file mode 100644 index 000000000..bc64abc75 --- /dev/null +++ b/src/RESP.Core/AmbientBufferWriter.cs @@ -0,0 +1,93 @@ +using System; +using System.Buffers; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace Resp; + +internal sealed class AmbientBufferWriter : IBufferWriter +{ + [ThreadStatic] + private static AmbientBufferWriter? _threadStaticInstance; + + public static AmbientBufferWriter Get(int estimatedSize) + { + var obj = _threadStaticInstance ??= new AmbientBufferWriter(); + obj.Init(estimatedSize); + return obj; + } + + private byte[] _buffer = []; + private int _committed; + + private void Init(int size) + { + _committed = 0; + if (size < 0) size = 0; + if (_buffer.Length < size) + { + DemandCapacity(size); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void DemandCapacity(int size) + { + const int MIN_BUFFER = 32; + size = Math.Max(size, MIN_BUFFER); + + if (_committed + size > _buffer.Length) + { + GrowBy(size); + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private void GrowBy(int length) + { + var newSize = Math.Max(_committed + length, checked((_buffer.Length * 3) / 2)); + byte[] newBuffer = ArrayPool.Shared.Rent(newSize), oldBuffer = _buffer; + if (_committed != 0) + { + new ReadOnlySpan(oldBuffer, 0, _committed).CopyTo(newBuffer); + } + + _buffer = newBuffer; + ArrayPool.Shared.Return(oldBuffer); + } + + internal byte[] Detach(out int length) + { + length = _committed; + if (length == 0) return []; + var result = _buffer; + _buffer = []; + _committed = 0; + return result; + } + + public void Advance(int count) + { + var capacity = _buffer.Length - _committed; + if (count < 0 || count > capacity) Throw(); + { + _committed += count; + } + + static void Throw() => throw new ArgumentOutOfRangeException(nameof(count)); + } + + public Memory GetMemory(int sizeHint = 0) + { + DemandCapacity(sizeHint); + return new(_buffer, _committed, _buffer.Length - _committed); + } + + public Span GetSpan(int sizeHint = 0) + { + DemandCapacity(sizeHint); + return new(_buffer, _committed, _buffer.Length - _committed); + } + + internal void Reset() => _committed = 0; +} diff --git a/src/RESP.Core/BatchConnection.cs b/src/RESP.Core/BatchConnection.cs new file mode 100644 index 000000000..29ce1dfef --- /dev/null +++ b/src/RESP.Core/BatchConnection.cs @@ -0,0 +1,229 @@ +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.InteropServices; +using System.Threading.Tasks; + +namespace Resp; + +public interface IBatchConnection : IRespConnection +{ + Task FlushAsync(); + void Flush(); +} + +internal sealed class BatchConnection : IBatchConnection +{ + private bool _isDisposed; + private readonly List _unsent; + private readonly IRespConnection _tail; + private readonly RespContext _context; + + public BatchConnection(in RespContext context, int sizeHint) + { + // ReSharper disable once ConditionIsAlwaysTrueOrFalseAccordingToNullableAPIContract - an abundance of caution + var tail = context.Connection; + if (tail is not { CanWrite: true }) ThrowNonWritable(); + if (tail is BatchConnection) ThrowBatch(); + + _unsent = sizeHint <= 0 ? [] : new List(sizeHint); + _tail = tail!; + _context = context.WithConnection(this); + static void ThrowBatch() => throw new ArgumentException("Nested batches are not supported", nameof(tail)); + + static void ThrowNonWritable() => + throw new ArgumentException("A writable connection is required", nameof(tail)); + } + + public void Dispose() + { + lock (_unsent) + { + /* everyone else checks disposal inside the lock, so: + once we've set this, we can be sure that no more + items will be added */ + _isDisposed = true; + } +#if NET5_0_OR_GREATER + var span = CollectionsMarshal.AsSpan(_unsent); + foreach (var message in span) + { + message.TrySetException(new ObjectDisposedException(ToString())); + } +#else + foreach (var message in _unsent) + { + message.TrySetException(new ObjectDisposedException(ToString())); + } +#endif + _unsent.Clear(); + } + + public ValueTask DisposeAsync() + { + Dispose(); + return default; + } + + public RespConfiguration Configuration => _tail.Configuration; + public bool CanWrite => _tail.CanWrite; + + public int Outstanding + { + get + { + lock (_unsent) + { + return _unsent.Count; + } + } + } + + public ref readonly RespContext Context => ref _context; + + private const string SyncMessage = "Batch connections do not support synchronous sends"; + public void Send(IRespMessage message) => throw new NotSupportedException(SyncMessage); + + public void Send(ReadOnlySpan messages) => throw new NotSupportedException(SyncMessage); + + private void ThrowIfDisposed() + { + if (_isDisposed) Throw(); + static void Throw() => throw new ObjectDisposedException(nameof(BatchConnection)); + } + + public Task SendAsync(IRespMessage message) + { + lock (_unsent) + { + ThrowIfDisposed(); + _unsent.Add(message); + } + + return Task.CompletedTask; + } + + public Task SendAsync(ReadOnlyMemory messages) + { + if (messages.Length != 0) + { + lock (_unsent) + { + ThrowIfDisposed(); +#if NET8_0_OR_GREATER + _unsent.AddRange(messages.Span); // internally optimized +#else + // two-step; first ensure capacity, then add in loop +#if NET6_0_OR_GREATER + _unsent.EnsureCapacity(_unsent.Count + messages.Length); +#else + var required = _unsent.Count + messages.Length; + if (_unsent.Capacity < required) + { + const int maxLength = 0X7FFFFFC7; // not directly available on down-level runtimes :( + var newCapacity = _unsent.Capacity * 2; // try doubling + if ((uint)newCapacity > maxLength) newCapacity = maxLength; // account for max + if (newCapacity < required) newCapacity = required; // in case doubling wasn't enough + _unsent.Capacity = newCapacity; + } +#endif + foreach (var message in messages.Span) + { + _unsent.Add(message); + } +#endif + } + } + + return Task.CompletedTask; + } + + private int Flush(out IRespMessage[] oversized, out IRespMessage? single) + { + lock (_unsent) + { + var count = _unsent.Count; + switch (count) + { + case 0: + oversized = []; + single = null; + break; + case 1: + oversized = []; + single = _unsent[0]; + break; + default: + oversized = ArrayPool.Shared.Rent(count); + single = null; + _unsent.CopyTo(oversized); + break; + } + + _unsent.Clear(); + return count; + } + } + + public Task FlushAsync() + { + var count = Flush(out var oversized, out var single); + return count switch + { + 0 => Task.CompletedTask, + 1 => _tail.SendAsync(single!), + _ => SendAndRecycleAsync(_tail, oversized, count), + }; + + static async Task SendAndRecycleAsync(IRespConnection tail, IRespMessage[] oversized, int count) + { + try + { + await tail.SendAsync(oversized.AsMemory(0, count)).ConfigureAwait(false); + ArrayPool.Shared.Return(oversized); // only on success, in case captured + } + catch (Exception ex) + { + foreach (var message in oversized.AsSpan(0, count)) + { + message.TrySetException(ex); + } + + throw; + } + } + } + + public void Flush() + { + var count = Flush(out var oversized, out var single); + switch (count) + { + case 0: + return; + case 1: + _tail.Send(single!); + return; + } + + try + { + _tail.Send(oversized.AsSpan(0, count)); + } + catch (Exception ex) + { + foreach (var message in oversized.AsSpan(0, count)) + { + message.TrySetException(ex); + } + + throw; + } + finally + { + // in the sync case, Send takes a span - hence can't have been captured anywhere; always recycle + ArrayPool.Shared.Return(oversized); + } + } +} diff --git a/src/RESP.Core/Builder.cs b/src/RESP.Core/Builder.cs new file mode 100644 index 000000000..0615bd3f9 --- /dev/null +++ b/src/RESP.Core/Builder.cs @@ -0,0 +1,33 @@ +using System; +using System.Threading.Tasks; + +namespace Resp; + +public readonly ref struct RespMessageBuilder(RespContext context, ReadOnlySpan command, TRequest value, IRespFormatter formatter) +#if NET9_0_OR_GREATER + where TRequest : allows ref struct +#endif +{ + private readonly ReadOnlySpan _command = command; + private readonly TRequest _value = value; // cannot inline to .ctor because of "allows ref struct" + + public TResponse Wait() + => Message.Send(context, _command, _value, formatter, RespParsers.Get()); + public TResponse Wait(IRespParser parser) + => Message.Send(context, _command, _value, formatter, parser); + + public void Wait() + => Message.Send(context, _command, _value, formatter, RespParsers.Success); + public void Wait(IRespParser parser) + => Message.Send(context, _command, _value, formatter, parser); + + public ValueTask AsValueTask() + => Message.SendAsync(context, _command, _value, formatter, RespParsers.Get()); + public ValueTask AsValueTask(IRespParser parser) + => Message.SendAsync(context, _command, _value, formatter, parser); + + public ValueTask AsValueTask() + => Message.SendAsync(context, _command, _value, formatter, RespParsers.Success); + public ValueTask AsValueTask(IRespParser parser) + => Message.SendAsync(context, _command, _value, formatter, parser); +} diff --git a/src/RESP.Core/CustomNetworkStream.cs b/src/RESP.Core/CustomNetworkStream.cs new file mode 100644 index 000000000..4673ea4ed --- /dev/null +++ b/src/RESP.Core/CustomNetworkStream.cs @@ -0,0 +1,297 @@ +#if NETCOREAPP3_0_OR_GREATER + +using System; +using System.Diagnostics; +using System.IO; +using System.Net.Sockets; +using System.Runtime.InteropServices; +using System.Threading; +using System.Threading.Tasks; +using System.Threading.Tasks.Sources; + +namespace Resp; + +internal sealed class CustomNetworkStream(Socket socket) : Stream +{ + private SocketAwaitableEventArgs _readArgs = new(), _writeArgs = new(); + private SocketAwaitableEventArgs ReadArgs() => _readArgs.Next(); + private SocketAwaitableEventArgs WriteArgs() => _writeArgs.Next(); + + public override void Close() + { + socket.Close(); + } + + protected override void Dispose(bool disposing) + { + if (disposing) + { + socket.Dispose(); + _readArgs.Dispose(); + _writeArgs.Dispose(); + } + + base.Dispose(disposing); + } + + public override void Flush() { } + + public override Task FlushAsync(CancellationToken cancellationToken) => Task.CompletedTask; + + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + + public override void SetLength(long value) => throw new NotSupportedException(); + + public override int Read(byte[] buffer, int offset, int count) => + socket.Receive(buffer, offset, count, SocketFlags.None); + + public override void Write(byte[] buffer, int offset, int count) => + socket.Send(buffer, offset, count, SocketFlags.None); + + public override int Read(Span buffer) => socket.Receive(buffer); + + public override void Write(ReadOnlySpan buffer) => socket.Send(buffer); + + private static void ThrowCancellable() => throw new NotSupportedException( + "Cancellable operations are not supported on this stream; cancellation should be handled at the message level, not the IO level."); + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + if (cancellationToken.CanBeCanceled) ThrowCancellable(); + var args = ReadArgs(); + args.SetBuffer(buffer, offset, count); + if (socket.ReceiveAsync(args)) return args.Pending().AsTask(); + return Task.FromResult(args.GetInlineResult()); + } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + if (cancellationToken.CanBeCanceled) ThrowCancellable(); + var args = WriteArgs(); + args.SetBuffer(buffer, offset, count); + if (socket.SendAsync(args)) return args.Pending().AsTask(); + args.GetInlineResult(); // check for socket errors + return Task.CompletedTask; + } + + public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + if (cancellationToken.CanBeCanceled) ThrowCancellable(); + var args = ReadArgs(); + args.SetBuffer(buffer); + if (socket.ReceiveAsync(args)) return args.Pending(); + return new(args.GetInlineResult()); + } + + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + if (cancellationToken.CanBeCanceled) ThrowCancellable(); + var args = WriteArgs(); + args.SetBuffer(MemoryMarshal.AsMemory(buffer)); + if (socket.SendAsync(args)) return args.PendingNoValue(); + args.GetInlineResult(); // check for socket errors + return default; + } + + public override int ReadByte() + { + Span buffer = stackalloc byte[1]; + int count = socket.Receive(buffer); + return count <= 0 ? -1 : buffer[0]; + } + + public override void WriteByte(byte value) + { + ReadOnlySpan buffer = [value]; + socket.Send(buffer); + } + + public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state) + { + var args = ReadArgs(); + args.SetBuffer(buffer, offset, count); + args.CompletedSynchronously = false; + if (socket.SendAsync(args)) + { + args.OnCompleted(callback, state); + } + else + { + args.CompletedSynchronously = true; + callback?.Invoke(args); + } + + return args; + } + + public override int EndRead(IAsyncResult asyncResult) => ((SocketAwaitableEventArgs)asyncResult).GetInlineResult(); + + public override IAsyncResult BeginWrite( + byte[] buffer, + int offset, + int count, + AsyncCallback? callback, + object? state) + { + var args = WriteArgs(); + args.SetBuffer(buffer, offset, count); + args.CompletedSynchronously = false; + if (socket.SendAsync(args)) + { + args.OnCompleted(callback, state); + } + else + { + args.CompletedSynchronously = true; + callback?.Invoke(args); + } + + return args; + } + + public override void EndWrite(IAsyncResult asyncResult) => + ((SocketAwaitableEventArgs)asyncResult).GetInlineResult(); + + public override bool CanRead => true; + public override bool CanSeek => false; + public override bool CanWrite => true; + public override long Length => throw new NotSupportedException(); + + public override long Position + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } + + public override bool CanTimeout => socket.ReceiveTimeout != 0 || socket.SendTimeout != 0; + + public override int ReadTimeout + { + get => socket.ReceiveTimeout; + set => socket.ReceiveTimeout = value; + } + + public override int WriteTimeout + { + get => socket.SendTimeout; + set => socket.SendTimeout = value; + } + + // inspired from Pipelines.Sockets.Unofficial and Kestrel's SocketAwaitableEventArgs; extended to support more scenarios + private sealed class SocketAwaitableEventArgs : SocketAsyncEventArgs, + IValueTaskSource, IValueTaskSource, IAsyncResult + { +#if NET5_0_OR_GREATER + public SocketAwaitableEventArgs() : base(unsafeSuppressExecutionContextFlow: true) { } +#else + public SocketAwaitableEventArgs() { } +#endif + private static readonly Action ContinuationCompleted = _ => { }; + + public WaitHandle AsyncWaitHandle => throw new NotSupportedException(); + public bool CompletedSynchronously { get; set; } + private volatile Action? _continuation; + + private object? _asyncCallbackState; // need an additional state here, unless we introduce type-check overhead + object? IAsyncResult.AsyncState => _asyncCallbackState; + private Action? _reusedAsyncCallback; + private Action AsyncCallback => _reusedAsyncCallback ??= OnAsyncCallback; + + public ValueTask Pending() => new(this, _token); + public ValueTask PendingNoValue() => new(this, _token); + private short _token; + + public SocketAwaitableEventArgs Next() + { + unchecked { _token++; } + + return this; + } + + private void ThrowToken() => throw new InvalidOperationException("Invalid token - overlapped IO error?"); + + private void OnAsyncCallback(object? state) + { + if (state is WaitCallback wc) + { + wc(_asyncCallbackState); + } + } + + protected override void OnCompleted(SocketAsyncEventArgs args) + { + Debug.Assert(ReferenceEquals(args, this), "Incorrect SocketAsyncEventArgs"); + var c = _continuation; + + if (c != null || (c = Interlocked.CompareExchange(ref _continuation, ContinuationCompleted, null)) != null) + { + var continuationState = UserToken; + UserToken = null; + _continuation = ContinuationCompleted; // in case someone's polling IsCompleted + + c(continuationState); // note: inline continuation + } + } + + public int GetInlineResult() + { + _continuation = null; + if (SocketError != SocketError.Success) + { + ThrowSocketError(SocketError); + } + + return BytesTransferred; + } + + void IValueTaskSource.GetResult(short token) => GetResult(token); + + public int GetResult(short token) + { + if (token != _token) ThrowToken(); + _continuation = null; + + if (SocketError != SocketError.Success) + { + ThrowSocketError(SocketError); + } + + return BytesTransferred; + } + + private static void ThrowSocketError(SocketError e) => throw new SocketException((int)e); + + public ValueTaskSourceStatus GetStatus(short token) + { + if (token != _token) ThrowToken(); + return !ReferenceEquals(_continuation, ContinuationCompleted) ? ValueTaskSourceStatus.Pending : + SocketError == SocketError.Success ? ValueTaskSourceStatus.Succeeded : + ValueTaskSourceStatus.Faulted; + } + + public bool IsCompleted => ReferenceEquals(_continuation, ContinuationCompleted); + + public void OnCompleted(AsyncCallback? callback, object? state) + { + _asyncCallbackState = state; + OnCompleted(AsyncCallback, callback, _token, ValueTaskSourceOnCompletedFlags.None); + } + + public void OnCompleted( + Action continuation, + object? state, + short token, + ValueTaskSourceOnCompletedFlags flags) + { + if (token != _token) ThrowToken(); + UserToken = state; + var prevContinuation = Interlocked.CompareExchange(ref _continuation, continuation, null); + if (ReferenceEquals(prevContinuation, ContinuationCompleted)) + { + UserToken = null; + ThreadPool.UnsafeQueueUserWorkItem(continuation, state, preferLocal: true); + } + } + } +} +#endif diff --git a/src/RESP.Core/CycleBuffer.Simple.cs b/src/RESP.Core/CycleBuffer.Simple.cs new file mode 100644 index 000000000..e469a0fab --- /dev/null +++ b/src/RESP.Core/CycleBuffer.Simple.cs @@ -0,0 +1,114 @@ +/* +using System; +using System.Buffers; +using System.Diagnostics; + +#pragma warning disable SA1205 // accessibility on partial - for debugging/test practicality + +partial struct CycleBuffer // basic impl for debugging / validation; just uses single-buffer pack-down +{ + private byte[] _buffer; + private int _committed; + + public void DiscardCommitted(long fullyConsumed) + => DiscardCommitted(checked((int)fullyConsumed)); + + public void DiscardCommitted(int fullyConsumed) + { + Debug.Assert(fullyConsumed >= 0 & fullyConsumed <= _committed); + var remaining = _committed - fullyConsumed; + if (remaining != 0) + { + var buffer = _buffer; + buffer.AsSpan(fullyConsumed, remaining).CopyTo(buffer); + } + + _committed -= fullyConsumed; + } + + public ReadOnlySequence GetAllCommitted() + => new(_buffer, 0, _committed); + + public bool TryGetCommitted(out ReadOnlySpan span) + { + span = _buffer.AsSpan(0, _committed); + return true; + } + + public void Release() + { + var buffer = _buffer; + _committed = 0; + _buffer = []; + ArrayPool.Shared.Return(buffer); + } + + public int PageSize { get; } + + private CycleBuffer(int pageSize) + { + _buffer = []; + PageSize = pageSize; + } + + public static CycleBuffer Create(MemoryPool? pool = null, int pageSize = 0) + { + _ = pool; + return new(Math.Max(pageSize, 1024)); + } + + public void Commit(int bytesRead) + { + Debug.Assert(bytesRead >= 0 & bytesRead <= UncommittedAvailable); + _committed += bytesRead; + } + + public Span GetUncommittedSpan(int hint = 1) + { + if (UncommittedAvailable < hint) Grow(hint); + return _buffer.AsSpan(_committed); + } + public Memory GetUncommittedMemory(int hint = 1) + { + if (UncommittedAvailable < hint) Grow(hint); + return _buffer.AsMemory(_committed); + } + + private void Grow(int hint) + { + hint = Math.Max(hint, 128); // at least a reasonable size + var newLength = Math.Max(_committed + hint, _committed * 2); // what we need, or double what we have; the larger + + var newBuffer = ArrayPool.Shared.Rent(newLength); + var oldBuffer = _buffer; + Debug.Assert(newBuffer.Length > oldBuffer.Length, " should have increased"); + oldBuffer.AsSpan(0, _committed).CopyTo(newBuffer); + ArrayPool.Shared.Return(oldBuffer); + _buffer = newBuffer; + } + + public int UncommittedAvailable => _buffer.Length - _committed; + public bool CommittedIsEmpty => _committed == 0; + + public int GetCommittedLength() => _committed; + + public bool TryGetFirstCommittedSpan(bool fullOnly, out ReadOnlySpan span) + { + var buffer = _buffer; + if (fullOnly) + { + if (_committed >= PageSize) + { + span = buffer.AsSpan(0, _committed); + return true; + } + // offer up a reasonable page size + span = default; + return false; + } + + span = buffer.AsSpan(0, _committed); + return _committed != 0; + } +} +*/ diff --git a/src/RESP.Core/CycleBuffer.cs b/src/RESP.Core/CycleBuffer.cs new file mode 100644 index 000000000..d3411bc2c --- /dev/null +++ b/src/RESP.Core/CycleBuffer.cs @@ -0,0 +1,707 @@ +using System; +using System.Buffers; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Threading; + +#pragma warning disable SA1205 // accessibility on partial - for debugging/test practicality + +namespace Resp; + +/// +/// Manages the state for a based IO buffer. Unlike Pipe, +/// it is not intended for a separate producer-consumer - there is no thread-safety, and no +/// activation; it just handles the buffers. It is intended to be used as a mutable (non-readonly) +/// field in a type that performs IO; the internal state mutates - it should not be passed around. +/// +/// Notionally, there is an uncommitted area (write) and a committed area (read). Process: +/// - producer loop (*note no concurrency**) +/// - call to get a new scratch +/// - (write to that span) +/// - call to mark complete portions +/// - consumer loop (*note no concurrency**) +/// - call to see if there is a single-span chunk; otherwise +/// - call to get the multi-span chunk +/// - (process none, some, or all of that data) +/// - call to indicate how much data is no longer needed +/// Emphasis: no concurrency! This is intended for a single worker acting as both producer and consumer. +/// +/// There is a *lot* of validation in debug mode; we want to be super sure that we don't corrupt buffer state. +/// +partial struct CycleBuffer +{ + // note: if someone uses an uninitialized CycleBuffer (via default): that's a skills issue; git gud + public static CycleBuffer Create(MemoryPool? pool = null, int pageSize = DefaultPageSize) + { + pool ??= MemoryPool.Shared; + if (pageSize <= 0) pageSize = DefaultPageSize; + if (pageSize > pool.MaxBufferSize) pageSize = pool.MaxBufferSize; + + return new CycleBuffer(pool, pageSize); + } + + private CycleBuffer(MemoryPool pool, int pageSize) + { + Pool = pool; + PageSize = pageSize; + } + + private const int DefaultPageSize = 8 * 1024; + + public int PageSize { get; } + public MemoryPool Pool { get; } + + private Segment? startSegment, endSegment; + + private int endSegmentCommitted, endSegmentLength; + + public bool TryGetCommitted(out ReadOnlySpan span) + { + DebugAssertValid(); + if (!ReferenceEquals(startSegment, endSegment)) + { + span = default; + return false; + } + + span = startSegment is null ? default : startSegment.Memory.Span.Slice(start: 0, length: endSegmentCommitted); + return true; + } + + /// + /// Commits data written to buffers from , making it available for consumption + /// via . This compares to . + /// + public void Commit(int count) + { + DebugAssertValid(); + if (count <= 0) + { + if (count < 0) Throw(); + return; + } + + var available = endSegmentLength - endSegmentCommitted; + if (count > available) Throw(); + endSegmentCommitted += count; + DebugAssertValid(); + + static void Throw() => throw new ArgumentOutOfRangeException(nameof(count)); + } + + public bool CommittedIsEmpty => ReferenceEquals(startSegment, endSegment) & endSegmentCommitted == 0; + + /// + /// Marks committed data as fully consumed; it will no longer appear in later calls to . + /// + public void DiscardCommitted(int count) + { + DebugAssertValid(); + // optimize for most common case, where we consume everything + if (ReferenceEquals(startSegment, endSegment) + & count == endSegmentCommitted + & count > 0) + { + /* + we are consuming all the data in the single segment; we can + just reset that segment back to full size and re-use as-is; + note that we also know that there must *be* a segment + for the count check to pass + */ + endSegmentCommitted = 0; + endSegmentLength = endSegment!.Untrim(expandBackwards: true); + DebugAssertValid(0); + DebugCounters.OnDiscardFull(count); + } + else if (count == 0) + { + // nothing to do + } + else + { + DiscardCommittedSlow(count); + } + } + + public void DiscardCommitted(long count) + { + DebugAssertValid(); + // optimize for most common case, where we consume everything + if (ReferenceEquals(startSegment, endSegment) + & count == endSegmentCommitted + & count > 0) // checks sign *and* non-trimmed + { + // see for logic + endSegmentCommitted = 0; + endSegmentLength = endSegment!.Untrim(expandBackwards: true); + DebugAssertValid(0); + DebugCounters.OnDiscardFull(count); + } + else if (count == 0) + { + // nothing to do + } + else + { + DiscardCommittedSlow(count); + } + } + + private void DiscardCommittedSlow(long count) + { + DebugCounters.OnDiscardPartial(count); +#if DEBUG + var originalLength = GetCommittedLength(); + var originalCount = count; + var expectedLength = originalLength - originalCount; + string blame = nameof(DiscardCommittedSlow); +#endif + while (count > 0) + { + DebugAssertValid(); + var segment = startSegment; + if (segment is null) break; + if (ReferenceEquals(segment, endSegment)) + { + // first==final==only segment + if (count == endSegmentCommitted) + { + endSegmentLength = startSegment!.Untrim(); + endSegmentCommitted = 0; // = untrimmed and unused +#if DEBUG + blame += ",full-final (t)"; +#endif + } + else + { + // discard from the start + int count32 = checked((int)count); + segment.TrimStart(count32); + endSegmentLength -= count32; + endSegmentCommitted -= count32; +#if DEBUG + blame += ",partial-final"; +#endif + } + + count = 0; + break; + } + else if (count < segment.Length) + { + // multiple, but can take some (not all) of the first buffer +#if DEBUG + var len = segment.Length; +#endif + segment.TrimStart((int)count); + Debug.Assert(segment.Length > 0, "parial trim should have left non-empty segment"); +#if DEBUG + Debug.Assert(segment.Length == len - count, "trim failure"); + blame += ",partial-first"; +#endif + count = 0; + break; + } + else + { + // multiple; discard the entire first segment + count -= segment.Length; + startSegment = + segment.ResetAndGetNext(); // we already did a ref-check, so we know this isn't going past endSegment + endSegment!.AppendOrRecycle(segment, maxDepth: 2); + DebugAssertValid(); +#if DEBUG + blame += ",full-first"; +#endif + } + } + + if (count != 0) ThrowCount(); +#if DEBUG + DebugAssertValid(expectedLength, blame); + _ = originalLength; + _ = originalCount; +#endif + + [DoesNotReturn] + static void ThrowCount() => throw new ArgumentOutOfRangeException(nameof(count)); + } + + [Conditional("DEBUG")] + private void DebugAssertValid(long expectedCommittedLength, [CallerMemberName] string caller = "") + { + DebugAssertValid(); + var actual = GetCommittedLength(); + Debug.Assert( + expectedCommittedLength >= 0, + $"Expected committed length is just... wrong: {expectedCommittedLength} (from {caller})"); + Debug.Assert( + expectedCommittedLength == actual, + $"Committed length mismatch: expected {expectedCommittedLength}, got {actual} (from {caller})"); + } + + [Conditional("DEBUG")] + private void DebugAssertValid() + { + if (startSegment is null) + { + Debug.Assert( + endSegmentLength == 0 & endSegmentCommitted == 0, + "un-init state should be zero"); + return; + } + + Debug.Assert(endSegment is not null, "end segment must not be null if start segment exists"); + Debug.Assert( + endSegmentLength == endSegment!.Length, + $"end segment length is incorrect - expected {endSegmentLength}, got {endSegment.Length}"); + Debug.Assert(endSegmentCommitted <= endSegmentLength, $"end segment is over-committed - {endSegmentCommitted} of {endSegmentLength}"); + + // check running indices + startSegment?.DebugAssertValidChain(); + } + + public long GetCommittedLength() + { + DebugAssertValid(); + if (ReferenceEquals(startSegment, endSegment)) + { + return endSegmentCommitted; + } + + // note that the start-segment is pre-trimmed; we don't need to account for an offset on the left + return (endSegment!.RunningIndex + endSegmentCommitted) - startSegment!.RunningIndex; + } + + /// + /// When used with , this means "any non-empty buffer". + /// + public const int GetAnything = 0; + + /// + /// When used with , this means "any full buffer". + /// + public const int GetFullPagesOnly = -1; + + public bool TryGetFirstCommittedSpan(int minBytes, out ReadOnlySpan span) + { + DebugAssertValid(); + if (TryGetFirstCommittedMemory(minBytes, out var memory)) + { + span = memory.Span; + return true; + } + + span = default; + return false; + } + + /// + /// The minLength arg: -ve means "full segments only" (useful when buffering outbound network data to avoid + /// packet fragmentation); otherwise, it is the minimum length we want. + /// + public bool TryGetFirstCommittedMemory(int minBytes, out ReadOnlyMemory memory) + { + if (minBytes == 0) minBytes = 1; // success always means "at least something" + DebugAssertValid(); + if (ReferenceEquals(startSegment, endSegment)) + { + // single page + var available = endSegmentCommitted; + if (available == 0) + { + // empty (includes uninitialized) + memory = default; + return false; + } + + memory = startSegment!.Memory; + var memLength = memory.Length; + if (available == memLength) + { + // full segment; is it enough to make the caller happy? + return available >= minBytes; + } + + // partial segment (and we know it isn't empty) + memory = memory.Slice(start: 0, length: available); + return available >= minBytes & minBytes > 0; // last check here applies the -ve logic + } + + // multi-page; hand out the first page (which is, by definition: full) + memory = startSegment!.Memory; + return memory.Length >= minBytes; + } + + /// + /// Note that this chain is invalidated by any other operations; no concurrency. + /// + public ReadOnlySequence GetAllCommitted() + { + if (ReferenceEquals(startSegment, endSegment)) + { + // single segment, fine + return startSegment is null + ? default + : new ReadOnlySequence(startSegment.Memory.Slice(start: 0, length: endSegmentCommitted)); + } + +#if PARSE_DETAIL + long length = GetCommittedLength(); +#endif + ReadOnlySequence ros = new(startSegment!, 0, endSegment!, endSegmentCommitted); +#if PARSE_DETAIL + Debug.Assert(ros.Length == length, $"length mismatch: calculated {length}, actual {ros.Length}"); +#endif + return ros; + } + + private Segment GetNextSegment() + { + DebugAssertValid(); + if (endSegment is not null) + { + endSegment.TrimEnd(endSegmentCommitted); + Debug.Assert(endSegment.Length == endSegmentCommitted, "trim failure"); + endSegmentLength = endSegmentCommitted; + DebugAssertValid(); + + var spare = endSegment.Next; + if (spare is not null) + { + // we already have a dangling segment; just update state + endSegment.DebugAssertValidChain(); + endSegment = spare; + endSegmentCommitted = 0; + endSegmentLength = spare.Length; + DebugAssertValid(); + return spare; + } + } + + Segment newSegment = Segment.Create(Pool.Rent(PageSize)); + if (endSegment is null) + { + // tabula rasa + endSegmentLength = newSegment.Length; + endSegment = startSegment = newSegment; + DebugAssertValid(); + return newSegment; + } + + endSegment.Append(newSegment); + endSegmentCommitted = 0; + endSegmentLength = newSegment.Length; + endSegment = newSegment; + DebugAssertValid(); + return newSegment; + } + + /// + /// Gets a scratch area for new data; this compares to . + /// + public Span GetUncommittedSpan(int hint = 0) + => GetUncommittedMemory(hint).Span; + + /// + /// Gets a scratch area for new data; this compares to . + /// + public Memory GetUncommittedMemory(int hint = 0) + { + DebugAssertValid(); + var segment = endSegment; + if (segment is not null) + { + var memory = segment.Memory; + if (endSegmentCommitted != 0) memory = memory.Slice(start: endSegmentCommitted); + if (hint <= 0) // allow anything non-empty + { + if (!memory.IsEmpty) return MemoryMarshal.AsMemory(memory); + } + else if (memory.Length >= Math.Min(hint, PageSize >> 2)) // respect the hint up to 1/4 of the page size + { + return MemoryMarshal.AsMemory(memory); + } + } + + // new segment, will always be entire + return MemoryMarshal.AsMemory(GetNextSegment().Memory); + } + + public int UncommittedAvailable + { + get + { + DebugAssertValid(); + return endSegmentLength - endSegmentCommitted; + } + } + + private sealed class Segment : ReadOnlySequenceSegment + { + private Segment() { } + private IMemoryOwner _lease = NullLease.Instance; + private static Segment? _spare; + private Flags _flags; + + [Flags] + private enum Flags + { + None = 0, + StartTrim = 1 << 0, + EndTrim = 1 << 2, + } + + public static Segment Create(IMemoryOwner lease) + { + Debug.Assert(lease is not null, "null lease"); + var memory = lease!.Memory; + if (memory.IsEmpty) ThrowEmpty(); + + var obj = Interlocked.Exchange(ref _spare, null) ?? new(); + return obj.Init(lease, memory); + static void ThrowEmpty() => throw new InvalidOperationException("leased segment is empty"); + } + + private Segment Init(IMemoryOwner lease, Memory memory) + { + _lease = lease; + Memory = memory; + return this; + } + + public int Length => Memory.Length; + + public void Append(Segment next) + { + Debug.Assert(Next is null, "current segment already has a next"); + Debug.Assert(next.Next is null && next.RunningIndex == 0, "inbound next segment is already in a chain"); + next.RunningIndex = RunningIndex + Length; + Next = next; + DebugAssertValidChain(); + } + + private void ApplyChainDelta(int delta) + { + if (delta != 0) + { + var node = Next; + while (node is not null) + { + node.RunningIndex += delta; + node = node.Next; + } + } + } + + public void TrimEnd(int newLength) + { + var delta = Length - newLength; + if (delta != 0) + { + // buffer wasn't fully used; trim + _flags |= Flags.EndTrim; + Memory = Memory.Slice(0, newLength); + ApplyChainDelta(-delta); + DebugAssertValidChain(); + } + } + + public void TrimStart(int remove) + { + if (remove != 0) + { + _flags |= Flags.StartTrim; + Memory = Memory.Slice(start: remove); + RunningIndex += remove; // so that ROS length keeps working; note we *don't* need to adjust the chain + DebugAssertValidChain(); + } + } + + public new Segment? Next + { + get => (Segment?)base.Next; + private set => base.Next = value; + } + + public Segment? ResetAndGetNext() + { + var next = Next; + Next = null; + RunningIndex = 0; + _flags = Flags.None; + Memory = _lease.Memory; // reset, in case we trimmed it + DebugAssertValidChain(); + return next; + } + + public void Recycle() + { + var lease = _lease; + _lease = NullLease.Instance; + lease.Dispose(); + Next = null; + Memory = default; + RunningIndex = 0; + _flags = Flags.None; + Interlocked.Exchange(ref _spare, this); + DebugAssertValidChain(); + } + + private sealed class NullLease : IMemoryOwner + { + private NullLease() { } + public static readonly NullLease Instance = new NullLease(); + public void Dispose() { } + + public Memory Memory => default; + } + + /// + /// Undo any trimming, returning the new full capacity. + /// + public int Untrim(bool expandBackwards = false) + { + var fullMemory = _lease.Memory; + var fullLength = fullMemory.Length; + var delta = fullLength - Length; + if (delta != 0) + { + _flags &= ~(Flags.StartTrim | Flags.EndTrim); + Memory = fullMemory; + if (expandBackwards & RunningIndex >= delta) + { + // push our origin earlier; only valid if + // we're the first segment, otherwise + // we break someone-else's chain + RunningIndex -= delta; + } + else + { + // push everyone else later + ApplyChainDelta(delta); + } + + DebugAssertValidChain(); + } + return fullLength; + } + + public bool StartTrimmed => (_flags & Flags.StartTrim) != 0; + public bool EndTrimmed => (_flags & Flags.EndTrim) != 0; + + [Conditional("DEBUG")] + public void DebugAssertValidChain([CallerMemberName] string blame = "") + { + var node = this; + var runningIndex = RunningIndex; + int index = 0; + while (node.Next is { } next) + { + index++; + var nextRunningIndex = runningIndex + node.Length; + if (nextRunningIndex != next.RunningIndex) ThrowRunningIndex(blame, index); + node = next; + runningIndex = nextRunningIndex; + static void ThrowRunningIndex(string blame, int index) => throw new InvalidOperationException( + $"Critical running index corruption in dangling chain, from '{blame}', segment {index}"); + } + } + + public void AppendOrRecycle(Segment segment, int maxDepth) + { + var node = this; + while (maxDepth-- > 0 && node is not null) + { + if (node.Next is null) // found somewhere to attach it + { + if (segment.Untrim() == 0) break; // turned out to be useless + segment.RunningIndex = node.RunningIndex + node.Length; + node.Next = segment; + return; + } + + node = node.Next; + } + + segment.Recycle(); + } + } + + /// + /// Discard all data and buffers. + /// + public void Release() + { + var node = startSegment; + startSegment = endSegment = null; + endSegmentCommitted = endSegmentLength = 0; + while (node is not null) + { + var next = node.Next; + node.Recycle(); + node = next; + } + } +} + +// this can be shared between CycleBuffer and CycleBuffer.Simple +partial struct CycleBuffer +{ + /// + /// Writes a value to the buffer; comparable to . + /// + public void Write(ReadOnlySpan value) + { + int srcLength = value.Length; + while (srcLength != 0) + { + var target = GetUncommittedSpan(hint: srcLength); + var tgtLength = target.Length; + if (tgtLength >= srcLength) + { + value.CopyTo(target); + Commit(srcLength); + return; + } + + value.Slice(0, tgtLength).CopyTo(target); + Commit(tgtLength); + value = value.Slice(tgtLength); + srcLength -= tgtLength; + } + } + + /// + /// Writes a value to the buffer; comparable to . + /// + public void Write(in ReadOnlySequence value) + { + if (value.IsSingleSegment) + { +#if NETCOREAPP3_0_OR_GREATER || NETSTANDARD2_1 + Write(value.FirstSpan); +#else + Write(value.First.Span); +#endif + } + else + { + WriteMultiSegment(ref this, in value); + } + + static void WriteMultiSegment(ref CycleBuffer @this, in ReadOnlySequence value) + { + foreach (var segment in value) + { +#if NETCOREAPP3_0_OR_GREATER || NETSTANDARD2_1 + @this.Write(value.FirstSpan); +#else + @this.Write(value.First.Span); +#endif + } + } + } +} diff --git a/src/RESP.Core/DebugCounters.cs b/src/RESP.Core/DebugCounters.cs new file mode 100644 index 000000000..48a9f08df --- /dev/null +++ b/src/RESP.Core/DebugCounters.cs @@ -0,0 +1,183 @@ +using System.Diagnostics; +using System.Threading; + +namespace Resp; +#if DEBUG +public partial class DebugCounters +#else +internal partial class DebugCounters +#endif +{ +#if DEBUG + private static int _tallyReadCount, + _tallyAsyncReadCount, + _tallyAsyncReadInlineCount, + _tallyWriteCount, + _tallyAsyncWriteCount, + _tallyAsyncWriteInlineCount, + _tallyCopyOutCount, + _tallyDiscardFullCount, + _tallyDiscardPartialCount, + _tallyPipelineFullAsyncCount, + _tallyPipelineSendAsyncCount, + _tallyPipelineFullSyncCount, + _tallyBatchWriteCount, + _tallyBatchWriteFullPageCount, + _tallyBatchWritePartialPageCount, + _tallyBatchWriteMessageCount; + + private static long _tallyWriteBytes, _tallyReadBytes, _tallyCopyOutBytes, _tallyDiscardAverage; +#endif + [Conditional("DEBUG")] + internal static void OnRead(int bytes) + { +#if DEBUG + Interlocked.Increment(ref _tallyReadCount); + if (bytes > 0) Interlocked.Add(ref _tallyReadBytes, bytes); +#endif + } + + public static void OnBatchWrite(int messageCount) + { +#if DEBUG + Interlocked.Increment(ref _tallyBatchWriteCount); + if (messageCount != 0) Interlocked.Add(ref _tallyBatchWriteMessageCount, messageCount); +#endif + } + + public static void OnBatchWriteFullPage() + { +#if DEBUG + Interlocked.Increment(ref _tallyBatchWriteFullPageCount); +#endif + } + public static void OnBatchWritePartialPage() + { +#if DEBUG + Interlocked.Increment(ref _tallyBatchWritePartialPageCount); +#endif + } + + [Conditional("DEBUG")] + internal static void OnAsyncRead(int bytes, bool inline) + { +#if DEBUG + Interlocked.Increment(ref inline ? ref _tallyAsyncReadInlineCount : ref _tallyAsyncReadCount); + if (bytes > 0) Interlocked.Add(ref _tallyReadBytes, bytes); +#endif + } + + [Conditional("DEBUG")] + internal static void OnWrite(int bytes) + { +#if DEBUG + Interlocked.Increment(ref _tallyWriteCount); + if (bytes > 0) Interlocked.Add(ref _tallyWriteBytes, bytes); +#endif + } + + [Conditional("DEBUG")] + internal static void OnAsyncWrite(int bytes, bool inline) + { +#if DEBUG + Interlocked.Increment(ref inline ? ref _tallyAsyncWriteInlineCount : ref _tallyAsyncWriteCount); + if (bytes > 0) Interlocked.Add(ref _tallyWriteBytes, bytes); +#endif + } + + [Conditional("DEBUG")] + internal static void OnCopyOut(int bytes) + { +#if DEBUG + Interlocked.Increment(ref _tallyCopyOutCount); + if (bytes > 0) Interlocked.Add(ref _tallyCopyOutBytes, bytes); +#endif + } + + [Conditional("DEBUG")] + public static void OnDiscardFull(long count) + { +#if DEBUG + if (count > 0) + { + Interlocked.Increment(ref _tallyDiscardFullCount); + EstimatedMovingRangeAverage(ref _tallyDiscardAverage, count); + } +#endif + } + + [Conditional("DEBUG")] + public static void OnDiscardPartial(long count) + { +#if DEBUG + if (count > 0) + { + Interlocked.Increment(ref _tallyDiscardPartialCount); + EstimatedMovingRangeAverage(ref _tallyDiscardAverage, count); + } +#endif + } + + [Conditional("DEBUG")] + public static void OnPipelineFullAsync() + { +#if DEBUG + Interlocked.Increment(ref _tallyPipelineFullAsyncCount); +#endif + } + + [Conditional("DEBUG")] + public static void OnPipelineSendAsync() + { +#if DEBUG + Interlocked.Increment(ref _tallyPipelineSendAsyncCount); +#endif + } + + [Conditional("DEBUG")] + public static void OnPipelineFullSync() + { +#if DEBUG + Interlocked.Increment(ref _tallyPipelineFullSyncCount); +#endif + } + + private DebugCounters() + { + } + + public static DebugCounters Flush() => new(); + +#if DEBUG + private static void EstimatedMovingRangeAverage(ref long field, long value) + { + var oldValue = Volatile.Read(ref field); + var delta = (value - oldValue) >> 3; // is is a 7:1 old:new EMRA, using integer/bit math (alplha=0.125) + if (delta != 0) Interlocked.Add(ref field, delta); + // note: strictly conflicting concurrent calls can skew the value incorrectly; this is, however, + // preferable to getting into a CEX squabble or requiring a lock - it is debug-only and just useful data + } + + public int ReadCount { get; } = Interlocked.Exchange(ref _tallyReadCount, 0); + public int AsyncReadCount { get; } = Interlocked.Exchange(ref _tallyAsyncReadCount, 0); + public int AsyncReadInlineCount { get; } = Interlocked.Exchange(ref _tallyAsyncReadInlineCount, 0); + public long ReadBytes { get; } = Interlocked.Exchange(ref _tallyReadBytes, 0); + + public int WriteCount { get; } = Interlocked.Exchange(ref _tallyWriteCount, 0); + public int AsyncWriteCount { get; } = Interlocked.Exchange(ref _tallyAsyncWriteCount, 0); + public int AsyncWriteInlineCount { get; } = Interlocked.Exchange(ref _tallyAsyncWriteInlineCount, 0); + public long WriteBytes { get; } = Interlocked.Exchange(ref _tallyWriteBytes, 0); + public int CopyOutCount { get; } = Interlocked.Exchange(ref _tallyCopyOutCount, 0); + public long CopyOutBytes { get; } = Interlocked.Exchange(ref _tallyCopyOutBytes, 0); + public long DiscardAverage { get; } = Interlocked.Exchange(ref _tallyDiscardAverage, 32); + public int DiscardFullCount { get; } = Interlocked.Exchange(ref _tallyDiscardFullCount, 0); + public int DiscardPartialCount { get; } = Interlocked.Exchange(ref _tallyDiscardPartialCount, 0); + public int PipelineFullAsyncCount { get; } = Interlocked.Exchange(ref _tallyPipelineFullAsyncCount, 0); + public int PipelineSendAsyncCount { get; } = Interlocked.Exchange(ref _tallyPipelineSendAsyncCount, 0); + public int PipelineFullSyncCount { get; } = Interlocked.Exchange(ref _tallyPipelineFullSyncCount, 0); + public int BatchWriteCount { get; } = Interlocked.Exchange(ref _tallyBatchWriteCount, 0); + public int BatchWriteFullPageCount { get; } = Interlocked.Exchange(ref _tallyBatchWriteFullPageCount, 0); + public int BatchWritePartialPageCount { get; } = Interlocked.Exchange(ref _tallyBatchWritePartialPageCount, 0); + public int BatchWriteMessageCount { get; } = Interlocked.Exchange(ref _tallyBatchWriteMessageCount, 0); +#endif +} diff --git a/src/RESP.Core/DirectWriteConnection.cs b/src/RESP.Core/DirectWriteConnection.cs new file mode 100644 index 000000000..eaf569d09 --- /dev/null +++ b/src/RESP.Core/DirectWriteConnection.cs @@ -0,0 +1,722 @@ +// #define PARSE_DETAIL // additional trace info in CommitAndParseFrames + +#if DEBUG +#define PARSE_DETAIL // always enable this in debug builds +#endif + +using System; +using System.Buffers; +using System.Collections.Concurrent; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Threading; +using System.Threading.Tasks; + +namespace Resp; + +internal sealed class DirectWriteConnection : IRespConnection +{ + private bool _isDoomed; + private RespScanState _readScanState; + private CycleBuffer _readBuffer, _writeBuffer; + private readonly RespContext _context; + public ref readonly RespContext Context => ref _context; + public bool CanWrite => Volatile.Read(ref _readStatus) == WRITER_AVAILABLE; + + public int Outstanding => _outstanding.Count; + + public Task Reader { get; private set; } = Task.CompletedTask; + + private readonly Stream tail; + private ConcurrentQueue _outstanding = new(); + public RespConfiguration Configuration { get; } + + public DirectWriteConnection(RespConfiguration configuration, Stream tail, bool asyncRead = true) + { + Configuration = configuration; + if (!(tail.CanRead && tail.CanWrite)) Throw(); + this.tail = tail; + var memoryPool = configuration.GetService>(); + _readBuffer = CycleBuffer.Create(memoryPool); + _writeBuffer = CycleBuffer.Create(memoryPool); + if (asyncRead) + { + Reader = Task.Run(ReadAllAsync); + } + else + { + new Thread(ReadAll).Start(); + } + + _context = RespContext.For(this); + + static void Throw() => throw new ArgumentException("Stream must be readable and writable", nameof(tail)); + } + + public RespMode Mode { get; set; } = RespMode.Resp2; + + public enum RespMode + { + Resp2, + Resp2PubSub, + Resp3, + } + + private static byte[]? SharedNoLease; + + private bool CommitAndParseFrames(int bytesRead) + { + if (bytesRead <= 0) + { + return false; + } + + // let's bypass a bunch of ldarg0 by hoisting the field-refs (this is **NOT** a struct copy; emphasis "ref") + ref RespScanState state = ref _readScanState; + ref CycleBuffer readBuffer = ref _readBuffer; + +#if PARSE_DETAIL + string src = $"parse {bytesRead}"; + try +#endif + { + Debug.Assert(readBuffer.GetCommittedLength() >= 0, "multi-segment running-indices are corrupt"); +#if PARSE_DETAIL + src += $" ({readBuffer.GetCommittedLength()}+{bytesRead}-{state.TotalBytes})"; +#endif + Debug.Assert( + bytesRead <= readBuffer.UncommittedAvailable, + $"Insufficient bytes in {nameof(CommitAndParseFrames)}; got {bytesRead}, Available={readBuffer.UncommittedAvailable}"); + readBuffer.Commit(bytesRead); +#if PARSE_DETAIL + src += $",total {readBuffer.GetCommittedLength()}"; +#endif + var scanner = RespFrameScanner.Default; + + OperationStatus status = OperationStatus.NeedMoreData; + if (readBuffer.TryGetCommitted(out var fullSpan)) + { + int fullyConsumed = 0; + var toParse = fullSpan.Slice((int)state.TotalBytes); // skip what we've already parsed + + Debug.Assert(!toParse.IsEmpty); + while (true) + { +#if PARSE_DETAIL + src += $",span {toParse.Length}"; +#endif + int totalBytesBefore = (int)state.TotalBytes; + if (toParse.Length < RespScanState.MinBytes + || (status = scanner.TryRead(ref state, toParse)) != OperationStatus.Done) + { + break; + } + + Debug.Assert( + state is + { + IsComplete: true, TotalBytes: >= RespScanState.MinBytes, Prefix: not RespPrefix.None + }, + "Invalid RESP read state"); + + // extract the frame + var bytes = (int)state.TotalBytes; +#if PARSE_DETAIL + src += $",frame {bytes}"; +#endif + // send the frame somewhere (note this is the *full* frame, not just the bit we just parsed) + OnResponseFrame(state.Prefix, fullSpan.Slice(fullyConsumed, bytes), ref SharedNoLease); + + // update our buffers to the unread potions and reset for a new RESP frame + fullyConsumed += bytes; + toParse = toParse.Slice(bytes - totalBytesBefore); // move past the extra bytes we just read + state = default; + status = OperationStatus.NeedMoreData; + } + + readBuffer.DiscardCommitted(fullyConsumed); + } + else // the same thing again, but this time with multi-segment sequence + { + var fullSequence = readBuffer.GetAllCommitted(); + Debug.Assert( + fullSequence is { IsEmpty: false, IsSingleSegment: false }, + "non-trivial sequence expected"); + + long fullyConsumed = 0; + var toParse = fullSequence.Slice((int)state.TotalBytes); // skip what we've already parsed + while (true) + { +#if PARSE_DETAIL + src += $",ros {toParse.Length}"; +#endif + int totalBytesBefore = (int)state.TotalBytes; + if (toParse.Length < RespScanState.MinBytes + || (status = scanner.TryRead(ref state, toParse)) != OperationStatus.Done) + { + break; + } + + Debug.Assert( + state is + { + IsComplete: true, TotalBytes: >= RespScanState.MinBytes, Prefix: not RespPrefix.None + }, + "Invalid RESP read state"); + + // extract the frame + var bytes = (int)state.TotalBytes; +#if PARSE_DETAIL + src += $",frame {bytes}"; +#endif + // send the frame somewhere (note this is the *full* frame, not just the bit we just parsed) + OnResponseFrame(state.Prefix, fullSequence.Slice(fullyConsumed, bytes)); + + // update our buffers to the unread potions and reset for a new RESP frame + fullyConsumed += bytes; + toParse = toParse.Slice(bytes - totalBytesBefore); // move past the extra bytes we just read + state = default; + status = OperationStatus.NeedMoreData; + } + + readBuffer.DiscardCommitted(fullyConsumed); + } + + if (status != OperationStatus.NeedMoreData) + { + ThrowStatus(status); + + static void ThrowStatus(OperationStatus status) => + throw new InvalidOperationException($"Unexpected operation status: {status}"); + } + + return true; + } +#if PARSE_DETAIL + catch (Exception ex) + { + Debug.WriteLine($"{nameof(CommitAndParseFrames)}: {ex.Message}"); + Debug.WriteLine(src); + ActivationHelper.DebugBreak(); + throw new InvalidOperationException($"{src} lead to {ex.Message}", ex); + } +#endif + } + + private async Task ReadAllAsync() + { + try + { + int read; + do + { + var buffer = _readBuffer.GetUncommittedMemory(); + var pending = tail.ReadAsync(buffer, CancellationToken.None); +#if DEBUG + bool inline = pending.IsCompleted; +#endif + read = await pending.ConfigureAwait(false); +#if DEBUG + DebugCounters.OnAsyncRead(read, inline); +#endif + } + // another formatter glitch + while (CommitAndParseFrames(read)); + + Volatile.Write(ref _readStatus, READER_COMPLETED); + _readBuffer.Release(); // clean exit, we can recycle + } + catch (Exception ex) + { + OnReadException(ex); + throw; + } + finally + { + OnReadAllFinally(); + } + } + + private void ReadAll() + { + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + Reader = tcs.Task; + try + { + int read; + do + { +#if NETCOREAPP || NETSTANDARD2_1_OR_GREATER + var buffer = _readBuffer.GetUncommittedSpan(); + read = tail.Read(buffer); +#else + var buffer = _readBuffer.GetUncommittedMemory(); + read = tail.Read(buffer); +#endif + DebugCounters.OnRead(read); + } + // another formatter glitch + while (CommitAndParseFrames(read)); + + Volatile.Write(ref _readStatus, READER_COMPLETED); + _readBuffer.Release(); // clean exit, we can recycle + tcs.TrySetResult(null); + } + catch (Exception ex) + { + tcs.TrySetException(ex); + OnReadException(ex); + } + finally + { + OnReadAllFinally(); + } + } + + private void OnReadException(Exception ex) + { + _fault ??= ex; + Volatile.Write(ref _readStatus, READER_FAILED); + Debug.WriteLine($"Reader failed: {ex.Message}"); + ActivationHelper.DebugBreak(); + while (_outstanding.TryDequeue(out var pending)) + { + pending.TrySetException(ex); + } + } + + private void OnReadAllFinally() + { + Doom(); + _readBuffer.Release(); + + // abandon anything in the queue + while (_outstanding.TryDequeue(out var pending)) + { + pending.TrySetCanceled(CancellationToken.None); + } + } + + private static readonly ulong + ArrayPong_LC_Bulk = RespConstants.UnsafeCpuUInt64("*2\r\n$4\r\npong\r\n$"u8), + ArrayPong_UC_Bulk = RespConstants.UnsafeCpuUInt64("*2\r\n$4\r\nPONG\r\n$"u8), + ArrayPong_LC_Simple = RespConstants.UnsafeCpuUInt64("*2\r\n+pong\r\n$"u8), + ArrayPong_UC_Simple = RespConstants.UnsafeCpuUInt64("*2\r\n+PONG\r\n$"u8); + + private static readonly uint + pong = RespConstants.UnsafeCpuUInt32("pong"u8), + PONG = RespConstants.UnsafeCpuUInt32("PONG"u8); + + private void OnOutOfBand(ReadOnlySpan payload, ref byte[]? lease) + { + throw new NotImplementedException(nameof(OnOutOfBand)); + } + + private void OnResponseFrame(RespPrefix prefix, ReadOnlySequence payload) + { + if (payload.IsSingleSegment) + { +#if NETCOREAPP || NETSTANDARD2_1_OR_GREATER + OnResponseFrame(prefix, payload.FirstSpan, ref SharedNoLease); +#else + OnResponseFrame(prefix, payload.First.Span, ref SharedNoLease); +#endif + } + else + { + var len = checked((int)payload.Length); + byte[]? oversized = ArrayPool.Shared.Rent(len); + payload.CopyTo(oversized); + OnResponseFrame(prefix, new(oversized, 0, len), ref oversized); + + // the lease could have been claimed by the activation code (to prevent another memcpy); otherwise, free + if (oversized is not null) + { + ArrayPool.Shared.Return(oversized); + } + } + } + + [Conditional("DEBUG")] + private static void DebugValidateSingleFrame(ReadOnlySpan payload) + { + var reader = new RespReader(payload); + reader.MoveNext(); + reader.SkipChildren(); + if (reader.TryMoveNext()) + { + throw new InvalidOperationException($"Unexpected trailing {reader.Prefix}"); + } + + if (reader.ProtocolBytesRemaining != 0) + { + var copy = reader; // leave reader alone for inspection + var prefix = copy.TryMoveNext() ? copy.Prefix : RespPrefix.None; + throw new InvalidOperationException( + $"Unexpected additional {reader.ProtocolBytesRemaining} bytes remaining, {prefix}"); + } + } + + private void OnResponseFrame(RespPrefix prefix, ReadOnlySpan payload, ref byte[]? lease) + { + DebugValidateSingleFrame(payload); + if (prefix == RespPrefix.Push || + (prefix == RespPrefix.Array && Mode is RespMode.Resp2PubSub && !IsArrayPong(payload))) + { + // out-of-band; pub/sub etc + OnOutOfBand(payload, ref lease); + return; + } + + // request/response; match to inbound + if (_outstanding.TryDequeue(out var pending)) + { + ActivationHelper.ProcessResponse(pending, payload, ref lease); + } + else + { + Debug.Fail("Unexpected response without pending message!"); + } + + static bool IsArrayPong(ReadOnlySpan payload) + { + if (payload.Length >= sizeof(ulong)) + { + var raw = RespConstants.UnsafeCpuUInt64(payload); + if (raw == ArrayPong_LC_Bulk + || raw == ArrayPong_UC_Bulk + || raw == ArrayPong_LC_Simple + || raw == ArrayPong_UC_Simple) + { + var reader = new RespReader(payload); + return reader.TryMoveNext() // have root + && reader.Prefix == RespPrefix.Array // root is array + && reader.TryMoveNext() // have first child + && (reader.IsInlneCpuUInt32(pong) || reader.IsInlneCpuUInt32(PONG)); // pong + } + } + + return false; + } + } + + private int _writeStatus, _readStatus; + private const int WRITER_AVAILABLE = 0, WRITER_TAKEN = 1, WRITER_DOOMED = 2; + private const int READER_ACTIVE = 0, READER_FAILED = 1, READER_COMPLETED = 2; + + private void TakeWriter() + { + var status = Interlocked.CompareExchange(ref _writeStatus, WRITER_TAKEN, WRITER_AVAILABLE); + if (status != WRITER_AVAILABLE) ThrowWriterNotAvailable(); + Debug.Assert(Volatile.Read(ref _writeStatus) == WRITER_TAKEN, "writer should be taken"); + } + + private void ThrowWriterNotAvailable() + { + var fault = Volatile.Read(ref _fault); + var status = Volatile.Read(ref _writeStatus); + var msg = status switch + { + WRITER_TAKEN => "A write operation is already in progress; concurrent writes are not supported.", + WRITER_DOOMED when fault is not null => "This connection is terminated; no further writes are possible: " + + fault.Message, + WRITER_DOOMED => "This connection is terminated; no further writes are possible.", + _ => $"Unexpected writer status: {status}", + }; + throw fault is null ? new InvalidOperationException(msg) : new InvalidOperationException(msg, fault); + } + + private Exception? _fault; + + private void ReleaseWriter(int status = WRITER_AVAILABLE) + { + if (status == WRITER_AVAILABLE && _isDoomed) + { + status = WRITER_DOOMED; + } + + Interlocked.CompareExchange(ref _writeStatus, status, WRITER_TAKEN); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private void OnRequestUnavailable(IRespMessage message) + { + if (!message.IsCompleted) + { + // make sure they know something is wrong + message.TrySetException(new InvalidOperationException("Connection is not available")); + } + } + + public void Send(IRespMessage message) + { + bool releaseRequest = message.TryReserveRequest(out var bytes); + if (!releaseRequest) + { + OnRequestUnavailable(message); + return; + } + + DebugValidateSingleFrame(bytes.Span); + TakeWriter(); + try + { + _outstanding.Enqueue(message); + releaseRequest = false; // once we write, only release on success +#if NETCOREAPP || NETSTANDARD2_1_OR_GREATER + tail.Write(bytes.Span); +#else + tail.Write(bytes); +#endif + DebugCounters.OnWrite(bytes.Length); + ReleaseWriter(); + message.ReleaseRequest(); + } + catch (Exception ex) + { + Debug.WriteLine($"Writer failed: {ex.Message}"); + ActivationHelper.DebugBreak(); + ReleaseWriter(WRITER_DOOMED); + if (releaseRequest) message.ReleaseRequest(); + throw; + } + } + + public void Send(ReadOnlySpan messages) + { + switch (messages.Length) + { + case 0: + return; + case 1: + Send(messages[0]); + return; + } + + TakeWriter(); + IRespMessage? toRelease = null; + try + { + foreach (var message in messages) + { + if (message.TryReserveRequest(out var bytes)) + { + toRelease = message; + } + else + { + OnRequestUnavailable(message); + continue; + } + + DebugValidateSingleFrame(bytes.Span); + _outstanding.Enqueue(message); + toRelease = null; // once we write, only release on success +#if NETCOREAPP || NETSTANDARD2_1_OR_GREATER + tail.Write(bytes.Span); +#else + tail.Write(bytes); +#endif + DebugCounters.OnWrite(bytes.Length); + ReleaseWriter(); + message.ReleaseRequest(); + } + } + catch (Exception ex) + { + Debug.WriteLine($"Writer failed: {ex.Message}"); + ActivationHelper.DebugBreak(); + ReleaseWriter(WRITER_DOOMED); + toRelease?.ReleaseRequest(); + foreach (var message in messages) + { + // assume all bad + message.TrySetException(ex); + } + + throw; + } + } + + public Task SendAsync(IRespMessage message) + { + bool releaseRequest = message.TryReserveRequest(out var bytes); + if (!releaseRequest) + { + OnRequestUnavailable(message); + return Task.CompletedTask; + } + + DebugValidateSingleFrame(bytes.Span); + TakeWriter(); + try + { + _outstanding.Enqueue(message); + releaseRequest = false; // once we write, only release on success + var pendingWrite = tail.WriteAsync(bytes, CancellationToken.None); + if (!pendingWrite.IsCompleted) + { + return AwaitedSingleWithToken( + this, + pendingWrite, +#if DEBUG + bytes.Length, +#endif + message); + } + + pendingWrite.GetAwaiter().GetResult(); + DebugCounters.OnAsyncWrite(bytes.Length, true); + ReleaseWriter(); + message.ReleaseRequest(); + return Task.CompletedTask; + } + catch (Exception ex) + { + Debug.WriteLine($"Writer failed: {ex.Message}"); + ActivationHelper.DebugBreak(); + ReleaseWriter(WRITER_DOOMED); + if (releaseRequest) message.ReleaseRequest(); + throw; + } + + static async Task AwaitedSingleWithToken( + DirectWriteConnection @this, + ValueTask pendingWrite, +#if DEBUG + int length, +#endif + IRespMessage message) + { + try + { + await pendingWrite.ConfigureAwait(false); +#if DEBUG + DebugCounters.OnAsyncWrite(length, false); +#endif + @this.ReleaseWriter(); + message.ReleaseRequest(); + } + catch + { + @this.ReleaseWriter(WRITER_DOOMED); + throw; + } + } + } + + public Task SendAsync(ReadOnlyMemory messages) + { + switch (messages.Length) + { + case 0: + return Task.CompletedTask; + case 1: + return SendAsync(messages.Span[0]); + default: + return CombineAndSendMultipleAsync(this, messages); + } + } + + private async Task CombineAndSendMultipleAsync(DirectWriteConnection @this, ReadOnlyMemory messages) + { + TakeWriter(); + IRespMessage? toRelease = null; + int definitelySent = 0; + try + { + int length = messages.Length; + for (int i = 0; i < length; i++) + { + var message = messages.Span[i]; + if (!message.TryReserveRequest(out var bytes)) + { + OnRequestUnavailable(message); + continue; // skip this message + } + + toRelease = message; + // append to the scratch and consider written (even though we haven't actually) + _writeBuffer.Write(bytes.Span); + toRelease = null; + message.ReleaseRequest(); + @this._outstanding.Enqueue(message); + + // do we have any full segments? if so, write them and narrow "messages" + if (_writeBuffer.TryGetFirstCommittedMemory(CycleBuffer.GetFullPagesOnly, out var memory)) + { + do + { + var pending = tail.WriteAsync(memory, CancellationToken.None); + DebugCounters.OnAsyncWrite(memory.Length, inline: pending.IsCompleted); + await pending.ConfigureAwait(false); + DebugCounters.OnBatchWriteFullPage(); + + _writeBuffer.DiscardCommitted(memory.Length); // mark the data as no longer needed + } + // and if one buffer was full, we might have multiple (think: "large BLOB outbound") + while (_writeBuffer.TryGetFirstCommittedMemory(CycleBuffer.GetFullPagesOnly, out memory)); + + definitelySent = i + 1; // for exception handling: no need to doom these if later fails + } + } + + // and send any remaining data + while (_writeBuffer.TryGetFirstCommittedMemory(CycleBuffer.GetAnything, out var memory)) + { + var pending = tail.WriteAsync(memory, CancellationToken.None); + DebugCounters.OnAsyncWrite(memory.Length, inline: pending.IsCompleted); + await pending.ConfigureAwait(false); + DebugCounters.OnBatchWritePartialPage(); + + _writeBuffer.DiscardCommitted(memory.Length); // mark the data as no longer needed + } + + Debug.Assert(_writeBuffer.CommittedIsEmpty, "should have written everything"); + + ReleaseWriter(); + DebugCounters.OnBatchWrite(messages.Length); + } + catch (Exception ex) + { + Debug.WriteLine($"Writer failed: {ex.Message}"); + ActivationHelper.DebugBreak(); + ReleaseWriter(WRITER_DOOMED); + toRelease?.ReleaseRequest(); + foreach (var message in messages.Span.Slice(start: definitelySent)) + { + message.TrySetException(ex); + } + + throw; + } + } + + private void Doom() + { + _isDoomed = true; // without a reader, there's no point writing + Interlocked.CompareExchange(ref _writeStatus, WRITER_DOOMED, WRITER_AVAILABLE); + } + + public void Dispose() + { + _fault ??= new ObjectDisposedException(ToString()); + Doom(); + tail.Dispose(); + } + + public override string ToString() => nameof(DirectWriteConnection); + + public ValueTask DisposeAsync() + { +#if COREAPP3_0_OR_GREATER || NETSTANDARD2_1_OR_GREATER + return tail.DisposeAsync().AsTask(); +#else + Dispose(); + return default; +#endif + } +} diff --git a/src/RESP.Core/FrameScanInfo.cs b/src/RESP.Core/FrameScanInfo.cs new file mode 100644 index 000000000..84e4cb604 --- /dev/null +++ b/src/RESP.Core/FrameScanInfo.cs @@ -0,0 +1,34 @@ +namespace Resp; + +/* +/// +/// Additional information about a frame parsing operation. +/// +public struct FrameScanInfo +{ + /// + /// Initialize an instance. + /// + public FrameScanInfo(bool isOutbound) => IsOutbound = isOutbound; + + /// + /// Indicates whether the data operation is outbound. + /// + public bool IsOutbound { get; } + + /// + /// The amount of data, in bytes, to read before attempting to read the next frame. + /// + public int MinBytes => 3; // minimum legal RESP frame is: _\r\n + + /// + /// Gets the total number of bytes processed. + /// + public long BytesRead { get; set; } + + /// + /// Indicates whether this is an out-of-band payload. + /// + public bool IsOutOfBand { get; set; } +} +*/ diff --git a/src/RESP.Core/Global.cs b/src/RESP.Core/Global.cs new file mode 100644 index 000000000..593d3f98b --- /dev/null +++ b/src/RESP.Core/Global.cs @@ -0,0 +1,4 @@ +using System; +using System.Runtime.CompilerServices; + +[assembly: CLSCompliant(true)] diff --git a/src/RESP.Core/IRespConnection.cs b/src/RESP.Core/IRespConnection.cs new file mode 100644 index 000000000..55273e54e --- /dev/null +++ b/src/RESP.Core/IRespConnection.cs @@ -0,0 +1,23 @@ +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace Resp; + +public interface IRespConnection : IDisposable, IAsyncDisposable +{ + RespConfiguration Configuration { get; } + bool CanWrite { get; } + int Outstanding { get; } + + /// + /// Gets the default context associates with this connection. + /// + ref readonly RespContext Context { get; } + + void Send(IRespMessage message); + void Send(ReadOnlySpan messages); + + Task SendAsync(IRespMessage message); + Task SendAsync(ReadOnlyMemory messages); +} diff --git a/src/RESP.Core/IRespReader.cs b/src/RESP.Core/IRespReader.cs new file mode 100644 index 000000000..2629efdf0 --- /dev/null +++ b/src/RESP.Core/IRespReader.cs @@ -0,0 +1,14 @@ +// using RESPite.Messages; +// +// namespace Resp; +// +// /// +// /// Reads RESP payloads. +// /// +// internal interface IRespReader : IReader +// { +// /// +// /// Read a given value. +// /// +// TResponse Read(in TRequest request, ref RespReader reader); +// } diff --git a/src/RESP.Core/Message.cs b/src/RESP.Core/Message.cs new file mode 100644 index 000000000..edd50f594 --- /dev/null +++ b/src/RESP.Core/Message.cs @@ -0,0 +1,211 @@ +using System; +using System.Threading.Tasks; + +namespace Resp; + +public static class Message +{ + public static TResponse Send( + in RespContext context, + scoped ReadOnlySpan command, + in TRequest request, + IRespFormatter formatter, + IRespParser parser) +#if NET9_0_OR_GREATER + where TRequest : allows ref struct +#endif + { + var bytes = Serialize(context.RespCommandMap, command, request, formatter, out int length); + var msg = SyncInternalRespMessage.Create( + bytes, + length, + parser, + in Void.Instance, + context.CancellationToken); + context.Connection.Send(msg); + return msg.WaitAndRecycle(context.Connection.Configuration.SyncTimeout); + } + + public static TResponse Send( + in RespContext context, + scoped ReadOnlySpan command, + in TRequest request, + in TState state, + IRespFormatter formatter, + IRespParser parser) +#if NET9_0_OR_GREATER + where TRequest : allows ref struct +#endif + { + var bytes = Serialize(context.RespCommandMap, command, in request, formatter, out int length); + var msg = SyncInternalRespMessage.Create( + bytes, + length, + parser, + in state, + context.CancellationToken); + context.Connection.Send(msg); + return msg.WaitAndRecycle(context.Connection.Configuration.SyncTimeout); + } + + public static ValueTask SendAsync( + in RespContext context, + scoped ReadOnlySpan command, + in TRequest request, + IRespFormatter formatter, + IRespParser parser) +#if NET9_0_OR_GREATER + where TRequest : allows ref struct +#endif + { + var bytes = Serialize(context.RespCommandMap, command, request, formatter, out int length); + var msg = AsyncInternalRespMessage.Create( + bytes, + length, + parser, + in Void.Instance, + context.CancellationToken); + return msg.WaitTypedAsync(context.Connection.SendAsync(msg)); + } + + public static ValueTask SendAsync( + in RespContext context, + scoped ReadOnlySpan command, + in TRequest request, + in TState state, + IRespFormatter formatter, + IRespParser parser) +#if NET9_0_OR_GREATER + where TRequest : allows ref struct +#endif + { + var bytes = Serialize(context.RespCommandMap, command, in request, formatter, out int length); + var msg = AsyncInternalRespMessage.Create( + bytes, + length, + parser, + in state, + context.CancellationToken); + return msg.WaitTypedAsync(context.Connection.SendAsync(msg)); + } + + public static void Send( + in RespContext context, + scoped ReadOnlySpan command, + in TRequest request, + IRespFormatter formatter, + IRespParser parser) +#if NET9_0_OR_GREATER + where TRequest : allows ref struct +#endif + { + var bytes = Serialize(context.RespCommandMap, command, request, formatter, out int length); + var msg = SyncInternalRespMessage.Create( + bytes, + length, + parser, + in Void.Instance, + context.CancellationToken); + context.Connection.Send(msg); + msg.WaitAndRecycle(context.Connection.Configuration.SyncTimeout); + } + + public static void Send( + in RespContext context, + scoped ReadOnlySpan command, + in TRequest request, + in TState state, + IRespFormatter formatter, + IRespParser parser) +#if NET9_0_OR_GREATER + where TRequest : allows ref struct +#endif + { + var bytes = Serialize(context.RespCommandMap, command, in request, formatter, out int length); + var msg = SyncInternalRespMessage.Create( + bytes, + length, + parser, + in state, + context.CancellationToken); + context.Connection.Send(msg); + msg.WaitAndRecycle(context.Connection.Configuration.SyncTimeout); + } + + public static ValueTask SendAsync( + in RespContext context, + scoped ReadOnlySpan command, + in TRequest request, + IRespFormatter formatter, + IRespParser parser) +#if NET9_0_OR_GREATER + where TRequest : allows ref struct +#endif + { + var bytes = Serialize(context.RespCommandMap, command, request, formatter, out int length); + var msg = AsyncInternalRespMessage.Create( + bytes, + length, + parser, + in Void.Instance, + context.CancellationToken); + return msg.WaitUntypedAsync(context.Connection.SendAsync(msg)); + } + + public static ValueTask SendAsync( + in RespContext context, + scoped ReadOnlySpan command, + in TRequest request, + in TState state, + IRespFormatter formatter, + IRespParser parser) +#if NET9_0_OR_GREATER + where TRequest : allows ref struct +#endif + { + var bytes = Serialize(context.RespCommandMap, command, in request, formatter, out int length); + var msg = AsyncInternalRespMessage.Create( + bytes, + length, + parser, + in state, + context.CancellationToken); + return msg.WaitUntypedAsync(context.Connection.SendAsync(msg)); + } + + private static byte[] Serialize( + RespCommandMap commandMap, + ReadOnlySpan command, + in TRequest request, + IRespFormatter formatter, + out int length) +#if NET9_0_OR_GREATER + where TRequest : allows ref struct +#endif + { + int size = 0; + if (formatter is IRespSizeEstimator estimator) + { + size = estimator.EstimateSize(command, request); + } + + var buffer = AmbientBufferWriter.Get(size); + try + { + var writer = new RespWriter(buffer); + if (!ReferenceEquals(commandMap, RespCommandMap.Default)) + { + writer.CommandMap = commandMap; + } + + formatter.Format(command, ref writer, request); + writer.Flush(); + return buffer.Detach(out length); + } + catch + { + buffer.Reset(); + throw; + } + } +} diff --git a/src/RESP.Core/PipelinedConnection.cs b/src/RESP.Core/PipelinedConnection.cs new file mode 100644 index 000000000..fc27b8c47 --- /dev/null +++ b/src/RESP.Core/PipelinedConnection.cs @@ -0,0 +1,218 @@ +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace Resp; + +internal class PipelinedConnection : IRespConnection +{ + private readonly IRespConnection _tail; + private readonly RespContext _context; + private readonly SemaphoreSlim _semaphore = new(1); + + public ref readonly RespContext Context => ref _context; + public PipelinedConnection(in RespContext tail) + { + _tail = tail.Connection; + _context = tail.WithConnection(this); + } + + public void Dispose() + { + _semaphore.Dispose(); + _tail.Dispose(); + } + + public ValueTask DisposeAsync() + { + _semaphore.Dispose(); + return _tail.DisposeAsync(); + } + + public RespConfiguration Configuration => _tail.Configuration; + public bool CanWrite => _semaphore.CurrentCount > 0 && _tail.CanWrite; + public int Outstanding => _tail.Outstanding; + + public void Send(IRespMessage message) + { + _semaphore.Wait(message.CancellationToken); + try + { + _tail.Send(message); + } + catch (Exception ex) + { + message.TrySetException(ex); + throw; + } + finally + { + _semaphore.Release(); + } + } + + public void Send(ReadOnlySpan messages) + { + switch (messages.Length) + { + case 0: return; + case 1: + Send(messages[0]); + return; + } + _semaphore.Wait(messages[0].CancellationToken); + try + { + _tail.Send(messages); + } + catch (Exception ex) + { + TrySetException(messages, ex); + throw; + } + finally + { + _semaphore.Release(); + } + } + + public Task SendAsync(IRespMessage message) + { + bool haveLock = false; + try + { + haveLock = _semaphore.Wait(0); + if (!haveLock) + { + DebugCounters.OnPipelineFullAsync(); + return FullAsync(this, message); + } + + var pending = _tail.SendAsync(message); + if (!pending.IsCompleted) + { + DebugCounters.OnPipelineSendAsync(); + haveLock = false; // transferring + return AwaitAndReleaseLock(pending); + } + + DebugCounters.OnPipelineFullSync(); + pending.GetAwaiter().GetResult(); + return Task.CompletedTask; + } + catch (Exception ex) + { + message.TrySetException(ex); + throw; + } + finally + { + if (haveLock) _semaphore.Release(); + } + + static async Task FullAsync(PipelinedConnection @this, IRespMessage message) + { + try + { + await @this._semaphore.WaitAsync(message.CancellationToken).ConfigureAwait(false); + } + catch (Exception ex) + { + message.TrySetException(ex); + throw; + } + + try + { + await @this._tail.SendAsync(message).ConfigureAwait(false); + } + finally + { + @this._semaphore.Release(); + } + } + } + + private async Task AwaitAndReleaseLock(Task pending) + { + try + { + await pending.ConfigureAwait(false); + } + finally + { + _semaphore.Release(); + } + } + + private static void TrySetException(ReadOnlySpan messages, Exception ex) + { + foreach (var message in messages) + { + message.TrySetException(ex); + } + } + + public Task SendAsync(ReadOnlyMemory messages) + { + switch (messages.Length) + { + case 0: return Task.CompletedTask; + case 1: return SendAsync(messages.Span[0]); + } + bool haveLock = false; + try + { + haveLock = _semaphore.Wait(0); + if (!haveLock) + { + DebugCounters.OnPipelineFullAsync(); + return FullAsync(this, messages); + } + + var pending = _tail.SendAsync(messages); + if (!pending.IsCompleted) + { + DebugCounters.OnPipelineSendAsync(); + haveLock = false; // transferring + return AwaitAndReleaseLock(pending); + } + + DebugCounters.OnPipelineFullSync(); + pending.GetAwaiter().GetResult(); + return Task.CompletedTask; + } + catch (Exception ex) + { + TrySetException(messages.Span, ex); + throw; + } + finally + { + if (haveLock) _semaphore.Release(); + } + + static async Task FullAsync(PipelinedConnection @this, ReadOnlyMemory messages) + { + bool haveLock = false; // we don't have the lock initially + try + { + await @this._semaphore.WaitAsync(messages.Span[0].CancellationToken).ConfigureAwait(false); + haveLock = true; + await @this._tail.SendAsync(messages).ConfigureAwait(false); + } + catch (Exception ex) + { + TrySetException(messages.Span, ex); + throw; + } + finally + { + if (haveLock) + { + @this._semaphore.Release(); + } + } + } + } +} diff --git a/src/RESP.Core/PublicAPI/PublicAPI.Shipped.txt b/src/RESP.Core/PublicAPI/PublicAPI.Shipped.txt new file mode 100644 index 000000000..7dc5c5811 --- /dev/null +++ b/src/RESP.Core/PublicAPI/PublicAPI.Shipped.txt @@ -0,0 +1 @@ +#nullable enable diff --git a/src/RESP.Core/PublicAPI/PublicAPI.Unshipped.txt b/src/RESP.Core/PublicAPI/PublicAPI.Unshipped.txt new file mode 100644 index 000000000..91c1ad8f0 --- /dev/null +++ b/src/RESP.Core/PublicAPI/PublicAPI.Unshipped.txt @@ -0,0 +1,191 @@ +#nullable enable +abstract Resp.RespPayload.Dispose(bool disposing) -> void +abstract Resp.RespPayload.GetPayload() -> System.Buffers.ReadOnlySequence +override Resp.RespPayload.ToString() -> string! +override Resp.RespScanState.Equals(object? obj) -> bool +override Resp.RespScanState.GetHashCode() -> int +override Resp.RespScanState.ToString() -> string! +Resp.FrameScanInfo +Resp.FrameScanInfo.BytesRead.get -> long +Resp.FrameScanInfo.BytesRead.set -> void +Resp.FrameScanInfo.FrameScanInfo() -> void +Resp.FrameScanInfo.FrameScanInfo(bool isOutbound) -> void +Resp.FrameScanInfo.IsOutbound.get -> bool +Resp.FrameScanInfo.IsOutOfBand.get -> bool +Resp.FrameScanInfo.IsOutOfBand.set -> void +Resp.FrameScanInfo.ReadHint.get -> int +Resp.FrameScanInfo.ReadHint.set -> void +Resp.ICommandMap +Resp.ICommandMap.Map(scoped ref System.ReadOnlySpan command) -> void +Resp.IRespConnection +Resp.IRespConnection.CanWrite.get -> bool +Resp.IRespConnection.Outstanding.get -> int +Resp.IRespConnection.Send(Resp.RespPayload! payload) -> Resp.RespPayload! +Resp.IRespConnection.SendAsync(Resp.RespPayload! payload, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.ValueTask +Resp.IRespFormatter +Resp.IRespFormatter.Format(scoped System.ReadOnlySpan command, ref Resp.RespWriter writer, in TRequest request) -> void +Resp.IRespMetadataParser +Resp.IRespParser +Resp.IRespParser.Parse(in TRequest request, ref Resp.RespReader reader) -> TResponse +Resp.IRespParser +Resp.IRespParser.Parse(ref Resp.RespReader reader) -> TResponse +Resp.IRespSizeEstimator +Resp.IRespSizeEstimator.EstimateSize(scoped System.ReadOnlySpan command, in TRequest request) -> int +Resp.RedisCommands.RedisString +Resp.RedisCommands.RedisString.Get() -> string? +Resp.RedisCommands.RedisString.RedisString(Resp.IRespConnection! connection, string! key) -> void +Resp.RedisCommands.RedisString.Set(string! value) -> void +Resp.RedisCommands.RespConnectionExtensions +Resp.RespAttributeReader +Resp.RespAttributeReader.RespAttributeReader() -> void +Resp.RespConnectionExtensions +Resp.RespException +Resp.RespException.RespException(string! message) -> void +Resp.RespFrameScanner +Resp.RespFrameScanner.OnBeforeFrame(ref Resp.RespScanState state, ref Resp.FrameScanInfo info) -> void +Resp.RespFrameScanner.TryRead(ref Resp.RespScanState state, in System.Buffers.ReadOnlySequence data, ref Resp.FrameScanInfo info) -> System.Buffers.OperationStatus +Resp.RespFrameScanner.ValidateRequest(in System.Buffers.ReadOnlySequence message) -> void +Resp.RespPayload +Resp.RespPayload.Dispose() -> void +Resp.RespPayload.Payload.get -> System.Buffers.ReadOnlySequence +Resp.RespPayload.RespPayload() -> void +Resp.RespPayload.Validate(bool checkError = true) -> void +Resp.RespPrefix +Resp.RespPrefix.Array = 42 -> Resp.RespPrefix +Resp.RespPrefix.Attribute = 124 -> Resp.RespPrefix +Resp.RespPrefix.BigInteger = 40 -> Resp.RespPrefix +Resp.RespPrefix.Boolean = 35 -> Resp.RespPrefix +Resp.RespPrefix.BulkError = 33 -> Resp.RespPrefix +Resp.RespPrefix.BulkString = 36 -> Resp.RespPrefix +Resp.RespPrefix.Double = 44 -> Resp.RespPrefix +Resp.RespPrefix.Integer = 58 -> Resp.RespPrefix +Resp.RespPrefix.Map = 37 -> Resp.RespPrefix +Resp.RespPrefix.None = 0 -> Resp.RespPrefix +Resp.RespPrefix.Null = 95 -> Resp.RespPrefix +Resp.RespPrefix.Push = 62 -> Resp.RespPrefix +Resp.RespPrefix.Set = 126 -> Resp.RespPrefix +Resp.RespPrefix.SimpleError = 45 -> Resp.RespPrefix +Resp.RespPrefix.SimpleString = 43 -> Resp.RespPrefix +Resp.RespPrefix.StreamContinuation = 59 -> Resp.RespPrefix +Resp.RespPrefix.StreamTerminator = 46 -> Resp.RespPrefix +Resp.RespPrefix.VerbatimString = 61 -> Resp.RespPrefix +Resp.RespReader +Resp.RespReader.AggregateChildren() -> Resp.RespReader.AggregateEnumerator +Resp.RespReader.AggregateEnumerator +Resp.RespReader.AggregateEnumerator.AggregateEnumerator() -> void +Resp.RespReader.AggregateEnumerator.AggregateEnumerator(scoped in Resp.RespReader reader) -> void +Resp.RespReader.AggregateEnumerator.Current.get -> Resp.RespReader +Resp.RespReader.AggregateEnumerator.DemandNext() -> void +Resp.RespReader.AggregateEnumerator.FillAll(scoped System.Span target, Resp.RespReader.Projection! projection) -> void +Resp.RespReader.AggregateEnumerator.GetEnumerator() -> Resp.RespReader.AggregateEnumerator +Resp.RespReader.AggregateEnumerator.MoveNext() -> bool +Resp.RespReader.AggregateEnumerator.MoveNext(Resp.RespPrefix prefix) -> bool +Resp.RespReader.AggregateEnumerator.MoveNext(Resp.RespAttributeReader! respAttributeReader, ref T attributes) -> bool +Resp.RespReader.AggregateEnumerator.MoveNext(Resp.RespPrefix prefix, Resp.RespAttributeReader! respAttributeReader, ref T attributes) -> bool +Resp.RespReader.AggregateEnumerator.MovePast(out Resp.RespReader reader) -> void +Resp.RespReader.AggregateEnumerator.ReadOne(Resp.RespReader.Projection! projection) -> T +Resp.RespReader.AggregateEnumerator.Value -> Resp.RespReader +Resp.RespReader.AggregateLength() -> int +Resp.RespReader.BytesConsumed.get -> long +Resp.RespReader.CopyTo(System.Span target) -> int +Resp.RespReader.DemandAggregate() -> void +Resp.RespReader.DemandEnd() -> void +Resp.RespReader.DemandNotNull() -> void +Resp.RespReader.DemandScalar() -> void +Resp.RespReader.FillAll(scoped System.Span target, Resp.RespReader.Projection! projection) -> void +Resp.RespReader.Is(byte value) -> bool +Resp.RespReader.Is(System.ReadOnlySpan value) -> bool +Resp.RespReader.IsAggregate.get -> bool +Resp.RespReader.IsAttribute.get -> bool +Resp.RespReader.IsError.get -> bool +Resp.RespReader.IsNull.get -> bool +Resp.RespReader.IsScalar.get -> bool +Resp.RespReader.IsStreaming.get -> bool +Resp.RespReader.MoveNext() -> void +Resp.RespReader.MoveNext(Resp.RespPrefix prefix) -> void +Resp.RespReader.MoveNext(Resp.RespAttributeReader! respAttributeReader, ref T attributes) -> void +Resp.RespReader.MoveNext(Resp.RespPrefix prefix, Resp.RespAttributeReader! respAttributeReader, ref T attributes) -> void +Resp.RespReader.MoveNextAggregate() -> void +Resp.RespReader.MoveNextScalar() -> void +Resp.RespReader.ParseBytes(Resp.RespReader.Parser! parser, TState? state) -> T +Resp.RespReader.ParseBytes(Resp.RespReader.Parser! parser) -> T +Resp.RespReader.ParseChars(Resp.RespReader.Parser! parser, TState? state) -> T +Resp.RespReader.ParseChars(Resp.RespReader.Parser! parser) -> T +Resp.RespReader.Parser +Resp.RespReader.Parser +Resp.RespReader.Prefix.get -> Resp.RespPrefix +Resp.RespReader.Projection +Resp.RespReader.ReadBoolean() -> bool +Resp.RespReader.ReadDecimal() -> decimal +Resp.RespReader.ReadDouble() -> double +Resp.RespReader.ReadEnum(T unknownValue = default(T)) -> T +Resp.RespReader.ReadInt32() -> int +Resp.RespReader.ReadInt64() -> long +Resp.RespReader.ReadString() -> string? +Resp.RespReader.ReadString(out string! prefix) -> string? +Resp.RespReader.RespReader() -> void +Resp.RespReader.RespReader(scoped in System.Buffers.ReadOnlySequence value) -> void +Resp.RespReader.RespReader(System.ReadOnlySpan value) -> void +Resp.RespReader.ScalarChunks() -> Resp.RespReader.ScalarEnumerator +Resp.RespReader.ScalarEnumerator +Resp.RespReader.ScalarEnumerator.Current.get -> System.ReadOnlySpan +Resp.RespReader.ScalarEnumerator.CurrentLength.get -> int +Resp.RespReader.ScalarEnumerator.GetEnumerator() -> Resp.RespReader.ScalarEnumerator +Resp.RespReader.ScalarEnumerator.MoveNext() -> bool +Resp.RespReader.ScalarEnumerator.MovePast(out Resp.RespReader reader) -> void +Resp.RespReader.ScalarEnumerator.ScalarEnumerator() -> void +Resp.RespReader.ScalarEnumerator.ScalarEnumerator(scoped in Resp.RespReader reader) -> void +Resp.RespReader.ScalarIsEmpty() -> bool +Resp.RespReader.ScalarLength() -> int +Resp.RespReader.ScalarLongLength() -> long +Resp.RespReader.SkipChildren() -> void +Resp.RespReader.TryGetSpan(out System.ReadOnlySpan value) -> bool +Resp.RespReader.TryMoveNext() -> bool +Resp.RespReader.TryMoveNext(bool checkError) -> bool +Resp.RespReader.TryMoveNext(Resp.RespPrefix prefix) -> bool +Resp.RespReader.TryMoveNext(Resp.RespAttributeReader! respAttributeReader, ref T attributes) -> bool +Resp.RespReader.TryReadDouble(out double value, bool allowTokens = true) -> bool +Resp.RespReader.TryReadInt32(out int value) -> bool +Resp.RespReader.TryReadInt64(out long value) -> bool +Resp.RespReader.TryReadNext() -> bool +Resp.RespScanState +Resp.RespScanState.IsComplete.get -> bool +Resp.RespScanState.IsOutOfBand.get -> bool +Resp.RespScanState.RespScanState() -> void +Resp.RespScanState.TotalBytes.get -> long +Resp.RespScanState.TryRead(in System.Buffers.ReadOnlySequence value, out long bytesRead) -> bool +Resp.RespScanState.TryRead(ref Resp.RespReader reader, out long bytesRead) -> bool +Resp.RespScanState.TryRead(System.ReadOnlySpan value, out int bytesRead) -> bool +Resp.RespWriter +Resp.RespWriter.CommandMap.get -> Resp.ICommandMap? +Resp.RespWriter.CommandMap.set -> void +Resp.RespWriter.Flush() -> void +Resp.RespWriter.RespWriter() -> void +Resp.RespWriter.RespWriter(System.Buffers.IBufferWriter! target) -> void +Resp.RespWriter.RespWriter(System.Span target) -> void +Resp.RespWriter.WriteArray(int count) -> void +Resp.RespWriter.WriteBulkString(bool value) -> void +Resp.RespWriter.WriteBulkString(int value) -> void +Resp.RespWriter.WriteBulkString(long value) -> void +Resp.RespWriter.WriteBulkString(scoped System.ReadOnlySpan value) -> void +Resp.RespWriter.WriteBulkString(scoped System.ReadOnlySpan value) -> void +Resp.RespWriter.WriteBulkString(string! value) -> void +Resp.RespWriter.WriteCommand(scoped System.ReadOnlySpan command, int args) -> void +Resp.RespWriter.WriteRaw(scoped System.ReadOnlySpan buffer) -> void +static Resp.RespConnectionExtensions.Send(this Resp.IRespConnection! connection, scoped System.ReadOnlySpan command, TRequest request, Resp.IRespFormatter? formatter = null) -> Resp.RespPayload! +static Resp.RespFrameScanner.Default.get -> Resp.RespFrameScanner! +static Resp.RespFrameScanner.Subscription.get -> Resp.RespFrameScanner! +static Resp.RespPayload.Create(System.Buffers.ReadOnlySequence payload) -> Resp.RespPayload! +static Resp.RespPayload.Create(System.ReadOnlyMemory payload) -> Resp.RespPayload! +static Resp.RespScanState.Create(bool pubSubConnection) -> Resp.RespScanState +virtual Resp.RespAttributeReader.Read(ref Resp.RespReader reader, ref T value) -> void +virtual Resp.RespAttributeReader.ReadKeyValuePair(scoped System.ReadOnlySpan key, ref Resp.RespReader reader, ref T value) -> bool +virtual Resp.RespAttributeReader.ReadKeyValuePairs(ref Resp.RespReader reader, ref T value) -> int +virtual Resp.RespPayload.Wait(System.TimeSpan timeout) -> void +virtual Resp.RespPayload.WaitAsync() -> System.Threading.Tasks.Task! +virtual Resp.RespReader.Parser.Invoke(System.ReadOnlySpan value, TState? state) -> TValue +virtual Resp.RespReader.Parser.Invoke(System.ReadOnlySpan value) -> TValue +virtual Resp.RespReader.Projection.Invoke(ref Resp.RespReader value) -> T +Resp.RedisCommands.RedisString.GetAsync() -> System.Threading.Tasks.Task! +Resp.RedisCommands.RedisString.RedisString(Resp.IRespConnection! connection, string! key, System.Threading.CancellationToken cancellationToken) -> void +Resp.RedisCommands.RedisString.RedisString(Resp.IRespConnection! connection, string! key, System.TimeSpan timeout = default(System.TimeSpan)) -> void diff --git a/src/RESP.Core/PublicAPI/net5.0/PublicAPI.Shipped.txt b/src/RESP.Core/PublicAPI/net5.0/PublicAPI.Shipped.txt new file mode 100644 index 000000000..7dc5c5811 --- /dev/null +++ b/src/RESP.Core/PublicAPI/net5.0/PublicAPI.Shipped.txt @@ -0,0 +1 @@ +#nullable enable diff --git a/src/RESP.Core/PublicAPI/net5.0/PublicAPI.Unshipped.txt b/src/RESP.Core/PublicAPI/net5.0/PublicAPI.Unshipped.txt new file mode 100644 index 000000000..8ba652ba4 --- /dev/null +++ b/src/RESP.Core/PublicAPI/net5.0/PublicAPI.Unshipped.txt @@ -0,0 +1,2 @@ +#nullable enable +System.Runtime.CompilerServices.IsExternalInit (forwarded, contained in System.Runtime) \ No newline at end of file diff --git a/src/RESP.Core/PublicAPI/net7.0/PublicAPI.Shipped.txt b/src/RESP.Core/PublicAPI/net7.0/PublicAPI.Shipped.txt new file mode 100644 index 000000000..7dc5c5811 --- /dev/null +++ b/src/RESP.Core/PublicAPI/net7.0/PublicAPI.Shipped.txt @@ -0,0 +1 @@ +#nullable enable diff --git a/src/RESP.Core/PublicAPI/net7.0/PublicAPI.Unshipped.txt b/src/RESP.Core/PublicAPI/net7.0/PublicAPI.Unshipped.txt new file mode 100644 index 000000000..76395b3e6 --- /dev/null +++ b/src/RESP.Core/PublicAPI/net7.0/PublicAPI.Unshipped.txt @@ -0,0 +1,2 @@ +#nullable enable +Resp.RespReader.ParseChars(System.IFormatProvider? formatProvider = null) -> T \ No newline at end of file diff --git a/src/RESP.Core/PublicAPI/net8.0/PublicAPI.Shipped.txt b/src/RESP.Core/PublicAPI/net8.0/PublicAPI.Shipped.txt new file mode 100644 index 000000000..815c92006 --- /dev/null +++ b/src/RESP.Core/PublicAPI/net8.0/PublicAPI.Shipped.txt @@ -0,0 +1 @@ +#nullable enable \ No newline at end of file diff --git a/src/RESP.Core/PublicAPI/net8.0/PublicAPI.Unshipped.txt b/src/RESP.Core/PublicAPI/net8.0/PublicAPI.Unshipped.txt new file mode 100644 index 000000000..235faee15 --- /dev/null +++ b/src/RESP.Core/PublicAPI/net8.0/PublicAPI.Unshipped.txt @@ -0,0 +1,2 @@ +#nullable enable +Resp.RespReader.ParseBytes(System.IFormatProvider? formatProvider = null) -> T diff --git a/src/RESP.Core/README.md b/src/RESP.Core/README.md new file mode 100644 index 000000000..4a32b6bcc --- /dev/null +++ b/src/RESP.Core/README.md @@ -0,0 +1,3 @@ +# RESP.Core + +This library contains the low-level RESP (Redis, etc) APIs. It is not intended for general use. \ No newline at end of file diff --git a/src/RESP.Core/RESP.Core.csproj b/src/RESP.Core/RESP.Core.csproj new file mode 100644 index 000000000..ee0188560 --- /dev/null +++ b/src/RESP.Core/RESP.Core.csproj @@ -0,0 +1,64 @@ + + + enable + + net461;netstandard2.0;net472;net6.0;net8.0;net9.0 + Resp + Low-level RESP (Redis, etc) APIs. + RESP.Core + RESP.Core + RESP.Core + RESP + true + true + README.md + $(NoWarn);CS1591 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + FrameworkShims.cs + + + NullableHacks.cs + + + SkipLocalsInit.cs + + + \ No newline at end of file diff --git a/src/RESP.Core/Raw.cs b/src/RESP.Core/Raw.cs new file mode 100644 index 000000000..0328e7961 --- /dev/null +++ b/src/RESP.Core/Raw.cs @@ -0,0 +1,139 @@ +using System; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Text; + +#if NETCOREAPP3_0_OR_GREATER +using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.X86; +#endif + +namespace Resp; + +/// +/// Pre-computed payload fragments, for high-volume scenarios / common values. +/// +/// +/// CPU-endianness applies here; we can't just use "const" - however, modern JITs treat "static readonly" *almost* the same as "const", so: meh. +/// +internal static class Raw +{ + public static ulong Create64(ReadOnlySpan bytes, int length) + { + if (length != bytes.Length) + { + throw new ArgumentException($"Length check failed: {length} vs {bytes.Length}, value: {RespConstants.UTF8.GetString(bytes)}", nameof(length)); + } + if (length < 0 || length > sizeof(ulong)) + { + throw new ArgumentOutOfRangeException(nameof(length), $"Invalid length {length} - must be 0-{sizeof(ulong)}"); + } + + // this *will* be aligned; this approach intentionally chosen for parity with write + Span scratch = stackalloc byte[sizeof(ulong)]; + if (length != sizeof(ulong)) scratch.Slice(length).Clear(); + bytes.CopyTo(scratch); + return Unsafe.ReadUnaligned(ref MemoryMarshal.GetReference(scratch)); + } + + public static uint Create32(ReadOnlySpan bytes, int length) + { + if (length != bytes.Length) + { + throw new ArgumentException($"Length check failed: {length} vs {bytes.Length}, value: {RespConstants.UTF8.GetString(bytes)}", nameof(length)); + } + if (length < 0 || length > sizeof(uint)) + { + throw new ArgumentOutOfRangeException(nameof(length), $"Invalid length {length} - must be 0-{sizeof(uint)}"); + } + + // this *will* be aligned; this approach intentionally chosen for parity with write + Span scratch = stackalloc byte[sizeof(uint)]; + if (length != sizeof(uint)) scratch.Slice(length).Clear(); + bytes.CopyTo(scratch); + return Unsafe.ReadUnaligned(ref MemoryMarshal.GetReference(scratch)); + } + + public static ulong BulkStringEmpty_6 = Create64("$0\r\n\r\n"u8, 6); + + public static ulong BulkStringInt32_M1_8 = Create64("$2\r\n-1\r\n"u8, 8); + public static ulong BulkStringInt32_0_7 = Create64("$1\r\n0\r\n"u8, 7); + public static ulong BulkStringInt32_1_7 = Create64("$1\r\n1\r\n"u8, 7); + public static ulong BulkStringInt32_2_7 = Create64("$1\r\n2\r\n"u8, 7); + public static ulong BulkStringInt32_3_7 = Create64("$1\r\n3\r\n"u8, 7); + public static ulong BulkStringInt32_4_7 = Create64("$1\r\n4\r\n"u8, 7); + public static ulong BulkStringInt32_5_7 = Create64("$1\r\n5\r\n"u8, 7); + public static ulong BulkStringInt32_6_7 = Create64("$1\r\n6\r\n"u8, 7); + public static ulong BulkStringInt32_7_7 = Create64("$1\r\n7\r\n"u8, 7); + public static ulong BulkStringInt32_8_7 = Create64("$1\r\n8\r\n"u8, 7); + public static ulong BulkStringInt32_9_7 = Create64("$1\r\n9\r\n"u8, 7); + public static ulong BulkStringInt32_10_8 = Create64("$2\r\n10\r\n"u8, 8); + + public static ulong BulkStringPrefix_M1_5 = Create64("$-1\r\n"u8, 5); + public static uint BulkStringPrefix_0_4 = Create32("$0\r\n"u8, 4); + public static uint BulkStringPrefix_1_4 = Create32("$1\r\n"u8, 4); + public static uint BulkStringPrefix_2_4 = Create32("$2\r\n"u8, 4); + public static uint BulkStringPrefix_3_4 = Create32("$3\r\n"u8, 4); + public static uint BulkStringPrefix_4_4 = Create32("$4\r\n"u8, 4); + public static uint BulkStringPrefix_5_4 = Create32("$5\r\n"u8, 4); + public static uint BulkStringPrefix_6_4 = Create32("$6\r\n"u8, 4); + public static uint BulkStringPrefix_7_4 = Create32("$7\r\n"u8, 4); + public static uint BulkStringPrefix_8_4 = Create32("$8\r\n"u8, 4); + public static uint BulkStringPrefix_9_4 = Create32("$9\r\n"u8, 4); + public static ulong BulkStringPrefix_10_5 = Create64("$10\r\n"u8, 5); + + public static ulong ArrayPrefix_M1_5 = Create64("*-1\r\n"u8, 5); + public static uint ArrayPrefix_0_4 = Create32("*0\r\n"u8, 4); + public static uint ArrayPrefix_1_4 = Create32("*1\r\n"u8, 4); + public static uint ArrayPrefix_2_4 = Create32("*2\r\n"u8, 4); + public static uint ArrayPrefix_3_4 = Create32("*3\r\n"u8, 4); + public static uint ArrayPrefix_4_4 = Create32("*4\r\n"u8, 4); + public static uint ArrayPrefix_5_4 = Create32("*5\r\n"u8, 4); + public static uint ArrayPrefix_6_4 = Create32("*6\r\n"u8, 4); + public static uint ArrayPrefix_7_4 = Create32("*7\r\n"u8, 4); + public static uint ArrayPrefix_8_4 = Create32("*8\r\n"u8, 4); + public static uint ArrayPrefix_9_4 = Create32("*9\r\n"u8, 4); + public static ulong ArrayPrefix_10_5 = Create64("*10\r\n"u8, 5); + +#if NETCOREAPP3_0_OR_GREATER + private static uint FirstAndLast(char first, char last) + { + Debug.Assert(first < 128 && last < 128, "ASCII please"); + Span scratch = [(byte)first, 0, 0, (byte)last]; + // this *will* be aligned; this approach intentionally chosen for how we read + return Unsafe.ReadUnaligned(ref MemoryMarshal.GetReference(scratch)); + } + + public const int CommonRespIndex_Success = 0; + public const int CommonRespIndex_SingleDigitInteger = 1; + public const int CommonRespIndex_DoubleDigitInteger = 2; + public const int CommonRespIndex_SingleDigitString = 3; + public const int CommonRespIndex_DoubleDigitString = 4; + public const int CommonRespIndex_SingleDigitArray = 5; + public const int CommonRespIndex_DoubleDigitArray = 6; + public const int CommonRespIndex_Error = 7; + + public static readonly Vector256 CommonRespPrefixes = Vector256.Create( + FirstAndLast('+', '\r'), // success +OK\r\n + FirstAndLast(':', '\n'), // single-digit integer :4\r\n + FirstAndLast(':', '\r'), // double-digit integer :42\r\n + FirstAndLast('$', '\n'), // 0-9 char string $0\r\n\r\n + FirstAndLast('$', '\r'), // null/10-99 char string $-1\r\n or $10\r\nABCDEFGHIJ\r\n + FirstAndLast('*', '\n'), // 0-9 length array *0\r\n + FirstAndLast('*', '\r'), // null/10-99 length array *-1\r\n or *10\r\n:0\r\n:0\r\n:0\r\n:0\r\n:0\r\n:0\r\n:0\r\n:0\r\n:0\r\n:0\r\n + FirstAndLast('-', 'R')); // common errors -ERR something bad happened + + public static readonly Vector256 FirstLastMask = CreateUInt32(0xFF0000FF); + + private static Vector256 CreateUInt32(uint value) + { +#if NET7_0_OR_GREATER + return Vector256.Create(value); +#else + return Vector256.Create(value, value, value, value, value, value, value, value); +#endif + } + +#endif +} diff --git a/src/RESP.Core/RespAttributeReader.cs b/src/RESP.Core/RespAttributeReader.cs new file mode 100644 index 000000000..5d65a9200 --- /dev/null +++ b/src/RESP.Core/RespAttributeReader.cs @@ -0,0 +1,74 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Resp; + +/// +/// Allows attribute data to be parsed conveniently. +/// +/// The type of data represented by this reader. +public abstract class RespAttributeReader +{ + /// + /// Parse a group of attributes. + /// + public virtual void Read(ref RespReader reader, ref T value) + { + reader.Demand(RespPrefix.Attribute); + _ = ReadKeyValuePairs(ref reader, ref value); + } + + /// + /// Parse an aggregate as a set of key/value pairs. + /// + /// The number of pairs successfully processed. + protected virtual int ReadKeyValuePairs(ref RespReader reader, ref T value) + { + var iterator = reader.AggregateChildren(); + + byte[] pooledBuffer = []; + Span localBuffer = stackalloc byte[128]; + int count = 0; + while (iterator.MoveNext() && iterator.Value.TryReadNext()) + { + if (iterator.Value.IsScalar) + { + var key = iterator.Value.Buffer(ref pooledBuffer, localBuffer); + + if (iterator.MoveNext() && iterator.Value.TryReadNext()) + { + if (ReadKeyValuePair(key, ref iterator.Value, ref value)) + { + count++; + } + } + else + { + break; // no matching value for this key + } + } + else + { + if (iterator.MoveNext() && iterator.Value.TryReadNext()) + { + // we won't try to handle aggregate keys; skip the value + } + else + { + break; // no matching value for this key + } + } + } + iterator.MovePast(out reader); + return count; + } + + /// + /// Parse an individual key/value pair. + /// + /// True if the pair was successfully processed. + public virtual bool ReadKeyValuePair(scoped ReadOnlySpan key, ref RespReader reader, ref T value) => false; +} diff --git a/src/RESP.Core/RespCommandAttribute.cs b/src/RESP.Core/RespCommandAttribute.cs new file mode 100644 index 000000000..be8c1e037 --- /dev/null +++ b/src/RESP.Core/RespCommandAttribute.cs @@ -0,0 +1,36 @@ +using System; +using System.Diagnostics; + +namespace Resp; + +[AttributeUsage(AttributeTargets.Method, AllowMultiple = false, Inherited = false)] +[Conditional("DEBUG")] +public sealed class RespCommandAttribute(string? command = null) : Attribute +{ + public string? Command => command; + public string? Formatter { get; set; } + public string? Parser { get; set; } + + public static class Parsers + { + private const string Prefix = "global::Resp.RespParsers."; + + /// + public const string Summary = "global::Resp." + nameof(ResponseSummary) + "." + nameof(ResponseSummary.Parser); + + public const string ByteArray = Prefix + nameof(RespParsers.ByteArray); + public const string String = Prefix + nameof(RespParsers.String); + public const string Int32 = Prefix + nameof(RespParsers.Int32); + public const string Int64 = Prefix + nameof(RespParsers.Int64); + public const string NullableInt64 = Prefix + nameof(RespParsers.NullableInt64); + public const string NullableInt32 = Prefix + nameof(RespParsers.NullableInt32); + public const string NullableSingle = Prefix + nameof(RespParsers.NullableSingle); + public const string BufferWriter = Prefix + nameof(RespParsers.BufferWriter); + public const string ByteArrayArray = Prefix + nameof(RespParsers.ByteArrayArray); + public const string OK = Prefix + nameof(RespParsers.OK); + public const string Single = Prefix + nameof(RespParsers.Single); + public const string Double = Prefix + nameof(RespParsers.Double); + public const string Success = Prefix + nameof(RespParsers.Success); + public const string NullableDouble = Prefix + nameof(RespParsers.NullableDouble); + } +} diff --git a/src/RESP.Core/RespConnectionExtensions.cs b/src/RESP.Core/RespConnectionExtensions.cs new file mode 100644 index 000000000..3d9d38785 --- /dev/null +++ b/src/RESP.Core/RespConnectionExtensions.cs @@ -0,0 +1,287 @@ +// #define PREFER_SYNC_WRITE // makes async calls use synchronous writes + +using System; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Resp; + +public interface IRespFormatter +#if NET9_0_OR_GREATER + where TRequest : allows ref struct +#endif +{ + void Format(scoped ReadOnlySpan command, ref RespWriter writer, in TRequest request); +} + +public interface IRespSizeEstimator : IRespFormatter +#if NET9_0_OR_GREATER + where TRequest : allows ref struct +#endif +{ + int EstimateSize(scoped ReadOnlySpan command, in TRequest request); +} + +public interface IRespParser +{ + TResponse Parse(in TState state, ref RespReader reader); +} + +internal interface IRespInternalMessage : IRespMessage +{ + bool AllowInlineParsing { get; } +} + +internal interface IRespInlineParser // if implemented, parsing is permitted on the IO thread +{ +} + +public interface IRespMetadataParser // if implemented, the consumer must manually advance to the content +{ +} + +public abstract class RespCommandMap +{ + /// + /// Apply any remapping to the command. + /// + /// The command requested. + /// The remapped command; this can be the original command, a remapped command, or an empty instance if the command is not available. + public abstract ReadOnlySpan Map(ReadOnlySpan command); + + /// + /// Indicates whether the specified command is available. + /// + public virtual bool IsAvailable(ReadOnlySpan command) + => Map(command).Length != 0; + + public static RespCommandMap Default { get; } = new DefaultRespCommandMap(); + + private sealed class DefaultRespCommandMap : RespCommandMap + { + public override ReadOnlySpan Map(ReadOnlySpan command) => command; + public override bool IsAvailable(ReadOnlySpan command) => true; + } +} + +/// +/// Over-arching configuration for a RESP system. +/// +public class RespConfiguration +{ + private static readonly TimeSpan DefaultSyncTimeout = TimeSpan.FromSeconds(10); + + public static RespConfiguration Default { get; } = new( + RespCommandMap.Default, [], DefaultSyncTimeout, NullServiceProvider.Instance); + + public static Builder Create() => default; // for discoverability + + public struct Builder // intentionally mutable + { + public TimeSpan? SyncTimeout { get; set; } + public IServiceProvider? ServiceProvider { get; set; } + public RespCommandMap? CommandMap { get; set; } + public object? KeyPrefix { get; set; } // can be a string or byte[] + + public Builder(RespConfiguration? source) + { + if (source is not null) + { + CommandMap = source.RespCommandMap; + SyncTimeout = source.SyncTimeout; + KeyPrefix = source.KeyPrefix.ToArray(); + ServiceProvider = source.ServiceProvider; + // undo defaults + if (ReferenceEquals(CommandMap, RespCommandMap.Default)) CommandMap = null; + if (ReferenceEquals(ServiceProvider, NullServiceProvider.Instance)) ServiceProvider = null; + } + } + + public RespConfiguration Create() + { + byte[] prefix = KeyPrefix switch + { + null => [], + string { Length: 0 } => [], + string s => Encoding.UTF8.GetBytes(s), + byte[] { Length: 0 } => [], + byte[] b => b.AsSpan().ToArray(), // create isolated copy for mutability reasons + _ => throw new ArgumentException("KeyPrefix must be a string or byte[]", nameof(KeyPrefix)), + }; + + if (prefix.Length == 0 & SyncTimeout is null & CommandMap is null & ServiceProvider is null) return Default; + + return new( + CommandMap ?? RespCommandMap.Default, + prefix, + SyncTimeout ?? DefaultSyncTimeout, + ServiceProvider ?? NullServiceProvider.Instance); + } + } + + private RespConfiguration( + RespCommandMap respCommandMap, + byte[] keyPrefix, + TimeSpan syncTimeout, + IServiceProvider serviceProvider) + { + RespCommandMap = respCommandMap; + SyncTimeout = syncTimeout; + _keyPrefix = (byte[])keyPrefix.Clone(); // create isolated copy + ServiceProvider = serviceProvider; + } + + private readonly byte[] _keyPrefix; + public IServiceProvider ServiceProvider { get; } + public RespCommandMap RespCommandMap { get; } + public TimeSpan SyncTimeout { get; } + public ReadOnlySpan KeyPrefix => _keyPrefix; + + public Builder AsBuilder() => new(this); + + private sealed class NullServiceProvider : IServiceProvider + { + public static readonly NullServiceProvider Instance = new(); + private NullServiceProvider() { } + public object? GetService(Type serviceType) => null; + } + + internal T? GetService() where T : class + => ServiceProvider.GetService(typeof(T)) as T; +} + +/// +/// Transient state for a RESP operation. +/// +public readonly struct RespContext +{ + private readonly IRespConnection _connection; + private readonly int _database; + private readonly CancellationToken _cancellationToken; + + private const string CtorUsageWarning = $"The context from {nameof(IRespConnection)}.{nameof(IRespConnection.Context)} should be preferred, using {nameof(WithCancellationToken)} etc as necessary."; + + /// + public override string ToString() => _connection?.ToString() ?? "(null)"; + + [Obsolete(CtorUsageWarning)] + public RespContext(IRespConnection connection) : this(connection, -1, CancellationToken.None) + { + } + + [Obsolete(CtorUsageWarning)] + public RespContext(IRespConnection connection, CancellationToken cancellationToken) + : this(connection, -1, cancellationToken) + { + } + + /// + /// Transient state for a RESP operation. + /// + [Obsolete(CtorUsageWarning)] + public RespContext( + IRespConnection connection, + int database = -1, + CancellationToken cancellationToken = default) + { + _connection = connection; + _database = database; + _cancellationToken = cancellationToken; + } + + public IRespConnection Connection => _connection; + public int Database => _database; + public CancellationToken CancellationToken => _cancellationToken; + + public RespMessageBuilder Command(ReadOnlySpan command, T value, IRespFormatter formatter) + => new(this, command, value, formatter); + + public RespMessageBuilder Command(ReadOnlySpan command) + => new(this, command, Void.Instance, RespFormatters.Void); + + public RespMessageBuilder Command(ReadOnlySpan command, string value, bool isKey) + => new(this, command, value, RespFormatters.String(isKey)); + + public RespMessageBuilder Command(ReadOnlySpan command, byte[] value, bool isKey) + => new(this, command, value, RespFormatters.ByteArray(isKey)); + + public RespCommandMap RespCommandMap => _connection.Configuration.RespCommandMap; + + public RespContext WithCancellationToken(CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + RespContext clone = this; + Unsafe.AsRef(in clone._cancellationToken) = cancellationToken; + return clone; + } + + public RespContext WithDatabase(int database) + { + RespContext clone = this; + Unsafe.AsRef(in clone._database) = database; + return clone; + } + + public RespContext WithConnection(IRespConnection connection) + { + RespContext clone = this; + Unsafe.AsRef(in clone._connection) = connection; + return clone; + } + + public IBatchConnection CreateBatch(int sizeHint = 0) => new BatchConnection(in this, sizeHint); + + internal static RespContext For(IRespConnection connection) +#pragma warning disable CS0618 // Type or member is obsolete + => new(connection); +#pragma warning restore CS0618 // Type or member is obsolete +} + +public static class RespConnectionExtensions +{ + /// + /// Enforces stricter ordering guarantees, so that unawaited async operations cannot cause overlapping writes. + /// + public static IRespConnection ForPipeline(this IRespConnection connection) + => connection is PipelinedConnection ? connection : new PipelinedConnection(in connection.Context); + + public static IRespConnection WithConfiguration(this IRespConnection connection, RespConfiguration configuration) + => ReferenceEquals(configuration, connection.Configuration) + ? connection + : new ConfiguredConnection(connection, configuration); + + private sealed class ConfiguredConnection : IRespConnection + { + private readonly IRespConnection _tail; + private readonly RespConfiguration _configuration; + private readonly RespContext _context; + + public ref readonly RespContext Context => ref _context; + public ConfiguredConnection(IRespConnection tail, RespConfiguration configuration) + { + _tail = tail; + _configuration = configuration; + _context = RespContext.For(this); + } + + public void Dispose() => _tail.Dispose(); + + public ValueTask DisposeAsync() => _tail.DisposeAsync(); + + public RespConfiguration Configuration => _configuration; + + public bool CanWrite => _tail.CanWrite; + + public int Outstanding => _tail.Outstanding; + + public void Send(IRespMessage message) => _tail.Send(message); + public void Send(ReadOnlySpan messages) => _tail.Send(messages); + + public Task SendAsync(IRespMessage message) => + _tail.SendAsync(message); + + public Task SendAsync(ReadOnlyMemory messages) => _tail.SendAsync(messages); + } +} diff --git a/src/RESP.Core/RespConnectionPool.cs b/src/RESP.Core/RespConnectionPool.cs new file mode 100644 index 000000000..71a3f97bf --- /dev/null +++ b/src/RESP.Core/RespConnectionPool.cs @@ -0,0 +1,189 @@ +using System; +using System.Collections.Concurrent; +using System.ComponentModel; +using System.IO; +using System.Net; +using System.Net.Sockets; +using System.Threading; +using System.Threading.Tasks; + +namespace Resp; + +public sealed class RespConnectionPool : IDisposable +{ + private readonly RespConfiguration _configuration; + private const int DefaultCount = 10; + private bool _isDisposed; + + [Obsolete("This is for testing only")] + [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)] + public bool UseCustomNetworkStream { get; set; } + + private readonly ConcurrentQueue _pool = []; + private readonly Func _createConnection; + private readonly int _count; + + public RespConnectionPool( + Func createConnection, + RespConfiguration? configuration = null, + int count = RespConnectionPool.DefaultCount) + { + _createConnection = createConnection; + _count = count; + _configuration = configuration ?? RespConfiguration.Default; + } + + public RespConnectionPool( + IPAddress? address = null, + int port = 6379, + RespConfiguration? configuration = null, + int count = DefaultCount) + : this(new IPEndPoint(address ?? IPAddress.Loopback, port), configuration, count) + { + } + + public RespConnectionPool(EndPoint endPoint, RespConfiguration? configuration = null, int count = DefaultCount) + { +#pragma warning disable CS0618 // Type or member is obsolete + _createConnection = config => CreateConnection(config, endPoint, UseCustomNetworkStream); +#pragma warning restore CS0618 // Type or member is obsolete + _count = count; + _configuration = configuration ?? RespConfiguration.Default; + } + + /// + /// Borrow a connection from the pool. + /// + /// The database to override in the context of the leased connection. + /// The cancellation token to override in the context of the leased connection. + public IRespConnection GetConnection(int? database = null, CancellationToken? cancellationToken = null) + { + ThrowIfDisposed(); + if (cancellationToken.HasValue) + { + cancellationToken.GetValueOrDefault().ThrowIfCancellationRequested(); + } + + if (!_pool.TryDequeue(out var connection)) + { + connection = _createConnection(_configuration); + } + + return new PoolWrapper(this, connection, database, cancellationToken); + } + + private void ThrowIfDisposed() + { + if (_isDisposed) Throw(); + static void Throw() => throw new ObjectDisposedException(nameof(RespConnectionPool)); + } + + public void Dispose() + { + _isDisposed = true; + while (_pool.TryDequeue(out var connection)) + { + connection.Dispose(); + } + } + + private void Return(IRespConnection tail) + { + if (_isDisposed || !tail.CanWrite || _pool.Count >= _count) + { + tail.Dispose(); + } + else + { + _pool.Enqueue(tail); + } + } + + private static IRespConnection CreateConnection(RespConfiguration config, EndPoint endpoint, bool useCustom) + { + Socket socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + socket.NoDelay = true; + socket.Connect(endpoint); + return new DirectWriteConnection(config, Wrap(socket, useCustom)); + + static Stream Wrap(Socket socket, bool useCustom) + { +#if NETCOREAPP3_0_OR_GREATER + if (useCustom) return new CustomNetworkStream(socket); +#endif + return new NetworkStream(socket); + } + } + + private sealed class PoolWrapper : IRespConnection + { + private bool _isDisposed; + private readonly RespConnectionPool _pool; + private readonly IRespConnection _tail; + private readonly RespContext _context; + + public ref readonly RespContext Context => ref _context; + + public PoolWrapper( + RespConnectionPool pool, + IRespConnection tail, + int? database, + CancellationToken? cancellationToken) + { + _pool = pool; + _tail = tail; + _context = RespContext.For(this); + if (database.HasValue) _context = _context.WithDatabase(database.GetValueOrDefault()); + if (cancellationToken.HasValue) + _context = _context.WithCancellationToken(cancellationToken.GetValueOrDefault()); + } + + public void Dispose() + { + _isDisposed = true; + _pool.Return(_tail); + } + + public bool CanWrite => !_isDisposed && _tail.CanWrite; + + public int Outstanding => _tail.Outstanding; + + public RespConfiguration Configuration => _tail.Configuration; + + private void ThrowIfDisposed() + { + if (_isDisposed) Throw(); + static void Throw() => throw new ObjectDisposedException(nameof(PoolWrapper)); + } + + public ValueTask DisposeAsync() + { + Dispose(); + return default; + } + + public void Send(IRespMessage message) + { + ThrowIfDisposed(); + _tail.Send(message); + } + + public void Send(ReadOnlySpan messages) + { + ThrowIfDisposed(); + _tail.Send(messages); + } + + public Task SendAsync(IRespMessage message) + { + ThrowIfDisposed(); + return _tail.SendAsync(message); + } + + public Task SendAsync(ReadOnlyMemory messages) + { + ThrowIfDisposed(); + return _tail.SendAsync(messages); + } + } +} diff --git a/src/RESP.Core/RespConstants.cs b/src/RESP.Core/RespConstants.cs new file mode 100644 index 000000000..4eff2a509 --- /dev/null +++ b/src/RESP.Core/RespConstants.cs @@ -0,0 +1,52 @@ +using System; +using System.Buffers.Binary; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Text; + +namespace Resp; + +internal static class RespConstants +{ + public static readonly UTF8Encoding UTF8 = new(false); + + public static ReadOnlySpan CrlfBytes => "\r\n"u8; + + public static readonly ushort CrLfUInt16 = UnsafeCpuUInt16(CrlfBytes); + + public static ReadOnlySpan OKBytes => "OK"u8; + public static readonly ushort OKUInt16 = UnsafeCpuUInt16(OKBytes); + + public static readonly uint BulkStringStreaming = UnsafeCpuUInt32("$?\r\n"u8); + public static readonly uint BulkStringNull = UnsafeCpuUInt32("$-1\r"u8); + + public static readonly uint ArrayStreaming = UnsafeCpuUInt32("*?\r\n"u8); + public static readonly uint ArrayNull = UnsafeCpuUInt32("*-1\r"u8); + + public static ushort UnsafeCpuUInt16(ReadOnlySpan bytes) + => Unsafe.ReadUnaligned(ref MemoryMarshal.GetReference(bytes)); + public static ushort UnsafeCpuUInt16(ReadOnlySpan bytes, int offset) + => Unsafe.ReadUnaligned(ref Unsafe.Add(ref MemoryMarshal.GetReference(bytes), offset)); + public static byte UnsafeCpuByte(ReadOnlySpan bytes, int offset) + => Unsafe.Add(ref MemoryMarshal.GetReference(bytes), offset); + public static uint UnsafeCpuUInt32(ReadOnlySpan bytes) + => Unsafe.ReadUnaligned(ref MemoryMarshal.GetReference(bytes)); + public static uint UnsafeCpuUInt32(ReadOnlySpan bytes, int offset) + => Unsafe.ReadUnaligned(ref Unsafe.Add(ref MemoryMarshal.GetReference(bytes), offset)); + public static ulong UnsafeCpuUInt64(ReadOnlySpan bytes) + => Unsafe.ReadUnaligned(ref MemoryMarshal.GetReference(bytes)); + public static ushort CpuUInt16(ushort bigEndian) + => BitConverter.IsLittleEndian ? BinaryPrimitives.ReverseEndianness(bigEndian) : bigEndian; + public static uint CpuUInt32(uint bigEndian) + => BitConverter.IsLittleEndian ? BinaryPrimitives.ReverseEndianness(bigEndian) : bigEndian; + public static ulong CpuUInt64(ulong bigEndian) + => BitConverter.IsLittleEndian ? BinaryPrimitives.ReverseEndianness(bigEndian) : bigEndian; + + public const int MaxRawBytesInt32 = 11, // "-2147483648" + MaxRawBytesInt64 = 20, // "-9223372036854775808", + MaxProtocolBytesIntegerInt32 = MaxRawBytesInt32 + 3, // ?X10X\r\n where ? could be $, *, etc - usually a length prefix + MaxProtocolBytesBulkStringIntegerInt32 = MaxRawBytesInt32 + 7, // $NN\r\nX11X\r\n for NN (length) 1-11 + MaxProtocolBytesBulkStringIntegerInt64 = MaxRawBytesInt64 + 7, // $NN\r\nX20X\r\n for NN (length) 1-20 + MaxRawBytesNumber = 20, // note G17 format, allow 20 for payload + MaxProtocolBytesBytesNumber = MaxRawBytesNumber + 7; // $NN\r\nX...X\r\n for NN (length) 1-20 +} diff --git a/src/RESP.Core/RespException.cs b/src/RESP.Core/RespException.cs new file mode 100644 index 000000000..edce85fb8 --- /dev/null +++ b/src/RESP.Core/RespException.cs @@ -0,0 +1,10 @@ +using System; + +namespace Resp; + +/// +/// Represents a RESP error message. +/// +public sealed class RespException(string message) : Exception(message) +{ +} diff --git a/src/RESP.Core/RespFormatters.cs b/src/RESP.Core/RespFormatters.cs new file mode 100644 index 000000000..d0ab98d9e --- /dev/null +++ b/src/RESP.Core/RespFormatters.cs @@ -0,0 +1,61 @@ +using System; + +namespace Resp; + +public static class RespFormatters +{ + public static IRespFormatter String(bool isKey) => isKey ? Key.String : Value.String; + public static IRespFormatter ByteArray(bool isKey) => isKey ? Key.ByteArray : Value.ByteArray; + public static class Key + { + // ReSharper disable once MemberHidesStaticFromOuterClass + public static IRespFormatter String => Formatter.Default; + // ReSharper disable once MemberHidesStaticFromOuterClass + public static IRespFormatter ByteArray => Formatter.Default; + + internal sealed class Formatter : IRespFormatter, IRespFormatter + { + private Formatter() { } + public static readonly Formatter Default = new(); + public void Format(scoped ReadOnlySpan command, ref RespWriter writer, in string value) + { + writer.WriteCommand(command, 1); + writer.WriteKey(value); + } + public void Format(scoped ReadOnlySpan command, ref RespWriter writer, in byte[] value) + { + writer.WriteCommand(command, 1); + writer.WriteKey(value); + } + } + } + + public static class Value + { + // ReSharper disable once MemberHidesStaticFromOuterClass + public static IRespFormatter String => Formatter.Default; + // ReSharper disable once MemberHidesStaticFromOuterClass + public static IRespFormatter ByteArray => Formatter.Default; + + internal sealed class Formatter : IRespFormatter, IRespFormatter, IRespFormatter + { + private Formatter() { } + public static readonly Formatter Default = new(); + public void Format(scoped ReadOnlySpan command, ref RespWriter writer, in Void value) + { + writer.WriteCommand(command, 0); + } + public void Format(scoped ReadOnlySpan command, ref RespWriter writer, in string value) + { + writer.WriteCommand(command, 1); + writer.WriteBulkString(value); + } + public void Format(scoped ReadOnlySpan command, ref RespWriter writer, in byte[] value) + { + writer.WriteCommand(command, 1); + writer.WriteBulkString(value); + } + } + } + public static IRespFormatter Void => Value.Formatter.Default; +} diff --git a/src/RESP.Core/RespFrameScanner.cs b/src/RESP.Core/RespFrameScanner.cs new file mode 100644 index 000000000..628b4c65b --- /dev/null +++ b/src/RESP.Core/RespFrameScanner.cs @@ -0,0 +1,193 @@ +using System; +using System.Buffers; +using static Resp.RespConstants; +namespace Resp; + +/// +/// Scans RESP frames. +/// . +public sealed class RespFrameScanner // : IFrameSacanner, IFrameValidator +{ + /// + /// Gets a frame scanner for RESP2 request/response connections, or RESP3 connections. + /// + public static RespFrameScanner Default { get; } = new(false); + + /// + /// Gets a frame scanner that identifies RESP2 pub/sub messages. + /// + public static RespFrameScanner Subscription { get; } = new(true); + private RespFrameScanner(bool pubsub) => _pubsub = pubsub; + private readonly bool _pubsub; + + private static readonly uint FastNull = UnsafeCpuUInt32("_\r\n\0"u8), + SingleCharScalarMask = CpuUInt32(0xFF00FFFF), + SingleDigitInteger = UnsafeCpuUInt32(":\0\r\n"u8), + EitherBoolean = UnsafeCpuUInt32("#\0\r\n"u8), + FirstThree = CpuUInt32(0xFFFFFF00); + private static readonly ulong OK = UnsafeCpuUInt64("+OK\r\n\0\0\0"u8), + PONG = UnsafeCpuUInt64("+PONG\r\n\0"u8), + DoubleCharScalarMask = CpuUInt64(0xFF0000FFFF000000), + DoubleDigitInteger = UnsafeCpuUInt64(":\0\0\r\n"u8), + FirstFive = CpuUInt64(0xFFFFFFFFFF000000), + FirstSeven = CpuUInt64(0xFFFFFFFFFFFFFF00); + + private const OperationStatus UseReader = (OperationStatus)(-1); + private static OperationStatus TryFastRead(ReadOnlySpan data, ref RespScanState info) + { + // use silly math to detect the most common short patterns without needing + // to access a reader, or use indexof etc; handles: + // +OK\r\n + // +PONG\r\n + // :N\r\n for any single-digit N (integer) + // :NN\r\n for any double-digit N (integer) + // #N\r\n for any single-digit N (boolean) + // _\r\n (null) + uint hi, lo; + switch (data.Length) + { + case 0: + case 1: + case 2: + return OperationStatus.NeedMoreData; + case 3: + hi = (((uint)UnsafeCpuUInt16(data)) << 16) | (((uint)UnsafeCpuByte(data, 2)) << 8); + break; + default: + hi = UnsafeCpuUInt32(data); + break; + } + if ((hi & FirstThree) == FastNull) + { + info.SetComplete(3, RespPrefix.Null); + return OperationStatus.Done; + } + + var masked = hi & SingleCharScalarMask; + if (masked == SingleDigitInteger) + { + info.SetComplete(4, RespPrefix.Integer); + return OperationStatus.Done; + } + else if (masked == EitherBoolean) + { + info.SetComplete(4, RespPrefix.Boolean); + return OperationStatus.Done; + } + + switch (data.Length) + { + case 3: + return OperationStatus.NeedMoreData; + case 4: + return UseReader; + case 5: + lo = ((uint)data[4]) << 24; + break; + case 6: + lo = ((uint)UnsafeCpuUInt16(data, 4)) << 16; + break; + case 7: + lo = ((uint)UnsafeCpuUInt16(data, 4)) << 16 | ((uint)UnsafeCpuByte(data, 6)) << 8; + break; + default: + lo = UnsafeCpuUInt32(data, 4); + break; + } + var u64 = BitConverter.IsLittleEndian ? ((((ulong)lo) << 32) | hi) : ((((ulong)hi) << 32) | lo); + if (((u64 & FirstFive) == OK) | ((u64 & DoubleCharScalarMask) == DoubleDigitInteger)) + { + info.SetComplete(5, RespPrefix.SimpleString); + return OperationStatus.Done; + } + if ((u64 & FirstSeven) == PONG) + { + info.SetComplete(7, RespPrefix.SimpleString); + return OperationStatus.Done; + } + return UseReader; + } + + /// + /// Attempt to read more data as part of the current frame. + /// + public OperationStatus TryRead(ref RespScanState state, in ReadOnlySequence data) + { + if (!_pubsub & state.TotalBytes == 0 & data.IsSingleSegment) + { +#if NETCOREAPP3_1_OR_GREATER + var status = TryFastRead(data.FirstSpan, ref state); +#else + var status = TryFastRead(data.First.Span, ref state); +#endif + if (status != UseReader) return status; + } + + return TryReadViaReader(ref state, in data); + + static OperationStatus TryReadViaReader(ref RespScanState state, in ReadOnlySequence data) + { + var reader = new RespReader(in data); + var complete = state.TryRead(ref reader, out var consumed); + if (complete) + { + return OperationStatus.Done; + } + return OperationStatus.NeedMoreData; + } + } + + /// + /// Attempt to read more data as part of the current frame. + /// + public OperationStatus TryRead(ref RespScanState state, ReadOnlySpan data) + { + if (!_pubsub & state.TotalBytes == 0) + { +#if NETCOREAPP3_1_OR_GREATER + var status = TryFastRead(data, ref state); +#else + var status = TryFastRead(data, ref state); +#endif + if (status != UseReader) return status; + } + + return TryReadViaReader(ref state, data); + + static OperationStatus TryReadViaReader(ref RespScanState state, ReadOnlySpan data) + { + var reader = new RespReader(data); + var complete = state.TryRead(ref reader, out var consumed); + if (complete) + { + return OperationStatus.Done; + } + return OperationStatus.NeedMoreData; + } + } + + /// + /// Validate that the supplied message is a valid RESP request, specifically: that it contains a single + /// top-level array payload with bulk-string elements, the first of which is non-empty (the command). + /// + public void ValidateRequest(in ReadOnlySequence message) + { + if (message.IsEmpty) Throw("Empty RESP frame"); + RespReader reader = new(in message); + reader.MoveNext(RespPrefix.Array); + reader.DemandNotNull(); + if (reader.IsStreaming) Throw("Streaming is not supported in this context"); + var count = reader.AggregateLength(); + for (int i = 0; i < count; i++) + { + reader.MoveNext(RespPrefix.BulkString); + reader.DemandNotNull(); + if (reader.IsStreaming) Throw("Streaming is not supported in this context"); + + if (i == 0 && reader.ScalarIsEmpty()) Throw("command must be non-empty"); + } + reader.DemandEnd(); + + static void Throw(string message) => throw new InvalidOperationException(message); + } +} diff --git a/src/RESP.Core/RespParsers.cs b/src/RESP.Core/RespParsers.cs new file mode 100644 index 000000000..958b3a3af --- /dev/null +++ b/src/RESP.Core/RespParsers.cs @@ -0,0 +1,177 @@ +using System; +using System.Buffers; +using System.Diagnostics.CodeAnalysis; + +namespace Resp; + +public readonly struct ResponseSummary(RespPrefix prefix, int length, long protocolBytes) : IEquatable +{ + public RespPrefix Prefix { get; } = prefix; + public int Length { get; } = length; + public long ProtocolBytes { get; } = protocolBytes; + + /// + public override string ToString() => $"{Prefix}, Length: {Length}, Protocol Bytes: {ProtocolBytes}"; + + /// + public bool Equals(ResponseSummary other) => EqualsCore(in other); + + private bool EqualsCore(in ResponseSummary other) => + Prefix == other.Prefix && Length == other.Length && ProtocolBytes == other.ProtocolBytes; + + bool IEquatable.Equals(ResponseSummary other) => EqualsCore(in other); + + /// + public override bool Equals(object? obj) => obj is ResponseSummary summary && EqualsCore(in summary); + + /// + public override int GetHashCode() => (int)Prefix ^ Length ^ ProtocolBytes.GetHashCode(); + + public static IRespParser Parser => ResponseSummaryParser.Default; + + private sealed class ResponseSummaryParser : IRespParser, IRespInlineParser, IRespMetadataParser + { + private ResponseSummaryParser() { } + public static readonly ResponseSummaryParser Default = new(); + + public ResponseSummary Parse(in Void state, ref RespReader reader) + { + var protocolBytes = reader.ProtocolBytesRemaining; + int length = 0; + if (reader.TryMoveNext()) + { + if (reader.IsScalar) length = reader.ScalarLength(); + else if (reader.IsAggregate) length = reader.AggregateLength(); + } + return new ResponseSummary(reader.Prefix, length, protocolBytes); + } + } +} + +public static class RespParsers +{ + public static IRespParser Success => InbuiltInlineParsers.Default; + public static IRespParser OK => OKParser.Default; + public static IRespParser String => InbuiltCopyOutParsers.Default; + public static IRespParser Int32 => InbuiltInlineParsers.Default; + public static IRespParser NullableInt32 => InbuiltInlineParsers.Default; + public static IRespParser Int64 => InbuiltInlineParsers.Default; + public static IRespParser NullableInt64 => InbuiltInlineParsers.Default; + public static IRespParser Single => InbuiltInlineParsers.Default; + public static IRespParser NullableSingle => InbuiltInlineParsers.Default; + public static IRespParser Double => InbuiltInlineParsers.Default; + public static IRespParser NullableDouble => InbuiltInlineParsers.Default; + public static IRespParser ByteArray => InbuiltCopyOutParsers.Default; + public static IRespParser ByteArrayArray => InbuiltCopyOutParsers.Default; + public static IRespParser, int> BufferWriter => InbuiltCopyOutParsers.Default; + + private sealed class Cache + { + public static IRespParser? Instance = + (InbuiltCopyOutParsers.Default as IRespParser) ?? // regular (may allocate, etc) + (InbuiltInlineParsers.Default as IRespParser) ?? // inline + (ResponseSummary.Parser as IRespParser); // inline+metadata + } + + public static IRespParser Get() + => Cache.Instance ??= GetCore(); + + public static void Set(IRespParser parser) + { + var obj = (InbuiltCopyOutParsers.Default as IRespParser) ?? + (InbuiltInlineParsers.Default as IRespParser); + if (obj is not null) ThrowInbuiltParser(typeof(TResponse)); + Cache.Instance = parser; + } + + private static IRespParser GetCore() + { + var obj = (InbuiltCopyOutParsers.Default as IRespParser) ?? + (InbuiltInlineParsers.Default as IRespParser); + if (obj is null) + { + ThrowNoParser(typeof(TResponse)); + } + + return Cache.Instance = obj; + } + + [DoesNotReturn] + private static void ThrowNoParser(Type type) => throw new InvalidOperationException( + message: + $"No default parser registered for type '{type.FullName}'; a custom parser must be specified via {nameof(RespParsers)}.{nameof(RespParsers.Set)}(...)."); + + [DoesNotReturn] + private static void ThrowInbuiltParser(Type type) => throw new InvalidOperationException( + message: $"Type '{type.FullName}' has inbuilt handling and cannot be changed."); + + private sealed class InbuiltInlineParsers : IRespParser, IRespInlineParser, + IRespParser, IRespParser, + IRespParser, IRespParser, + IRespParser, IRespParser, + IRespParser, IRespParser + { + private InbuiltInlineParsers() { } + public static readonly InbuiltInlineParsers Default = new(); + + public Void Parse(in Void state, ref RespReader reader) => Void.Instance; + + int IRespParser.Parse(in Void state, ref RespReader reader) => reader.ReadInt32(); + + int? IRespParser.Parse(in Void state, ref RespReader reader) => + reader.IsNull ? null : reader.ReadInt32(); + + long IRespParser.Parse(in Void state, ref RespReader reader) => reader.ReadInt64(); + + long? IRespParser.Parse(in Void state, ref RespReader reader) => + reader.IsNull ? null : reader.ReadInt64(); + + float IRespParser.Parse(in Void state, ref RespReader reader) => (float)reader.ReadDouble(); + + float? IRespParser.Parse(in Void state, ref RespReader reader) => + reader.IsNull ? null : (float)reader.ReadDouble(); + + double IRespParser.Parse(in Void state, ref RespReader reader) => reader.ReadDouble(); + + double? IRespParser.Parse(in Void state, ref RespReader reader) => + reader.IsNull ? null : reader.ReadDouble(); + } + + private sealed class OKParser : IRespParser, IRespInlineParser + { + private OKParser() { } + public static readonly OKParser Default = new(); + + public Void Parse(in Void state, ref RespReader reader) + { + if (!(reader.Prefix == RespPrefix.SimpleString && reader.IsOK())) + { + Throw(); + } + + return default; + static void Throw() => throw new InvalidOperationException("Expected +OK response"); + } + } + + private sealed class InbuiltCopyOutParsers : IRespParser, + IRespParser, IRespParser, + IRespParser, int> + { + private InbuiltCopyOutParsers() { } + public static readonly InbuiltCopyOutParsers Default = new(); + + string? IRespParser.Parse(in Void state, ref RespReader reader) => reader.ReadString(); + byte[]? IRespParser.Parse(in Void state, ref RespReader reader) => reader.ReadByteArray(); + + byte[]?[]? IRespParser.Parse(in Void state, ref RespReader reader) => + reader.ReadArray(static (ref RespReader reader) => reader.ReadByteArray()); + + int IRespParser, int>.Parse(in IBufferWriter state, ref RespReader reader) + { + reader.DemandScalar(); + if (reader.IsNull) return -1; + return reader.CopyTo(state); + } + } +} diff --git a/src/RESP.Core/RespPayload.cs b/src/RESP.Core/RespPayload.cs new file mode 100644 index 000000000..62449c97c --- /dev/null +++ b/src/RESP.Core/RespPayload.cs @@ -0,0 +1,671 @@ +using System; +using System.Buffers; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using System.Threading.Tasks.Sources; + +namespace Resp; + +public interface IRespMessage +{ + /// + /// Gets the request payload, reserving the value. This must be released using . + /// + bool TryReserveRequest(out ReadOnlyMemory payload); + + bool IsCompleted { get; } + + /// + /// Releases the request payload. + /// + void ReleaseRequest(); + + bool TrySetCanceled(CancellationToken cancellationToken = default); + bool TrySetException(Exception exception); + + /// + /// Parse the response and complete the request. + /// + void ProcessResponse(ref RespReader reader); + + /// + /// Cancellation associated with this message. Note that this should not typically be used to + /// cancel IO operations (for example sockets), as that would break the entire stream - however, it + /// can be used to interrupt intermediate processing before it is submitted. + /// + CancellationToken CancellationToken { get; } +} + +internal static class ActivationHelper +{ + private sealed class WorkItem +#if NETCOREAPP3_0_OR_GREATER + : IThreadPoolWorkItem +#endif + { + private WorkItem() + { +#if NET5_0_OR_GREATER + Unsafe.SkipInit(out _payload); +#else + _payload = []; +#endif + } + + private void Init(byte[] payload, int length, IRespMessage message) + { + _payload = payload; + _length = length; + _message = message; + } + + private byte[] _payload; + private int _length; + private IRespMessage? _message; + + private static WorkItem? _spare; // do NOT use ThreadStatic - different producer/consumer, no overlap + + public static void UnsafeQueueUserWorkItem( + IRespMessage message, + ReadOnlySpan payload, + ref byte[]? lease) + { + if (lease is null) + { + // we need to create our own copy of the data + lease = ArrayPool.Shared.Rent(payload.Length); + payload.CopyTo(lease); + } + + var obj = Interlocked.Exchange(ref _spare, null) ?? new(); + obj.Init(lease, payload.Length, message); + lease = null; // count as claimed + + DebugCounters.OnCopyOut(payload.Length); +#if NETCOREAPP3_0_OR_GREATER + ThreadPool.UnsafeQueueUserWorkItem(obj, false); +#else + ThreadPool.UnsafeQueueUserWorkItem(WaitCallback, obj); +#endif + } +#if !NETCOREAPP3_0_OR_GREATER + private static readonly WaitCallback WaitCallback = state => ((WorkItem)state!).Execute(); +#endif + + public static void Execute(IRespMessage? message, ReadOnlySpan payload) + { + if (message is { IsCompleted: false }) + { + try + { + var reader = new RespReader(payload); + message.ProcessResponse(ref reader); + } + catch (Exception ex) + { + message.TrySetException(ex); + } + } + } + + public void Execute() + { + var message = _message; + var payload = _payload; + var length = _length; + _message = null; + _payload = []; + _length = 0; + Interlocked.Exchange(ref _spare, this); + Execute(message, new(payload, 0, length)); + ArrayPool.Shared.Return(payload); + } + } + + public static void ProcessResponse(IRespMessage? pending, ReadOnlySpan payload, ref byte[]? lease) + { + if (pending is null) + { + // nothing to do + } + else if (pending is IRespInternalMessage { AllowInlineParsing: true }) + { + WorkItem.Execute(pending, payload); + } + else + { + WorkItem.UnsafeQueueUserWorkItem(pending, payload, ref lease); + } + } + + private static readonly Action CancellationCallback = static state + => ((IRespMessage)state!).TrySetCanceled(); + + public static CancellationTokenRegistration RegisterForCancellation( + IRespMessage message, + CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + return cancellationToken.Register(CancellationCallback, message); + } + + [Conditional("DEBUG")] + public static void DebugBreak() + { +#if DEBUG + if (Debugger.IsAttached) Debugger.Break(); +#endif + } + + [Conditional("DEBUG")] + public static void DebugBreakIf(bool condition) + { +#if DEBUG + if (condition && Debugger.IsAttached) Debugger.Break(); +#endif + } +} + +internal abstract class InternalRespMessageBase : IRespInternalMessage +{ + private IRespParser? _parser; + private byte[] _requestPayload = []; + private int _requestLength, _requestRefCount = 1; + private TState _state = default!; + private CancellationToken _cancellationToken; + private CancellationTokenRegistration _cancellationTokenRegistration; + + public abstract bool IsCompleted { get; } + + public bool TryReserveRequest(out ReadOnlyMemory payload) + { + payload = default; + while (true) // need to take reservation + { + if (IsCompleted) return false; + var oldCount = Volatile.Read(ref _requestRefCount); + if (oldCount == 0) return false; + if (Interlocked.CompareExchange(ref _requestRefCount, checked(oldCount + 1), oldCount) == oldCount) break; + } + + payload = new(_requestPayload, 0, _requestLength); + return true; + } + + public void ReleaseRequest() + { + if (!TryReleaseRequest()) ThrowReleased(); + + static void ThrowReleased() => + throw new InvalidOperationException("The request payload has already been released"); + } + + private bool TryReleaseRequest() // bool here means "it wasn't already zero"; it doesn't mean "it became zero" + { + while (true) + { + var oldCount = Volatile.Read(ref _requestRefCount); + if (oldCount == 0) return false; + if (Interlocked.CompareExchange(ref _requestRefCount, oldCount - 1, oldCount) == oldCount) + { + if (oldCount == 1) // we were the last one; recycle + { + _parser = null; + var arr = _requestPayload; + _requestLength = 0; + _requestPayload = []; + ArrayPool.Shared.Return(arr); + } + + return true; + } + } + } + + protected abstract bool TrySetResult(TResponse value); + public abstract bool TrySetException(Exception exception); + public abstract bool TrySetCanceled(CancellationToken cancellationToken = default); + + // ReSharper disable once SuspiciousTypeConversion.Global + public bool AllowInlineParsing => _parser is null or IRespInlineParser; + + void IRespMessage.ProcessResponse(ref RespReader reader) + { + // ReSharper disable once SuspiciousTypeConversion.Global + if (_parser is { } parser) + { + if (parser is not IRespMetadataParser) + { + reader.MoveNext(); // skip attributes and process errors + } + + var result = parser.Parse(in _state, ref reader); + TryReleaseRequest(); + TrySetResult(result); + } + } + + public CancellationToken CancellationToken => _cancellationToken; + + protected void UnregisterCancellation() + { + var reg = _cancellationTokenRegistration; + _cancellationTokenRegistration = default; + reg.Dispose(); + } + + protected void Reset() + { + _parser = null; + _state = default!; + _requestLength = 0; + _requestPayload = []; + _requestRefCount = 0; + _cancellationToken = CancellationToken.None; + _cancellationTokenRegistration = default; + } + + protected void Init( + byte[] requestPayload, + int requestLength, + IRespParser? parser, + in TState state, + CancellationToken cancellationToken) + { + if (cancellationToken.CanBeCanceled) + { + cancellationToken.ThrowIfCancellationRequested(); + _cancellationTokenRegistration = ActivationHelper.RegisterForCancellation(this, cancellationToken); + } + + _parser = parser; + _state = state; + _requestPayload = requestPayload; + _requestLength = requestLength; + _requestRefCount = 1; + _cancellationToken = cancellationToken; + } +} + +internal static class SyncRespMessageStatus // think "enum", but need Volatile.Read friendliness +{ + internal const int + Pending = 0, + Completed = 1, + Faulted = 2, + Cancelled = 3, + Timeout = 4; +} + +internal sealed class SyncInternalRespMessage : InternalRespMessageBase +{ + private SyncInternalRespMessage() { } + + private int _status; + private TResponse _result = default!; + private Exception? _exception; + + protected override bool TrySetResult(TResponse value) + { + if (Volatile.Read(ref _status) == SyncRespMessageStatus.Pending) + { + lock (this) + { + if (_status == SyncRespMessageStatus.Pending) + { + _result = value; + _status = SyncRespMessageStatus.Completed; + Monitor.PulseAll(this); + return true; + } + } + } + + return false; + } + + public override bool TrySetException(Exception exception) + { + if (Volatile.Read(ref _status) == SyncRespMessageStatus.Pending) + { + var newStatus = exception switch + { + TimeoutException => SyncRespMessageStatus.Timeout, + OperationCanceledException => SyncRespMessageStatus.Cancelled, + _ => SyncRespMessageStatus.Faulted, + }; + lock (this) + { + if (_status == SyncRespMessageStatus.Pending) + { + _exception = exception; + _status = newStatus; + Monitor.PulseAll(this); + return true; + } + } + } + + return false; + } + + public override bool TrySetCanceled(CancellationToken cancellationToken = default) + { + if (Volatile.Read(ref _status) == SyncRespMessageStatus.Pending) + { + if (!cancellationToken.IsCancellationRequested) + { + // if the inbound token was not cancelled: use our own + cancellationToken = CancellationToken; + } + + lock (this) + { + if (_status == SyncRespMessageStatus.Pending) + { + _status = SyncRespMessageStatus.Cancelled; + _exception = new OperationCanceledException(cancellationToken); + Monitor.PulseAll(this); + return true; + } + } + } + + return false; + } + + public override bool IsCompleted => Volatile.Read(ref _status) != SyncRespMessageStatus.Pending; + + public TResponse WaitAndRecycle(TimeSpan timeout) + { + int status = Volatile.Read(ref _status); + if (status == SyncRespMessageStatus.Pending) + { + lock (this) + { + status = _status; + if (status == SyncRespMessageStatus.Pending) + { + if (timeout == TimeSpan.Zero) + { + Monitor.Wait(this); + status = _status; + } + else if (!Monitor.Wait(this, timeout)) + { + status = _status = SyncRespMessageStatus.Timeout; + } + else + { + status = _status; + } + } + } + } + + switch (status) + { + case SyncRespMessageStatus.Completed: + var result = _result; // snapshot + if (_spare is null && TryReset()) + { + _spare = this; + } + + return result; + case SyncRespMessageStatus.Faulted: + throw _exception ?? new InvalidOperationException("Operation failed"); + case SyncRespMessageStatus.Cancelled: + throw _exception ?? new OperationCanceledException(CancellationToken); + case SyncRespMessageStatus.Timeout: + throw _exception ?? new TimeoutException(); + default: + throw new InvalidOperationException($"Unexpected status: {status}"); + } + } + + private bool TryReset() + { + Reset(); + _exception = null; + _result = default!; + _status = SyncRespMessageStatus.Pending; + return true; + } + + [ThreadStatic] + // this comment just to stop a weird formatter glitch + private static SyncInternalRespMessage? _spare; + + public static SyncInternalRespMessage Create( + byte[] requestPayload, + int requestLength, + IRespParser? parser, + in TState state, + CancellationToken cancellationToken) + { + var obj = _spare ?? new(); + _spare = null; + obj.Init(requestPayload, requestLength, parser, in state, cancellationToken); + + return obj; + } +} + +#if NET9_0_OR_GREATER && NEVER +internal sealed class AsyncInternalRespMessage( + byte[] requestPayload, + int requestLength, + IRespParser? parser) + : InternalRespMessageBase(requestPayload, requestLength, parser) +{ + [UnsafeAccessor(UnsafeAccessorKind.Constructor)] + private static extern Task CreateTask(object? state, TaskCreationOptions options); + + [UnsafeAccessor(UnsafeAccessorKind.Method)] + private static extern bool TrySetException(Task obj, Exception exception); + + [UnsafeAccessor(UnsafeAccessorKind.Method)] + private static extern bool TrySetResult(Task obj, TResponse value); + + [UnsafeAccessor(UnsafeAccessorKind.Method)] + private static extern bool TrySetCanceled(Task obj, CancellationToken cancellationToken); + + // ReSharper disable once SuspiciousTypeConversion.Global + private readonly Task _task = CreateTask( + null, + // if we're using IO-thread parsing, we *must* still dispatch downstream continuations to the thread-pool to + // prevent thread-theft; otherwise, we're fine to run downstream inline (we already jumped) + parser is IRespInlineParser ? TaskCreationOptions.RunContinuationsAsynchronously : TaskCreationOptions.None); + + private CancellationTokenRegistration _cancellationTokenRegistration; + + public override bool IsCompleted => _task.IsCompleted; + protected override bool TrySetResult(TResponse value) + { + UnregisterCancellation(); + return TrySetResult(_task, value); + } + + public override bool TrySetException(Exception exception) + { + UnregisterCancellation(); + return TrySetException(_task, exception); + } + + public override bool TrySetCanceled(CancellationToken cancellationToken) + { + UnregisterCancellation(); + return TrySetCanceled(_task, cancellationToken); + } + + private void UnregisterCancellation() + { + _cancellationTokenRegistration.Dispose(); + _cancellationTokenRegistration = default; + } + + public Task WaitAsync(CancellationToken cancellationToken = default) + { + if (cancellationToken.CanBeCanceled) + { + _cancellationTokenRegistration = ActivationHelper.RegisterForCancellation(this, cancellationToken); + } + + return _task; + } +} +#else +internal sealed class AsyncInternalRespMessage : InternalRespMessageBase, + IValueTaskSource, IValueTaskSource +{ + [ThreadStatic] + // this comment just to stop a weird formatter glitch + private static AsyncInternalRespMessage? _spare; + + // we need synchronization over multiple attempts (completion, cancellation, abort) trying + // to signal the MRTCS + private int _completedFlag; + + private bool SetCompleted(bool withSuccess = false) + { + if (Interlocked.CompareExchange(ref _completedFlag, 1, 0) == 0) + { + // stop listening for CT notifications + UnregisterCancellation(); + + // configure threading model; failure can be triggered from any thread - *always* + // dispatch to pool; in the success case, we're either on the IO thread + // (if inline-parsing is enabled) - in which case, yes: dispatch - or we've + // already jumped to a pool thread for the parse step. So: the only + // time we want to complete inline is success and not inline-parsing. + _asyncCore.RunContinuationsAsynchronously = !withSuccess || AllowInlineParsing; + + return true; + } + + return false; + } + + public static AsyncInternalRespMessage Create( + byte[] requestPayload, + int requestLength, + IRespParser? parser, + in TState state, + CancellationToken cancellationToken) + { + var obj = _spare ?? new(); + _spare = null; + obj._asyncCore.RunContinuationsAsynchronously = true; + obj.Init(requestPayload, requestLength, parser, in state, cancellationToken); + return obj; + } + + private ManualResetValueTaskSourceCore _asyncCore; + + public override bool IsCompleted => Volatile.Read(ref _completedFlag) == 1; + + protected override bool TrySetResult(TResponse value) + { + if (SetCompleted(withSuccess: true)) + { + _asyncCore.SetResult(value); + return true; + } + + return false; + } + + public override bool TrySetException(Exception exception) + { + if (SetCompleted()) + { + _asyncCore.SetException(exception); + return true; + } + + return false; + } + + public override bool TrySetCanceled(CancellationToken cancellationToken = default) + { + if (SetCompleted()) + { + if (!cancellationToken.IsCancellationRequested) + { + // if the inbound token was not cancelled: use our own + cancellationToken = CancellationToken; + } + + _asyncCore.SetException(new OperationCanceledException(cancellationToken)); + return true; + } + + return false; + } + + public ValueTask WaitTypedAsync() => new(this, _asyncCore.Version); + + internal ValueTask WaitTypedAsync(Task send) + { + if (!send.IsCompleted) return Awaited(send, this); + send.GetAwaiter().GetResult(); + return new(this, _asyncCore.Version); + + static async ValueTask Awaited(Task task, AsyncInternalRespMessage @this) + { + await task.ConfigureAwait(false); + return await @this.WaitTypedAsync().ConfigureAwait(false); + } + } + + public ValueTask WaitUntypedAsync() => new(this, _asyncCore.Version); + + internal ValueTask WaitUntypedAsync(Task send) + { + if (!send.IsCompleted) return Awaited(send, this); + send.GetAwaiter().GetResult(); + return new(this, _asyncCore.Version); + + static async ValueTask Awaited(Task task, AsyncInternalRespMessage @this) + { + await task.ConfigureAwait(false); + await @this.WaitUntypedAsync().ConfigureAwait(false); + } + } + + public ValueTaskSourceStatus GetStatus(short token) => _asyncCore.GetStatus(token); + + public void OnCompleted( + Action continuation, + object? state, + short token, + ValueTaskSourceOnCompletedFlags flags) + => _asyncCore.OnCompleted(continuation, state, token, flags); + + public TResponse GetResult(short token) + { + Debug.Assert(IsCompleted, "Async payload should already be completed"); + var result = _asyncCore.GetResult(token); + // recycle on success (only) + if (_spare is null && TryReset()) + { + _spare = this; + } + + return result; + } + + private bool TryReset() + { + Reset(); + _asyncCore.Reset(); // incr version, etc + _completedFlag = 0; + return true; + } + + void IValueTaskSource.GetResult(short token) => _ = GetResult(token); +} +#endif diff --git a/src/RESP.Core/RespPrefix.cs b/src/RESP.Core/RespPrefix.cs new file mode 100644 index 000000000..382be7925 --- /dev/null +++ b/src/RESP.Core/RespPrefix.cs @@ -0,0 +1,97 @@ +namespace Resp; + +/// +/// RESP protocol prefix. +/// +public enum RespPrefix : byte +{ + /// + /// Invalid. + /// + None = 0, + + /// + /// Simple strings: +OK\r\n. + /// + SimpleString = (byte)'+', + + /// + /// Simple errors: -ERR message\r\n. + /// + SimpleError = (byte)'-', + + /// + /// Integers: :123\r\n. + /// + Integer = (byte)':', + + /// + /// String with support for binary data: $7\r\nmessage\r\n. + /// + BulkString = (byte)'$', + + /// + /// Multiple inner messages: *1\r\n+message\r\n. + /// + Array = (byte)'*', + + /// + /// Null strings/arrays: _\r\n. + /// + Null = (byte)'_', + + /// + /// Boolean values: #T\r\n. + /// + Boolean = (byte)'#', + + /// + /// Floating-point number: ,123.45\r\n. + /// + Double = (byte)',', + + /// + /// Large integer number: (12...89\r\n. + /// + BigInteger = (byte)'(', + + /// + /// Error with support for binary data: !7\r\nmessage\r\n. + /// + BulkError = (byte)'!', + + /// + /// String that should be interpreted verbatim: =11\r\ntxt:message\r\n. + /// + VerbatimString = (byte)'=', + + /// + /// Multiple sub-items that represent a map. + /// + Map = (byte)'%', + + /// + /// Multiple sub-items that represent a set. + /// + Set = (byte)'~', + + /// + /// Out-of band messages. + /// + Push = (byte)'>', + + /// + /// Continuation of streaming scalar values. + /// + StreamContinuation = (byte)';', + + /// + /// End sentinel for streaming aggregate values. + /// + StreamTerminator = (byte)'.', + + /// + /// Metadata about the next element. + /// + Attribute = (byte)'|', +} diff --git a/src/RESP.Core/RespReader.AggregateEnumerator.cs b/src/RESP.Core/RespReader.AggregateEnumerator.cs new file mode 100644 index 000000000..7e57910e5 --- /dev/null +++ b/src/RESP.Core/RespReader.AggregateEnumerator.cs @@ -0,0 +1,196 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.ComponentModel; +using System.Runtime.CompilerServices; + +#pragma warning disable IDE0079 // Remove unnecessary suppression +#pragma warning disable CS0282 // There is no defined ordering between fields in multiple declarations of partial struct +#pragma warning restore IDE0079 // Remove unnecessary suppression + +namespace Resp; + +public ref partial struct RespReader +{ + /// + /// Reads the sub-elements associated with an aggregate value. + /// + public readonly AggregateEnumerator AggregateChildren() => new(in this); + + /// + /// Reads the sub-elements associated with an aggregate value. + /// + public ref struct AggregateEnumerator + { + // Note that _reader is the overall reader that can see outside this aggregate, as opposed + // to Current which is the sub-tree of the current element *only* + private RespReader _reader; + private int _remaining; + + /// + /// Create a new enumerator for the specified . + /// + /// The reader containing the data for this operation. + public AggregateEnumerator(scoped in RespReader reader) + { + reader.DemandAggregate(); + _remaining = reader.IsStreaming ? -1 : reader._length; + _reader = reader; + Value = default; + } + + /// + public readonly AggregateEnumerator GetEnumerator() => this; + + /// + [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)] + public RespReader Current => Value; + + /// + /// Gets the current element associated with this reader. + /// + public RespReader Value; // intentionally a field, because of ref-semantics + + /// + /// Move to the next child if possible, and move the child element into the next node. + /// + public bool MoveNext(RespPrefix prefix) + { + bool result = MoveNext(); + if (result) + { + Value.MoveNext(prefix); + } + return result; + } + + /// + /// Move to the next child if possible, and move the child element into the next node. + /// + /// The type of data represented by this reader. + public bool MoveNext(RespPrefix prefix, RespAttributeReader respAttributeReader, ref T attributes) + { + bool result = MoveNext(respAttributeReader, ref attributes); + if (result) + { + Value.MoveNext(prefix); + } + return result; + } + + /// > + public bool MoveNext() + { + object? attributes = null; + return MoveNextCore(null, ref attributes); + } + + /// > + /// The type of data represented by this reader. + public bool MoveNext(RespAttributeReader respAttributeReader, ref T attributes) + => MoveNextCore(respAttributeReader, ref attributes); + + /// > + private bool MoveNextCore(RespAttributeReader? attributeReader, ref T attributes) + { + if (_remaining == 0) + { + Value = default; + return false; + } + + // in order to provide access to attributes etc, we want Current to be positioned + // *before* the next element; for that, we'll take a snapshot before we read + _reader.MovePastCurrent(); + var snapshot = _reader.Clone(); + + if (attributeReader is null) + { + _reader.MoveNext(); + } + else + { + _reader.MoveNext(attributeReader, ref attributes); + } + if (_remaining > 0) + { + // non-streaming, decrement + _remaining--; + } + else if (_reader.Prefix == RespPrefix.StreamTerminator) + { + // end of streaming aggregate + _remaining = 0; + Value = default; + return false; + } + + // move past that sub-tree and trim the "snapshot" state, giving + // us a scoped reader that is *just* that sub-tree + _reader.SkipChildren(); + snapshot.TrimToTotal(_reader.BytesConsumed); + + Value = snapshot; + return true; + } + + /// + /// Move to the end of this aggregate and export the state of the . + /// + /// The reader positioned at the end of the data; this is commonly + /// used to update a tree reader, to get to the next data after the aggregate. + public void MovePast(out RespReader reader) + { + while (MoveNext()) { } + reader = _reader; + } + + public void DemandNext() + { + if (!MoveNext()) ThrowEOF(); + Value.MoveNext(); // skip any attributes etc + } + + public T ReadOne(Projection projection) + { + DemandNext(); + return projection(ref Value); + } + + public void FillAll(scoped Span target, Projection projection) + { + for (int i = 0; i < target.Length; i++) + { + if (!MoveNext()) ThrowEOF(); + + Value.MoveNext(); // skip any attributes etc + target[i] = projection(ref Value); + } + } + } + + internal void TrimToTotal(long length) => TrimToRemaining(length - BytesConsumed); + + internal void TrimToRemaining(long bytes) + { + if (_prefix != RespPrefix.None || bytes < 0) Throw(); + + var current = CurrentAvailable; + if (bytes <= current) + { + UnsafeTrimCurrentBy(current - (int)bytes); + _remainingTailLength = 0; + return; + } + + bytes -= current; + if (bytes <= _remainingTailLength) + { + _remainingTailLength = bytes; + return; + } + + Throw(); + static void Throw() => throw new ArgumentOutOfRangeException(nameof(bytes)); + } +} diff --git a/src/RESP.Core/RespReader.Debug.cs b/src/RESP.Core/RespReader.Debug.cs new file mode 100644 index 000000000..9e911911c --- /dev/null +++ b/src/RESP.Core/RespReader.Debug.cs @@ -0,0 +1,33 @@ +using System.Diagnostics; + +#pragma warning disable IDE0079 // Remove unnecessary suppression +#pragma warning disable CS0282 // There is no defined ordering between fields in multiple declarations of partial struct +#pragma warning restore IDE0079 // Remove unnecessary suppression + +namespace Resp; + +[DebuggerDisplay($"{{{nameof(GetDebuggerDisplay)}(),nq}}")] +public ref partial struct RespReader +{ + internal bool DebugEquals(in RespReader other) + => _prefix == other._prefix + && _length == other._length + && _flags == other._flags + && _bufferIndex == other._bufferIndex + && _positionBase == other._positionBase + && _remainingTailLength == other._remainingTailLength; + + internal new string ToString() => $"{Prefix} ({_flags}); length {_length}, {TotalAvailable} remaining"; + + internal void DebugReset() + { + _bufferIndex = 0; + _length = 0; + _flags = 0; + _prefix = RespPrefix.None; + } + +#if DEBUG + internal bool VectorizeDisabled { get; set; } +#endif +} diff --git a/src/RESP.Core/RespReader.ScalarEnumerator.cs b/src/RESP.Core/RespReader.ScalarEnumerator.cs new file mode 100644 index 000000000..9169ad709 --- /dev/null +++ b/src/RESP.Core/RespReader.ScalarEnumerator.cs @@ -0,0 +1,107 @@ +using System; +using System.Buffers; +using System.Collections; +using System.Collections.Generic; + +#pragma warning disable IDE0079 // Remove unnecessary suppression +#pragma warning disable CS0282 // There is no defined ordering between fields in multiple declarations of partial struct +#pragma warning restore IDE0079 // Remove unnecessary suppression + +namespace Resp; + +public ref partial struct RespReader +{ + /// + /// Gets the chunks associated with a scalar value. + /// + public readonly ScalarEnumerator ScalarChunks() => new(in this); + + /// + /// Allows enumeration of chunks in a scalar value; this includes simple values + /// that span multiple segments, and streaming + /// scalar RESP values. + /// + public ref struct ScalarEnumerator + { + /// + public readonly ScalarEnumerator GetEnumerator() => this; + + private RespReader _reader; + + private ReadOnlySpan _current; + private ReadOnlySequenceSegment? _tail; + private int _offset, _remaining; + + /// + /// Create a new enumerator for the specified . + /// + /// The reader containing the data for this operation. + public ScalarEnumerator(scoped in RespReader reader) + { + reader.DemandScalar(); + _reader = reader; + InitSegment(); + } + + private void InitSegment() + { + _current = _reader.CurrentSpan(); + _tail = _reader._tail; + _offset = CurrentLength = 0; + _remaining = _reader._length; + if (_reader.TotalAvailable < _remaining) ThrowEOF(); + } + + /// + public bool MoveNext() + { + while (true) // for each streaming element + { + _offset += CurrentLength; + while (_remaining > 0) // for each span in the current element + { + // look in the active span + var take = Math.Min(_remaining, _current.Length - _offset); + if (take > 0) // more in the current chunk + { + _remaining -= take; + CurrentLength = take; + return true; + } + + // otherwise, we expect more tail data + if (_tail is null) ThrowEOF(); + + _current = _tail.Memory.Span; + _offset = 0; + _tail = _tail.Next; + } + + if (!_reader.MoveNextStreamingScalar()) break; + InitSegment(); + } + + CurrentLength = 0; + return false; + } + + /// + public readonly ReadOnlySpan Current => _current.Slice(_offset, CurrentLength); + + /// + /// Gets the or . + /// + public int CurrentLength { readonly get; private set; } + + /// + /// Move to the end of this aggregate and export the state of the . + /// + /// The reader positioned at the end of the data; this is commonly + /// used to update a tree reader, to get to the next data after the aggregate. + public void MovePast(out RespReader reader) + { + while (MoveNext()) { } + reader = _reader; + } + } +} diff --git a/src/RESP.Core/RespReader.Span.cs b/src/RESP.Core/RespReader.Span.cs new file mode 100644 index 000000000..796ecc397 --- /dev/null +++ b/src/RESP.Core/RespReader.Span.cs @@ -0,0 +1,85 @@ +#define USE_UNSAFE_SPAN + +using System; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +#pragma warning disable IDE0079 // Remove unnecessary suppression +#pragma warning disable CS0282 // There is no defined ordering between fields in multiple declarations of partial struct +#pragma warning restore IDE0079 // Remove unnecessary suppression + +namespace Resp; + +/* + How we actually implement the underlying buffer depends on the capabilities of the runtime. + */ + +#if NET7_0_OR_GREATER && USE_UNSAFE_SPAN + +public ref partial struct RespReader +{ + // intent: avoid lots of slicing by dealing with everything manually, and accepting the "don't get it wrong" rule + private ref byte _bufferRoot; + private int _bufferLength; + + private partial void UnsafeTrimCurrentBy(int count) + { + Debug.Assert(count >= 0 && count <= _bufferLength, "Unsafe trim length"); + _bufferLength -= count; + } + + private readonly partial ref byte UnsafeCurrent + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => ref Unsafe.Add(ref _bufferRoot, _bufferIndex); + } + + private readonly partial int CurrentLength + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => _bufferLength; + } + + private readonly partial ReadOnlySpan CurrentSpan() => MemoryMarshal.CreateReadOnlySpan( + ref UnsafeCurrent, CurrentAvailable); + + private readonly partial ReadOnlySpan UnsafePastPrefix() => MemoryMarshal.CreateReadOnlySpan( + ref Unsafe.Add(ref _bufferRoot, _bufferIndex + 1), + _bufferLength - (_bufferIndex + 1)); + + private partial void SetCurrent(ReadOnlySpan value) + { + _bufferRoot = ref MemoryMarshal.GetReference(value); + _bufferLength = value.Length; + } +} +#else +public ref partial struct RespReader // much more conservative - uses slices etc +{ + private ReadOnlySpan _buffer; + + private partial void UnsafeTrimCurrentBy(int count) + { + _buffer = _buffer.Slice(0, _buffer.Length - count); + } + + private readonly partial ref byte UnsafeCurrent + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => ref Unsafe.AsRef(in _buffer[_bufferIndex]); // hack around CS8333 + } + + private readonly partial int CurrentLength + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => _buffer.Length; + } + + private readonly partial ReadOnlySpan UnsafePastPrefix() => _buffer.Slice(_bufferIndex + 1); + + private readonly partial ReadOnlySpan CurrentSpan() => _buffer.Slice(_bufferIndex); + + private partial void SetCurrent(ReadOnlySpan value) => _buffer = value; +} +#endif diff --git a/src/RESP.Core/RespReader.Utils.cs b/src/RESP.Core/RespReader.Utils.cs new file mode 100644 index 000000000..a5302fb13 --- /dev/null +++ b/src/RESP.Core/RespReader.Utils.cs @@ -0,0 +1,318 @@ +using System; +using System.Buffers.Text; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Runtime.CompilerServices; + +#pragma warning disable IDE0079 // Remove unnecessary suppression +#pragma warning disable CS0282 // There is no defined ordering between fields in multiple declarations of partial struct +#pragma warning restore IDE0079 // Remove unnecessary suppression + +namespace Resp; + +public ref partial struct RespReader +{ + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void UnsafeAssertClLf(int offset) => UnsafeAssertClLf(ref UnsafeCurrent, offset); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void UnsafeAssertClLf(scoped ref byte source, int offset) + { + if (Unsafe.ReadUnaligned(ref Unsafe.Add(ref source, offset)) != RespConstants.CrLfUInt16) + { + ThrowProtocolFailure("Expected CR/LF"); + } + } + + private enum LengthPrefixResult + { + NeedMoreData, + Length, + Null, + Streaming, + } + + /// + /// Asserts that the current element is a scalar type. + /// + public readonly void DemandScalar() + { + if (!IsScalar) Throw(Prefix); + static void Throw(RespPrefix prefix) => throw new InvalidOperationException($"This operation requires a scalar element, got {prefix}"); + } + + /// + /// Asserts that the current element is a scalar type. + /// + public readonly void DemandAggregate() + { + if (!IsAggregate) Throw(Prefix); + static void Throw(RespPrefix prefix) => throw new InvalidOperationException($"This operation requires an aggregate element, got {prefix}"); + } + + private static LengthPrefixResult TryReadLengthPrefix(ReadOnlySpan bytes, out int value, out int byteCount) + { + var end = bytes.IndexOf(RespConstants.CrlfBytes); + if (end < 0) + { + byteCount = value = 0; + if (bytes.Length >= RespConstants.MaxRawBytesInt32 + 2) + { + ThrowProtocolFailure("Unterminated or over-length integer"); // should have failed; report failure to prevent infinite loop + } + return LengthPrefixResult.NeedMoreData; + } + byteCount = end + 2; + switch (end) + { + case 0: + ThrowProtocolFailure("Length prefix expected"); + goto case default; // not reached, just satisfying definite assignment + case 1 when bytes[0] == (byte)'?': + value = 0; + return LengthPrefixResult.Streaming; + default: + if (end > RespConstants.MaxRawBytesInt32 || !(Utf8Parser.TryParse(bytes, out value, out var consumed) && consumed == end)) + { + ThrowProtocolFailure("Unable to parse integer"); + value = 0; + } + if (value < 0) + { + if (value == -1) + { + value = 0; + return LengthPrefixResult.Null; + } + ThrowProtocolFailure("Invalid negative length prefix"); + } + return LengthPrefixResult.Length; + } + } + + private readonly RespReader Clone() => this; // useful for performing streaming operations without moving the primary + + [MethodImpl(MethodImplOptions.NoInlining), DoesNotReturn] + private static void ThrowProtocolFailure(string message) + => throw new InvalidOperationException("RESP protocol failure: " + message); // protocol exception? + + [MethodImpl(MethodImplOptions.NoInlining), DoesNotReturn] + internal static void ThrowEOF() => throw new EndOfStreamException(); + + [MethodImpl(MethodImplOptions.NoInlining), DoesNotReturn] + private static void ThrowFormatException() => throw new FormatException(); + + private int RawTryReadByte() + { + if (_bufferIndex < CurrentLength || TryMoveToNextSegment()) + { + var result = UnsafeCurrent; + _bufferIndex++; + return result; + } + return -1; + } + + private int RawPeekByte() + { + return (CurrentLength < _bufferIndex || TryMoveToNextSegment()) ? UnsafeCurrent : -1; + } + + private bool RawAssertCrLf() + { + if (CurrentAvailable >= 2) + { + UnsafeAssertClLf(0); + _bufferIndex += 2; + return true; + } + else + { + int next = RawTryReadByte(); + if (next < 0) return false; + if (next == '\r') + { + next = RawTryReadByte(); + if (next < 0) return false; + if (next == '\n') return true; + } + ThrowProtocolFailure("Expected CR/LF"); + return false; + } + } + + private LengthPrefixResult RawTryReadLengthPrefix() + { + _length = 0; + if (!RawTryFindCrLf(out int end)) + { + if (TotalAvailable >= RespConstants.MaxRawBytesInt32 + 2) + { + ThrowProtocolFailure("Unterminated or over-length integer"); // should have failed; report failure to prevent infinite loop + } + return LengthPrefixResult.NeedMoreData; + } + + switch (end) + { + case 0: + ThrowProtocolFailure("Length prefix expected"); + goto case default; // not reached, just satisfying definite assignment + case 1: + var b = (byte)RawTryReadByte(); + RawAssertCrLf(); + if (b == '?') + { + return LengthPrefixResult.Streaming; + } + else + { + _length = ParseSingleDigit(b); + return LengthPrefixResult.Length; + } + default: + if (end > RespConstants.MaxRawBytesInt32) + { + ThrowProtocolFailure("Unable to parse integer"); + } + Span bytes = stackalloc byte[end]; + RawFillBytes(bytes); + RawAssertCrLf(); + if (!(Utf8Parser.TryParse(bytes, out _length, out var consumed) && consumed == end)) + { + ThrowProtocolFailure("Unable to parse integer"); + } + + if (_length < 0) + { + if (_length == -1) + { + _length = 0; + return LengthPrefixResult.Null; + } + ThrowProtocolFailure("Invalid negative length prefix"); + } + + return LengthPrefixResult.Length; + } + } + + private void RawFillBytes(scoped Span target) + { + do + { + var current = CurrentSpan(); + if (current.Length >= target.Length) + { + // more than enough, need to trim + current.Slice(0, target.Length).CopyTo(target); + _bufferIndex += target.Length; + return; // we're done + } + else + { + // take what we can + current.CopyTo(target); + target = target.Slice(current.Length); + // we could move _bufferIndex here, but we're about to trash that in TryMoveToNextSegment + } + } + while (TryMoveToNextSegment()); + ThrowEOF(); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static int ParseSingleDigit(byte value) + { + return value switch + { + (byte)'0' or (byte)'1' or (byte)'2' or (byte)'3' or (byte)'4' or (byte)'5' or (byte)'6' or (byte)'7' or (byte)'8' or (byte)'9' => value - (byte)'0', + _ => Invalid(value), + }; + + [MethodImpl(MethodImplOptions.NoInlining), DoesNotReturn] + static int Invalid(byte value) => throw new FormatException($"Unable to parse integer: '{(char)value}'"); + } + + private readonly bool RawTryAssertInlineScalarPayloadCrLf() + { + Debug.Assert(IsInlineScalar, "should be inline scalar"); + + var reader = Clone(); + var len = reader._length; + if (len == 0) return reader.RawAssertCrLf(); + + do + { + var current = reader.CurrentSpan(); + if (current.Length >= len) + { + reader._bufferIndex += len; + return reader.RawAssertCrLf(); // we're done + } + else + { + // take what we can + len -= current.Length; + // we could move _bufferIndex here, but we're about to trash that in TryMoveToNextSegment + } + } + while (reader.TryMoveToNextSegment()); + return false; // EOF + } + + private readonly bool RawTryFindCrLf(out int length) + { + length = 0; + RespReader reader = Clone(); + do + { + var span = reader.CurrentSpan(); + var index = span.IndexOf((byte)'\r'); + if (index >= 0) + { + checked + { + length += index; + } + // move past the CR and assert the LF + reader._bufferIndex += index + 1; + var next = reader.RawTryReadByte(); + if (next < 0) break; // we don't know + if (next != '\n') ThrowProtocolFailure("CR/LF expected"); + + return true; + } + checked + { + length += span.Length; + } + } + while (reader.TryMoveToNextSegment()); + length = 0; + return false; + } + + private string GetDebuggerDisplay() + { + return ToString(); + } + + internal int GetInitialScanCount(out ushort streamingAggregateDepth) + { + // this is *similar* to GetDelta, but: without any discount for attributes + switch (_flags & (RespFlags.IsAggregate | RespFlags.IsStreaming)) + { + case RespFlags.IsAggregate: + streamingAggregateDepth = 0; + return _length - 1; + case RespFlags.IsAggregate | RespFlags.IsStreaming: + streamingAggregateDepth = 1; + return 0; + default: + streamingAggregateDepth = 0; + return -1; + } + } +} diff --git a/src/RESP.Core/RespReader.cs b/src/RESP.Core/RespReader.cs new file mode 100644 index 000000000..1b0f0cead --- /dev/null +++ b/src/RESP.Core/RespReader.cs @@ -0,0 +1,1599 @@ +using System; +using System.Buffers; +using System.Buffers.Text; +using System.ComponentModel; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; +using System.IO; +using System.Runtime.CompilerServices; +using System.Text; + +#if NETCOREAPP3_0_OR_GREATER +using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.X86; +#endif + +#pragma warning disable IDE0079 // Remove unnecessary suppression +#pragma warning disable CS0282 // There is no defined ordering between fields in multiple declarations of partial struct +#pragma warning restore IDE0079 // Remove unnecessary suppression + +namespace Resp; + +/// +/// Provides low level RESP parsing functionality. +/// +public ref partial struct RespReader +{ + [Flags] + private enum RespFlags : byte + { + None = 0, + IsScalar = 1 << 0, // simple strings, bulk strings, etc + IsAggregate = 1 << 1, // arrays, maps, sets, etc + IsNull = 1 << 2, // explicit null RESP types, or bulk-strings/aggregates with length -1 + IsInlineScalar = 1 << 3, // a non-null scalar, i.e. with payload+CrLf + IsAttribute = 1 << 4, // is metadata for following elements + IsStreaming = 1 << 5, // unknown length + IsError = 1 << 6, // an explicit error reported inside the protocol + } + + // relates to the element we're currently reading + private RespFlags _flags; + private RespPrefix _prefix; + + private int _length; // for null: 0; for scalars: the length of the payload; for aggregates: the child count + + // the current buffer that we're observing + private int _bufferIndex; // after TryRead, this should be positioned immediately before the actual data + + // the position in a multi-segment payload + private long _positionBase; // total data we've already moved past in *previous* buffers + private ReadOnlySequenceSegment? _tail; // the next tail node + private long _remainingTailLength; // how much more can we consume from the tail? + + public long ProtocolBytesRemaining => TotalAvailable; + + private readonly int CurrentAvailable => CurrentLength - _bufferIndex; + + private readonly long TotalAvailable => CurrentAvailable + _remainingTailLength; + private partial void UnsafeTrimCurrentBy(int count); + private readonly partial ref byte UnsafeCurrent { get; } + private readonly partial int CurrentLength { get; } + private partial void SetCurrent(ReadOnlySpan value); + private RespPrefix UnsafePeekPrefix() => (RespPrefix)UnsafeCurrent; + private readonly partial ReadOnlySpan UnsafePastPrefix(); + private readonly partial ReadOnlySpan CurrentSpan(); + + /// + /// Get the scalar value as a single-segment span. + /// + /// True if this is a non-streaming scalar element that covers a single span only, otherwise False. + /// If a scalar reports False, can be used to iterate the entire payload. + /// When True, the contents of the scalar value. + public readonly bool TryGetSpan(out ReadOnlySpan value) + { + if (IsInlineScalar && CurrentAvailable >= _length) + { + value = CurrentSpan().Slice(0, _length); + return true; + } + + value = default; + return IsNullScalar; + } + + /// + /// Returns the position after the end of the current element. + /// + public readonly long BytesConsumed => _positionBase + _bufferIndex + TrailingLength; + + /// + /// Body length of scalar values, plus any terminating sentinels. + /// + private readonly int TrailingLength => (_flags & RespFlags.IsInlineScalar) == 0 ? 0 : (_length + 2); + + /// + /// Gets the RESP kind of the current element. + /// + public readonly RespPrefix Prefix => _prefix; + + /// + /// The payload length of this scalar element (includes combined length for streaming scalars). + /// + public readonly int ScalarLength() => IsInlineScalar ? _length : IsNullScalar ? 0 : checked((int)ScalarLengthSlow()); + + /// + /// Indicates whether this scalar value is zero-length. + /// + public readonly bool ScalarIsEmpty() => IsInlineScalar ? _length == 0 : (IsNullScalar || !ScalarChunks().MoveNext()); + + /// + /// The payload length of this scalar element (includes combined length for streaming scalars). + /// + public readonly long ScalarLongLength() => IsInlineScalar ? _length : IsNullScalar ? 0 : ScalarLengthSlow(); + + private readonly long ScalarLengthSlow() + { + DemandScalar(); + long length = 0; + var iterator = ScalarChunks(); + while (iterator.MoveNext()) + { + length += iterator.CurrentLength; + } + return length; + } + + /// + /// The number of child elements associated with an aggregate. + /// + /// For + /// and aggregates, this is twice the value reported in the RESP protocol, + /// i.e. a map of the form %2\r\n... will report 4 as the length. + /// Note that if the data could be streaming (), it may be preferable to use + /// the API, using the API to update the outer reader. + public readonly int AggregateLength() => (_flags & (RespFlags.IsAggregate | RespFlags.IsStreaming)) == RespFlags.IsAggregate + ? _length : AggregateLengthSlow(); + + public delegate T Projection(ref RespReader value); + + public void FillAll(scoped Span target, Projection projection) + { + DemandNotNull(); + AggregateChildren().FillAll(target, projection); + } + + private readonly int AggregateLengthSlow() + { + switch (_flags & (RespFlags.IsAggregate | RespFlags.IsStreaming)) + { + case RespFlags.IsAggregate: + return _length; + case RespFlags.IsAggregate | RespFlags.IsStreaming: + break; + default: + DemandAggregate(); // we expect this to throw + break; + } + + int count = 0; + var reader = Clone(); + while (true) + { + if (!reader.TryMoveNext()) ThrowEOF(); + if (reader.Prefix == RespPrefix.StreamTerminator) + { + return count; + } + reader.SkipChildren(); + count++; + } + } + + /// + /// Indicates whether this is a scalar value, i.e. with a potential payload body. + /// + public readonly bool IsScalar => (_flags & RespFlags.IsScalar) != 0; + + internal readonly bool IsInlineScalar => (_flags & RespFlags.IsInlineScalar) != 0; + + internal readonly bool IsNullScalar => (_flags & (RespFlags.IsScalar | RespFlags.IsNull)) == (RespFlags.IsScalar | RespFlags.IsNull); + + /// + /// Indicates whether this is an aggregate value, i.e. represents a collection of sub-values. + /// + public readonly bool IsAggregate => (_flags & RespFlags.IsAggregate) != 0; + + /// + /// Indicates whether this is a null value; this could be an explicit , + /// or a scalar or aggregate a negative reported length. + /// + public readonly bool IsNull => (_flags & RespFlags.IsNull) != 0; + + /// + /// Indicates whether this is an attribute value, i.e. metadata relating to later element data. + /// + public readonly bool IsAttribute => (_flags & RespFlags.IsAttribute) != 0; + + /// + /// Indicates whether this represents streaming content, where the or is not known in advance. + /// + public readonly bool IsStreaming => (_flags & RespFlags.IsStreaming) != 0; + + /// + /// Equivalent to both and . + /// + internal readonly bool IsStreamingScalar => (_flags & (RespFlags.IsScalar | RespFlags.IsStreaming)) == (RespFlags.IsScalar | RespFlags.IsStreaming); + + /// + /// Indicates errors reported inside the protocol. + /// + public readonly bool IsError => (_flags & RespFlags.IsError) != 0; + + /// + /// Gets the effective change (in terms of how many RESP nodes we expect to see) from consuming this element. + /// For simple scalars, this is -1 because we have one less node to read; for simple aggregates, this is + /// AggregateLength-1 because we will have consumed one element, but now need to read the additional + /// child elements. Attributes report 0, since they supplement data + /// we still need to consume. The final terminator for streaming data reports a delta of -1, otherwise: 0. + /// + /// This does not account for being nested inside a streaming aggregate; the caller must deal with that manually. + internal int Delta() => (_flags & (RespFlags.IsScalar | RespFlags.IsAggregate | RespFlags.IsStreaming | RespFlags.IsAttribute)) switch + { + RespFlags.IsScalar => -1, + RespFlags.IsAggregate => _length - 1, + RespFlags.IsAggregate | RespFlags.IsAttribute => _length, + _ => 0, + }; + + /// + /// Assert that this is the final element in the current payload. + /// + /// If additional elements are available. + public void DemandEnd() + { + while (IsStreamingScalar) + { + if (!TryReadNext()) ThrowEOF(); + } + if (TryReadNext()) + { + Throw(Prefix); + } + static void Throw(RespPrefix prefix) => throw new InvalidOperationException($"Expected end of payload, but found {prefix}"); + } + + private bool TryReadNextSkipAttributes() + { + while (TryReadNext()) + { + if (IsAttribute) + { + SkipChildren(); + } + else + { + return true; + } + } + return false; + } + + private bool TryReadNextProcessAttributes(RespAttributeReader respAttributeReader, ref T attributes) + { + while (TryReadNext()) + { + if (IsAttribute) + { + respAttributeReader.Read(ref this, ref attributes); + } + else + { + return true; + } + } + return false; + } + + /// + /// Move to the next content element; this skips attribute metadata, checking for RESP error messages by default. + /// + /// If the data is exhausted before a streaming scalar is exhausted. + /// If the data contains an explicit error element. + public bool TryMoveNext() + { + while (IsStreamingScalar) // close out the current streaming scalar + { + if (!TryReadNextSkipAttributes()) ThrowEOF(); + } + + if (TryReadNextSkipAttributes()) + { + if (IsError) ThrowError(); + return true; + } + return false; + } + + /// + /// Move to the next content element; this skips attribute metadata, checking for RESP error messages by default. + /// + /// Whether to check and throw for error messages. + /// If the data is exhausted before a streaming scalar is exhausted. + /// If the data contains an explicit error element. + public bool TryMoveNext(bool checkError) + { + while (IsStreamingScalar) // close out the current streaming scalar + { + if (!TryReadNextSkipAttributes()) ThrowEOF(); + } + + if (TryReadNextSkipAttributes()) + { + if (checkError && IsError) ThrowError(); + return true; + } + return false; + } + + /// + /// Move to the next content element; this skips attribute metadata, checking for RESP error messages by default. + /// + /// Parser for attribute data preceding the data. + /// The state for attributes encountered. + /// If the data is exhausted before a streaming scalar is exhausted. + /// If the data contains an explicit error element. + /// The type of data represented by this reader. + public bool TryMoveNext(RespAttributeReader respAttributeReader, ref T attributes) + { + while (IsStreamingScalar) // close out the current streaming scalar + { + if (!TryReadNextSkipAttributes()) ThrowEOF(); + } + + if (TryReadNextProcessAttributes(respAttributeReader, ref attributes)) + { + if (IsError) ThrowError(); + return true; + } + return false; + } + + /// + /// Move to the next content element, asserting that it is of the expected type; this skips attribute metadata, checking for RESP error messages by default. + /// + /// The expected data type. + /// If the data is exhausted before a streaming scalar is exhausted. + /// If the data contains an explicit error element. + /// If the data is not of the expected type. + public bool TryMoveNext(RespPrefix prefix) + { + bool result = TryMoveNext(); + if (result) Demand(prefix); + return result; + } + + /// + /// Move to the next content element; this skips attribute metadata, checking for RESP error messages by default. + /// + /// If the data is exhausted before content is found. + /// If the data contains an explicit error element. + public void MoveNext() + { + if (!TryMoveNext()) ThrowEOF(); + } + + /// + /// Move to the next content element; this skips attribute metadata, checking for RESP error messages by default. + /// + /// Parser for attribute data preceding the data. + /// The state for attributes encountered. + /// If the data is exhausted before content is found. + /// If the data contains an explicit error element. + /// The type of data represented by this reader. + public void MoveNext(RespAttributeReader respAttributeReader, ref T attributes) + { + if (!TryMoveNext(respAttributeReader, ref attributes)) ThrowEOF(); + } + + private bool MoveNextStreamingScalar() + { + if (IsStreamingScalar) + { + while (TryReadNext()) + { + if (IsAttribute) + { + SkipChildren(); + } + else + { + if (Prefix != RespPrefix.StreamContinuation) ThrowProtocolFailure("Streaming continuation expected"); + return _length > 0; + } + } + ThrowEOF(); // we should have found something! + } + return false; + } + + /// + /// Move to the next content element () and assert that it is a scalar (). + /// + /// If the data is exhausted before content is found. + /// If the data contains an explicit error element. + /// If the data is not a scalar type. + public void MoveNextScalar() + { + MoveNext(); + DemandScalar(); + } + + /// + /// Move to the next content element () and assert that it is an aggregate (). + /// + /// If the data is exhausted before content is found. + /// If the data contains an explicit error element. + /// If the data is not an aggregate type. + public void MoveNextAggregate() + { + MoveNext(); + DemandAggregate(); + } + + /// + /// Move to the next content element () and assert that it of type specified + /// in . + /// + /// The expected data type. + /// Parser for attribute data preceding the data. + /// The state for attributes encountered. + /// If the data is exhausted before content is found. + /// If the data contains an explicit error element. + /// If the data is not of the expected type. + /// The type of data represented by this reader. + public void MoveNext(RespPrefix prefix, RespAttributeReader respAttributeReader, ref T attributes) + { + MoveNext(respAttributeReader, ref attributes); + Demand(prefix); + } + + /// + /// Move to the next content element () and assert that it of type specified + /// in . + /// + /// The expected data type. + /// If the data is exhausted before content is found. + /// If the data contains an explicit error element. + /// If the data is not of the expected type. + public void MoveNext(RespPrefix prefix) + { + MoveNext(); + Demand(prefix); + } + + internal void Demand(RespPrefix prefix) + { + if (Prefix != prefix) Throw(prefix, Prefix); + static void Throw(RespPrefix expected, RespPrefix actual) => throw new InvalidOperationException($"Expected {expected} element, but found {actual}."); + } + + private readonly void ThrowError() => throw new RespException(ReadString()!); + + /// + /// Skip all sub elements of the current node; this includes both aggregate children and scalar streaming elements. + /// + public void SkipChildren() + { + // if this is a simple non-streaming scalar, then: there's nothing complex to do; otherwise, re-use the + // frame scanner logic to seek past the noise (this way, we avoid recursion etc) + switch (_flags & (RespFlags.IsScalar | RespFlags.IsAggregate | RespFlags.IsStreaming)) + { + case RespFlags.None: + // no current element + break; + case RespFlags.IsScalar: + // simple scalar + MovePastCurrent(); + break; + default: + // something more complex + RespScanState state = new(in this); + if (!state.TryRead(ref this, out _)) ThrowEOF(); + break; + } + } + + /// + /// Reads the current element as a string value. + /// + public readonly string? ReadString() => ReadString(out _); + + /// + /// Reads the current element as a string value. + /// + public readonly string? ReadString(out string prefix) + { + byte[] pooled = []; + try + { + var span = Buffer(ref pooled, stackalloc byte[256]); + prefix = ""; + if (span.IsEmpty) + { + return IsNull ? null : ""; + } + if (Prefix == RespPrefix.VerbatimString + && span.Length >= 4 && span[3] == ':') + { + // "the first three bytes provide information about the format of the following string, + // which can be txt for plain text, or mkd for markdown. The fourth byte is always :. + // Then the real string follows." + var prefixValue = RespConstants.UnsafeCpuUInt32(span); + if (prefixValue == PrefixTxt) + { + prefix = "txt"; + } + else if (prefixValue == PrefixMkd) + { + prefix = "mkd"; + } + else + { + prefix = RespConstants.UTF8.GetString(span.Slice(0, 3)); + } + span = span.Slice(4); + } + return RespConstants.UTF8.GetString(span); + } + finally + { + ArrayPool.Shared.Return(pooled); + } + } + + private static readonly uint + PrefixTxt = RespConstants.UnsafeCpuUInt32("txt:"u8), + PrefixMkd = RespConstants.UnsafeCpuUInt32("mkd:"u8); + + /// + /// Reads the current element as a string value. + /// + public readonly byte[]? ReadByteArray() + { + byte[] pooled = []; + try + { + var span = Buffer(ref pooled, stackalloc byte[256]); + if (span.IsEmpty) + { + return IsNull ? null : []; + } + return span.ToArray(); + } + finally + { + ArrayPool.Shared.Return(pooled); + } + } + + /// + /// Reads the current element using a general purpose text parser. + /// + /// The type of data being parsed. + public readonly T ParseBytes(Parser parser) + { + byte[] pooled = []; + var span = Buffer(ref pooled, stackalloc byte[256]); + try + { + return parser(span); + } + finally + { + ArrayPool.Shared.Return(pooled); + } + } + + /// + /// Reads the current element using a general purpose text parser. + /// + /// The type of data being parsed. + /// State required by the parser. + public readonly T ParseBytes(Parser parser, TState? state) + { + byte[] pooled = []; + var span = Buffer(ref pooled, stackalloc byte[256]); + try + { + return parser(span, default); + } + finally + { + ArrayPool.Shared.Return(pooled); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal readonly ReadOnlySpan Buffer(Span target) + { + if (TryGetSpan(out var simple)) + { + return simple; + } + +#if NET6_0_OR_GREATER + return BufferSlow(ref Unsafe.NullRef(), target, usePool: false); +#else + byte[] pooled = []; + return BufferSlow(ref pooled, target, usePool: false); +#endif + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal readonly ReadOnlySpan Buffer(scoped ref byte[] pooled, Span target = default) + => TryGetSpan(out var simple) ? simple : BufferSlow(ref pooled, target, true); + + [MethodImpl(MethodImplOptions.NoInlining)] + private readonly ReadOnlySpan BufferSlow(scoped ref byte[] pooled, Span target, bool usePool) + { + DemandScalar(); + + if (IsInlineScalar && usePool) + { + // grow to the correct size in advance, if needed + var length = ScalarLength(); + if (length > target.Length) + { + var bigger = ArrayPool.Shared.Rent(length); + ArrayPool.Shared.Return(pooled); + target = pooled = bigger; + } + } + + var iterator = ScalarChunks(); + ReadOnlySpan current; + int offset = 0; + while (iterator.MoveNext()) + { + // will the current chunk fit? + current = iterator.Current; + if (current.TryCopyTo(target.Slice(offset))) + { + // fits into the current buffer + offset += current.Length; + } + else if (!usePool) + { + // rent disallowed; fill what we can + var available = target.Slice(offset); + current.Slice(0, available.Length).CopyTo(available); + return target; // we filled it + } + else + { + // rent a bigger buffer, copy and recycle + var bigger = ArrayPool.Shared.Rent(offset + current.Length); + if (offset != 0) + { + target.Slice(0, offset).CopyTo(bigger); + } + ArrayPool.Shared.Return(pooled); + target = pooled = bigger; + current.CopyTo(target.Slice(offset)); + } + } + return target.Slice(0, offset); + } + + /// + /// Reads the current element using a general purpose byte parser. + /// + /// The type of data being parsed. + public readonly T ParseChars(Parser parser) + { + byte[] bArr = []; + char[] cArr = []; + try + { + var bSpan = Buffer(ref bArr, stackalloc byte[128]); + var maxChars = RespConstants.UTF8.GetMaxCharCount(bSpan.Length); + Span cSpan = maxChars <= 128 ? stackalloc char[128] : (cArr = ArrayPool.Shared.Rent(maxChars)); + int chars = RespConstants.UTF8.GetChars(bSpan, cSpan); + return parser(cSpan.Slice(0, chars)); + } + finally + { + ArrayPool.Shared.Return(bArr); + ArrayPool.Shared.Return(cArr); + } + } + + /// + /// Reads the current element using a general purpose byte parser. + /// + /// The type of data being parsed. + /// State required by the parser. + public readonly T ParseChars(Parser parser, TState? state) + { + byte[] bArr = []; + char[] cArr = []; + try + { + var bSpan = Buffer(ref bArr, stackalloc byte[128]); + var maxChars = RespConstants.UTF8.GetMaxCharCount(bSpan.Length); + Span cSpan = maxChars <= 128 ? stackalloc char[128] : (cArr = ArrayPool.Shared.Rent(maxChars)); + int chars = RespConstants.UTF8.GetChars(bSpan, cSpan); + return parser(cSpan.Slice(0, chars), state); + } + finally + { + ArrayPool.Shared.Return(bArr); + ArrayPool.Shared.Return(cArr); + } + } + +#if NET7_0_OR_GREATER + /// + /// Reads the current element using . + /// + /// The type of data being parsed. +#pragma warning disable RS0016, RS0027 // back-compat overload + public readonly T ParseChars(IFormatProvider? formatProvider = null) where T : ISpanParsable +#pragma warning restore RS0016, RS0027 // back-compat overload + { + byte[] bArr = []; + char[] cArr = []; + try + { + var bSpan = Buffer(ref bArr, stackalloc byte[128]); + var maxChars = RespConstants.UTF8.GetMaxCharCount(bSpan.Length); + Span cSpan = maxChars <= 128 ? stackalloc char[128] : (cArr = ArrayPool.Shared.Rent(maxChars)); + int chars = RespConstants.UTF8.GetChars(bSpan, cSpan); + return T.Parse(cSpan.Slice(0, chars), formatProvider ?? CultureInfo.InvariantCulture); + } + finally + { + ArrayPool.Shared.Return(bArr); + ArrayPool.Shared.Return(cArr); + } + } +#endif + +#if NET8_0_OR_GREATER + /// + /// Reads the current element using . + /// + /// The type of data being parsed. +#pragma warning disable RS0016, RS0027 // back-compat overload + public readonly T ParseBytes(IFormatProvider? formatProvider = null) where T : IUtf8SpanParsable +#pragma warning restore RS0016, RS0027 // back-compat overload + { + byte[] bArr = []; + try + { + var bSpan = Buffer(ref bArr, stackalloc byte[128]); + return T.Parse(bSpan, formatProvider ?? CultureInfo.InvariantCulture); + } + finally + { + ArrayPool.Shared.Return(bArr); + } + } +#endif + + /// + /// General purpose parsing callback. + /// + /// The type of source data being parsed. + /// State required by the parser. + /// The output type of data being parsed. + public delegate TValue Parser(ReadOnlySpan value, TState? state); + + /// + /// General purpose parsing callback. + /// + /// The type of source data being parsed. + /// The output type of data being parsed. + public delegate TValue Parser(ReadOnlySpan value); + + /// + /// Initializes a new instance of the struct. + /// + /// The raw contents to parse with this instance. + public RespReader(ReadOnlySpan value) + { + _length = 0; + _flags = RespFlags.None; + _prefix = RespPrefix.None; + SetCurrent(value); + + _remainingTailLength = _positionBase = 0; + _tail = null; + } + + private void MovePastCurrent() + { + // skip past the trailing portion of a value, if any + var skip = TrailingLength; + if (_bufferIndex + skip <= CurrentLength) + { + _bufferIndex += skip; // available in the current buffer + } + else + { + AdvanceSlow(skip); + } + + // reset the current state + _length = 0; + _flags = 0; + _prefix = RespPrefix.None; + } + + /// + public RespReader(scoped in ReadOnlySequence value) +#if NETCOREAPP3_0_OR_GREATER + : this(value.FirstSpan) +#else + : this(value.First.Span) +#endif + { + if (!value.IsSingleSegment) + { + _remainingTailLength = value.Length - CurrentLength; + _tail = (value.Start.GetObject() as ReadOnlySequenceSegment)?.Next ?? MissingNext(); + } + + [MethodImpl(MethodImplOptions.NoInlining), DoesNotReturn] + static ReadOnlySequenceSegment MissingNext() => throw new ArgumentException("Unable to extract tail segment", nameof(value)); + } + + /// + /// Attempt to move to the next RESP element. + /// + /// Unless you are intentionally handling errors, attributes and streaming data, should be preferred. + [EditorBrowsable(EditorBrowsableState.Never), Browsable(false)] + public unsafe bool TryReadNext() + { + MovePastCurrent(); + +#if NETCOREAPP3_0_OR_GREATER + // check what we have available; don't worry about zero/fetching the next segment; this is only + // for SIMD lookup, and zero would only apply when data ends exactly on segment boundaries, which + // is incredible niche + var available = CurrentAvailable; + + if (Avx2.IsSupported && Bmi1.IsSupported && available >= sizeof(uint)) + { + // read the first 4 bytes + ref byte origin = ref UnsafeCurrent; + var comparand = Unsafe.ReadUnaligned(ref origin); + + // broadcast those 4 bytes into a vector, mask to get just the first and last byte, and apply a SIMD equality test with our known cases + var eqs = Avx2.CompareEqual(Avx2.And(Avx2.BroadcastScalarToVector256(&comparand), Raw.FirstLastMask), Raw.CommonRespPrefixes); + + // reinterpret that as floats, and pick out the sign bits (which will be 1 for "equal", 0 for "not equal"); since the + // test cases are mutually exclusive, we expect zero or one matches, so: lzcount tells us which matched + var index = Bmi1.TrailingZeroCount((uint)Avx.MoveMask(Unsafe.As, Vector256>(ref eqs))); + int len; +#if DEBUG + if (VectorizeDisabled) index = uint.MaxValue; // just to break the switch +#endif + switch (index) + { + case Raw.CommonRespIndex_Success when available >= 5 && Unsafe.Add(ref origin, 4) == (byte)'\n': + _prefix = RespPrefix.SimpleString; + _length = 2; + _bufferIndex++; + _flags = RespFlags.IsScalar | RespFlags.IsInlineScalar; + return true; + case Raw.CommonRespIndex_SingleDigitInteger when Unsafe.Add(ref origin, 2) == (byte)'\r': + _prefix = RespPrefix.Integer; + _length = 1; + _bufferIndex++; + _flags = RespFlags.IsScalar | RespFlags.IsInlineScalar; + return true; + case Raw.CommonRespIndex_DoubleDigitInteger when available >= 5 && Unsafe.Add(ref origin, 4) == (byte)'\n': + _prefix = RespPrefix.Integer; + _length = 2; + _bufferIndex++; + _flags = RespFlags.IsScalar | RespFlags.IsInlineScalar; + return true; + case Raw.CommonRespIndex_SingleDigitString when Unsafe.Add(ref origin, 2) == (byte)'\r': + if (comparand == RespConstants.BulkStringStreaming) + { + _flags = RespFlags.IsScalar | RespFlags.IsStreaming; + } + else + { + len = ParseSingleDigit(Unsafe.Add(ref origin, 1)); + if (available < len + 6) break; // need more data + + UnsafeAssertClLf(4 + len); + _length = len; + _flags = RespFlags.IsScalar | RespFlags.IsInlineScalar; + } + _prefix = RespPrefix.BulkString; + _bufferIndex += 4; + return true; + case Raw.CommonRespIndex_DoubleDigitString when available >= 5 && Unsafe.Add(ref origin, 4) == (byte)'\n': + if (comparand == RespConstants.BulkStringNull) + { + _length = 0; + _flags = RespFlags.IsScalar | RespFlags.IsNull; + } + else + { + len = ParseDoubleDigitsNonNegative(ref Unsafe.Add(ref origin, 1)); + if (available < len + 7) break; // need more data + + UnsafeAssertClLf(5 + len); + _length = len; + _flags = RespFlags.IsScalar | RespFlags.IsInlineScalar; + } + _prefix = RespPrefix.BulkString; + _bufferIndex += 5; + return true; + case Raw.CommonRespIndex_SingleDigitArray when Unsafe.Add(ref origin, 2) == (byte)'\r': + if (comparand == RespConstants.ArrayStreaming) + { + _flags = RespFlags.IsAggregate | RespFlags.IsStreaming; + } + else + { + _flags = RespFlags.IsAggregate; + _length = ParseSingleDigit(Unsafe.Add(ref origin, 1)); + } + _prefix = RespPrefix.Array; + _bufferIndex += 4; + return true; + case Raw.CommonRespIndex_DoubleDigitArray when available >= 5 && Unsafe.Add(ref origin, 4) == (byte)'\n': + if (comparand == RespConstants.ArrayNull) + { + _flags = RespFlags.IsAggregate | RespFlags.IsNull; + } + else + { + _length = ParseDoubleDigitsNonNegative(ref Unsafe.Add(ref origin, 1)); + _flags = RespFlags.IsAggregate; + } + _prefix = RespPrefix.Array; + _bufferIndex += 5; + return true; + case Raw.CommonRespIndex_Error: + len = UnsafePastPrefix().IndexOf(RespConstants.CrlfBytes); + if (len < 0) break; // need more data + + _prefix = RespPrefix.SimpleError; + _flags = RespFlags.IsScalar | RespFlags.IsInlineScalar | RespFlags.IsError; + _length = len; + _bufferIndex++; + return true; + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static int ParseDoubleDigitsNonNegative(ref byte value) => (10 * ParseSingleDigit(value)) + ParseSingleDigit(Unsafe.Add(ref value, 1)); +#endif + + // no fancy vectorization, but: we can still try to find the payload the fast way in a single segment + if (_bufferIndex + 3 <= CurrentLength) // shortest possible RESP fragment is length 3 + { + var remaining = UnsafePastPrefix(); + switch (_prefix = UnsafePeekPrefix()) + { + case RespPrefix.SimpleString: + case RespPrefix.SimpleError: + case RespPrefix.Integer: + case RespPrefix.Boolean: + case RespPrefix.Double: + case RespPrefix.BigInteger: + // CRLF-terminated + _length = remaining.IndexOf(RespConstants.CrlfBytes); + if (_length < 0) break; // can't find, need more data + _bufferIndex++; // payload follows prefix directly + _flags = RespFlags.IsScalar | RespFlags.IsInlineScalar; + if (_prefix == RespPrefix.SimpleError) _flags |= RespFlags.IsError; + return true; + case RespPrefix.BulkError: + case RespPrefix.BulkString: + case RespPrefix.VerbatimString: + // length prefix with value payload; first, the length + switch (TryReadLengthPrefix(remaining, out _length, out int consumed)) + { + case LengthPrefixResult.Length: + // still need to valid terminating CRLF + if (remaining.Length < consumed + _length + 2) break; // need more data + UnsafeAssertClLf(1 + consumed + _length); + + _flags = RespFlags.IsScalar | RespFlags.IsInlineScalar; + break; + case LengthPrefixResult.Null: + _flags = RespFlags.IsScalar | RespFlags.IsNull; + break; + case LengthPrefixResult.Streaming: + _flags = RespFlags.IsScalar | RespFlags.IsStreaming; + break; + } + if (_flags == 0) break; // will need more data to know + if (_prefix == RespPrefix.BulkError) _flags |= RespFlags.IsError; + _bufferIndex += 1 + consumed; + return true; + case RespPrefix.StreamContinuation: + // length prefix, possibly with value payload; first, the length + switch (TryReadLengthPrefix(remaining, out _length, out consumed)) + { + case LengthPrefixResult.Length when _length == 0: + // EOF, no payload + _flags = RespFlags.IsScalar; // don't claim as streaming, we want this to count towards delta-decrement + break; + case LengthPrefixResult.Length: + // still need to valid terminating CRLF + if (remaining.Length < consumed + _length + 2) break; // need more data + UnsafeAssertClLf(1 + consumed + _length); + + _flags = RespFlags.IsScalar | RespFlags.IsInlineScalar | RespFlags.IsStreaming; + break; + case LengthPrefixResult.Null: + case LengthPrefixResult.Streaming: + ThrowProtocolFailure("Invalid streaming scalar length prefix"); + break; + } + if (_flags == 0) break; // will need more data to know + _bufferIndex += 1 + consumed; + return true; + case RespPrefix.Array: + case RespPrefix.Set: + case RespPrefix.Map: + case RespPrefix.Push: + case RespPrefix.Attribute: + // length prefix without value payload (child values follow) + switch (TryReadLengthPrefix(remaining, out _length, out consumed)) + { + case LengthPrefixResult.Length: + _flags = RespFlags.IsAggregate; + if (AggregateLengthNeedsDoubling()) _length *= 2; + break; + case LengthPrefixResult.Null: + _flags = RespFlags.IsAggregate | RespFlags.IsNull; + break; + case LengthPrefixResult.Streaming: + _flags = RespFlags.IsAggregate | RespFlags.IsStreaming; + break; + } + if (_flags == 0) break; // will need more data to know + if (_prefix is RespPrefix.Attribute) _flags |= RespFlags.IsAttribute; + _bufferIndex += consumed + 1; + return true; + case RespPrefix.Null: // null + // note we already checked we had 3 bytes + UnsafeAssertClLf(1); + _flags = RespFlags.IsScalar | RespFlags.IsNull; + _bufferIndex += 3; // skip prefix+terminator + return true; + case RespPrefix.StreamTerminator: + // note we already checked we had 3 bytes + UnsafeAssertClLf(1); + _flags = RespFlags.IsAggregate; // don't claim as streaming - this counts towards delta + _bufferIndex += 3; // skip prefix+terminator + return true; + default: + ThrowProtocolFailure("Unexpected protocol prefix: " + _prefix); + return false; + } + } + + return TryReadNextSlow(ref this); + } + + private static bool TryReadNextSlow(ref RespReader live) + { + // in the case of failure, we don't want to apply any changes, + // so we work against an isolated copy until we're happy + live.MovePastCurrent(); + RespReader isolated = live; + + int next = isolated.RawTryReadByte(); + if (next < 0) return false; + + switch (isolated._prefix = (RespPrefix)next) + { + case RespPrefix.SimpleString: + case RespPrefix.SimpleError: + case RespPrefix.Integer: + case RespPrefix.Boolean: + case RespPrefix.Double: + case RespPrefix.BigInteger: + // CRLF-terminated + if (!isolated.RawTryFindCrLf(out isolated._length)) return false; + isolated._flags = RespFlags.IsScalar | RespFlags.IsInlineScalar; + if (isolated._prefix == RespPrefix.SimpleError) isolated._flags |= RespFlags.IsError; + break; + case RespPrefix.BulkError: + case RespPrefix.BulkString: + case RespPrefix.VerbatimString: + // length prefix with value payload + switch (isolated.RawTryReadLengthPrefix()) + { + case LengthPrefixResult.Length: + // still need to valid terminating CRLF + isolated._flags = RespFlags.IsScalar | RespFlags.IsInlineScalar; + if (!isolated.RawTryAssertInlineScalarPayloadCrLf()) return false; + break; + case LengthPrefixResult.Null: + isolated._flags = RespFlags.IsScalar | RespFlags.IsNull; + break; + case LengthPrefixResult.Streaming: + isolated._flags = RespFlags.IsScalar | RespFlags.IsStreaming; + break; + case LengthPrefixResult.NeedMoreData: + return false; + default: + ThrowProtocolFailure("Unexpected length prefix"); + return false; + } + if (isolated._prefix == RespPrefix.BulkError) isolated._flags |= RespFlags.IsError; + break; + case RespPrefix.Array: + case RespPrefix.Set: + case RespPrefix.Map: + case RespPrefix.Push: + case RespPrefix.Attribute: + // length prefix without value payload (child values follow) + switch (isolated.RawTryReadLengthPrefix()) + { + case LengthPrefixResult.Length: + isolated._flags = RespFlags.IsAggregate; + if (isolated.AggregateLengthNeedsDoubling()) isolated._length *= 2; + break; + case LengthPrefixResult.Null: + isolated._flags = RespFlags.IsAggregate | RespFlags.IsNull; + break; + case LengthPrefixResult.Streaming: + isolated._flags = RespFlags.IsAggregate | RespFlags.IsStreaming; + break; + case LengthPrefixResult.NeedMoreData: + return false; + default: + ThrowProtocolFailure("Unexpected length prefix"); + return false; + } + if (isolated._prefix is RespPrefix.Attribute) isolated._flags |= RespFlags.IsAttribute; + break; + case RespPrefix.Null: // null + if (!isolated.RawAssertCrLf()) return false; + isolated._flags = RespFlags.IsScalar | RespFlags.IsNull; + break; + case RespPrefix.StreamTerminator: + if (!isolated.RawAssertCrLf()) return false; + isolated._flags = RespFlags.IsAggregate; // don't claim as streaming - this counts towards delta + break; + case RespPrefix.StreamContinuation: + // length prefix, possibly with value payload; first, the length + switch (isolated.RawTryReadLengthPrefix()) + { + case LengthPrefixResult.Length when isolated._length == 0: + // EOF, no payload + isolated._flags = RespFlags.IsScalar; // don't claim as streaming, we want this to count towards delta-decrement + break; + case LengthPrefixResult.Length: + // still need to valid terminating CRLF + isolated._flags = RespFlags.IsScalar | RespFlags.IsInlineScalar | RespFlags.IsStreaming; + if (!isolated.RawTryAssertInlineScalarPayloadCrLf()) return false; // need more data + break; + case LengthPrefixResult.Null: + case LengthPrefixResult.Streaming: + ThrowProtocolFailure("Invalid streaming scalar length prefix"); + break; + case LengthPrefixResult.NeedMoreData: + default: + return false; + } + break; + default: + ThrowProtocolFailure("Unexpected protocol prefix: " + isolated._prefix); + return false; + } + // commit the speculative changes back, and accept + live = isolated; + return true; + } + + private void AdvanceSlow(long bytes) + { + while (bytes > 0) + { + var available = CurrentLength - _bufferIndex; + if (bytes <= available) + { + _bufferIndex += (int)bytes; + return; + } + bytes -= available; + + if (!TryMoveToNextSegment()) Throw(); + } + + [DoesNotReturn] + static void Throw() => throw new EndOfStreamException("Unexpected end of payload; this is unexpected because we already validated that it was available!"); + } + + private bool AggregateLengthNeedsDoubling() => _prefix is RespPrefix.Map or RespPrefix.Attribute; + + private bool TryMoveToNextSegment() + { + while (_tail is not null && _remainingTailLength > 0) + { + var memory = _tail.Memory; + _tail = _tail.Next; + if (!memory.IsEmpty) + { + var span = memory.Span; // check we can get this before mutating anything + _positionBase += CurrentLength; + if (span.Length > _remainingTailLength) + { + span = span.Slice(0, (int)_remainingTailLength); + _remainingTailLength = 0; + } + else + { + _remainingTailLength -= span.Length; + } + SetCurrent(span); + _bufferIndex = 0; + return true; + } + } + return false; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal readonly bool IsOK() // go mad with this, because it is used so often + { + return TryGetSpan(out var span) && span.Length == 2 + ? Unsafe.ReadUnaligned(ref UnsafeCurrent) == RespConstants.OKUInt16 + : IsSlow(RespConstants.OKBytes); + } + + /// + /// Indicates whether the current element is a scalar with a value that matches the provided . + /// + /// The payload value to verify. + public readonly bool Is(ReadOnlySpan value) + => TryGetSpan(out var span) ? span.SequenceEqual(value) : IsSlow(value); + + internal readonly bool IsInlneCpuUInt32(uint value) + { + if (IsInlineScalar && _length == sizeof(uint)) + { + return CurrentAvailable >= sizeof(uint) + ? Unsafe.ReadUnaligned(ref UnsafeCurrent) == value + : SlowIsInlneCpuUInt32(value); + } + + return false; + } + + private readonly bool SlowIsInlneCpuUInt32(uint value) + { + Debug.Assert(IsInlineScalar && _length == sizeof(uint), "should be inline scalar of length 4"); + Span buffer = stackalloc byte[sizeof(uint)]; + var copy = this; + copy.RawFillBytes(buffer); + return RespConstants.UnsafeCpuUInt32(buffer) == value; + } + + /// + /// Indicates whether the current element is a scalar with a value that matches the provided . + /// + /// The payload value to verify. + public readonly bool Is(byte value) + { + if (IsInlineScalar && _length == 1 && CurrentAvailable >= 1) + { + return UnsafeCurrent == value; + } + + ReadOnlySpan span = [value]; + return IsSlow(span); + } + + private readonly bool IsSlow(ReadOnlySpan testValue) + { + DemandScalar(); + if (IsNull) return false; // nothing equals null + if (TotalAvailable < testValue.Length) return false; + + if (!IsStreaming && testValue.Length != ScalarLength()) return false; + + var iterator = ScalarChunks(); + while (true) + { + if (testValue.IsEmpty) + { + // nothing left to test; if also nothing left to read, great! + return !iterator.MoveNext(); + } + if (!iterator.MoveNext()) + { + return false; // test is longer + } + + var current = iterator.Current; + if (testValue.Length < current.Length) return false; // payload is longer + + if (!current.SequenceEqual(testValue.Slice(0, current.Length))) return false; // payload is different + + testValue = testValue.Slice(current.Length); // validated; continue + } + } + + /// + /// Copy the current scalar value out into the supplied , or as much as can be copied. + /// + /// The destination for the copy operation. + /// The number of bytes successfully copied. + public readonly int CopyTo(Span target) + { + if (TryGetSpan(out var value)) + { + if (target.Length < value.Length) value = value.Slice(0, target.Length); + + value.CopyTo(target); + return value.Length; + } + + int totalBytes = 0; + var iterator = ScalarChunks(); + while (iterator.MoveNext()) + { + value = iterator.Current; + if (target.Length <= value.Length) + { + value.Slice(0, target.Length).CopyTo(target); + return totalBytes + target.Length; + } + + value.CopyTo(target); + target = target.Slice(value.Length); + totalBytes += value.Length; + } + return totalBytes; + } + + /// + /// Copy the current scalar value out into the supplied , or as much as can be copied. + /// + /// The destination for the copy operation. + /// The number of bytes successfully copied. + public readonly int CopyTo(IBufferWriter target) + { + if (TryGetSpan(out var value)) + { + target.Write(value); + return value.Length; + } + + int totalBytes = 0; + var iterator = ScalarChunks(); + while (iterator.MoveNext()) + { + value = iterator.Current; + target.Write(value); + totalBytes += value.Length; + } + return totalBytes; + } + + /// + /// Asserts that the current element is not null. + /// + public void DemandNotNull() + { + if (IsNull) Throw(); + static void Throw() => throw new InvalidOperationException("A non-null element was expected"); + } + + /// + /// Read the current element as a value. + /// + [SuppressMessage("Style", "IDE0018:Inline variable declaration", Justification = "No it can't - conditional")] + public readonly long ReadInt64() + { + var span = Buffer(stackalloc byte[RespConstants.MaxRawBytesInt64 + 1]); + long value; + if (!(span.Length <= RespConstants.MaxRawBytesInt64 + && Utf8Parser.TryParse(span, out value, out int bytes) + && bytes == span.Length)) + { + ThrowFormatException(); + value = 0; + } + return value; + } + + /// + /// Try to read the current element as a value. + /// + public readonly bool TryReadInt64(out long value) + { + var span = Buffer(stackalloc byte[RespConstants.MaxRawBytesInt64 + 1]); + if (span.Length <= RespConstants.MaxRawBytesInt64) + { + return Utf8Parser.TryParse(span, out value, out int bytes) & bytes == span.Length; + } + + value = 0; + return false; + } + + /// + /// Read the current element as a value. + /// + [SuppressMessage("Style", "IDE0018:Inline variable declaration", Justification = "No it can't - conditional")] + public readonly int ReadInt32() + { + var span = Buffer(stackalloc byte[RespConstants.MaxRawBytesInt32 + 1]); + int value; + if (!(span.Length <= RespConstants.MaxRawBytesInt32 + && Utf8Parser.TryParse(span, out value, out int bytes) + && bytes == span.Length)) + { + ThrowFormatException(); + value = 0; + } + return value; + } + + /// + /// Try to read the current element as a value. + /// + public readonly bool TryReadInt32(out int value) + { + var span = Buffer(stackalloc byte[RespConstants.MaxRawBytesInt32 + 1]); + if (span.Length <= RespConstants.MaxRawBytesInt32) + { + return Utf8Parser.TryParse(span, out value, out int bytes) & bytes == span.Length; + } + + value = 0; + return false; + } + + /// + /// Read the current element as a value. + /// + public readonly double ReadDouble() + { + var span = Buffer(stackalloc byte[RespConstants.MaxRawBytesNumber + 1]); + + if (span.Length <= RespConstants.MaxRawBytesNumber + && Utf8Parser.TryParse(span, out double value, out int bytes) + && bytes == span.Length) + { + return value; + } + switch (span.Length) + { + case 3 when "inf"u8.SequenceEqual(span): + return double.PositiveInfinity; + case 3 when "nan"u8.SequenceEqual(span): + return double.NaN; + case 4 when "+inf"u8.SequenceEqual(span): // not actually mentioned in spec, but: we'll allow it + return double.PositiveInfinity; + case 4 when "-inf"u8.SequenceEqual(span): + return double.NegativeInfinity; + } + ThrowFormatException(); + return 0; + } + + /// + /// Try to read the current element as a value. + /// + public bool TryReadDouble(out double value, bool allowTokens = true) + { + var span = Buffer(stackalloc byte[RespConstants.MaxRawBytesNumber + 1]); + + if (span.Length <= RespConstants.MaxRawBytesNumber + && Utf8Parser.TryParse(span, out value, out int bytes) + && bytes == span.Length) + { + return true; + } + + if (allowTokens) + { + switch (span.Length) + { + case 3 when "inf"u8.SequenceEqual(span): + value = double.PositiveInfinity; + return true; + case 3 when "nan"u8.SequenceEqual(span): + value = double.NaN; + return true; + case 4 when "+inf"u8.SequenceEqual(span): // not actually mentioned in spec, but: we'll allow it + value = double.PositiveInfinity; + return true; + case 4 when "-inf"u8.SequenceEqual(span): + value = double.NegativeInfinity; + return true; + } + } + + value = 0; + return false; + } + + internal readonly bool TryReadShortAscii(out string value) + { + const int ShortLength = 31; + + var span = Buffer(stackalloc byte[ShortLength + 1]); + value = ""; + if (span.IsEmpty) return true; + + if (span.Length <= ShortLength) + { + // check for anything that looks binary or unicode + foreach (var b in span) + { + // allow [SPACE]-thru-[DEL], plus CR/LF + if (!(b < 127 & (b >= 32 | (b is 12 or 13)))) + { + return false; + } + } + + value = Encoding.UTF8.GetString(span); + return true; + } + + return false; + } + + /// + /// Read the current element as a value. + /// + [SuppressMessage("Style", "IDE0018:Inline variable declaration", Justification = "No it can't - conditional")] + public readonly decimal ReadDecimal() + { + var span = Buffer(stackalloc byte[RespConstants.MaxRawBytesNumber + 1]); + decimal value; + if (!(span.Length <= RespConstants.MaxRawBytesNumber + && Utf8Parser.TryParse(span, out value, out int bytes) + && bytes == span.Length)) + { + ThrowFormatException(); + value = 0; + } + return value; + } + + /// + /// Read the current element as a value. + /// + public readonly bool ReadBoolean() + { + var span = Buffer(stackalloc byte[2]); + if (span.Length == 1) + { + switch (span[0]) + { + case (byte)'0' when Prefix == RespPrefix.Integer: return false; + case (byte)'1' when Prefix == RespPrefix.Integer: return true; + case (byte)'f' when Prefix == RespPrefix.Boolean: return false; + case (byte)'t' when Prefix == RespPrefix.Boolean: return true; + } + } + ThrowFormatException(); + return false; + } + + /// + /// Parse a scalar value as an enum of type . + /// + /// The value to report if the value is not recognized. + /// The type of enum being parsed. + public readonly T ReadEnum(T unknownValue = default) where T : struct, Enum + { +#if NET6_0_OR_GREATER + return ParseChars(static (chars, state) => Enum.TryParse(chars, true, out T value) ? value : state, unknownValue); +#else + return Enum.TryParse(ReadString(), true, out T value) ? value : unknownValue; +#endif + } + + public T[]? ReadArray(Projection projection) + { + DemandAggregate(); + if (IsNull) return null; + var len = AggregateLength(); + if (len == 0) return []; + T[] result = new T[len]; + FillAll(result, projection); + return result; + } +} diff --git a/src/RESP.Core/RespReaderExtensions.cs b/src/RESP.Core/RespReaderExtensions.cs new file mode 100644 index 000000000..37e013dc9 --- /dev/null +++ b/src/RESP.Core/RespReaderExtensions.cs @@ -0,0 +1,142 @@ +// using System; +// using System.Buffers; +// using System.Diagnostics; +// +// namespace Resp; +// +// /// +// /// Utility methods for s. +// /// +// internal static class RespReaderExtensions +// { +// public static RedisValue ReadRedisValue(in this RespReader reader) +// { +// reader.DemandScalar(); +// if (reader.IsNull) return RedisValue.Null; +// +// var len = reader.ScalarLength(); +// switch (reader.Prefix) +// { +// case RespPrefix.Boolean: +// return reader.ReadBoolean(); +// case RespPrefix.Integer: +// return reader.ReadInt64(); +// case RespPrefix.Double: +// return reader.ReadDouble(); +// } +// +// if (len == 0) return RedisValue.EmptyString; +// +// // try to be efficient with obvious numbers and short strings +// if (reader.TryReadInt64(out var i64)) +// { +// return i64; +// } +// +// if (reader.TryReadDouble(out var f64, allowTokens: false)) +// { +// return f64; +// } +// +// if (reader.TryReadShortAscii(out var s)) +// { +// return s; +// } +// +// // otherwise, copy out the blob +// var result = new byte[len]; +// int actual = reader.CopyTo(result); +// Debug.Assert(actual == len); +// return result; +// } +// +// public static RedisKey ReadRedisKey(in this RespReader reader) +// { +// reader.DemandScalar(); +// if (reader.IsNull) return RedisKey.Null; +// +// var len = reader.ScalarLength(); +// if (len == 0) return ""; +// +// if (reader.TryReadShortAscii(out var s)) +// { +// return s; +// } +// +// // copy out the blob +// var result = new byte[len]; +// int actual = reader.CopyTo(result); +// Debug.Assert(actual == len); +// return result; +// } +// +// /* +// +// /// +// /// Interpret a scalar value as a value. +// /// +// public static LeasedString ReadLeasedString(in this RespReader reader) +// { +// if (reader.TryGetSpan(out var span)) return reader.IsNull ? default : new LeasedString(span); +// +// var len = reader.ScalarLength(); +// var result = new LeasedString(len, out var memory); +// int actual = reader.CopyTo(memory.Span); +// Debug.Assert(actual == len); +// return result; +// } +// +// /// +// /// Interpret an aggregate value as a value. +// /// +// public static LeasedStrings ReadLeasedStrings(in this RespReader reader) +// { +// Debug.Assert(reader.IsAggregate, "should have already checked for aggregate"); +// reader.DemandAggregate(); +// if (reader.IsNull) return default; +// +// int count = 0, bytes = 0; +// foreach (var child in reader.AggregateChildren()) +// { +// count++; +// bytes += child.ScalarLength(); +// } +// if (count == 0) return LeasedStrings.Empty; +// +// var builder = new LeasedStrings.Builder(count, bytes); +// foreach (var child in reader.AggregateChildren()) +// { +// if (child.IsNull) +// { +// builder.AddNull(); +// } +// else +// { +// var len = child.ScalarLength(); +// var span = builder.Add(len); +// child.CopyTo(span); +// } +// } +// return builder.Create(); +// } +// +// /// +// /// Indicates whether the given value is an byte match. +// /// +// public static bool Is(in this RespReader reader, in SimpleString value) +// { +// if (value.TryGetBytes(span: out var span)) +// { +// return reader.Is(span) & reader.IsNull == value.IsNull; +// } +// +// var len = value.GetByteCount(); +// var oversized = ArrayPool.Shared.Rent(len); +// var actual = value.CopyTo(oversized); +// Debug.Assert(actual == len); +// var result = reader.Is(new ReadOnlySpan(oversized, 0, len)); +// ArrayPool.Shared.Return(oversized); +// return result; +// } +// */ +// } diff --git a/src/RESP.Core/RespReaders.cs b/src/RESP.Core/RespReaders.cs new file mode 100644 index 000000000..58cbac2ea --- /dev/null +++ b/src/RESP.Core/RespReaders.cs @@ -0,0 +1,341 @@ +// using System.Buffers; +// using System.Diagnostics.CodeAnalysis; +// using System.Runtime.CompilerServices; +// using RESPite.Messages; +// using static RESPite.Resp.RespConstants; +// +// namespace Resp; +// +// /// +// /// Provides common RESP reader implementations. +// /// +// internal static class RespReaders +// { +// internal static readonly Impl Common = new(); +// +// /// +// /// Reads payloads. +// /// +// public static IRespReader String => Common; +// +// /// +// /// Reads payloads. +// /// +// public static IRespReader Int32 => Common; +// +// /// +// /// Reads payloads. +// /// +// public static IRespReader NullableInt32 => Common; +// +// /// +// /// Reads payloads. +// /// +// public static IRespReader Int64 => Common; +// +// /// +// /// Reads payloads. +// /// +// public static IRespReader NullableInt64 => Common; +// +// /// +// /// Reads 'OK' acknowledgements. +// /// +// public static IRespReader OK => Common; +// +// /// +// /// Reads payloads. +// /// +// public static IRespReader LeasedString => Common; +// +// /// +// /// Reads arrays of opaque payloads. +// /// +// public static IRespReader LeasedStrings => Common; +// +// internal static void ThrowMissingExpected(string expected, [CallerMemberName] string caller = "") +// => throw new InvalidOperationException($"Did not receive expected response: '{expected}'"); +// +// internal sealed class Impl : +// IRespReader, +// IRespReader, +// IRespReader, +// IRespReader, +// IRespReader, +// IRespReader, +// IRespReader, +// IRespReader, +// IRespReader +// { +// private static readonly uint OK_HiNibble = UnsafeCpuUInt32("+OK\r"u8); +// Empty IReader.Read(in Empty request, in ReadOnlySequence content) +// { +// if (content.IsSingleSegment) +// { +// #if NETCOREAPP3_1_OR_GREATER +// var span = content.FirstSpan; +// #else +// var span = content.First.Span; +// #endif +// if (span.Length != 5 || !(UnsafeCpuUInt32(span) == OK_HiNibble & UnsafeCpuByte(span, 4) == (byte)'\n')) ThrowMissingExpected("OK"); +// } +// else +// { +// Slower(content); +// } +// return default; +// +// static Empty Slower(scoped in ReadOnlySequence content) +// { +// var reader = new RespReader(content); +// reader.MoveNext(RespPrefix.SimpleString); +// if (!reader.IsOK()) ThrowMissingExpected("OK"); +// return default; +// } +// } +// +// Empty IRespReader.Read(in Empty request, ref RespReader reader) +// { +// reader.MoveNext(RespPrefix.SimpleString); +// if (!reader.IsOK()) ThrowMissingExpected("OK"); +// return default; +// } +// +// string? IRespReader.Read(in Empty request, ref RespReader reader) +// { +// reader.MoveNextScalar(); +// return reader.ReadString(); +// } +// +// string? IReader.Read(in Empty request, in ReadOnlySequence content) +// { +// var reader = new RespReader(in content); +// reader.MoveNextScalar(); +// return reader.ReadString(); +// } +// +// long IReader.Read(in Empty request, in ReadOnlySequence content) +// { +// if (content.IsSingleSegment && content.Length <= 12) // 9 chars for pre-billion integers, plus 3 protocol chars +// { +// return ((IReader)this).Read(request, content); +// } +// var reader = new RespReader(in content); +// reader.MoveNextScalar(); +// reader.DemandNotNull(); +// return reader.ReadInt64(); +// } +// +// long? IReader.Read(in Empty request, in ReadOnlySequence content) +// { +// if (content.IsSingleSegment && content.Length <= 12) // 9 chars for pre-billion integers, plus 3 protocol chars +// { +// return ((IReader)this).Read(request, content); +// } +// var reader = new RespReader(in content); +// reader.MoveNextScalar(); +// return reader.IsNull ? null : reader.ReadInt64(); +// } +// +// long IRespReader.Read(in Empty request, ref RespReader reader) +// { +// reader.MoveNextScalar(); +// reader.DemandNotNull(); +// return reader.ReadInt64(); +// } +// +// long? IRespReader.Read(in Empty request, ref RespReader reader) +// { +// reader.MoveNextScalar(); +// return reader.IsNull ? null : reader.ReadInt64(); +// } +// +// int IRespReader.Read(in Empty request, ref RespReader reader) +// { +// reader.MoveNextScalar(); +// reader.DemandNotNull(); +// return reader.ReadInt32(); +// } +// +// int? IRespReader.Read(in Empty request, ref RespReader reader) +// { +// reader.MoveNextScalar(); +// return reader.IsNull ? null : reader.ReadInt32(); +// } +// +// LeasedString IReader.Read(in Empty request, in ReadOnlySequence content) +// { +// var reader = new RespReader(in content); +// reader.MoveNextScalar(); +// return reader.ReadLeasedString(); +// } +// +// LeasedString IRespReader.Read(in Empty request, ref RespReader reader) +// { +// reader.MoveNextScalar(); +// return reader.ReadLeasedString(); +// } +// +// bool IReader.Read(in Empty request, in ReadOnlySequence content) +// { +// var reader = new RespReader(in content); +// reader.MoveNextScalar(); +// return reader.IsOK() || reader.Is((byte)'1'); +// } +// +// bool IRespReader.Read(in Empty request, ref RespReader reader) +// { +// reader.MoveNextScalar(); +// return reader.IsOK() || reader.Is((byte)'1'); +// } +// +// private static bool TryReadFastInt32(ReadOnlySpan span, out int value) +// { +// switch (span.Length) +// { +// case 4: // :N\r\n +// if ((UnsafeCpuUInt32(span) & SingleCharScalarMask) == SingleDigitInteger) +// { +// value = Digit(UnsafeCpuByte(span, 1)); +// return true; +// } +// break; +// case 5: // :NN\r\n +// if ((UnsafeCpuUInt32(span) & DoubleCharScalarMask) == DoubleDigitInteger +// & UnsafeCpuByte(span, 4) == (byte)'\n') +// { +// value = (10 * Digit(UnsafeCpuByte(span, 1))) +// + Digit(UnsafeCpuByte(span, 2)); +// return true; +// } +// break; +// case 7: // $1\r\nN\r\n +// if (UnsafeCpuUInt32(span) == BulkSingleDigitPrefix +// && UnsafeCpuUInt16(span, 5) == CrLfUInt16) +// { +// value = Digit(UnsafeCpuByte(span, 4)); +// return true; +// } +// break; +// case 8: // $2\r\nNN\r\n +// if (UnsafeCpuUInt32(span) == BulkDoubleDigitPrefix +// && UnsafeCpuUInt16(span, 6) == CrLfUInt16) +// { +// value = (10 * Digit(UnsafeCpuByte(span, 4))) +// + Digit(UnsafeCpuByte(span, 5)); +// return true; +// } +// break; +// } +// value = default; +// return false; +// +// static int Digit(byte value) +// { +// var i = value - '0'; +// if (i < 0 | i > 9) ThrowFormat(); +// return i; +// } +// } +// +// int IReader.Read(in Empty request, in ReadOnlySequence content) +// { +// if (content.IsSingleSegment) +// { +// #if NETCOREAPP3_1_OR_GREATER +// var span = content.FirstSpan; +// #else +// var span = content.First.Span; +// #endif +// if (TryReadFastInt32(span, out int i)) return i; +// } +// var reader = new RespReader(in content); +// reader.MoveNextScalar(); +// reader.DemandNotNull(); +// return reader.ReadInt32(); +// } +// +// int? IReader.Read(in Empty request, in ReadOnlySequence content) +// { +// if (content.IsSingleSegment) +// { +// #if NETCOREAPP3_1_OR_GREATER +// var span = content.FirstSpan; +// #else +// var span = content.First.Span; +// #endif +// if (TryReadFastInt32(span, out int i)) return i; +// } +// var reader = new RespReader(in content); +// reader.MoveNextScalar(); +// return reader.IsNull ? null : reader.ReadInt32(); +// } +// +// LeasedStrings IReader.Read(in Empty request, in ReadOnlySequence content) +// { +// var reader = new RespReader(in content); +// reader.MoveNextAggregate(); +// return reader.ReadLeasedStrings(); +// } +// +// LeasedStrings IRespReader.Read(in Empty request, ref RespReader reader) +// { +// reader.MoveNextAggregate(); +// return reader.ReadLeasedStrings(); +// } +// +// private static readonly uint +// SingleCharScalarMask = CpuUInt32(0xFF00FFFF), +// DoubleCharScalarMask = CpuUInt32(0xFF0000FF), +// SingleDigitInteger = UnsafeCpuUInt32(":\0\r\n"u8), +// DoubleDigitInteger = UnsafeCpuUInt32(":\0\0\r"u8), +// BulkSingleDigitPrefix = UnsafeCpuUInt32("$1\r\n"u8), +// BulkDoubleDigitPrefix = UnsafeCpuUInt32("$2\r\n"u8); +// } +// +// /// +// /// Reads values as an enum of type . +// /// +// public sealed class EnumReader : IRespReader, IRespReader where T : struct, Enum +// { +// /// +// /// Gets the reader instance. +// /// +// public static EnumReader Instance { get; } = new(); +// +// private EnumReader() +// { +// } +// +// T IReader.Read(in Empty request, in ReadOnlySequence content) +// { +// RespReader reader = new(content); +// reader.MoveNextScalar(); +// reader.DemandNotNull(); +// return reader.ReadEnum(default); +// } +// +// T? IReader.Read(in Empty request, in ReadOnlySequence content) +// { +// RespReader reader = new(content); +// reader.MoveNextScalar(); +// return reader.IsNull ? null : reader.ReadEnum(default); +// } +// +// T IRespReader.Read(in Empty request, ref RespReader reader) +// { +// reader.MoveNextScalar(); +// reader.DemandNotNull(); +// return reader.ReadEnum(default); +// } +// +// T? IRespReader.Read(in Empty request, ref RespReader reader) +// { +// reader.MoveNextScalar(); +// return reader.IsNull ? null : reader.ReadEnum(default); +// } +// } +// +// [DoesNotReturn, MethodImpl(MethodImplOptions.NoInlining)] +// private static void ThrowFormat() => throw new FormatException(); +// } diff --git a/src/RESP.Core/RespScanState.cs b/src/RESP.Core/RespScanState.cs new file mode 100644 index 000000000..7eba5d8be --- /dev/null +++ b/src/RESP.Core/RespScanState.cs @@ -0,0 +1,161 @@ +using System; +using System.Buffers; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; + +namespace Resp; + +/// +/// Holds state used for RESP frame parsing, i.e. detecting the RESP for an entire top-level message. +/// +public struct RespScanState +{ + /* + The key point of ScanState is to skim over a RESP stream with minimal frame processing, to find the + end of a single top-level RESP message. We start by expecting 1 message, and then just read, with the + rules that the end of a message subtracts one, and aggregates add N. Streaming scalars apply zero offset + until the scalar stream terminator. Attributes also apply zero offset. + Note that streaming aggregates change the rules - when at least one streaming aggregate is in effect, + no offsets are applied until we get back out of the outermost streaming aggregate - we achieve this + by simply counting the streaming aggregate depth, which is usually zero. + Note that in reality streaming (scalar and aggregates) and attributes are non-existent; in addition + to being specific to RESP3, no known server currently implements these parts of the RESP3 specification, + so everything here is theoretical, but: works according to the spec. + */ + private int _delta; // when this becomes -1, we have fully read a top-level message; + private ushort _streamingAggregateDepth; + private RespPrefix _prefix; + + public RespPrefix Prefix => _prefix; + + private long _totalBytes; +#if DEBUG + private int _elementCount; + + /// + public override string ToString() => $"{_prefix}, consumed: {_totalBytes} bytes, {_elementCount} nodes, complete: {IsComplete}"; +#else + /// + public override string ToString() => _prefix.ToString(); +#endif + + /// + public override bool Equals([NotNullWhen(true)] object? obj) => throw new NotSupportedException(); + + /// + public override int GetHashCode() => throw new NotSupportedException(); + + /// + /// Gets whether an entire top-level RESP message has been consumed. + /// + public bool IsComplete => _delta == -1; + + /// + /// Gets the total length of the payload read (or read so far, if it is not yet complete); this combines payloads from multiple + /// TryRead operations. + /// + public long TotalBytes => _totalBytes; + + // used when spotting common replies - we entirely bypass the usual reader/delta mechanism + internal void SetComplete(int totalBytes, RespPrefix prefix) + { + _totalBytes = totalBytes; + _delta = -1; + _prefix = prefix; +#if DEBUG + _elementCount = 1; +#endif + } + + /// + /// The amount of data, in bytes, to read before attempting to read the next frame. + /// + public const int MinBytes = 3; // minimum legal RESP frame is: _\r\n + + /// + /// Create a new value that can parse the supplied node (and subtree). + /// + internal RespScanState(in RespReader reader) + { + Debug.Assert(reader.Prefix != RespPrefix.None, "missing RESP prefix"); + _totalBytes = 0; + _delta = reader.GetInitialScanCount(out _streamingAggregateDepth); + } + + /// + /// Scan as far as possible, stopping when an entire top-level RESP message has been consumed or the data is exhausted. + /// + /// True if a top-level RESP message has been consumed. + public bool TryRead(ref RespReader reader, out long bytesRead) + { + bytesRead = ReadCore(ref reader, reader.BytesConsumed); + return IsComplete; + } + + /// + /// Scan as far as possible, stopping when an entire top-level RESP message has been consumed or the data is exhausted. + /// + /// True if a top-level RESP message has been consumed. + public bool TryRead(ReadOnlySpan value, out int bytesRead) + { + var reader = new RespReader(value); + bytesRead = (int)ReadCore(ref reader); + return IsComplete; + } + + /// + /// Scan as far as possible, stopping when an entire top-level RESP message has been consumed or the data is exhausted. + /// + /// True if a top-level RESP message has been consumed. + public bool TryRead(in ReadOnlySequence value, out long bytesRead) + { + var reader = new RespReader(in value); + bytesRead = ReadCore(ref reader); + return IsComplete; + } + + /// + /// Scan as far as possible, stopping when an entire top-level RESP message has been consumed or the data is exhausted. + /// + /// The number of bytes consumed in this operation. + private long ReadCore(ref RespReader reader, long startOffset = 0) + { + while (_delta >= 0 && reader.TryReadNext()) + { +#if DEBUG + _elementCount++; +#endif + if (!reader.IsAttribute & _prefix == RespPrefix.None) + { + _prefix = reader.Prefix; + } + + if (reader.IsAggregate) ApplyAggregateRules(ref reader); + + if (_streamingAggregateDepth == 0) _delta += reader.Delta(); + } + + var bytesRead = reader.BytesConsumed - startOffset; + _totalBytes += bytesRead; + return bytesRead; + } + + private void ApplyAggregateRules(ref RespReader reader) + { + Debug.Assert(reader.IsAggregate, "RESP aggregate expected"); + if (reader.IsStreaming) + { + // entering an aggregate stream + if (_streamingAggregateDepth == ushort.MaxValue) ThrowTooDeep(); + _streamingAggregateDepth++; + } + else if (reader.Prefix == RespPrefix.StreamTerminator) + { + // exiting an aggregate stream + if (_streamingAggregateDepth == 0) ThrowUnexpectedTerminator(); + _streamingAggregateDepth--; + } + static void ThrowTooDeep() => throw new InvalidOperationException("Maximum streaming aggregate depth exceeded."); + static void ThrowUnexpectedTerminator() => throw new InvalidOperationException("Unexpected streaming aggregate terminator."); + } +} diff --git a/src/RESP.Core/RespWriter.cs b/src/RESP.Core/RespWriter.cs new file mode 100644 index 000000000..db10d381d --- /dev/null +++ b/src/RESP.Core/RespWriter.cs @@ -0,0 +1,913 @@ +using System; +using System.Buffers; +using System.Buffers.Text; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Text; + +namespace Resp; + +/// +/// Provides low-level RESP formatting operations. +/// +public ref struct RespWriter +{ + private readonly IBufferWriter? _target; + + [SuppressMessage("Style", "IDE0032:Use auto property", Justification = "Clarity")] + private int _index; + + internal readonly int IndexInCurrentBuffer => _index; + +#if NET7_0_OR_GREATER + private ref byte StartOfBuffer; + private int BufferLength; + + private ref byte WriteHead + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => ref Unsafe.Add(ref StartOfBuffer, _index); + } + + private Span Tail + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => MemoryMarshal.CreateSpan(ref Unsafe.Add(ref StartOfBuffer, _index), BufferLength - _index); + } + + private void WriteRawUnsafe(byte value) => Unsafe.Add(ref StartOfBuffer, _index++) = value; + + private readonly ReadOnlySpan WrittenLocalBuffer => + MemoryMarshal.CreateReadOnlySpan(ref StartOfBuffer, _index); +#else + private Span _buffer; + private readonly int BufferLength => _buffer.Length; + + private readonly ref byte StartOfBuffer + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => ref MemoryMarshal.GetReference(_buffer); + } + + private readonly ref byte WriteHead + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => ref Unsafe.Add(ref MemoryMarshal.GetReference(_buffer), _index); + } + + private readonly Span Tail => _buffer.Slice(_index); + private void WriteRawUnsafe(byte value) => _buffer[_index++] = value; + + private readonly ReadOnlySpan WrittenLocalBuffer => _buffer.Slice(0, _index); +#endif + + internal readonly string DebugBuffer() => RespConstants.UTF8.GetString(WrittenLocalBuffer); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void WriteCrLfUnsafe() + { + Unsafe.WriteUnaligned(ref WriteHead, RespConstants.CrLfUInt16); + _index += 2; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void WriteCrLf() + { + if (Available >= 2) + { + Unsafe.WriteUnaligned(ref WriteHead, RespConstants.CrLfUInt16); + _index += 2; + } + else + { + WriteRaw(RespConstants.CrlfBytes); + } + } + + private readonly int Available + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => BufferLength - _index; + } + + /// + /// Create a new RESP writer over the provided target. + /// + public RespWriter(IBufferWriter target) + { + _target = target; + _index = 0; +#if NET7_0_OR_GREATER + StartOfBuffer = ref Unsafe.NullRef(); + BufferLength = 0; +#else + _buffer = default; +#endif + GetBuffer(); + } + + /// + /// Create a new RESP writer over the provided target. + /// + public RespWriter(Span target) + { + _index = 0; +#if NET7_0_OR_GREATER + BufferLength = target.Length; + StartOfBuffer = ref MemoryMarshal.GetReference(target); +#else + _buffer = target; +#endif + } + + /// + /// Commits any unwritten bytes to the output. + /// + public void Flush() + { + if (_index != 0 && _target is not null) + { + _target.Advance(_index); +#if NET7_0_OR_GREATER + _index = BufferLength = 0; + StartOfBuffer = ref Unsafe.NullRef(); +#else + _index = 0; + _buffer = default; +#endif + } + } + + private void FlushAndGetBuffer(int sizeHint) + { + Flush(); + GetBuffer(sizeHint); + } + + private void GetBuffer(int sizeHint = 128) + { + if (Available == 0) + { + if (_target is null) + { + ThrowFixedBufferExceeded(); + } + else + { + const int MIN_BUFFER = 1024; + _index = 0; +#if NET7_0_OR_GREATER + var span = _target.GetSpan(Math.Max(sizeHint, MIN_BUFFER)); + BufferLength = span.Length; + StartOfBuffer = ref MemoryMarshal.GetReference(span); +#else + _buffer = _target.GetSpan(Math.Max(sizeHint, MIN_BUFFER)); +#endif + ActivationHelper.DebugBreakIf(Available == 0); + } + } + } + + [DoesNotReturn, MethodImpl(MethodImplOptions.NoInlining)] + private static void ThrowFixedBufferExceeded() => + throw new InvalidOperationException("Fixed buffer cannot be expanded"); + + /// + /// Write raw RESP data to the output; no validation will occur. + /// + public void WriteRaw(scoped ReadOnlySpan buffer) + { + const int MAX_TO_DOUBLE_BUFFER = 128; + if (buffer.Length <= MAX_TO_DOUBLE_BUFFER && buffer.Length <= Available) + { + buffer.CopyTo(Tail); + _index += buffer.Length; + } + else + { + // write directly to the output + Flush(); + if (_target is null) + { + ThrowFixedBufferExceeded(); + } + else + { + _target.Write(buffer); + } + } + } + + public RespCommandMap? CommandMap { get; set; } + + /// + /// Write a command header. + /// + /// The command name to write. + /// The number of arguments for the command (excluding the command itself). + public void WriteCommand(scoped ReadOnlySpan command, int args) + { + if (args < 0) Throw(); + WritePrefixedInteger(RespPrefix.Array, args + 1); + if (command.IsEmpty) ThrowEmptyCommand(); + if (CommandMap is { } map) + { + var mapped = map.Map(command); + if (mapped.IsEmpty) ThrowCommandUnavailable(command); + command = mapped; + } + + WriteBulkString(command); + + static void Throw() => throw new ArgumentOutOfRangeException(nameof(args)); + + static void ThrowEmptyCommand() => + throw new ArgumentException(paramName: nameof(command), message: "Empty command specified."); + + static void ThrowCommandUnavailable(ReadOnlySpan command) + => throw new ArgumentException( + paramName: nameof(command), + message: $"The command {Encoding.UTF8.GetString(command)} is not available."); + } + + /// + /// Write a key as a bulk string. + /// + /// The key to write. + public void WriteKey(scoped ReadOnlySpan value) => WriteBulkString(value); + + /// + /// Write a key as a bulk string. + /// + /// The key to write. + public void WriteKey(ReadOnlyMemory value) => WriteBulkString(value.Span); + + /// + /// Write a key as a bulk string. + /// + /// The key to write. + public void WriteKey(scoped ReadOnlySpan value) => WriteBulkString(value); + + /// + /// Write a key as a bulk string. + /// + /// The key to write. + public void WriteKey(ReadOnlyMemory value) => WriteBulkString(value.Span); + + /// + /// Write a key as a bulk string. + /// + /// The key to write. + public void WriteKey(string value) => WriteBulkString(value); + + /// + /// Write a key as a bulk string. + /// + /// The key to write. + public void WriteKey(byte[] value) => WriteBulkString(value.AsSpan()); + + /// + /// Write a payload as a bulk string. + /// + /// The payload to write. + public void WriteBulkString(byte[] value) => WriteBulkString(value.AsSpan()); + + /// + /// Write a payload as a bulk string. + /// + /// The payload to write. + public void WriteBulkString(ReadOnlyMemory value) + => WriteBulkString(value.Span); + + /// + /// Write a payload as a bulk string. + /// + /// The payload to write. + public void WriteBulkString(scoped ReadOnlySpan value) + { + if (value.IsEmpty) + { + if (Available >= 6) + { + WriteRawPrechecked(Raw.BulkStringEmpty_6, 6); + } + else + { + WriteRaw("$0\r\n\r\n"u8); + } + } + else + { + WriteBulkStringHeader(value.Length); + if (Available >= value.Length + 2) + { + value.CopyTo(Tail); + _index += value.Length; + WriteCrLfUnsafe(); + } + else + { + // slow path + WriteRaw(value); + WriteCrLf(); + } + } + } + + /* + /// + /// Write a payload as a bulk string. + /// + /// The payload to write. + public void WriteBulkString(in SimpleString value) + { + if (value.IsEmpty) + { + WriteRaw("$0\r\n\r\n"u8); + } + else if (value.TryGetBytes(span: out var bytes)) + { + WriteBulkString(bytes); + } + else if (value.TryGetChars(span: out var chars)) + { + WriteBulkString(chars); + } + else if (value.TryGetBytes(sequence: out var bytesSeq)) + { + WriteBulkString(bytesSeq); + } + else if (value.TryGetChars(sequence: out var charsSeq)) + { + WriteBulkString(charsSeq); + } + else + { + Throw(); + } + + static void Throw() => throw new InvalidOperationException($"It was not possible to read the {nameof(SimpleString)} contents"); + } + */ + + /// + /// Write an integer as a bulk string. + /// + public void WriteBulkString(bool value) => WriteBulkString(value ? 1 : 0); + + /// + /// Write a floating point as a bulk string. + /// + public void WriteBulkString(double value) + { + if (value == 0.0 | double.IsNaN(value) | double.IsInfinity(value)) + { + WriteKnownDouble(ref this, value); + + static void WriteKnownDouble(ref RespWriter writer, double value) + { + if (value == 0.0) + { + writer.WriteRaw("$1\r\n0\r\n"u8); + } + else if (double.IsNaN(value)) + { + writer.WriteRaw("$3\r\nnan\r\n"u8); + } + else if (double.IsPositiveInfinity(value)) + { + writer.WriteRaw("$3\r\ninf\r\n"u8); + } + else if (double.IsNegativeInfinity(value)) + { + writer.WriteRaw("$4\r\n-inf\r\n"u8); + } + else + { + Throw(); + static void Throw() => throw new ArgumentOutOfRangeException(nameof(value)); + } + } + } + else + { + Debug.Assert(RespConstants.MaxProtocolBytesBytesNumber <= 32); + Span scratch = stackalloc byte[24]; + if (!Utf8Formatter.TryFormat(value, scratch, out int bytes, G17)) + ThrowFormatException(); + WritePrefixedInteger(RespPrefix.BulkString, bytes); + WriteRaw(scratch.Slice(0, bytes)); + WriteCrLf(); + } + } + + private static readonly StandardFormat G17 = new('G', 17); + + /// + /// Write an integer as a bulk string. + /// + public void WriteBulkString(long value) + { + if (value >= -1 & value <= 20) + { + WriteRaw(value switch + { + -1 => "$2\r\n-1\r\n"u8, + 0 => "$1\r\n0\r\n"u8, + 1 => "$1\r\n1\r\n"u8, + 2 => "$1\r\n2\r\n"u8, + 3 => "$1\r\n3\r\n"u8, + 4 => "$1\r\n4\r\n"u8, + 5 => "$1\r\n5\r\n"u8, + 6 => "$1\r\n6\r\n"u8, + 7 => "$1\r\n7\r\n"u8, + 8 => "$1\r\n8\r\n"u8, + 9 => "$1\r\n9\r\n"u8, + 10 => "$2\r\n10\r\n"u8, + 11 => "$2\r\n11\r\n"u8, + 12 => "$2\r\n12\r\n"u8, + 13 => "$2\r\n13\r\n"u8, + 14 => "$2\r\n14\r\n"u8, + 15 => "$2\r\n15\r\n"u8, + 16 => "$2\r\n16\r\n"u8, + 17 => "$2\r\n17\r\n"u8, + 18 => "$2\r\n18\r\n"u8, + 19 => "$2\r\n19\r\n"u8, + 20 => "$2\r\n20\r\n"u8, + _ => Throw(), + }); + + static ReadOnlySpan Throw() => throw new ArgumentOutOfRangeException(nameof(value)); + } + else if (Available >= RespConstants.MaxProtocolBytesBulkStringIntegerInt64) + { + var singleDigit = value >= -99_999_999 && value <= 999_999_999; + WriteRawUnsafe((byte)RespPrefix.BulkString); + + var target = Tail.Slice(singleDigit ? 3 : 4); // N\r\n or NN\r\n + if (!Utf8Formatter.TryFormat(value, target, out var valueBytes)) + ThrowFormatException(); + + Debug.Assert(valueBytes > 0 && singleDigit ? valueBytes < 10 : valueBytes is 10 or 11); + if (!Utf8Formatter.TryFormat(valueBytes, Tail, out var prefixBytes)) + ThrowFormatException(); + Debug.Assert(prefixBytes == (singleDigit ? 1 : 2)); + _index += prefixBytes; + WriteCrLfUnsafe(); + _index += valueBytes; + WriteCrLfUnsafe(); + } + else + { + Debug.Assert(RespConstants.MaxRawBytesInt64 <= 24); + Span scratch = stackalloc byte[24]; + if (!Utf8Formatter.TryFormat(value, scratch, out int bytes)) + ThrowFormatException(); + WritePrefixedInteger(RespPrefix.BulkString, bytes); + WriteRaw(scratch.Slice(0, bytes)); + WriteCrLf(); + } + } + + private static void ThrowFormatException() => throw new FormatException(); + + private void WritePrefixedInteger(RespPrefix prefix, int length) + { + if (Available >= RespConstants.MaxProtocolBytesIntegerInt32) + { + WriteRawUnsafe((byte)prefix); + if (length >= 0 & length <= 9) + { + WriteRawUnsafe((byte)(length + '0')); + } + else + { + if (!Utf8Formatter.TryFormat(length, Tail, out var bytesWritten)) + { + ThrowFormatException(); + } + + _index += bytesWritten; + } + + WriteCrLfUnsafe(); + } + else + { + WriteViaStack(ref this, prefix, length); + } + + static void WriteViaStack(ref RespWriter respWriter, RespPrefix prefix, int length) + { + Debug.Assert(RespConstants.MaxProtocolBytesIntegerInt32 <= 16); + Span buffer = stackalloc byte[16]; + buffer[0] = (byte)prefix; + int payloadLength; + if (length >= 0 & length <= 9) + { + buffer[1] = (byte)(length + '0'); + payloadLength = 1; + } + else if (!Utf8Formatter.TryFormat(length, buffer.Slice(1), out payloadLength)) + { + ThrowFormatException(); + } + + Unsafe.WriteUnaligned(ref buffer[payloadLength + 1], RespConstants.CrLfUInt16); + respWriter.WriteRaw(buffer.Slice(0, payloadLength + 3)); + } + + bool writeToStack = Available < RespConstants.MaxProtocolBytesIntegerInt32; + + Span target = writeToStack ? stackalloc byte[16] : Tail; + target[0] = (byte)prefix; + } + + /// + /// Write a payload as a bulk string. + /// + /// The payload to write. + public void WriteBulkString(string value) + { + // ReSharper disable once ConditionIsAlwaysTrueOrFalseAccordingToNullableAPIContract + if (value is null) ThrowNull(); + WriteBulkString(value.AsSpan()); + } + + [MethodImpl(MethodImplOptions.NoInlining), DoesNotReturn] + // ReSharper disable once NotResolvedInText + private static void ThrowNull() => + throw new ArgumentNullException("value", "Null values cannot be sent from client to server"); + + internal void WriteBulkStringUnoptimized(string? value) + { + if (value is null) ThrowNull(); + if (value.Length == 0) + { + WriteRaw("$0\r\n\r\n"u8); + } + else + { + var byteCount = RespConstants.UTF8.GetByteCount(value); + WritePrefixedInteger(RespPrefix.BulkString, byteCount); + if (Available >= byteCount) + { + var actual = RespConstants.UTF8.GetBytes(value.AsSpan(), Tail); + Debug.Assert(actual == byteCount); + _index += actual; + } + else + { + WriteUtf8Slow(value.AsSpan(), byteCount); + } + + WriteCrLf(); + } + } + + /// + /// Write a payload as a bulk string. + /// + /// The payload to write. + public void WriteBulkString(ReadOnlyMemory value) => WriteBulkString(value.Span); + + /// + /// Write a payload as a bulk string. + /// + /// The payload to write. + public void WriteBulkString(scoped ReadOnlySpan value) + { + if (value.Length == 0) + { + if (Available >= 6) + { + WriteRawPrechecked(Raw.BulkStringEmpty_6, 6); + } + else + { + WriteRaw("$0\r\n\r\n"u8); + } + } + else + { + var byteCount = RespConstants.UTF8.GetByteCount(value); + WriteBulkStringHeader(byteCount); + if (Available >= 2 + byteCount) + { + var actual = RespConstants.UTF8.GetBytes(value, Tail); + Debug.Assert(actual == byteCount); + _index += actual; + WriteCrLfUnsafe(); + } + else + { + FlushAndGetBuffer(Math.Min(byteCount, MAX_BUFFER_HINT)); + if (Available >= byteCount + 2) + { + // that'll work + var actual = RespConstants.UTF8.GetBytes(value, Tail); + Debug.Assert(actual == byteCount); + _index += actual; + WriteCrLfUnsafe(); + } + else + { + WriteUtf8Slow(value, byteCount); + WriteCrLf(); + } + } + } + } + + private const int MAX_BUFFER_HINT = 64 * 1024; + + private void WriteUtf8Slow(scoped ReadOnlySpan value, int remaining) + { + var enc = _perThreadEncoder; + if (enc is null) + { + enc = _perThreadEncoder = RespConstants.UTF8.GetEncoder(); + } + else + { + enc.Reset(); + } + + bool completed; + int charsUsed, bytesUsed; + do + { + enc.Convert(value, Tail, false, out charsUsed, out bytesUsed, out completed); + value = value.Slice(charsUsed); + _index += bytesUsed; + remaining -= bytesUsed; + FlushAndGetBuffer(Math.Min(remaining, MAX_BUFFER_HINT)); + } + // until done... + while (!completed); + + if (remaining != 0) + { + // any trailing data? + FlushAndGetBuffer(Math.Min(remaining, MAX_BUFFER_HINT)); + enc.Convert(value, Tail, true, out charsUsed, out bytesUsed, out completed); + Debug.Assert(charsUsed == 0 && completed); + _index += bytesUsed; + remaining -= bytesUsed; + } + + enc.Reset(); + Debug.Assert(remaining == 0); + } + + internal void WriteBulkString(in ReadOnlySequence value) + { + if (value.IsSingleSegment) + { +#if NETCOREAPP3_0_OR_GREATER + WriteBulkString(value.FirstSpan); +#else + WriteBulkString(value.First.Span); +#endif + } + else + { + // lazy for now + int len = checked((int)value.Length); + byte[] buffer = ArrayPool.Shared.Rent(len); + value.CopyTo(buffer); + WriteBulkString(new ReadOnlySpan(buffer, 0, len)); + ArrayPool.Shared.Return(buffer); + } + } + + internal void WriteBulkString(in ReadOnlySequence value) + { + if (value.IsSingleSegment) + { +#if NETCOREAPP3_0_OR_GREATER + WriteBulkString(value.FirstSpan); +#else + WriteBulkString(value.First.Span); +#endif + } + else + { + // lazy for now + int len = checked((int)value.Length); + char[] buffer = ArrayPool.Shared.Rent(len); + value.CopyTo(buffer); + WriteBulkString(new ReadOnlySpan(buffer, 0, len)); + ArrayPool.Shared.Return(buffer); + } + } + + /// + /// Experimental. + /// + public void WriteBulkString(int value) + { + if (Available >= sizeof(ulong)) + { + switch (value) + { + case -1: + WriteRawPrechecked(Raw.BulkStringInt32_M1_8, 8); + return; + case 0: + WriteRawPrechecked(Raw.BulkStringInt32_0_7, 7); + return; + case 1: + WriteRawPrechecked(Raw.BulkStringInt32_1_7, 7); + return; + case 2: + WriteRawPrechecked(Raw.BulkStringInt32_2_7, 7); + return; + case 3: + WriteRawPrechecked(Raw.BulkStringInt32_3_7, 7); + return; + case 4: + WriteRawPrechecked(Raw.BulkStringInt32_4_7, 7); + return; + case 5: + WriteRawPrechecked(Raw.BulkStringInt32_5_7, 7); + return; + case 6: + WriteRawPrechecked(Raw.BulkStringInt32_6_7, 7); + return; + case 7: + WriteRawPrechecked(Raw.BulkStringInt32_7_7, 7); + return; + case 8: + WriteRawPrechecked(Raw.BulkStringInt32_8_7, 7); + return; + case 9: + WriteRawPrechecked(Raw.BulkStringInt32_9_7, 7); + return; + case 10: + WriteRawPrechecked(Raw.BulkStringInt32_10_8, 8); + return; + } + } + + WriteBulkStringUnoptimized(value); + } + + internal void WriteBulkStringUnoptimized(int value) + { + if (Available >= RespConstants.MaxProtocolBytesBulkStringIntegerInt32) + { + var singleDigit = value >= -99_999_999 && value <= 999_999_999; + WriteRawUnsafe((byte)RespPrefix.BulkString); + + var target = Tail.Slice(singleDigit ? 3 : 4); // N\r\n or NN\r\n + if (!Utf8Formatter.TryFormat(value, target, out var valueBytes)) + ThrowFormatException(); + + Debug.Assert(valueBytes > 0 && singleDigit ? valueBytes < 10 : valueBytes is 10 or 11); + if (!Utf8Formatter.TryFormat(valueBytes, Tail, out var prefixBytes)) + ThrowFormatException(); + Debug.Assert(prefixBytes == (singleDigit ? 1 : 2)); + _index += prefixBytes; + WriteCrLfUnsafe(); + _index += valueBytes; + WriteCrLfUnsafe(); + } + else + { + Debug.Assert(RespConstants.MaxRawBytesInt32 <= 16); + Span scratch = stackalloc byte[16]; + if (!Utf8Formatter.TryFormat(value, scratch, out int bytes)) + ThrowFormatException(); + WritePrefixedInteger(RespPrefix.BulkString, bytes); + WriteRaw(scratch.Slice(0, bytes)); + WriteCrLf(); + } + } + + /// + /// Write an array header. + /// + /// The number of elements in the array. + public void WriteArray(int count) + { + if (Available >= sizeof(uint)) + { + switch (count) + { + case 0: + WriteRawPrechecked(Raw.ArrayPrefix_0_4, 4); + return; + case 1: + WriteRawPrechecked(Raw.ArrayPrefix_1_4, 4); + return; + case 2: + WriteRawPrechecked(Raw.ArrayPrefix_2_4, 4); + return; + case 3: + WriteRawPrechecked(Raw.ArrayPrefix_3_4, 4); + return; + case 4: + WriteRawPrechecked(Raw.ArrayPrefix_4_4, 4); + return; + case 5: + WriteRawPrechecked(Raw.ArrayPrefix_5_4, 4); + return; + case 6: + WriteRawPrechecked(Raw.ArrayPrefix_6_4, 4); + return; + case 7: + WriteRawPrechecked(Raw.ArrayPrefix_7_4, 4); + return; + case 8: + WriteRawPrechecked(Raw.ArrayPrefix_8_4, 4); + return; + case 9: + WriteRawPrechecked(Raw.ArrayPrefix_9_4, 4); + return; + case 10 when Available >= sizeof(ulong): + WriteRawPrechecked(Raw.ArrayPrefix_10_5, 5); + return; + case -1: + WriteRawPrechecked(Raw.ArrayPrefix_M1_5, 5); + return; + } + } + + WritePrefixedInteger(RespPrefix.Array, count); + } + + private void WriteBulkStringHeader(int count) + { + if (Available >= sizeof(uint)) + { + switch (count) + { + case 0: + WriteRawPrechecked(Raw.BulkStringPrefix_0_4, 4); + return; + case 1: + WriteRawPrechecked(Raw.BulkStringPrefix_1_4, 4); + return; + case 2: + WriteRawPrechecked(Raw.BulkStringPrefix_2_4, 4); + return; + case 3: + WriteRawPrechecked(Raw.BulkStringPrefix_3_4, 4); + return; + case 4: + WriteRawPrechecked(Raw.BulkStringPrefix_4_4, 4); + return; + case 5: + WriteRawPrechecked(Raw.BulkStringPrefix_5_4, 4); + return; + case 6: + WriteRawPrechecked(Raw.BulkStringPrefix_6_4, 4); + return; + case 7: + WriteRawPrechecked(Raw.BulkStringPrefix_7_4, 4); + return; + case 8: + WriteRawPrechecked(Raw.BulkStringPrefix_8_4, 4); + return; + case 9: + WriteRawPrechecked(Raw.BulkStringPrefix_9_4, 4); + return; + case 10 when Available >= sizeof(ulong): + WriteRawPrechecked(Raw.BulkStringPrefix_10_5, 5); + return; + case -1 when Available >= sizeof(ulong): + WriteRawPrechecked(Raw.BulkStringPrefix_M1_5, 5); + return; + } + } + + WritePrefixedInteger(RespPrefix.BulkString, count); + } + + internal void WriteArrayUnpotimized(int count) => WritePrefixedInteger(RespPrefix.Array, count); + + private void WriteRawPrechecked(ulong value, int count) + { + Debug.Assert(Available >= sizeof(ulong)); + Debug.Assert(count >= 0 && count <= sizeof(long)); + Unsafe.WriteUnaligned(ref WriteHead, value); + _index += count; + } + + private void WriteRawPrechecked(uint value, int count) + { + Debug.Assert(Available >= sizeof(uint)); + Debug.Assert(count >= 0 && count <= sizeof(uint)); + Unsafe.WriteUnaligned(ref WriteHead, value); + _index += count; + } + + internal void DebugResetIndex() => _index = 0; + + [ThreadStatic] + // used for multi-chunk encoding + private static Encoder? _perThreadEncoder; +} diff --git a/src/RESP.Core/ResponseReader.cs b/src/RESP.Core/ResponseReader.cs new file mode 100644 index 000000000..1702f1695 --- /dev/null +++ b/src/RESP.Core/ResponseReader.cs @@ -0,0 +1,54 @@ +// using System.Buffers; +// using RESPite.Messages; +// +// namespace Resp; +// +// /// +// /// Base implementation for RESP writers that do not depend on the request parameter. +// /// +// public abstract class ResponseReader : IReader, IRespReader +// { +// TResponse IReader.Read(in Empty request, in ReadOnlySequence content) +// => Read(content); +// +// /// +// /// Read a raw RESP payload. +// /// +// public virtual TResponse Read(scoped in ReadOnlySequence content) +// { +// var reader = new RespReader(in content); +// reader.MoveNext(); +// return Read(ref reader); +// } +// +// /// +// /// Read a RESP payload via the API. +// /// +// public virtual TResponse Read(ref RespReader reader) +// => throw new NotSupportedException("A " + nameof(Read) + " overload must be overridden"); +// +// TResponse IRespReader.Read(in Empty request, ref RespReader reader) +// => Read(ref reader); +// } +// +// /// +// /// Base implementation for RESP writers that do depend on the request parameter. +// /// +// public abstract class ResponseReader : IReader, IRespReader +// { +// /// +// /// Read a raw RESP payload. +// /// +// public virtual TResponse Read(in TRequest request, in ReadOnlySequence content) +// { +// var reader = new RespReader(in content); +// reader.MoveNext(); +// return Read(in request, ref reader); +// } +// +// /// +// /// Read a RESP payload via the API. +// /// +// public virtual TResponse Read(in TRequest request, ref RespReader reader) +// => throw new NotSupportedException("A " + nameof(Read) + " overload must be overridden"); +// } diff --git a/src/RESP.Core/Void.cs b/src/RESP.Core/Void.cs new file mode 100644 index 000000000..d213c6e29 --- /dev/null +++ b/src/RESP.Core/Void.cs @@ -0,0 +1,7 @@ +namespace Resp; + +public readonly struct Void +{ + private static readonly Void _shared = default; + public static ref readonly Void Instance => ref _shared; +} diff --git a/src/RESPite.Benchmark/BenchmarkBase.cs b/src/RESPite.Benchmark/BenchmarkBase.cs new file mode 100644 index 000000000..a8d31cbea --- /dev/null +++ b/src/RESPite.Benchmark/BenchmarkBase.cs @@ -0,0 +1,682 @@ +using System; +using System.Buffers; +using System.Collections.Generic; +using System.ComponentModel; +using System.Diagnostics; +using System.Linq; +using System.Reflection; +using System.Threading; +using System.Threading.Tasks; + +// influenced by redis-benchmark, see .md file +namespace RESPite.Benchmark; + +public abstract class BenchmarkBase : IDisposable +{ + protected const string + GetSetKey = "key:__rand_int__", + CounterKey = "counter:__rand_int__", + ListKey = "mylist", + SetKey = "myset", + HashKey = "myhash", + SortedSetKey = "myzset", + StreamKey = "mystream"; + + public PipelineStrategy PipelineMode { get; } = + PipelineStrategy.Batch; // the default, for parity with how redis-benchmark works + + public enum PipelineStrategy + { + /// + /// Build a batch of operations, send them all at once. + /// + Batch, + + /// + /// Use a queue to pipeline operations - when we hit the pipeline depth, we pop one, push one, await the popped. + /// + Queue, + } + + private readonly HashSet _tests = new(StringComparer.OrdinalIgnoreCase); + protected bool RunTest(string name) => _tests.Count == 0 || _tests.Contains(name); + public virtual void Dispose() => GC.SuppressFinalize(this); + public int Port { get; } = 6379; + public int PipelineDepth { get; } = 1; + public bool Multiplexed { get; } + public bool SupportCancel { get; } + public bool Loop { get; } + public bool Quiet { get; } + public int ClientCount { get; } = 50; + private int _operationsPerClient; + public int OperationsPerClient(int divisor = 1) => _operationsPerClient / divisor; + + public int TotalOperations(int divisor = 1) => OperationsPerClient(divisor) * ClientCount; + + protected readonly byte[] Payload; + + protected BenchmarkBase(string[] args) + { + int operations = 100_000; + + string tests = ""; + for (int i = 0; i < args.Length; i++) + { + switch (args[i]) + { + case "-p" when i != args.Length - 1 && int.TryParse(args[++i], out int tmp) && tmp > 0: + Port = tmp; + break; + case "-c" when i != args.Length - 1 && int.TryParse(args[++i], out int tmp) && tmp > 0: + ClientCount = tmp; + break; + case "-n" when i != args.Length - 1 && int.TryParse(args[++i], out int tmp) && tmp > 0: + operations = tmp; + break; + case "-P" when i != args.Length - 1 && int.TryParse(args[++i], out int tmp) && tmp > 0: + PipelineDepth = tmp; + break; + case "+m": + Multiplexed = true; + break; + case "-m": + Multiplexed = false; + break; + case "+x": + SupportCancel = true; + break; + case "-c": + SupportCancel = false; + break; + case "-l": + Loop = true; + break; + case "-q": + Quiet = true; + break; + case "-t" when i != args.Length - 1: + tests = args[++i]; + break; + case "--batch": + PipelineMode = PipelineStrategy.Batch; + break; + case "--queue": + PipelineMode = PipelineStrategy.Queue; + break; + } + } + + if (!string.IsNullOrWhiteSpace(tests)) + { + foreach (var test in tests.Split(',')) + { + var t = test.Trim(); + if (!string.IsNullOrWhiteSpace(t)) _tests.Add(t); + } + } + + _operationsPerClient = operations / ClientCount; + + Payload = "abc"u8.ToArray(); + } + + public abstract Task RunAll(); + + public async Task RunBasicLoopAsync() + { + await DeleteAsync(CounterKey).ConfigureAwait(false); + + if (ClientCount <= 1) + { + await RunBasicLoopAsync(0); + } + else + { + Task[] tasks = new Task[ClientCount]; + for (int i = 0; i < ClientCount; i++) + { + var loopSnapshot = i; + tasks[i] = Task.Run(() => RunBasicLoopAsync(loopSnapshot)); + } + + await Task.WhenAll(tasks); + } + } + + protected abstract Task RunBasicLoopAsync(int clientId); + protected abstract Task DeleteAsync(string key); +} + +public abstract class BenchmarkBase(string[] args) : BenchmarkBase(args) +{ + protected override Task DeleteAsync(string key) => DeleteAsync(GetClient(0), key); + + protected virtual Task OnCleanupAsync(TClient client) => Task.CompletedTask; + + protected virtual Task InitAsync(TClient client) => Task.CompletedTask; + + public async Task CleanupAsync() + { + try + { + var client = GetClient(0); + await DeleteAsync(client, GetSetKey).ConfigureAwait(false); + await DeleteAsync(client, CounterKey).ConfigureAwait(false); + await DeleteAsync(client, ListKey).ConfigureAwait(false); + await DeleteAsync(client, SetKey).ConfigureAwait(false); + await DeleteAsync(client, HashKey).ConfigureAwait(false); + await DeleteAsync(client, SortedSetKey).ConfigureAwait(false); + await DeleteAsync(client, StreamKey).ConfigureAwait(false); + await OnCleanupAsync(client).ConfigureAwait(false); + } + catch (Exception ex) + { + await Console.Error.WriteLineAsync($"Cleanup: {ex.Message}"); + } + } + + protected virtual ValueTask Flush(TClient client) => default; + protected virtual void PrepareBatch(TClient client, int count) { } + + private async Task PipelineUntyped( + TClient client, + Func operation, + int divisor) + { + var opsPerClient = OperationsPerClient(divisor); + int i = 0; + try + { + if (PipelineDepth <= 1) + { + for (; i < opsPerClient; i++) + { + await operation(client).ConfigureAwait(false); + } + } + else if (PipelineMode == PipelineStrategy.Queue) + { + var queue = new Queue(opsPerClient); + for (; i < opsPerClient; i++) + { + if (queue.Count == opsPerClient) + { + await queue.Dequeue().ConfigureAwait(false); + } + + queue.Enqueue(operation(client)); + } + + while (queue.Count > 0) + { + await queue.Dequeue().ConfigureAwait(false); + } + } + else if (PipelineMode == PipelineStrategy.Batch) + { + int count = 0; + var oversized = ArrayPool.Shared.Rent(PipelineDepth); + PrepareBatch(client, Math.Min(opsPerClient, PipelineDepth)); + for (; i < opsPerClient; i++) + { + oversized[count++] = operation(client); + if (count == PipelineDepth) + { + await Flush(client).ConfigureAwait(false); + PrepareBatch(client, Math.Min(opsPerClient - i, PipelineDepth)); + for (int j = 0; j < count; j++) + { + await oversized[j].ConfigureAwait(false); + } + + count = 0; + } + } + + await Flush(client).ConfigureAwait(false); + for (int j = 0; j < count; j++) + { + await oversized[j].ConfigureAwait(false); + } + + ArrayPool.Shared.Return(oversized); + } + else + { + throw new InvalidOperationException($"Unexpected pipeline mode: {PipelineMode}"); + } + } + catch (Exception ex) + { + await Console.Error.WriteLineAsync($"{operation.Method.Name} failed after {i} operations"); + Program.WriteException(ex); + } + + return DBNull.Value; + } + + private async Task PipelineTyped(TClient client, Func> operation, int divisor) + { + var opsPerClient = OperationsPerClient(divisor); + int i = 0; + T result = default!; + try + { + if (PipelineDepth == 1) + { + for (; i < opsPerClient; i++) + { + result = await operation(client).ConfigureAwait(false); + } + } + else if (PipelineMode == PipelineStrategy.Queue) + { + var queue = new Queue>(opsPerClient); + for (; i < opsPerClient; i++) + { + if (queue.Count == opsPerClient) + { + _ = await queue.Dequeue().ConfigureAwait(false); + } + + queue.Enqueue(operation(client)); + } + + while (queue.Count > 0) + { + result = await queue.Dequeue().ConfigureAwait(false); + } + } + else if (PipelineMode == PipelineStrategy.Batch) + { + int count = 0; + var oversized = ArrayPool>.Shared.Rent(PipelineDepth); + PrepareBatch(client, Math.Min(opsPerClient, PipelineDepth)); + for (; i < opsPerClient; i++) + { + oversized[count++] = operation(client); + if (count == PipelineDepth) + { + await Flush(client).ConfigureAwait(false); + PrepareBatch(client, Math.Min(opsPerClient - (i + 1), PipelineDepth)); + for (int j = 0; j < count; j++) + { + result = await oversized[j].ConfigureAwait(false); + } + + count = 0; + } + } + + await Flush(client).ConfigureAwait(false); + for (int j = 0; j < count; j++) + { + result = await oversized[j].ConfigureAwait(false); + } + + ArrayPool>.Shared.Return(oversized); + } + else + { + throw new InvalidOperationException($"Unexpected pipeline mode: {PipelineMode}"); + } + } + catch (Exception ex) + { + await Console.Error.WriteLineAsync($"{operation.Method.Name} failed after {i} operations"); + Program.WriteException(ex); + } + + return result; + } + + public async Task InitAsync() + { + for (int i = 0; i < ClientCount; i++) + { + await InitAsync(GetClient(i)).ConfigureAwait(false); + } + } + + protected abstract TClient GetClient(int index); + protected virtual TClient WithCancellation(TClient client, CancellationToken cancellationToken) => client; + protected abstract Task DeleteAsync(TClient client, string key); + + protected abstract TClient CreateBatch(TClient client); + + protected Task RunAsync( + string? key, + Func> action, + bool deleteKey, + int divisor = 1) + => RunAsyncCore( + key, + GetNameCore(action, out var desc), + desc, + client => action(client).AsUntypedValueTask(), + client => PipelineTyped(client, action, divisor), + [], + deleteKey, + divisor); + + protected Task RunAsync( + string? key, + Func> action, + params string[] consumers) + => RunAsyncCore( + key, + GetNameCore(action, out var desc), + desc, + client => action(client).AsUntypedValueTask(), + client => PipelineTyped(client, action, 1), + consumers, + consumers.Length != 0, + 1); + + protected Task RunAsync( + string? key, + Func action, + bool deleteKey, + int divisor = 1) + => RunAsyncCore( + key, + GetNameCore(action, out var desc), + desc, + action, + client => PipelineUntyped(client, action, divisor), + [], + deleteKey, + divisor); + + protected Task RunAsync( + string? key, + Func action, + params string[] consumers) + => RunAsyncCore( + key, + GetNameCore(action, out var desc), + desc, + action, + client => PipelineUntyped(client, action, 1), + consumers, + consumers.Length != 0, + 1); + + private static string GetNameCore(Delegate underlyingAction, out string description) + { + string name = underlyingAction.Method.Name; + + if (underlyingAction.Method.GetCustomAttribute(typeof(DisplayNameAttribute)) is DisplayNameAttribute + { + DisplayName: { Length: > 0 } + } dna) + { + name = dna.DisplayName; + } + + description = ""; + if (underlyingAction.Method.GetCustomAttribute(typeof(DescriptionAttribute)) is DescriptionAttribute + { + Description: { Length: > 0 } + } da) + { + description = da.Description; + } + + return name; + } + + protected static string GetName(Func> action) => GetNameCore(action, out _); + protected static string GetName(Func action) => GetNameCore(action, out _); + + private async Task RunAsyncCore( + string? key, + string name, + string description, + Func test, + Func> pipeline, + string[] consumers, + bool deleteKey, + int divisor) + { + // skip test if not needed + string auxReason = ""; + if (!RunTest(name)) + { + auxReason = string.Join(", ", consumers.Where(x => RunTest(x))); + if (auxReason.Length == 0) return; // not needed by any consumers either + auxReason = $" (required for {auxReason})"; + } + + // include additional test metadata + if (description is { Length: > 0 }) + { + description = $" ({description})"; + } + + if (Quiet) + { + Console.Write($"{name}:"); + } + else + { + Console.Write( + $"====== {name}{description}{auxReason} ====== (clients: {ClientCount:#,##0}, ops: {TotalOperations(divisor):#,##0}"); + if (Multiplexed) + { + Console.Write(", mux"); + } + + if (SupportCancel) + { + Console.Write(", cancel"); + } + + Console.Write(PipelineDepth > 1 ? $", {PipelineMode}: {PipelineDepth:#,##0}" : ", sequential"); + + Console.WriteLine(")"); + } + + bool didNotRun = false; + try + { + if (key is not null && deleteKey) + { + await DeleteAsync(GetClient(0), key).ConfigureAwait(false); + } + + try + { + await test(GetClient(0)).ConfigureAwait(false); + } + catch (Exception ex) + { + await Console.Error.WriteLineAsync($"\t{ex.Message}"); + didNotRun = true; + return; + } + + var pending = new Task[ClientCount]; + int index = 0; +#if DEBUG + Internal.DebugCounters.Flush(); +#endif + // optionally support cancellation, applied per-test + CancellationToken cancellationToken = CancellationToken.None; + using var cts = SupportCancel ? new CancellationTokenSource(TimeSpan.FromSeconds(20)) : null; + if (SupportCancel) cancellationToken = cts!.Token; + + var watch = Stopwatch.StartNew(); + for (int i = 0; i < ClientCount; i++) + { + var client = GetClient(i); + if (PipelineMode == PipelineStrategy.Batch && PipelineDepth > 1) + { + client = CreateBatch(client); + } + + pending[index++] = Task.Run( + () => pipeline(WithCancellation(client, cancellationToken)), + cancellationToken); + } + + await Task.WhenAll(pending).ConfigureAwait(false); + watch.Stop(); + + var seconds = watch.Elapsed.TotalSeconds; + // ReSharper disable once PossibleLossOfFraction + var rate = TotalOperations(divisor) / seconds; + if (Quiet) + { + Console.WriteLine($"\t{rate:###,###,##0} requests per second"); + return; + } + else + { + Console.WriteLine( + $"{TotalOperations(divisor):###,###,##0} requests completed in {seconds:0.00} seconds, {rate:###,###,##0} ops/sec"); + } + + if (!Quiet & typeof(T) != typeof(DBNull)) + { + const string format = "Typical result: {0}"; + + T result = await pending[^1]; + Console.WriteLine(format, result); + } + } + catch (Exception ex) + { + if (Quiet) Console.WriteLine(); + Program.WriteException(ex, name); + } + finally + { + _ = didNotRun; +#if DEBUG + var counters = Internal.DebugCounters.Flush(); // flush even if not showing + if (!Quiet & !didNotRun) + { + if (counters.WriteBytes != 0) + { + Console.Write($"Write: {FormatBytes(counters.WriteBytes)}"); + if (counters.SyncWriteCount != 0) Console.Write($"; {counters.SyncWriteCount:#,##0} sync"); + if (counters.AsyncWriteInlineCount != 0) + Console.Write($"; {counters.AsyncWriteInlineCount:#,##0} async-inline"); + if (counters.AsyncWriteCount != 0) Console.Write($"; {counters.AsyncWriteCount:#,##0} full-async"); + Console.WriteLine(); + } + + if (counters.ReadBytes != 0) + { + Console.Write($"Read: {FormatBytes(counters.ReadBytes)}"); + if (counters.ReadCount != 0) Console.Write($"; {counters.ReadCount:#,##0} sync"); + if (counters.AsyncReadInlineCount != 0) + Console.Write($"; {counters.AsyncReadInlineCount:#,##0} async-inline"); + if (counters.AsyncReadCount != 0) Console.Write($"; {counters.AsyncReadCount:#,##0} full-async"); + Console.WriteLine(); + } + + if (counters.DiscardFullCount + counters.DiscardPartialCount != 0) + { + Console.Write($"Discard average: {FormatBytes(counters.DiscardAverage)}"); + if (counters.DiscardFullCount != 0) Console.Write($"; {counters.DiscardFullCount} full"); + if (counters.DiscardPartialCount != 0) Console.Write($"; {counters.DiscardPartialCount} partial"); + Console.WriteLine(); + } + + if (counters.CopyOutCount != 0) + { + Console.WriteLine( + $"Copy out: {FormatBytes(counters.CopyOutBytes)}; {counters.CopyOutCount:#,##0} times"); + } + + if (counters.PipelineFullAsyncCount != 0 + | counters.PipelineSendAsyncCount != 0 + | counters.PipelineFullSyncCount != 0) + { + Console.Write("Pipelining"); + if (counters.PipelineFullSyncCount != 0) + Console.Write($"; full sync: {counters.PipelineFullSyncCount:#,##0}"); + if (counters.PipelineSendAsyncCount != 0) + Console.Write($"; send async: {counters.PipelineSendAsyncCount:#,##0}"); + if (counters.PipelineFullAsyncCount != 0) + Console.Write($"; full async: {counters.PipelineFullAsyncCount:#,##0}"); + Console.WriteLine(); + } + + if (counters.BatchWriteCount != 0) + { + Console.Write($"Batching; {counters.BatchWriteCount:#,##0} batches"); + if (counters.BatchWriteFullPageCount != 0) + Console.Write($"; {counters.BatchWriteFullPageCount:#,###,##0} full pages"); + if (counters.BatchWritePartialPageCount != 0) + Console.Write($"; {counters.BatchWritePartialPageCount:#,###,##0} partial pages"); + if (counters.BatchWriteMessageCount != 0) + Console.Write($"; {counters.BatchWriteMessageCount:#,###,##0} messages"); + Console.WriteLine(); + } + + if (counters.BatchGrowCount != 0) + { + Console.WriteLine( + $"Batch growth; {counters.BatchGrowCount:#,##0} events, {counters.BatchGrowCopyCount:#,###,##0} elements copied"); + } + + if (counters.BatchBufferLeaseCount != 0 | counters.BatchMultiRootMessageCount != 0) + { + Console.Write( + $"Multi-message batching: {counters.BatchMultiRootMessageCount:#,###,##0} batches, {counters.BatchMultiChildMessageCount:#,###,##0} sub-messages"); + if (counters.BatchBufferLeaseCount != 0) + { + Console.Write( + $"; {counters.BatchBufferLeaseCount:#,###,##0} blocks leased, {counters.BatchBufferReturnCount:#,###,##0} blocks returned, {counters.BatchBufferElementsOutstanding:#,###,##0} elements outstanding"); + } + Console.WriteLine(); + } + + if (counters.BufferCreatedCount != 0 || + counters.BufferRecycledCount != 0 | counters.BufferMessageCount != 0) + { + Console.Write("Buffers"); + if (counters.BufferCreatedCount != 0) + { + Console.Write( + $"; created: {counters.BufferCreatedCount:#,###,##0}, {FormatBytes(counters.BufferTotalBytes)}"); + // always write recycled count - it being zero is important + Console.Write( + $"; recycled: {counters.BufferRecycledCount:#,###,##0}, {FormatBytes(counters.BufferRecycledBytes)}"); + } + + if (counters.BufferMessageCount != 0) + { + Console.Write( + $"; {counters.BufferMessageCount:#,###,##0} messages, {FormatBytes(counters.BufferMessageBytes)}"); + } + + Console.Write( + $"; max working {FormatBytes(counters.BufferMaxOutstandingBytes)}; {counters.BufferPinCount:#,###,##0} pins; {counters.BufferLeakCount:#,###,##0} leaks"); + Console.WriteLine(); + } + + static string FormatBytes(long bytes) + { + // ReSharper disable InconsistentNaming + const long k = 1024, M = k * k, G = M * k, T = G * k; + + // ReSharper restore InconsistentNaming + return bytes switch + { + < k => $"{bytes:#,##0} B", + < M => $"{bytes / (double)k:#,##0.00} KiB", + < G => $"{bytes / (double)M:#,##0.00} MiB", + < T => $"{bytes / (double)G:#,##0.00} GiB", + _ => $"{bytes / (double)T:#,##0.00} TiB", + }; + } + } +#endif + if (!Quiet) Console.WriteLine(); + } + } +} diff --git a/src/RESPite.Benchmark/BridgeBenchmark.cs b/src/RESPite.Benchmark/BridgeBenchmark.cs new file mode 100644 index 000000000..defa330c1 --- /dev/null +++ b/src/RESPite.Benchmark/BridgeBenchmark.cs @@ -0,0 +1,15 @@ +using RESPite.StackExchange.Redis; +using StackExchange.Redis; + +namespace RESPite.Benchmark; + +public sealed class BridgeBenchmark(string[] args) : OldCoreBenchmarkBase(args) +{ + public override string ToString() => "bridge SE.Redis"; + protected override IConnectionMultiplexer Create(int port) + { + var obj = new RespMultiplexer(); + obj.Connect("127.0.0.1:{Port}"); + return obj; + } +} diff --git a/src/RESPite.Benchmark/NewCoreBenchmark.cs b/src/RESPite.Benchmark/NewCoreBenchmark.cs new file mode 100644 index 000000000..2ab93790a --- /dev/null +++ b/src/RESPite.Benchmark/NewCoreBenchmark.cs @@ -0,0 +1,433 @@ +using System; +using System.ComponentModel; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; +using RESPite.Connections; +using RESPite.Messages; + +namespace RESPite.Benchmark; + +public sealed class NewCoreBenchmark : BenchmarkBase +{ + public override string ToString() => "new IO core"; + + private readonly RespConnectionPool _connectionPool; + + private readonly RespContext[] _clients; + private readonly (string Key, byte[] Value)[] _pairs; + + protected override RespContext GetClient(int index) => _clients[index]; + + protected override Task DeleteAsync(RespContext client, string key) => client.DelAsync(key).AsTask(); + + protected override RespContext WithCancellation(RespContext client, CancellationToken cancellationToken) + => client.WithCancellationToken(cancellationToken); + + protected override Task InitAsync(RespContext client) => client.PingAsync().AsTask(); + + public NewCoreBenchmark(string[] args) : base(args) + { + _clients = new RespContext[ClientCount]; + + _connectionPool = new(count: Multiplexed ? 1 : ClientCount); + _connectionPool.ConnectionError += (_, e) => Program.WriteException(e.Exception, e.Operation); + _pairs = new (string, byte[])[10]; + + for (var i = 0; i < 10; i++) + { + _pairs[i] = ($"key:__rand_int__{i}", Payload); + } + + if (Multiplexed) + { + var conn = _connectionPool.GetConnection().Synchronized(); + var ctx = conn.Context; + for (int i = 0; i < ClientCount; i++) // init all + { + _clients[i] = ctx; + } + } + else + { + for (int i = 0; i < ClientCount; i++) // init all + { + var conn = _connectionPool.GetConnection(); + if (PipelineDepth > 1) + { + conn = conn.Synchronized(); + } + + _clients[i] = conn.Context; + } + } + } + + public override void Dispose() + { + _connectionPool.Dispose(); + foreach (var client in _clients) + { + client.Connection.Dispose(); + } + } + + protected override async Task OnCleanupAsync(RespContext client) + { + foreach (var pair in _pairs) + { + await client.DelAsync(pair.Key).ConfigureAwait(false); + } + } + + public override async Task RunAll() + { + await InitAsync().ConfigureAwait(false); + // await RunAsync(PingInline).ConfigureAwait(false); + await RunAsync(null, PingBulk).ConfigureAwait(false); + + await RunAsync(GetSetKey, Set, GetName(Get)).ConfigureAwait(false); + await RunAsync(GetSetKey, Get).ConfigureAwait(false); + + await RunAsync(CounterKey, Incr, true).ConfigureAwait(false); + + await RunAsync(ListKey, LPush, GetName(LPop)).ConfigureAwait(false); + await RunAsync(ListKey, LPop).ConfigureAwait(false); + + await RunAsync(ListKey, RPush, GetName(RPop)).ConfigureAwait(false); + await RunAsync(ListKey, RPop).ConfigureAwait(false); + + await RunAsync(SetKey, SAdd, GetName(SPop)).ConfigureAwait(false); + await RunAsync(SetKey, SPop).ConfigureAwait(false); + + await RunAsync(HashKey, HSet).ConfigureAwait(false); + + await RunAsync(SortedSetKey, ZAdd, GetName(ZPopMin)).ConfigureAwait(false); + await RunAsync(SortedSetKey, ZPopMin).ConfigureAwait(false); + + await RunAsync(null, MSet).ConfigureAwait(false); + await RunAsync(StreamKey, XAdd).ConfigureAwait(false); + + // leave until last, they're slower + if (RunTest(GetName(LRange100)) || + RunTest(GetName(LRange300)) || + RunTest(GetName(LRange500)) || + RunTest(GetName(LRange600))) + { + await LRangeInit650(GetClient(0)).ConfigureAwait(false); + await RunAsync(ListKey, LRange100, false, 10).ConfigureAwait(false); + await RunAsync(ListKey, LRange300, false, 10).ConfigureAwait(false); + await RunAsync(ListKey, LRange500, false, 10).ConfigureAwait(false); + await RunAsync(ListKey, LRange600, false, 10).ConfigureAwait(false); + } + + await CleanupAsync().ConfigureAwait(false); + } + + protected override RespContext CreateBatch(RespContext client) => client.CreateBatch(PipelineDepth).Context; + + protected override ValueTask Flush(RespContext client) + { + if (client.Connection is RespBatch batch) + { + return new(batch.FlushAsync()); + } + + return default; + } + + protected override void PrepareBatch(RespContext client, int count) + { + if (client.Connection is RespBatch batch) + { + batch.EnsureCapacity(count); + } + } + + [DisplayName("PING_INLINE")] + // ReSharper disable once UnusedMember.Local + private ValueTask PingInline(RespContext ctx) => ctx.PingInlineAsync(Payload); + + [DisplayName("PING_BULK")] + private ValueTask PingBulk(RespContext ctx) => ctx.PingAsync(Payload); + + [DisplayName("INCR")] + private ValueTask Incr(RespContext ctx) => ctx.IncrAsync(CounterKey); + + [DisplayName("GET")] + private ValueTask Get(RespContext ctx) => ctx.GetAsync(GetSetKey); + + [DisplayName("SET")] + private ValueTask Set(RespContext ctx) => ctx.SetAsync(GetSetKey, Payload); + + [DisplayName("LPUSH")] + private ValueTask LPush(RespContext ctx) => ctx.LPushAsync(ListKey, Payload); + + [DisplayName("RPUSH")] + private ValueTask RPush(RespContext ctx) => ctx.RPushAsync(ListKey, Payload); + + [DisplayName("LRANGE_100")] + private ValueTask LRange100(RespContext ctx) => ctx.LRangeAsync(ListKey, 0, 99); + + [DisplayName("LRANGE_300")] + private ValueTask LRange300(RespContext ctx) => ctx.LRangeAsync(ListKey, 0, 299); + + [DisplayName("LRANGE_500")] + private ValueTask LRange500(RespContext ctx) => ctx.LRangeAsync(ListKey, 0, 499); + + [DisplayName("LRANGE_600")] + private ValueTask LRange600(RespContext ctx) => ctx.LRangeAsync(ListKey, 0, 599); + + [DisplayName("LPOP")] + private ValueTask LPop(RespContext ctx) => ctx.LPopAsync(ListKey); + + [DisplayName("RPOP")] + private ValueTask RPop(RespContext ctx) => ctx.RPopAsync(ListKey); + + [DisplayName("SADD")] + private ValueTask SAdd(RespContext ctx) => ctx.SAddAsync(SetKey, "element:__rand_int__"); + + [DisplayName("HSET")] + private ValueTask HSet(RespContext ctx) => ctx.HSetAsync(HashKey, "element:__rand_int__", Payload); + + [DisplayName("ZADD")] + private ValueTask ZAdd(RespContext ctx) => ctx.ZAddAsync(SortedSetKey, 0, "element:__rand_int__"); + + [DisplayName("ZPOPMIN")] + private ValueTask ZPopMin(RespContext ctx) => ctx.ZPopMinAsync(SortedSetKey); + + [DisplayName("SPOP")] + private ValueTask SPop(RespContext ctx) => ctx.SPopAsync(SetKey); + + [DisplayName("MSET"), Description("10 keys")] + private ValueTask MSet(RespContext ctx) => ctx.MSetAsync(_pairs); + + private async ValueTask LRangeInit650(RespContext ctx) + { + await ctx.DelAsync(ListKey).ConfigureAwait(false); + await ctx.LPushAsync(ListKey, Payload, 650); + if (await ctx.LLenAsync(ListKey).ConfigureAwait(false) != 650) + { + throw new InvalidOperationException(); + } + } + + [DisplayName("XADD")] + private ValueTask XAdd(RespContext ctx) => + ctx.XAddAsync(StreamKey, "*", "myfield", Payload); + + protected override async Task RunBasicLoopAsync(int clientId) + { + // The purpose of this is to represent a more realistic loop using natural code + // rather than code that is drowning in test infrastructure. + var client = GetClient(clientId); + var depth = PipelineDepth; + int tickCount = 0; // this is just so we don't query DateTime. + long previousValue = (await client.GetInt32Async(CounterKey).ConfigureAwait(false)) ?? 0, + currentValue = previousValue; + var watch = Stopwatch.StartNew(); + long previousMillis = watch.ElapsedMilliseconds; + + bool Tick() + { + var currentMillis = watch.ElapsedMilliseconds; + var elapsedMillis = currentMillis - previousMillis; + if (elapsedMillis >= 1000) + { + if (clientId == 0) // only one client needs to update the UI + { + var qty = currentValue - previousValue; + var seconds = elapsedMillis / 1000.0; + Console.WriteLine( + $"{qty:#,###,##0} ops in {seconds:#0.00}s, {qty / seconds:#,###,##0}/s\ttotal: {currentValue:#,###,###,##0}"); + + // reset for next UI update + previousValue = currentValue; + previousMillis = currentMillis; + } + + if (currentMillis >= 20_000) + { + if (clientId == 0) + { + Console.WriteLine(); + Console.WriteLine( + $"\t Overall: {currentValue:#,###,###,##0} ops in {currentMillis / 1000:#0.00}s, {currentValue / (currentMillis / 1000.0):#,###,##0}/s"); + Console.WriteLine(); + } + + return true; // stop after some time + } + } + + tickCount = 0; + return false; + } + + if (depth <= 1) + { + while (true) + { + currentValue = await client.IncrAsync(CounterKey).ConfigureAwait(false); + + if (++tickCount >= 1000 && Tick()) break; // only check whether to output every N iterations + } + } + else + { + ValueTask[] pending = new ValueTask[depth]; + await using var batch = client.CreateBatch(depth); + var ctx = batch.Context; + while (true) + { + for (int i = 0; i < depth; i++) + { + pending[i] = ctx.IncrAsync(CounterKey); + } + + await batch.FlushAsync().ConfigureAwait(false); + batch.EnsureCapacity(depth); // batches don't assume re-use + for (var i = 0; i < depth; i++) + { + currentValue = await pending[i].ConfigureAwait(false); + } + + tickCount += depth; + if (tickCount >= 1000 && Tick()) break; // only check whether to output every N iterations + } + } + } +} + +internal static partial class RedisCommands +{ + [RespCommand] + internal static partial RespParsers.ResponseSummary Ping(this in RespContext ctx); + + [RespCommand] + internal static partial RespParsers.ResponseSummary SPop(this in RespContext ctx, string key); + + [RespCommand] + internal static partial int SAdd(this in RespContext ctx, string key, string payload); + + [RespCommand] + internal static partial RespParsers.ResponseSummary Set(this in RespContext ctx, string key, byte[] payload); + + [RespCommand] + internal static partial int LLen(this in RespContext ctx, string key); + + [RespCommand] + internal static partial int LPush(this in RespContext ctx, string key, byte[] payload); + + [RespCommand(Formatter = LPushFormatter.Name)] + internal static partial int LPush(this in RespContext ctx, string key, byte[] payload, int count); + + private sealed class LPushFormatter : IRespFormatter<(string Key, byte[] Payload, int Count)> + { + public const string Name = $"{nameof(LPushFormatter)}.{nameof(Instance)}"; + private LPushFormatter() { } + public static readonly LPushFormatter Instance = new(); + + public void Format( + scoped ReadOnlySpan command, + ref RespWriter writer, + in (string Key, byte[] Payload, int Count) request) + { + writer.WriteCommand(command, request.Count + 1); + writer.WriteKey(request.Key); + for (int i = 0; i < request.Count; i++) + { + // duplicate for lazy bulk load + writer.WriteBulkString(request.Payload); + } + } + } + + [RespCommand] + internal static partial int RPush(this in RespContext ctx, string key, byte[] payload); + + [RespCommand] + internal static partial RespParsers.ResponseSummary LPop(this in RespContext ctx, string key); + + [RespCommand] + internal static partial RespParsers.ResponseSummary RPop(this in RespContext ctx, string key); + + [RespCommand] + internal static partial RespParsers.ResponseSummary + LRange(this in RespContext ctx, string key, int start, int stop); + + [RespCommand] + internal static partial int HSet(this in RespContext ctx, string key, string field, byte[] payload); + + [RespCommand] + internal static partial RespParsers.ResponseSummary Ping(this in RespContext ctx, byte[] payload); + + [RespCommand] + internal static partial int Incr(this in RespContext ctx, string key); + + [RespCommand] + internal static partial RespParsers.ResponseSummary Del(this in RespContext ctx, string key); + + [RespCommand] + internal static partial RespParsers.ResponseSummary ZPopMin(this in RespContext ctx, string key); + + [RespCommand] + internal static partial int ZAdd(this in RespContext ctx, string key, double score, string payload); + + [RespCommand("get")] + internal static partial int? GetInt32(this in RespContext ctx, string key); + + [RespCommand] + internal static partial RespParsers.ResponseSummary XAdd( + this in RespContext ctx, + string key, + string id, + string field, + byte[] value); + + [RespCommand] + internal static partial RespParsers.ResponseSummary Get(this in RespContext ctx, string key); + + [RespCommand(Formatter = PairsFormatter.Name)] // custom command formatter + internal static partial bool MSet(this in RespContext ctx, (string, byte[])[] pairs); + + internal static RespParsers.ResponseSummary PingInline(this in RespContext ctx, byte[] payload) + => ctx.Command("ping"u8, payload, InlinePingFormatter.Instance).Wait(RespParsers.ResponseSummary.Parser); + + internal static ValueTask PingInlineAsync(this in RespContext ctx, byte[] payload) + => ctx.Command("ping"u8, payload, InlinePingFormatter.Instance) + .Send(RespParsers.ResponseSummary.Parser); + + private sealed class InlinePingFormatter : IRespFormatter + { + private InlinePingFormatter() { } + public static readonly InlinePingFormatter Instance = new(); + + public void Format(scoped ReadOnlySpan command, ref RespWriter writer, in byte[] request) + { + writer.WriteRaw(command); + writer.WriteRaw(" "u8); + writer.WriteRaw(request); + writer.WriteRaw("\r\n"u8); + } + } + + private sealed class PairsFormatter : IRespFormatter<(string Key, byte[] Value)[]> + { + public const string Name = $"{nameof(PairsFormatter)}.{nameof(Instance)}"; + public static readonly PairsFormatter Instance = new PairsFormatter(); + + public void Format( + scoped ReadOnlySpan command, + ref RespWriter writer, + in (string Key, byte[] Value)[] request) + { + writer.WriteCommand(command, 2 * request.Length); + foreach (var pair in request) + { + writer.WriteKey(pair.Key); + writer.WriteBulkString(pair.Value); + } + } + } +} diff --git a/src/RESPite.Benchmark/OldCoreBenchmark.cs b/src/RESPite.Benchmark/OldCoreBenchmark.cs new file mode 100644 index 000000000..630460e2b --- /dev/null +++ b/src/RESPite.Benchmark/OldCoreBenchmark.cs @@ -0,0 +1,9 @@ +using StackExchange.Redis; + +namespace RESPite.Benchmark; + +public sealed class OldCoreBenchmark(string[] args) : OldCoreBenchmarkBase(args) +{ + public override string ToString() => "legacy SE.Redis"; + protected override IConnectionMultiplexer Create(int port) => ConnectionMultiplexer.Connect($"127.0.0.1:{Port}"); +} diff --git a/src/RESPite.Benchmark/OldCoreBenchmarkBase.cs b/src/RESPite.Benchmark/OldCoreBenchmarkBase.cs new file mode 100644 index 000000000..0bab0ea4c --- /dev/null +++ b/src/RESPite.Benchmark/OldCoreBenchmarkBase.cs @@ -0,0 +1,299 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Diagnostics; +using System.Threading.Tasks; +using StackExchange.Redis; + +namespace RESPite.Benchmark; + +public abstract class OldCoreBenchmarkBase : BenchmarkBase +{ + private readonly IConnectionMultiplexer _connectionMultiplexer; + private readonly IDatabase _client; + private readonly KeyValuePair[] _pairs; + + public OldCoreBenchmarkBase(string[] args) : base(args) + { + // ReSharper disable once VirtualMemberCallInConstructor + _connectionMultiplexer = Create(Port); + _client = _connectionMultiplexer.GetDatabase(); + _pairs = new KeyValuePair[10]; + + for (var i = 0; i < 10; i++) + { + _pairs[i] = new($"{"key:__rand_int__"}{i}", Payload); + } + } + + protected abstract IConnectionMultiplexer Create(int port); + + protected override async Task OnCleanupAsync(IDatabaseAsync client) + { + foreach (var pair in _pairs) + { + await client.KeyDeleteAsync(pair.Key); + } + } + + protected override Task InitAsync(IDatabaseAsync client) => client.PingAsync(); + + public override void Dispose() + { + _connectionMultiplexer.Dispose(); + } + + protected override IDatabaseAsync GetClient(int index) => _client; + protected override Task DeleteAsync(IDatabaseAsync client, string key) => client.KeyDeleteAsync(key); + + public override async Task RunAll() + { + await InitAsync().ConfigureAwait(false); + // await RunAsync(PingInline).ConfigureAwait(false); + await RunAsync(null, PingBulk).ConfigureAwait(false); + + await RunAsync(GetSetKey, Set, GetName(Get)).ConfigureAwait(false); + await RunAsync(GetSetKey, Get).ConfigureAwait(false); + + await RunAsync(CounterKey, Incr, true).ConfigureAwait(false); + + await RunAsync(ListKey, LPush, GetName(LPop)).ConfigureAwait(false); + await RunAsync(ListKey, LPop).ConfigureAwait(false); + + await RunAsync(ListKey, RPush, GetName(RPop)).ConfigureAwait(false); + await RunAsync(ListKey, RPop).ConfigureAwait(false); + + await RunAsync(SetKey, SAdd, GetName(SPop)).ConfigureAwait(false); + await RunAsync(SetKey, SPop).ConfigureAwait(false); + + await RunAsync(HashKey, HSet).ConfigureAwait(false); + + await RunAsync(SortedSetKey, ZAdd, GetName(ZPopMin)).ConfigureAwait(false); + await RunAsync(SortedSetKey, ZPopMin).ConfigureAwait(false); + + await RunAsync(null, MSet).ConfigureAwait(false); + await RunAsync(StreamKey, XAdd).ConfigureAwait(false); + + // leave until last, they're slower + if (RunTest(GetName(LRange100)) || + RunTest(GetName(LRange300)) || + RunTest(GetName(LRange500)) || + RunTest(GetName(LRange600))) + { + await LRangeInit650(GetClient(0)).ConfigureAwait(false); + await RunAsync(ListKey, LRange100, false, 10).ConfigureAwait(false); + await RunAsync(ListKey, LRange300, false, 10).ConfigureAwait(false); + await RunAsync(ListKey, LRange500, false, 10).ConfigureAwait(false); + await RunAsync(ListKey, LRange600, false, 10).ConfigureAwait(false); + } + + await CleanupAsync().ConfigureAwait(false); + } + + protected override IDatabaseAsync CreateBatch(IDatabaseAsync client) => ((IDatabase)client).CreateBatch(); + + protected override ValueTask Flush(IDatabaseAsync client) + { + if (client is IBatch batch) + { + batch.Execute(); + } + + return default; + } + + protected override async Task RunBasicLoopAsync(int clientId) + { + // The purpose of this is to represent a more realistic loop using natural code + // rather than code that is drowning in test infrastructure. + var client = (IDatabase)GetClient(clientId); // need IDatabase for CreateBatch + var depth = PipelineDepth; + int tickCount = 0; // this is just so we don't query DateTime. + var tmp = await client.StringGetAsync(CounterKey).ConfigureAwait(false); + long previousValue = tmp.IsNull ? 0 : (long)tmp, currentValue = previousValue; + var watch = Stopwatch.StartNew(); + long previousMillis = watch.ElapsedMilliseconds; + + bool Tick() + { + var currentMillis = watch.ElapsedMilliseconds; + var elapsedMillis = currentMillis - previousMillis; + if (elapsedMillis >= 1000) + { + if (clientId == 0) // only one client needs to update the UI + { + var qty = currentValue - previousValue; + var seconds = elapsedMillis / 1000.0; + Console.WriteLine( + $"{qty:#,###,##0} ops in {seconds:#0.00}s, {qty / seconds:#,###,##0}/s\ttotal: {currentValue:#,###,###,##0}"); + + // reset for next UI update + previousValue = currentValue; + previousMillis = currentMillis; + } + + if (currentMillis >= 20_000) + { + if (clientId == 0) + { + Console.WriteLine(); + Console.WriteLine( + $"\t Overall: {currentValue:#,###,###,##0} ops in {currentMillis / 1000:#0.00}s, {currentValue / (currentMillis / 1000.0):#,###,##0}/s"); + Console.WriteLine(); + } + + return true; // stop after some time + } + } + + tickCount = 0; + return false; + } + + if (depth <= 1) + { + while (true) + { + currentValue = await client.StringIncrementAsync(CounterKey).ConfigureAwait(false); + + if (++tickCount >= 1000 && Tick()) break; // only check whether to output every N iterations + } + } + else + { + Task[] pending = new Task[depth]; + var batch = client.CreateBatch(depth); + while (true) + { + for (int i = 0; i < depth; i++) + { + pending[i] = batch.StringIncrementAsync(CounterKey); + } + + batch.Execute(); + for (int i = 0; i < depth; i++) + { + currentValue = await pending[i].ConfigureAwait(false); + } + + tickCount += depth; + if (tickCount >= 1000 && Tick()) break; // only check whether to output every N iterations + } + } + } + + [DisplayName("GET")] + private ValueTask Get(IDatabaseAsync client) => GetAndMeasureString(client); + + private async ValueTask GetAndMeasureString(IDatabaseAsync client) + { + using var lease = await client.StringGetLeaseAsync(GetSetKey).ConfigureAwait(false); + return lease?.Length ?? -1; + } + + [DisplayName("SET")] + private ValueTask Set(IDatabaseAsync client) => client.StringSetAsync(GetSetKey, Payload).AsValueTask(); + + [DisplayName("PING_BULK")] + private ValueTask PingBulk(IDatabaseAsync client) => client.PingAsync().AsValueTask(); + + [DisplayName("INCR")] + private ValueTask Incr(IDatabaseAsync client) => client.StringIncrementAsync(CounterKey).AsValueTask(); + + [DisplayName("HSET")] + private ValueTask HSet(IDatabaseAsync client) => + client.HashSetAsync(HashKey, "element:__rand_int__", Payload).AsValueTask(); + + [DisplayName("SADD")] + private ValueTask SAdd(IDatabaseAsync client) => + client.SetAddAsync(SetKey, "element:__rand_int__").AsValueTask(); + + [DisplayName("LPUSH")] + private ValueTask LPush(IDatabaseAsync client) => client.ListLeftPushAsync(ListKey, Payload).AsValueTask(); + + [DisplayName("RPUSH")] + private ValueTask RPush(IDatabaseAsync client) => client.ListRightPushAsync(ListKey, Payload).AsValueTask(); + + [DisplayName("LPOP")] + private ValueTask LPop(IDatabaseAsync client) => client.ListLeftPopAsync(ListKey).AsValueTask(); + + [DisplayName("RPOP")] + private ValueTask RPop(IDatabaseAsync client) => client.ListRightPopAsync(ListKey).AsValueTask(); + + [DisplayName("SPOP")] + private ValueTask SPop(IDatabaseAsync client) => client.SetPopAsync(SetKey).AsValueTask(); + + [DisplayName("ZADD")] + private ValueTask ZAdd(IDatabaseAsync client) => + client.SortedSetAddAsync(SortedSetKey, "element:__rand_int__", 0).AsValueTask(); + + [DisplayName("ZPOPMIN")] + private ValueTask ZPopMin(IDatabaseAsync client) => HasSortedSetElement(client.SortedSetPopAsync(SortedSetKey)); + + private async ValueTask HasSortedSetElement(Task pending) + { + var result = await pending.ConfigureAwait(false); + return result.HasValue ? 1 : 0; + } + + [DisplayName("MSET")] + private ValueTask MSet(IDatabaseAsync client) => client.StringSetAsync(_pairs).AsValueTask(); + + [DisplayName("XADD")] + private ValueTask XAdd(IDatabaseAsync client) => + client.StreamAddAsync(StreamKey, "myfield", Payload).AsValueTask(); + + [DisplayName("LRANGE_100")] + private ValueTask LRange100(IDatabaseAsync client) => CountAsync(client.ListRangeAsync(ListKey, 0, 99)); + + [DisplayName("LRANGE_300")] + private ValueTask LRange300(IDatabaseAsync client) => CountAsync(client.ListRangeAsync(ListKey, 0, 299)); + + [DisplayName("LRANGE_500")] + private ValueTask LRange500(IDatabaseAsync client) => CountAsync(client.ListRangeAsync(ListKey, 0, 499)); + + [DisplayName("LRANGE_600")] + private ValueTask LRange600(IDatabaseAsync client) => + CountAsync(client.ListRangeAsync(ListKey, 0, 599)); + + private static ValueTask CountAsync(Task task) => task.ContinueWith( + t => t.Result.Length, TaskContinuationOptions.ExecuteSynchronously).AsValueTask(); + + private async ValueTask LRangeInit650(IDatabaseAsync client) + { + var batch = CreateBatch(client); + _ = batch.KeyDeleteAsync(ListKey, flags: CommandFlags.FireAndForget); + for (int i = 0; i < 650; i++) + { + _ = batch.ListLeftPushAsync(ListKey, Payload, flags: CommandFlags.FireAndForget); + } + + await Flush(batch).ConfigureAwait(false); + if (await client.ListLengthAsync(ListKey).ConfigureAwait(false) != 650) + { + throw new InvalidOperationException(); + } + } +} + +internal static class TaskExtensions +{ + public static ValueTask AsValueTask(this Task task) => new(task); + + /* + public static ValueTask AsUntypedValueTask(this Task task) => new(task); + public static ValueTask AsValueTask(this Task task) => new(task); + */ + + public static ValueTask AsUntypedValueTask(this ValueTask task) + { + if (!task.IsCompleted) return Awaited(task); + task.GetAwaiter().GetResult(); + return default; + + static async ValueTask Awaited(ValueTask task) + { + await task.ConfigureAwait(false); + } + } +} diff --git a/src/RESPite.Benchmark/Program.cs b/src/RESPite.Benchmark/Program.cs new file mode 100644 index 000000000..3d421d173 --- /dev/null +++ b/src/RESPite.Benchmark/Program.cs @@ -0,0 +1,98 @@ +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading.Tasks; + +namespace RESPite.Benchmark; + +internal static class Program +{ + private static async Task Main(string[] args) + { + bool basic = false; + try + { + List benchmarks = []; + foreach (var arg in args) + { + switch (arg) + { + case "--old": + benchmarks.Add(new OldCoreBenchmark(args)); + break; + case "--bridge": + benchmarks.Add(new BridgeBenchmark(args)); + break; + case "--new": + benchmarks.Add(new NewCoreBenchmark(args)); + break; + case "--basic": + basic = true; + break; + } + } + + if (benchmarks.Count == 0) + { + benchmarks.Add(new NewCoreBenchmark(args)); + } + + do + { + foreach (var bench in benchmarks) + { + if (benchmarks.Count > 1) + { + Console.WriteLine($"### {bench} ###"); + } + + if (basic) + { + await bench.RunBasicLoopAsync().ConfigureAwait(false); + } + else + { + await bench.RunAll().ConfigureAwait(false); + } + } + } + // ReSharper disable once LoopVariableIsNeverChangedInsideLoop + while (benchmarks[0].Loop); + + foreach (var bench in benchmarks) + { + bench.Dispose(); + } + return 0; + } + catch (Exception ex) + { + WriteException(ex); + return -1; + } + } + + internal static void WriteException(Exception? ex, [CallerMemberName] string operation = "") + { + Console.Error.WriteLine(); + Console.Error.WriteLine($"### EXCEPTION: {operation}"); + while (ex is not null) + { + Console.Error.WriteLine(); + Console.Error.WriteLine($"{ex.GetType().Name}: {ex.Message}"); + Console.Error.WriteLine($"\t{ex.StackTrace}"); + var data = ex.Data; + // ReSharper disable once ConditionIsAlwaysTrueOrFalseAccordingToNullableAPIContract + if (data is not null) + { + foreach (var key in data.Keys) + { + Console.Error.WriteLine($"\t{key}: {data[key]}"); + } + } + + ex = ex.InnerException; + } + Console.Error.WriteLine(); + } +} diff --git a/src/RESPite.Benchmark/RESPite.Benchmark.csproj b/src/RESPite.Benchmark/RESPite.Benchmark.csproj new file mode 100644 index 000000000..273e383d4 --- /dev/null +++ b/src/RESPite.Benchmark/RESPite.Benchmark.csproj @@ -0,0 +1,27 @@ + + + + enable + + Exe + net8.0;net9.0 + resp-benchmark + true + command-line "RESP" benchmark client, comparable to redis-benchmark + True + readme.md + False + True + false + false + 2025 - $([System.DateTime]::Now.Year) Marc Gravell + + + + + + + + + + diff --git a/src/RESPite.Benchmark/RespBenchmark.md b/src/RESPite.Benchmark/RespBenchmark.md new file mode 100644 index 000000000..b579d84de --- /dev/null +++ b/src/RESPite.Benchmark/RespBenchmark.md @@ -0,0 +1,352 @@ +# Influenced by redis-benchmark, which has typical output (with the default config) as below. + +Keys used (by default): + +- `key:__rand_int__` +- `counter:__rand_int__` +- `mylist` + +====== PING_INLINE ====== +100000 requests completed in 2.45 seconds +50 parallel clients +3 bytes payload +keep alive: 1 + +98.22% <= 1 milliseconds +99.88% <= 2 milliseconds +99.93% <= 3 milliseconds +99.99% <= 4 milliseconds +100.00% <= 5 milliseconds +100.00% <= 5 milliseconds +40849.68 requests per second + +====== PING_BULK ====== +100000 requests completed in 2.45 seconds +50 parallel clients +3 bytes payload +keep alive: 1 + +97.27% <= 1 milliseconds +99.86% <= 2 milliseconds +99.92% <= 3 milliseconds +99.94% <= 4 milliseconds +99.95% <= 23 milliseconds +99.96% <= 24 milliseconds +99.98% <= 25 milliseconds +100.00% <= 25 milliseconds +40866.37 requests per second + +====== SET ====== +100000 requests completed in 2.46 seconds +50 parallel clients +3 bytes payload +keep alive: 1 + +96.99% <= 1 milliseconds +99.47% <= 2 milliseconds +99.71% <= 3 milliseconds +99.86% <= 4 milliseconds +99.87% <= 9 milliseconds +99.88% <= 10 milliseconds +99.92% <= 11 milliseconds +99.93% <= 12 milliseconds +99.94% <= 13 milliseconds +99.96% <= 14 milliseconds +99.97% <= 15 milliseconds +99.97% <= 16 milliseconds +100.00% <= 27 milliseconds +40650.41 requests per second + +====== GET ====== +100000 requests completed in 3.00 seconds +50 parallel clients +3 bytes payload +keep alive: 1 + +90.56% <= 1 milliseconds +98.90% <= 2 milliseconds +99.46% <= 3 milliseconds +99.61% <= 4 milliseconds +99.70% <= 5 milliseconds +99.73% <= 6 milliseconds +99.75% <= 7 milliseconds +99.75% <= 9 milliseconds +99.77% <= 10 milliseconds +99.79% <= 12 milliseconds +99.80% <= 14 milliseconds +99.80% <= 15 milliseconds +99.83% <= 16 milliseconds +99.90% <= 17 milliseconds +99.93% <= 18 milliseconds +99.96% <= 19 milliseconds +99.98% <= 20 milliseconds +99.98% <= 22 milliseconds +99.98% <= 30 milliseconds +99.99% <= 31 milliseconds +100.00% <= 31 milliseconds +33377.84 requests per second + +====== INCR ====== +100000 requests completed in 2.94 seconds +50 parallel clients +3 bytes payload +keep alive: 1 + +93.21% <= 1 milliseconds +99.21% <= 2 milliseconds +99.70% <= 3 milliseconds +99.81% <= 4 milliseconds +99.86% <= 5 milliseconds +99.89% <= 6 milliseconds +99.93% <= 7 milliseconds +99.94% <= 8 milliseconds +99.96% <= 11 milliseconds +99.96% <= 12 milliseconds +99.96% <= 13 milliseconds +99.97% <= 14 milliseconds +99.97% <= 24 milliseconds +100.00% <= 24 milliseconds +34048.35 requests per second + +====== LPUSH ====== +100000 requests completed in 2.98 seconds +50 parallel clients +3 bytes payload +keep alive: 1 + +92.58% <= 1 milliseconds +99.21% <= 2 milliseconds +99.57% <= 3 milliseconds +99.71% <= 4 milliseconds +99.82% <= 5 milliseconds +99.85% <= 6 milliseconds +99.85% <= 7 milliseconds +99.88% <= 9 milliseconds +99.93% <= 10 milliseconds +99.93% <= 13 milliseconds +99.93% <= 14 milliseconds +99.95% <= 16 milliseconds +99.95% <= 31 milliseconds +99.99% <= 32 milliseconds +100.00% <= 32 milliseconds +33512.07 requests per second + +====== LPOP ====== +100000 requests completed in 2.91 seconds +50 parallel clients +3 bytes payload +keep alive: 1 + +92.81% <= 1 milliseconds +99.33% <= 2 milliseconds +99.89% <= 3 milliseconds +99.94% <= 4 milliseconds +99.96% <= 5 milliseconds +99.97% <= 15 milliseconds +99.98% <= 16 milliseconds +100.00% <= 17 milliseconds +34317.09 requests per second + +====== SADD ====== +100000 requests completed in 2.87 seconds +50 parallel clients +3 bytes payload +keep alive: 1 + +94.26% <= 1 milliseconds +99.58% <= 2 milliseconds +99.87% <= 3 milliseconds +99.93% <= 4 milliseconds +99.98% <= 17 milliseconds +99.98% <= 18 milliseconds +100.00% <= 19 milliseconds +34855.35 requests per second + +====== SPOP ====== +100000 requests completed in 2.99 seconds +50 parallel clients +3 bytes payload +keep alive: 1 + +91.00% <= 1 milliseconds +99.30% <= 2 milliseconds +99.69% <= 3 milliseconds +99.80% <= 4 milliseconds +99.85% <= 5 milliseconds +99.85% <= 8 milliseconds +99.86% <= 9 milliseconds +99.89% <= 10 milliseconds +99.92% <= 13 milliseconds +99.94% <= 14 milliseconds +99.95% <= 16 milliseconds +100.00% <= 16 milliseconds +33456.00 requests per second + +====== LPUSH (needed to benchmark LRANGE) ====== +100000 requests completed in 2.92 seconds +50 parallel clients +3 bytes payload +keep alive: 1 + +93.25% <= 1 milliseconds +99.45% <= 2 milliseconds +99.75% <= 3 milliseconds +99.86% <= 4 milliseconds +99.89% <= 5 milliseconds +99.91% <= 6 milliseconds +99.93% <= 9 milliseconds +99.95% <= 10 milliseconds +99.96% <= 11 milliseconds +99.97% <= 14 milliseconds +99.98% <= 15 milliseconds +99.99% <= 17 milliseconds +100.00% <= 18 milliseconds +100.00% <= 20 milliseconds +100.00% <= 20 milliseconds +34258.31 requests per second + +====== LRANGE_100 (first 100 elements) ====== +100000 requests completed in 4.33 seconds +50 parallel clients +3 bytes payload +keep alive: 1 + +35.50% <= 1 milliseconds +98.90% <= 2 milliseconds +99.61% <= 3 milliseconds +99.76% <= 4 milliseconds +99.83% <= 5 milliseconds +99.83% <= 7 milliseconds +99.84% <= 8 milliseconds +99.88% <= 9 milliseconds +99.88% <= 10 milliseconds +99.91% <= 11 milliseconds +99.91% <= 12 milliseconds +99.91% <= 13 milliseconds +99.96% <= 15 milliseconds +99.96% <= 34 milliseconds +99.97% <= 35 milliseconds +100.00% <= 39 milliseconds +100.00% <= 39 milliseconds +23089.36 requests per second + +====== LRANGE_300 (first 300 elements) ====== +100000 requests completed in 7.12 seconds +50 parallel clients +3 bytes payload +keep alive: 1 + +0.01% <= 1 milliseconds +84.00% <= 2 milliseconds +98.64% <= 3 milliseconds +99.44% <= 4 milliseconds +99.65% <= 5 milliseconds +99.70% <= 6 milliseconds +99.72% <= 7 milliseconds +99.75% <= 8 milliseconds +99.77% <= 9 milliseconds +99.81% <= 10 milliseconds +99.85% <= 11 milliseconds +99.87% <= 12 milliseconds +99.89% <= 13 milliseconds +99.90% <= 14 milliseconds +99.92% <= 15 milliseconds +99.96% <= 16 milliseconds +99.97% <= 17 milliseconds +99.99% <= 18 milliseconds +99.99% <= 26 milliseconds +99.99% <= 32 milliseconds +100.00% <= 37 milliseconds +100.00% <= 38 milliseconds +100.00% <= 39 milliseconds +14039.03 requests per second + +====== LRANGE_500 (first 450 elements) ====== +100000 requests completed in 8.32 seconds +50 parallel clients +3 bytes payload +keep alive: 1 + +0.71% <= 1 milliseconds +49.73% <= 2 milliseconds +96.81% <= 3 milliseconds +99.35% <= 4 milliseconds +99.79% <= 5 milliseconds +99.83% <= 6 milliseconds +99.84% <= 7 milliseconds +99.85% <= 8 milliseconds +99.91% <= 9 milliseconds +99.91% <= 10 milliseconds +99.91% <= 12 milliseconds +99.91% <= 27 milliseconds +99.91% <= 28 milliseconds +99.92% <= 29 milliseconds +99.93% <= 30 milliseconds +99.96% <= 31 milliseconds +99.96% <= 49 milliseconds +99.96% <= 50 milliseconds +99.98% <= 99 milliseconds +99.98% <= 100 milliseconds +100.00% <= 100 milliseconds +12022.12 requests per second + +====== LRANGE_600 (first 600 elements) ====== +100000 requests completed in 10.27 seconds +50 parallel clients +3 bytes payload +keep alive: 1 + +0.15% <= 1 milliseconds +28.15% <= 2 milliseconds +72.35% <= 3 milliseconds +96.20% <= 4 milliseconds +98.96% <= 5 milliseconds +99.68% <= 6 milliseconds +99.80% <= 7 milliseconds +99.85% <= 8 milliseconds +99.87% <= 9 milliseconds +99.88% <= 10 milliseconds +99.88% <= 11 milliseconds +99.88% <= 12 milliseconds +99.89% <= 13 milliseconds +99.89% <= 14 milliseconds +99.89% <= 15 milliseconds +99.90% <= 16 milliseconds +99.91% <= 17 milliseconds +99.91% <= 18 milliseconds +99.91% <= 19 milliseconds +99.92% <= 20 milliseconds +99.93% <= 21 milliseconds +99.95% <= 22 milliseconds +99.95% <= 23 milliseconds +99.96% <= 24 milliseconds +99.97% <= 25 milliseconds +99.97% <= 26 milliseconds +99.98% <= 27 milliseconds +100.00% <= 28 milliseconds +100.00% <= 29 milliseconds +100.00% <= 29 milliseconds +9736.15 requests per second + +====== MSET (10 keys) ====== +100000 requests completed in 2.94 seconds +50 parallel clients +3 bytes payload +keep alive: 1 + +92.48% <= 1 milliseconds +99.33% <= 2 milliseconds +99.91% <= 3 milliseconds +99.93% <= 4 milliseconds +99.94% <= 6 milliseconds +99.94% <= 11 milliseconds +99.96% <= 12 milliseconds +99.97% <= 13 milliseconds +99.98% <= 14 milliseconds +99.98% <= 17 milliseconds +99.99% <= 18 milliseconds +99.99% <= 19 milliseconds +99.99% <= 25 milliseconds +100.00% <= 30 milliseconds +100.00% <= 30 milliseconds +34059.95 requests per second \ No newline at end of file diff --git a/src/RESPite.Benchmark/readme.md b/src/RESPite.Benchmark/readme.md new file mode 100644 index 000000000..6767895f4 --- /dev/null +++ b/src/RESPite.Benchmark/readme.md @@ -0,0 +1,18 @@ +# resp-benchmark + +The `resp-benchmark` tool is a command-line "RESP" benchmark client, comparable to `redis-benchmark`, and +many of the arguments are the same. This is mostly for internal team usage, but is included here for +reference. + +Example usage: + +``` bash +> dotnet tool install -g RESPite.Benchmark + +# basic usage +> resp-benchmark + +# 50 clients, pipeline to 100, multiplexed, 1M operations, only test incr, loop +> resp-benchmark -c 50 -P 100 -n 1000000 +m -t incr -l + +``` \ No newline at end of file diff --git a/src/RESPite.Redis/Alt/DownlevelExtensions.cs b/src/RESPite.Redis/Alt/DownlevelExtensions.cs new file mode 100644 index 000000000..03fac7d6b --- /dev/null +++ b/src/RESPite.Redis/Alt/DownlevelExtensions.cs @@ -0,0 +1,15 @@ +using System.Runtime.CompilerServices; + +namespace RESPite.Redis.Alt; // legacy fallback for down-level compilers + +/// +/// For use with older compilers that don't support byref-return, extension-everything, etc. +/// +public static class DownlevelExtensions +{ + public static RedisStrings AsStrings(this in RespContext context) + => Unsafe.As(ref Unsafe.AsRef(in context)); + + public static RedisKeys AsKeys(this in RespContext context) + => Unsafe.As(ref Unsafe.AsRef(in context)); +} diff --git a/src/RESPite.Redis/Formatters.cs b/src/RESPite.Redis/Formatters.cs new file mode 100644 index 000000000..eecddd14d --- /dev/null +++ b/src/RESPite.Redis/Formatters.cs @@ -0,0 +1,9 @@ +namespace RESPite.Redis; + +internal static class Formatters +{ + private const string Global = "global::RESPite.Redis"; + + public const string KeyStringArray = + $"{Global}.{nameof(KeyStringArrayFormatter)}.{nameof(KeyStringArrayFormatter.Instance)}"; +} diff --git a/src/RESPite.Redis/KeyStringArrayFormatter.cs b/src/RESPite.Redis/KeyStringArrayFormatter.cs new file mode 100644 index 000000000..eb65b2507 --- /dev/null +++ b/src/RESPite.Redis/KeyStringArrayFormatter.cs @@ -0,0 +1,19 @@ +using System; +using RESPite; +using RESPite.Messages; + +namespace RESPite.Redis; + +internal sealed class KeyStringArrayFormatter : IRespFormatter> +{ + public static readonly KeyStringArrayFormatter Instance = new(); + + public void Format(scoped ReadOnlySpan command, ref RespWriter writer, in ReadOnlyMemory keys) + { + writer.WriteCommand(command, keys.Length); + foreach (var key in keys.Span) + { + writer.WriteKey(key); + } + } +} diff --git a/src/RESPite.Redis/RESPite.Redis.csproj b/src/RESPite.Redis/RESPite.Redis.csproj new file mode 100644 index 000000000..33fc28fd1 --- /dev/null +++ b/src/RESPite.Redis/RESPite.Redis.csproj @@ -0,0 +1,22 @@ + + + + true + net461;netstandard2.0;net8.0 + $(NoWarn);CS1591 + 2025 - $([System.DateTime]::Now.Year) Marc Gravell + + + + + + + + + NullableHacks.cs + + + SkipLocalsInit.cs + + + diff --git a/src/RESPite.Redis/RedisExtensions.cs b/src/RESPite.Redis/RedisExtensions.cs new file mode 100644 index 000000000..f4468bebe --- /dev/null +++ b/src/RESPite.Redis/RedisExtensions.cs @@ -0,0 +1,21 @@ +using System.Runtime.CompilerServices; + +namespace RESPite.Redis; + +public static class RedisExtensions +{ +#if PREVIEW_LANGVER + extension(in RespContext context) + { + // since this is valid... + // public ref readonly RespContext Self => ref context; + + // so must this be (importantly, RedisStrings has only a single RespContext field) + public ref readonly RedisStrings Strings + => ref Unsafe.As(ref Unsafe.AsRef(in context)); + + public ref readonly RedisKeys Keys + => ref Unsafe.As(ref Unsafe.AsRef(in context)); + } +#endif +} diff --git a/src/RESPite.Redis/RedisKeys.cs b/src/RESPite.Redis/RedisKeys.cs new file mode 100644 index 000000000..1e9d4531e --- /dev/null +++ b/src/RESPite.Redis/RedisKeys.cs @@ -0,0 +1,16 @@ +using System; + +namespace RESPite.Redis; + +// note that members may also be added as extensions if necessary +public readonly partial struct RedisKeys(in RespContext context) +{ + // ReSharper disable once UnusedMember.Local + private readonly RespContext _context = context; + + [RespCommand] + public partial void Del(string key); + + [RespCommand(Formatter = Formatters.KeyStringArray)] + public partial int Del(ReadOnlyMemory keys); +} diff --git a/src/RESPite.Redis/RedisStrings.cs b/src/RESPite.Redis/RedisStrings.cs new file mode 100644 index 000000000..02d224dc6 --- /dev/null +++ b/src/RESPite.Redis/RedisStrings.cs @@ -0,0 +1,99 @@ +using System; +using System.Threading.Tasks; +using RESPite.Messages; + +#if !PREVIEW_LANGVER +using RESPite.Redis.Alt; +#endif + +namespace RESPite.Redis; + +// note that members may also be added as extensions if necessary +public readonly partial struct RedisStrings(in RespContext context) +{ + private readonly RespContext _context = context; + + // re-expose del +#if PREVIEW_LANGVER + public void Del(string key) => _context.Keys.Del(key); + public ValueTask DelAsync(string key) => _context.Keys.DelAsync(key); +#else + public void Del(string key) => _context.AsKeys().Del(key); + public ValueTask DelAsync(string key) => _context.AsKeys().DelAsync(key); +#endif + + [RespCommand] + public partial int Append(string key, string value); + + [RespCommand] + public partial int Append(string key, ReadOnlyMemory value); + + [RespCommand] + public partial int Decr(string key); + + [RespCommand] + public partial int DecrBy(string key, int value); + + [RespCommand] + public partial string Get(string key); + + [RespCommand("get")] + public partial int GetInt32(string key); + + [RespCommand("get")] + public partial double GetDouble(string key); + + [RespCommand] + public partial string GetDel(string key); + + [RespCommand(Formatter = ExpiryTimeSpanFormatter.Formatter)] + public partial string GetEx(string key, TimeSpan expiry); + + private sealed class ExpiryTimeSpanFormatter : IRespFormatter<(string Key, TimeSpan Expiry)> + { + public const string Formatter = $"{nameof(ExpiryTimeSpanFormatter)}.{nameof(Instance)}"; + + public static readonly ExpiryTimeSpanFormatter Instance = new(); + + public void Format( + scoped ReadOnlySpan command, + ref RespWriter writer, + in (string Key, TimeSpan Expiry) request) + { + writer.WriteCommand(command, 3); + writer.WriteKey(request.Key); + writer.WriteBulkString("PX"u8); + writer.WriteBulkString((long)request.Expiry.TotalMilliseconds); + } + } + + [RespCommand] + public partial string GetRange(string key, int start, int end); + + [RespCommand] + public partial string GetSet(string key, string value); + + [RespCommand] + public partial string GetSet(string key, ReadOnlyMemory value); + + [RespCommand] + public partial int Incr(string key); + + [RespCommand] + public partial int IncrBy(string key, int value); + + [RespCommand] + public partial double IncrByFloat(string key, double value); + + [RespCommand] + public partial void Set(string key, string value); + + [RespCommand] + public partial void Set(string key, ReadOnlyMemory value); + + [RespCommand] + public partial void Set(string key, int value); + + [RespCommand] + public partial void Set(string key, double value); +} diff --git a/src/RESPite.StackExchange.Redis/Global.cs b/src/RESPite.StackExchange.Redis/Global.cs new file mode 100644 index 000000000..0751e11a8 --- /dev/null +++ b/src/RESPite.StackExchange.Redis/Global.cs @@ -0,0 +1,9 @@ +#if NET5_0_OR_GREATER +[module:global::System.Runtime.CompilerServices.SkipLocalsInit] +#else +// we've gone some disambiguation to do... +extern alias seredis; +global using DoesNotReturnAttribute = seredis::System.Diagnostics.CodeAnalysis.DoesNotReturnAttribute; + +[module:seredis::System.Runtime.CompilerServices.SkipLocalsInit] +#endif diff --git a/src/RESPite.StackExchange.Redis/RESPite.StackExchange.Redis.csproj b/src/RESPite.StackExchange.Redis/RESPite.StackExchange.Redis.csproj new file mode 100644 index 000000000..7f6d1952a --- /dev/null +++ b/src/RESPite.StackExchange.Redis/RESPite.StackExchange.Redis.csproj @@ -0,0 +1,34 @@ + + + + true + net461;netstandard2.0;net8.0 + enable + enable + $(NoWarn);CS1591 + readme.md + false + 2025 - $([System.DateTime]::Now.Year) Marc Gravell + + + + + + + + + + + + + + + + + RespContextDatabase.cs + + + RedisCommands.cs + + + diff --git a/src/RESPite.StackExchange.Redis/RedisCommands.HashCommands.cs b/src/RESPite.StackExchange.Redis/RedisCommands.HashCommands.cs new file mode 100644 index 000000000..48e52aa67 --- /dev/null +++ b/src/RESPite.StackExchange.Redis/RedisCommands.HashCommands.cs @@ -0,0 +1,875 @@ +using System.Runtime.CompilerServices; +using RESPite.Internal; +using RESPite.Messages; +using StackExchange.Redis; + +// ReSharper disable MemberCanBePrivate.Global +// ReSharper disable InconsistentNaming +namespace RESPite.StackExchange.Redis; + +internal static partial class RedisCommands +{ + // this is just a "type pun" - it should be an invisible/magic pointer cast to the JIT + public static ref readonly HashCommands Hashes(this in RespContext context) + => ref Unsafe.As(ref Unsafe.AsRef(in context)); +} + +public readonly struct HashCommands(in RespContext context) +{ + public readonly RespContext Context = context; // important: this is the only field +} + +internal static partial class HashCommandsExtensions +{ + [RespCommand] + public static partial RespOperation HDel(this in HashCommands context, RedisKey key, RedisValue field); + + [RespCommand] + public static partial RespOperation HDel(this in HashCommands context, RedisKey key, RedisValue[] fields); + + [RespCommand] + public static partial RespOperation HExists(this in HashCommands context, RedisKey key, RedisValue field); + + [RespCommand(Parser = "ExpireResultParser.Default")] + private static partial RespOperation HExpire( + this in HashCommands context, + RedisKey key, + long seconds, + [RespIgnore(ExpireWhen.Always)] ExpireWhen when, + [RespPrefix("FIELDS"), RespPrefix] RedisValue[] fields); + + [RespCommand(Parser = "ExpireResultParser.Default")] + private static partial RespOperation HExpireAt( + this in HashCommands context, + RedisKey key, + long seconds, + [RespIgnore(ExpireWhen.Always)] ExpireWhen when, + [RespPrefix("FIELDS"), RespPrefix] RedisValue[] fields); + + [RespCommand(Parser = "ExpireResultParser.Default")] + private static partial RespOperation HPExpire( + this in HashCommands context, + RedisKey key, + long milliseconds, + [RespIgnore(ExpireWhen.Always)] ExpireWhen when, + [RespPrefix("FIELDS"), RespPrefix] RedisValue[] fields); + + [RespCommand(Parser = "ExpireResultParser.Default")] + private static partial RespOperation HPExpireAt( + this in HashCommands context, + RedisKey key, + long milliseconds, + [RespIgnore(ExpireWhen.Always)] ExpireWhen when, + [RespPrefix("FIELDS"), RespPrefix] RedisValue[] fields); + + private sealed class ExpireResultParser : IRespParser, IRespParser + { + private ExpireResultParser() { } + public static readonly ExpireResultParser Default = new(); + + ExpireResult IRespParser.Parse(ref RespReader reader) + { + if (reader.IsAggregate & !reader.IsNull) + { + // if aggregate: take the first element + reader.MoveNext(); + } + + // otherwise, take first from array + return (ExpireResult)reader.ReadInt64(); + } + + ExpireResult[] IRespParser.Parse(ref RespReader reader) + => reader.ReadArray(static (ref RespReader reader) => (ExpireResult)reader.ReadInt64(), scalar: true)!; + } + + internal static RespOperation HExpire( + this in HashCommands context, + RedisKey key, + TimeSpan expiry, + ExpireWhen when, + RedisValue[] fields) + { + var millis = (long)expiry.TotalMilliseconds; + if (millis % 1000 == 0) // use seconds + { + return HExpire(context, key, millis / 1000, when, fields); + } + + return HPExpire(context, key, millis, when, fields); + } + + internal static RespOperation HExpireAt( + this in HashCommands context, + RedisKey key, + DateTime expiry, + ExpireWhen when, + RedisValue[] fields) + { + var millis = RedisDatabase.GetUnixTimeMilliseconds(expiry); + if (millis % 1000 == 0) // use seconds + { + return HExpireAt(context, key, millis / 1000, when, fields); + } + + return HPExpireAt(context, key, millis, when, fields); + } + + [RespCommand(Parser = "RespParsers.DateTimeFromSeconds")] + public static partial RespOperation HExpireTime( + this in HashCommands context, + RedisKey key, + [RespPrefix("FIELDS"), RespPrefix("1")] RedisValue field); + + [RespCommand(Parser = "RespParsers.DateTimeArrayFromSeconds")] + public static partial RespOperation HExpireTime( + this in HashCommands context, + RedisKey key, + [RespPrefix("FIELDS"), RespPrefix] RedisValue[] fields); + + [RespCommand(nameof(HPExpireTime))] + public static partial RespOperation HPExpireTimeRaw( + this in HashCommands context, + RedisKey key, + [RespPrefix("FIELDS"), RespPrefix("1")] RedisValue field); + + [RespCommand(nameof(HPExpireTime))] + public static partial RespOperation HPExpireTimeRaw( + this in HashCommands context, + RedisKey key, + [RespPrefix("FIELDS"), RespPrefix] RedisValue[] fields); + + [RespCommand(Parser = "RespParsers.DateTimeFromMilliseconds")] + public static partial RespOperation HPExpireTime( + this in HashCommands context, + RedisKey key, + [RespPrefix("FIELDS"), RespPrefix("1")] RedisValue field); + + [RespCommand(Parser = "RespParsers.DateTimeArrayFromMilliseconds")] + public static partial RespOperation HPExpireTime( + this in HashCommands context, + RedisKey key, + [RespPrefix("FIELDS"), RespPrefix] RedisValue[] fields); + + [RespCommand] + public static partial RespOperation HGet( + this in HashCommands context, + RedisKey key, + RedisValue field); + + [RespCommand] + public static partial RespOperation HGetDel( + this in HashCommands context, + RedisKey key, + [RespPrefix("FIELDS"), RespPrefix] RedisValue[] fields); + + [RespCommand] + public static partial RespOperation HGetDel( + this in HashCommands context, + RedisKey key, + [RespPrefix("FIELDS"), RespPrefix("1")] + RedisValue fields); + + [RespCommand(nameof(HGetDel))] + public static partial RespOperation?> HGetDelLease( + this in HashCommands context, + RedisKey key, + [RespPrefix("FIELDS"), RespPrefix("1")] + RedisValue fields); + + public static RespOperation HGetEx( + this in HashCommands context, + RedisKey key, + RedisValue field, + bool persist = false) + => HGetEx(context, key, persist ? HashExpiryMode.PERSIST : HashExpiryMode.None, -1, field); + + public static RespOperation?> HGetExLease( + this in HashCommands context, + RedisKey key, + RedisValue field, + bool persist = false) + => HGetExLease(context, key, persist ? HashExpiryMode.PERSIST : HashExpiryMode.None, -1, field); + + internal static RespOperation?> HGetExLease( + this in HashCommands context, + RedisKey key, + RedisValue field, + TimeSpan? expiry, + bool persist) + => expiry.HasValue + ? HGetExLease(context, key, expiry.GetValueOrDefault(), field) + : HGetExLease(context, key, field, persist); + + internal static RespOperation HGetEx( + this in HashCommands context, + RedisKey key, + RedisValue field, + TimeSpan? expiry, + bool persist) + => expiry.HasValue + ? HGetEx(context, key, expiry.GetValueOrDefault(), field) + : HGetEx(context, key, field, persist); + + internal static RespOperation HGetEx( + this in HashCommands context, + RedisKey key, + RedisValue[] fields, + TimeSpan? expiry, + bool persist) + => expiry.HasValue + ? HGetEx(context, key, expiry.GetValueOrDefault(), fields) + : HGetEx(context, key, fields, persist); + + public static RespOperation HGetEx( + this in HashCommands context, + RedisKey key, + RedisValue[] fields, + bool persist = false) + => HGetEx(context, key, persist ? HashExpiryMode.PERSIST : HashExpiryMode.None, -1, fields); + + public static RespOperation HGetEx( + this in HashCommands context, + RedisKey key, + DateTime expiry, + RedisValue field) + { + var millis = RedisDatabase.GetUnixTimeMilliseconds(expiry); + if (millis % 1000 == 0) // use seconds + { + return HGetEx(context, key, HashExpiryMode.EXAT, millis / 1000, field); + } + + return HGetEx(context, key, HashExpiryMode.PXAT, millis, field); + } + + public static RespOperation?> HGetExLease( + this in HashCommands context, + RedisKey key, + DateTime expiry, + RedisValue field) + { + var millis = RedisDatabase.GetUnixTimeMilliseconds(expiry); + if (millis % 1000 == 0) // use seconds + { + return HGetExLease(context, key, HashExpiryMode.EXAT, millis / 1000, field); + } + + return HGetExLease(context, key, HashExpiryMode.PXAT, millis, field); + } + + public static RespOperation HGetEx( + this in HashCommands context, + RedisKey key, + DateTime expiry, + RedisValue[] fields) + { + var millis = RedisDatabase.GetUnixTimeMilliseconds(expiry); + if (millis % 1000 == 0) // use seconds + { + return HGetEx(context, key, HashExpiryMode.EXAT, millis / 1000, fields); + } + + return HGetEx(context, key, HashExpiryMode.PXAT, millis, fields); + } + + public static RespOperation HGetEx( + this in HashCommands context, + RedisKey key, + TimeSpan expiry, + RedisValue field) + { + var millis = (long)expiry.TotalMilliseconds; + if (millis % 1000 == 0) // use seconds + { + return HGetEx(context, key, HashExpiryMode.EX, millis / 1000, field); + } + + return HGetEx(context, key, HashExpiryMode.PX, millis, field); + } + + public static RespOperation?> HGetExLease( + this in HashCommands context, + RedisKey key, + TimeSpan expiry, + RedisValue field) + { + var millis = (long)expiry.TotalMilliseconds; + if (millis % 1000 == 0) // use seconds + { + return HGetExLease(context, key, HashExpiryMode.EX, millis / 1000, field); + } + + return HGetExLease(context, key, HashExpiryMode.PX, millis, field); + } + + public static RespOperation HGetEx( + this in HashCommands context, + RedisKey key, + TimeSpan expiry, + RedisValue[] fields) + { + var millis = (long)expiry.TotalMilliseconds; + if (millis % 1000 == 0) // use seconds + { + return HGetEx(context, key, HashExpiryMode.EXAT, millis / 1000, fields); + } + + return HGetEx(context, key, HashExpiryMode.PXAT, millis, fields); + } + + internal enum HashExpiryMode + { + None, + EX, + PX, + EXAT, + PXAT, + PERSIST, + KEEPTTL, + } + + [RespCommand] + private static partial RespOperation HGetEx( + this in HashCommands context, + RedisKey key, + [RespIgnore(HashExpiryMode.None)] HashExpiryMode mode, + [RespIgnore(-1)] long value, + [RespPrefix("FIELDS"), RespPrefix] RedisValue[] fields); + + [RespCommand] + private static partial RespOperation HGetEx( + this in HashCommands context, + RedisKey key, + [RespIgnore(HashExpiryMode.None)] HashExpiryMode mode, + [RespIgnore(-1)] long value, + [RespPrefix("FIELDS"), RespPrefix("1")] RedisValue field); + + [RespCommand(nameof(HGetEx))] + private static partial RespOperation?> HGetExLease( + this in HashCommands context, + RedisKey key, + [RespIgnore(HashExpiryMode.None)] HashExpiryMode mode, + [RespIgnore(-1)] long value, + [RespPrefix("FIELDS"), RespPrefix("1")] RedisValue field); + + [RespCommand(nameof(HGet))] + public static partial RespOperation?> HGetLease( + this in HashCommands context, + RedisKey key, + RedisValue field); + + [RespCommand] + public static partial RespOperation HGetAll(this in HashCommands context, RedisKey key); + + [RespCommand] + public static partial RespOperation HIncrBy( + this in HashCommands context, + RedisKey key, + RedisValue field, + long value = 1); + + [RespCommand] + public static partial RespOperation HIncrByFloat( + this in HashCommands context, + RedisKey key, + RedisValue field, + double value); + + [RespCommand] + public static partial RespOperation HKeys(this in HashCommands context, RedisKey key); + + [RespCommand] + public static partial RespOperation HLen(this in HashCommands context, RedisKey key); + + [RespCommand] + public static partial RespOperation HMGet( + this in HashCommands context, + RedisKey key, + RedisValue[] fields); + + [RespCommand] + public static partial RespOperation HRandField(this in HashCommands context, RedisKey key); + + [RespCommand] + public static partial RespOperation + HRandField(this in HashCommands context, RedisKey key, long count); + + [RespCommand] + public static partial RespOperation HRandFieldWithValues( + this in HashCommands context, + RedisKey key, + [RespSuffix("WITHVALUES")] long count); + + [RespCommand] + public static partial RespOperation HSet( + this in HashCommands context, + RedisKey key, + RedisValue field, + RedisValue value); + + internal static RespOperation HSet( + this in HashCommands context, + RedisKey key, + RedisValue field, + RedisValue value, + When when) + { + switch (when) + { + case When.Always: + return HSet(context, key, field, value); + case When.NotExists: + return HSetNX(context, key, field, value); + default: + when.AlwaysOrNotExists(); // throws + return default; + } + } + + [RespCommand(Formatter = "HSetFormatter.Instance")] + public static partial RespOperation HSet(this in HashCommands context, RedisKey key, HashEntry[] fields); + + private sealed class HSetFormatter : IRespFormatter<(RedisKey Key, HashEntry[] Fields)> + { + private HSetFormatter() { } + public static readonly HSetFormatter Instance = new(); + + public void Format( + scoped ReadOnlySpan command, + ref RespWriter writer, + in (RedisKey Key, HashEntry[] Fields) request) + { + writer.WriteCommand(command, 1 + (request.Fields.Length * 2)); + writer.Write(request.Key); + foreach (var entry in request.Fields) + { + writer.Write(entry.Name); + writer.Write(entry.Value); + } + } + } + + public static RespOperation HSetEx( + this in HashCommands context, + RedisKey key, + TimeSpan expiry, + RedisValue field, + RedisValue value, + When when = When.Always) + { + var millis = (long)expiry.TotalMilliseconds; + if (millis % 1000 == 0) // use seconds + { + return HSetEx(context, key, when, HashExpiryMode.EX, millis / 1000, field, value); + } + + return HSetEx(context, key, when, HashExpiryMode.PX, millis, field, value); + } + + // "Legacy" - OK, so: historically, HashFieldSetAndSetExpiry returned RedisValue; this is ... bizarre, + // since HSETEX returns a bool. So: in the name of not breaking the world, we'll keep returning RedisValue; + // but: in the nice clean shiny API: expose bool + internal static RespOperation HSetExLegacy( + this in HashCommands context, + RedisKey key, + TimeSpan expiry, + RedisValue field, + RedisValue value, + When when) + { + var millis = (long)expiry.TotalMilliseconds; + if (millis % 1000 == 0) // use seconds + { + return HSetExLegacy(context, key, when, HashExpiryMode.EX, millis / 1000, field, value); + } + + return HSetExLegacy(context, key, when, HashExpiryMode.PX, millis, field, value); + } + + internal static RespOperation HSetExLegacy( + this in HashCommands context, + RedisKey key, + TimeSpan? expiry, + RedisValue field, + RedisValue value, + When when, + bool keepTtl) + { + if (expiry.HasValue) return HSetExLegacy(context, key, expiry.GetValueOrDefault(), field, value, when); + return HSetExLegacy(context, key, field, value, when, keepTtl); + } + + public static RespOperation HSetEx( + this in HashCommands context, + RedisKey key, + TimeSpan expiry, + HashEntry[] fields, + When when = When.Always) + { + if (fields.Length == 1) return HSetEx(context, key, expiry, fields[0].Name, fields[0].Value, when); + var millis = (long)expiry.TotalMilliseconds; + if (millis % 1000 == 0) // use seconds + { + return HSetEx(context, key, when, HashExpiryMode.EX, millis / 1000, fields); + } + + return HSetEx(context, key, when, HashExpiryMode.PX, millis, fields); + } + + private static RespOperation HSetExLegacy( + this in HashCommands context, + RedisKey key, + TimeSpan expiry, + HashEntry[] fields, + When when) + { + if (fields.Length == 1) return HSetExLegacy(context, key, expiry, fields[0].Name, fields[0].Value, when); + var millis = (long)expiry.TotalMilliseconds; + if (millis % 1000 == 0) // use seconds + { + return HSetExLegacy(context, key, when, HashExpiryMode.EX, millis / 1000, fields); + } + + return HSetExLegacy(context, key, when, HashExpiryMode.PX, millis, fields); + } + + internal static RespOperation HSetExLegacy( + this in HashCommands context, + RedisKey key, + TimeSpan? expiry, + HashEntry[] fields, + When when, + bool keepTtl) + { + if (expiry.HasValue) return HSetExLegacy(context, key, expiry.GetValueOrDefault(), fields, when); + return HSetExLegacy(context, key, fields, when, keepTtl); + } + + public static RespOperation HSetEx( + this in HashCommands context, + RedisKey key, + DateTime expiry, + RedisValue field, + RedisValue value, + When when = When.Always) + { + var millis = RedisDatabase.GetUnixTimeMilliseconds(expiry); + if (millis % 1000 == 0) // use seconds + { + return HSetEx(context, key, when, HashExpiryMode.EXAT, millis / 1000, field, value); + } + + return HSetEx(context, key, when, HashExpiryMode.PXAT, millis, field, value); + } + + internal static RespOperation HSetExLegacy( + this in HashCommands context, + RedisKey key, + DateTime expiry, + RedisValue field, + RedisValue value, + When when) + { + var millis = RedisDatabase.GetUnixTimeMilliseconds(expiry); + if (millis % 1000 == 0) // use seconds + { + return HSetExLegacy(context, key, when, HashExpiryMode.EXAT, millis / 1000, field, value); + } + + return HSetExLegacy(context, key, when, HashExpiryMode.PXAT, millis, field, value); + } + + public static RespOperation HSetEx( + this in HashCommands context, + RedisKey key, + DateTime expiry, + HashEntry[] fields, + When when = When.Always) + { + if (fields.Length == 1) return HSetEx(context, key, expiry, fields[0].Name, fields[0].Value, when); + var millis = RedisDatabase.GetUnixTimeMilliseconds(expiry); + if (millis % 1000 == 0) // use seconds + { + return HSetEx(context, key, when, HashExpiryMode.EXAT, millis / 1000, fields); + } + + return HSetEx(context, key, when, HashExpiryMode.PXAT, millis, fields); + } + + internal static RespOperation HSetExLegacy( + this in HashCommands context, + RedisKey key, + DateTime expiry, + HashEntry[] fields, + When when) + { + if (fields.Length == 1) return HSetExLegacy(context, key, expiry, fields[0].Name, fields[0].Value, when); + var millis = RedisDatabase.GetUnixTimeMilliseconds(expiry); + if (millis % 1000 == 0) // use seconds + { + return HSetExLegacy(context, key, when, HashExpiryMode.EXAT, millis / 1000, fields); + } + + return HSetExLegacy(context, key, when, HashExpiryMode.PXAT, millis, fields); + } + + public static RespOperation HSetEx( + this in HashCommands context, + RedisKey key, + RedisValue field, + RedisValue value, + When when = When.Always, + bool keepTtl = false) + => HSetEx(context, key, when, keepTtl ? HashExpiryMode.KEEPTTL : HashExpiryMode.None, -1, field, value); + + private static RespOperation HSetExLegacy( + this in HashCommands context, + RedisKey key, + RedisValue field, + RedisValue value, + When when = When.Always, + bool keepTtl = false) + => HSetExLegacy(context, key, when, keepTtl ? HashExpiryMode.KEEPTTL : HashExpiryMode.None, -1, field, value); + + public static RespOperation HSetEx( + this in HashCommands context, + RedisKey key, + HashEntry[] fields, + When when = When.Always, + bool keepTtl = false) + { + if (fields.Length == 1) return HSetEx(context, key, fields[0].Name, fields[0].Value, when, keepTtl); + return HSetEx(context, key, when, keepTtl ? HashExpiryMode.KEEPTTL : HashExpiryMode.None, -1, fields); + } + + private static RespOperation HSetExLegacy( + this in HashCommands context, + RedisKey key, + HashEntry[] fields, + When when, + bool keepTtl) + { + if (fields.Length == 1) return HSetExLegacy(context, key, fields[0].Name, fields[0].Value, when, keepTtl); + return HSetExLegacy(context, key, when, keepTtl ? HashExpiryMode.KEEPTTL : HashExpiryMode.None, -1, fields); + } + + [RespCommand(Formatter = "HSetExFormatter.Instance")] + private static partial RespOperation HSetEx( + this in HashCommands context, + RedisKey key, + When when, + HashExpiryMode mode, + long expiry, + RedisValue field, + RedisValue value); + + [RespCommand(Formatter = "HSetExFormatter.Instance")] + private static partial RespOperation HSetEx( + this in HashCommands context, + RedisKey key, + When when, + HashExpiryMode mode, + long expiry, + HashEntry[] fields); + + [RespCommand(nameof(HSetEx), Formatter = "HSetExFormatter.Instance")] + private static partial RespOperation HSetExLegacy( + this in HashCommands context, + RedisKey key, + When when, + HashExpiryMode mode, + long expiry, + RedisValue field, + RedisValue value); + + [RespCommand(nameof(HSetEx), Formatter = "HSetExFormatter.Instance")] + private static partial RespOperation HSetExLegacy( + this in HashCommands context, + RedisKey key, + When when, + HashExpiryMode mode, + long expiry, + HashEntry[] fields); + + private sealed class + HSetExFormatter : IRespFormatter<(RedisKey Key, When When, HashExpiryMode Mode, long Expiry, HashEntry[] Fields)>, + IRespFormatter<(RedisKey Key, When When, HashExpiryMode Mode, long Expiry, RedisValue Field, RedisValue Value)> + { + private HSetExFormatter() { } + public static readonly HSetExFormatter Instance = new(); + + public void Format( + scoped ReadOnlySpan command, + ref RespWriter writer, + in (RedisKey Key, When When, HashExpiryMode Mode, long Expiry, HashEntry[] Fields) request) + { + bool __inc0 = request.When != When.Always; // IgnoreExpression + bool __inc1 = request.Mode != HashExpiryMode.None; // IgnoreExpression + bool __inc2 = request.Expiry != -1; // IgnoreExpression +#pragma warning disable SA1118 + writer.WriteCommand(command, 3 // constant args: key, FIELDS, numfields + + (__inc0 ? 1 : 0) // request.When + + (__inc1 ? 1 : 0) // request.Mode + + (__inc2 ? 1 : 0) // request.Expiry + + (request.Fields.Length * 2)); // request.Fields +#pragma warning restore SA1118 + writer.Write(request.Key); + if (__inc0) + { + writer.WriteRaw(GetRaw(request.When)); + } + if (__inc1) + { + writer.WriteBulkString(request.Mode); + } + if (__inc2) + { + writer.WriteBulkString(request.Expiry); + } + writer.WriteRaw("$6\r\nFIELDS\r\n"u8); // FIELDS + writer.WriteBulkString(request.Fields.Length); + foreach (var entry in request.Fields) + { + writer.Write(entry.Name); + writer.Write(entry.Value); + } + } + + public void Format( + scoped ReadOnlySpan command, + ref RespWriter writer, + in (RedisKey Key, When When, HashExpiryMode Mode, long Expiry, RedisValue Field, RedisValue Value) request) + { + bool __inc0 = request.When != When.Always; // IgnoreExpression + bool __inc1 = request.Mode != HashExpiryMode.None; // IgnoreExpression + bool __inc2 = request.Expiry != -1; // IgnoreExpression +#pragma warning disable SA1118 + writer.WriteCommand(command, 5 // constant args: key, FIELDS, numfields, field, value + + (__inc0 ? 1 : 0) // request.When + + (__inc1 ? 1 : 0) // request.Mode + + (__inc2 ? 1 : 0)); // request.Expiry +#pragma warning restore SA1118 + writer.Write(request.Key); + if (__inc0) + { + writer.WriteRaw(GetRaw(request.When)); + } + if (__inc1) + { + writer.WriteBulkString(request.Mode); + } + if (__inc2) + { + writer.WriteBulkString(request.Expiry); + } + writer.WriteRaw("$6\r\nFIELDS\r\n$1\r\n1\r\n"u8); // FIELDS 1 + writer.Write(request.Field); + writer.Write(request.Value); + } + + private static ReadOnlySpan GetRaw(When when) + { + return when switch + { + When.Exists => "FXX"u8, + When.NotExists => "FNX"u8, + _ => Throw(), + }; + static ReadOnlySpan Throw() => throw new ArgumentOutOfRangeException(nameof(when)); + } + } + + [RespCommand] + public static partial RespOperation HSetNX( + this in HashCommands context, + RedisKey key, + RedisValue field, + RedisValue value); + + [RespCommand] + public static partial RespOperation HStrLen(this in HashCommands context, RedisKey key, RedisValue field); + + [RespCommand(Parser = "PersistResultParser.Default")] + public static partial RespOperation HPersist( + this in HashCommands context, + RedisKey key, + [RespPrefix("FIELDS"), RespPrefix("1")] RedisValue field); + + [RespCommand(Parser = "PersistResultParser.Default")] + public static partial RespOperation HPersist( + this in HashCommands context, + RedisKey key, + [RespPrefix("FIELDS"), RespPrefix] RedisValue[] fields); + + private sealed class PersistResultParser : IRespParser, IRespParser, IRespInlineParser + { + private PersistResultParser() { } + public static readonly PersistResultParser Default = new(); + PersistResult IRespParser.Parse(ref RespReader reader) + { + if (reader.IsAggregate) + { + reader.MoveNext(); // read first element from array + } + return (PersistResult)reader.ReadInt64(); + } + + PersistResult[] IRespParser.Parse(ref RespReader reader) => reader.ReadArray( + static (ref RespReader reader) => (PersistResult)reader.ReadInt64(), + scalar: true)!; + } + + [RespCommand(nameof(HPTtl))] + public static partial RespOperation HPTtlRaw( + this in HashCommands context, + RedisKey key, + [RespPrefix("FIELDS"), RespPrefix("1")] + RedisValue field); + + [RespCommand(nameof(HPTtl))] + public static partial RespOperation HPTtlRaw( + this in HashCommands context, + RedisKey key, + [RespPrefix("FIELDS"), RespPrefix] RedisValue[] fields); + + [RespCommand(Parser = "RespParsers.TimeSpanFromMilliseconds")] + public static partial RespOperation HPTtl( + this in HashCommands context, + RedisKey key, + [RespPrefix("FIELDS"), RespPrefix("1")] RedisValue field); + + [RespCommand(Parser = "RespParsers.TimeSpanArrayFromMilliseconds")] + public static partial RespOperation HPTtl( + this in HashCommands context, + RedisKey key, + [RespPrefix("FIELDS"), RespPrefix] RedisValue[] fields); + + [RespCommand(nameof(HTtl))] + public static partial RespOperation HTtlRaw( + this in HashCommands context, + RedisKey key, + [RespPrefix("FIELDS"), RespPrefix("1")] + RedisValue field); + + [RespCommand(nameof(HTtl))] + public static partial RespOperation HTtlRaw( + this in HashCommands context, + RedisKey key, + [RespPrefix("FIELDS"), RespPrefix] RedisValue[] fields); + + [RespCommand(Parser = "RespParsers.TimeSpanFromSeconds")] + public static partial RespOperation HTtl( + this in HashCommands context, + RedisKey key, + [RespPrefix("FIELDS"), RespPrefix("1")] RedisValue field); + + [RespCommand(Parser = "RespParsers.TimeSpanArrayFromSeconds")] + public static partial RespOperation HTtl( + this in HashCommands context, + RedisKey key, + [RespPrefix("FIELDS"), RespPrefix] RedisValue[] fields); + + [RespCommand] + public static partial RespOperation HVals(this in HashCommands context, RedisKey key); +} diff --git a/src/RESPite.StackExchange.Redis/RedisCommands.HyperLogLogCommands.cs b/src/RESPite.StackExchange.Redis/RedisCommands.HyperLogLogCommands.cs new file mode 100644 index 000000000..553d0b78d --- /dev/null +++ b/src/RESPite.StackExchange.Redis/RedisCommands.HyperLogLogCommands.cs @@ -0,0 +1,39 @@ +using System.Runtime.CompilerServices; +using StackExchange.Redis; + +// ReSharper disable MemberCanBePrivate.Global +// ReSharper disable InconsistentNaming +namespace RESPite.StackExchange.Redis; + +internal static partial class RedisCommands +{ + // this is just a "type pun" - it should be an invisible/magic pointer cast to the JIT + public static ref readonly HyperLogLogCommands HyperLogLogs(this in RespContext context) + => ref Unsafe.As(ref Unsafe.AsRef(in context)); +} + +public readonly struct HyperLogLogCommands(in RespContext context) +{ + public readonly RespContext Context = context; // important: this is the only field +} + +internal static partial class HyperLogLogCommandsExtensions +{ + [RespCommand] + public static partial RespOperation PfAdd(this in HyperLogLogCommands context, RedisKey key, RedisValue value); + + [RespCommand] + public static partial RespOperation PfAdd(this in HyperLogLogCommands context, RedisKey key, RedisValue[] values); + + [RespCommand] + public static partial RespOperation PfCount(this in HyperLogLogCommands context, RedisKey key); + + [RespCommand] + public static partial RespOperation PfCount(this in HyperLogLogCommands context, RedisKey[] keys); + + [RespCommand] + public static partial RespOperation PfMerge(this in HyperLogLogCommands context, RedisKey destination, RedisKey first, RedisKey second); + + [RespCommand] + public static partial RespOperation PfMerge(this in HyperLogLogCommands context, RedisKey destination, RedisKey[] sourceKeys); +} diff --git a/src/RESPite.StackExchange.Redis/RedisCommands.KeyCommands.cs b/src/RESPite.StackExchange.Redis/RedisCommands.KeyCommands.cs new file mode 100644 index 000000000..91b5eb255 --- /dev/null +++ b/src/RESPite.StackExchange.Redis/RedisCommands.KeyCommands.cs @@ -0,0 +1,260 @@ +using System.Runtime.CompilerServices; +using RESPite.Messages; +using StackExchange.Redis; + +// ReSharper disable MemberCanBePrivate.Global +// ReSharper disable InconsistentNaming +namespace RESPite.StackExchange.Redis; + +internal static partial class RedisCommands +{ + // this is just a "type pun" - it should be an invisible/magic pointer cast to the JIT + public static ref readonly KeyCommands Keys(this in RespContext context) + => ref Unsafe.As(ref Unsafe.AsRef(in context)); +} + +public readonly struct KeyCommands(in RespContext context) +{ + public readonly RespContext Context = context; // important: this is the only field +} + +internal static partial class KeyCommandsExtensions +{ + [RespCommand(Formatter = "CopyFormatter.Instance")] + public static partial RespOperation Copy( + this in KeyCommands context, + RedisKey source, + RedisKey destination, + int destinationDatabase = -1, + bool replace = false); + + [RespCommand] + public static partial RespOperation Del(this in KeyCommands context, RedisKey key); + + [RespCommand] + public static partial RespOperation Del(this in KeyCommands context, [RespKey] RedisKey[] keys); + + [RespCommand] + public static partial RespOperation Dump(this in KeyCommands context, RedisKey key); + + [RespCommand] + public static partial RespOperation Exists(this in KeyCommands context, RedisKey key); + + [RespCommand] + public static partial RespOperation Exists(this in KeyCommands context, [RespKey] RedisKey[] keys); + + public static RespOperation Expire( + this in KeyCommands context, + RedisKey key, + TimeSpan? expiry, + ExpireWhen when = ExpireWhen.Always) + { + if (expiry is null || expiry == TimeSpan.MaxValue) + { + if (when != ExpireWhen.Always) Throw(when); + return Persist(context, key); + static void Throw(ExpireWhen when) => throw new ArgumentException($"PERSIST cannot be used with {when}."); + } + + var millis = (long)expiry.GetValueOrDefault().TotalMilliseconds; + if (millis % 1000 == 0) // use seconds + { + return Expire(context, key, millis / 1000, when); + } + + return PExpire(context, key, millis, when); + } + + [RespCommand] + public static partial RespOperation Expire( + this in KeyCommands context, + RedisKey key, + long seconds, + [RespIgnore(ExpireWhen.Always)] ExpireWhen when = ExpireWhen.Always); + + public static RespOperation ExpireAt( + this in KeyCommands context, + RedisKey key, + DateTime? expiry, + ExpireWhen when = ExpireWhen.Always) + { + if (expiry is null || expiry == DateTime.MaxValue) + { + if (when != ExpireWhen.Always) Throw(when); + return Persist(context, key); + static void Throw(ExpireWhen when) => throw new ArgumentException($"PERSIST cannot be used with {when}."); + } + + var millis = RedisDatabase.GetUnixTimeMilliseconds(expiry.GetValueOrDefault()); + if (millis % 1000 == 0) // use seconds + { + return ExpireAt(context, key, millis / 1000, when); + } + + return PExpireAt(context, key, millis, when); + } + + [RespCommand] + public static partial RespOperation ExpireAt( + this in KeyCommands context, + RedisKey key, + long seconds, + [RespIgnore(ExpireWhen.Always)] ExpireWhen when = ExpireWhen.Always); + + [RespCommand(Parser = "RespParsers.DateTimeFromSeconds")] + public static partial RespOperation ExpireTime(this in KeyCommands context, RedisKey key); + + [RespCommand] + public static partial RespOperation Move(this in KeyCommands context, RedisKey key, int db); + + [RespCommand("object")] + public static partial RespOperation ObjectEncoding( + this in KeyCommands context, + [RespPrefix("ENCODING")] RedisKey key); + + [RespCommand("object")] + public static partial RespOperation ObjectFreq( + this in KeyCommands context, + [RespPrefix("FREQ")] RedisKey key); + + [RespCommand("object", Parser = "RespParsers.TimeSpanFromSeconds")] + public static partial RespOperation ObjectIdleTime( + this in KeyCommands context, + [RespPrefix("IDLETIME")] RedisKey key); + + [RespCommand("object")] + public static partial RespOperation ObjectRefCount( + this in KeyCommands context, + [RespPrefix("REFCOUNT")] RedisKey key); + + [RespCommand] + public static partial RespOperation PExpire( + this in KeyCommands context, + RedisKey key, + long milliseconds, + [RespIgnore(ExpireWhen.Always)] ExpireWhen when = ExpireWhen.Always); + + [RespCommand] + public static partial RespOperation PExpireAt( + this in KeyCommands context, + RedisKey key, + long milliseconds, + [RespIgnore(ExpireWhen.Always)] ExpireWhen when = ExpireWhen.Always); + + [RespCommand(Parser = "RespParsers.DateTimeFromMilliseconds")] + public static partial RespOperation PExpireTime(this in KeyCommands context, RedisKey key); + + [RespCommand] + public static partial RespOperation Persist(this in KeyCommands context, RedisKey key); + + [RespCommand(Parser = "RespParsers.TimeSpanFromMilliseconds")] + public static partial RespOperation Pttl(this in KeyCommands context, RedisKey key); + + [RespCommand] + public static partial RespOperation RandomKey(this in KeyCommands context); + + [RespCommand] + public static partial RespOperation Rename(this in KeyCommands context, RedisKey key, RedisKey newKey); + + [RespCommand] + public static RespOperation Rename(this in KeyCommands context, RedisKey key, RedisKey newKey, When when) + { + switch (when) + { + case When.Always: + return Rename(context, key, newKey); + case When.NotExists: + return RenameNx(context, key, newKey); + default: + when.AlwaysOrNotExists(); // throws + return default; + } + } + + [RespCommand] + public static partial RespOperation RenameNx(this in KeyCommands context, RedisKey key, RedisKey newKey); + + [RespCommand(Formatter = "RestoreFormatter.Instance")] + public static partial RespOperation Restore( + this in KeyCommands context, + RedisKey key, + TimeSpan? ttl, + byte[] serializedValue); + + [RespCommand] + public static partial RespOperation Touch(this in KeyCommands context, RedisKey key); + + [RespCommand] + public static partial RespOperation Touch(this in KeyCommands context, [RespKey] RedisKey[] keys); + + [RespCommand(Parser = "RespParsers.TimeSpanFromSeconds")] + public static partial RespOperation Ttl(this in KeyCommands context, RedisKey key); + + [RespCommand(Parser = "RedisTypeParser.Instance")] + public static partial RespOperation Type(this in KeyCommands context, RedisKey key); + + private sealed class RedisTypeParser : IRespParser + { + public static readonly RedisTypeParser Instance = new(); + private RedisTypeParser() { } + + public RedisType Parse(ref RespReader reader) + { + if (reader.IsNull) return RedisType.None; + if (reader.Is("zset"u8)) return RedisType.SortedSet; + return reader.ReadEnum(RedisType.Unknown); + } + } + + private sealed class CopyFormatter : IRespFormatter<(RedisKey Source, RedisKey Destination, int DestinationDatabase, + bool Replace)> + { + public static readonly CopyFormatter Instance = new(); + private CopyFormatter() { } + + public void Format( + scoped ReadOnlySpan command, + ref RespWriter writer, + in (RedisKey Source, RedisKey Destination, int DestinationDatabase, bool Replace) request) + { + writer.WriteCommand(command, (request.DestinationDatabase >= 0 ? 4 : 2) + (request.Replace ? 1 : 0)); + writer.Write(request.Source); + writer.Write(request.Destination); + if (request.DestinationDatabase >= 0) + { + writer.WriteRaw("$2\r\nDB\r\n"u8); + writer.WriteBulkString(request.DestinationDatabase); + } + + if (request.Replace) + { + writer.WriteRaw("$7\r\nREPLACE\r\n"u8); + } + } + } + + private sealed class RestoreFormatter : IRespFormatter<(RedisKey Key, TimeSpan? Ttl, byte[] SerializedValue)> + { + public static readonly RestoreFormatter Instance = new(); + private RestoreFormatter() { } + + public void Format( + scoped ReadOnlySpan command, + ref RespWriter writer, + in (RedisKey Key, TimeSpan? Ttl, byte[] SerializedValue) request) + { + writer.WriteCommand(command, 3); + writer.Write(request.Key); + if (request.Ttl.HasValue) + { + writer.WriteBulkString((long)request.Ttl.Value.TotalMilliseconds); + } + else + { + writer.WriteRaw("$1\r\n0\r\n"u8); + } + + writer.WriteBulkString(request.SerializedValue); + } + } +} diff --git a/src/RESPite.StackExchange.Redis/RedisCommands.ListCommands.cs b/src/RESPite.StackExchange.Redis/RedisCommands.ListCommands.cs new file mode 100644 index 000000000..1112e26c2 --- /dev/null +++ b/src/RESPite.StackExchange.Redis/RedisCommands.ListCommands.cs @@ -0,0 +1,223 @@ +using System.Runtime.CompilerServices; +using RESPite.Messages; +using StackExchange.Redis; + +// ReSharper disable MemberCanBePrivate.Global +// ReSharper disable InconsistentNaming +namespace RESPite.StackExchange.Redis; + +internal static partial class RedisCommands +{ + // this is just a "type pun" - it should be an invisible/magic pointer cast to the JIT + public static ref readonly ListCommands Lists(this in RespContext context) + => ref Unsafe.As(ref Unsafe.AsRef(in context)); +} + +public readonly struct ListCommands(in RespContext context) +{ + public readonly RespContext Context = context; // important: this is the only field +} + +internal static partial class ListCommandsExtensions +{ + /* + [RespCommand] + public static partial RespOperation BLMove( + this in ListCommands context, + RedisKey source, + RedisKey destination, + ListSide sourceSide, + ListSide destinationSide, + double timeoutSeconds); + + [RespCommand] + public static partial RespOperation BLMPop( + this in ListCommands context, + [RespKey] RedisKey[] keys, + ListSide side, + long count, + double timeoutSeconds); + + [RespCommand] + public static partial RespOperation BLPop( + this in ListCommands context, + [RespKey] RedisKey[] keys, + double timeoutSeconds); + + [RespCommand] + public static partial RespOperation BRPop( + this in ListCommands context, + [RespKey] RedisKey[] keys, + double timeoutSeconds); + + [RespCommand] + public static partial RespOperation BRPopLPush( + this in ListCommands context, + RedisKey source, + RedisKey destination, + double timeoutSeconds); + */ + + [RespCommand] + public static partial RespOperation LIndex(this in ListCommands context, RedisKey key, long index); + + [RespCommand(Formatter = "LInsertFormatter.Instance")] + public static partial RespOperation LInsert( + this in ListCommands context, + RedisKey key, + bool insertBefore, + RedisValue pivot, + RedisValue element); + + private sealed class + LInsertFormatter : IRespFormatter<(RedisKey Key, bool InsertBefore, RedisValue Pivot, RedisValue Element)> + { + public static readonly LInsertFormatter Instance = new(); + private LInsertFormatter() { } + + public void Format( + scoped ReadOnlySpan command, + ref RespWriter writer, + in (RedisKey Key, bool InsertBefore, RedisValue Pivot, RedisValue Element) request) + { + writer.WriteCommand(command, 4); + writer.Write(request.Key); + writer.WriteRaw(request.InsertBefore ? "$6\r\nBEFORE\r\n"u8 : "$5\r\nAFTER\r\n"u8); + writer.Write(request.Pivot); + writer.Write(request.Element); + } + } + + [RespCommand] + public static partial RespOperation LLen(this in ListCommands context, RedisKey key); + + [RespCommand] + public static partial RespOperation LMove( + this in ListCommands context, + RedisKey source, + RedisKey destination, + ListSide sourceSide, + ListSide destinationSide); + + [RespCommand] + public static partial RespOperation LPop(this in ListCommands context, RedisKey key); + + [RespCommand] + public static partial RespOperation LPop(this in ListCommands context, RedisKey key, long count); + + [RespCommand(Parser = "RespParsers.Int64Index")] + public static partial RespOperation LPos( + this in ListCommands context, + RedisKey key, + RedisValue element, + [RespPrefix("RANK"), RespIgnore(1)] long rank = 1, + [RespPrefix("MAXLEN"), RespIgnore(0)] long maxLen = 0); + + [RespCommand] + public static partial RespOperation LPos( + this in ListCommands context, + RedisKey key, + RedisValue element, + [RespPrefix("RANK"), RespIgnore(1)] long rank, + [RespPrefix("MAXLEN"), RespIgnore(0)] long maxLen, + [RespPrefix("COUNT")] long count); + + [RespCommand] + public static partial RespOperation LPush(this in ListCommands context, RedisKey key, RedisValue element); + + internal static RespOperation Push(this in ListCommands context, RedisKey key, RedisValue element, ListSide side, When when) + { + switch (when) + { + case When.Always: + return side == ListSide.Left ? LPush(context, key, element) : RPush(context, key, element); + case When.Exists: + return side == ListSide.Left ? LPushX(context, key, element) : RPushX(context, key, element); + default: + when.AlwaysOrExists(); // throws + return default; + } + } + + internal static RespOperation Push(this in ListCommands context, RedisKey key, RedisValue[] elements, ListSide side, When when) + { + switch (when) + { + case When.Always when elements.Length == 1: + return side == ListSide.Left ? LPush(context, key, elements[0]) : RPush(context, key, elements[0]); + case When.Always when elements.Length > 1: + return side == ListSide.Left ? LPush(context, key, elements) : RPush(context, key, elements); + case When.Exists when elements.Length == 1: + return side == ListSide.Left ? LPushX(context, key, elements[0]) : RPushX(context, key, elements[0]); + case When.Exists when elements.Length > 1: + return side == ListSide.Left ? LPushX(context, key, elements) : RPushX(context, key, elements); + default: + when.AlwaysOrExists(); // check that "when" is valid + return LLen(context, key); // handle zero case (no insert, just get length) + } + } + + [RespCommand] + public static partial RespOperation LPush(this in ListCommands context, RedisKey key, RedisValue[] elements); + + [RespCommand] + public static partial RespOperation LPushX(this in ListCommands context, RedisKey key, RedisValue element); + + [RespCommand] + public static partial RespOperation LPushX(this in ListCommands context, RedisKey key, RedisValue[] elements); + + [RespCommand] + public static partial RespOperation LRange( + this in ListCommands context, + RedisKey key, + long start, + long stop); + + [RespCommand] + public static partial RespOperation LRem( + this in ListCommands context, + RedisKey key, + long count, + RedisValue element); + + [RespCommand] + public static partial RespOperation LSet( + this in ListCommands context, + RedisKey key, + long index, + RedisValue element); + + [RespCommand] + public static partial RespOperation LTrim(this in ListCommands context, RedisKey key, long start, long stop); + + [RespCommand] + public static partial RespOperation RPop(this in ListCommands context, RedisKey key); + + [RespCommand] + public static partial RespOperation RPop(this in ListCommands context, RedisKey key, long count); + + [RespCommand] + public static partial RespOperation RPopLPush( + this in ListCommands context, + RedisKey source, + RedisKey destination); + + [RespCommand] + public static partial RespOperation RPush(this in ListCommands context, RedisKey key, RedisValue element); + + [RespCommand] + public static partial RespOperation RPush(this in ListCommands context, RedisKey key, RedisValue[] elements); + + [RespCommand] + public static partial RespOperation RPushX(this in ListCommands context, RedisKey key, RedisValue element); + + [RespCommand] + public static partial RespOperation RPushX(this in ListCommands context, RedisKey key, RedisValue[] elements); + + [RespCommand(Parser = "RespParsers.ListPopResult")] + public static partial RespOperation LMPop( + this in ListCommands context, + [RespPrefix, RespKey] RedisKey[] keys, + ListSide side, + [RespIgnore(1), RespPrefix("COUNT")] long count = 1); +} diff --git a/src/RESPite.StackExchange.Redis/RedisCommands.ServerCommands.cs b/src/RESPite.StackExchange.Redis/RedisCommands.ServerCommands.cs new file mode 100644 index 000000000..2313f3039 --- /dev/null +++ b/src/RESPite.StackExchange.Redis/RedisCommands.ServerCommands.cs @@ -0,0 +1,19 @@ +using System.Runtime.CompilerServices; + +namespace RESPite.StackExchange.Redis; + +internal static partial class RedisCommands +{ + // this is just a "type pun" - it should be an invisible/magic pointer cast to the JIT + public static ref readonly ServerCommands Servers(this in RespContext context) + => ref Unsafe.As(ref Unsafe.AsRef(in context)); +} +internal readonly struct ServerCommands(in RespContext context) +{ + public readonly RespContext Context = context; // important: this is the only field +} +internal static partial class ServerCommandsExtensions +{ + [RespCommand] + public static partial void Ping(this in ServerCommands context); +} diff --git a/src/RESPite.StackExchange.Redis/RedisCommands.SetCommands.cs b/src/RESPite.StackExchange.Redis/RedisCommands.SetCommands.cs new file mode 100644 index 000000000..b44dbef8d --- /dev/null +++ b/src/RESPite.StackExchange.Redis/RedisCommands.SetCommands.cs @@ -0,0 +1,154 @@ +using System.Runtime.CompilerServices; +using StackExchange.Redis; + +// ReSharper disable InconsistentNaming +// ReSharper disable MemberCanBePrivate.Global +namespace RESPite.StackExchange.Redis; + +internal static partial class RedisCommands +{ + // this is just a "type pun" - it should be an invisible/magic pointer cast to the JIT + public static ref readonly SetCommands Sets(this in RespContext context) + => ref Unsafe.As(ref Unsafe.AsRef(in context)); +} + +public readonly struct SetCommands(in RespContext context) +{ + public readonly RespContext Context = context; // important: this is the only field +} + +internal static partial class SetCommandsExtensions +{ + [RespCommand] + public static partial RespOperation SAdd(this in SetCommands context, RedisKey key, RedisValue member); + + [RespCommand] + public static partial RespOperation SAdd(this in SetCommands context, RedisKey key, RedisValue[] members); + + [RespCommand] + public static partial RespOperation SCard(this in SetCommands context, RedisKey key); + + [RespCommand] + public static partial RespOperation SDiff(this in SetCommands context, RedisKey first, RedisKey second); + + [RespCommand] + public static partial RespOperation SDiff(this in SetCommands context, RedisKey[] keys); + + [RespCommand] + public static partial RespOperation SDiffStore(this in SetCommands context, RedisKey destination, RedisKey first, RedisKey second); + + [RespCommand] + public static partial RespOperation SDiffStore(this in SetCommands context, RedisKey destination, RedisKey[] keys); + + [RespCommand] + public static partial RespOperation SInter(this in SetCommands context, RedisKey first, RedisKey second); + + [RespCommand] + public static partial RespOperation SInter(this in SetCommands context, RedisKey[] keys); + + [RespCommand] + public static partial RespOperation SInterCard(this in SetCommands context, RedisKey first, RedisKey second, long limit = 0); + + [RespCommand] + public static partial RespOperation SInterCard(this in SetCommands context, RedisKey[] keys, long limit = 0); + + [RespCommand] + public static partial RespOperation SInterStore(this in SetCommands context, RedisKey destination, RedisKey first, RedisKey second); + + [RespCommand] + public static partial RespOperation SInterStore(this in SetCommands context, RedisKey destination, RedisKey[] keys); + + [RespCommand] + public static partial RespOperation SIsMember(this in SetCommands context, RedisKey key, RedisValue member); + + [RespCommand] + public static partial RespOperation SMIsMember(this in SetCommands context, RedisKey key, RedisValue[] members); + + [RespCommand] + public static partial RespOperation SMembers(this in SetCommands context, RedisKey key); + + [RespCommand] + public static partial RespOperation SMove(this in SetCommands context, RedisKey source, RedisKey destination, RedisValue member); + + [RespCommand] + public static partial RespOperation SPop(this in SetCommands context, RedisKey key); + + [RespCommand] + public static partial RespOperation SPop(this in SetCommands context, RedisKey key, long count); + + [RespCommand] + public static partial RespOperation SRandMember(this in SetCommands context, RedisKey key); + + [RespCommand] + public static partial RespOperation SRandMember(this in SetCommands context, RedisKey key, long count); + + [RespCommand] + public static partial RespOperation SRem(this in SetCommands context, RedisKey key, RedisValue member); + + [RespCommand] + public static partial RespOperation SRem(this in SetCommands context, RedisKey key, RedisValue[] members); + + [RespCommand] + public static partial RespOperation SUnion(this in SetCommands context, RedisKey first, RedisKey second); + + [RespCommand] + public static partial RespOperation SUnion(this in SetCommands context, RedisKey[] keys); + + [RespCommand] + public static partial RespOperation SUnionStore(this in SetCommands context, RedisKey destination, RedisKey first, RedisKey second); + + [RespCommand] + public static partial RespOperation SUnionStore(this in SetCommands context, RedisKey destination, RedisKey[] keys); + + internal static RespOperation CombineStore( + this in SetCommands context, + SetOperation operation, + RedisKey destination, + RedisKey first, + RedisKey second) => + operation switch + { + SetOperation.Difference => context.SDiffStore(destination, first, second), + SetOperation.Intersect => context.SInterStore(destination, first, second), + SetOperation.Union => context.SUnionStore(destination, first, second), + _ => throw new ArgumentOutOfRangeException(nameof(operation)), + }; + + internal static RespOperation CombineStore( + this in SetCommands context, + SetOperation operation, + RedisKey destination, + RedisKey[] keys) => + operation switch + { + SetOperation.Difference => context.SDiffStore(destination, keys), + SetOperation.Intersect => context.SInterStore(destination, keys), + SetOperation.Union => context.SUnionStore(destination, keys), + _ => throw new ArgumentOutOfRangeException(nameof(operation)), + }; + + internal static RespOperation Combine( + this in SetCommands context, + SetOperation operation, + RedisKey first, + RedisKey second) => + operation switch + { + SetOperation.Difference => context.SDiff(first, second), + SetOperation.Intersect => context.SInter(first, second), + SetOperation.Union => context.SUnion(first, second), + _ => throw new ArgumentOutOfRangeException(nameof(operation)), + }; + + internal static RespOperation Combine( + this in SetCommands context, + SetOperation operation, + RedisKey[] keys) => + operation switch + { + SetOperation.Difference => context.SDiff(keys), + SetOperation.Intersect => context.SInter(keys), + SetOperation.Union => context.SUnion(keys), + _ => throw new ArgumentOutOfRangeException(nameof(operation)), + }; +} diff --git a/src/RESPite.StackExchange.Redis/RedisCommands.SortedSetCommands.cs b/src/RESPite.StackExchange.Redis/RedisCommands.SortedSetCommands.cs new file mode 100644 index 000000000..176352720 --- /dev/null +++ b/src/RESPite.StackExchange.Redis/RedisCommands.SortedSetCommands.cs @@ -0,0 +1,879 @@ +using System.Buffers; +using System.Runtime.CompilerServices; +using RESPite.Messages; +using StackExchange.Redis; + +// ReSharper disable MemberCanBePrivate.Global +// ReSharper disable InconsistentNaming +namespace RESPite.StackExchange.Redis; + +internal static partial class RedisCommands +{ + // this is just a "type pun" - it should be an invisible/magic pointer cast to the JIT + public static ref readonly SortedSetCommands SortedSets(this in RespContext context) + => ref Unsafe.As(ref Unsafe.AsRef(in context)); +} + +public readonly struct SortedSetCommands(in RespContext context) +{ + public readonly RespContext Context = context; // important: this is the only field + + public abstract class ZRangeRequest + { + [Flags] + internal enum ModeFlags + { + None = 0, + WithScores = 1, + ByLex = 2, + ByScore = 4, + } + + private void DemandType(Type type, string factory) + { + if (GetType() != type) Throw(factory); + static void Throw(string factory) => throw new InvalidOperationException($"The request for this operation must be created via {factory}"); + } + + /// + /// Indicates whether the data should be reversed. + /// + public bool Reverse { get; set; } + + /// + /// The offset into the sub-range for the matching elements. + /// + public long Offset { get; set; } + + /// + /// The number of elements to return. A netative value returns all elements from the . + /// + public long Count { get; set; } = -1; + + internal void Write( + scoped ReadOnlySpan command, + ref RespWriter writer, + RedisKey source, + RedisKey destination, + ModeFlags flags) + { + bool writeLimit = Offset != 0 || Count >= 0; + ReadOnlySpan by = default; + switch (flags & (ModeFlags.ByLex | ModeFlags.ByScore)) + { + case ModeFlags.ByLex: + DemandType(typeof(ZRangeRequestByLex), nameof(ByLex)); + break; + case ModeFlags.ByScore: + DemandType(typeof(ZRangeRequestByScore), nameof(ByScore)); + break; + default: + by = this switch + { + ZRangeRequestByLex => "$5\r\nBYLEX\r\n"u8, + ZRangeRequestByScore => "$7\r\nBYSCORE\r\n"u8, + _ => default, + }; + break; + } + + bool withScores = (flags & ModeFlags.WithScores) != 0; + int argCount = (by.IsEmpty ? 3 : 4) + + (withScores ? 1 : 0) + + (Reverse ? 1 : 0) + (writeLimit ? 3 : 0) + + (destination.IsNull ? 0 : 1); + writer.WriteCommand(command, argCount); + if (!destination.IsNull) writer.Write(destination); + writer.Write(source); + WriteStartStop(ref writer); + if (!by.IsEmpty) writer.WriteRaw(by); + if (Reverse) writer.WriteRaw("$3\r\nREV\r]\n"u8); + if (writeLimit) + { + writer.WriteRaw("$5\r\nLIMIT\r\n"u8); + writer.WriteBulkString(Offset); + writer.WriteBulkString(Count); + } + if (withScores) writer.WriteRaw("$10\r\nWITHSCORES\r\n"u8); + } + protected abstract void WriteStartStop(ref RespWriter writer); + private protected ZRangeRequest() { } + + public static ZRangeRequest ByRank(long start, long stop) + => new ZRangeRequestByRank(start, stop); + + public static ZRangeRequest ByLex(RedisValue start, RedisValue stop, Exclude exclude) + => new ZRangeRequestByLex(start, stop, exclude); + + public static ZRangeRequest ByScore(double start, double stop, Exclude exclude) + => new ZRangeRequestByScore(start, stop, exclude); + + private sealed class ZRangeRequestByRank(long start, long stop) : ZRangeRequest + { + protected override void WriteStartStop(ref RespWriter writer) + { + writer.WriteBulkString(start); + writer.WriteBulkString(stop); + } + } + private sealed class ZRangeRequestByLex(RedisValue start, RedisValue stop, Exclude exclude) : ZRangeRequest + { + protected override void WriteStartStop(ref RespWriter writer) + { + Write(ref writer, start, exclude, true); + Write(ref writer, stop, exclude, false); + } + } + + internal static void Write(ref RespWriter writer, in RedisValue value, Exclude exclude, bool isStart) + { + bool exclusive = (exclude & (isStart ? Exclude.Start : Exclude.Stop)) != 0; + if (value.IsNull) + { + writer.WriteRaw(isStart ? "$1\r\n-\r\n"u8 : "$1\r\n+\r\n"u8); + } + else + { + var len = value.GetByteCount(); + byte[]? lease = null; + var span = len < 128 ? stackalloc byte[128] : (lease = ArrayPool.Shared.Rent(len)); + span[0] = exclusive ? (byte)'(' : (byte)'['; + value.CopyTo(span.Slice(1)); // allow for the prefix + writer.WriteBulkString(span.Slice(0, len + 1)); + if (lease is not null) ArrayPool.Shared.Return(lease); + } + } + + private sealed class ZRangeRequestByScore(double start, double stop, Exclude exclude) : ZRangeRequest + { + protected override void WriteStartStop(ref RespWriter writer) + { + Write(ref writer, start, exclude, true); + Write(ref writer, stop, exclude, false); + } + } + + internal static void Write(ref RespWriter writer, double value, Exclude exclude, bool isStart) + { + bool exclusive = (exclude & (isStart ? Exclude.Start : Exclude.Stop)) != 0; + if (exclusive) + { + writer.WriteBulkStringExclusive(value); + } + else + { + writer.WriteBulkString(value); + } + } + } +} + +internal static partial class SortedSetCommandsExtensions +{ + [RespCommand] + public static partial RespOperation ZAdd( + this in SortedSetCommands context, + RedisKey key, + RedisValue member, + double score); + + [RespCommand(Formatter = "ZAddFormatter.Instance")] + public static partial RespOperation ZAdd( + this in SortedSetCommands context, + RedisKey key, + SortedSetWhen when, + RedisValue member, + double score); + + [RespCommand(Formatter = "ZAddFormatter.Instance")] + public static RespOperation ZAdd( + this in SortedSetCommands context, + RedisKey key, + SortedSetEntry[] values) => + context.ZAdd(key, SortedSetWhen.Always, values); + + [RespCommand(Formatter = "ZAddFormatter.Instance")] + public static partial RespOperation ZAdd( + this in SortedSetCommands context, + RedisKey key, + SortedSetWhen when, + SortedSetEntry[] values); + + [RespCommand] + public static partial RespOperation ZCard(this in SortedSetCommands context, RedisKey key); + + internal static RespOperation ZCardOrCount( + this in SortedSetCommands context, + RedisKey key, + double min, + double max, + Exclude exclude) + { + if (double.IsNegativeInfinity(min) && double.IsPositiveInfinity(max)) + { + return context.ZCard(key); + } + + return context.ZCount(key, min, max, exclude); + } + + internal static RespOperation Combine( + this in SortedSetCommands context, + SetOperation operation, + RedisKey[] keys, + double[]? weights = null, + Aggregate? aggregate = null) => + operation switch + { + SetOperation.Difference => context.ZDiff(keys), + SetOperation.Intersect => context.ZInter(keys, weights, aggregate), + SetOperation.Union => context.ZUnion(keys, weights, aggregate), + _ => throw new ArgumentOutOfRangeException(nameof(operation)), + }; + + internal static RespOperation Combine( + this in SortedSetCommands context, + SetOperation operation, + RedisKey first, + RedisKey second, + Aggregate? aggregate = null) => + operation switch + { + SetOperation.Difference => context.ZDiff(first, second), + SetOperation.Intersect => context.ZInter(first, second, aggregate), + SetOperation.Union => context.ZUnion(first, second, aggregate), + _ => throw new ArgumentOutOfRangeException(nameof(operation)), + }; + + internal static RespOperation CombineAndStore( + this in SortedSetCommands context, + SetOperation operation, + RedisKey destination, + RedisKey[] keys, + double[]? weights = null, + Aggregate? aggregate = null) => + operation switch + { + SetOperation.Difference => context.ZDiffStore(destination, keys), + SetOperation.Intersect => context.ZInterStore(destination, keys, weights, aggregate), + SetOperation.Union => context.ZUnionStore(destination, keys, weights, aggregate), + _ => throw new ArgumentOutOfRangeException(nameof(operation)), + }; + + internal static RespOperation CombineAndStore( + this in SortedSetCommands context, + SetOperation operation, + RedisKey destination, + RedisKey first, + RedisKey second, + Aggregate? aggregate = null) => + operation switch + { + SetOperation.Difference => context.ZDiffStore(destination, first, second), + SetOperation.Intersect => context.ZInterStore(destination, first, second, aggregate), + SetOperation.Union => context.ZUnionStore(destination, first, second, aggregate), + _ => throw new ArgumentOutOfRangeException(nameof(operation)), + }; + + internal static RespOperation CombineWithScores( + this in SortedSetCommands context, + SetOperation operation, + RedisKey[] keys, + double[]? weights = null, + Aggregate? aggregate = null) => + operation switch + { + SetOperation.Difference => context.ZDiffWithScores(keys), + SetOperation.Intersect => context.ZInterWithScores(keys, weights, aggregate), + SetOperation.Union => context.ZUnionWithScores(keys, weights, aggregate), + _ => throw new ArgumentOutOfRangeException(nameof(operation)), + }; + + internal static RespOperation CombineWithScores( + this in SortedSetCommands context, + SetOperation operation, + RedisKey first, + RedisKey second, + Aggregate? aggregate = null) => + operation switch + { + SetOperation.Difference => context.ZDiffWithScores(first, second), + SetOperation.Intersect => context.ZInterWithScores(first, second, aggregate), + SetOperation.Union => context.ZUnionWithScores(first, second, aggregate), + _ => throw new ArgumentOutOfRangeException(nameof(operation)), + }; + + [RespCommand(Formatter = "ZCountFormatter.Instance")] + public static partial RespOperation ZCount( + this in SortedSetCommands context, + RedisKey key, + double min, + double max, + Exclude exclude = Exclude.None); + + private sealed class ZCountFormatter : IRespFormatter<(RedisKey Key, double Min, double Max, Exclude Exclude)> + { + private ZCountFormatter() { } + public static readonly ZCountFormatter Instance = new(); + + public void Format( + scoped ReadOnlySpan command, + ref RespWriter writer, + in (RedisKey Key, double Min, double Max, Exclude Exclude) request) + { + writer.WriteCommand(command, 3); + writer.Write(request.Key); + SortedSetCommands.ZRangeRequest.Write(ref writer, request.Min, request.Exclude, true); + SortedSetCommands.ZRangeRequest.Write(ref writer, request.Max, request.Exclude, false); + } + } + + [RespCommand] + public static partial RespOperation ZDiff( + this in SortedSetCommands context, + RedisKey[] keys); + + [RespCommand] + public static partial RespOperation ZDiff( + this in SortedSetCommands context, + RedisKey first, + RedisKey second); + + [RespCommand] + public static partial RespOperation ZDiffStore( + this in SortedSetCommands context, + RedisKey destination, + RedisKey[] keys); + + [RespCommand] + public static partial RespOperation ZDiffStore( + this in SortedSetCommands context, + RedisKey destination, + RedisKey first, + RedisKey second); + + [RespCommand(nameof(ZDiff))] + public static partial RespOperation ZDiffWithScores( + this in SortedSetCommands context, + [RespSuffix("WITHSCORES")] RedisKey[] keys); + + [RespCommand(nameof(ZDiff))] + public static partial RespOperation ZDiffWithScores( + this in SortedSetCommands context, + RedisKey first, + [RespSuffix("WITHSCORES")] RedisKey second); + + [RespCommand] + public static partial RespOperation ZIncrBy( + this in SortedSetCommands context, + RedisKey key, + RedisValue member, + double increment); + + [RespCommand] + public static partial RespOperation ZInter( + this in SortedSetCommands context, + RedisKey[] keys, + [RespPrefix("WEIGHTS")] double[]? weights = null, + [RespPrefix("AGGREGATE")] Aggregate? aggregate = null); + + [RespCommand] + public static partial RespOperation ZInter( + this in SortedSetCommands context, + RedisKey first, + RedisKey second, + [RespPrefix("AGGREGATE")] Aggregate? aggregate = null); + + [RespCommand] + public static partial RespOperation ZInterCard( + this in SortedSetCommands context, + [RespPrefix] RedisKey[] keys, + [RespPrefix("LIMIT"), RespIgnore(0)] long limit = 0); + + [RespCommand] + public static partial RespOperation ZInterCard( + this in SortedSetCommands context, + [RespPrefix("2")] RedisKey first, + RedisKey second, + [RespPrefix("LIMIT"), RespIgnore(0)] long limit = 0); + + [RespCommand] + public static partial RespOperation ZInterStore( + this in SortedSetCommands context, + RedisKey destination, + RedisKey[] keys, + [RespPrefix("WEIGHTS")] double[]? weights = null, + [RespPrefix("AGGREGATE")] Aggregate? aggregate = null); + + [RespCommand] + public static partial RespOperation ZInterStore( + this in SortedSetCommands context, + RedisKey destination, + RedisKey first, + RedisKey second, + [RespPrefix("AGGREGATE")] Aggregate? aggregate = null); + + [RespCommand(nameof(ZInter))] + public static partial RespOperation ZInterWithScores( + this in SortedSetCommands context, + [RespSuffix("WITHSCORES")] RedisKey[] keys, + [RespPrefix("WEIGHTS")] double[]? weights = null, + [RespPrefix("AGGREGATE")] Aggregate? aggregate = null); + + [RespCommand(nameof(ZInter))] + public static partial RespOperation ZInterWithScores( + this in SortedSetCommands context, + RedisKey first, + [RespSuffix("WITHSCORES")] RedisKey second, + [RespPrefix("AGGREGATE")] Aggregate? aggregate = null); + + [RespCommand(Formatter = "ZLexCountFormatter.Instance")] + public static partial RespOperation ZLexCount( + this in SortedSetCommands context, + RedisKey key, + RedisValue min, + RedisValue max, + Exclude exclude = Exclude.None); + + private sealed class ZLexCountFormatter : IRespFormatter<(RedisKey Key, RedisValue Min, RedisValue Max, Exclude Exclude)> + { + private ZLexCountFormatter() { } + public static readonly ZLexCountFormatter Instance = new(); + + public void Format( + scoped ReadOnlySpan command, + ref RespWriter writer, + in (RedisKey Key, RedisValue Min, RedisValue Max, Exclude Exclude) request) + { + writer.WriteCommand(command, 3); + writer.Write(request.Key); + SortedSetCommands.ZRangeRequest.Write(ref writer, request.Min, request.Exclude, true); + SortedSetCommands.ZRangeRequest.Write(ref writer, request.Max, request.Exclude, false); + } + } + + [RespCommand(nameof(ZMPop))] + private static partial RespOperation ZMPopMax( + this in SortedSetCommands context, + [RespPrefix, RespSuffix("MAX")] RedisKey[] keys, + [RespIgnore(1), RespPrefix("COUNT")] long count); + + [RespCommand(nameof(ZMPop))] + private static partial RespOperation ZMPopMin( + this in SortedSetCommands context, + [RespPrefix, RespSuffix("MIN")] RedisKey[] keys, + [RespIgnore(1), RespPrefix("COUNT")] long count); + + public static RespOperation ZMPop( + this in SortedSetCommands context, + RedisKey[] keys, + Order order = Order.Ascending, + long count = 1) + => order == Order.Ascending ? context.ZMPopMin(keys, count) : context.ZMPopMax(keys, count); + + internal static RespOperation ZPop( + this in SortedSetCommands context, + RedisKey key, + Order order) => + order == Order.Ascending + ? context.ZPopMin(key) + : context.ZPopMax(key); + + internal static RespOperation ZPop( + this in SortedSetCommands context, + RedisKey key, + long count, + Order order) => + order == Order.Ascending + ? context.ZPopMin(key, count) + : context.ZPopMax(key, count); + + [RespCommand] + public static partial RespOperation ZPopMax(this in SortedSetCommands context, RedisKey key); + + [RespCommand] + public static partial RespOperation ZPopMax( + this in SortedSetCommands context, + RedisKey key, + long count); + + [RespCommand] + public static partial RespOperation ZPopMin(this in SortedSetCommands context, RedisKey key); + + [RespCommand] + public static partial RespOperation ZPopMin( + this in SortedSetCommands context, + RedisKey key, + long count); + + [RespCommand] + public static partial RespOperation ZRandMember(this in SortedSetCommands context, RedisKey key); + + [RespCommand] + public static partial RespOperation ZRandMember( + this in SortedSetCommands context, + RedisKey key, + long count); + + [RespCommand(nameof(ZRandMember))] + public static partial RespOperation ZRandMemberWithScores( + this in SortedSetCommands context, + RedisKey key, + [RespSuffix("WITHSCORES")] long count); + + private sealed class ZRangeFormatter : IRespFormatter<(RedisKey Key, SortedSetCommands.ZRangeRequest Request)>, + IRespFormatter<(RedisKey Destination, RedisKey Source, SortedSetCommands.ZRangeRequest Request)> + { + private readonly SortedSetCommands.ZRangeRequest.ModeFlags _flags; + private ZRangeFormatter(SortedSetCommands.ZRangeRequest.ModeFlags flags) => _flags = flags; + public static readonly ZRangeFormatter NoScores = new(SortedSetCommands.ZRangeRequest.ModeFlags.None); + public static readonly ZRangeFormatter WithScores = new(SortedSetCommands.ZRangeRequest.ModeFlags.WithScores); + public static readonly ZRangeFormatter ByLexNoScores = new(SortedSetCommands.ZRangeRequest.ModeFlags.ByLex); + public static readonly ZRangeFormatter ByLexWithScores = new(SortedSetCommands.ZRangeRequest.ModeFlags.WithScores | SortedSetCommands.ZRangeRequest.ModeFlags.ByLex); + public static readonly ZRangeFormatter ByScoreNoScores = new(SortedSetCommands.ZRangeRequest.ModeFlags.ByScore); + public static readonly ZRangeFormatter ByScoreWithScores = new(SortedSetCommands.ZRangeRequest.ModeFlags.WithScores | SortedSetCommands.ZRangeRequest.ModeFlags.ByScore); + + public void Format( + scoped ReadOnlySpan command, + ref RespWriter writer, + in (RedisKey Key, SortedSetCommands.ZRangeRequest Request) request) + => request.Request.Write(command, ref writer, RedisKey.Null, request.Key, _flags); + public void Format( + scoped ReadOnlySpan command, + ref RespWriter writer, + in (RedisKey Destination, RedisKey Source, SortedSetCommands.ZRangeRequest Request) request) + => request.Request.Write(command, ref writer, request.Destination, request.Source, _flags); + } + + [RespCommand] // by rank + public static partial RespOperation ZRange( + this in SortedSetCommands context, + RedisKey key, + long min, + long max); + + [RespCommand(Formatter = "ZRangeFormatter.NoScores")] // flexible + public static partial RespOperation ZRange( + this in SortedSetCommands context, + RedisKey key, + SortedSetCommands.ZRangeRequest request); + + internal static RespOperation ZRange( + this in SortedSetCommands context, + RedisKey key, + long min, + long max, + Order order) => order == Order.Ascending ? context.ZRange(key, min, max) : context.ZRevRange(key, max, min); + + [RespCommand(nameof(ZRange))] // by rank, with scores + public static partial RespOperation ZRangeWithScores( + this in SortedSetCommands context, + RedisKey key, + long min, + [RespSuffix("WITHSCORES")] long max); + + internal static RespOperation ZRangeWithScores( + this in SortedSetCommands context, + RedisKey key, + long min, + long max, + Order order) => order == Order.Ascending ? context.ZRangeWithScores(key, min, max) : context.ZRevRangeWithScores(key, max, min); + + [RespCommand(nameof(ZRange), Formatter = "ZRangeFormatter.WithScores")] // flexible, with scores + public static partial RespOperation ZRangeWithScores( + this in SortedSetCommands context, + RedisKey key, + SortedSetCommands.ZRangeRequest request); + + [RespCommand(Formatter = "ZRangeFormatter.ByLexNoScores")] + public static partial RespOperation ZRangeByLex( + this in SortedSetCommands context, + RedisKey key, + SortedSetCommands.ZRangeRequest request); + + internal static RespOperation ZRangeByLex( + this in SortedSetCommands context, + RedisKey key, + SortedSetCommands.ZRangeRequest request, + Order order) => order == Order.Ascending ? context.ZRangeByLex(key, request) : context.ZRevRangeByLex(key, request); + + [RespCommand(nameof(ZRangeByLex), Formatter = "ZRangeFormatter.ByLexWithScores")] + public static partial RespOperation ZRangeByLexWithScores( + this in SortedSetCommands context, + RedisKey key, + SortedSetCommands.ZRangeRequest request); + + [RespCommand(Formatter = "ZRangeFormatter.ByScoreNoScores")] + public static partial RespOperation ZRangeByScore( + this in SortedSetCommands context, + RedisKey key, + SortedSetCommands.ZRangeRequest request); + + internal static RespOperation ZRangeByScore( + this in SortedSetCommands context, + RedisKey key, + SortedSetCommands.ZRangeRequest request, + Order order) => order == Order.Ascending ? context.ZRangeByScore(key, request) : context.ZRevRangeByScore(key, request); + + [RespCommand(nameof(ZRangeByScore), Formatter = "ZRangeFormatter.ByScoreWithScores")] + public static partial RespOperation ZRangeByScoreWithScores( + this in SortedSetCommands context, + RedisKey key, + SortedSetCommands.ZRangeRequest request); + + internal static RespOperation ZRangeByScoreWithScores( + this in SortedSetCommands context, + RedisKey key, + SortedSetCommands.ZRangeRequest request, + Order order) => order == Order.Ascending ? context.ZRangeByScoreWithScores(key, request) : context.ZRevRangeByScoreWithScores(key, request); + + [RespCommand] // by rank + public static partial RespOperation ZRangeStore( + this in SortedSetCommands context, + RedisKey destination, + RedisKey source, + long min, + long max); + + [RespCommand(Formatter = "ZRangeFormatter.NoScores")] // flexible + public static partial RespOperation ZRangeStore( + this in SortedSetCommands context, + RedisKey destination, + RedisKey source, + SortedSetCommands.ZRangeRequest request); + + internal static RespOperation ZRangeStore( + this in SortedSetCommands context, + RedisKey sourceKey, + RedisKey destinationKey, + RedisValue start, + RedisValue stop, + SortedSetOrder sortedSetOrder, + Exclude exclude, + Order order, + long skip, + long? take) + { + SortedSetCommands.ZRangeRequest request = + sortedSetOrder switch + { + SortedSetOrder.ByRank => SortedSetCommands.ZRangeRequest.ByRank((long)start, (long)stop), + SortedSetOrder.ByLex => SortedSetCommands.ZRangeRequest.ByLex(start, stop, exclude), + SortedSetOrder.ByScore => SortedSetCommands.ZRangeRequest.ByScore((double)start, (double)stop, exclude), + _ => throw new ArgumentOutOfRangeException(nameof(sortedSetOrder)), + }; + request.Offset = skip; + if (take is not null) request.Count = take.Value; + request.Reverse = order == Order.Descending; + return context.ZRangeStore(destinationKey, sourceKey, request); + } + + internal static RespOperation ZRank( + this in SortedSetCommands context, + RedisKey key, + RedisValue member, + Order order) => + order == Order.Ascending + ? context.ZRank(key, member) + : context.ZRevRank(key, member); + + [RespCommand] + public static partial RespOperation ZRank( + this in SortedSetCommands context, + RedisKey key, + RedisValue member); + + [RespCommand] + public static partial RespOperation ZRem( + this in SortedSetCommands context, + RedisKey key, + RedisValue member); + + [RespCommand] + public static partial RespOperation ZRem( + this in SortedSetCommands context, + RedisKey key, + RedisValue[] members); + + [RespCommand(Formatter = "ZRangeFormatter.ByLexNoScores")] + public static partial RespOperation ZRemRangeByLex( + this in SortedSetCommands context, + RedisKey key, + SortedSetCommands.ZRangeRequest request); + + [RespCommand] + public static partial RespOperation ZRemRangeByRank( + this in SortedSetCommands context, + RedisKey key, + long start, + long stop); + + [RespCommand(Formatter = "ZRangeFormatter.ByScoreNoScores")] + public static partial RespOperation ZRemRangeByScore( + this in SortedSetCommands context, + RedisKey key, + SortedSetCommands.ZRangeRequest request); + + [RespCommand] + public static partial RespOperation ZRevRange( + this in SortedSetCommands context, + RedisKey key, + long start, + long stop); + + [RespCommand(nameof(ZRevRange))] + public static partial RespOperation ZRevRangeWithScores( + this in SortedSetCommands context, + RedisKey key, + long start, + [RespSuffix("WITHSCORES")] long stop); + + [RespCommand(Formatter = "ZRangeFormatter.ByLexNoScores")] + public static partial RespOperation ZRevRangeByLex( + this in SortedSetCommands context, + RedisKey key, + SortedSetCommands.ZRangeRequest request); + + [RespCommand(nameof(ZRevRangeByLex), Formatter = "ZRangeFormatter.ByLexWithScores")] + public static partial RespOperation ZRevRangeByLexWithScores( + this in SortedSetCommands context, + RedisKey key, + SortedSetCommands.ZRangeRequest request); + + [RespCommand(Formatter = "ZRangeFormatter.ByScoreNoScores")] + public static partial RespOperation ZRevRangeByScore( + this in SortedSetCommands context, + RedisKey key, + SortedSetCommands.ZRangeRequest request); + + [RespCommand(Formatter = "ZRangeFormatter.ByScoreWithScores")] + public static partial RespOperation ZRevRangeByScoreWithScores( + this in SortedSetCommands context, + RedisKey key, + SortedSetCommands.ZRangeRequest request); + + [RespCommand] + public static partial RespOperation ZRevRank( + this in SortedSetCommands context, + RedisKey key, + RedisValue member); + + [RespCommand(Parser = "RespParsers.ZScanSimple")] + public static partial RespOperation> ZScan( + this in SortedSetCommands context, + RedisKey key, + long cursor, + [RespPrefix("MATCH"), RespIgnore] RedisValue pattern = default, + [RespPrefix("COUNT"), RespIgnore(10)] long count = 10); + + [RespCommand] + public static partial RespOperation ZScore( + this in SortedSetCommands context, + RedisKey key, + RedisValue member); + + [RespCommand] + public static partial RespOperation ZScore( + this in SortedSetCommands context, + RedisKey key, + RedisValue[] members); + + [RespCommand] + public static partial RespOperation ZUnion( + this in SortedSetCommands context, + RedisKey[] keys, + [RespPrefix("WEIGHTS")] double[]? weights = null, + [RespPrefix("AGGREGATE")] Aggregate? aggregate = null); + + [RespCommand] + public static partial RespOperation ZUnion( + this in SortedSetCommands context, + RedisKey first, + RedisKey second, + [RespPrefix("AGGREGATE")] Aggregate? aggregate = null); + + [RespCommand] + public static partial RespOperation ZUnionStore( + this in SortedSetCommands context, + RedisKey destination, + RedisKey[] keys, + [RespPrefix("WEIGHTS")] double[]? weights = null, + [RespPrefix("AGGREGATE")] Aggregate? aggregate = null); + + [RespCommand] + public static partial RespOperation ZUnionStore( + this in SortedSetCommands context, + RedisKey destination, + RedisKey first, + RedisKey second, + [RespPrefix("AGGREGATE")] Aggregate? aggregate = null); + + [RespCommand(nameof(ZUnion))] + public static partial RespOperation ZUnionWithScores( + this in SortedSetCommands context, + [RespSuffix("WITHSCORES")] RedisKey[] keys, + [RespPrefix("WEIGHTS")] double[]? weights = null, + [RespPrefix("AGGREGATE")] Aggregate? aggregate = null); + + [RespCommand(nameof(ZUnion))] + public static partial RespOperation ZUnionWithScores( + this in SortedSetCommands context, + RedisKey first, + [RespSuffix("WITHSCORES")] RedisKey second, + [RespPrefix("AGGREGATE")] Aggregate? aggregate = null); + + private sealed class ZAddFormatter : + IRespFormatter<(RedisKey Key, SortedSetWhen When, RedisValue Member, double Score)>, + IRespFormatter<(RedisKey Key, SortedSetWhen When, SortedSetEntry[] Values)> + { + private ZAddFormatter() { } + public static readonly ZAddFormatter Instance = new(); + + public void Format( + scoped ReadOnlySpan command, + ref RespWriter writer, + in (RedisKey Key, SortedSetWhen When, RedisValue Member, double Score) request) + { + var argCount = 3 + GetWhenFlagCount(request.When); + writer.WriteCommand(command, argCount); + writer.Write(request.Key); + WriteWhenFlags(ref writer, request.When); + writer.WriteBulkString(request.Score); + writer.Write(request.Member); + } + + public void Format( + scoped ReadOnlySpan command, + ref RespWriter writer, + in (RedisKey Key, SortedSetWhen When, SortedSetEntry[] Values) request) + { + var argCount = 1 + GetWhenFlagCount(request.When) + (request.Values.Length * 2); + writer.WriteCommand(command, argCount); + writer.Write(request.Key); + WriteWhenFlags(ref writer, request.When); + foreach (var entry in request.Values) + { + writer.WriteBulkString(entry.Score); + writer.Write(entry.Element); + } + } + + private static int GetWhenFlagCount(SortedSetWhen when) + { + when &= SortedSetWhen.NotExists | SortedSetWhen.Exists | SortedSetWhen.GreaterThan | SortedSetWhen.LessThan; + return (int)when.CountBits(); + } + + private static void WriteWhenFlags(ref RespWriter writer, SortedSetWhen when) + { + if ((when & SortedSetWhen.NotExists) != 0) + writer.WriteBulkString("NX"u8); + if ((when & SortedSetWhen.Exists) != 0) + writer.WriteBulkString("XX"u8); + if ((when & SortedSetWhen.GreaterThan) != 0) + writer.WriteBulkString("GT"u8); + if ((when & SortedSetWhen.LessThan) != 0) + writer.WriteBulkString("LT"u8); + } + } +} diff --git a/src/RESPite.StackExchange.Redis/RedisCommands.StringCommands.cs b/src/RESPite.StackExchange.Redis/RedisCommands.StringCommands.cs new file mode 100644 index 000000000..00d1cfdf4 --- /dev/null +++ b/src/RESPite.StackExchange.Redis/RedisCommands.StringCommands.cs @@ -0,0 +1,22 @@ +using System.Runtime.CompilerServices; +using StackExchange.Redis; + +namespace RESPite.StackExchange.Redis; + +internal static partial class RedisCommands +{ + // this is just a "type pun" - it should be an invisible/magic pointer cast to the JIT + public static ref readonly StringCommands Strings(this in RespContext context) + => ref Unsafe.As(ref Unsafe.AsRef(in context)); +} + +internal readonly struct StringCommands(in RespContext context) +{ + public readonly RespContext Context = context; // important: this is the only field +} + +internal static partial class StringCommandsExtensions +{ + [RespCommand("get")] + public static partial RespOperation Get(this in StringCommands context, RedisKey key); +} diff --git a/src/RESPite.StackExchange.Redis/RedisCommands.cs b/src/RESPite.StackExchange.Redis/RedisCommands.cs new file mode 100644 index 000000000..35c604715 --- /dev/null +++ b/src/RESPite.StackExchange.Redis/RedisCommands.cs @@ -0,0 +1,41 @@ +using System.Buffers; +using System.Runtime.CompilerServices; + +namespace RESPite.StackExchange.Redis; + +internal static partial class RedisCommands +{ + public static ref readonly RespContext Self(this in RespContext context) + => ref context; // this just proves that the above are well-defined in terms of escape analysis +} + +public readonly struct ScanResult +{ + private const int MSB = 1 << 31; + private readonly int _countAndIsPooled; // and use MSB for "ispooled" + private readonly T[] values; + + public ScanResult(long cursor, T[] values) + { + Cursor = cursor; + this.values = values; + _countAndIsPooled = values.Length; + } + internal ScanResult(long cursor, T[] values, int count) + { + this.Cursor = cursor; + this.values = values; + _countAndIsPooled = count | MSB; + } + + public long Cursor { get; } + public ReadOnlySpan Values => new(values, 0, _countAndIsPooled & ~MSB); + + internal void UnsafeRecycle() + { + var arr = values; + bool recycle = (_countAndIsPooled & MSB) != 0; + Unsafe.AsRef(in this) = default; // best effort at salting the earth + if (recycle && arr is not null) ArrayPool.Shared.Return(arr); + } +} diff --git a/src/RESPite.StackExchange.Redis/RespContextBatch.cs b/src/RESPite.StackExchange.Redis/RespContextBatch.cs new file mode 100644 index 000000000..d6e80cc6f --- /dev/null +++ b/src/RESPite.StackExchange.Redis/RespContextBatch.cs @@ -0,0 +1,21 @@ +using RESPite.Connections; +using StackExchange.Redis; + +namespace RESPite.StackExchange.Redis; + +internal sealed class RespContextBatch : RespContextDatabase, IBatch, IDisposable, IRespContextSource +{ + private readonly RespBatch _batch; + + public RespContextBatch(IConnectionMultiplexer muxer, IRespContextSource source, int db) : base(muxer, source, db) + { + _batch = source.Context.CreateBatch(); + SetSource(this); + } + + void IBatch.Execute() => _batch.Flush(); + + public void Dispose() => _batch.Dispose(); + + public ref readonly RespContext Context => ref _batch.Context; +} diff --git a/src/RESPite.StackExchange.Redis/RespContextDatabase.Connection.cs b/src/RESPite.StackExchange.Redis/RespContextDatabase.Connection.cs new file mode 100644 index 000000000..f9ed5bf6d --- /dev/null +++ b/src/RESPite.StackExchange.Redis/RespContextDatabase.Connection.cs @@ -0,0 +1,67 @@ +using System.Net; +using RESPite.Messages; +using StackExchange.Redis; + +namespace RESPite.StackExchange.Redis; + +internal partial class RespContextDatabase +{ + // Connection and core methods + public bool IsConnected(RedisKey key, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + internal static readonly byte[] PingRaw = "*1\r\n$4\r\nping\r\n"u8.ToArray(); + + public Task PingAsync(CommandFlags flags = CommandFlags.None) => + Context(flags).Send("ping"u8, DateTime.UtcNow, PingParser.Default, PingRaw).AsTask(); + + public TimeSpan Ping(CommandFlags flags = CommandFlags.None) => + Context(flags).Send("ping"u8, DateTime.UtcNow, PingParser.Default, PingRaw).Wait(SyncTimeout); + + internal sealed class PingParser : IRespParser + { + public static readonly PingParser Default = new(); + private PingParser() { } + public TimeSpan Parse(in DateTime state, ref RespReader reader) => DateTime.UtcNow - state; + } + + public Task IdentifyEndpointAsync(RedisKey key = default, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public EndPoint? IdentifyEndpoint(RedisKey key = default, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public IBatch CreateBatch(object? asyncState = null) + { + if (asyncState is not null) throw new NotSupportedException($"{nameof(asyncState)} is not supported"); + return new RespContextBatch(_muxer, _source, _db); + } + + public ITransaction CreateTransaction(object? asyncState = null) + { + throw new NotImplementedException(); + } + + // Key migration + public Task KeyMigrateAsync( + RedisKey key, + EndPoint toServer, + int toDatabase = 0, + int timeoutMilliseconds = 0, + MigrateOptions migrateOptions = MigrateOptions.None, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public void KeyMigrate( + RedisKey key, + EndPoint toServer, + int toDatabase = 0, + int timeoutMilliseconds = 0, + MigrateOptions migrateOptions = MigrateOptions.None, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + // Debug + [RespCommand("debug")] + public partial RedisValue DebugObject([RespPrefix("object")] RedisKey key, CommandFlags flags = CommandFlags.None); +} diff --git a/src/RESPite.StackExchange.Redis/RespContextDatabase.Geo.cs b/src/RESPite.StackExchange.Redis/RespContextDatabase.Geo.cs new file mode 100644 index 000000000..87d28ab47 --- /dev/null +++ b/src/RESPite.StackExchange.Redis/RespContextDatabase.Geo.cs @@ -0,0 +1,224 @@ +using StackExchange.Redis; + +namespace RESPite.StackExchange.Redis; + +internal partial class RespContextDatabase +{ + // Async Geo methods + public Task GeoAddAsync( + RedisKey key, + double longitude, + double latitude, + RedisValue member, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task GeoAddAsync(RedisKey key, GeoEntry value, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task GeoAddAsync(RedisKey key, GeoEntry[] values, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task GeoRemoveAsync(RedisKey key, RedisValue member, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task GeoDistanceAsync( + RedisKey key, + RedisValue member1, + RedisValue member2, + GeoUnit unit = GeoUnit.Meters, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task GeoHashAsync(RedisKey key, RedisValue[] members, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task GeoHashAsync(RedisKey key, RedisValue member, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task GeoPositionAsync( + RedisKey key, + RedisValue member, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task GeoRadiusAsync( + RedisKey key, + RedisValue member, + double radius, + GeoUnit unit = GeoUnit.Meters, + int count = -1, + Order? order = null, + GeoRadiusOptions options = GeoRadiusOptions.Default, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task GeoRadiusAsync( + RedisKey key, + double longitude, + double latitude, + double radius, + GeoUnit unit = GeoUnit.Meters, + int count = -1, + Order? order = null, + GeoRadiusOptions options = GeoRadiusOptions.Default, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task GeoSearchAsync( + RedisKey key, + RedisValue member, + GeoSearchShape shape, + int count = -1, + bool demandClosest = true, + Order? order = null, + GeoRadiusOptions options = GeoRadiusOptions.Default, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task GeoSearchAsync( + RedisKey key, + double longitude, + double latitude, + GeoSearchShape shape, + int count = -1, + bool demandClosest = true, + Order? order = null, + GeoRadiusOptions options = GeoRadiusOptions.Default, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task GeoSearchAndStoreAsync( + RedisKey sourceKey, + RedisKey destinationKey, + RedisValue member, + GeoSearchShape shape, + int count = -1, + bool demandClosest = true, + Order? order = null, + bool storeDistances = false, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task GeoSearchAndStoreAsync( + RedisKey sourceKey, + RedisKey destinationKey, + double longitude, + double latitude, + GeoSearchShape shape, + int count = -1, + bool demandClosest = true, + Order? order = null, + bool storeDistances = false, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + // Synchronous Geo methods + public bool GeoAdd( + RedisKey key, + double longitude, + double latitude, + RedisValue member, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public bool GeoAdd(RedisKey key, GeoEntry value, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public long GeoAdd(RedisKey key, GeoEntry[] values, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public bool GeoRemove(RedisKey key, RedisValue member, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public double? GeoDistance( + RedisKey key, + RedisValue member1, + RedisValue member2, + GeoUnit unit = GeoUnit.Meters, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public string?[] GeoHash(RedisKey key, RedisValue[] members, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public string? GeoHash(RedisKey key, RedisValue member, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + [RespCommand("geopos")] + public partial GeoPosition?[] GeoPosition(RedisKey key, RedisValue[] members, CommandFlags flags = CommandFlags.None); + + public GeoPosition? GeoPosition(RedisKey key, RedisValue member, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public GeoRadiusResult[] GeoRadius( + RedisKey key, + RedisValue member, + double radius, + GeoUnit unit = GeoUnit.Meters, + int count = -1, + Order? order = null, + GeoRadiusOptions options = GeoRadiusOptions.Default, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public GeoRadiusResult[] GeoRadius( + RedisKey key, + double longitude, + double latitude, + double radius, + GeoUnit unit = GeoUnit.Meters, + int count = -1, + Order? order = null, + GeoRadiusOptions options = GeoRadiusOptions.Default, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public GeoRadiusResult[] GeoSearch( + RedisKey key, + RedisValue member, + GeoSearchShape shape, + int count = -1, + bool demandClosest = true, + Order? order = null, + GeoRadiusOptions options = GeoRadiusOptions.Default, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public GeoRadiusResult[] GeoSearch( + RedisKey key, + double longitude, + double latitude, + GeoSearchShape shape, + int count = -1, + bool demandClosest = true, + Order? order = null, + GeoRadiusOptions options = GeoRadiusOptions.Default, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public long GeoSearchAndStore( + RedisKey sourceKey, + RedisKey destinationKey, + RedisValue member, + GeoSearchShape shape, + int count = -1, + bool demandClosest = true, + Order? order = null, + bool storeDistances = false, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public long GeoSearchAndStore( + RedisKey sourceKey, + RedisKey destinationKey, + double longitude, + double latitude, + GeoSearchShape shape, + int count = -1, + bool demandClosest = true, + Order? order = null, + bool storeDistances = false, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); +} diff --git a/src/RESPite.StackExchange.Redis/RespContextDatabase.Hash.cs b/src/RESPite.StackExchange.Redis/RespContextDatabase.Hash.cs new file mode 100644 index 000000000..d721e9a8a --- /dev/null +++ b/src/RESPite.StackExchange.Redis/RespContextDatabase.Hash.cs @@ -0,0 +1,526 @@ +using RESPite.Messages; +using StackExchange.Redis; + +namespace RESPite.StackExchange.Redis; + +internal partial class RespContextDatabase +{ + public long HashDecrement( + RedisKey key, + RedisValue hashField, + long value = 1, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HIncrBy(key, hashField, -value).Wait(SyncTimeout); + + public double HashDecrement( + RedisKey key, + RedisValue hashField, + double value, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HIncrByFloat(key, hashField, -value).Wait(SyncTimeout); + + public Task HashDecrementAsync( + RedisKey key, + RedisValue hashField, + long value = 1, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HIncrBy(key, hashField, -value).AsTask(); + + public Task HashDecrementAsync( + RedisKey key, + RedisValue hashField, + double value, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HIncrByFloat(key, hashField, -value).AsTask(); + + public bool HashDelete( + RedisKey key, + RedisValue hashFields, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HDel(key, hashFields).Wait(SyncTimeout); + + public long HashDelete( + RedisKey key, + RedisValue[] hashFields, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HDel(key, hashFields).Wait(SyncTimeout); + + public Task HashDeleteAsync( + RedisKey key, + RedisValue hashField, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HDel(key, hashField).AsTask(); + + public Task HashDeleteAsync( + RedisKey key, + RedisValue[] hashFields, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HDel(key, hashFields).AsTask(); + + public bool HashExists(RedisKey key, RedisValue hashField, CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HExists(key, hashField).Wait(SyncTimeout); + + public Task HashExistsAsync( + RedisKey key, + RedisValue hashField, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HExists(key, hashField).AsTask(); + + public ExpireResult[] HashFieldExpire( + RedisKey key, + RedisValue[] hashFields, + TimeSpan expiry, + ExpireWhen when = ExpireWhen.Always, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HExpire(key, expiry, when, hashFields).Wait(SyncTimeout); + + public ExpireResult[] HashFieldExpire( + RedisKey key, + RedisValue[] hashFields, + DateTime expiry, + ExpireWhen when = ExpireWhen.Always, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HExpireAt(key, expiry, when, hashFields).Wait(SyncTimeout); + + public Task HashFieldExpireAsync( + RedisKey key, + RedisValue[] hashFields, + TimeSpan expiry, + ExpireWhen when = ExpireWhen.Always, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HExpire(key, expiry, when, hashFields).AsTask(); + + public Task HashFieldExpireAsync( + RedisKey key, + RedisValue[] hashFields, + DateTime expiry, + ExpireWhen when = ExpireWhen.Always, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HExpireAt(key, expiry, when, hashFields).AsTask(); + + public RedisValue HashFieldGetAndDelete( + RedisKey key, + RedisValue hashField, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HGetDel(key, hashField).Wait(SyncTimeout); + + public RedisValue[] HashFieldGetAndDelete( + RedisKey key, + RedisValue[] hashFields, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HGetDel(key, hashFields).Wait(SyncTimeout); + + public Task HashFieldGetAndDeleteAsync( + RedisKey key, + RedisValue hashField, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HGetDel(key, hashField).AsTask(); + + public Task HashFieldGetAndDeleteAsync( + RedisKey key, + RedisValue[] hashFields, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HGetDel(key, hashFields).AsTask(); + + public RedisValue HashFieldGetAndSetExpiry( + RedisKey key, + RedisValue hashField, + TimeSpan? expiry = null, + bool persist = false, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HGetEx(key, hashField, expiry, persist).Wait(SyncTimeout); + + public RedisValue HashFieldGetAndSetExpiry( + RedisKey key, + RedisValue hashField, + DateTime expiry, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HGetEx(key, expiry, hashField).Wait(SyncTimeout); + + public RedisValue[] HashFieldGetAndSetExpiry( + RedisKey key, + RedisValue[] hashFields, + TimeSpan? expiry = null, + bool persist = false, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HGetEx(key, hashFields, expiry, persist).Wait(SyncTimeout); + + public RedisValue[] HashFieldGetAndSetExpiry( + RedisKey key, + RedisValue[] hashFields, + DateTime expiry, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HGetEx(key, expiry, hashFields).Wait(SyncTimeout); + + public Task HashFieldGetAndSetExpiryAsync( + RedisKey key, + RedisValue hashField, + TimeSpan? expiry = null, + bool persist = false, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HGetEx(key, hashField, expiry, persist).AsTask(); + + public Task HashFieldGetAndSetExpiryAsync( + RedisKey key, + RedisValue hashField, + DateTime expiry, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HGetEx(key, expiry, hashField).AsTask(); + + public Task HashFieldGetAndSetExpiryAsync( + RedisKey key, + RedisValue[] hashFields, + TimeSpan? expiry = null, + bool persist = false, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HGetEx(key, hashFields, expiry, persist).AsTask(); + + public Task HashFieldGetAndSetExpiryAsync( + RedisKey key, + RedisValue[] hashFields, + DateTime expiry, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HGetEx(key, expiry, hashFields).AsTask(); + + public long[] HashFieldGetExpireDateTime( + RedisKey key, + RedisValue[] hashFields, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HPExpireTimeRaw(key, hashFields).Wait(SyncTimeout); + + public Task HashFieldGetExpireDateTimeAsync( + RedisKey key, + RedisValue[] hashFields, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HPExpireTimeRaw(key, hashFields).AsTask(); + + public Lease? HashFieldGetLeaseAndDelete( + RedisKey key, + RedisValue hashField, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HGetDelLease(key, hashField).Wait(SyncTimeout); + + public Task?> HashFieldGetLeaseAndDeleteAsync( + RedisKey key, + RedisValue hashField, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HGetDelLease(key, hashField).AsTask(); + + public Lease? HashFieldGetLeaseAndSetExpiry( + RedisKey key, + RedisValue hashField, + TimeSpan? expiry = null, + bool persist = false, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HGetExLease(key, hashField, expiry, persist).Wait(SyncTimeout); + + public Lease? HashFieldGetLeaseAndSetExpiry( + RedisKey key, + RedisValue hashField, + DateTime expiry, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HGetExLease(key, expiry, hashField).Wait(SyncTimeout); + + public Task?> HashFieldGetLeaseAndSetExpiryAsync( + RedisKey key, + RedisValue hashField, + TimeSpan? expiry = null, + bool persist = false, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HGetExLease(key, hashField, expiry, persist).AsTask(); + + public Task?> HashFieldGetLeaseAndSetExpiryAsync( + RedisKey key, + RedisValue hashField, + DateTime expiry, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HGetExLease(key, expiry, hashField).AsTask(); + + public long[] HashFieldGetTimeToLive( + RedisKey key, + RedisValue[] hashFields, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HTtlRaw(key, hashFields).Wait(SyncTimeout); + + public Task HashFieldGetTimeToLiveAsync( + RedisKey key, + RedisValue[] hashFields, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HTtlRaw(key, hashFields).AsTask(); + + public PersistResult[] HashFieldPersist( + RedisKey key, + RedisValue[] hashFields, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HPersist(key, hashFields).Wait(SyncTimeout); + + public Task HashFieldPersistAsync( + RedisKey key, + RedisValue[] hashFields, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HPersist(key, hashFields).AsTask(); + + public RedisValue HashFieldSetAndSetExpiry( + RedisKey key, + RedisValue hashField, + RedisValue value, + TimeSpan? expiry = null, + bool keepTtl = false, + When when = When.Always, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HSetExLegacy(key, expiry, hashField, value, when, keepTtl).Wait(SyncTimeout); + + public RedisValue HashFieldSetAndSetExpiry( + RedisKey key, + RedisValue hashField, + RedisValue value, + DateTime expiry, + When when = When.Always, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HSetExLegacy(key, expiry, hashField, value, when).Wait(SyncTimeout); + + public RedisValue HashFieldSetAndSetExpiry( + RedisKey key, + HashEntry[] hashFields, + TimeSpan? expiry = null, + bool keepTtl = false, + When when = When.Always, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HSetExLegacy(key, expiry, hashFields, when, keepTtl).Wait(SyncTimeout); + + public RedisValue HashFieldSetAndSetExpiry( + RedisKey key, + HashEntry[] hashFields, + DateTime expiry, + When when = When.Always, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HSetExLegacy(key, expiry, hashFields, when).Wait(SyncTimeout); + + public Task HashFieldSetAndSetExpiryAsync( + RedisKey key, + RedisValue hashField, + RedisValue value, + TimeSpan? expiry = null, + bool keepTtl = false, + When when = When.Always, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HSetExLegacy(key, expiry, hashField, value, when, keepTtl).AsTask(); + + public Task HashFieldSetAndSetExpiryAsync( + RedisKey key, + RedisValue hashField, + RedisValue value, + DateTime expiry, + When when = When.Always, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HSetExLegacy(key, expiry, hashField, value, when).AsTask(); + + public Task HashFieldSetAndSetExpiryAsync( + RedisKey key, + HashEntry[] hashFields, + TimeSpan? expiry = null, + bool keepTtl = false, + When when = When.Always, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HSetExLegacy(key, expiry, hashFields, when, keepTtl).AsTask(); + + public Task HashFieldSetAndSetExpiryAsync( + RedisKey key, + HashEntry[] hashFields, + DateTime expiry, + When when = When.Always, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HSetExLegacy(key, expiry, hashFields, when).AsTask(); + + public RedisValue HashGet(RedisKey key, RedisValue hashField, CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HGet(key, hashField).Wait(SyncTimeout); + + public RedisValue[] HashGet(RedisKey key, RedisValue[] hashFields, CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HMGet(key, hashFields).Wait(SyncTimeout); + + public HashEntry[] HashGetAll(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HGetAll(key).Wait(SyncTimeout); + + public Task HashGetAsync( + RedisKey key, + RedisValue hashField, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HGet(key, hashField).AsTask(); + + public Task HashGetAsync( + RedisKey key, + RedisValue[] hashFields, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HMGet(key, hashFields).AsTask(); + + public Task HashGetAllAsync(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HGetAll(key).AsTask(); + + public Lease? HashGetLease(RedisKey key, RedisValue hashField, CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HGetLease(key, hashField).Wait(SyncTimeout); + + public Task?> HashGetLeaseAsync( + RedisKey key, + RedisValue hashField, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HGetLease(key, hashField).AsTask(); + + public long HashIncrement( + RedisKey key, + RedisValue hashField, + long value = 1, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HIncrBy(key, hashField, value).Wait(SyncTimeout); + + public double HashIncrement( + RedisKey key, + RedisValue hashField, + double value, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HIncrByFloat(key, hashField, value).Wait(SyncTimeout); + + public Task HashIncrementAsync( + RedisKey key, + RedisValue hashField, + long value = 1, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HIncrBy(key, hashField, value).AsTask(); + + public Task HashIncrementAsync( + RedisKey key, + RedisValue hashField, + double value, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HIncrByFloat(key, hashField, value).AsTask(); + + public RedisValue[] HashKeys(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HKeys(key).Wait(SyncTimeout); + + public Task HashKeysAsync(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HKeys(key).AsTask(); + + public long HashLength(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HLen(key).Wait(SyncTimeout); + + public Task HashLengthAsync(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HLen(key).AsTask(); + + public RedisValue HashRandomField(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HRandField(key).Wait(SyncTimeout); + + public RedisValue[] HashRandomFields( + RedisKey key, + long count, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HRandField(key, count).Wait(SyncTimeout); + + public HashEntry[] HashRandomFieldsWithValues( + RedisKey key, + long count, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HRandFieldWithValues(key, count).Wait(SyncTimeout); + + public Task HashRandomFieldAsync(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HRandField(key).AsTask(); + + public Task HashRandomFieldsAsync( + RedisKey key, + long count, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HRandField(key, count).AsTask(); + + public Task HashRandomFieldsWithValuesAsync( + RedisKey key, + long count, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HRandFieldWithValues(key, count).AsTask(); + + public IEnumerable HashScan( + RedisKey key, + RedisValue pattern = default, + int pageSize = RedisBase.CursorUtils.DefaultLibraryPageSize, + long cursor = RedisBase.CursorUtils.Origin, + int pageOffset = 0, + CommandFlags flags = CommandFlags.None) + => throw new NotImplementedException(); + + public IEnumerable HashScan( + RedisKey key, + RedisValue pattern, + int pageSize, + CommandFlags flags) + => throw new NotImplementedException(); + + public IAsyncEnumerable HashScanAsync( + RedisKey key, + RedisValue pattern = default, + int pageSize = RedisBase.CursorUtils.DefaultLibraryPageSize, + long cursor = RedisBase.CursorUtils.Origin, + int pageOffset = 0, + CommandFlags flags = CommandFlags.None) + => throw new NotImplementedException(); + + public IEnumerable HashScanNoValues( + RedisKey key, + RedisValue pattern = default, + int pageSize = RedisBase.CursorUtils.DefaultLibraryPageSize, + long cursor = RedisBase.CursorUtils.Origin, + int pageOffset = 0, + CommandFlags flags = CommandFlags.None) + => throw new NotImplementedException(); + + public IAsyncEnumerable HashScanNoValuesAsync( + RedisKey key, + RedisValue pattern = default, + int pageSize = RedisBase.CursorUtils.DefaultLibraryPageSize, + long cursor = RedisBase.CursorUtils.Origin, + int pageOffset = 0, + CommandFlags flags = CommandFlags.None) + => throw new NotImplementedException(); + + public bool HashSet( + RedisKey key, + RedisValue hashField, + RedisValue value, + When when = When.Always, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HSet(key, hashField, value, when).Wait(SyncTimeout); + + public Task HashSetAsync( + RedisKey key, + RedisValue hashField, + RedisValue value, + When when = When.Always, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HSet(key, hashField, value, when).AsTask(); + + public void HashSet( + RedisKey key, + HashEntry[] hashFields, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HSet(key, hashFields).Wait(SyncTimeout); + + public Task HashSetAsync( + RedisKey key, + HashEntry[] hashFields, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HSet(key, hashFields).AsTask(); + + public long HashStringLength( + RedisKey key, + RedisValue hashField, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HStrLen(key, hashField).Wait(SyncTimeout); + + public Task HashStringLengthAsync( + RedisKey key, + RedisValue hashField, + CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HStrLen(key, hashField).AsTask(); + + public RedisValue[] HashValues(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HVals(key).Wait(SyncTimeout); + + public Task HashValuesAsync(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Hashes().HVals(key).AsTask(); +} diff --git a/src/RESPite.StackExchange.Redis/RespContextDatabase.HyperLogLog.cs b/src/RESPite.StackExchange.Redis/RespContextDatabase.HyperLogLog.cs new file mode 100644 index 000000000..ddc07dfba --- /dev/null +++ b/src/RESPite.StackExchange.Redis/RespContextDatabase.HyperLogLog.cs @@ -0,0 +1,79 @@ +using RESPite.Messages; +using StackExchange.Redis; + +namespace RESPite.StackExchange.Redis; + +internal partial class RespContextDatabase +{ + // Async HyperLogLog methods + public Task HyperLogLogAddAsync( + RedisKey key, + RedisValue value, + CommandFlags flags = CommandFlags.None) => + Context(flags).HyperLogLogs().PfAdd(key, value).AsTask(); + + public Task HyperLogLogAddAsync( + RedisKey key, + RedisValue[] values, + CommandFlags flags = CommandFlags.None) => + Context(flags).HyperLogLogs().PfAdd(key, values).AsTask(); + + public Task HyperLogLogLengthAsync( + RedisKey key, + CommandFlags flags = CommandFlags.None) => + Context(flags).HyperLogLogs().PfCount(key).AsTask(); + + public Task HyperLogLogLengthAsync( + RedisKey[] keys, + CommandFlags flags = CommandFlags.None) => + Context(flags).HyperLogLogs().PfCount(keys).AsTask(); + + public Task HyperLogLogMergeAsync( + RedisKey destination, + RedisKey first, + RedisKey second, + CommandFlags flags = CommandFlags.None) => + Context(flags).HyperLogLogs().PfMerge(destination, first, second).AsTask(); + + public Task HyperLogLogMergeAsync( + RedisKey destination, + RedisKey[] sourceKeys, + CommandFlags flags = CommandFlags.None) => + Context(flags).HyperLogLogs().PfMerge(destination, sourceKeys).AsTask(); + + // Synchronous HyperLogLog methods + public bool HyperLogLogAdd( + RedisKey key, + RedisValue value, + CommandFlags flags = CommandFlags.None) => + Context(flags).HyperLogLogs().PfAdd(key, value).Wait(SyncTimeout); + + public bool HyperLogLogAdd( + RedisKey key, + RedisValue[] values, + CommandFlags flags = CommandFlags.None) => + Context(flags).HyperLogLogs().PfAdd(key, values).Wait(SyncTimeout); + + public long HyperLogLogLength( + RedisKey key, + CommandFlags flags = CommandFlags.None) => + Context(flags).HyperLogLogs().PfCount(key).Wait(SyncTimeout); + + public long HyperLogLogLength( + RedisKey[] keys, + CommandFlags flags = CommandFlags.None) => + Context(flags).HyperLogLogs().PfCount(keys).Wait(SyncTimeout); + + public void HyperLogLogMerge( + RedisKey destination, + RedisKey first, + RedisKey second, + CommandFlags flags = CommandFlags.None) => + Context(flags).HyperLogLogs().PfMerge(destination, first, second).Wait(SyncTimeout); + + public void HyperLogLogMerge( + RedisKey destination, + RedisKey[] sourceKeys, + CommandFlags flags = CommandFlags.None) => + Context(flags).HyperLogLogs().PfMerge(destination, sourceKeys).Wait(SyncTimeout); +} diff --git a/src/RESPite.StackExchange.Redis/RespContextDatabase.Key.cs b/src/RESPite.StackExchange.Redis/RespContextDatabase.Key.cs new file mode 100644 index 000000000..e8cf31e2f --- /dev/null +++ b/src/RESPite.StackExchange.Redis/RespContextDatabase.Key.cs @@ -0,0 +1,192 @@ +using StackExchange.Redis; + +namespace RESPite.StackExchange.Redis; + +internal partial class RespContextDatabase +{ + public bool KeyCopy( + RedisKey sourceKey, + RedisKey destinationKey, + int destinationDatabase = -1, + bool replace = false, + CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().Copy(sourceKey, destinationKey, destinationDatabase, replace).Wait(SyncTimeout); + + public Task KeyCopyAsync( + RedisKey sourceKey, + RedisKey destinationKey, + int destinationDatabase = -1, + bool replace = false, + CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().Copy(sourceKey, destinationKey, destinationDatabase, replace).AsTask(); + + public bool KeyDelete(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().Del(key).Wait(SyncTimeout); + + public long KeyDelete(RedisKey[] keys, CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().Del(keys).Wait(SyncTimeout); + + public Task KeyDeleteAsync(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().Del(key).AsTask(); + + public Task KeyDeleteAsync(RedisKey[] keys, CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().Del(keys).AsTask(); + + public byte[]? KeyDump(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().Dump(key).Wait(SyncTimeout); + + public Task KeyDumpAsync(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().Dump(key).AsTask(); + + public string? KeyEncoding(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().ObjectEncoding(key).Wait(SyncTimeout); + + public Task KeyEncodingAsync(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().ObjectEncoding(key).AsTask(); + + public bool KeyExists(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().Exists(key).Wait(SyncTimeout); + + public long KeyExists(RedisKey[] keys, CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().Exists(keys).Wait(SyncTimeout); + + public Task KeyExistsAsync(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().Exists(key).AsTask(); + + public Task KeyExistsAsync(RedisKey[] keys, CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().Exists(keys).AsTask(); + + public bool KeyExpire(RedisKey key, TimeSpan? expiry, CommandFlags flags) + => Context(flags).Keys().Expire(key, expiry).Wait(SyncTimeout); + + public bool KeyExpire( + RedisKey key, + TimeSpan? expiry, + ExpireWhen when = ExpireWhen.Always, + CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().Expire(key, expiry, when).Wait(SyncTimeout); + + public bool KeyExpire(RedisKey key, DateTime? expiry, CommandFlags flags) + => Context(flags).Keys().ExpireAt(key, expiry).Wait(SyncTimeout); + + public bool KeyExpire( + RedisKey key, + DateTime? expiry, + ExpireWhen when = ExpireWhen.Always, + CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().ExpireAt(key, expiry, when).Wait(SyncTimeout); + + public Task KeyExpireAsync(RedisKey key, TimeSpan? expiry, CommandFlags flags) + => Context(flags).Keys().Expire(key, expiry).AsTask(); + + public Task KeyExpireAsync( + RedisKey key, + TimeSpan? expiry, + ExpireWhen when = ExpireWhen.Always, + CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().Expire(key, expiry, when).AsTask(); + + public Task KeyExpireAsync(RedisKey key, DateTime? expiry, CommandFlags flags) + => Context(flags).Keys().ExpireAt(key, expiry).AsTask(); + + public Task KeyExpireAsync( + RedisKey key, + DateTime? expiry, + ExpireWhen when = ExpireWhen.Always, + CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().ExpireAt(key, expiry, when).AsTask(); + + public DateTime? KeyExpireTime(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().PExpireTime(key).Wait(SyncTimeout); + + public Task KeyExpireTimeAsync(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().PExpireTime(key).AsTask(); + + public long? KeyFrequency(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().ObjectFreq(key).Wait(SyncTimeout); + + public Task KeyFrequencyAsync(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().ObjectFreq(key).AsTask(); + + public TimeSpan? KeyIdleTime(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().ObjectIdleTime(key).Wait(SyncTimeout); + + public Task KeyIdleTimeAsync(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().ObjectIdleTime(key).AsTask(); + + public bool KeyMove(RedisKey key, int database, CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().Move(key, database).Wait(SyncTimeout); + + public Task KeyMoveAsync(RedisKey key, int database, CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().Move(key, database).AsTask(); + + public bool KeyPersist(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().Persist(key).Wait(SyncTimeout); + + public Task KeyPersistAsync(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().Persist(key).AsTask(); + + public RedisKey KeyRandom(CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().RandomKey().Wait(SyncTimeout); + + public Task KeyRandomAsync(CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().RandomKey().AsTask(); + + public long? KeyRefCount(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().ObjectRefCount(key).Wait(SyncTimeout); + + public Task KeyRefCountAsync(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().ObjectRefCount(key).AsTask(); + + public bool KeyRename( + RedisKey key, + RedisKey newKey, + When when = When.Always, + CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().Rename(key, newKey, when).Wait(SyncTimeout); + + public Task KeyRenameAsync( + RedisKey key, + RedisKey newKey, + When when = When.Always, + CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().Rename(key, newKey, when).AsTask(); + + public void KeyRestore( + RedisKey key, + byte[] value, + TimeSpan? expiry = null, + CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().Restore(key, expiry, value).Wait(SyncTimeout); + + public Task KeyRestoreAsync( + RedisKey key, + byte[] value, + TimeSpan? expiry = null, + CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().Restore(key, expiry, value).AsTask(); + + public TimeSpan? KeyTimeToLive(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().Pttl(key).Wait(SyncTimeout); + + public Task KeyTimeToLiveAsync(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().Pttl(key).AsTask(); + + public bool KeyTouch(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().Touch(key).Wait(SyncTimeout); + + public long KeyTouch(RedisKey[] keys, CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().Touch(keys).Wait(SyncTimeout); + + public Task KeyTouchAsync(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().Touch(key).AsTask(); + + public Task KeyTouchAsync(RedisKey[] keys, CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().Touch(keys).AsTask(); + + public RedisType KeyType(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().Type(key).Wait(SyncTimeout); + + public Task KeyTypeAsync(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Keys().Type(key).AsTask(); +} diff --git a/src/RESPite.StackExchange.Redis/RespContextDatabase.List.cs b/src/RESPite.StackExchange.Redis/RespContextDatabase.List.cs new file mode 100644 index 000000000..995950ee4 --- /dev/null +++ b/src/RESPite.StackExchange.Redis/RespContextDatabase.List.cs @@ -0,0 +1,268 @@ +using StackExchange.Redis; + +namespace RESPite.StackExchange.Redis; + +internal partial class RespContextDatabase +{ + public RedisValue ListGetByIndex(RedisKey key, long index, CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().LIndex(key, index).Wait(SyncTimeout); + + public Task ListGetByIndexAsync(RedisKey key, long index, CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().LIndex(key, index).AsTask(); + + public long ListInsertAfter( + RedisKey key, + RedisValue pivot, + RedisValue value, + CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().LInsert(key, false, pivot, value).Wait(SyncTimeout); + + public Task ListInsertAfterAsync( + RedisKey key, + RedisValue pivot, + RedisValue value, + CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().LInsert(key, false, pivot, value).AsTask(); + + public long ListInsertBefore( + RedisKey key, + RedisValue pivot, + RedisValue value, + CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().LInsert(key, true, pivot, value).Wait(SyncTimeout); + + public Task ListInsertBeforeAsync( + RedisKey key, + RedisValue pivot, + RedisValue value, + CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().LInsert(key, true, pivot, value).AsTask(); + + public RedisValue ListLeftPop(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().LPop(key).Wait(SyncTimeout); + + public RedisValue[] ListLeftPop(RedisKey key, long count, CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().LPop(key, count).Wait(SyncTimeout); + + public ListPopResult ListLeftPop(RedisKey[] keys, long count, CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().LMPop(keys, ListSide.Left, count).Wait(SyncTimeout); + + public Task ListLeftPopAsync(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().LPop(key).AsTask(); + + public Task ListLeftPopAsync(RedisKey key, long count, CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().LPop(key, count).AsTask(); + + public Task ListLeftPopAsync(RedisKey[] keys, long count, CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().LMPop(keys, ListSide.Left, count).AsTask(); + + public long ListLeftPush( + RedisKey key, + RedisValue value, + When when = When.Always, + CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().Push(key, value, ListSide.Left, when).Wait(SyncTimeout); + + public long ListLeftPush( + RedisKey key, + RedisValue[] values, + When when, + CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().Push(key, values, ListSide.Left, when).Wait(SyncTimeout); + + public long ListLeftPush(RedisKey key, RedisValue[] values, CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().LPush(key, values).Wait(SyncTimeout); + + public Task ListLeftPushAsync( + RedisKey key, + RedisValue value, + When when = When.Always, + CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().Push(key, value, ListSide.Left, when).AsTask(); + + public Task ListLeftPushAsync( + RedisKey key, + RedisValue[] values, + When when, + CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().Push(key, values, ListSide.Left, when).AsTask(); + + public Task ListLeftPushAsync(RedisKey key, RedisValue[] values, CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().LPush(key, values).AsTask(); + + public long ListLength(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().LLen(key).Wait(SyncTimeout); + + public Task ListLengthAsync(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().LLen(key).AsTask(); + + public RedisValue ListMove( + RedisKey sourceKey, + RedisKey destinationKey, + ListSide sourceSide, + ListSide destinationSide, + CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().LMove(sourceKey, destinationKey, sourceSide, destinationSide).Wait(SyncTimeout); + + public Task ListMoveAsync( + RedisKey sourceKey, + RedisKey destinationKey, + ListSide sourceSide, + ListSide destinationSide, + CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().LMove(sourceKey, destinationKey, sourceSide, destinationSide).AsTask(); + + public long ListPosition( + RedisKey key, + RedisValue element, + long rank = 1, + long maxLength = 0, + CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().LPos(key, element, rank, maxLength).Wait(SyncTimeout); + + public Task ListPositionAsync( + RedisKey key, + RedisValue element, + long rank = 1, + long maxLength = 0, + CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().LPos(key, element, rank, maxLength).AsTask(); + + public long[] ListPositions( + RedisKey key, + RedisValue element, + long count, + long rank = 1, + long maxLength = 0, + CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().LPos(key, element, rank, maxLength, count).Wait(SyncTimeout); + + public Task ListPositionsAsync( + RedisKey key, + RedisValue element, + long count, + long rank = 1, + long maxLength = 0, + CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().LPos(key, element, rank, maxLength, count).AsTask(); + + public RedisValue[] ListRange( + RedisKey key, + long start, + long stop, + CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().LRange(key, start, stop).Wait(SyncTimeout); + + public Task ListRangeAsync( + RedisKey key, + long start, + long stop, + CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().LRange(key, start, stop).AsTask(); + + public long ListRemove( + RedisKey key, + RedisValue value, + long count = 0, + CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().LRem(key, count, value).Wait(SyncTimeout); + + public Task ListRemoveAsync( + RedisKey key, + RedisValue value, + long count = 0, + CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().LRem(key, count, value).AsTask(); + + public RedisValue ListRightPop(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().RPop(key).Wait(SyncTimeout); + + public RedisValue[] ListRightPop(RedisKey key, long count, CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().RPop(key, count).Wait(SyncTimeout); + + public ListPopResult ListRightPop(RedisKey[] keys, long count, CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().LMPop(keys, ListSide.Right, count).Wait(SyncTimeout); + + public Task ListRightPopAsync(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().RPop(key).AsTask(); + + public Task ListRightPopAsync(RedisKey key, long count, CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().RPop(key, count).AsTask(); + + public Task ListRightPopAsync(RedisKey[] keys, long count, CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().LMPop(keys, ListSide.Right, count).AsTask(); + + public RedisValue ListRightPopLeftPush( + RedisKey source, + RedisKey destination, + CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().RPopLPush(source, destination).Wait(SyncTimeout); + + public Task ListRightPopLeftPushAsync( + RedisKey source, + RedisKey destination, + CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().RPopLPush(source, destination).AsTask(); + + public long ListRightPush( + RedisKey key, + RedisValue value, + When when = When.Always, + CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().Push(key, value, ListSide.Right, when).Wait(SyncTimeout); + + public long ListRightPush( + RedisKey key, + RedisValue[] values, + When when, + CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().Push(key, values, ListSide.Right, when).Wait(SyncTimeout); + + public long ListRightPush(RedisKey key, RedisValue[] values, CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().RPush(key, values).Wait(SyncTimeout); + + public Task ListRightPushAsync( + RedisKey key, + RedisValue value, + When when = When.Always, + CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().Push(key, value, ListSide.Right, when).AsTask(); + + public Task ListRightPushAsync( + RedisKey key, + RedisValue[] values, + When when, + CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().Push(key, values, ListSide.Right, when).AsTask(); + + public Task ListRightPushAsync(RedisKey key, RedisValue[] values, CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().RPush(key, values).AsTask(); + + public void ListSetByIndex( + RedisKey key, + long index, + RedisValue value, + CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().LSet(key, index, value).Wait(SyncTimeout); + + public Task ListSetByIndexAsync( + RedisKey key, + long index, + RedisValue value, + CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().LSet(key, index, value).AsTask(); + + public void ListTrim( + RedisKey key, + long start, + long stop, + CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().LTrim(key, start, stop).Wait(SyncTimeout); + + public Task ListTrimAsync( + RedisKey key, + long start, + long stop, + CommandFlags flags = CommandFlags.None) + => Context(flags).Lists().LTrim(key, start, stop).AsTask(); +} diff --git a/src/RESPite.StackExchange.Redis/RespContextDatabase.Lock.cs b/src/RESPite.StackExchange.Redis/RespContextDatabase.Lock.cs new file mode 100644 index 000000000..cc9e4c919 --- /dev/null +++ b/src/RESPite.StackExchange.Redis/RespContextDatabase.Lock.cs @@ -0,0 +1,40 @@ +using StackExchange.Redis; + +namespace RESPite.StackExchange.Redis; + +internal partial class RespContextDatabase +{ + // Async Lock methods + public Task LockExtendAsync( + RedisKey key, + RedisValue value, + TimeSpan expiry, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task LockQueryAsync(RedisKey key, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task LockReleaseAsync(RedisKey key, RedisValue value, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task LockTakeAsync( + RedisKey key, + RedisValue value, + TimeSpan expiry, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + // Synchronous Lock methods + public bool LockExtend(RedisKey key, RedisValue value, TimeSpan expiry, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public RedisValue LockQuery(RedisKey key, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public bool LockRelease(RedisKey key, RedisValue value, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public bool LockTake(RedisKey key, RedisValue value, TimeSpan expiry, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); +} diff --git a/src/RESPite.StackExchange.Redis/RespContextDatabase.Script.cs b/src/RESPite.StackExchange.Redis/RespContextDatabase.Script.cs new file mode 100644 index 000000000..9930e6f9f --- /dev/null +++ b/src/RESPite.StackExchange.Redis/RespContextDatabase.Script.cs @@ -0,0 +1,112 @@ +using StackExchange.Redis; + +namespace RESPite.StackExchange.Redis; + +internal partial class RespContextDatabase +{ + // Async Script/Execute/Publish methods + public Task PublishAsync(RedisChannel channel, RedisValue message, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task ExecuteAsync(string command, params object[] args) => + throw new NotImplementedException(); + + public Task ExecuteAsync( + string command, + ICollection? args, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task ScriptEvaluateAsync( + string script, + RedisKey[]? keys = null, + RedisValue[]? values = null, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task ScriptEvaluateAsync( + byte[] hash, + RedisKey[]? keys = null, + RedisValue[]? values = null, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task ScriptEvaluateAsync( + LuaScript script, + object? parameters = null, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task ScriptEvaluateAsync( + LoadedLuaScript script, + object? parameters = null, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task ScriptEvaluateReadOnlyAsync( + string script, + RedisKey[]? keys = null, + RedisValue[]? values = null, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task ScriptEvaluateReadOnlyAsync( + byte[] hash, + RedisKey[]? keys = null, + RedisValue[]? values = null, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + // Synchronous Script/Execute/Publish methods + public long Publish(RedisChannel channel, RedisValue message, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public RedisResult Execute(string command, params object[] args) => + throw new NotImplementedException(); + + public RedisResult Execute( + string command, + ICollection? args, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public RedisResult ScriptEvaluate( + string script, + RedisKey[]? keys = null, + RedisValue[]? values = null, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public RedisResult ScriptEvaluate( + byte[] hash, + RedisKey[]? keys = null, + RedisValue[]? values = null, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public RedisResult ScriptEvaluate( + LuaScript script, + object? parameters = null, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public RedisResult ScriptEvaluate( + LoadedLuaScript script, + object? parameters = null, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public RedisResult ScriptEvaluateReadOnly( + string script, + RedisKey[]? keys = null, + RedisValue[]? values = null, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public RedisResult ScriptEvaluateReadOnly( + byte[] hash, + RedisKey[]? keys = null, + RedisValue[]? values = null, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); +} diff --git a/src/RESPite.StackExchange.Redis/RespContextDatabase.Set.cs b/src/RESPite.StackExchange.Redis/RespContextDatabase.Set.cs new file mode 100644 index 000000000..4563ccd05 --- /dev/null +++ b/src/RESPite.StackExchange.Redis/RespContextDatabase.Set.cs @@ -0,0 +1,175 @@ +using StackExchange.Redis; + +namespace RESPite.StackExchange.Redis; + +internal partial class RespContextDatabase +{ + public bool SetAdd(RedisKey key, RedisValue value, CommandFlags flags = CommandFlags.None) + => Context(flags).Sets().SAdd(key, value).Wait(SyncTimeout); + + public long SetAdd(RedisKey key, RedisValue[] values, CommandFlags flags = CommandFlags.None) + => Context(flags).Sets().SAdd(key, values).Wait(SyncTimeout); + + public Task SetAddAsync(RedisKey key, RedisValue value, CommandFlags flags = CommandFlags.None) + => Context(flags).Sets().SAdd(key, value).AsTask(); + + public Task SetAddAsync(RedisKey key, RedisValue[] values, CommandFlags flags = CommandFlags.None) + => Context(flags).Sets().SAdd(key, values).AsTask(); + + public RedisValue[] SetCombine( + SetOperation operation, + RedisKey first, + RedisKey second, + CommandFlags flags = CommandFlags.None) + => Context(flags).Sets().Combine(operation, first, second).Wait(SyncTimeout); + + public RedisValue[] SetCombine(SetOperation operation, RedisKey[] keys, CommandFlags flags = CommandFlags.None) + => Context(flags).Sets().Combine(operation, keys).Wait(SyncTimeout); + + public long SetCombineAndStore( + SetOperation operation, + RedisKey destination, + RedisKey first, + RedisKey second, + CommandFlags flags = CommandFlags.None) + => Context(flags).Sets().CombineStore(operation, destination, first, second).Wait(SyncTimeout); + + public long SetCombineAndStore( + SetOperation operation, + RedisKey destination, + RedisKey[] keys, + CommandFlags flags = CommandFlags.None) + => Context(flags).Sets().CombineStore(operation, destination, keys).Wait(SyncTimeout); + + public Task SetCombineAndStoreAsync( + SetOperation operation, + RedisKey destination, + RedisKey first, + RedisKey second, + CommandFlags flags = CommandFlags.None) + => Context(flags).Sets().CombineStore(operation, destination, first, second).AsTask(); + + public Task SetCombineAndStoreAsync( + SetOperation operation, + RedisKey destination, + RedisKey[] keys, + CommandFlags flags = CommandFlags.None) + => Context(flags).Sets().CombineStore(operation, destination, keys).AsTask(); + + public Task SetCombineAsync( + SetOperation operation, + RedisKey first, + RedisKey second, + CommandFlags flags = CommandFlags.None) + => Context(flags).Sets().Combine(operation, first, second).AsTask(); + + public Task SetCombineAsync( + SetOperation operation, + RedisKey[] keys, + CommandFlags flags = CommandFlags.None) + => Context(flags).Sets().Combine(operation, keys).AsTask(); + + public bool SetContains(RedisKey key, RedisValue value, CommandFlags flags = CommandFlags.None) + => Context(flags).Sets().SIsMember(key, value).Wait(SyncTimeout); + + public bool[] SetContains(RedisKey key, RedisValue[] values, CommandFlags flags = CommandFlags.None) + => Context(flags).Sets().SMIsMember(key, values).Wait(SyncTimeout); + + public Task SetContainsAsync(RedisKey key, RedisValue value, CommandFlags flags = CommandFlags.None) + => Context(flags).Sets().SIsMember(key, value).AsTask(); + + public Task SetContainsAsync(RedisKey key, RedisValue[] values, CommandFlags flags = CommandFlags.None) + => Context(flags).Sets().SMIsMember(key, values).AsTask(); + + public long SetIntersectionLength(RedisKey[] keys, long limit = 0, CommandFlags flags = CommandFlags.None) + => Context(flags).Sets().SInterCard(keys, limit).Wait(SyncTimeout); + + public Task SetIntersectionLengthAsync( + RedisKey[] keys, + long limit = 0, + CommandFlags flags = CommandFlags.None) + => Context(flags).Sets().SInterCard(keys, limit).AsTask(); + + public long SetLength(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Sets().SCard(key).Wait(SyncTimeout); + + public Task SetLengthAsync(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Sets().SCard(key).AsTask(); + + public RedisValue[] SetMembers(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Sets().SMembers(key).Wait(SyncTimeout); + + public Task SetMembersAsync(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Sets().SMembers(key).AsTask(); + + public bool SetMove( + RedisKey source, + RedisKey destination, + RedisValue value, + CommandFlags flags = CommandFlags.None) + => Context(flags).Sets().SMove(source, destination, value).Wait(SyncTimeout); + + public Task SetMoveAsync( + RedisKey source, + RedisKey destination, + RedisValue value, + CommandFlags flags = CommandFlags.None) + => Context(flags).Sets().SMove(source, destination, value).AsTask(); + + public RedisValue SetPop(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Sets().SPop(key).Wait(SyncTimeout); + + public RedisValue[] SetPop(RedisKey key, long count, CommandFlags flags = CommandFlags.None) + => Context(flags).Sets().SPop(key, count).Wait(SyncTimeout); + + public Task SetPopAsync(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Sets().SPop(key).AsTask(); + + public Task SetPopAsync(RedisKey key, long count, CommandFlags flags = CommandFlags.None) + => Context(flags).Sets().SPop(key, count).AsTask(); + + public RedisValue SetRandomMember(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Sets().SRandMember(key).Wait(SyncTimeout); + + public RedisValue[] SetRandomMembers(RedisKey key, long count, CommandFlags flags = CommandFlags.None) + => Context(flags).Sets().SRandMember(key, count).Wait(SyncTimeout); + + public Task SetRandomMemberAsync(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).Sets().SRandMember(key).AsTask(); + + public Task SetRandomMembersAsync(RedisKey key, long count, CommandFlags flags = CommandFlags.None) + => Context(flags).Sets().SRandMember(key, count).AsTask(); + + public bool SetRemove(RedisKey key, RedisValue value, CommandFlags flags = CommandFlags.None) + => Context(flags).Sets().SRem(key, value).Wait(SyncTimeout); + + public long SetRemove(RedisKey key, RedisValue[] values, CommandFlags flags = CommandFlags.None) + => Context(flags).Sets().SRem(key, values).Wait(SyncTimeout); + + public Task SetRemoveAsync(RedisKey key, RedisValue value, CommandFlags flags = CommandFlags.None) + => Context(flags).Sets().SRem(key, value).AsTask(); + + public Task SetRemoveAsync(RedisKey key, RedisValue[] values, CommandFlags flags = CommandFlags.None) + => Context(flags).Sets().SRem(key, values).AsTask(); + + public IEnumerable SetScan(RedisKey key, RedisValue pattern, int pageSize, CommandFlags flags) => + throw new NotImplementedException(); + + public IEnumerable SetScan( + RedisKey key, + RedisValue pattern = default, + int pageSize = RedisBase.CursorUtils.DefaultLibraryPageSize, + long cursor = RedisBase.CursorUtils.Origin, + int pageOffset = 0, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public IAsyncEnumerable SetScanAsync( + RedisKey key, + RedisValue pattern = default, + int pageSize = RedisBase.CursorUtils.DefaultLibraryPageSize, + long cursor = RedisBase.CursorUtils.Origin, + int pageOffset = 0, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); +} diff --git a/src/RESPite.StackExchange.Redis/RespContextDatabase.Sort.cs b/src/RESPite.StackExchange.Redis/RespContextDatabase.Sort.cs new file mode 100644 index 000000000..dc1148194 --- /dev/null +++ b/src/RESPite.StackExchange.Redis/RespContextDatabase.Sort.cs @@ -0,0 +1,54 @@ +using StackExchange.Redis; + +namespace RESPite.StackExchange.Redis; + +internal partial class RespContextDatabase +{ + // Async Sort methods + public Task SortAsync( + RedisKey key, + long skip = 0, + long take = -1, + Order order = Order.Ascending, + SortType sortType = SortType.Numeric, + RedisValue by = default, + RedisValue[]? get = null, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task SortAndStoreAsync( + RedisKey destination, + RedisKey key, + long skip = 0, + long take = -1, + Order order = Order.Ascending, + SortType sortType = SortType.Numeric, + RedisValue by = default, + RedisValue[]? get = null, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + // Synchronous Sort methods + public RedisValue[] Sort( + RedisKey key, + long skip = 0, + long take = -1, + Order order = Order.Ascending, + SortType sortType = SortType.Numeric, + RedisValue by = default, + RedisValue[]? get = null, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public long SortAndStore( + RedisKey destination, + RedisKey key, + long skip = 0, + long take = -1, + Order order = Order.Ascending, + SortType sortType = SortType.Numeric, + RedisValue by = default, + RedisValue[]? get = null, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); +} diff --git a/src/RESPite.StackExchange.Redis/RespContextDatabase.SortedSet.cs b/src/RESPite.StackExchange.Redis/RespContextDatabase.SortedSet.cs new file mode 100644 index 000000000..d4b870261 --- /dev/null +++ b/src/RESPite.StackExchange.Redis/RespContextDatabase.SortedSet.cs @@ -0,0 +1,596 @@ +using RESPite.Messages; +using StackExchange.Redis; + +namespace RESPite.StackExchange.Redis; + +internal partial class RespContextDatabase +{ + public bool SortedSetAdd( + RedisKey key, + RedisValue member, + double score, + CommandFlags flags) + => Context(flags).SortedSets().ZAdd(key, member, score).Wait(SyncTimeout); + + public bool SortedSetAdd( + RedisKey key, + RedisValue member, + double score, + When when, + CommandFlags flags) + => Context(flags).SortedSets().ZAdd(key, when.ToSortedSetWhen(), member, score).Wait(SyncTimeout); + + public bool SortedSetAdd( + RedisKey key, + RedisValue member, + double score, + SortedSetWhen when, + CommandFlags flags) + => Context(flags).SortedSets().ZAdd(key, when, member, score).Wait(SyncTimeout); + + public long SortedSetAdd(RedisKey key, SortedSetEntry[] values, CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZAdd(key, values).Wait(SyncTimeout); + + public long SortedSetAdd( + RedisKey key, + SortedSetEntry[] values, + When when, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZAdd(key, when.ToSortedSetWhen(), values).Wait(SyncTimeout); + + public long SortedSetAdd( + RedisKey key, + SortedSetEntry[] values, + SortedSetWhen when = SortedSetWhen.Always, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZAdd(key, when, values).Wait(SyncTimeout); + + public Task SortedSetAddAsync( + RedisKey key, + RedisValue member, + double score, + CommandFlags flags) + => Context(flags).SortedSets().ZAdd(key, member, score).AsTask(); + + public Task SortedSetAddAsync( + RedisKey key, + RedisValue member, + double score, + When when, + CommandFlags flags) + => Context(flags).SortedSets().ZAdd(key, when.ToSortedSetWhen(), member, score).AsTask(); + + public Task SortedSetAddAsync( + RedisKey key, + RedisValue member, + double score, + SortedSetWhen when, + CommandFlags flags) + => Context(flags).SortedSets().ZAdd(key, when, member, score).AsTask(); + + public Task SortedSetAddAsync( + RedisKey key, + SortedSetEntry[] values, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZAdd(key, values).AsTask(); + + public Task SortedSetAddAsync( + RedisKey key, + SortedSetEntry[] values, + When when, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZAdd(key, when.ToSortedSetWhen(), values).AsTask(); + + public Task SortedSetAddAsync( + RedisKey key, + SortedSetEntry[] values, + SortedSetWhen when = SortedSetWhen.Always, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZAdd(key, when, values).AsTask(); + + public RedisValue[] SortedSetCombine( + SetOperation operation, + RedisKey[] keys, + double[]? weights = null, + Aggregate aggregate = Aggregate.Sum, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().Combine(operation, keys, weights, aggregate).Wait(SyncTimeout); + + public long SortedSetCombineAndStore( + SetOperation operation, + RedisKey destination, + RedisKey first, + RedisKey second, + Aggregate aggregate = Aggregate.Sum, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().CombineAndStore(operation, destination, first, second, aggregate).Wait(SyncTimeout); + + public long SortedSetCombineAndStore( + SetOperation operation, + RedisKey destination, + RedisKey[] keys, + double[]? weights = null, + Aggregate aggregate = Aggregate.Sum, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().CombineAndStore(operation, destination, keys, weights, aggregate).Wait(SyncTimeout); + + public Task SortedSetCombineAsync( + SetOperation operation, + RedisKey[] keys, + double[]? weights = null, + Aggregate aggregate = Aggregate.Sum, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().Combine(operation, keys, weights, aggregate).AsTask(); + + public Task SortedSetCombineAndStoreAsync( + SetOperation operation, + RedisKey destination, + RedisKey first, + RedisKey second, + Aggregate aggregate = Aggregate.Sum, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().CombineAndStore(operation, destination, first, second, aggregate).AsTask(); + + public Task SortedSetCombineAndStoreAsync( + SetOperation operation, + RedisKey destination, + RedisKey[] keys, + double[]? weights = null, + Aggregate aggregate = Aggregate.Sum, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().CombineAndStore(operation, destination, keys, weights, aggregate).AsTask(); + + public SortedSetEntry[] SortedSetCombineWithScores( + SetOperation operation, + RedisKey[] keys, + double[]? weights = null, + Aggregate aggregate = Aggregate.Sum, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().CombineWithScores(operation, keys, weights, aggregate).Wait(SyncTimeout); + + public Task SortedSetCombineWithScoresAsync( + SetOperation operation, + RedisKey[] keys, + double[]? weights = null, + Aggregate aggregate = Aggregate.Sum, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().CombineWithScores(operation, keys, weights, aggregate).AsTask(); + + public double SortedSetDecrement( + RedisKey key, + RedisValue member, + double value, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZIncrBy(key, member, -value).Wait(SyncTimeout); + + public Task SortedSetDecrementAsync( + RedisKey key, + RedisValue member, + double value, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZIncrBy(key, member, -value).AsTask(); + + public double SortedSetIncrement( + RedisKey key, + RedisValue member, + double value, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZIncrBy(key, member, value).Wait(SyncTimeout); + + public Task SortedSetIncrementAsync( + RedisKey key, + RedisValue member, + double value, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZIncrBy(key, member, value).AsTask(); + + public long SortedSetIntersectionLength(RedisKey[] keys, long limit = 0, CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZInterCard(keys, limit).Wait(SyncTimeout); + + public Task SortedSetIntersectionLengthAsync( + RedisKey[] keys, + long limit = 0, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZInterCard(keys, limit).AsTask(); + + public long SortedSetLength( + RedisKey key, + double min = double.NegativeInfinity, + double max = double.PositiveInfinity, + Exclude exclude = Exclude.None, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZCardOrCount(key, min, max, exclude).Wait(SyncTimeout); + + public Task SortedSetLengthAsync( + RedisKey key, + double min = double.NegativeInfinity, + double max = double.PositiveInfinity, + Exclude exclude = Exclude.None, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZCardOrCount(key, min, max, exclude).AsTask(); + + public long SortedSetLengthByValue( + RedisKey key, + RedisValue min, + RedisValue max, + Exclude exclude = Exclude.None, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZLexCount(key, min, max, exclude).Wait(SyncTimeout); + + public Task SortedSetLengthByValueAsync( + RedisKey key, + RedisValue min, + RedisValue max, + Exclude exclude = Exclude.None, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZLexCount(key, min, max, exclude).AsTask(); + + public SortedSetEntry? SortedSetPop( + RedisKey key, + Order order = Order.Ascending, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZPop(key, order).Wait(SyncTimeout); + + public SortedSetEntry[] SortedSetPop( + RedisKey key, + long count, + Order order = Order.Ascending, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZPop(key, count, order).Wait(SyncTimeout); + + public SortedSetPopResult SortedSetPop( + RedisKey[] keys, + long count, + Order order = Order.Ascending, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZMPop(keys, order, count).Wait(SyncTimeout); + + public Task SortedSetPopAsync( + RedisKey key, + Order order = Order.Ascending, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZPop(key, order).AsTask(); + + public Task SortedSetPopAsync( + RedisKey key, + long count, + Order order = Order.Ascending, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZPop(key, count, order).AsTask(); + + public Task SortedSetPopAsync( + RedisKey[] keys, + long count, + Order order = Order.Ascending, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZMPop(keys, order, count).AsTask(); + + public RedisValue SortedSetRandomMember(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZRandMember(key).Wait(SyncTimeout); + + public Task SortedSetRandomMemberAsync(RedisKey key, CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZRandMember(key).AsTask(); + + public RedisValue[] SortedSetRandomMembers(RedisKey key, long count, CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZRandMember(key, count).Wait(SyncTimeout); + + public Task SortedSetRandomMembersAsync( + RedisKey key, + long count, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZRandMember(key, count).AsTask(); + + public SortedSetEntry[] SortedSetRandomMembersWithScores( + RedisKey key, + long count, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZRandMemberWithScores(key, count).Wait(SyncTimeout); + + public Task SortedSetRandomMembersWithScoresAsync( + RedisKey key, + long count, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZRandMemberWithScores(key, count).AsTask(); + + public long? SortedSetRank( + RedisKey key, + RedisValue member, + Order order = Order.Ascending, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZRank(key, member, order).Wait(SyncTimeout); + + public Task SortedSetRankAsync( + RedisKey key, + RedisValue member, + Order order = Order.Ascending, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZRank(key, member, order).AsTask(); + + public RedisValue[] SortedSetRangeByRank( + RedisKey key, + long start = 0, + long stop = -1, + Order order = Order.Ascending, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZRange(key, start, stop, order).Wait(SyncTimeout); + + public long SortedSetRangeAndStore( + RedisKey sourceKey, + RedisKey destinationKey, + RedisValue start, + RedisValue stop, + SortedSetOrder sortedSetOrder = SortedSetOrder.ByRank, + Exclude exclude = Exclude.None, + Order order = Order.Ascending, + long skip = 0, + long? take = null, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZRangeStore(sourceKey, destinationKey, start, stop, sortedSetOrder, exclude, order, skip, take).Wait(SyncTimeout); + + public Task SortedSetRangeByRankAsync( + RedisKey key, + long start = 0, + long stop = -1, + Order order = Order.Ascending, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZRange(key, start, stop, order).AsTask(); + + public Task SortedSetRangeAndStoreAsync( + RedisKey sourceKey, + RedisKey destinationKey, + RedisValue start, + RedisValue stop, + SortedSetOrder sortedSetOrder = SortedSetOrder.ByRank, + Exclude exclude = Exclude.None, + Order order = Order.Ascending, + long skip = 0, + long? take = null, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZRangeStore(sourceKey, destinationKey, start, stop, sortedSetOrder, exclude, order, skip, take).AsTask(); + + public SortedSetEntry[] SortedSetRangeByRankWithScores( + RedisKey key, + long start = 0, + long stop = -1, + Order order = Order.Ascending, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZRangeWithScores(key, start, stop, order).Wait(SyncTimeout); + + public Task SortedSetRangeByRankWithScoresAsync( + RedisKey key, + long start = 0, + long stop = -1, + Order order = Order.Ascending, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZRangeWithScores(key, start, stop, order).AsTask(); + + public RedisValue[] SortedSetRangeByScore( + RedisKey key, + double start = double.NegativeInfinity, + double stop = double.PositiveInfinity, + Exclude exclude = Exclude.None, + Order order = Order.Ascending, + long skip = 0, + long take = -1, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZRangeByScore(key, ByScore(start, stop, exclude, skip, take), order).Wait(SyncTimeout); + + private static SortedSetCommands.ZRangeRequest ByScore(double start, double stop, Exclude exclude, long skip, long take) + { + var req = SortedSetCommands.ZRangeRequest.ByScore(start, stop, exclude); + req.Offset = skip; + req.Count = take; + return req; + } + + public Task SortedSetRangeByScoreAsync( + RedisKey key, + double start = double.NegativeInfinity, + double stop = double.PositiveInfinity, + Exclude exclude = Exclude.None, + Order order = Order.Ascending, + long skip = 0, + long take = -1, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZRangeByScore(key, ByScore(start, stop, exclude, skip, take), order).AsTask(); + + public SortedSetEntry[] SortedSetRangeByScoreWithScores( + RedisKey key, + double start = double.NegativeInfinity, + double stop = double.PositiveInfinity, + Exclude exclude = Exclude.None, + Order order = Order.Ascending, + long skip = 0, + long take = -1, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZRangeByScoreWithScores(key, ByScore(start, stop, exclude, skip, take), order).Wait(SyncTimeout); + + public Task SortedSetRangeByScoreWithScoresAsync( + RedisKey key, + double start = double.NegativeInfinity, + double stop = double.PositiveInfinity, + Exclude exclude = Exclude.None, + Order order = Order.Ascending, + long skip = 0, + long take = -1, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZRangeByScoreWithScores(key, ByScore(start, stop, exclude, skip, take), order).AsTask(); + + private static SortedSetCommands.ZRangeRequest ByLex(RedisValue start, RedisValue stop, Exclude exclude, long skip, long take) + { + var req = SortedSetCommands.ZRangeRequest.ByLex(start, stop, exclude); + req.Offset = skip; + req.Count = take; + return req; + } + + public RedisValue[] SortedSetRangeByValue( + RedisKey key, + RedisValue min = default, + RedisValue max = default, + Exclude exclude = Exclude.None, + long skip = 0, + long take = -1, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZRangeByLex(key, ByLex(min, max, exclude, skip, take)).Wait(SyncTimeout); + + public RedisValue[] SortedSetRangeByValue( + RedisKey key, + RedisValue min, + RedisValue max, + Exclude exclude, + Order order, + long skip = 0, + long take = -1, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZRangeByLex(key, ByLex(min, max, exclude, skip, take), order).Wait(SyncTimeout); + + public Task SortedSetRangeByValueAsync( + RedisKey key, + RedisValue min = default, + RedisValue max = default, + Exclude exclude = Exclude.None, + long skip = 0, + long take = -1, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZRangeByLex(key, ByLex(min, max, exclude, skip, take)).AsTask(); + + public Task SortedSetRangeByValueAsync( + RedisKey key, + RedisValue min, + RedisValue max, + Exclude exclude, + Order order, + long skip = 0, + long take = -1, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZRangeByLex(key, ByLex(min, max, exclude, skip, take), order).AsTask(); + + public bool SortedSetRemove(RedisKey key, RedisValue member, CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZRem(key, member).Wait(SyncTimeout); + + public long SortedSetRemove(RedisKey key, RedisValue[] members, CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZRem(key, members).Wait(SyncTimeout); + + public Task SortedSetRemoveAsync(RedisKey key, RedisValue member, CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZRem(key, member).AsTask(); + + public Task SortedSetRemoveAsync( + RedisKey key, + RedisValue[] members, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZRem(key, members).AsTask(); + + public long SortedSetRemoveRangeByRank( + RedisKey key, + long start, + long stop, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZRemRangeByRank(key, start, stop).Wait(SyncTimeout); + + public Task SortedSetRemoveRangeByRankAsync( + RedisKey key, + long start, + long stop, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZRemRangeByRank(key, start, stop).AsTask(); + + public long SortedSetRemoveRangeByScore( + RedisKey key, + double start, + double stop, + Exclude exclude = Exclude.None, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZRemRangeByScore(key, SortedSetCommands.ZRangeRequest.ByScore(start, stop, exclude)).Wait(SyncTimeout); + + public Task SortedSetRemoveRangeByScoreAsync( + RedisKey key, + double start, + double stop, + Exclude exclude = Exclude.None, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZRemRangeByScore(key, SortedSetCommands.ZRangeRequest.ByScore(start, stop, exclude)).AsTask(); + + public long SortedSetRemoveRangeByValue( + RedisKey key, + RedisValue min, + RedisValue max, + Exclude exclude = Exclude.None, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZRemRangeByScore(key, SortedSetCommands.ZRangeRequest.ByLex(min, max, exclude)).Wait(SyncTimeout); + + public Task SortedSetRemoveRangeByValueAsync( + RedisKey key, + RedisValue min, + RedisValue max, + Exclude exclude = Exclude.None, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZRemRangeByScore(key, SortedSetCommands.ZRangeRequest.ByLex(min, max, exclude)).AsTask(); + + public IEnumerable + SortedSetScan(RedisKey key, RedisValue pattern, int pageSize, CommandFlags flags) + => throw new NotImplementedException(); + + public IEnumerable SortedSetScan( + RedisKey key, + RedisValue pattern = default, + int pageSize = RedisBase.CursorUtils.DefaultLibraryPageSize, + long cursor = RedisBase.CursorUtils.Origin, + int pageOffset = 0, + CommandFlags flags = CommandFlags.None) + => throw new NotImplementedException(); + + public IAsyncEnumerable SortedSetScanAsync( + RedisKey key, + RedisValue pattern = default, + int pageSize = RedisBase.CursorUtils.DefaultLibraryPageSize, + long cursor = RedisBase.CursorUtils.Origin, + int pageOffset = 0, + CommandFlags flags = CommandFlags.None) + => throw new NotImplementedException(); + + public double? SortedSetScore(RedisKey key, RedisValue member, CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZScore(key, member).Wait(SyncTimeout); + + public Task SortedSetScoreAsync(RedisKey key, RedisValue member, CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZScore(key, member).AsTask(); + + public double?[] SortedSetScores(RedisKey key, RedisValue[] members, CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZScore(key, members).Wait(SyncTimeout); + + public Task SortedSetScoresAsync( + RedisKey key, + RedisValue[] members, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZScore(key, members).AsTask(); + + public bool SortedSetUpdate( + RedisKey key, + RedisValue member, + double score, + SortedSetWhen when = SortedSetWhen.Always, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZAdd(key, when, member, score).Wait(SyncTimeout); + + public long SortedSetUpdate( + RedisKey key, + SortedSetEntry[] values, + SortedSetWhen when = SortedSetWhen.Always, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZAdd(key, when, values).Wait(SyncTimeout); + + public Task SortedSetUpdateAsync( + RedisKey key, + RedisValue member, + double score, + SortedSetWhen when = SortedSetWhen.Always, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZAdd(key, when, member, score).AsTask(); + + public Task SortedSetUpdateAsync( + RedisKey key, + SortedSetEntry[] values, + SortedSetWhen when = SortedSetWhen.Always, + CommandFlags flags = CommandFlags.None) + => Context(flags).SortedSets().ZAdd(key, when, values).AsTask(); +} diff --git a/src/RESPite.StackExchange.Redis/RespContextDatabase.Stream.cs b/src/RESPite.StackExchange.Redis/RespContextDatabase.Stream.cs new file mode 100644 index 000000000..8fbdd8cea --- /dev/null +++ b/src/RESPite.StackExchange.Redis/RespContextDatabase.Stream.cs @@ -0,0 +1,586 @@ +using StackExchange.Redis; + +namespace RESPite.StackExchange.Redis; + +internal partial class RespContextDatabase +{ + // Async Stream methods + public Task StreamAcknowledgeAsync( + RedisKey key, + RedisValue groupName, + RedisValue messageId, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StreamAcknowledgeAsync( + RedisKey key, + RedisValue groupName, + RedisValue[] messageIds, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StreamAcknowledgeAndDeleteAsync( + RedisKey key, + RedisValue groupName, + StreamTrimMode trimMode, + RedisValue messageId, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StreamAcknowledgeAndDeleteAsync( + RedisKey key, + RedisValue groupName, + StreamTrimMode trimMode, + RedisValue[] messageIds, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StreamAddAsync( + RedisKey key, + RedisValue streamField, + RedisValue streamValue, + RedisValue? messageId = null, + int? maxLength = null, + bool useApproximateMaxLength = false, + CommandFlags flags = CommandFlags.None) => + (messageId is null & maxLength is null & !useApproximateMaxLength) + ? StreamAddSimpleCoreAsync(key, streamField, streamValue, flags) + : throw new NotImplementedException(); + + [RespCommand("xadd")] + private partial RedisValue StreamAddSimpleCore( + RedisKey key, + [RespPrefix("*")] + RedisValue streamField, + RedisValue streamValue, + CommandFlags flags = CommandFlags.None); + + public Task StreamAddAsync( + RedisKey key, + NameValueEntry[] streamPairs, + RedisValue? messageId = null, + int? maxLength = null, + bool useApproximateMaxLength = false, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StreamAddAsync( + RedisKey key, + RedisValue streamField, + RedisValue streamValue, + RedisValue? messageId = null, + long? maxLength = null, + bool useApproximateMaxLength = false, + long? limit = null, + StreamTrimMode trimMode = StreamTrimMode.KeepReferences, + CommandFlags flags = CommandFlags.None) => + (messageId is null & maxLength is null & !useApproximateMaxLength + & limit is null & trimMode == StreamTrimMode.KeepReferences) + ? StreamAddSimpleCoreAsync(key, streamField, streamValue, flags) + : throw new NotImplementedException(); + + public Task StreamAddAsync( + RedisKey key, + NameValueEntry[] streamPairs, + RedisValue? messageId = null, + long? maxLength = null, + bool useApproximateMaxLength = false, + long? limit = null, + StreamTrimMode trimMode = StreamTrimMode.KeepReferences, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StreamAutoClaimAsync( + RedisKey key, + RedisValue consumerGroup, + RedisValue claimingConsumer, + long minIdleTimeInMs, + RedisValue startAtId, + int? count = null, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StreamAutoClaimIdsOnlyAsync( + RedisKey key, + RedisValue consumerGroup, + RedisValue claimingConsumer, + long minIdleTimeInMs, + RedisValue startAtId, + int? count = null, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StreamClaimAsync( + RedisKey key, + RedisValue consumerGroup, + RedisValue claimingConsumer, + long minIdleTimeInMs, + RedisValue[] messageIds, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StreamClaimIdsOnlyAsync( + RedisKey key, + RedisValue consumerGroup, + RedisValue claimingConsumer, + long minIdleTimeInMs, + RedisValue[] messageIds, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StreamConsumerGroupSetPositionAsync( + RedisKey key, + RedisValue groupName, + RedisValue position, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StreamConsumerInfoAsync( + RedisKey key, + RedisValue groupName, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StreamCreateConsumerGroupAsync( + RedisKey key, + RedisValue groupName, + RedisValue? position = null, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StreamCreateConsumerGroupAsync( + RedisKey key, + RedisValue groupName, + RedisValue? position = null, + bool createStream = true, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StreamDeleteAsync( + RedisKey key, + RedisValue[] messageIds, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StreamDeleteAsync( + RedisKey key, + RedisValue[] messageIds, + StreamTrimMode trimMode, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StreamDeleteConsumerAsync( + RedisKey key, + RedisValue groupName, + RedisValue consumerName, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StreamDeleteConsumerGroupAsync( + RedisKey key, + RedisValue groupName, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StreamGroupInfoAsync(RedisKey key, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StreamInfoAsync(RedisKey key, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StreamLengthAsync(RedisKey key, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StreamPendingAsync( + RedisKey key, + RedisValue groupName, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StreamPendingMessagesAsync( + RedisKey key, + RedisValue groupName, + int count, + RedisValue consumerName, + RedisValue? minId = null, + RedisValue? maxId = null, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StreamPendingMessagesAsync( + RedisKey key, + RedisValue groupName, + int count, + RedisValue consumerName, + RedisValue? minId = null, + RedisValue? maxId = null, + long? idleTime = null, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StreamRangeAsync( + RedisKey key, + RedisValue? minId = null, + RedisValue? maxId = null, + int? count = null, + Order messageOrder = Order.Ascending, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StreamReadAsync( + RedisKey key, + RedisValue position, + int? count = null, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StreamReadAsync( + StreamPosition[] streamPositions, + int? countPerStream = null, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StreamReadGroupAsync( + RedisKey key, + RedisValue groupName, + RedisValue consumerName, + RedisValue? position = null, + int? count = null, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StreamReadGroupAsync( + RedisKey key, + RedisValue groupName, + RedisValue consumerName, + RedisValue? position = null, + int? count = null, + bool noAck = false, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StreamReadGroupAsync( + StreamPosition[] streamPositions, + RedisValue groupName, + RedisValue consumerName, + int? countPerStream = null, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StreamReadGroupAsync( + StreamPosition[] streamPositions, + RedisValue groupName, + RedisValue consumerName, + int? countPerStream = null, + bool noAck = false, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StreamTrimAsync( + RedisKey key, + int maxLength, + bool useApproximateMaxLength = false, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StreamTrimAsync( + RedisKey key, + long maxLength, + bool useApproximateMaxLength = false, + long? limit = null, + StreamTrimMode trimMode = StreamTrimMode.KeepReferences, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StreamTrimByMinIdAsync( + RedisKey key, + RedisValue minId, + bool useApproximateMaxLength = false, + long? limit = null, + StreamTrimMode trimMode = StreamTrimMode.KeepReferences, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + // Synchronous Stream methods + public long StreamAcknowledge( + RedisKey key, + RedisValue groupName, + RedisValue messageId, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public long StreamAcknowledge( + RedisKey key, + RedisValue groupName, + RedisValue[] messageIds, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public StreamTrimResult StreamAcknowledgeAndDelete( + RedisKey key, + RedisValue groupName, + StreamTrimMode trimMode, + RedisValue messageId, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public StreamTrimResult[] StreamAcknowledgeAndDelete( + RedisKey key, + RedisValue groupName, + StreamTrimMode trimMode, + RedisValue[] messageIds, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public RedisValue StreamAdd( + RedisKey key, + RedisValue streamField, + RedisValue streamValue, + RedisValue? messageId = null, + int? maxLength = null, + bool useApproximateMaxLength = false, + CommandFlags flags = CommandFlags.None) => + (messageId is null & maxLength is null & !useApproximateMaxLength) + ? StreamAddSimpleCore(key, streamField, streamValue, flags) + : throw new NotImplementedException(); + + public RedisValue StreamAdd( + RedisKey key, + NameValueEntry[] streamPairs, + RedisValue? messageId = null, + int? maxLength = null, + bool useApproximateMaxLength = false, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public RedisValue StreamAdd( + RedisKey key, + RedisValue streamField, + RedisValue streamValue, + RedisValue? messageId = null, + long? maxLength = null, + bool useApproximateMaxLength = false, + long? limit = null, + StreamTrimMode trimMode = StreamTrimMode.KeepReferences, + CommandFlags flags = CommandFlags.None) => + (messageId is null & maxLength is null & !useApproximateMaxLength + & limit is null & trimMode == StreamTrimMode.KeepReferences) + ? StreamAddSimpleCore(key, streamField, streamValue, flags) + : throw new NotImplementedException(); + + public RedisValue StreamAdd( + RedisKey key, + NameValueEntry[] streamPairs, + RedisValue? messageId = null, + long? maxLength = null, + bool useApproximateMaxLength = false, + long? limit = null, + StreamTrimMode trimMode = StreamTrimMode.KeepReferences, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public StreamAutoClaimResult StreamAutoClaim( + RedisKey key, + RedisValue consumerGroup, + RedisValue claimingConsumer, + long minIdleTimeInMs, + RedisValue startAtId, + int? count = null, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public StreamAutoClaimIdsOnlyResult StreamAutoClaimIdsOnly( + RedisKey key, + RedisValue consumerGroup, + RedisValue claimingConsumer, + long minIdleTimeInMs, + RedisValue startAtId, + int? count = null, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public StreamEntry[] StreamClaim( + RedisKey key, + RedisValue consumerGroup, + RedisValue claimingConsumer, + long minIdleTimeInMs, + RedisValue[] messageIds, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public RedisValue[] StreamClaimIdsOnly( + RedisKey key, + RedisValue consumerGroup, + RedisValue claimingConsumer, + long minIdleTimeInMs, + RedisValue[] messageIds, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public bool StreamConsumerGroupSetPosition( + RedisKey key, + RedisValue groupName, + RedisValue position, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public StreamConsumerInfo[] StreamConsumerInfo( + RedisKey key, + RedisValue groupName, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public bool StreamCreateConsumerGroup( + RedisKey key, + RedisValue groupName, + RedisValue? position = null, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public bool StreamCreateConsumerGroup( + RedisKey key, + RedisValue groupName, + RedisValue? position = null, + bool createStream = true, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public long StreamDelete(RedisKey key, RedisValue[] messageIds, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public StreamTrimResult[] StreamDelete( + RedisKey key, + RedisValue[] messageIds, + StreamTrimMode trimMode, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public long StreamDeleteConsumer( + RedisKey key, + RedisValue groupName, + RedisValue consumerName, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public bool StreamDeleteConsumerGroup(RedisKey key, RedisValue groupName, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public StreamGroupInfo[] StreamGroupInfo(RedisKey key, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public StreamInfo StreamInfo(RedisKey key, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public long StreamLength(RedisKey key, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public StreamPendingInfo StreamPending( + RedisKey key, + RedisValue groupName, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public StreamPendingMessageInfo[] StreamPendingMessages( + RedisKey key, + RedisValue groupName, + int count, + RedisValue consumerName, + RedisValue? minId = null, + RedisValue? maxId = null, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public StreamPendingMessageInfo[] StreamPendingMessages( + RedisKey key, + RedisValue groupName, + int count, + RedisValue consumerName, + RedisValue? minId = null, + RedisValue? maxId = null, + long? idleTime = null, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public StreamEntry[] StreamRange( + RedisKey key, + RedisValue? minId = null, + RedisValue? maxId = null, + int? count = null, + Order messageOrder = Order.Ascending, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public StreamEntry[] StreamRead( + RedisKey key, + RedisValue position, + int? count = null, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public RedisStream[] StreamRead( + StreamPosition[] streamPositions, + int? countPerStream = null, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public StreamEntry[] StreamReadGroup( + RedisKey key, + RedisValue groupName, + RedisValue consumerName, + RedisValue? position = null, + int? count = null, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public StreamEntry[] StreamReadGroup( + RedisKey key, + RedisValue groupName, + RedisValue consumerName, + RedisValue? position = null, + int? count = null, + bool noAck = false, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public RedisStream[] StreamReadGroup( + StreamPosition[] streamPositions, + RedisValue groupName, + RedisValue consumerName, + int? countPerStream = null, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public RedisStream[] StreamReadGroup( + StreamPosition[] streamPositions, + RedisValue groupName, + RedisValue consumerName, + int? countPerStream = null, + bool noAck = false, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public long StreamTrim( + RedisKey key, + int maxLength, + bool useApproximateMaxLength = false, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public long StreamTrim( + RedisKey key, + long maxLength, + bool useApproximateMaxLength = false, + long? limit = null, + StreamTrimMode trimMode = StreamTrimMode.KeepReferences, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public long StreamTrimByMinId( + RedisKey key, + RedisValue minId, + bool useApproximateMaxLength = false, + long? limit = null, + StreamTrimMode trimMode = StreamTrimMode.KeepReferences, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); +} diff --git a/src/RESPite.StackExchange.Redis/RespContextDatabase.String.cs b/src/RESPite.StackExchange.Redis/RespContextDatabase.String.cs new file mode 100644 index 000000000..73c96649b --- /dev/null +++ b/src/RESPite.StackExchange.Redis/RespContextDatabase.String.cs @@ -0,0 +1,438 @@ +using RESPite.Messages; +using StackExchange.Redis; + +namespace RESPite.StackExchange.Redis; + +internal partial class RespContextDatabase +{ + // Async String methods + public Task StringAppendAsync(RedisKey key, RedisValue value, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StringBitCountAsync(RedisKey key, long start, long end, CommandFlags flags) => + throw new NotImplementedException(); + + public Task StringBitCountAsync( + RedisKey key, + long start = 0, + long end = -1, + StringIndexType indexType = StringIndexType.Byte, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StringBitOperationAsync( + Bitwise operation, + RedisKey destination, + RedisKey first, + RedisKey second = default, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StringBitOperationAsync( + Bitwise operation, + RedisKey destination, + RedisKey[] keys, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StringBitPositionAsync(RedisKey key, bool bit, long start, long end, CommandFlags flags) => + throw new NotImplementedException(); + + public Task StringBitPositionAsync( + RedisKey key, + bool bit, + long start = 0, + long end = -1, + StringIndexType indexType = StringIndexType.Byte, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StringDecrementAsync(RedisKey key, long value = 1, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StringDecrementAsync(RedisKey key, double value, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StringGetAsync(RedisKey[] keys, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StringGetBitAsync(RedisKey key, long offset, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StringGetRangeAsync( + RedisKey key, + long start, + long end, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StringGetSetAsync(RedisKey key, RedisValue value, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StringGetSetExpiryAsync( + RedisKey key, + TimeSpan? expiry, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StringGetSetExpiryAsync( + RedisKey key, + DateTime expiry, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StringGetDeleteAsync(RedisKey key, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StringGetWithExpiryAsync(RedisKey key, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StringIncrementAsync(RedisKey key, long value = 1, CommandFlags flags = CommandFlags.None) + => value == 1 ? StringIncrementUnitAsync(key, flags) : StringIncrementNonUnitAsync(key, value, flags); + + public Task StringLengthAsync(RedisKey key, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StringLongestCommonSubsequenceAsync( + RedisKey first, + RedisKey second, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StringLongestCommonSubsequenceLengthAsync( + RedisKey first, + RedisKey second, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StringLongestCommonSubsequenceWithMatchesAsync( + RedisKey first, + RedisKey second, + long minLength = 0, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StringSetAsync(RedisKey key, RedisValue value, TimeSpan? expiry, When when) => + StringSetAsync(key, value, expiry, false, when, CommandFlags.None); + + public Task StringSetAsync(RedisKey key, RedisValue value, TimeSpan? expiry, When when, CommandFlags flags) => + StringSetAsync(key, value, expiry, false, when, flags); + + public Task StringSetAndGetAsync( + RedisKey key, + RedisValue value, + TimeSpan? expiry, + When when, + CommandFlags flags) => + throw new NotImplementedException(); + + public Task StringSetAndGetAsync( + RedisKey key, + RedisValue value, + TimeSpan? expiry = null, + bool keepTtl = false, + When when = When.Always, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StringSetBitAsync(RedisKey key, long offset, bool bit, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task StringSetRangeAsync( + RedisKey key, + long offset, + RedisValue value, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + // Synchronous String methods + public long StringAppend(RedisKey key, RedisValue value, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public long StringBitCount(RedisKey key, long start, long end, CommandFlags flags) => + throw new NotImplementedException(); + + public long StringBitCount( + RedisKey key, + long start = 0, + long end = -1, + StringIndexType indexType = StringIndexType.Byte, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public long StringBitOperation( + Bitwise operation, + RedisKey destination, + RedisKey first, + RedisKey second = default, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public long StringBitOperation( + Bitwise operation, + RedisKey destination, + RedisKey[] keys, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public long StringBitPosition(RedisKey key, bool bit, long start, long end, CommandFlags flags) => + throw new NotImplementedException(); + + public long StringBitPosition( + RedisKey key, + bool bit, + long start = 0, + long end = -1, + StringIndexType indexType = StringIndexType.Byte, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public long StringDecrement(RedisKey key, long value = 1, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public double StringDecrement(RedisKey key, double value, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + [RespCommand("get")] + public partial RedisValue StringGet(RedisKey key, CommandFlags flags = CommandFlags.None); + + public RedisValue[] StringGet(RedisKey[] keys, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + [RespCommand("get")] + public partial Lease? StringGetLease(RedisKey key, CommandFlags flags = CommandFlags.None); + + public bool StringGetBit(RedisKey key, long offset, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public RedisValue StringGetRange(RedisKey key, long start, long end, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public RedisValue StringGetSet(RedisKey key, RedisValue value, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public RedisValue StringGetSetExpiry(RedisKey key, TimeSpan? expiry, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public RedisValue StringGetSetExpiry(RedisKey key, DateTime expiry, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public RedisValue StringGetDelete(RedisKey key, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public RedisValueWithExpiry StringGetWithExpiry(RedisKey key, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public long StringIncrement(RedisKey key, long value = 1, CommandFlags flags = CommandFlags.None) + => value == 1 ? StringIncrementUnit(key, flags) : StringIncrementNonUnit(key, value, flags); + + [RespCommand("incr")] + private partial long StringIncrementUnit(RedisKey key, CommandFlags flags); + + [RespCommand("incrby")] + private partial long StringIncrementNonUnit(RedisKey key, long value, CommandFlags flags); + + [RespCommand("incrbyfloat")] + public partial double StringIncrement(RedisKey key, double value, CommandFlags flags = CommandFlags.None); + + public long StringLength(RedisKey key, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public string? StringLongestCommonSubsequence( + RedisKey first, + RedisKey second, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public long StringLongestCommonSubsequenceLength( + RedisKey first, + RedisKey second, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public LCSMatchResult StringLongestCommonSubsequenceWithMatches( + RedisKey first, + RedisKey second, + long minLength = 0, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public bool StringSet(RedisKey key, RedisValue value, TimeSpan? expiry, When when) => + StringSet(key, value, expiry, false, when, CommandFlags.None); + + public bool StringSet(RedisKey key, RedisValue value, TimeSpan? expiry, When when, CommandFlags flags) => + StringSet(key, value, expiry, false, when, flags); + + public bool StringSet( + RedisKey key, + RedisValue value, + TimeSpan? expiry = null, + bool keepTtl = false, + When when = When.Always, + CommandFlags flags = CommandFlags.None) + => value.IsNull + ? KeyDelete(key, flags) + : StringSetCore(key, value, expiry.NullIfMaxValue(), keepTtl, when, flags); + + public Task StringSetAsync( + RedisKey key, + RedisValue value, + TimeSpan? expiry = null, + bool keepTtl = false, + When when = When.Always, + CommandFlags flags = CommandFlags.None) + => value.IsNull + ? KeyDeleteAsync(key, flags) + : StringSetCoreAsync(key, value, expiry.NullIfMaxValue(), keepTtl, when, flags); + + [RespCommand("set", Formatter = StringSetFormatter.Formatter)] + private partial bool StringSetCore( + RedisKey key, + RedisValue value, + TimeSpan? expiry = null, + bool keepTtl = false, + When when = When.Always, + CommandFlags flags = CommandFlags.None); + + private sealed class StringSetFormatter : IRespFormatter<(RedisKey Key, RedisValue Value, TimeSpan? Expiry, bool + KeepTtl, + When When)>, IRespFormatter[]> + { + public const string Formatter = $"{nameof(StringSetFormatter)}.{nameof(Instance)}"; + public static readonly StringSetFormatter Instance = new(); + private StringSetFormatter() { } + + public void Format( + scoped ReadOnlySpan command, + ref RespWriter writer, + in (RedisKey Key, RedisValue Value, TimeSpan? Expiry, bool KeepTtl, When When) request) + { + // SET key value [NX | XX] [GET] [EX seconds | PX milliseconds | + // EXAT unix-time-seconds | PXAT unix-time-milliseconds | KEEPTTL] + var argCount = 2 + request.When switch + { + When.Always => 0, + When.Exists or When.NotExists => 1, + _ => throw new ArgumentOutOfRangeException(nameof(request.When)), + } + (request.Expiry.HasValue ? 2 : 0) + (request.KeepTtl ? 1 : 0); + writer.WriteCommand(command, argCount); + writer.Write(request.Key); + writer.Write(request.Value); + switch (request.When) + { + case When.Exists: + writer.WriteBulkString("EX"u8); + break; + case When.NotExists: + writer.WriteBulkString("NX"u8); + break; + } + + if (request.Expiry.HasValue) + { + var millis = (long)request.Expiry.Value.TotalMilliseconds; + if ((millis % 1000) == 0) + { + writer.WriteBulkString("EX"u8); + writer.WriteBulkString(millis / 1000); + } + else + { + writer.WriteBulkString("PX"u8); + writer.WriteBulkString(millis); + } + } + + if (request.KeepTtl) + { + writer.WriteBulkString("KEEPTTL"u8); + } + } + + public void Format( + scoped ReadOnlySpan command, + ref RespWriter writer, + in KeyValuePair[] request) + { + writer.WriteCommand(command, 2 * request.Length); + foreach (var pair in request) + { + writer.Write(pair.Key); + writer.Write(pair.Value); + } + } + } + + public bool StringSet( + KeyValuePair[] values, + When when = When.Always, + CommandFlags flags = CommandFlags.None) + { + switch (values.Length) + { + case 0: return false; + case 1: return StringSet(values[0].Key, values[0].Value, null, false, when, flags); + default: + when.AlwaysOrNotExists(); + return when == When.Always + ? StringMSetCore(values, flags) + : StringMSetNXCore(values, flags); + } + } + + public Task StringSetAsync( + KeyValuePair[] values, + When when = When.Always, + CommandFlags flags = CommandFlags.None) + { + switch (values.Length) + { + case 0: return FalseTask; + case 1: return StringSetAsync(values[0].Key, values[0].Value, null, false, when, flags); + default: + when.AlwaysOrNotExists(); + return when == When.Always + ? StringMSetCoreAsync(values, flags) + : StringMSetNXCoreAsync(values, flags); + } + } + + private static readonly Task FalseTask = Task.FromResult(false); + + [RespCommand("mset", Formatter = StringSetFormatter.Formatter)] + private partial bool StringMSetCore( + KeyValuePair[] values, + CommandFlags flags = CommandFlags.None); + + [RespCommand("msetnx", Formatter = StringSetFormatter.Formatter)] + private partial bool StringMSetNXCore( + KeyValuePair[] values, + CommandFlags flags = CommandFlags.None); + + public RedisValue StringSetAndGet( + RedisKey key, + RedisValue value, + TimeSpan? expiry, + When when, + CommandFlags flags) => + throw new NotImplementedException(); + + public RedisValue StringSetAndGet( + RedisKey key, + RedisValue value, + TimeSpan? expiry = null, + bool keepTtl = false, + When when = When.Always, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public bool StringSetBit(RedisKey key, long offset, bool bit, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public RedisValue StringSetRange( + RedisKey key, + long offset, + RedisValue value, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); +} diff --git a/src/RESPite.StackExchange.Redis/RespContextDatabase.VectorSets.cs b/src/RESPite.StackExchange.Redis/RespContextDatabase.VectorSets.cs new file mode 100644 index 000000000..83e48ce83 --- /dev/null +++ b/src/RESPite.StackExchange.Redis/RespContextDatabase.VectorSets.cs @@ -0,0 +1,123 @@ +using StackExchange.Redis; + +namespace RESPite.StackExchange.Redis; + +internal partial class RespContextDatabase +{ + // Vector Set operations + public Task VectorSetAddAsync( + RedisKey key, + VectorSetAddRequest request, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task VectorSetLengthAsync(RedisKey key, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task VectorSetDimensionAsync(RedisKey key, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task?> VectorSetGetApproximateVectorAsync( + RedisKey key, + RedisValue member, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task VectorSetGetAttributesJsonAsync( + RedisKey key, + RedisValue member, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task VectorSetInfoAsync(RedisKey key, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task VectorSetContainsAsync(RedisKey key, RedisValue member, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task?> VectorSetGetLinksAsync( + RedisKey key, + RedisValue member, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task?> VectorSetGetLinksWithScoresAsync( + RedisKey key, + RedisValue member, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task VectorSetRandomMemberAsync(RedisKey key, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task VectorSetRandomMembersAsync( + RedisKey key, + long count, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task VectorSetRemoveAsync(RedisKey key, RedisValue member, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task VectorSetSetAttributesJsonAsync( + RedisKey key, + RedisValue member, + string attributesJson, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task?> VectorSetSimilaritySearchAsync( + RedisKey key, + VectorSetSimilaritySearchRequest query, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public bool VectorSetAdd(RedisKey key, VectorSetAddRequest request, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public long VectorSetLength(RedisKey key, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public int VectorSetDimension(RedisKey key, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Lease? VectorSetGetApproximateVector( + RedisKey key, + RedisValue member, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public string? + VectorSetGetAttributesJson(RedisKey key, RedisValue member, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public VectorSetInfo? VectorSetInfo(RedisKey key, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public bool VectorSetContains(RedisKey key, RedisValue member, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Lease? + VectorSetGetLinks(RedisKey key, RedisValue member, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Lease? VectorSetGetLinksWithScores( + RedisKey key, + RedisValue member, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public RedisValue VectorSetRandomMember(RedisKey key, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public RedisValue[] VectorSetRandomMembers(RedisKey key, long count, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public bool VectorSetRemove(RedisKey key, RedisValue member, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public bool VectorSetSetAttributesJson( + RedisKey key, + RedisValue member, + string attributesJson, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Lease? VectorSetSimilaritySearch( + RedisKey key, + VectorSetSimilaritySearchRequest query, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); +} diff --git a/src/RESPite.StackExchange.Redis/RespContextDatabase.cs b/src/RESPite.StackExchange.Redis/RespContextDatabase.cs new file mode 100644 index 000000000..65e7d60b1 --- /dev/null +++ b/src/RESPite.StackExchange.Redis/RespContextDatabase.cs @@ -0,0 +1,54 @@ +using RESPite.Connections; +using StackExchange.Redis; + +namespace RESPite.StackExchange.Redis; + +/// +/// Implements IDatabase on top of a , which provides access to a RESP context; this +/// could be direct to a known server or routed - the is responsible for +/// that determination. +/// +internal partial class RespContextDatabase : IDatabase +{ + private readonly IConnectionMultiplexer _muxer; + private IRespContextSource _source; + private readonly int _db; + + /// + /// Initializes a new instance of the class. + /// Implements IDatabase on top of a , which provides access to a RESP context; this + /// could be direct to a known server or routed - the is responsible for + /// that determination. + /// + public RespContextDatabase(IConnectionMultiplexer muxer, IRespContextSource source, int db) + { + _muxer = muxer; + _source = source; + _db = db; + } + + // change the proxy being used + protected void SetSource(IRespContextSource source) + => this._source = source; + + private RespContext Context(CommandFlags flags) => _source.Context.With(_db, flags); + + private TimeSpan SyncTimeout => _source.Context.SyncTimeout; + public int Database => _db; + + IConnectionMultiplexer IRedisAsync.Multiplexer => _muxer; + + public bool TryWait(Task task) => task.Wait(SyncTimeout); + + public void Wait(Task task) => _muxer.Wait(task); + + public T Wait(Task task) => _muxer.Wait(task); + + public void WaitAll(params Task[] tasks) => _muxer.WaitAll(tasks); +} + +internal static class MiscExtensions +{ + internal static TimeSpan? NullIfMaxValue(this TimeSpan? value) + => value == TimeSpan.MaxValue ? null : value; +} diff --git a/src/RESPite.StackExchange.Redis/RespContextExtensions.cs b/src/RESPite.StackExchange.Redis/RespContextExtensions.cs new file mode 100644 index 000000000..b03b1019f --- /dev/null +++ b/src/RESPite.StackExchange.Redis/RespContextExtensions.cs @@ -0,0 +1,20 @@ +using StackExchange.Redis; + +namespace RESPite.StackExchange.Redis; + +internal static class RespContextExtensions +{ + // Question: cache this, or rebuild each time? the latter handles shutdown better. + // internal readonly RespContext Context = proxy.Context.WithDatabase(db); + internal static RespContext With(this in RespContext context, int db, CommandFlags flags) + { + // the flags intentionally align between CommandFlags and RespContextFlags + const RespContext.RespContextFlags FlagMask = RespContext.RespContextFlags.DemandPrimary + | RespContext.RespContextFlags.DemandReplica + | RespContext.RespContextFlags.PreferReplica + | RespContext.RespContextFlags.NoRedirect + | RespContext.RespContextFlags.FireAndForget + | RespContext.RespContextFlags.NoScriptCache; + return context.With(db, (RespContext.RespContextFlags)flags, FlagMask); + } +} diff --git a/src/RESPite.StackExchange.Redis/RespContextServer.cs b/src/RESPite.StackExchange.Redis/RespContextServer.cs new file mode 100644 index 000000000..0e8caf5ee --- /dev/null +++ b/src/RESPite.StackExchange.Redis/RespContextServer.cs @@ -0,0 +1,460 @@ +using System.Net; +using RESPite.Connections; +using RESPite.Connections.Internal; +using StackExchange.Redis; + +namespace RESPite.StackExchange.Redis; + +/// +/// Implements IServer on top of a , which represents a fixed single connection +/// to a single redis instance. The connection exposed is the "interactive" connection. +/// +internal sealed class RespContextServer(RespMultiplexer muxer, Node node) : IServer +{ + // deliberately not caching this - if the connection changes, we want to know about it + internal RespContext Context(CommandFlags flags) => node.Context.With(-1, flags); + + private TimeSpan SyncTimeout => node.Context.SyncTimeout; + + public IConnectionMultiplexer Multiplexer => muxer; + + public Task PingAsync(CommandFlags flags = CommandFlags.None) + => Context(flags).Send("ping"u8, DateTime.UtcNow, RespContextDatabase.PingParser.Default, RespContextDatabase.PingRaw).AsTask(); + + public bool TryWait(Task task) => task.Wait(Multiplexer.TimeoutMilliseconds); + + public void Wait(Task task) => Multiplexer.Wait(task); + + public T Wait(Task task) => Multiplexer.Wait(task); + + public void WaitAll(params Task[] tasks) => Multiplexer.WaitAll(tasks); + + public TimeSpan Ping(CommandFlags flags = CommandFlags.None) + => Context(flags).Send("ping"u8, DateTime.UtcNow, RespContextDatabase.PingParser.Default, RespContextDatabase.PingRaw).Wait(SyncTimeout); + + public ClusterConfiguration? ClusterConfiguration => throw new NotImplementedException(); + public EndPoint EndPoint => node.Manager.ConnectionFactory.GetEndPoint(node.EndPoint, node.Port); + public RedisFeatures Features => new(Version); + public bool IsConnected => node.IsConnected; + public RedisProtocol Protocol { get; } + bool IServer.IsSlave => node.IsReplica; + public bool IsReplica => node.IsReplica; + public bool AllowSlaveWrites { get; set; } + public bool AllowReplicaWrites { get; set; } + public ServerType ServerType { get; } + public Version Version => throw new NotImplementedException(); + public int DatabaseCount => throw new NotImplementedException(); + public void ClientKill(EndPoint endpoint, CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task ClientKillAsync(EndPoint endpoint, CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public long ClientKill( + long? id = null, + ClientType? clientType = null, + EndPoint? endpoint = null, + bool skipMe = true, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task ClientKillAsync( + long? id = null, + ClientType? clientType = null, + EndPoint? endpoint = null, + bool skipMe = true, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public long ClientKill(ClientKillFilter filter, CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task ClientKillAsync( + ClientKillFilter filter, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public ClientInfo[] ClientList(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task ClientListAsync(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public ClusterConfiguration? ClusterNodes(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task ClusterNodesAsync(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public string? ClusterNodesRaw(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task ClusterNodesRawAsync(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public KeyValuePair[] ConfigGet( + RedisValue pattern = default, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task[]> ConfigGetAsync( + RedisValue pattern = default, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public void ConfigResetStatistics(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task ConfigResetStatisticsAsync(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public void ConfigRewrite(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task ConfigRewriteAsync(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public void ConfigSet( + RedisValue setting, + RedisValue value, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task ConfigSetAsync( + RedisValue setting, + RedisValue value, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public long CommandCount(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task CommandCountAsync(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public RedisKey[] CommandGetKeys( + RedisValue[] command, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task CommandGetKeysAsync( + RedisValue[] command, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public string[] CommandList( + RedisValue? moduleName = null, + RedisValue? category = null, + RedisValue? pattern = null, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task CommandListAsync( + RedisValue? moduleName = null, + RedisValue? category = null, + RedisValue? pattern = null, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public long DatabaseSize( + int database = -1, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task DatabaseSizeAsync( + int database = -1, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public RedisValue Echo( + RedisValue message, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task EchoAsync( + RedisValue message, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public RedisResult Execute(string command, params object[] args) => throw new NotImplementedException(); + + public Task ExecuteAsync(string command, params object[] args) => throw new NotImplementedException(); + + public RedisResult Execute( + string command, + ICollection args, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task ExecuteAsync( + string command, + ICollection args, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public RedisResult Execute(int? database, string command, ICollection args, CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public Task ExecuteAsync(int? database, string command, ICollection args, CommandFlags flags = CommandFlags.None) + => throw new NotImplementedException(); + + public void FlushAllDatabases(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task FlushAllDatabasesAsync(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public void FlushDatabase( + int database = -1, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task FlushDatabaseAsync( + int database = -1, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public ServerCounters GetCounters() => throw new NotImplementedException(); + + public IGrouping>[] Info( + RedisValue section = default, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task>[]> InfoAsync( + RedisValue section = default, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public string? InfoRaw( + RedisValue section = default, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task InfoRawAsync( + RedisValue section = default, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public IEnumerable Keys( + int database, + RedisValue pattern, + int pageSize, + CommandFlags flags) => throw new NotImplementedException(); + + public IEnumerable Keys( + int database = -1, + RedisValue pattern = default, + int pageSize = RedisBase.CursorUtils.DefaultLibraryPageSize, + long cursor = RedisBase.CursorUtils.Origin, + int pageOffset = 0, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public IAsyncEnumerable KeysAsync( + int database = -1, + RedisValue pattern = default, + int pageSize = RedisBase.CursorUtils.DefaultLibraryPageSize, + long cursor = RedisBase.CursorUtils.Origin, + int pageOffset = 0, + CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public DateTime LastSave(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task LastSaveAsync(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public void MakeMaster( + ReplicationChangeOptions options, + TextWriter? log = null) => throw new NotImplementedException(); + + public Task MakePrimaryAsync( + ReplicationChangeOptions options, + TextWriter? log = null) => throw new NotImplementedException(); + + public Role Role(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task RoleAsync(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public void Save( + SaveType type, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task SaveAsync( + SaveType type, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public bool ScriptExists( + string script, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task ScriptExistsAsync( + string script, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public bool ScriptExists( + byte[] sha1, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task ScriptExistsAsync( + byte[] sha1, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public void ScriptFlush(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task ScriptFlushAsync(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public byte[] ScriptLoad( + string script, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task ScriptLoadAsync( + string script, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public LoadedLuaScript ScriptLoad( + LuaScript script, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task ScriptLoadAsync( + LuaScript script, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public void Shutdown( + ShutdownMode shutdownMode = ShutdownMode.Default, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public void SlaveOf( + EndPoint master, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task SlaveOfAsync( + EndPoint master, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public void ReplicaOf( + EndPoint master, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task ReplicaOfAsync( + EndPoint master, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public CommandTrace[] SlowlogGet( + int count = 0, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task SlowlogGetAsync( + int count = 0, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public void SlowlogReset(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task SlowlogResetAsync(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public RedisChannel[] SubscriptionChannels( + RedisChannel pattern = default, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task SubscriptionChannelsAsync( + RedisChannel pattern = default, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public long SubscriptionPatternCount(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task SubscriptionPatternCountAsync(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public long SubscriptionSubscriberCount( + RedisChannel channel, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task SubscriptionSubscriberCountAsync( + RedisChannel channel, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public void SwapDatabases( + int first, + int second, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task SwapDatabasesAsync( + int first, + int second, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public DateTime Time(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task TimeAsync(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public string LatencyDoctor(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task LatencyDoctorAsync(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public long LatencyReset( + string[]? eventNames = null, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task LatencyResetAsync( + string[]? eventNames = null, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public LatencyHistoryEntry[] LatencyHistory( + string eventName, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task LatencyHistoryAsync( + string eventName, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public LatencyLatestEntry[] LatencyLatest(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task LatencyLatestAsync(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public string MemoryDoctor(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task MemoryDoctorAsync(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public void MemoryPurge(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task MemoryPurgeAsync(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public RedisResult MemoryStats(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task MemoryStatsAsync(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public string? MemoryAllocatorStats(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task MemoryAllocatorStatsAsync(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public EndPoint? SentinelGetMasterAddressByName( + string serviceName, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task SentinelGetMasterAddressByNameAsync( + string serviceName, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public EndPoint[] SentinelGetSentinelAddresses( + string serviceName, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task SentinelGetSentinelAddressesAsync( + string serviceName, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public EndPoint[] SentinelGetReplicaAddresses( + string serviceName, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task SentinelGetReplicaAddressesAsync( + string serviceName, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public KeyValuePair[] SentinelMaster( + string serviceName, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task[]> SentinelMasterAsync( + string serviceName, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public KeyValuePair[][] SentinelMasters(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task[][]> SentinelMastersAsync(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public KeyValuePair[][] SentinelSlaves( + string serviceName, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task[][]> SentinelSlavesAsync( + string serviceName, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public KeyValuePair[][] SentinelReplicas( + string serviceName, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task[][]> SentinelReplicasAsync( + string serviceName, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public void SentinelFailover( + string serviceName, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task SentinelFailoverAsync( + string serviceName, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public KeyValuePair[][] SentinelSentinels( + string serviceName, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task[][]> SentinelSentinelsAsync( + string serviceName, + CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); +} diff --git a/src/RESPite.StackExchange.Redis/RespFormatters.cs b/src/RESPite.StackExchange.Redis/RespFormatters.cs new file mode 100644 index 000000000..3a4f497f5 --- /dev/null +++ b/src/RESPite.StackExchange.Redis/RespFormatters.cs @@ -0,0 +1,179 @@ +using System.Buffers; +using RESPite.Messages; +using StackExchange.Redis; +using StorageType = StackExchange.Redis.RedisValue.StorageType; + +namespace RESPite.StackExchange.Redis; + +public static class RespFormatters +{ + public static IRespFormatter RedisValue => DefaultFormatter.Instance; + public static IRespFormatter RedisKey => DefaultFormatter.Instance; + public static IRespFormatter RedisKeyArray => DefaultFormatter.Instance; + + private sealed class DefaultFormatter : IRespFormatter, IRespFormatter, IRespFormatter + { + public static readonly DefaultFormatter Instance = new(); + private DefaultFormatter() { } + + public void Format(scoped ReadOnlySpan command, ref RespWriter writer, in RedisValue request) + { + writer.WriteCommand(command, 1); + writer.Write(request); + } + + public void Format(scoped ReadOnlySpan command, ref RespWriter writer, in RedisKey request) + { + writer.WriteCommand(command, 1); + writer.Write(request); + } + + public void Format(scoped ReadOnlySpan command, ref RespWriter writer, in RedisKey[] request) + { + writer.WriteCommand(command, 1 + request.Length); + foreach (var key in request) + { + writer.Write(key); + } + } + } + + // ReSharper disable once MemberCanBePrivate.Global + public static void Write(this ref RespWriter writer, in RedisKey key) + { + if (key.TryGetSimpleBuffer(out var arr)) + { + key.AssertNotNull(); + writer.WriteKey(arr); + } + else + { + var len = key.TotalLength(); + byte[]? lease = null; + var span = len <= 128 ? stackalloc byte[128] : (lease = ArrayPool.Shared.Rent(len)); + var written = key.CopyTo(span); + writer.WriteKey(span.Slice(0, written)); + if (lease is not null) ArrayPool.Shared.Return(lease); + } + } + + internal static void WriteBulkString(this ref RespWriter writer, HashCommandsExtensions.HashExpiryMode when) + { + switch (when) + { + case HashCommandsExtensions.HashExpiryMode.EX: + writer.WriteRaw("$2\r\nEX\r\n"u8); + break; + case HashCommandsExtensions.HashExpiryMode.PX: + writer.WriteRaw("$2\r\nPX\r\n"u8); + break; + case HashCommandsExtensions.HashExpiryMode.EXAT: + writer.WriteRaw("$4\r\nEXAT\r\n"u8); + break; + case HashCommandsExtensions.HashExpiryMode.PXAT: + writer.WriteRaw("$4\r\nPXAT\r\n"u8); + break; + case HashCommandsExtensions.HashExpiryMode.PERSIST: + writer.WriteRaw("$7\r\nPERSIST\r\n"u8); + break; + case HashCommandsExtensions.HashExpiryMode.KEEPTTL: + writer.WriteRaw("$7\r\nKEEPTTL\r\n"u8); + break; + default: + Throw(); + static void Throw() => throw new ArgumentOutOfRangeException(nameof(when)); + break; + } + } + + internal static void WriteBulkString(this ref RespWriter writer, ExpireWhen when) + { + switch (when) + { + case ExpireWhen.HasExpiry: + writer.WriteRaw("$2\r\nXX\r\n"u8); + break; + case ExpireWhen.HasNoExpiry: + writer.WriteRaw("$2\r\nNX\r\n"u8); + break; + case ExpireWhen.GreaterThanCurrentExpiry: + writer.WriteRaw("$2\r\nGT\r\n"u8); + break; + case ExpireWhen.LessThanCurrentExpiry: + writer.WriteRaw("$2\r\nLT\r\n"u8); + break; + default: + Throw(); + static void Throw() => throw new ArgumentOutOfRangeException(nameof(when)); + break; + } + } + + internal static void WriteBulkString(this ref RespWriter writer, ListSide side) + { + switch (side) + { + case ListSide.Left: + writer.WriteRaw("$4\r\nLEFT\r\n"u8); + break; + case ListSide.Right: + writer.WriteRaw("$5\r\nRIGHT\r\n"u8); + break; + default: + Throw(); + static void Throw() => throw new ArgumentOutOfRangeException(nameof(side)); + break; + } + } + + internal static void WriteBulkString(this ref RespWriter writer, Aggregate? aggregate) + { + switch (aggregate!.Value) + { + case Aggregate.Sum: + writer.WriteRaw("$3\r\nSUM\r\n"u8); + break; + case Aggregate.Min: + writer.WriteRaw("$3\r\nMIN\r\n"u8); + break; + case Aggregate.Max: + writer.WriteRaw("$3\r\nMAX\r\n"u8); + break; + default: + Throw(); + static void Throw() => throw new ArgumentOutOfRangeException(nameof(aggregate)); + break; + } + } + + // ReSharper disable once MemberCanBePrivate.Global + public static void Write(this ref RespWriter writer, in RedisValue value) + { + switch (value.Type) + { + case StorageType.Double: + writer.WriteBulkString(value.OverlappedValueDouble); + break; + case StorageType.Int64: + writer.WriteBulkString(value.OverlappedValueInt64); + break; + case StorageType.UInt64: + writer.WriteBulkString(value.OverlappedValueUInt64); + break; + case StorageType.String: + writer.WriteBulkString((string)value.DirectObject!); + break; + case StorageType.Raw: + writer.WriteBulkString((ReadOnlyMemory)value); + break; + case StorageType.Null: + value.AssertNotNull(); + break; + default: + Throw(value.Type); + break; + } + static void Throw(StorageType type) + => throw new InvalidOperationException($"Unexpected {type} value."); + } +} diff --git a/src/RESPite.StackExchange.Redis/RespMultiplexer.cs b/src/RESPite.StackExchange.Redis/RespMultiplexer.cs new file mode 100644 index 000000000..b80d6df1e --- /dev/null +++ b/src/RESPite.StackExchange.Redis/RespMultiplexer.cs @@ -0,0 +1,309 @@ +using System.Buffers; +using System.Net; +using RESPite.Connections; +using RESPite.Connections.Internal; +using StackExchange.Redis; +using StackExchange.Redis.Maintenance; +using StackExchange.Redis.Profiling; + +namespace RESPite.StackExchange.Redis; + +public sealed class RespMultiplexer : IConnectionMultiplexer +{ + private readonly RespConnectionManager _connectionManager = new(); + private ConfigurationOptions? _options; + private string _clientName = ""; + + private ConfigurationOptions Options + { + get + { + return _options ?? ThrowNotConnected(); + + static ConfigurationOptions ThrowNotConnected() => + throw new InvalidOperationException("Not connected."); + } + set + { + if (value is null) throw new ArgumentNullException(nameof(Options)); + if (Interlocked.CompareExchange(ref _options, value, null) is not null) + throw new InvalidOperationException("Options have already been set."); + } + } + + /// + public override string ToString() => GetType().Name; + + public ValueTask DisposeAsync() => _connectionManager.DisposeAsync(); + + public void Dispose() => _connectionManager.Dispose(); + + public void Connect(string configurationString, TextWriter? log = null) + => Connect(ConfigurationOptions.Parse(configurationString), log); + + public void Connect(ConfigurationOptions options, TextWriter? log = null) + { + Options = options; + var parsed = ParseOptions(options, out _clientName); + _connectionManager.Connect(parsed, GetEndpoints(options, out var oversized), log); + ArrayPool.Shared.Return(oversized); + } + + public Task ConnectAsync(string configurationString, TextWriter? log = null) + => ConnectAsync(ConfigurationOptions.Parse(configurationString), log); + + public async Task ConnectAsync(ConfigurationOptions options, TextWriter? log = null) + { + Options = options; + var parsed = ParseOptions(options, out _clientName); + await _connectionManager.ConnectAsync(parsed, GetEndpoints(options, out var oversized), log); + ArrayPool.Shared.Return(oversized); + } + + private static RespConfiguration ParseOptions(ConfigurationOptions options, out string clientName) + { + var config = RespConfiguration.Default.AsBuilder(); + clientName = options.ClientName ?? options.Defaults.ClientName; + config.SyncTimeout = TimeSpan.FromMilliseconds(options.SyncTimeout); + config.DefaultDatabase = options.DefaultDatabase ?? 0; + return config.CreateConfiguration(); + } + + private ReadOnlySpan GetEndpoints( + ConfigurationOptions options, + out RespConnectionManager.EndpointPair[] oversized) + { + oversized = ArrayPool.Shared.Rent(Math.Max(options.EndPoints.Count, 1)); + if (options.EndPoints.Count == 0) + { + oversized[0] = new("127.0.0.1", 6379); + return oversized.AsSpan(0, 1); + } + else + { + int count = 0; + foreach (var endpoint in options.EndPoints) + { + if (!_connectionManager.ConnectionFactory.TryParse(endpoint, out var host, out var port)) + { + throw new ArgumentException($"Could not parse host and port from {endpoint}", nameof(endpoint)); + } + + oversized[count++] = new(host, port); + } + + return oversized.AsSpan(0, count); + } + } + + // ReSharper disable once ConvertToAutoProperty + string IConnectionMultiplexer.ClientName => _clientName; + + string IConnectionMultiplexer.Configuration => Options.ToString(includePassword: false); + + private int SyncTimeoutMilliseconds => Options.SyncTimeout; + int IConnectionMultiplexer.TimeoutMilliseconds => Options.SyncTimeout; + + long IConnectionMultiplexer.OperationCount => _connectionManager.OperationCount; + + bool IConnectionMultiplexer.PreserveAsyncOrder + { + get => false; + set { } + } + + public bool IsConnected => _connectionManager.IsConnected; + + bool IConnectionMultiplexer.IsConnecting => _connectionManager.IsConnecting; + + bool IConnectionMultiplexer.IncludeDetailInExceptions + { + get => Options.IncludeDetailInExceptions; + set => Options.IncludeDetailInExceptions = value; + } + + int IConnectionMultiplexer.StormLogThreshold + { + get => 0; + set { } + } + + void IConnectionMultiplexer.RegisterProfiler(Func profilingSessionProvider) { } + + ServerCounters IConnectionMultiplexer.GetCounters() => throw new NotImplementedException(); + +#pragma warning disable CS0067 // Event is never used + private event EventHandler? ErrorMessage; + + private event EventHandler? ConnectionFailed, ConnectionRestored; + private event EventHandler? InternalError; + private event EventHandler? ConfigurationChanged, ConfigurationChangedBroadcast; + private event EventHandler? ServerMaintenanceEvent; + private event EventHandler? HashSlotMoved; +#pragma warning restore CS0067 // Event is never used + + event EventHandler? IConnectionMultiplexer.ErrorMessage + { + add => ErrorMessage += value; + remove => ErrorMessage -= value; + } + + event EventHandler? IConnectionMultiplexer.ConnectionFailed + { + add => ConnectionFailed += value; + remove => ConnectionFailed -= value; + } + + event EventHandler? IConnectionMultiplexer.InternalError + { + add => InternalError += value; + remove => InternalError -= value; + } + + event EventHandler? IConnectionMultiplexer.ConnectionRestored + { + add => ConnectionRestored += value; + remove => ConnectionRestored -= value; + } + + event EventHandler? IConnectionMultiplexer.ConfigurationChanged + { + add => ConfigurationChanged += value; + remove => ConfigurationChanged -= value; + } + + event EventHandler? IConnectionMultiplexer.ConfigurationChangedBroadcast + { + add => ConfigurationChangedBroadcast += value; + remove => ConfigurationChangedBroadcast -= value; + } + + event EventHandler? IConnectionMultiplexer.ServerMaintenanceEvent + { + add => ServerMaintenanceEvent += value; + remove => ServerMaintenanceEvent -= value; + } + + public EndPoint[] GetEndPoints(bool configuredOnly = false) + { + throw new NotImplementedException(); + } + + void IConnectionMultiplexer.Wait(Task task) + { + if (!task.Wait(SyncTimeoutMilliseconds)) + { + ThrowTimeout(); + } + + task.GetAwaiter().GetResult(); + } + + private static void ThrowTimeout() => throw new TimeoutException(); + + T IConnectionMultiplexer.Wait(Task task) + { + if (!task.Wait(SyncTimeoutMilliseconds)) + { + ThrowTimeout(); + } + + return task.GetAwaiter().GetResult(); + } + + void IConnectionMultiplexer.WaitAll(params Task[] tasks) + { + if (!Task.WaitAll(tasks, SyncTimeoutMilliseconds)) + { + ThrowTimeout(); + } + } + + event EventHandler? IConnectionMultiplexer.HashSlotMoved + { + add => HashSlotMoved += value; + remove => HashSlotMoved -= value; + } + + int IConnectionMultiplexer.HashSlot(RedisKey key) => throw new NotImplementedException(); + + ISubscriber IConnectionMultiplexer.GetSubscriber(object? asyncState) => throw new NotImplementedException(); + + public IDatabase GetDatabase(int db = -1, object? asyncState = null) + { + if (db < 0) db = Options.DefaultDatabase ?? 0; + return new RespContextDatabase(this, _connectionManager, db); + } + + IServer IConnectionMultiplexer.GetServer(string host, int port, object? asyncState) => + GetServer(_connectionManager.GetNode(host, port), asyncState); + + IServer IConnectionMultiplexer.GetServer(string hostAndPort, object? asyncState) => + GetServer(_connectionManager.GetNode(hostAndPort), asyncState); + + IServer IConnectionMultiplexer.GetServer(IPAddress host, int port) => + GetServer(_connectionManager.GetNode(host.ToString(), port), null); + + public IServer GetServer(EndPoint endpoint, object? asyncState = null) + { + if (!_connectionManager.ConnectionFactory.TryParse(endpoint, out var host, out var port)) + { + throw new ArgumentException($"Could not parse host and port from {endpoint}", nameof(endpoint)); + } + + return GetServer(_connectionManager.GetNode(host, port), asyncState); + } + + public IServer GetServer(RedisKey key, object? asyncState = null, CommandFlags flags = CommandFlags.None) + { + if (key.IsNull) // just get anything + { + var node = _connectionManager.GetRandomNode(); + if (node is not null) return GetServer(node, asyncState); + } + throw new NotImplementedException(); + } + + private IServer GetServer(Node node, object? asyncState) + { + if (asyncState is not null) ThrowNotSupported(); + if (node.UserObject is not IServer server) + { + server = new RespContextServer(this, node); + node.UserObject = server; + } + + return server; + static void ThrowNotSupported() => throw new NotSupportedException($"{nameof(asyncState)} is not supported"); + } + + IServer[] IConnectionMultiplexer.GetServers() => throw new NotImplementedException(); + + public Task ConfigureAsync(TextWriter? log = null) => throw new NotImplementedException(); + + public bool Configure(TextWriter? log = null) => throw new NotImplementedException(); + + public string GetStatus() => throw new NotImplementedException(); + + public void GetStatus(TextWriter log) => throw new NotImplementedException(); + + public void Close(bool allowCommandsToComplete = true) => throw new NotImplementedException(); + + public Task CloseAsync(bool allowCommandsToComplete = true) => throw new NotImplementedException(); + + public string? GetStormLog() => throw new NotImplementedException(); + + public void ResetStormLog() => throw new NotImplementedException(); + + public long PublishReconfigure(CommandFlags flags = CommandFlags.None) => throw new NotImplementedException(); + + public Task PublishReconfigureAsync(CommandFlags flags = CommandFlags.None) => + throw new NotImplementedException(); + + public int GetHashSlot(RedisKey key) => throw new NotImplementedException(); + + public void ExportConfiguration(Stream destination, ExportOptions options = ExportOptions.All) => + throw new NotImplementedException(); + + public void AddLibraryNameSuffix(string suffix) => throw new NotImplementedException(); +} diff --git a/src/RESPite.StackExchange.Redis/RespParsers.ScanParsers.cs b/src/RESPite.StackExchange.Redis/RespParsers.ScanParsers.cs new file mode 100644 index 000000000..ad562468f --- /dev/null +++ b/src/RESPite.StackExchange.Redis/RespParsers.ScanParsers.cs @@ -0,0 +1,36 @@ +using RESPite.Messages; +using StackExchange.Redis; + +namespace RESPite.StackExchange.Redis; + +public static partial class RespParsers +{ + internal static IRespParser> ZScanSimple = ScanResultParser.NonLeased; + internal static IRespParser> ZScanLeased = ScanResultParser.Leased; + + private sealed class ScanResultParser : IRespParser> + { + public static readonly ScanResultParser NonLeased = new(false); + public static readonly ScanResultParser Leased = new(true); + private readonly bool _leased; + private ScanResultParser(bool leased) => _leased = leased; + + ScanResult IRespParser>.Parse(ref RespReader reader) + { + reader.DemandAggregate(); + reader.MoveNextScalar(); + var cursor = reader.ReadInt64(); + reader.MoveNextAggregate(); + if (_leased) + { + var values = DefaultParser.ReadLeasedSortedSetEntryArray(ref reader, out int count); + return new(cursor, values, count); + } + else + { + var values = DefaultParser.ReadSortedSetEntryArray(ref reader); + return new(cursor, values); + } + } + } +} diff --git a/src/RESPite.StackExchange.Redis/RespParsers.cs b/src/RESPite.StackExchange.Redis/RespParsers.cs new file mode 100644 index 000000000..f0a4f0258 --- /dev/null +++ b/src/RESPite.StackExchange.Redis/RespParsers.cs @@ -0,0 +1,201 @@ +using RESPite.Internal; +using RESPite.Messages; +using StackExchange.Redis; + +namespace RESPite.StackExchange.Redis; + +public static partial class RespParsers +{ + public static IRespParser RedisValue => DefaultParser.Instance; + public static IRespParser RedisValueArray => DefaultParser.Instance; + public static IRespParser RedisKey => DefaultParser.Instance; + public static IRespParser> BytesLease => DefaultParser.Instance; + public static IRespParser HashEntryArray => DefaultParser.Instance; + public static IRespParser SortedSetEntryArray => DefaultParser.Instance; + public static IRespParser SortedSetEntry => DefaultParser.Instance; + public static IRespParser TimeSpanFromSeconds => TimeParser.FromSeconds; + public static IRespParser TimeSpanArrayFromSeconds => TimeParser.FromSeconds; + public static IRespParser DateTimeFromSeconds => TimeParser.FromSeconds; + public static IRespParser DateTimeArrayFromSeconds => TimeParser.FromSeconds; + public static IRespParser TimeSpanFromMilliseconds => TimeParser.FromMilliseconds; + public static IRespParser TimeSpanArrayFromMilliseconds => TimeParser.FromMilliseconds; + public static IRespParser DateTimeFromMilliseconds => TimeParser.FromMilliseconds; + public static IRespParser DateTimeArrayFromMilliseconds => TimeParser.FromMilliseconds; + internal static IRespParser Int64Index => Int64DefaultNegativeOneParser.Instance; + internal static IRespParser ListPopResult => DefaultParser.Instance; + + public static RedisValue ReadRedisValue(ref RespReader reader) + { + reader.DemandScalar(); + if (reader.IsNull) return global::StackExchange.Redis.RedisValue.Null; + if (reader.TryReadInt64(out var i64)) return i64; + if (reader.TryReadDouble(out var f64)) return f64; + + if (reader.UnsafeTryReadShortAscii(out var s)) return s; + return reader.ReadByteArray(); + } + + public static RedisKey ReadRedisKey(ref RespReader reader) + { + reader.DemandScalar(); + if (reader.IsNull) return global::StackExchange.Redis.RedisKey.Null; + if (reader.UnsafeTryReadShortAscii(out var s)) return s; + return reader.ReadByteArray(); + } + + private static readonly RespReader.Projection SharedReadRedisValue = ReadRedisValue; + private static readonly RespReader.Projection SharedReadRedisKey = ReadRedisKey; + + private sealed class DefaultParser : IRespParser, IRespParser, + IRespParser>, IRespParser, IRespParser, + IRespParser, IRespParser, IRespParser, + IRespParser + { + private DefaultParser() { } + public static readonly DefaultParser Instance = new(); + + RedisValue IRespParser.Parse(ref RespReader reader) => ReadRedisValue(ref reader); + + RedisKey IRespParser.Parse(ref RespReader reader) => ReadRedisKey(ref reader); + + Lease IRespParser>.Parse(ref RespReader reader) + { + reader.DemandScalar(); + if (reader.IsNull) return null!; + var len = reader.ScalarLength(); + var lease = Lease.Create(len); + reader.CopyTo(lease.Span); + return lease; + } + + RedisValue[] IRespParser.Parse(ref RespReader reader) + => reader.ReadArray(SharedReadRedisValue, scalar: true)!; + + RedisKey[] IRespParser.Parse(ref RespReader reader) + => reader.ReadArray(SharedReadRedisKey, scalar: true)!; + + HashEntry[] IRespParser.Parse(ref RespReader reader) + { + return reader.ReadPairArray( + SharedReadRedisValue, + SharedReadRedisValue, + static (x, y) => new HashEntry(x, y), + scalar: true)!; + + /* we could also do this locally: + reader.DemandAggregate(); + if (reader.IsNull) return null!; + var len = reader.AggregateLength() / 2; + if (len == 0) return []; + + var result = new HashEntry[len]; + for (int i = 0; i < result.Length; i++) + { + reader.MoveNextScalar(); + var x = ReadRedisValue(ref reader); + reader.MoveNextScalar(); + var y = ReadRedisValue(ref reader); + result[i] = new HashEntry(x, y); + } + + return result; + */ + } + + ListPopResult IRespParser.Parse(ref RespReader reader) + { + if (reader.IsNull) return global::StackExchange.Redis.ListPopResult.Null; + reader.DemandAggregate(); + reader.MoveNext(); + var key = ReadRedisKey(ref reader); + reader.MoveNext(); + var arr = reader.ReadArray(SharedReadRedisValue, scalar: true)!; + return new(key, arr); + } + + SortedSetEntry[] IRespParser.Parse(ref RespReader reader) + => ReadSortedSetEntryArray(ref reader); + + internal static SortedSetEntry[] ReadSortedSetEntryArray(ref RespReader reader) => reader.ReadPairArray( + SharedReadRedisValue, + static (ref RespReader reader) => reader.ReadDouble(), + static (x, y) => new SortedSetEntry(x, y), + scalar: true)!; + + internal static SortedSetEntry[] ReadLeasedSortedSetEntryArray(ref RespReader reader, out int count) + => reader.ReadLeasedPairArray( + SharedReadRedisValue, + static (ref RespReader reader) => reader.ReadDouble(), + static (x, y) => new SortedSetEntry(x, y), + out count, + scalar: true)!; + + SortedSetEntry? IRespParser.Parse(ref RespReader reader) + { + if (reader.IsNull) return null; + reader.DemandAggregate(); + if (reader.AggregateLength() < 2) return null; + reader.MoveNext(); + var member = ReadRedisValue(ref reader); + reader.MoveNext(); + var score = reader.ReadDouble(); + return new SortedSetEntry(member, score); + } + } +} + +internal sealed class Int64DefaultNegativeOneParser : IRespParser, IRespInlineParser +{ + private Int64DefaultNegativeOneParser() { } + public static readonly Int64DefaultNegativeOneParser Instance = new(); + public long Parse(ref RespReader reader) => reader.IsNull ? -1 : reader.ReadInt64(); +} + +internal sealed class TimeParser : IRespParser, IRespParser, IRespInlineParser, + IRespParser, IRespParser +{ + private readonly bool _millis; + public static readonly TimeParser FromMilliseconds = new(true); + public static readonly TimeParser FromSeconds = new(false); + + private readonly RespReader.Projection _readTimeSpan; + private readonly RespReader.Projection _readDateTime; + private TimeParser(bool millis) + { + _millis = millis; + _readTimeSpan = ReadTimeSpan; + _readDateTime = ReadDateTime; + } + + TimeSpan? IRespParser.Parse(ref RespReader reader) => ReadTimeSpan(ref reader); + private TimeSpan? ReadTimeSpan(ref RespReader reader) + { + if (reader.IsNull) return null; + if (reader.IsAggregate) + { + reader.MoveNext(); // take first element from aggregate + if (reader.IsNull) return null; + } + var value = reader.ReadInt64(); + if (value < 0) return null; // -1 means no expiry and -2 means key does not exist + return _millis ? TimeSpan.FromMilliseconds(value) : TimeSpan.FromSeconds(value); + } + + DateTime? IRespParser.Parse(ref RespReader reader) => ReadDateTime(ref reader); + private DateTime? ReadDateTime(ref RespReader reader) + { + if (reader.IsNull) return null; + if (reader.IsAggregate) + { + reader.MoveNext(); // take first element from aggregate + if (reader.IsNull) return null; + } + var value = reader.ReadInt64(); + if (value < 0) return null; // -1 means no expiry and -2 means key does not exist + return _millis ? RedisBase.UnixEpoch.AddMilliseconds(value) : RedisBase.UnixEpoch.AddSeconds(value); + } + + TimeSpan?[] IRespParser.Parse(ref RespReader reader) => reader.ReadArray(_readTimeSpan, scalar: true)!; + + DateTime?[] IRespParser.Parse(ref RespReader reader) => reader.ReadArray(_readDateTime, scalar: true)!; +} diff --git a/src/RESPite.StackExchange.Redis/readme.md b/src/RESPite.StackExchange.Redis/readme.md new file mode 100644 index 000000000..4b193bf5d --- /dev/null +++ b/src/RESPite.StackExchange.Redis/readme.md @@ -0,0 +1,5 @@ +# RESPite.StackExchange.Redis + +This libary is a bridge between StackExchange.Redis and RESPite. It provides the `IConnectionMultiplexer`, +`IDatabase`, `IServer` APIs, but implemented using the `RespConnection` and `RespContext` primitives from +RESPite. This is the intended direction for StackExchange.Redis vFuture. \ No newline at end of file diff --git a/src/RESPite/Connections/IRespContextSource.cs b/src/RESPite/Connections/IRespContextSource.cs new file mode 100644 index 000000000..2305e45c3 --- /dev/null +++ b/src/RESPite/Connections/IRespContextSource.cs @@ -0,0 +1,6 @@ +namespace RESPite.Connections; + +public interface IRespContextSource +{ + ref readonly RespContext Context { get; } +} diff --git a/src/RESPite/Connections/Internal/BasicBatchConnection.cs b/src/RESPite/Connections/Internal/BasicBatchConnection.cs new file mode 100644 index 000000000..6490acf51 --- /dev/null +++ b/src/RESPite/Connections/Internal/BasicBatchConnection.cs @@ -0,0 +1,84 @@ +using System.Buffers; + +namespace RESPite.Connections.Internal; + +/// +/// Holds basic RespOperation, queue and release - turns +/// multiple send/send-many calls into a single send-many call. +/// +internal sealed class BasicBatchConnection(in RespContext context, int sizeHint) : BufferingBatchConnection(context, sizeHint) +{ + public override Task FlushAsync() + { + try + { + var count = Flush(out var oversized, out var single); + return count switch + { + 0 => Task.CompletedTask, + 1 => Tail.WriteAsync(single!), + _ => SendAndRecycleAsync(Tail, oversized, count), + }; + } + catch (Exception ex) + { + OnConnectionError(ex); + throw; + } + + static async Task SendAndRecycleAsync(RespConnection tail, RespOperation[] oversized, int count) + { + try + { + await tail.WriteAsync(oversized.AsMemory(0, count)).ConfigureAwait(false); + ArrayPool.Shared.Return(oversized); // only on success, in case captured + } + catch (Exception ex) + { + TrySetException(oversized.AsSpan(0, count), ex); + throw; + } + } + } + + public override void Flush() + { + string operation = nameof(Flush); + int count; + RespOperation[] oversized; + RespOperation single; + try + { + count = Flush(out oversized, out single); + switch (count) + { + case 0: + return; + case 1: + operation = nameof(Tail.Write); + Tail.Write(single!); + return; + } + } + catch (Exception ex) + { + OnConnectionError(ex, operation); + throw; + } + + try + { + Tail.Write(oversized.AsSpan(0, count)); + } + catch (Exception ex) + { + TrySetException(oversized.AsSpan(0, count), ex); + throw; + } + finally + { + // in the sync case, Send takes a span - hence can't have been captured anywhere; always recycle + ArrayPool.Shared.Return(oversized); + } + } +} diff --git a/src/RESPite/Connections/Internal/BufferingBatchConnection.cs b/src/RESPite/Connections/Internal/BufferingBatchConnection.cs new file mode 100644 index 000000000..1cf0264ff --- /dev/null +++ b/src/RESPite/Connections/Internal/BufferingBatchConnection.cs @@ -0,0 +1,165 @@ +using System.Buffers; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using RESPite.Internal; + +namespace RESPite.Connections.Internal; + +/// +/// Collects messages into a buffer, and then flushes them all at once. Subclass defines how to flush. +/// +internal abstract class BufferingBatchConnection(in RespContext context, int sizeHint) : RespBatch(context) +{ + internal static void Return(ref RespOperation[] buffer) + { + if (buffer.Length != 0) + { + DebugCounters.OnBatchBufferReturn(buffer.Length); + ArrayPool.Shared.Return(buffer); + buffer = []; + } + } + + private static RespOperation[] Rent(int sizeHint) + { + if (sizeHint <= 0) return []; + var arr = ArrayPool.Shared.Rent(sizeHint); + DebugCounters.OnBatchBufferLease(arr.Length); + return arr; + } + + private RespOperation[] _buffer = Rent(sizeHint); + + private int _count = 0; + + protected object SyncLock => this; + + protected override void OnDispose(bool disposing) + { + if (disposing) + { + lock (SyncLock) + { + /* everyone else checks disposal inside the lock; + the base type already marked as disposed, so: + once we're past this point, we can be sure that no more + items will be added */ + Debug.Assert(IsDisposed); + } + + var buffer = _buffer; + _buffer = []; + var span = buffer.AsSpan(0, _count); + foreach (var message in span) + { + message.Message.TrySetException(message.Token, CreateObjectDisposedException()); + } + + Return(ref buffer); + ConnectionError = null; + } + + base.OnDispose(disposing); + } + + internal override int OutstandingOperations => _count; // always a thread-race, no point locking + + public override void Write(in RespOperation message) + { + lock (SyncLock) + { + ThrowIfDisposed(); + EnsureSpaceForLocked(1); + _buffer[_count++] = message; + } + } + + public override void EnsureCapacity(int additionalCount) + { + if (additionalCount > _buffer.Length - _count) + { + lock (SyncLock) + { + EnsureSpaceForLocked(additionalCount); + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void EnsureSpaceForLocked(int add) + { + var required = _count + add; + if (_buffer.Length < required) GrowLocked(required); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private void GrowLocked(int required) + { + const int maxLength = 0X7FFFFFC7; // not directly available on down-level runtimes :( + var newCapacity = _buffer.Length * 2; // try doubling + if ((uint)newCapacity > maxLength) newCapacity = maxLength; // account for max + if (newCapacity < required) newCapacity = required; // in case doubling wasn't enough + + var newBuffer = Rent(newCapacity); + DebugCounters.OnBatchGrow(_count); + _buffer.AsSpan(0, _count).CopyTo(newBuffer); + Return(ref _buffer); + _buffer = newBuffer; + } + + internal override void Write(ReadOnlySpan messages) + { + if (messages.Length != 0) + { + lock (SyncLock) + { + ThrowIfDisposed(); + EnsureSpaceForLocked(messages.Length); + messages.CopyTo(_buffer.AsSpan(_count)); + _count += messages.Length; + } + } + } + + protected int Flush(out RespOperation[] oversized, out RespOperation single) + { + lock (SyncLock) + { + var count = _count; + switch (_count) + { + case 0: + // nothing to do, keep our local buffer + oversized = []; + single = default; + return 0; + case 1: + // but keep our local buffer, just reset the count + oversized = []; + single = _buffer[0]; + _count = 0; + return 1; + default: + // hand the caller our buffer, and reset + oversized = _buffer; + single = default; + _buffer = []; // we *expect* people to only flush once, so: don't rent a new one + _count = 0; + return count; + } + } + } + + protected void OnConnectionError(Exception exception, [CallerMemberName] string operation = "") + => OnConnectionError(ConnectionError, exception, operation); + + public override event EventHandler? ConnectionError; + + protected static void TrySetException(ReadOnlySpan messages, Exception ex) + { + foreach (var message in messages) + { + message.Message.TrySetException(message.Token, ex); + } + } +} diff --git a/src/RESPite/Connections/Internal/ConfiguredConnection.cs b/src/RESPite/Connections/Internal/ConfiguredConnection.cs new file mode 100644 index 000000000..ad5ae6ce7 --- /dev/null +++ b/src/RESPite/Connections/Internal/ConfiguredConnection.cs @@ -0,0 +1,4 @@ +namespace RESPite.Connections.Internal; + +internal sealed class ConfiguredConnection(in RespContext tail, RespConfiguration configuration) + : DecoratorConnection(tail, configuration); diff --git a/src/RESPite/Connections/Internal/DecoratorConnection.cs b/src/RESPite/Connections/Internal/DecoratorConnection.cs new file mode 100644 index 000000000..83ac600b6 --- /dev/null +++ b/src/RESPite/Connections/Internal/DecoratorConnection.cs @@ -0,0 +1,74 @@ +namespace RESPite.Connections.Internal; + +internal abstract class DecoratorConnection : RespConnection +{ + protected readonly RespConnection Tail; + + public DecoratorConnection(in RespContext tail, RespConfiguration? configuration = null) + : base(tail, configuration) + { + Tail = tail.Connection; + } + + internal override void ThrowIfUnhealthy() => Tail.ThrowIfUnhealthy(); + + protected virtual bool OwnsConnection => true; + + internal override bool IsHealthy => base.IsHealthy & Tail.IsHealthy; + internal override int OutstandingOperations => Tail.OutstandingOperations; + + protected override void OnDispose(bool disposing) + { + if (PrivateConnectionError is not null) + { + PrivateConnectionError = null; // force unsubscribe + Tail.ConnectionError -= _onConnectionError; + } + if (disposing & OwnsConnection) Tail.Dispose(); + } + + protected override ValueTask OnDisposeAsync() => + OwnsConnection ? Tail.DisposeAsync() : default; + + // Note that default behaviour *does not* add a dispose check, as it + // assumes that the connection is "owned", and therefore the tail will throw. + public override void Write(in RespOperation message) => Tail.Write(message); + + internal override void Write(ReadOnlySpan messages) => Tail.Write(messages); + + public override Task WriteAsync(in RespOperation message) => Tail.WriteAsync(in message); + + internal override Task WriteAsync(ReadOnlyMemory messages) => Tail.WriteAsync(messages); + + private event EventHandler? PrivateConnectionError; // to wrap "sender" + private EventHandler? _onConnectionError; // local lazy callback + public override event EventHandler? ConnectionError + { + add + { + if (value is not null) + { + if (PrivateConnectionError is null) + { + Tail.ConnectionError += _onConnectionError ??= OnConnectionError; + } + + PrivateConnectionError += value; + } + } + remove + { + if (value is not null) + { + PrivateConnectionError -= value; + if (PrivateConnectionError is null) // last unsubscribe + { + Tail.ConnectionError -= _onConnectionError; + } + } + } + } + + private void OnConnectionError(object? sender, RespConnectionErrorEventArgs e) + => PrivateConnectionError?.Invoke(this, e); // mask sender +} diff --git a/src/RESPite/Connections/Internal/MergingBatchConnection.cs b/src/RESPite/Connections/Internal/MergingBatchConnection.cs new file mode 100644 index 000000000..fe8e91914 --- /dev/null +++ b/src/RESPite/Connections/Internal/MergingBatchConnection.cs @@ -0,0 +1,80 @@ +using System.Buffers; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using RESPite.Internal; +using RESPite.Messages; + +namespace RESPite.Connections.Internal; + +/// +/// Holds basic RespOperation, queue and release - turns +/// multiple send calls into a single multi-message send. +/// +internal sealed class MergingBatchConnection(in RespContext context, int sizeHint) : BufferingBatchConnection(context, sizeHint) +{ + // Collate new messages in a batch-specific buffer, rather than the usual thread-local one; this means + // that all the messages will be in contiguous memory. + private readonly BlockBufferSerializer _serializer = BlockBufferSerializer.Create(retainChain: true); + + protected override void OnDispose(bool disposing) + { + if (disposing) + { + _serializer.Clear(); + } + + base.OnDispose(disposing); + } + + internal override BlockBufferSerializer Serializer + { + get + { + ThrowIfDisposed(); + return _serializer; + } + } + + private bool Flush(out RespOperation single) + { + lock (SyncLock) + { + var payload = _serializer.Flush(); + var count = Flush(out var oversized, out single); + switch (count) + { + case 0: + Debug.Assert(payload.IsEmpty); + return false; + case 1: + Debug.Assert(!payload.IsEmpty); + // send as a single-message we don't need the extra add-ref on the entire payload + BlockBufferSerializer.BlockBuffer.Release(in payload); + + return true; + default: + Debug.Assert(!payload.IsEmpty); + var msg = RespMultiMessage.Get(oversized, count); + msg.Init(payload, Context.CancellationToken); + single = new(msg); + return true; + } + } + } + + public override Task FlushAsync() + { + return Flush(out var single) + ? Tail.WriteAsync(single) + : Task.CompletedTask; + } + + public override void Flush() + { + if (Flush(out var single)) + { + Tail.Write(single); + } + } +} diff --git a/src/RESPite/Connections/Internal/Node.cs b/src/RESPite/Connections/Internal/Node.cs new file mode 100644 index 000000000..d58aab3ca --- /dev/null +++ b/src/RESPite/Connections/Internal/Node.cs @@ -0,0 +1,285 @@ +using RESPite.Internal; + +namespace RESPite.Connections.Internal; + +internal sealed class Node : IDisposable, IAsyncDisposable, IRespContextSource +{ + private bool _isDisposed; + public override string ToString() => Label; + public string EndPoint { get; } + public int Port { get; } + private string? _label; + internal string Label => _label ??= $"{EndPoint}:{Port}"; + internal RespConnectionManager Manager { get; } + + public Node(RespConnectionManager manager, string endPoint, int port) + { + Manager = manager; + EndPoint = endPoint; + Port = port; + _interactive = new(this, false); + } + + internal object? UserObject { get; set; } + public bool IsConnected => _interactive.IsConnected; + public bool IsConnecting => _interactive.IsConnecting; + public bool IsReplica { get; private set; } + + public void Dispose() + { + _isDisposed = true; + _interactive.Dispose(); + _subscription?.Dispose(); + } + + public async ValueTask DisposeAsync() + { + _isDisposed = true; + await _interactive.DisposeAsync().ConfigureAwait(false); + if (_subscription is { } obj) + { + await obj.DisposeAsync().ConfigureAwait(false); + } + } + + private readonly NodeConnection _interactive; + private NodeConnection? _subscription; + + public ref readonly RespContext Context => ref _interactive.Context; + + public RespConnection InteractiveConnection => _interactive.Connection; + + public Task ConnectAsync( + TextWriter? log = null, + bool force = false, + bool pubSub = false) + { + if (_isDisposed) return Task.FromResult(false); + if (!pubSub) + { + return _interactive.ConnectAsync(log, force); + } + + _subscription ??= new(this, pubSub); + return _subscription.ConnectAsync(log, force); + } + + public Shard AsShard() + { + return new( + 0, + int.MaxValue, + Port, + IsReplica ? ShardFlags.Replica : ShardFlags.None, + EndPoint, + "", + this); + } +} + +internal sealed class NodeConnection : IDisposable, IAsyncDisposable, IRespContextSource +{ + // private EventHandler? _onConnectionError; + private readonly Node _node; + private readonly bool _pubSub; + + public override string ToString() => Label; + + public NodeConnection(Node node, bool pubSub) + { + _node = node; + _pubSub = pubSub; + } + + private string? _label; + private string Label => _label ??= _pubSub ? $"{_node.Label}/s" : _node.Label; + public Node Node => _node; + private int _state = (int)NodeState.Disconnected; + + private NodeState State => (NodeState)_state; + + private enum NodeState + { + Disconnected, + Connecting, + Connected, + Faulted, + Disposed, + } + + public bool IsFaulted => State == NodeState.Faulted; + public bool IsConnected => State == NodeState.Connected; + public bool IsConnecting => State == NodeState.Connecting; + + public ref readonly RespContext Context => ref _connection.Context; + private RespConnection _connection = RespContext.Null.Connection; + public RespConnection Connection => _connection; + + public async Task ConnectAsync( + TextWriter? log = null, + bool force = false, + CancellationToken cancellationToken = default) + { + int state; + bool connecting = false; + do + { + state = _state; + switch ((NodeState)state) + { + case NodeState.Connected when force: + case NodeState.Connecting when force: + log.LogLocked($"[{Label}] (already {(NodeState)state}, but forcing reconnect...)"); + break; // reconnect anyway! + case NodeState.Connected: + case NodeState.Connecting: + log.LogLocked($"[{Label}] (already {(NodeState)state})"); + return true; + case NodeState.Disposed: + log.LogLocked($"[{Label}] (already {(NodeState)state})"); + return false; + } + } + // otherwise: move to connecting (or retry, if there was a race) + while (Interlocked.CompareExchange(ref _state, (int)NodeState.Connecting, state) != state); + + try + { + // observe outcome of CEX above (noting that if forcing, we don't do that CEX) + if (State == NodeState.Connecting) state = (int)NodeState.Connecting; + + log.LogLocked($"[{Label}] connecting..."); + connecting = true; + var manager = _node.Manager; + var connection = await manager.ConnectionFactory.ConnectAsync( + _node.EndPoint, + _node.Port, + cancellationToken: cancellationToken).ConfigureAwait(false); + connecting = false; + + log.LogLocked($"[{Label}] Performing handshake..."); + // TODO: handshake + + // finalize the connections + log.LogLocked($"[{Label}] Finalizing..."); + var oldConnection = _connection; + _connection = connection.Synchronized(); + await oldConnection.DisposeAsync().ConfigureAwait(false); + + // check nothing changed while we weren't looking + if (Interlocked.CompareExchange(ref _state, (int)NodeState.Connected, state) == state) + { + // success + log.LogLocked($"[{Label}] (success)"); + /* + connection.ConnectionError += _onConnectionError ??= OnConnectionError; + + if (state == (int)NodeState.Faulted) OnConnectionRestored(); + */ + return true; + } + + log.LogLocked($"[{Label}] (unable to complete; became {State})"); + _connection = oldConnection; + return false; + } + catch (Exception ex) + { + log.LogLocked($"[{Label}] Faulted: {ex.Message}{(connecting ? " (while connecting)" : "")}"); + // something failed; cleanup and move to faulted, unless disposed + if (State != NodeState.Disposed) + { + _state = (int)NodeState.Faulted; + } + + var conn = _connection; + _connection = RespContext.Null.Connection; + await conn.DisposeAsync(); + + /* + var failureType = ConnectionFailureType.InternalFailure; + if (connecting) + { + failureType = ConnectionFailureType.UnableToConnect; + } + else if (ex is SocketException) + { + failureType = ConnectionFailureType.SocketFailure; + } + else if (ex is ObjectDisposedException) + { + failureType = ConnectionFailureType.ConnectionDisposed; + } + + OnConnectionError(failureType, ex); + */ + return false; + } + } +/* + private void OnConnectionError(object? sender, RespConnection.RespConnectionErrorEventArgs e) + { + var handler = _multiplexer.DirectConnectionFailed; + if (handler is not null) + { + handler(_multiplexer, new ConnectionFailedEventArgs( + handler, + _multiplexer, + _endPoint, + _connectionType, + ConnectionFailureType.InternalFailure, + e.Exception, + Label)); + } + } + + private void OnConnectionError(ConnectionFailureType failureType, Exception? exception = null) + { + var handler = _multiplexer.DirectConnectionFailed; + if (handler is not null) + { + handler(_multiplexer, new ConnectionFailedEventArgs( + handler, + _multiplexer, + _endPoint, + _connectionType, + failureType, + exception, + Label)); + } + } + + private void OnConnectionRestored() + { + var handler = _multiplexer.DirectConnectionRestored; + if (handler is not null) + { + handler(_multiplexer, new ConnectionFailedEventArgs( + handler, + _multiplexer, + _endPoint, + _connectionType, + ConnectionFailureType.None, + null, + Label)); + } + }*/ + + public void Dispose() + { + _state = (int)NodeState.Disposed; + var conn = _connection; + _connection = RespContext.Null.Connection; + conn.Dispose(); + // OnConnectionError(ConnectionFailureType.ConnectionDisposed); + } + + public async ValueTask DisposeAsync() + { + _state = (int)NodeState.Disposed; + var conn = _connection; + _connection = RespContext.Null.Connection; + await conn.DisposeAsync().ConfigureAwait(false); + // OnConnectionError(ConnectionFailureType.ConnectionDisposed); + } +} diff --git a/src/RESPite/Connections/Internal/NullConnection.cs b/src/RESPite/Connections/Internal/NullConnection.cs new file mode 100644 index 000000000..a803e2c48 --- /dev/null +++ b/src/RESPite/Connections/Internal/NullConnection.cs @@ -0,0 +1,79 @@ +namespace RESPite.Connections.Internal; + +internal sealed class NullConnection : RespConnection +{ + private enum FailureMode + { + Default, + Disposed, + NonRoutable, + } + + private readonly FailureMode _failureMode; + + public static NullConnection WithConfiguration(RespConfiguration configuration) + => ReferenceEquals(configuration, RespConfiguration.Default) + ? Default + : new(configuration, FailureMode.Default); + + // convenience singletons (all but Default are lazily created) + public static readonly NullConnection Default = new(RespConfiguration.Default, FailureMode.Default); + private static NullConnection? _disposed, _nonRoutable; + public static NullConnection Disposed => + _disposed ??= new(RespConfiguration.Default, FailureMode.Disposed); + public static NullConnection NonRoutable => + _nonRoutable ??= new(RespConfiguration.Default, FailureMode.NonRoutable); + + internal override int OutstandingOperations => 0; + + private NullConnection(RespConfiguration configuration, FailureMode failureMode) : base(configuration) + => _failureMode = failureMode; + + private void SetError(in RespOperation message) + { + message.TrySetException(_failureMode switch + { + FailureMode.Disposed => new ObjectDisposedException(nameof(RespConnection)), + FailureMode.NonRoutable => new InvalidOperationException("No connection is available for this operation."), + _ => new NotSupportedException("Null connections do not support sending messages."), + }); + } + + public override void Write(in RespOperation message) => SetError(in message); + + public override Task WriteAsync(in RespOperation message) + { + SetError(message); + return Task.CompletedTask; + } + + internal override void Write(ReadOnlySpan messages) + { + foreach (var message in messages) + { + SetError(in message); + } + } + + internal override Task WriteAsync(ReadOnlyMemory messages) + { + foreach (var message in messages.Span) + { + SetError(in message); + } + + return Task.CompletedTask; + } + + public override event EventHandler? ConnectionError + { + add + { + } + remove + { + } + } + + internal override void ThrowIfUnhealthy() { } +} diff --git a/src/RESPite/Connections/Internal/RoutedConnection.cs b/src/RESPite/Connections/Internal/RoutedConnection.cs new file mode 100644 index 000000000..de9ea6310 --- /dev/null +++ b/src/RESPite/Connections/Internal/RoutedConnection.cs @@ -0,0 +1,110 @@ +using System.Runtime.CompilerServices; +using RESPite.Internal; + +namespace RESPite.Connections.Internal; + +internal sealed class RoutedConnection : RespConnection +{ + private Shard[] _shards = []; + + private Shard[] _primaries = [], _replicas = []; + + public void SetRoutingTable(ReadOnlySpan shards) + { + if (shards.Length == _shards.Length) + { + bool match = true; + int index = 0; + Shard previous = default; + foreach (ref readonly Shard shard in shards) + { + if (index != 0 && previous.CompareTo(shard) > 0) ThrowNotSorted(); + if (!shard.Equals(_shards[index++])) + { + match = false; + break; + } + + previous = shard; + } + + if (match) return; // nothing has changed + } + + _shards = shards.ToArray(); + + static void ThrowNotSorted() => + throw new InvalidOperationException($"The input to {nameof(SetRoutingTable)} must be pre-sorted."); + } + + public override event EventHandler? ConnectionError + { + add => throw new NotSupportedException(); + remove => throw new NotSupportedException(); + } + + internal override int OutstandingOperations + { + get + { + int count = 0; + foreach (var shard in _shards) + { + if (shard.GetConnection() is { } conn) count += conn.OutstandingOperations; + } + + return count; + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public override void Write(in RespOperation message) + { + // simplest thing possible for now; long term, we could do bunching. + var conn = Select( + replicas: (message.Flags & RespMessageBase.StateFlags.Replica) != 0, + slot: message.Slot); + if (conn is null) + { + WriteNonPreferred(message); + } + else + { + // this is the happy path + conn.Write(in message); + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private void WriteNonPreferred(in RespOperation message) + { + var flags = message.Flags; + var conn = (flags & RespMessageBase.StateFlags.Demand) == 0 + ? Select((flags & RespMessageBase.StateFlags.Replica) == 0, message.Slot) + : null; + if (conn is null) + { + message.TrySetException( + new InvalidOperationException("No connection is available to handle this request.")); + } + else + { + conn.Write(in message); + } + } + + private RespConnection? Select(bool replicas, int slot) + { + var shards = replicas ? _replicas : _primaries; + foreach (var shard in shards) + { + if ((shard.From <= slot & shard.To >= slot) + && shard.GetConnection() is { IsHealthy: true } conn) + { + return conn; + } + } + + return null; + } +} diff --git a/src/RESPite/Connections/Internal/Shard.cs b/src/RESPite/Connections/Internal/Shard.cs new file mode 100644 index 000000000..d35a02b70 --- /dev/null +++ b/src/RESPite/Connections/Internal/Shard.cs @@ -0,0 +1,84 @@ +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; + +namespace RESPite.Connections.Internal; + +[Flags] +internal enum ShardFlags +{ + None = 0, + Replica = 1, +} + +internal readonly struct Shard( + int from, + int to, + int port, + ShardFlags flags, + string primary, + string secondary, + IRespContextSource? source) : IEquatable, IComparable, IComparable +{ + public readonly int From = from; + public readonly int To = to; + public readonly int Port = port; + public readonly ShardFlags Flags = flags; + public readonly string Primary = primary; + public readonly string Secondary = secondary; + public bool Repliace => (Flags & ShardFlags.Replica) != 0; + + private readonly IRespContextSource? source = source; + + public override string ToString() => $"[{From}-{To}] {source}"; + public int CompareTo(object? obj) => obj is Shard shard ? CompareTo(in shard) : -1; + + public override int GetHashCode() => From ^ To ^ Port ^ (int)Flags ^ Primary.GetHashCode(); + + public override bool Equals([NotNullWhen(true)] object? obj) + => obj is Shard other && Equals(other); + + bool IEquatable.Equals(Shard other) => Equals(in other); + + public bool Equals(in Shard other) => + (From == other.From + & To == other.To + & Port == other.Port + & Flags == other.Flags + & Primary == other.Primary + & Secondary == other.Secondary) + && ReferenceEquals(source, other.source); + + int IComparable.CompareTo(Shard other) => CompareTo(in other); + + public int CompareTo(in Shard other) + { + int delta = From - other.From; + if (delta == 0) + { + delta = To - other.To; + if (delta == 0) + { + delta = (int)Flags - (int)other.Flags; + if (delta == 0) + { + delta = string.CompareOrdinal(Primary, other.Primary); + } + } + } + + return delta; + } + + public RespConnection? GetConnection() + { + if (source is not null) + { + // in this *very specific* case: watch out for null by-refs; we don't + // do this exhaustively! + ref readonly RespContext ctx = ref source.Context; + if (!Unsafe.IsNullRef(ref Unsafe.AsRef(in ctx))) return ctx.Connection; + } + + return null; + } +} diff --git a/src/RESPite/Connections/Internal/StreamConnection.cs b/src/RESPite/Connections/Internal/StreamConnection.cs new file mode 100644 index 000000000..2f6cf8e96 --- /dev/null +++ b/src/RESPite/Connections/Internal/StreamConnection.cs @@ -0,0 +1,829 @@ +// #define PARSE_DETAIL // additional trace info in CommitAndParseFrames + +#if DEBUG +#define PARSE_DETAIL // always enable this in debug builds +#endif + +using System.Buffers; +using System.Collections.Concurrent; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using RESPite.Internal; +using RESPite.Messages; + +namespace RESPite.Connections.Internal; + +internal sealed class StreamConnection : RespConnection +{ + private bool _isDoomed; + private RespScanState _readScanState; + private CycleBuffer _readBuffer, _writeBuffer; + + internal override int OutstandingOperations => _outstanding.Count; + internal override bool IsHealthy => !_isDoomed; + + public Task Reader { get; private set; } = Task.CompletedTask; + + private readonly Stream tail; + private ConcurrentQueue _outstanding = new(); + + public StreamConnection(in RespContext context, RespConfiguration configuration, Stream tail, bool asyncRead = true) + : base(context, configuration) + { + if (!(tail.CanRead && tail.CanWrite)) Throw(); + this.tail = tail; + var memoryPool = Configuration.GetService>(); + _readBuffer = CycleBuffer.Create(memoryPool); + _writeBuffer = CycleBuffer.Create(memoryPool); + if (asyncRead) + { + Reader = Task.Run(ReadAllAsync); + } + else + { + new Thread(ReadAll).Start(); + } + + static void Throw() => throw new ArgumentException("Stream must be readable and writable", nameof(tail)); + } + + public StreamConnection(RespConfiguration configuration, Stream tail, bool asyncRead = true) + : this(RespContext.Null, configuration, tail, asyncRead) + { + } + + public RespMode Mode { get; set; } = RespMode.Resp2; + + public enum RespMode + { + Resp2, + Resp2PubSub, + Resp3, + } + + private static byte[]? SharedNoLease; + + private bool CommitAndParseFrames(int bytesRead) + { + if (bytesRead <= 0) + { + return false; + } + + // let's bypass a bunch of ldarg0 by hoisting the field-refs (this is **NOT** a struct copy; emphasis "ref") + ref RespScanState state = ref _readScanState; + ref CycleBuffer readBuffer = ref _readBuffer; + +#if PARSE_DETAIL + string src = $"parse {bytesRead}"; + try +#endif + { + Debug.Assert(readBuffer.GetCommittedLength() >= 0, "multi-segment running-indices are corrupt"); +#if PARSE_DETAIL + src += $" ({readBuffer.GetCommittedLength()}+{bytesRead}-{state.TotalBytes})"; +#endif + Debug.Assert( + bytesRead <= readBuffer.UncommittedAvailable, + $"Insufficient bytes in {nameof(CommitAndParseFrames)}; got {bytesRead}, Available={readBuffer.UncommittedAvailable}"); + readBuffer.Commit(bytesRead); +#if PARSE_DETAIL + src += $",total {readBuffer.GetCommittedLength()}"; +#endif + var scanner = RespFrameScanner.Default; + + OperationStatus status = OperationStatus.NeedMoreData; + if (readBuffer.TryGetCommitted(out var fullSpan)) + { + int fullyConsumed = 0; + var toParse = fullSpan.Slice((int)state.TotalBytes); // skip what we've already parsed + + Debug.Assert(!toParse.IsEmpty); + while (true) + { +#if PARSE_DETAIL + src += $",span {toParse.Length}"; +#endif + int totalBytesBefore = (int)state.TotalBytes; + if (toParse.Length < RespScanState.MinBytes + || (status = scanner.TryRead(ref state, toParse)) != OperationStatus.Done) + { + break; + } + + Debug.Assert( + state is + { + IsComplete: true, TotalBytes: >= RespScanState.MinBytes, Prefix: not RespPrefix.None + }, + "Invalid RESP read state"); + + // extract the frame + var bytes = (int)state.TotalBytes; +#if PARSE_DETAIL + src += $",frame {bytes}"; +#endif + // send the frame somewhere (note this is the *full* frame, not just the bit we just parsed) + OnResponseFrame(state.Prefix, fullSpan.Slice(fullyConsumed, bytes), ref SharedNoLease); + + // update our buffers to the unread potions and reset for a new RESP frame + fullyConsumed += bytes; + toParse = toParse.Slice(bytes - totalBytesBefore); // move past the extra bytes we just read + state = default; + status = OperationStatus.NeedMoreData; + } + + readBuffer.DiscardCommitted(fullyConsumed); + } + else // the same thing again, but this time with multi-segment sequence + { + var fullSequence = readBuffer.GetAllCommitted(); + Debug.Assert( + fullSequence is { IsEmpty: false, IsSingleSegment: false }, + "non-trivial sequence expected"); + + long fullyConsumed = 0; + var toParse = fullSequence.Slice((int)state.TotalBytes); // skip what we've already parsed + while (true) + { +#if PARSE_DETAIL + src += $",ros {toParse.Length}"; +#endif + int totalBytesBefore = (int)state.TotalBytes; + if (toParse.Length < RespScanState.MinBytes + || (status = scanner.TryRead(ref state, toParse)) != OperationStatus.Done) + { + break; + } + + Debug.Assert( + state is + { + IsComplete: true, TotalBytes: >= RespScanState.MinBytes, Prefix: not RespPrefix.None + }, + "Invalid RESP read state"); + + // extract the frame + var bytes = (int)state.TotalBytes; +#if PARSE_DETAIL + src += $",frame {bytes}"; +#endif + // send the frame somewhere (note this is the *full* frame, not just the bit we just parsed) + OnResponseFrame(state.Prefix, fullSequence.Slice(fullyConsumed, bytes)); + + // update our buffers to the unread potions and reset for a new RESP frame + fullyConsumed += bytes; + toParse = toParse.Slice(bytes - totalBytesBefore); // move past the extra bytes we just read + state = default; + status = OperationStatus.NeedMoreData; + } + + readBuffer.DiscardCommitted(fullyConsumed); + } + + if (status != OperationStatus.NeedMoreData) + { + ThrowStatus(status); + + static void ThrowStatus(OperationStatus status) => + throw new InvalidOperationException($"Unexpected operation status: {status}"); + } + + return true; + } +#if PARSE_DETAIL + catch (Exception ex) + { + Debug.WriteLine($"{nameof(CommitAndParseFrames)}: {ex.Message}"); + Debug.WriteLine(src); + ActivationHelper.DebugBreak(); + throw new InvalidOperationException($"{src} lead to {ex.Message}", ex); + } +#endif + } + + private async Task ReadAllAsync() + { + try + { + int read; + do + { + var buffer = _readBuffer.GetUncommittedMemory(); + var pending = tail.ReadAsync(buffer, CancellationToken.None); +#if DEBUG + bool inline = pending.IsCompleted; +#endif + read = await pending.ConfigureAwait(false); +#if DEBUG + DebugCounters.OnAsyncRead(read, inline); +#endif + } + // another formatter glitch + while (CommitAndParseFrames(read)); + + Volatile.Write(ref _readStatus, ReaderCompleted); + _readBuffer.Release(); // clean exit, we can recycle + } + catch (Exception ex) + { + OnReadException(ex); + throw; + } + finally + { + OnReadAllFinally(); + } + } + + private void ReadAll() + { + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + Reader = tcs.Task; + try + { + int read; + do + { +#if NETCOREAPP || NETSTANDARD2_1_OR_GREATER + var buffer = _readBuffer.GetUncommittedSpan(); + read = tail.Read(buffer); +#else + var buffer = _readBuffer.GetUncommittedMemory(); + read = tail.Read(buffer); +#endif + DebugCounters.OnRead(read); + } + // another formatter glitch + while (CommitAndParseFrames(read)); + + Volatile.Write(ref _readStatus, ReaderCompleted); + _readBuffer.Release(); // clean exit, we can recycle + tcs.TrySetResult(null); + } + catch (Exception ex) + { + tcs.TrySetException(ex); + OnReadException(ex); + } + finally + { + OnReadAllFinally(); + } + } + + internal override void ThrowIfUnhealthy() + { + if (_fault is { } fault) Throw(fault); + base.ThrowIfUnhealthy(); + + static void Throw(Exception fault) => throw new InvalidOperationException("Connection is unhealthy", fault); + } + + private void OnReadException(Exception ex, [CallerMemberName] string operation = "") + { + _fault ??= ex; + Volatile.Write(ref _readStatus, ReaderFailed); + Debug.WriteLine($"Reader failed: {ex.Message}"); + ActivationHelper.DebugBreak(); + while (_outstanding.TryDequeue(out var pending)) + { + pending.Message.TrySetException(pending.Token, ex); + } + + OnConnectionError(ConnectionError, ex, operation); + } + + private void OnReadAllFinally() + { + Doom(); + _readBuffer.Release(); + + // abandon anything in the queue + while (_outstanding.TryDequeue(out var pending)) + { + pending.Message.TrySetCanceled(pending.Token, CancellationToken.None); + } + } + + private static readonly ulong + ArrayPong_LC_Bulk = RespConstants.UnsafeCpuUInt64("*2\r\n$4\r\npong\r\n$"u8), + ArrayPong_UC_Bulk = RespConstants.UnsafeCpuUInt64("*2\r\n$4\r\nPONG\r\n$"u8), + ArrayPong_LC_Simple = RespConstants.UnsafeCpuUInt64("*2\r\n+pong\r\n$"u8), + ArrayPong_UC_Simple = RespConstants.UnsafeCpuUInt64("*2\r\n+PONG\r\n$"u8); + + private static readonly uint + pong = RespConstants.UnsafeCpuUInt32("pong"u8), + PONG = RespConstants.UnsafeCpuUInt32("PONG"u8); + + private void OnOutOfBand(ReadOnlySpan payload, ref byte[]? lease) + { + throw new NotImplementedException(nameof(OnOutOfBand)); + } + + private void OnResponseFrame(RespPrefix prefix, ReadOnlySequence payload) + { + if (payload.IsSingleSegment) + { +#if NETCOREAPP || NETSTANDARD2_1_OR_GREATER + OnResponseFrame(prefix, payload.FirstSpan, ref SharedNoLease); +#else + OnResponseFrame(prefix, payload.First.Span, ref SharedNoLease); +#endif + } + else + { + var len = checked((int)payload.Length); + byte[]? oversized = ArrayPool.Shared.Rent(len); + payload.CopyTo(oversized); + OnResponseFrame(prefix, new(oversized, 0, len), ref oversized); + + // the lease could have been claimed by the activation code (to prevent another memcpy); otherwise, free + if (oversized is not null) + { + ArrayPool.Shared.Return(oversized); + } + } + } + + [Conditional("DEBUG")] + private static void DebugValidateSingleFrame(ReadOnlySpan payload) + { + var reader = new RespReader(payload); + reader.MoveNext(); + reader.SkipChildren(); + + if (reader.TryMoveNext()) + { + throw new InvalidOperationException($"Unexpected trailing {reader.Prefix}"); + } + + if (reader.ProtocolBytesRemaining != 0) + { + var copy = reader; // leave reader alone for inspection + var prefix = copy.TryMoveNext() ? copy.Prefix : RespPrefix.None; + throw new InvalidOperationException( + $"Unexpected additional {reader.ProtocolBytesRemaining} bytes remaining, {prefix}"); + } + } + + [Conditional("DEBUG")] + private static void DebugValidateFrameCount(in ReadOnlySequence payload, int count) + { + var reader = new RespReader(payload); + while (count-- > 0) + { + reader.MoveNext(); + reader.SkipChildren(); + } + + if (reader.TryMoveNext()) + { + throw new InvalidOperationException($"Unexpected trailing {reader.Prefix}"); + } + + if (reader.ProtocolBytesRemaining != 0) + { + var copy = reader; // leave reader alone for inspection + var prefix = copy.TryMoveNext() ? copy.Prefix : RespPrefix.None; + throw new InvalidOperationException( + $"Unexpected additional {reader.ProtocolBytesRemaining} bytes remaining, {prefix}"); + } + } + + private void OnResponseFrame(RespPrefix prefix, ReadOnlySpan payload, ref byte[]? lease) + { + DebugValidateSingleFrame(payload); + if (prefix == RespPrefix.Push || + (prefix == RespPrefix.Array && Mode is RespMode.Resp2PubSub && !IsArrayPong(payload))) + { + // out-of-band; pub/sub etc + OnOutOfBand(payload, ref lease); + return; + } + + // request/response; match to inbound + if (_outstanding.TryDequeue(out var pending)) + { + ActivationHelper.ProcessResponse(pending, payload, ref lease); + } + else + { + Debug.Fail("Unexpected response without pending message!"); + } + + static bool IsArrayPong(ReadOnlySpan payload) + { + if (payload.Length >= sizeof(ulong)) + { + var raw = RespConstants.UnsafeCpuUInt64(payload); + if (raw == ArrayPong_LC_Bulk + || raw == ArrayPong_UC_Bulk + || raw == ArrayPong_LC_Simple + || raw == ArrayPong_UC_Simple) + { + var reader = new RespReader(payload); + return reader.TryMoveNext() // have root + && reader.Prefix == RespPrefix.Array // root is array + && reader.TryMoveNext() // have first child + && (reader.IsInlneCpuUInt32(pong) || reader.IsInlneCpuUInt32(PONG)); // pong + } + } + + return false; + } + } + + private int _writeStatus, _readStatus; + private const int WriterAvailable = 0, WriterTaken = 1, WriterDoomed = 2; + private const int ReaderActive = 0, ReaderFailed = 1, ReaderCompleted = 2; + + private void TakeWriter() + { + var status = Interlocked.CompareExchange(ref _writeStatus, WriterTaken, WriterAvailable); + if (status != WriterAvailable) ThrowWriterNotAvailable(); + Debug.Assert(Volatile.Read(ref _writeStatus) == WriterTaken, "writer should be taken"); + } + + private void ThrowWriterNotAvailable() + { + var fault = Volatile.Read(ref _fault); + var status = Volatile.Read(ref _writeStatus); + var msg = status switch + { + WriterTaken => "A write operation is already in progress; concurrent writes are not supported.", + WriterDoomed when fault is not null => "This connection is terminated; no further writes are possible: " + + fault.Message, + WriterDoomed => "This connection is terminated; no further writes are possible.", + _ => $"Unexpected writer status: {status}", + }; + throw fault is null ? new InvalidOperationException(msg) : new InvalidOperationException(msg, fault); + } + + private Exception? _fault; + + private void ReleaseWriter(int status = WriterAvailable) + { + if (status == WriterAvailable && _isDoomed) + { + status = WriterDoomed; + } + + Interlocked.CompareExchange(ref _writeStatus, status, WriterTaken); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private void OnRequestUnavailable(in RespOperation message) + { + if (!message.IsCompleted) + { + // make sure they know something is wrong + message.TrySetException(new InvalidOperationException("Request is not available")); + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private void EnqueueMultiMessage(in RespOperation operation, ReadOnlySpan operations) + { + // This typically *does not* include the batch message itself. + DebugCounters.OnMultiMessageWrite(operations.Length); + foreach (var message in operations) + { + _outstanding.Enqueue(message); + } + // The root message typically gets completed here - on the receiving side, all + // we see is N unrelated inbound messages; the batch terminates at write. + if (!operation.TrySetResultAfterUnloadingSubMessages()) + { + _outstanding.Enqueue(operation); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void Enqueue(in RespOperation operation) + { + if (operation.TryGetSubMessages(out var operations)) + { + // rare path - multi-message batch + EnqueueMultiMessage(in operation, operations); + } + else + { + _outstanding.Enqueue(operation); + } + } + + public override void Write(in RespOperation message) + { + bool releaseRequest = message.Message.TryReserveRequest(message.Token, out var bytes); + if (!releaseRequest) + { + OnRequestUnavailable(message); + return; + } + + DebugValidateFrameCount(bytes, message.MessageCount); + TakeWriter(); + try + { + Enqueue(in message); + releaseRequest = false; // once we write, only release on success + if (bytes.IsSingleSegment) + { +#if NETCOREAPP || NETSTANDARD2_1_OR_GREATER + tail.Write(bytes.FirstSpan); +#else + tail.Write(bytes.First); +#endif + DebugCounters.OnSyncWrite(bytes.First.Length); + } + else + { + WriteMultiSegment(tail, in bytes); + } + + ReleaseWriter(); + message.Message.ReleaseRequest(); + } + catch (Exception ex) + { + Debug.WriteLine($"Writer failed: {ex.Message}"); + ActivationHelper.DebugBreak(); + ReleaseWriter(WriterDoomed); + if (releaseRequest) message.Message.ReleaseRequest(); + OnConnectionError(ConnectionError, ex); + throw; + } + } + + private static void WriteMultiSegment(Stream tail, in ReadOnlySequence payload) + { + foreach (var segment in payload) + { +#if NETCOREAPP || NETSTANDARD2_1_OR_GREATER + tail.Write(segment.Span); +#else + tail.Write(segment); +#endif + DebugCounters.OnSyncWrite(segment.Length); + } + } + + private static async ValueTask WriteMultiSegmentAsync(Stream tail, ReadOnlySequence payload) + { + foreach (var segment in payload) + { + var pending = tail.WriteAsync(segment, CancellationToken.None); + DebugCounters.OnAsyncWrite(segment.Length, pending.IsCompleted); + await pending.ConfigureAwait(false); + } + } + + internal override void Write(ReadOnlySpan messages) + { + switch (messages.Length) + { + case 0: + return; + case 1: + Write(messages[0]); + return; + } + + TakeWriter(); + RespMessageBase? toRelease = null; + try + { + foreach (var message in messages) + { + if (message.Message.TryReserveRequest(message.Token, out var bytes)) + { + toRelease = message.Message; + } + else + { + OnRequestUnavailable(message); + continue; + } + + DebugValidateFrameCount(bytes, message.MessageCount); + Enqueue(in message); + toRelease = null; // once we write, only release on success + if (bytes.IsSingleSegment) + { +#if NETCOREAPP || NETSTANDARD2_1_OR_GREATER + tail.Write(bytes.FirstSpan); +#else + tail.Write(bytes.First); +#endif + DebugCounters.OnSyncWrite(bytes.First.Length); + } + else + { + WriteMultiSegment(tail, in bytes); + } + + ReleaseWriter(); + message.Message.ReleaseRequest(); + } + } + catch (Exception ex) + { + Debug.WriteLine($"Writer failed: {ex.Message}"); + ActivationHelper.DebugBreak(); + ReleaseWriter(WriterDoomed); + toRelease?.ReleaseRequest(); + foreach (var message in messages) + { + // assume all bad + message.Message.TrySetException(message.Token, ex); + } + + OnConnectionError(ConnectionError, ex); + throw; + } + } + + public override Task WriteAsync(in RespOperation message) + { + bool releaseRequest = message.Message.TryReserveRequest(message.Token, out var bytes); + if (!releaseRequest) + { + OnRequestUnavailable(message); + return Task.CompletedTask; + } + + DebugValidateFrameCount(bytes, message.MessageCount); + try + { + Enqueue(in message); + releaseRequest = false; // once we write, only release on success + ValueTask pendingWrite; + if (bytes.IsSingleSegment) + { + pendingWrite = tail.WriteAsync(bytes.First, CancellationToken.None); + DebugCounters.OnAsyncWrite(bytes.First.Length, pendingWrite.IsCompleted); + } + else + { + pendingWrite = WriteMultiSegmentAsync(tail, bytes); + } + + if (!pendingWrite.IsCompleted) + { + return AwaitedSingleWithToken(this, pendingWrite, message.Message); + } + pendingWrite.GetAwaiter().GetResult(); + ReleaseWriter(); + message.Message.ReleaseRequest(); + return Task.CompletedTask; + } + catch (Exception ex) + { + Debug.WriteLine($"Writer failed: {ex.Message}"); + ActivationHelper.DebugBreak(); + ReleaseWriter(WriterDoomed); + if (releaseRequest) message.Message.ReleaseRequest(); + OnConnectionError(ConnectionError, ex); + throw; + } + + static async Task AwaitedSingleWithToken( + StreamConnection @this, + ValueTask pendingWrite, + RespMessageBase message) + { + try + { + await pendingWrite.ConfigureAwait(false); + @this.ReleaseWriter(); + message.ReleaseRequest(); + } + catch (Exception ex) + { + @this.ReleaseWriter(WriterDoomed); + OnConnectionError(@this.ConnectionError, ex, $"{nameof(WriteAsync)}:{nameof(AwaitedSingleWithToken)}"); + throw; + } + } + } + + internal override Task WriteAsync(ReadOnlyMemory messages) + { + switch (messages.Length) + { + case 0: + return Task.CompletedTask; + case 1: + return WriteAsync(messages.Span[0]); + default: + return CombineAndSendMultipleAsync(this, messages); + } + } + + public override event EventHandler? ConnectionError; // use simple handler + + private async Task CombineAndSendMultipleAsync(StreamConnection @this, ReadOnlyMemory messages) + { + TakeWriter(); + RespMessageBase? toRelease = null; + int definitelySent = 0; + try + { + int length = messages.Length; + for (int i = 0; i < length; i++) + { + var message = messages.Span[i]; + if (!message.Message.TryReserveRequest(message.Token, out var bytes)) + { + OnRequestUnavailable(message); + continue; // skip this message + } + + DebugValidateFrameCount(bytes, message.MessageCount); + toRelease = message.Message; + // append to the scratch and consider written (even though we haven't actually) + _writeBuffer.Write(bytes); + toRelease = null; + message.Message.ReleaseRequest(); + @this.Enqueue(in message); + + // do we have any full segments? if so, write them and narrow "messages" + if (_writeBuffer.TryGetFirstCommittedMemory(CycleBuffer.GetFullPagesOnly, out var memory)) + { + do + { + var pending = tail.WriteAsync(memory, CancellationToken.None); + DebugCounters.OnAsyncWrite(memory.Length, inline: pending.IsCompleted); + await pending.ConfigureAwait(false); + DebugCounters.OnBatchWriteFullPage(); + + _writeBuffer.DiscardCommitted(memory.Length); // mark the data as no longer needed + } + // and if one buffer was full, we might have multiple (think: "large BLOB outbound") + while (_writeBuffer.TryGetFirstCommittedMemory(CycleBuffer.GetFullPagesOnly, out memory)); + + definitelySent = i + 1; // for exception handling: no need to doom these if later fails + } + } + + // and send any remaining data + while (_writeBuffer.TryGetFirstCommittedMemory(CycleBuffer.GetAnything, out var memory)) + { + var pending = tail.WriteAsync(memory, CancellationToken.None); + DebugCounters.OnAsyncWrite(memory.Length, inline: pending.IsCompleted); + await pending.ConfigureAwait(false); + DebugCounters.OnBatchWritePartialPage(); + + _writeBuffer.DiscardCommitted(memory.Length); // mark the data as no longer needed + } + + Debug.Assert(_writeBuffer.CommittedIsEmpty, "should have written everything"); + + ReleaseWriter(); + DebugCounters.OnBatchWrite(messages.Length); + } + catch (Exception ex) + { + Debug.WriteLine($"Writer failed: {ex.Message}"); + ActivationHelper.DebugBreak(); + ReleaseWriter(WriterDoomed); + toRelease?.ReleaseRequest(); + foreach (var message in messages.Span.Slice(start: definitelySent)) + { + message.Message.TrySetException(message.Token, ex); + } + + OnConnectionError(ConnectionError, ex); + throw; + } + } + + private void Doom() + { + _isDoomed = true; // without a reader, there's no point writing + Interlocked.CompareExchange(ref _writeStatus, WriterDoomed, WriterAvailable); + } + + protected override void OnDispose(bool disposing) + { + if (disposing) + { + _fault ??= new ObjectDisposedException(ToString()); + Doom(); + tail.Dispose(); + } + } + + protected override ValueTask OnDisposeAsync() + { + _fault ??= new ObjectDisposedException(ToString()); + Doom(); +#if COREAPP3_0_OR_GREATER || NETSTANDARD2_1_OR_GREATER + return tail.DisposeAsync().AsTask(); +#else + tail.Dispose(); + return default; +#endif + } +} diff --git a/src/RESPite/Connections/Internal/SynchronizedConnection.cs b/src/RESPite/Connections/Internal/SynchronizedConnection.cs new file mode 100644 index 000000000..b7e7cf53b --- /dev/null +++ b/src/RESPite/Connections/Internal/SynchronizedConnection.cs @@ -0,0 +1,201 @@ +using RESPite.Internal; + +namespace RESPite.Connections.Internal; + +internal sealed class SynchronizedConnection(in RespContext tail) : DecoratorConnection(tail) +{ + private readonly SemaphoreSlim _semaphore = new(1); + + protected override void OnDispose(bool disposing) + { + if (disposing) + { + _semaphore.Dispose(); + } + base.OnDispose(disposing); + } + + protected override ValueTask OnDisposeAsync() + { + _semaphore.Dispose(); + return base.OnDisposeAsync(); + } + + internal override bool IsHealthy => _semaphore.CurrentCount > 0 & base.IsHealthy; + public override void Write(in RespOperation message) + { + try + { + _semaphore.Wait(message.CancellationToken); + Tail.Write(message); + } + catch (Exception ex) + { + message.TrySetException(ex); + throw; + } + finally + { + _semaphore.Release(); + } + } + + internal override void Write(ReadOnlySpan messages) + { + switch (messages.Length) + { + case 0: return; + case 1: + Write(messages[0]); + return; + } + + try + { + _semaphore.Wait(messages[0].CancellationToken); + Tail.Write(messages); + } + catch (Exception ex) + { + MarkFaulted(messages, ex); + throw; + } + finally + { + _semaphore.Release(); + } + } + + public override Task WriteAsync(in RespOperation message) + { + bool haveLock = false; + try + { + haveLock = _semaphore.Wait(0); + if (!haveLock) + { + DebugCounters.OnPipelineFullAsync(); + return FullAsync(this, message); + } + + var pending = Tail.WriteAsync(message); + if (!pending.IsCompleted) + { + DebugCounters.OnPipelineSendAsync(); + haveLock = false; // transferring + return AwaitAndReleaseLock(pending); + } + + DebugCounters.OnPipelineFullSync(); + pending.GetAwaiter().GetResult(); + return Task.CompletedTask; + } + catch (Exception ex) + { + message.Message.TrySetException(message.Token, ex); + throw; + } + finally + { + if (haveLock) _semaphore.Release(); + } + + static async Task FullAsync(SynchronizedConnection @this, RespOperation message) + { + try + { + await @this._semaphore.WaitAsync(message.CancellationToken).ConfigureAwait(false); + } + catch (Exception ex) + { + message.Message.TrySetException(message.Token, ex); + throw; + } + + try + { + await @this.Tail.WriteAsync(message).ConfigureAwait(false); + } + finally + { + @this._semaphore.Release(); + } + } + } + + private async Task AwaitAndReleaseLock(Task pending) + { + try + { + await pending.ConfigureAwait(false); + } + finally + { + _semaphore.Release(); + } + } + + internal override Task WriteAsync(ReadOnlyMemory messages) + { + switch (messages.Length) + { + case 0: return Task.CompletedTask; + case 1: return WriteAsync(messages.Span[0]); + } + + bool haveLock = false; + try + { + haveLock = _semaphore.Wait(0); + if (!haveLock) + { + DebugCounters.OnPipelineFullAsync(); + return FullAsync(this, messages); + } + + var pending = Tail.WriteAsync(messages); + if (!pending.IsCompleted) + { + DebugCounters.OnPipelineSendAsync(); + haveLock = false; // transferring + return AwaitAndReleaseLock(pending); + } + + DebugCounters.OnPipelineFullSync(); + pending.GetAwaiter().GetResult(); + return Task.CompletedTask; + } + catch (Exception ex) + { + MarkFaulted(messages.Span, ex); + throw; + } + finally + { + if (haveLock) _semaphore.Release(); + } + + static async Task FullAsync(SynchronizedConnection @this, ReadOnlyMemory messages) + { + bool haveLock = false; // we don't have the lock initially + try + { + await @this._semaphore.WaitAsync(messages.Span[0].CancellationToken).ConfigureAwait(false); + haveLock = true; + await @this.Tail.WriteAsync(messages).ConfigureAwait(false); + } + catch (Exception ex) + { + MarkFaulted(messages.Span, ex); + throw; + } + finally + { + if (haveLock) + { + @this._semaphore.Release(); + } + } + } + } +} diff --git a/src/RESPite/Connections/RespConnectionExtensions.cs b/src/RESPite/Connections/RespConnectionExtensions.cs new file mode 100644 index 000000000..be9791255 --- /dev/null +++ b/src/RESPite/Connections/RespConnectionExtensions.cs @@ -0,0 +1,17 @@ +using RESPite.Connections.Internal; + +namespace RESPite.Connections; + +public static class RespConnectionExtensions +{ + /// + /// Enforces stricter ordering guarantees, so that unawaited async operations cannot cause overlapping writes. + /// + public static RespConnection Synchronized(this RespConnection connection) + => connection is SynchronizedConnection ? connection : new SynchronizedConnection(in connection.Context); + + public static RespConnection WithConfiguration(this RespConnection connection, RespConfiguration configuration) + => ReferenceEquals(configuration, connection.Configuration) + ? connection + : new ConfiguredConnection(in connection.Context, configuration); +} diff --git a/src/RESPite/Connections/RespConnectionFactory.cs b/src/RESPite/Connections/RespConnectionFactory.cs new file mode 100644 index 000000000..9077f6d86 --- /dev/null +++ b/src/RESPite/Connections/RespConnectionFactory.cs @@ -0,0 +1,135 @@ +using System.Globalization; +using System.Net; +using System.Net.Sockets; + +namespace RESPite.Connections; + +/// +/// Controls connection to endpoints. By default, this is TCP streams. +/// +// ReSharper disable once ClassWithVirtualMembersNeverInherited.Global +public class RespConnectionFactory +{ + private static RespConnectionFactory? _default, _defaultTls; + public static RespConnectionFactory Default => _default ??= new(); + public static RespConnectionFactory DefaultTls => _defaultTls ??= new(true); + protected RespConnectionFactory(bool tls = false) => _tls = tls; + private readonly bool _tls; + + public virtual string DefaultHost => "127.0.0.1"; + public virtual int DefaultPort => _tls ? 6380 : 6379; + + /// + /// Connect to the designated endpoint and return an open for the duplex + /// connection. + /// + /// The location to connect to; how this is interpreted is implementation-specific, + /// but will commonly be an IP address or DNS hostname. + /// The port to connect to, if appropriate. + /// The configuration for the connection. + /// Cancellation for the operation. + /// An open for the duplex connection. + public virtual async ValueTask ConnectAsync( + string endpoint, + int port, + RespConfiguration? configuration = null, + CancellationToken cancellationToken = default) + { + var ep = GetEndPoint(endpoint, port); + Socket socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + socket.NoDelay = true; +#if NET6_0_OR_GREATER + await socket.ConnectAsync(ep, cancellationToken).ConfigureAwait(false); +#else + // hack together cancellation via dispose + using (cancellationToken.Register( + static state => ((Socket)state).Dispose(), socket)) + { + try + { + await socket.ConnectAsync(ep).ConfigureAwait(false); + } + catch (ObjectDisposedException) when (cancellationToken.IsCancellationRequested) + { + throw new OperationCanceledException(cancellationToken); + } + catch (SocketException) when (cancellationToken.IsCancellationRequested) + { + throw new OperationCanceledException(cancellationToken); + } + } +#endif + var stream = new NetworkStream(socket); + var authed = await AuthenticateAsync(stream, cancellationToken).ConfigureAwait(false); + return RespConnection.Create(authed, configuration); + } + + protected virtual ValueTask AuthenticateAsync(Stream stream, CancellationToken cancellationToken) + { + if (_tls) throw new NotImplementedException("TLS"); + return new(stream); + } + + protected internal virtual EndPoint GetEndPoint(string endpoint, int port) + { + if (port == 0) port = DefaultPort; + if (string.IsNullOrWhiteSpace(endpoint)) + { + endpoint = DefaultHost; + } + + return endpoint switch + { + "127.0.0.1" => new IPEndPoint(IPAddress.Loopback, port), + "::1" or "0:0:0:0:0:0:0:1" => new IPEndPoint(IPAddress.IPv6Loopback, port), + _ when IPAddress.TryParse(endpoint, out var address) => new IPEndPoint(address, port), + _ => new DnsEndPoint(endpoint, port), + }; + } + + public virtual bool TryParse(EndPoint endpoint, out string host, out int port) + { + if (endpoint is DnsEndPoint dns) + { + host = dns.Host switch + { + "localhost" or "." => "127.0.0.1", + _ => dns.Host, + }; + port = dns.Port; + return true; + } + + if (endpoint is IPEndPoint ip) + { + host = ip.Address.ToString(); + port = ip.Port; + return true; + } + + host = ""; + port = 0; + return false; + } + + public virtual bool TryParse(string hostAndPort, out string host, out int port) + { + int i = hostAndPort.LastIndexOf(':'); + if (i < 0) + { + host = hostAndPort; + port = 0; + return true; + } + + host = hostAndPort.Substring(0, i); + if (int.TryParse(hostAndPort.Substring(i + 1), NumberStyles.Integer, CultureInfo.InvariantCulture, out port)) + { + return true; + } + + host = hostAndPort; + port = 0; + return false; + } +} diff --git a/src/RESPite/Connections/RespConnectionManager.cs b/src/RESPite/Connections/RespConnectionManager.cs new file mode 100644 index 000000000..c8672f04a --- /dev/null +++ b/src/RESPite/Connections/RespConnectionManager.cs @@ -0,0 +1,253 @@ +using System.Buffers; +using System.Diagnostics.CodeAnalysis; +using System.Net; +using RESPite.Connections.Internal; +using RESPite.Internal; + +namespace RESPite.Connections; + +public sealed class RespConnectionManager : IRespContextSource +{ + /// + public override string ToString() => GetType().Name; + + // the routed connection performs message-inspection based routing; on a single node + // instance that isn't necessary, so the default-connection abstracts over that: + // in a single-node instance, the default-connection will be the single interactive connection + // otherwise, the default-connection will be the routed connection + private RoutedConnection? _routedConnection; + private RespContext _defaultContext = RespContext.Null; + internal ref readonly RespContext Context => ref _defaultContext; + ref readonly RespContext IRespContextSource.Context => ref _defaultContext; + + private readonly CancellationTokenSource _lifetime = new(); + + private RespConnectionFactory? _factory; + + public RespConnectionFactory ConnectionFactory + { + get => _factory ??= RespConnectionFactory.Default; + set + { + // ReSharper disable once JoinNullCheckWithUsage + if (value is null) throw new ArgumentNullException(nameof(ConnectionFactory)); + _factory = value; + } + } + + private Node[] _nodes = []; + internal CancellationToken Lifetime => _lifetime.Token; + private RespConfiguration? _options; + internal RespConfiguration Options => _options ?? ThrowNotConnected(); + + [DoesNotReturn] + private RespConfiguration ThrowNotConnected() + => throw new InvalidOperationException($"The {GetType().Name} has not been connected."); + + internal readonly struct EndpointPair(string endpoint, int port) + { + public override string ToString() => $"{Endpoint}:{Port}"; + + public readonly string Endpoint = endpoint; + public readonly int Port = port; + public override int GetHashCode() => (Endpoint?.GetHashCode() ?? 0) ^ Port; + + public override bool Equals(object? obj) => obj is EndpointPair other && + (Endpoint == other.Endpoint & Port == other.Port); + } + + private void OnConnect(RespConfiguration options, ReadOnlySpan endpoints) + { + if (options is null) throw new ArgumentNullException(nameof(options)); + if (Interlocked.CompareExchange(ref _options, options, null) is not null) + { + throw new InvalidOperationException($"A {GetType().Name} can only be connected once."); + } + + var nodes = new Node[Math.Max(endpoints.Length, 1)]; + var factory = ConnectionFactory; + if (endpoints.IsEmpty) + { + nodes[0] = new Node(this, factory.DefaultHost, factory.DefaultPort); + } + else + { + for (int i = 0; i < endpoints.Length; i++) + { + var host = endpoints[i].Endpoint; + if (string.IsNullOrWhiteSpace(host) || host is "." or "localhost") + host = "127.0.0.1"; + var port = endpoints[i].Port; + if (port == 0) port = factory.DefaultPort; + nodes[i] = new Node(this, host, port); + } + } + + _nodes = nodes; + } + + internal void Connect(RespConfiguration options, ReadOnlySpan endpoints, TextWriter? log = null) + // use sync over async; reduce code-duplication, and sync wouldn't add anything + => ConnectAsync(options, endpoints, log).Wait(Lifetime); + + internal Task ConnectAsync(RespConfiguration options, ReadOnlySpan endpoints, TextWriter? log = null) + { + OnConnect(options, endpoints); + var snapshot = _nodes; + log.LogLocked($"Connecting to {snapshot.Length} nodes..."); + Task[] pending = new Task[snapshot.Length]; + for (int i = 0; i < snapshot.Length; i++) + { + pending[i] = snapshot[i].ConnectAsync(log); + } + + return ConnectAsyncAwaited(pending, log, snapshot.Length); + } + + private async Task ConnectAsyncAwaited(Task[] pending, TextWriter? log, int nodeCount) + { + await Task.WhenAll(pending).ConfigureAwait(false); + int success = 0; + foreach (var task in pending) + { + // note WhenAll ensures all connected + if (task.Result) success++; + } + + // configure our primary connection + OnNodesChanged(); + + log.LogLocked($"Connected to {success} of {nodeCount} nodes."); + } + + public void Dispose() + { + var routed = _routedConnection; + _routedConnection = null; + _defaultContext = NullConnection.Disposed.Context; + _lifetime.Cancel(); + routed?.Dispose(); + foreach (var node in _nodes) + { + node.Dispose(); + } + } + + public async ValueTask DisposeAsync() + { + var routed = _routedConnection; + _routedConnection = null; + _defaultContext = NullConnection.Disposed.Context; +#if NET8_0_OR_GREATER + await _lifetime.CancelAsync().ConfigureAwait(false); +#else + _lifetime.Cancel(); +#endif + if (routed is not null) + { + await routed.DisposeAsync().ConfigureAwait(false); + } + + foreach (var node in _nodes) + { + await node.DisposeAsync().ConfigureAwait(false); + } + } + + public string ClientName { get; private set; } = ""; + public int TimeoutMilliseconds => (int)Options.SyncTimeout.TotalMilliseconds; + public long OperationCount => 0; + + public bool PreserveAsyncOrder + { + get => false; + [Obsolete("This feature is no longer supported", false)] + set { } + } + + public bool IsConnected + { + get + { + foreach (var node in _nodes) + { + if (node.IsConnected) return true; + } + + return false; + } + } + + public bool IsConnecting + { + get + { + foreach (var node in _nodes) + { + if (node.IsConnecting) return true; + } + + return false; + } + } + + private void OnNodesChanged() + { + var nodes = _nodes; + _defaultContext = nodes.Length switch + { + 0 => NullConnection.NonRoutable.Context, // nowhere to go + 1 => nodes[0] is { IsConnected: true } conn + ? conn.Context + : NullConnection.NonRoutable.Context, // nowhere to go + _ => BuildRouted(nodes), + }; + } + + private ref readonly RespContext BuildRouted(Node[] nodes) + { + Shard[] oversized = ArrayPool.Shared.Rent(nodes.Length); + for (int i = 0; i < nodes.Length; i++) + { + oversized[i] = nodes[i].AsShard(); + } + + Array.Sort(oversized, 0, nodes.Length); + var conn = _routedConnection ??= new(); + conn.SetRoutingTable(new ReadOnlySpan(oversized, 0, nodes.Length)); + ArrayPool.Shared.Return(oversized); + return ref conn.Context; + } + + internal Node GetNode(string host, int port) + { + foreach (var node in _nodes) + { + if (node.EndPoint == host && node.Port == port) return node; + } + + throw new KeyNotFoundException($"No node found for {host}:{port}"); + } + + internal Node GetNode(string hostAndPort) => ConnectionFactory.TryParse(hostAndPort, out var host, out var port) + ? GetNode(host, port) + : throw new ArgumentException($"Could not parse host and port from '{hostAndPort}'", nameof(hostAndPort)); + + internal Node? GetRandomNode() + { + var nodes = _nodes; + if (nodes is { Length: > 0 }) + { + var index = SharedRandom.Next(nodes.Length); + return nodes[index]; + } + + return null; + } + +#if NET5_0_OR_GREATER + private static Random SharedRandom => Random.Shared; +#else + private static Random SharedRandom { get; } = new(); +#endif +} diff --git a/src/RESPite/Connections/RespConnectionPool.cs b/src/RESPite/Connections/RespConnectionPool.cs new file mode 100644 index 000000000..a1e4bd191 --- /dev/null +++ b/src/RESPite/Connections/RespConnectionPool.cs @@ -0,0 +1,214 @@ +using System.Collections.Concurrent; +using System.ComponentModel; +using System.Net; +using System.Net.Sockets; +using RESPite.Connections.Internal; +using RESPite.Internal; + +namespace RESPite.Connections; + +public sealed class RespConnectionPool : IDisposable +{ + private const int DefaultCount = 10; + private bool _isDisposed; + + [Obsolete("This is for testing only")] + [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)] + public bool UseCustomNetworkStream { get; set; } + + private readonly ConcurrentQueue _pool = []; + private readonly Func> _createConnection; + private readonly int _count; + private readonly RespContext _defaultTemplate; + + public ref readonly RespContext Template => ref _defaultTemplate; + + public event EventHandler? ConnectionError; + + private void OnConnectionError(object? sender, RespConnection.RespConnectionErrorEventArgs e) + => ConnectionError?.Invoke(this, e); // mask sender + + private readonly EventHandler _onConnectionError; + + public RespConnectionPool(int count = DefaultCount) : this(RespContext.Null, "127.0.0.1", 6379, count) + { + } + + public RespConnectionPool( + in RespContext template, + Func> createConnection, + int count = DefaultCount) + { + _createConnection = createConnection; + _count = count; + template.CancellationToken.ThrowIfCancellationRequested(); + // swap out the connection for a dummy (retaining the configuration) + var configuredConnection = NullConnection.WithConfiguration(template.Connection.Configuration); + _defaultTemplate = template.WithConnection(configuredConnection); + _onConnectionError = OnConnectionError; + } + + public RespConnectionPool( + in RespContext template, + string endpoint, + int port, + int count = DefaultCount, + RespConnectionFactory? connectionFactory = null) + : this(template, MakeCreateConnection(endpoint, port, connectionFactory), count) + { + } + + private static Func> MakeCreateConnection( + string endpoint, + int port, + RespConnectionFactory? connectionFactory) + { + connectionFactory ??= RespConnectionFactory.Default; + return (config, cancellationToken) + => connectionFactory.ConnectAsync(endpoint, port, config, cancellationToken); + } + + /// + /// Borrow a connection from the pool, using the default template. + /// + public RespConnection GetConnection(CancellationToken cancellationToken = default) + { + if (cancellationToken.CanBeCanceled) + { + var context = _defaultTemplate.WithCancellationToken(cancellationToken); + return GetConnection(in context); + } + else + { + return GetConnection(in _defaultTemplate); + } + } + + public RespConnection GetConnection(in RespContext template) // sync over async + { + var pending = GetConnectionAsync(in template); + if (!pending.IsCompleted) return pending.AsTask().GetAwaiter().GetResult(); + return pending.GetAwaiter().GetResult(); + } + + /// + /// Borrow a connection from the pool, using the default template. + /// + public ValueTask GetConnectionAsync(CancellationToken cancellationToken = default) + { + if (cancellationToken.CanBeCanceled) + { + var context = _defaultTemplate.WithCancellationToken(cancellationToken); + return GetConnectionAsync(in context); + } + else + { + return GetConnectionAsync(in _defaultTemplate); + } + } + + /// + /// Borrow a connection from the pool. + /// + /// The template context to use for the leased connection; everything except the connection + /// will be inherited by the new context. + public ValueTask GetConnectionAsync(in RespContext template) + { + ThrowIfDisposed(); + template.CancellationToken.ThrowIfCancellationRequested(); + + if (_pool.TryDequeue(out var connection)) return new(connection); + + var pending = _createConnection(template.Connection.Configuration, template.CancellationToken); + if (!pending.IsCompleted) return Awaited(template, pending); + + connection = pending.GetAwaiter().GetResult(); + connection.ConnectionError += _onConnectionError; + connection = new PoolWrapper(this, template.WithConnection(connection)); + return new(connection); + } + + private async ValueTask Awaited(RespContext template, ValueTask pending) + { + var connection = await pending.ConfigureAwait(false); + connection.ConnectionError += _onConnectionError; + return new PoolWrapper(this, template.WithConnection(connection)); + } + + private void ThrowIfDisposed() + { + if (_isDisposed) Throw(); + static void Throw() => throw new ObjectDisposedException(nameof(RespConnectionPool)); + } + + public void Dispose() + { + _isDisposed = true; + while (_pool.TryDequeue(out var connection)) + { + connection.Dispose(); + } + } + + private void Return(RespConnection tail) + { + if (_isDisposed || !tail.IsHealthy || _pool.Count >= _count) + { + tail.Dispose(); + } + else + { + _pool.Enqueue(tail); + } + } + + private sealed class PoolWrapper( + RespConnectionPool pool, + in RespContext tail) : DecoratorConnection(tail) + { + protected override bool OwnsConnection => false; + + private const string ConnectionErrorNotSupportedMessage = + $"{nameof(ConnectionError)} events are not supported on pooled connections; use {nameof(RespConnectionPool)}.{nameof(RespConnectionPool.ConnectionError)} instead"; + + public override event EventHandler? ConnectionError + { + add => throw new NotSupportedException(ConnectionErrorNotSupportedMessage); + remove => throw new NotSupportedException(ConnectionErrorNotSupportedMessage); + } + + protected override void OnDispose(bool disposing) + { + if (disposing) + { + pool.Return(Tail); + } + + base.OnDispose(disposing); + } + + public override void Write(in RespOperation message) + { + ThrowIfDisposed(); + Tail.Write(message); + } + + internal override void Write(ReadOnlySpan messages) + { + ThrowIfDisposed(); + Tail.Write(messages); + } + + public override Task WriteAsync(in RespOperation message) + { + ThrowIfDisposed(); + return Tail.WriteAsync(message); + } + + internal override Task WriteAsync(ReadOnlyMemory messages) + { + ThrowIfDisposed(); + return Tail.WriteAsync(messages); + } + } +} diff --git a/src/RESPite/Internal/ActivationHelper.cs b/src/RESPite/Internal/ActivationHelper.cs new file mode 100644 index 000000000..8abd9bac4 --- /dev/null +++ b/src/RESPite/Internal/ActivationHelper.cs @@ -0,0 +1,116 @@ +using System.Buffers; +using System.Diagnostics; + +namespace RESPite.Internal; + +internal static class ActivationHelper +{ + private sealed class WorkItem +#if NETCOREAPP3_0_OR_GREATER + : IThreadPoolWorkItem +#endif + { + private WorkItem() + { +#if NET5_0_OR_GREATER + System.Runtime.CompilerServices.Unsafe.SkipInit(out _payload); +#else + _payload = []; +#endif + } + + private void Init(byte[] payload, int length, in RespOperation message) + { + _payload = payload; + _length = length; + _message = message; + } + + private byte[] _payload; + private int _length; + private RespOperation _message; + + private static WorkItem? _spare; // do NOT use ThreadStatic - different producer/consumer, no overlap + + public static void UnsafeQueueUserWorkItem( + in RespOperation message, + ReadOnlySpan payload, + ref byte[]? lease) + { + if (lease is null) + { + // we need to create our own copy of the data + lease = ArrayPool.Shared.Rent(payload.Length); + payload.CopyTo(lease); + } + + var obj = Interlocked.Exchange(ref _spare, null) ?? new(); + obj.Init(lease, payload.Length, message); + lease = null; // count as claimed + + DebugCounters.OnCopyOut(payload.Length); +#if NETCOREAPP3_0_OR_GREATER + ThreadPool.UnsafeQueueUserWorkItem(obj, false); +#else + ThreadPool.UnsafeQueueUserWorkItem(WaitCallback, obj); +#endif + } +#if !NETCOREAPP3_0_OR_GREATER + private static readonly WaitCallback WaitCallback = state => ((WorkItem)state!).Execute(); +#endif + + public void Execute() + { + var message = _message; + var payload = _payload; + var length = _length; + _message = default; + _payload = []; + _length = 0; + Interlocked.Exchange(ref _spare, this); + var msg = message; + msg.Message.TrySetResult(msg.Token, new ReadOnlySpan(payload, 0, length)); + ArrayPool.Shared.Return(payload); + } + } + + public static void ProcessResponse(in RespOperation pending, ReadOnlySpan payload, ref byte[]? lease) + { + var msg = pending.Message; + if (msg.AllowInlineParsing) + { + msg.TrySetResult(pending.Token, payload); + } + else + { + WorkItem.UnsafeQueueUserWorkItem(pending, payload, ref lease); + } + } + + private static readonly Action CancellationCallback = static state + => ((RespMessageBase)state!).TrySetCanceledTrustToken(); + + public static CancellationTokenRegistration RegisterForCancellation( + RespMessageBase message, + CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + return cancellationToken.Register(CancellationCallback, message); + } + + [Conditional("DEBUG")] + public static void DebugBreak() + { +#if DEBUG + if (Debugger.IsAttached) Debugger.Break(); +#endif + } + + [Conditional("DEBUG")] + public static void DebugBreakIf(bool condition) + { +#if DEBUG + if (condition && Debugger.IsAttached) Debugger.Break(); +#endif + } +} diff --git a/src/RESPite/Internal/BlockBuffer.cs b/src/RESPite/Internal/BlockBuffer.cs new file mode 100644 index 000000000..752d74c8d --- /dev/null +++ b/src/RESPite/Internal/BlockBuffer.cs @@ -0,0 +1,341 @@ +using System.Buffers; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace RESPite.Internal; + +internal abstract partial class BlockBufferSerializer +{ + internal sealed class BlockBuffer : MemoryManager + { + private BlockBuffer(BlockBufferSerializer parent, int minCapacity) + { + _arrayPool = parent._arrayPool; + _array = _arrayPool.Rent(minCapacity); + DebugCounters.OnBufferCapacity(_array.Length); +#if DEBUG + _parent = parent; + parent.DebugBufferCreated(); +#endif + } + + private int _refCount = 1; + private int _finalizedOffset, _writeOffset; + private readonly ArrayPool _arrayPool; + private byte[] _array; +#if DEBUG + private int _finalizedCount; + private BlockBufferSerializer _parent; +#endif + + public override string ToString() => +#if DEBUG + $"{_finalizedCount} messages; " + +#endif + $"{_finalizedOffset} finalized bytes; writing: {NonFinalizedData.Length} bytes, {Available} available; observers: {_refCount}"; + + // only used when filling; _buffer should be non-null + private int Available => _array.Length - _writeOffset; + public Memory UncommittedMemory => _array.AsMemory(_writeOffset); + public Span UncommittedSpan => _array.AsSpan(_writeOffset); + + // decrease ref-count; dispose if necessary + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Release() + { + if (Interlocked.Decrement(ref _refCount) <= 0) Recycle(); + } + + public void AddRef() + { + if (!TryAddRef()) Throw(); + static void Throw() => throw new ObjectDisposedException(nameof(BlockBuffer)); + } + + public bool TryAddRef() + { + int count; + do + { + count = Volatile.Read(ref _refCount); + if (count <= 0) return false; + } + // repeat until we can successfully swap/incr + while (Interlocked.CompareExchange(ref _refCount, count + 1, count) != count); + + return true; + } + + [MethodImpl(MethodImplOptions.NoInlining)] // called rarely vs Dispose + private void Recycle() + { + var count = Volatile.Read(ref _refCount); + if (count == 0) + { + _array.DebugScramble(); +#if DEBUG + GC.SuppressFinalize(this); // only have a finalizer in debug + _parent.DebugBufferRecycled(_array.Length); +#endif + _arrayPool.Return(_array); + _array = []; + } + + Debug.Assert(count == 0, $"over-disposal? count={count}"); + } + +#if DEBUG +#pragma warning disable CA2015 // Adding a finalizer to a type derived from MemoryManager may permit memory to be freed while it is still in use by a Span + // (the above is fine because we don't actually release anything - just a counter) + ~BlockBuffer() + { + _parent.DebugBufferLeaked(); + DebugCounters.OnBufferLeaked(); + } +#pragma warning restore CA2015 +#endif + + public static BlockBuffer GetBuffer(BlockBufferSerializer parent, int sizeHint) + { + // note this isn't an actual "max", just a max of what we guarantee; we give the caller + // whatever is left in the buffer; the clamped hint just decides whether we need a *new* buffer + const int MinSize = 16, MaxSize = 128; + sizeHint = Math.Min(Math.Max(sizeHint, MinSize), MaxSize); + + var buffer = parent.Buffer; // most common path is "exists, with enough data" + return buffer is not null && buffer.AvailableWithResetIfUseful() >= sizeHint + ? buffer + : GetBufferSlow(parent, sizeHint); + } + + // would it be useful and possible to reset? i.e. if all finalized chunks have been returned, + private int AvailableWithResetIfUseful() + { + if (_finalizedOffset != 0 // at least some chunks have been finalized + && Volatile.Read(ref _refCount) == 1 // all finalized chunks returned + & _writeOffset == _finalizedOffset) // we're not in the middle of serializing something new + { + _writeOffset = _finalizedOffset = 0; // swipe left + } + + return _array.Length - _writeOffset; + } + + private static BlockBuffer GetBufferSlow(BlockBufferSerializer parent, int minBytes) + { + // note clamp on size hint has already been applied + const int DefaultBufferSize = 2048; + var buffer = parent.Buffer; + if (buffer is null) + { + // first buffer + return parent.Buffer = new BlockBuffer(parent, DefaultBufferSize); + } + + Debug.Assert(minBytes > buffer.Available, "existing buffer has capacity - why are we here?"); + + if (buffer.TryResizeFor(minBytes)) + { + Debug.Assert(buffer.Available >= minBytes); + return buffer; + } + + // We've tried reset and resize - no more tricks; we need to move to a new buffer, starting with a + // capacity for any existing data in this message, plus the new chunk we're adding. + var nonFinalizedBytes = buffer.NonFinalizedData; + var newBuffer = new BlockBuffer(parent, Math.Max(nonFinalizedBytes.Length + minBytes, DefaultBufferSize)); + + // copy the existing message data, if any (the previous message might have finished near the + // boundary, in which case we might not have written anything yet) + newBuffer.CopyFrom(nonFinalizedBytes); + Debug.Assert(newBuffer.Available >= minBytes, "should have requested extra capacity"); + + // the ~emperor~ buffer is dead; long live the ~emperor~ buffer + parent.Buffer = newBuffer; + buffer.MarkComplete(parent); + return newBuffer; + } + + // used for elective reset (rather than "because we ran out of space") + public static void Clear(BlockBufferSerializer parent) + { + if (parent.Buffer is { } buffer) + { + parent.Buffer = null; + buffer.MarkComplete(parent); + } + } + + public static ReadOnlyMemory RetainCurrent(BlockBufferSerializer parent) + { + if (parent.Buffer is { } buffer && buffer._finalizedOffset != 0) + { + parent.Buffer = null; + buffer.AddRef(); + return buffer.CreateMemory(0, buffer._finalizedOffset); + } + // nothing useful to detach! + return default; + } + + private void MarkComplete(BlockBufferSerializer parent) + { + // record that the old buffer no longer logically has any non-committed bytes (mostly just for ToString()) + _writeOffset = _finalizedOffset; + Debug.Assert(IsNonCommittedEmpty); + + // see if the caller wants to take ownership of the segment + if (_finalizedOffset != 0 && !parent.ClaimSegment(CreateMemory(0, _finalizedOffset))) + { + Release(); // decrement the observer + } +#if DEBUG + DebugCounters.OnBufferCompleted(_finalizedCount, _finalizedOffset); +#endif + } + + private void CopyFrom(Span source) + { + source.CopyTo(UncommittedSpan); + _writeOffset += source.Length; + } + + private Span NonFinalizedData => _array.AsSpan( + _finalizedOffset, _writeOffset - _finalizedOffset); + + private bool TryResizeFor(int extraBytes) + { + if (_finalizedOffset == 0 & // we can only do this if there are no other messages in the buffer + Volatile.Read(ref _refCount) == 1) // and no-one else is looking (we already tried reset) + { + // we're already on the boundary - don't scrimp; just do the math from the end of the buffer + byte[] newArray = _arrayPool.Rent(_array.Length + extraBytes); + DebugCounters.OnBufferCapacity(newArray.Length - _array.Length); // account for extra only + + // copy the existing data (we always expect some, since we've clamped extraBytes to be + // much smaller than the default buffer size) + NonFinalizedData.CopyTo(newArray); + _array.DebugScramble(); + _arrayPool.Return(_array); + _array = newArray; + return true; + } + + return false; + } + + public static void Advance(BlockBufferSerializer parent, int count) + { + if (count == 0) return; + if (count < 0) ThrowOutOfRange(); + var buffer = parent.Buffer; + if (buffer is null || buffer.Available < count) ThrowOutOfRange(); + buffer._writeOffset += count; + + [DoesNotReturn] + static void ThrowOutOfRange() => throw new ArgumentOutOfRangeException(nameof(count)); + } + + public void RevertUnfinalized(BlockBufferSerializer parent) + { + // undo any writes (something went wrong during serialize) + _finalizedOffset = _writeOffset; + } + + private ReadOnlyMemory FinalizeBlock() + { + var length = _writeOffset - _finalizedOffset; + Debug.Assert(length > 0, "already checked this in FinalizeMessage!"); + var chunk = CreateMemory(_finalizedOffset, length); + _finalizedOffset = _writeOffset; // move the write head +#if DEBUG + _finalizedCount++; + _parent.DebugMessageFinalized(length); +#endif + Interlocked.Increment(ref _refCount); // add an observer + return chunk; + } + + private bool IsNonCommittedEmpty => _finalizedOffset == _writeOffset; + + public static ReadOnlyMemory FinalizeMessage(BlockBufferSerializer parent) + { + var buffer = parent.Buffer; + if (buffer is null || buffer.IsNonCommittedEmpty) + { +#if DEBUG // still count it for logging purposes + if (buffer is not null) buffer._finalizedCount++; + parent.DebugMessageFinalized(0); +#endif + return default; + } + + return buffer.FinalizeBlock(); + } + + // MemoryManager pieces + protected override void Dispose(bool disposing) + { + if (disposing) Release(); + } + + public override Span GetSpan() => _array; + public int Length => _array.Length; + + // base version is CreateMemory(GetSpan().Length); avoid that GetSpan() + public override Memory Memory => CreateMemory(_array.Length); + + public override unsafe MemoryHandle Pin(int elementIndex = 0) + { + // We *could* be cute and use a shared pin - but that's a *lot* + // of work (synchronization), requires extra storage, and for an + // API that is very unlikely; hence: we'll use per-call GC pins. + GCHandle handle = GCHandle.Alloc(_array, GCHandleType.Pinned); + DebugCounters.OnBufferPinned(); // prove how unlikely this is + byte* ptr = (byte*)handle.AddrOfPinnedObject(); + // note no IPinnable in the MemoryHandle; + return new MemoryHandle(ptr + elementIndex, handle); + } + + // This would only be called if we passed out a MemoryHandle with ourselves + // as IPinnable (in Pin), which: we don't. + public override void Unpin() => throw new NotSupportedException(); + + protected override bool TryGetArray(out ArraySegment segment) + { + segment = new ArraySegment(_array); + return true; + } + + internal static void Release(in ReadOnlySequence request) + { + if (request.IsSingleSegment) + { + if (MemoryMarshal.TryGetMemoryManager( + request.First, out var block)) + { + block.Release(); + } + } + else + { + ReleaseMultiBlock(in request); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + static void ReleaseMultiBlock(in ReadOnlySequence request) + { + foreach (var segment in request) + { + if (MemoryMarshal.TryGetMemoryManager( + segment, out var block)) + { + block.Release(); + } + } + } + } + } +} diff --git a/src/RESPite/Internal/BlockBufferSerializer.cs b/src/RESPite/Internal/BlockBufferSerializer.cs new file mode 100644 index 000000000..22ff5361f --- /dev/null +++ b/src/RESPite/Internal/BlockBufferSerializer.cs @@ -0,0 +1,93 @@ +using System.Buffers; +using System.Diagnostics; +using RESPite.Messages; + +namespace RESPite.Internal; + +/// +/// Provides abstracted access to a buffer-writing API. Conveniently, we only give the caller +/// RespWriter - which they cannot export (ref-type), thus we never actually give the +/// public caller our IBufferWriter{byte}. Likewise, note that serialization is synchronous, +/// i.e. never switches thread during an operation. This gives us quite a bit of flexibility. +/// There are two main uses of BlockBufferSerializer: +/// 1. thread-local: ambient, used for random messages so that each thread is quietly packing +/// a thread-specific buffer; zero concurrency because of [ThreadStatic] hackery. +/// 2. batching: RespBatch hosts a serializer that reflects the batch we're building; successive +/// commands in the same batch are written adjacently in a shared buffer - we explicitly +/// detect and reject concurrency attempts in a batch (which is fair: a batch has order). +/// +internal abstract partial class BlockBufferSerializer(ArrayPool? arrayPool = null) : IBufferWriter +{ + private readonly ArrayPool _arrayPool = arrayPool ?? ArrayPool.Shared; + private protected abstract BlockBuffer? Buffer { get; set; } + + Memory IBufferWriter.GetMemory(int sizeHint) => BlockBuffer.GetBuffer(this, sizeHint).UncommittedMemory; + + Span IBufferWriter.GetSpan(int sizeHint) => BlockBuffer.GetBuffer(this, sizeHint).UncommittedSpan; + + void IBufferWriter.Advance(int count) => BlockBuffer.Advance(this, count); + + public virtual void Clear() => BlockBuffer.Clear(this); + + internal virtual ReadOnlySequence Flush() => throw new NotSupportedException(); + + public virtual ReadOnlyMemory Serialize( + RespCommandMap? commandMap, + ReadOnlySpan command, + in TRequest request, + IRespFormatter formatter) +#if NET9_0_OR_GREATER + where TRequest : allows ref struct +#endif + { + try + { + var writer = new RespWriter(this); + writer.CommandMap = commandMap; + formatter.Format(command, ref writer, request); + writer.Flush(); + return BlockBuffer.FinalizeMessage(this); + } + catch + { + Buffer?.RevertUnfinalized(this); + throw; + } + } + + protected virtual bool ClaimSegment(ReadOnlyMemory segment) => false; + +#if DEBUG + private int _countAdded, _countRecycled, _countLeaked, _countMessages; + private long _countMessageBytes; + public int CountLeaked => Volatile.Read(ref _countLeaked); + public int CountRecycled => Volatile.Read(ref _countRecycled); + public int CountAdded => Volatile.Read(ref _countAdded); + public int CountMessages => Volatile.Read(ref _countMessages); + public long CountMessageBytes => Volatile.Read(ref _countMessageBytes); + + [Conditional("DEBUG")] + private void DebugBufferLeaked() => Interlocked.Increment(ref _countLeaked); + + [Conditional("DEBUG")] + private void DebugBufferRecycled(int length) + { + Interlocked.Increment(ref _countRecycled); + DebugCounters.OnBufferRecycled(length); + } + + [Conditional("DEBUG")] + private void DebugBufferCreated() + { + Interlocked.Increment(ref _countAdded); + DebugCounters.OnBufferCreated(); + } + + [Conditional("DEBUG")] + private void DebugMessageFinalized(int bytes) + { + Interlocked.Increment(ref _countMessages); + Interlocked.Add(ref _countMessageBytes, bytes); + } +#endif +} diff --git a/src/RESPite/Internal/CycleBuffer.cs b/src/RESPite/Internal/CycleBuffer.cs new file mode 100644 index 000000000..a0d827679 --- /dev/null +++ b/src/RESPite/Internal/CycleBuffer.cs @@ -0,0 +1,706 @@ +using System.Buffers; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +#pragma warning disable SA1205 // accessibility on partial - for debugging/test practicality + +namespace RESPite.Internal; + +/// +/// Manages the state for a based IO buffer. Unlike Pipe, +/// it is not intended for a separate producer-consumer - there is no thread-safety, and no +/// activation; it just handles the buffers. It is intended to be used as a mutable (non-readonly) +/// field in a type that performs IO; the internal state mutates - it should not be passed around. +/// +/// Notionally, there is an uncommitted area (write) and a committed area (read). Process: +/// - producer loop (*note no concurrency**) +/// - call to get a new scratch +/// - (write to that span) +/// - call to mark complete portions +/// - consumer loop (*note no concurrency**) +/// - call to see if there is a single-span chunk; otherwise +/// - call to get the multi-span chunk +/// - (process none, some, or all of that data) +/// - call to indicate how much data is no longer needed +/// Emphasis: no concurrency! This is intended for a single worker acting as both producer and consumer. +/// +/// There is a *lot* of validation in debug mode; we want to be super sure that we don't corrupt buffer state. +/// +internal partial struct CycleBuffer +{ + // note: if someone uses an uninitialized CycleBuffer (via default): that's a skills issue; git gud + public static CycleBuffer Create(MemoryPool? pool = null, int pageSize = DefaultPageSize) + { + pool ??= MemoryPool.Shared; + if (pageSize <= 0) pageSize = DefaultPageSize; + if (pageSize > pool.MaxBufferSize) pageSize = pool.MaxBufferSize; + + return new CycleBuffer(pool, pageSize); + } + + private CycleBuffer(MemoryPool pool, int pageSize) + { + Pool = pool; + PageSize = pageSize; + } + + private const int DefaultPageSize = 8 * 1024; + + public int PageSize { get; } + public MemoryPool Pool { get; } + + private Segment? startSegment, endSegment; + + private int endSegmentCommitted, endSegmentLength; + + public bool TryGetCommitted(out ReadOnlySpan span) + { + DebugAssertValid(); + if (!ReferenceEquals(startSegment, endSegment)) + { + span = default; + return false; + } + + span = startSegment is null ? default : startSegment.Memory.Span.Slice(start: 0, length: endSegmentCommitted); + return true; + } + + /// + /// Commits data written to buffers from , making it available for consumption + /// via . This compares to . + /// + public void Commit(int count) + { + DebugAssertValid(); + if (count <= 0) + { + if (count < 0) Throw(); + return; + } + + var available = endSegmentLength - endSegmentCommitted; + if (count > available) Throw(); + endSegmentCommitted += count; + DebugAssertValid(); + + static void Throw() => throw new ArgumentOutOfRangeException(nameof(count)); + } + + public bool CommittedIsEmpty => ReferenceEquals(startSegment, endSegment) & endSegmentCommitted == 0; + + /// + /// Marks committed data as fully consumed; it will no longer appear in later calls to . + /// + public void DiscardCommitted(int count) + { + DebugAssertValid(); + // optimize for most common case, where we consume everything + if (ReferenceEquals(startSegment, endSegment) + & count == endSegmentCommitted + & count > 0) + { + /* + we are consuming all the data in the single segment; we can + just reset that segment back to full size and re-use as-is; + note that we also know that there must *be* a segment + for the count check to pass + */ + endSegmentCommitted = 0; + endSegmentLength = endSegment!.Untrim(expandBackwards: true); + DebugAssertValid(0); + DebugCounters.OnDiscardFull(count); + } + else if (count == 0) + { + // nothing to do + } + else + { + DiscardCommittedSlow(count); + } + } + + public void DiscardCommitted(long count) + { + DebugAssertValid(); + // optimize for most common case, where we consume everything + if (ReferenceEquals(startSegment, endSegment) + & count == endSegmentCommitted + & count > 0) // checks sign *and* non-trimmed + { + // see for logic + endSegmentCommitted = 0; + endSegmentLength = endSegment!.Untrim(expandBackwards: true); + DebugAssertValid(0); + DebugCounters.OnDiscardFull(count); + } + else if (count == 0) + { + // nothing to do + } + else + { + DiscardCommittedSlow(count); + } + } + + private void DiscardCommittedSlow(long count) + { + DebugCounters.OnDiscardPartial(count); +#if DEBUG + var originalLength = GetCommittedLength(); + var originalCount = count; + var expectedLength = originalLength - originalCount; + string blame = nameof(DiscardCommittedSlow); +#endif + while (count > 0) + { + DebugAssertValid(); + var segment = startSegment; + if (segment is null) break; + if (ReferenceEquals(segment, endSegment)) + { + // first==final==only segment + if (count == endSegmentCommitted) + { + endSegmentLength = startSegment!.Untrim(); + endSegmentCommitted = 0; // = untrimmed and unused +#if DEBUG + blame += ",full-final (t)"; +#endif + } + else + { + // discard from the start + int count32 = checked((int)count); + segment.TrimStart(count32); + endSegmentLength -= count32; + endSegmentCommitted -= count32; +#if DEBUG + blame += ",partial-final"; +#endif + } + + count = 0; + break; + } + else if (count < segment.Length) + { + // multiple, but can take some (not all) of the first buffer +#if DEBUG + var len = segment.Length; +#endif + segment.TrimStart((int)count); + Debug.Assert(segment.Length > 0, "parial trim should have left non-empty segment"); +#if DEBUG + Debug.Assert(segment.Length == len - count, "trim failure"); + blame += ",partial-first"; +#endif + count = 0; + break; + } + else + { + // multiple; discard the entire first segment + count -= segment.Length; + startSegment = + segment.ResetAndGetNext(); // we already did a ref-check, so we know this isn't going past endSegment + endSegment!.AppendOrRecycle(segment, maxDepth: 2); + DebugAssertValid(); +#if DEBUG + blame += ",full-first"; +#endif + } + } + + if (count != 0) ThrowCount(); +#if DEBUG + DebugAssertValid(expectedLength, blame); + _ = originalLength; + _ = originalCount; +#endif + + [DoesNotReturn] + static void ThrowCount() => throw new ArgumentOutOfRangeException(nameof(count)); + } + + [Conditional("DEBUG")] + private void DebugAssertValid(long expectedCommittedLength, [CallerMemberName] string caller = "") + { + DebugAssertValid(); + var actual = GetCommittedLength(); + Debug.Assert( + expectedCommittedLength >= 0, + $"Expected committed length is just... wrong: {expectedCommittedLength} (from {caller})"); + Debug.Assert( + expectedCommittedLength == actual, + $"Committed length mismatch: expected {expectedCommittedLength}, got {actual} (from {caller})"); + } + + [Conditional("DEBUG")] + private void DebugAssertValid() + { + if (startSegment is null) + { + Debug.Assert( + endSegmentLength == 0 & endSegmentCommitted == 0, + "un-init state should be zero"); + return; + } + + Debug.Assert(endSegment is not null, "end segment must not be null if start segment exists"); + Debug.Assert( + endSegmentLength == endSegment!.Length, + $"end segment length is incorrect - expected {endSegmentLength}, got {endSegment.Length}"); + Debug.Assert(endSegmentCommitted <= endSegmentLength, $"end segment is over-committed - {endSegmentCommitted} of {endSegmentLength}"); + + // check running indices + startSegment?.DebugAssertValidChain(); + } + + public long GetCommittedLength() + { + DebugAssertValid(); + if (ReferenceEquals(startSegment, endSegment)) + { + return endSegmentCommitted; + } + + // note that the start-segment is pre-trimmed; we don't need to account for an offset on the left + return (endSegment!.RunningIndex + endSegmentCommitted) - startSegment!.RunningIndex; + } + + /// + /// When used with , this means "any non-empty buffer". + /// + public const int GetAnything = 0; + + /// + /// When used with , this means "any full buffer". + /// + public const int GetFullPagesOnly = -1; + + public bool TryGetFirstCommittedSpan(int minBytes, out ReadOnlySpan span) + { + DebugAssertValid(); + if (TryGetFirstCommittedMemory(minBytes, out var memory)) + { + span = memory.Span; + return true; + } + + span = default; + return false; + } + + /// + /// The minLength arg: -ve means "full segments only" (useful when buffering outbound network data to avoid + /// packet fragmentation); otherwise, it is the minimum length we want. + /// + public bool TryGetFirstCommittedMemory(int minBytes, out ReadOnlyMemory memory) + { + if (minBytes == 0) minBytes = 1; // success always means "at least something" + DebugAssertValid(); + if (ReferenceEquals(startSegment, endSegment)) + { + // single page + var available = endSegmentCommitted; + if (available == 0) + { + // empty (includes uninitialized) + memory = default; + return false; + } + + memory = startSegment!.Memory; + var memLength = memory.Length; + if (available == memLength) + { + // full segment; is it enough to make the caller happy? + return available >= minBytes; + } + + // partial segment (and we know it isn't empty) + memory = memory.Slice(start: 0, length: available); + return available >= minBytes & minBytes > 0; // last check here applies the -ve logic + } + + // multi-page; hand out the first page (which is, by definition: full) + memory = startSegment!.Memory; + return memory.Length >= minBytes; + } + + /// + /// Note that this chain is invalidated by any other operations; no concurrency. + /// + public ReadOnlySequence GetAllCommitted() + { + if (ReferenceEquals(startSegment, endSegment)) + { + // single segment, fine + return startSegment is null + ? default + : new ReadOnlySequence(startSegment.Memory.Slice(start: 0, length: endSegmentCommitted)); + } + +#if PARSE_DETAIL + long length = GetCommittedLength(); +#endif + ReadOnlySequence ros = new(startSegment!, 0, endSegment!, endSegmentCommitted); +#if PARSE_DETAIL + Debug.Assert(ros.Length == length, $"length mismatch: calculated {length}, actual {ros.Length}"); +#endif + return ros; + } + + private Segment GetNextSegment() + { + DebugAssertValid(); + if (endSegment is not null) + { + endSegment.TrimEnd(endSegmentCommitted); + Debug.Assert(endSegment.Length == endSegmentCommitted, "trim failure"); + endSegmentLength = endSegmentCommitted; + DebugAssertValid(); + + var spare = endSegment.Next; + if (spare is not null) + { + // we already have a dangling segment; just update state + endSegment.DebugAssertValidChain(); + endSegment = spare; + endSegmentCommitted = 0; + endSegmentLength = spare.Length; + DebugAssertValid(); + return spare; + } + } + + Segment newSegment = Segment.Create(Pool.Rent(PageSize)); + if (endSegment is null) + { + // tabula rasa + endSegmentLength = newSegment.Length; + endSegment = startSegment = newSegment; + DebugAssertValid(); + return newSegment; + } + + endSegment.Append(newSegment); + endSegmentCommitted = 0; + endSegmentLength = newSegment.Length; + endSegment = newSegment; + DebugAssertValid(); + return newSegment; + } + + /// + /// Gets a scratch area for new data; this compares to . + /// + public Span GetUncommittedSpan(int hint = 0) + => GetUncommittedMemory(hint).Span; + + /// + /// Gets a scratch area for new data; this compares to . + /// + public Memory GetUncommittedMemory(int hint = 0) + { + DebugAssertValid(); + var segment = endSegment; + if (segment is not null) + { + var memory = segment.Memory; + if (endSegmentCommitted != 0) memory = memory.Slice(start: endSegmentCommitted); + if (hint <= 0) // allow anything non-empty + { + if (!memory.IsEmpty) return MemoryMarshal.AsMemory(memory); + } + else if (memory.Length >= Math.Min(hint, PageSize >> 2)) // respect the hint up to 1/4 of the page size + { + return MemoryMarshal.AsMemory(memory); + } + } + + // new segment, will always be entire + return MemoryMarshal.AsMemory(GetNextSegment().Memory); + } + + public int UncommittedAvailable + { + get + { + DebugAssertValid(); + return endSegmentLength - endSegmentCommitted; + } + } + + private sealed class Segment : ReadOnlySequenceSegment + { + private Segment() { } + private IMemoryOwner _lease = NullLease.Instance; + private static Segment? _spare; + private Flags _flags; + + [Flags] + private enum Flags + { + None = 0, + StartTrim = 1 << 0, + EndTrim = 1 << 2, + } + + public static Segment Create(IMemoryOwner lease) + { + Debug.Assert(lease is not null, "null lease"); + var memory = lease!.Memory; + if (memory.IsEmpty) ThrowEmpty(); + + var obj = Interlocked.Exchange(ref _spare, null) ?? new(); + return obj.Init(lease, memory); + static void ThrowEmpty() => throw new InvalidOperationException("leased segment is empty"); + } + + private Segment Init(IMemoryOwner lease, Memory memory) + { + _lease = lease; + Memory = memory; + return this; + } + + public int Length => Memory.Length; + + public void Append(Segment next) + { + Debug.Assert(Next is null, "current segment already has a next"); + Debug.Assert(next.Next is null && next.RunningIndex == 0, "inbound next segment is already in a chain"); + next.RunningIndex = RunningIndex + Length; + Next = next; + DebugAssertValidChain(); + } + + private void ApplyChainDelta(int delta) + { + if (delta != 0) + { + var node = Next; + while (node is not null) + { + node.RunningIndex += delta; + node = node.Next; + } + } + } + + public void TrimEnd(int newLength) + { + var delta = Length - newLength; + if (delta != 0) + { + // buffer wasn't fully used; trim + _flags |= Flags.EndTrim; + Memory = Memory.Slice(0, newLength); + ApplyChainDelta(-delta); + DebugAssertValidChain(); + } + } + + public void TrimStart(int remove) + { + if (remove != 0) + { + _flags |= Flags.StartTrim; + Memory = Memory.Slice(start: remove); + RunningIndex += remove; // so that ROS length keeps working; note we *don't* need to adjust the chain + DebugAssertValidChain(); + } + } + + public new Segment? Next + { + get => (Segment?)base.Next; + private set => base.Next = value; + } + + public Segment? ResetAndGetNext() + { + var next = Next; + Next = null; + RunningIndex = 0; + _flags = Flags.None; + Memory = _lease.Memory; // reset, in case we trimmed it + DebugAssertValidChain(); + return next; + } + + public void Recycle() + { + var lease = _lease; + _lease = NullLease.Instance; + lease.Dispose(); + Next = null; + Memory = default; + RunningIndex = 0; + _flags = Flags.None; + Interlocked.Exchange(ref _spare, this); + DebugAssertValidChain(); + } + + private sealed class NullLease : IMemoryOwner + { + private NullLease() { } + public static readonly NullLease Instance = new NullLease(); + public void Dispose() { } + + public Memory Memory => default; + } + + /// + /// Undo any trimming, returning the new full capacity. + /// + public int Untrim(bool expandBackwards = false) + { + var fullMemory = _lease.Memory; + var fullLength = fullMemory.Length; + var delta = fullLength - Length; + if (delta != 0) + { + _flags &= ~(Flags.StartTrim | Flags.EndTrim); + Memory = fullMemory; + if (expandBackwards & RunningIndex >= delta) + { + // push our origin earlier; only valid if + // we're the first segment, otherwise + // we break someone-else's chain + RunningIndex -= delta; + } + else + { + // push everyone else later + ApplyChainDelta(delta); + } + + DebugAssertValidChain(); + } + return fullLength; + } + + public bool StartTrimmed => (_flags & Flags.StartTrim) != 0; + public bool EndTrimmed => (_flags & Flags.EndTrim) != 0; + + [Conditional("DEBUG")] + public void DebugAssertValidChain([CallerMemberName] string blame = "") + { + var node = this; + var runningIndex = RunningIndex; + int index = 0; + while (node.Next is { } next) + { + index++; + var nextRunningIndex = runningIndex + node.Length; + if (nextRunningIndex != next.RunningIndex) ThrowRunningIndex(blame, index); + node = next; + runningIndex = nextRunningIndex; + static void ThrowRunningIndex(string blame, int index) => throw new InvalidOperationException( + $"Critical running index corruption in dangling chain, from '{blame}', segment {index}"); + } + } + + public void AppendOrRecycle(Segment segment, int maxDepth) + { + segment.Memory.DebugScramble(); + var node = this; + while (maxDepth-- > 0 && node is not null) + { + if (node.Next is null) // found somewhere to attach it + { + if (segment.Untrim() == 0) break; // turned out to be useless + segment.RunningIndex = node.RunningIndex + node.Length; + node.Next = segment; + return; + } + + node = node.Next; + } + + segment.Recycle(); + } + } + + /// + /// Discard all data and buffers. + /// + public void Release() + { + var node = startSegment; + startSegment = endSegment = null; + endSegmentCommitted = endSegmentLength = 0; + while (node is not null) + { + var next = node.Next; + node.Recycle(); + node = next; + } + } +} + +// this can be shared between CycleBuffer and CycleBuffer.Simple +partial struct CycleBuffer +{ + /// + /// Writes a value to the buffer; comparable to . + /// + public void Write(ReadOnlySpan value) + { + int srcLength = value.Length; + while (srcLength != 0) + { + var target = GetUncommittedSpan(hint: srcLength); + var tgtLength = target.Length; + if (tgtLength >= srcLength) + { + value.CopyTo(target); + Commit(srcLength); + return; + } + + value.Slice(0, tgtLength).CopyTo(target); + Commit(tgtLength); + value = value.Slice(tgtLength); + srcLength -= tgtLength; + } + } + + /// + /// Writes a value to the buffer; comparable to . + /// + public void Write(in ReadOnlySequence value) + { + if (value.IsSingleSegment) + { +#if NETCOREAPP3_0_OR_GREATER || NETSTANDARD2_1 + Write(value.FirstSpan); +#else + Write(value.First.Span); +#endif + } + else + { + WriteMultiSegment(ref this, in value); + } + + static void WriteMultiSegment(ref CycleBuffer @this, in ReadOnlySequence value) + { + foreach (var segment in value) + { +#if NETCOREAPP3_0_OR_GREATER || NETSTANDARD2_1 + @this.Write(value.FirstSpan); +#else + @this.Write(value.First.Span); +#endif + } + } + } +} diff --git a/src/RESPite/Internal/DebugCounters.cs b/src/RESPite/Internal/DebugCounters.cs new file mode 100644 index 000000000..dd3d9f6d4 --- /dev/null +++ b/src/RESPite/Internal/DebugCounters.cs @@ -0,0 +1,318 @@ +using System.Diagnostics; + +namespace RESPite.Internal; +#if DEBUG +public partial class DebugCounters +#else +internal partial class DebugCounters +#endif +{ +#if DEBUG + private static int _tallyReadCount, + _tallyAsyncReadCount, + _tallyAsyncReadInlineCount, + _tallySyncWriteCount, + _tallyAsyncWriteCount, + _tallyAsyncWriteInlineCount, + _tallyCopyOutCount, + _tallyDiscardFullCount, + _tallyDiscardPartialCount, + _tallyPipelineFullAsyncCount, + _tallyPipelineSendAsyncCount, + _tallyPipelineFullSyncCount, + _tallyBatchWriteCount, + _tallyBatchWriteFullPageCount, + _tallyBatchWritePartialPageCount, + _tallyBatchWriteMessageCount, + _tallyBufferCreatedCount, + _tallyBufferRecycledCount, + _tallyBufferMessageCount, + _tallyBufferPinCount, + _tallyBufferLeakCount, + _tallyBatchGrowCount, + _tallyBatchBufferLeaseCount, + _tallyBatchBufferReturnCount, + _tallyBatchMultiRootMessageCount; + + private static long _tallyWriteBytes, + _tallyReadBytes, + _tallyCopyOutBytes, + _tallyDiscardAverage, + _tallyBufferMessageBytes, + _tallyBufferRecycledBytes, + _tallyBufferMaxOutstandingBytes, + _tallyBufferTotalBytes, + _tallyBatchGrowCopyCount, + _tallyBatchBufferElementsOutstanding, + _tallyBatchMultiChildMessageCount; +#endif + + [Conditional("DEBUG")] + internal static void OnRead(int bytes) + { +#if DEBUG + Interlocked.Increment(ref _tallyReadCount); + if (bytes > 0) Interlocked.Add(ref _tallyReadBytes, bytes); +#endif + } + + public static void OnBatchGrow(int count) + { +#if DEBUG + Interlocked.Increment(ref _tallyBatchGrowCount); + if (count > 0) Interlocked.Add(ref _tallyBatchGrowCopyCount, count); +#endif + } + + public static void OnBatchWrite(int messageCount) + { +#if DEBUG + Interlocked.Increment(ref _tallyBatchWriteCount); + if (messageCount != 0) Interlocked.Add(ref _tallyBatchWriteMessageCount, messageCount); +#endif + } + + public static void OnBatchWriteFullPage() + { +#if DEBUG + Interlocked.Increment(ref _tallyBatchWriteFullPageCount); +#endif + } + + public static void OnBatchWritePartialPage() + { +#if DEBUG + Interlocked.Increment(ref _tallyBatchWritePartialPageCount); +#endif + } + + public static void OnBatchBufferLease(int length) + { +#if DEBUG + Interlocked.Increment(ref _tallyBatchBufferLeaseCount); + Interlocked.Add(ref _tallyBatchBufferElementsOutstanding, length); +#endif + } + + public static void OnBatchBufferReturn(int length) + { +#if DEBUG + Interlocked.Increment(ref _tallyBatchBufferReturnCount); + Interlocked.Add(ref _tallyBatchBufferElementsOutstanding, -length); +#endif + } + + public static void OnMultiMessageWrite(int length) + { +#if DEBUG + Interlocked.Increment(ref _tallyBatchMultiRootMessageCount); + Interlocked.Add(ref _tallyBatchMultiChildMessageCount, length); +#endif + } + + [Conditional("DEBUG")] + internal static void OnAsyncRead(int bytes, bool inline) + { +#if DEBUG + Interlocked.Increment(ref inline ? ref _tallyAsyncReadInlineCount : ref _tallyAsyncReadCount); + if (bytes > 0) Interlocked.Add(ref _tallyReadBytes, bytes); +#endif + } + + [Conditional("DEBUG")] + internal static void OnSyncWrite(int bytes) + { +#if DEBUG + Interlocked.Increment(ref _tallySyncWriteCount); + if (bytes > 0) Interlocked.Add(ref _tallyWriteBytes, bytes); +#endif + } + + [Conditional("DEBUG")] + internal static void OnAsyncWrite(int bytes, bool inline) + { +#if DEBUG + Interlocked.Increment(ref inline ? ref _tallyAsyncWriteInlineCount : ref _tallyAsyncWriteCount); + if (bytes > 0) Interlocked.Add(ref _tallyWriteBytes, bytes); +#endif + } + + [Conditional("DEBUG")] + internal static void OnCopyOut(int bytes) + { +#if DEBUG + Interlocked.Increment(ref _tallyCopyOutCount); + if (bytes > 0) Interlocked.Add(ref _tallyCopyOutBytes, bytes); +#endif + } + + [Conditional("DEBUG")] + public static void OnDiscardFull(long count) + { +#if DEBUG + if (count > 0) + { + Interlocked.Increment(ref _tallyDiscardFullCount); + EstimatedMovingRangeAverage(ref _tallyDiscardAverage, count); + } +#endif + } + + [Conditional("DEBUG")] + public static void OnDiscardPartial(long count) + { +#if DEBUG + if (count > 0) + { + Interlocked.Increment(ref _tallyDiscardPartialCount); + EstimatedMovingRangeAverage(ref _tallyDiscardAverage, count); + } +#endif + } + + [Conditional("DEBUG")] + public static void OnPipelineFullAsync() + { +#if DEBUG + Interlocked.Increment(ref _tallyPipelineFullAsyncCount); +#endif + } + + [Conditional("DEBUG")] + public static void OnPipelineSendAsync() + { +#if DEBUG + Interlocked.Increment(ref _tallyPipelineSendAsyncCount); +#endif + } + + [Conditional("DEBUG")] + public static void OnPipelineFullSync() + { +#if DEBUG + Interlocked.Increment(ref _tallyPipelineFullSyncCount); +#endif + } + + [Conditional("DEBUG")] + public static void OnBufferCreated() + { +#if DEBUG + Interlocked.Increment(ref _tallyBufferCreatedCount); +#endif + } + + [Conditional("DEBUG")] + public static void OnBufferRecycled(int messageBytes) + { +#if DEBUG + Interlocked.Increment(ref _tallyBufferRecycledCount); + var now = Interlocked.Add(ref _tallyBufferRecycledBytes, messageBytes); + var outstanding = Volatile.Read(ref _tallyBufferMessageBytes) - now; + + while (true) + { + var oldOutstanding = Volatile.Read(ref _tallyBufferMaxOutstandingBytes); + // loop until either it isn't an increase, or we successfully perform + // the swap + if (outstanding <= oldOutstanding + || Interlocked.CompareExchange( + ref _tallyBufferMaxOutstandingBytes, + outstanding, + oldOutstanding) == oldOutstanding) break; + } +#endif + } + + [Conditional("DEBUG")] + public static void OnBufferCompleted(int messageCount, int messageBytes) + { +#if DEBUG + Interlocked.Add(ref _tallyBufferMessageCount, messageCount); + Interlocked.Add(ref _tallyBufferMessageBytes, messageBytes); +#endif + } + + public static void OnBufferCapacity(int bytes) + { +#if DEBUG + Interlocked.Add(ref _tallyBufferTotalBytes, bytes); +#endif + } + + public static void OnBufferPinned() + { +#if DEBUG + Interlocked.Increment(ref _tallyBufferPinCount); +#endif + } + + public static void OnBufferLeaked() + { +#if DEBUG + Interlocked.Increment(ref _tallyBufferLeakCount); +#endif + } + + private DebugCounters() + { + } + + public static DebugCounters Flush() + { + #if DEBUG + BlockBufferSerializer.Shared.Clear(); // release any outstanding buffers + #endif + return new(); + } + +#if DEBUG + private static void EstimatedMovingRangeAverage(ref long field, long value) + { + var oldValue = Volatile.Read(ref field); + var delta = (value - oldValue) >> 3; // is is a 7:1 old:new EMRA, using integer/bit math (alplha=0.125) + if (delta != 0) Interlocked.Add(ref field, delta); + // note: strictly conflicting concurrent calls can skew the value incorrectly; this is, however, + // preferable to getting into a CEX squabble or requiring a lock - it is debug-only and just useful data + } + + public int ReadCount { get; } = Interlocked.Exchange(ref _tallyReadCount, 0); + public int AsyncReadCount { get; } = Interlocked.Exchange(ref _tallyAsyncReadCount, 0); + public int AsyncReadInlineCount { get; } = Interlocked.Exchange(ref _tallyAsyncReadInlineCount, 0); + public long ReadBytes { get; } = Interlocked.Exchange(ref _tallyReadBytes, 0); + + public int SyncWriteCount { get; } = Interlocked.Exchange(ref _tallySyncWriteCount, 0); + public int AsyncWriteCount { get; } = Interlocked.Exchange(ref _tallyAsyncWriteCount, 0); + public int AsyncWriteInlineCount { get; } = Interlocked.Exchange(ref _tallyAsyncWriteInlineCount, 0); + public long WriteBytes { get; } = Interlocked.Exchange(ref _tallyWriteBytes, 0); + public int CopyOutCount { get; } = Interlocked.Exchange(ref _tallyCopyOutCount, 0); + public long CopyOutBytes { get; } = Interlocked.Exchange(ref _tallyCopyOutBytes, 0); + public long DiscardAverage { get; } = Interlocked.Exchange(ref _tallyDiscardAverage, 32); + public int DiscardFullCount { get; } = Interlocked.Exchange(ref _tallyDiscardFullCount, 0); + public int DiscardPartialCount { get; } = Interlocked.Exchange(ref _tallyDiscardPartialCount, 0); + public int PipelineFullAsyncCount { get; } = Interlocked.Exchange(ref _tallyPipelineFullAsyncCount, 0); + public int PipelineSendAsyncCount { get; } = Interlocked.Exchange(ref _tallyPipelineSendAsyncCount, 0); + public int PipelineFullSyncCount { get; } = Interlocked.Exchange(ref _tallyPipelineFullSyncCount, 0); + public int BatchWriteCount { get; } = Interlocked.Exchange(ref _tallyBatchWriteCount, 0); + public int BatchWriteFullPageCount { get; } = Interlocked.Exchange(ref _tallyBatchWriteFullPageCount, 0); + public int BatchWritePartialPageCount { get; } = Interlocked.Exchange(ref _tallyBatchWritePartialPageCount, 0); + public int BatchWriteMessageCount { get; } = Interlocked.Exchange(ref _tallyBatchWriteMessageCount, 0); + public int BatchGrowCount { get; } = Interlocked.Exchange(ref _tallyBatchGrowCount, 0); + public long BatchGrowCopyCount { get; } = Interlocked.Exchange(ref _tallyBatchGrowCopyCount, 0); + public int BatchBufferLeaseCount { get; } = Interlocked.Exchange(ref _tallyBatchBufferLeaseCount, 0); + public int BatchBufferReturnCount { get; } = Interlocked.Exchange(ref _tallyBatchBufferReturnCount, 0); + public long BatchBufferElementsOutstanding { get; } = Interlocked.Exchange(ref _tallyBatchBufferElementsOutstanding, 0); + public int BatchMultiRootMessageCount { get; } = Interlocked.Exchange(ref _tallyBatchMultiRootMessageCount, 0); + public long BatchMultiChildMessageCount { get; } = Interlocked.Exchange(ref _tallyBatchMultiChildMessageCount, 0); + + public int BufferCreatedCount { get; } = Interlocked.Exchange(ref _tallyBufferCreatedCount, 0); + public int BufferRecycledCount { get; } = Interlocked.Exchange(ref _tallyBufferRecycledCount, 0); + public long BufferRecycledBytes { get; } = Interlocked.Exchange(ref _tallyBufferRecycledBytes, 0); + public long BufferMaxOutstandingBytes { get; } = Interlocked.Exchange(ref _tallyBufferMaxOutstandingBytes, 0); + public int BufferMessageCount { get; } = Interlocked.Exchange(ref _tallyBufferMessageCount, 0); + public long BufferMessageBytes { get; } = Interlocked.Exchange(ref _tallyBufferMessageBytes, 0); + public long BufferTotalBytes { get; } = Interlocked.Exchange(ref _tallyBufferTotalBytes, 0); + public int BufferPinCount { get; } = Interlocked.Exchange(ref _tallyBufferPinCount, 0); + public int BufferLeakCount { get; } = Interlocked.Exchange(ref _tallyBufferLeakCount, 0); +#endif +} diff --git a/src/RESPite/Internal/IRespInlineParser.cs b/src/RESPite/Internal/IRespInlineParser.cs new file mode 100644 index 000000000..31b054a40 --- /dev/null +++ b/src/RESPite/Internal/IRespInlineParser.cs @@ -0,0 +1,5 @@ +namespace RESPite.Internal; + +internal interface IRespInlineParser // marker interface for readers safe to use on the IO thread +{ +} diff --git a/src/RESPite/Internal/Raw.cs b/src/RESPite/Internal/Raw.cs new file mode 100644 index 000000000..65d0c5059 --- /dev/null +++ b/src/RESPite/Internal/Raw.cs @@ -0,0 +1,138 @@ +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Text; + +#if NETCOREAPP3_0_OR_GREATER +using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.X86; +#endif + +namespace RESPite.Internal; + +/// +/// Pre-computed payload fragments, for high-volume scenarios / common values. +/// +/// +/// CPU-endianness applies here; we can't just use "const" - however, modern JITs treat "static readonly" *almost* the same as "const", so: meh. +/// +internal static class Raw +{ + public static ulong Create64(ReadOnlySpan bytes, int length) + { + if (length != bytes.Length) + { + throw new ArgumentException($"Length check failed: {length} vs {bytes.Length}, value: {RespConstants.UTF8.GetString(bytes)}", nameof(length)); + } + if (length < 0 || length > sizeof(ulong)) + { + throw new ArgumentOutOfRangeException(nameof(length), $"Invalid length {length} - must be 0-{sizeof(ulong)}"); + } + + // this *will* be aligned; this approach intentionally chosen for parity with write + Span scratch = stackalloc byte[sizeof(ulong)]; + if (length != sizeof(ulong)) scratch.Slice(length).Clear(); + bytes.CopyTo(scratch); + return Unsafe.ReadUnaligned(ref MemoryMarshal.GetReference(scratch)); + } + + public static uint Create32(ReadOnlySpan bytes, int length) + { + if (length != bytes.Length) + { + throw new ArgumentException($"Length check failed: {length} vs {bytes.Length}, value: {RespConstants.UTF8.GetString(bytes)}", nameof(length)); + } + if (length < 0 || length > sizeof(uint)) + { + throw new ArgumentOutOfRangeException(nameof(length), $"Invalid length {length} - must be 0-{sizeof(uint)}"); + } + + // this *will* be aligned; this approach intentionally chosen for parity with write + Span scratch = stackalloc byte[sizeof(uint)]; + if (length != sizeof(uint)) scratch.Slice(length).Clear(); + bytes.CopyTo(scratch); + return Unsafe.ReadUnaligned(ref MemoryMarshal.GetReference(scratch)); + } + + public static ulong BulkStringEmpty_6 = Create64("$0\r\n\r\n"u8, 6); + + public static ulong BulkStringInt32_M1_8 = Create64("$2\r\n-1\r\n"u8, 8); + public static ulong BulkStringInt32_0_7 = Create64("$1\r\n0\r\n"u8, 7); + public static ulong BulkStringInt32_1_7 = Create64("$1\r\n1\r\n"u8, 7); + public static ulong BulkStringInt32_2_7 = Create64("$1\r\n2\r\n"u8, 7); + public static ulong BulkStringInt32_3_7 = Create64("$1\r\n3\r\n"u8, 7); + public static ulong BulkStringInt32_4_7 = Create64("$1\r\n4\r\n"u8, 7); + public static ulong BulkStringInt32_5_7 = Create64("$1\r\n5\r\n"u8, 7); + public static ulong BulkStringInt32_6_7 = Create64("$1\r\n6\r\n"u8, 7); + public static ulong BulkStringInt32_7_7 = Create64("$1\r\n7\r\n"u8, 7); + public static ulong BulkStringInt32_8_7 = Create64("$1\r\n8\r\n"u8, 7); + public static ulong BulkStringInt32_9_7 = Create64("$1\r\n9\r\n"u8, 7); + public static ulong BulkStringInt32_10_8 = Create64("$2\r\n10\r\n"u8, 8); + + public static ulong BulkStringPrefix_M1_5 = Create64("$-1\r\n"u8, 5); + public static uint BulkStringPrefix_0_4 = Create32("$0\r\n"u8, 4); + public static uint BulkStringPrefix_1_4 = Create32("$1\r\n"u8, 4); + public static uint BulkStringPrefix_2_4 = Create32("$2\r\n"u8, 4); + public static uint BulkStringPrefix_3_4 = Create32("$3\r\n"u8, 4); + public static uint BulkStringPrefix_4_4 = Create32("$4\r\n"u8, 4); + public static uint BulkStringPrefix_5_4 = Create32("$5\r\n"u8, 4); + public static uint BulkStringPrefix_6_4 = Create32("$6\r\n"u8, 4); + public static uint BulkStringPrefix_7_4 = Create32("$7\r\n"u8, 4); + public static uint BulkStringPrefix_8_4 = Create32("$8\r\n"u8, 4); + public static uint BulkStringPrefix_9_4 = Create32("$9\r\n"u8, 4); + public static ulong BulkStringPrefix_10_5 = Create64("$10\r\n"u8, 5); + + public static ulong ArrayPrefix_M1_5 = Create64("*-1\r\n"u8, 5); + public static uint ArrayPrefix_0_4 = Create32("*0\r\n"u8, 4); + public static uint ArrayPrefix_1_4 = Create32("*1\r\n"u8, 4); + public static uint ArrayPrefix_2_4 = Create32("*2\r\n"u8, 4); + public static uint ArrayPrefix_3_4 = Create32("*3\r\n"u8, 4); + public static uint ArrayPrefix_4_4 = Create32("*4\r\n"u8, 4); + public static uint ArrayPrefix_5_4 = Create32("*5\r\n"u8, 4); + public static uint ArrayPrefix_6_4 = Create32("*6\r\n"u8, 4); + public static uint ArrayPrefix_7_4 = Create32("*7\r\n"u8, 4); + public static uint ArrayPrefix_8_4 = Create32("*8\r\n"u8, 4); + public static uint ArrayPrefix_9_4 = Create32("*9\r\n"u8, 4); + public static ulong ArrayPrefix_10_5 = Create64("*10\r\n"u8, 5); + +#if NETCOREAPP3_0_OR_GREATER + private static uint FirstAndLast(char first, char last) + { + Debug.Assert(first < 128 && last < 128, "ASCII please"); + Span scratch = [(byte)first, 0, 0, (byte)last]; + // this *will* be aligned; this approach intentionally chosen for how we read + return Unsafe.ReadUnaligned(ref MemoryMarshal.GetReference(scratch)); + } + + public const int CommonRespIndex_Success = 0; + public const int CommonRespIndex_SingleDigitInteger = 1; + public const int CommonRespIndex_DoubleDigitInteger = 2; + public const int CommonRespIndex_SingleDigitString = 3; + public const int CommonRespIndex_DoubleDigitString = 4; + public const int CommonRespIndex_SingleDigitArray = 5; + public const int CommonRespIndex_DoubleDigitArray = 6; + public const int CommonRespIndex_Error = 7; + + public static readonly Vector256 CommonRespPrefixes = Vector256.Create( + FirstAndLast('+', '\r'), // success +OK\r\n + FirstAndLast(':', '\n'), // single-digit integer :4\r\n + FirstAndLast(':', '\r'), // double-digit integer :42\r\n + FirstAndLast('$', '\n'), // 0-9 char string $0\r\n\r\n + FirstAndLast('$', '\r'), // null/10-99 char string $-1\r\n or $10\r\nABCDEFGHIJ\r\n + FirstAndLast('*', '\n'), // 0-9 length array *0\r\n + FirstAndLast('*', '\r'), // null/10-99 length array *-1\r\n or *10\r\n:0\r\n:0\r\n:0\r\n:0\r\n:0\r\n:0\r\n:0\r\n:0\r\n:0\r\n:0\r\n + FirstAndLast('-', 'R')); // common errors -ERR something bad happened + + public static readonly Vector256 FirstLastMask = CreateUInt32(0xFF0000FF); + + private static Vector256 CreateUInt32(uint value) + { +#if NET7_0_OR_GREATER + return Vector256.Create(value); +#else + return Vector256.Create(value, value, value, value, value, value, value, value); +#endif + } + +#endif +} diff --git a/src/RESPite/Internal/RespConstants.cs b/src/RESPite/Internal/RespConstants.cs new file mode 100644 index 000000000..accb8400b --- /dev/null +++ b/src/RESPite/Internal/RespConstants.cs @@ -0,0 +1,53 @@ +using System.Buffers.Binary; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Text; +// ReSharper disable InconsistentNaming +namespace RESPite.Internal; + +internal static class RespConstants +{ + public static readonly UTF8Encoding UTF8 = new(false); + + public static ReadOnlySpan CrlfBytes => "\r\n"u8; + + public static readonly ushort CrLfUInt16 = UnsafeCpuUInt16(CrlfBytes); + + public static ReadOnlySpan OKBytes_LC => "ok"u8; + public static ReadOnlySpan OKBytes => "OK"u8; + public static readonly ushort OKUInt16 = UnsafeCpuUInt16(OKBytes); + public static readonly ushort OKUInt16_LC = UnsafeCpuUInt16(OKBytes_LC); + + public static readonly uint BulkStringStreaming = UnsafeCpuUInt32("$?\r\n"u8); + public static readonly uint BulkStringNull = UnsafeCpuUInt32("$-1\r"u8); + + public static readonly uint ArrayStreaming = UnsafeCpuUInt32("*?\r\n"u8); + public static readonly uint ArrayNull = UnsafeCpuUInt32("*-1\r"u8); + + public static ushort UnsafeCpuUInt16(ReadOnlySpan bytes) + => Unsafe.ReadUnaligned(ref MemoryMarshal.GetReference(bytes)); + public static ushort UnsafeCpuUInt16(ReadOnlySpan bytes, int offset) + => Unsafe.ReadUnaligned(ref Unsafe.Add(ref MemoryMarshal.GetReference(bytes), offset)); + public static byte UnsafeCpuByte(ReadOnlySpan bytes, int offset) + => Unsafe.Add(ref MemoryMarshal.GetReference(bytes), offset); + public static uint UnsafeCpuUInt32(ReadOnlySpan bytes) + => Unsafe.ReadUnaligned(ref MemoryMarshal.GetReference(bytes)); + public static uint UnsafeCpuUInt32(ReadOnlySpan bytes, int offset) + => Unsafe.ReadUnaligned(ref Unsafe.Add(ref MemoryMarshal.GetReference(bytes), offset)); + public static ulong UnsafeCpuUInt64(ReadOnlySpan bytes) + => Unsafe.ReadUnaligned(ref MemoryMarshal.GetReference(bytes)); + public static ushort CpuUInt16(ushort bigEndian) + => BitConverter.IsLittleEndian ? BinaryPrimitives.ReverseEndianness(bigEndian) : bigEndian; + public static uint CpuUInt32(uint bigEndian) + => BitConverter.IsLittleEndian ? BinaryPrimitives.ReverseEndianness(bigEndian) : bigEndian; + public static ulong CpuUInt64(ulong bigEndian) + => BitConverter.IsLittleEndian ? BinaryPrimitives.ReverseEndianness(bigEndian) : bigEndian; + + public const int MaxRawBytesInt32 = 11, // "-2147483648" + MaxRawBytesInt64 = 20, // "-9223372036854775808", + MaxProtocolBytesIntegerInt32 = MaxRawBytesInt32 + 3, // ?X10X\r\n where ? could be $, *, etc - usually a length prefix + MaxProtocolBytesBulkStringIntegerInt32 = MaxRawBytesInt32 + 7, // $NN\r\nX11X\r\n for NN (length) 1-11 + MaxProtocolBytesBulkStringIntegerInt64 = MaxRawBytesInt64 + 7, // $NN\r\nX20X\r\n for NN (length) 1-20 + MaxRawBytesNumber = 20, // note G17 format, allow 20 for payload + MaxProtocolBytesBytesNumber = MaxRawBytesNumber + 7; // $NN\r\nX...X\r\n for NN (length) 1-20 +} diff --git a/src/RESPite/Internal/RespMessageBase.cs b/src/RESPite/Internal/RespMessageBase.cs new file mode 100644 index 000000000..1d06f9150 --- /dev/null +++ b/src/RESPite/Internal/RespMessageBase.cs @@ -0,0 +1,381 @@ +using System.Buffers; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Threading.Tasks.Sources; +using RESPite.Messages; + +namespace RESPite.Internal; + +internal abstract class RespMessageBase : IValueTaskSource +{ + protected RespMessageBase() => RespOperation.DebugOnAllocateMessage(); + + private CancellationToken _cancellationToken; + private CancellationTokenRegistration _cancellationTokenRegistration; + private int _requestRefCount, _flags, _slot; + private ReadOnlySequence _request; + public ref readonly CancellationToken CancellationToken => ref _cancellationToken; + + [Flags] + internal enum StateFlags + { + None = 0, + IsSent = 1 << 0, // the request has been sent + OutcomeKnown = 1 << 1, // controls which code flow gets to set an outcome + Complete = 1 << 2, // indicates whether all follow-up has completed + NoPulse = 1 << 4, // don't pulse when completing - either async, or timeout + Doomed = 1 << 5, // something went wrong, do not recycle + HasParser = 1 << 6, // we have a parser + MetadataParser = 1 << 7, // the parser wants to consume metadata + InlineParser = 1 << 8, // we can safely use the parser on the IO thread + Replica = 1 << 9, // request a replica (otherwise, primary is requested) + Demand = 1 << 10, // the presence/absence of Replica is a hard demand + } + + internal StateFlags Flags => (StateFlags)Volatile.Read(ref _flags); + public virtual int MessageCount => 1; + internal int Slot => _slot; + + protected void InitParser(object? parser) + { + if (parser is null) + { + SetFlag(StateFlags.InlineParser); // F+F + } + else + { + var flags = StateFlags.HasParser; + // detect parsers that want to manually parse attributes, errors, etc. + if (parser is IRespMetadataParser) flags |= StateFlags.MetadataParser; + // detect fast, internal, non-allocating parsers (int, bool, etc.) + if (parser is IRespInlineParser) flags |= StateFlags.InlineParser; + SetFlag(flags); + } + } + + public bool AllowInlineParsing => HasFlag(StateFlags.InlineParser); + + public bool TrySetResult(short token, ref RespReader reader) + { + var flags = Flags & (StateFlags.MetadataParser | StateFlags.HasParser | StateFlags.OutcomeKnown); + if ((flags & StateFlags.OutcomeKnown) != 0 | Token != token) return false; + switch (flags) + { + case StateFlags.HasParser: + case StateFlags.HasParser | StateFlags.MetadataParser: + try + { + if ((flags & StateFlags.MetadataParser) == 0) + { + reader.MoveNext(); + } + + return TrySetResultPrecheckedToken(ref reader); + } + catch (Exception ex) + { + return TrySetExceptionPrecheckedToken(ex); + } + default: + return TrySetDefaultResultPrecheckedToken(); + } + } + + // if this is a multi-message type, then when adding to the "sent awaiting resport" queue, + // instead of adding the message, we add the sub-messages **instead** (and not the root message) + public virtual bool TryGetSubMessages(short token, out ReadOnlySpan operations) + { + operations = default; + return false; + } + + // if this is a multi-message type, this does cleanup after TryGetSubMessages has been consumed + public virtual bool TrySetResultAfterUnloadingSubMessages(short token) => false; + + public bool TrySetResult(short token, scoped ReadOnlySpan response) + { + RespReader reader = new(response); + return TrySetResult(token, ref reader); + } + + public bool TrySetResult(short token, in ReadOnlySequence response) + { + RespReader reader = new(response); + return TrySetResult(token, ref reader); + } + + protected abstract bool TrySetResultPrecheckedToken(ref RespReader reader); + protected abstract bool TrySetDefaultResultPrecheckedToken(); + + public abstract short Token { get; } + + [Obsolete("Prefer de-virtualized version via CheckTokenCore")] + private protected abstract void CheckToken(short token); + + private protected abstract ValueTaskSourceStatus OwnStatus { get; } + + public abstract ValueTaskSourceStatus GetStatus(short token); + + public bool IsSent(short token) + { +#pragma warning disable CS0618 // can't access CheckTokenCore in base-class + CheckToken(token); +#pragma warning restore CS0618 + return HasFlag(StateFlags.IsSent); + } + + protected bool SetFlag(StateFlags flag) + { + Debug.Assert(flag != 0, "trying to set a zero flag"); +#if NET5_0_OR_GREATER + return (Interlocked.Or(ref _flags, (int)flag) & (int)flag) == 0; +#else + while (true) + { + var oldValue = Volatile.Read(ref _flags); + var newValue = oldValue | (int)flag; + if (oldValue == newValue || + Interlocked.CompareExchange(ref _flags, newValue, oldValue) == oldValue) + { + return (oldValue & (int)flag) == 0; + } + } +#endif + } + + // in the "any" sense + protected bool HasFlag(StateFlags flag) => (Volatile.Read(ref _flags) & (int)flag) != 0; + + public void Init(bool sent, CancellationToken cancellationToken) + { + Debug.Assert(Flags is 0 or StateFlags.InlineParser, $"flags should be zero; got {Flags}"); + Debug.Assert(_requestRefCount == 0, "trying to set a request more than once"); + if (sent) SetFlag(StateFlags.IsSent); + if (cancellationToken.CanBeCanceled) + { + _cancellationToken = cancellationToken; + _cancellationTokenRegistration = ActivationHelper.RegisterForCancellation(this, cancellationToken); + } + } + + public void Init( + ReadOnlyMemory request, + CancellationToken cancellationToken) => Init(new ReadOnlySequence(request), cancellationToken); + + public void Init( + ReadOnlySequence request, + CancellationToken cancellationToken) + { + Debug.Assert(_requestRefCount == 0, "trying to set a request more than once"); + _request = request; + _requestRefCount = 1; + if (cancellationToken.CanBeCanceled) + { + _cancellationToken = cancellationToken; + _cancellationTokenRegistration = ActivationHelper.RegisterForCancellation(this, cancellationToken); + } + } + + protected void UnregisterCancellation() + { + _cancellationTokenRegistration.Dispose(); + _cancellationTokenRegistration = default; + _cancellationToken = CancellationToken.None; + } + + protected virtual void Reset(bool recycle) + { + Debug.Assert( + !recycle || OwnStatus == ValueTaskSourceStatus.Succeeded, + "We should only be recycling completed messages"); + // note we only reset on success, and on + // success we've already unregistered cancellation + _request = default; + _requestRefCount = _flags = 0; + _slot = -1; + NextToken(); + if (recycle) Recycle(); + } + + protected abstract void Recycle(); + protected abstract void NextToken(); + + internal void OnSent(short token) + { + // only if our token matches, but: don't throw + if (token == Token) OnSent(); + } + + protected virtual void OnSent() => SetFlag(StateFlags.IsSent); + + public bool TryReserveRequest(short token, out ReadOnlySequence payload, bool recordSent = true) + { + while (true) // redo in case of CEX failure + { + Debug.Assert(OwnStatus == ValueTaskSourceStatus.Pending); + + var oldCount = Volatile.Read(ref _requestRefCount); + if (oldCount == 0 | token != Token) + { + payload = default; + return false; + } + + if (Interlocked.CompareExchange(ref _requestRefCount, checked(oldCount + 1), oldCount) == oldCount) + { + if (recordSent) OnSent(); + + payload = _request; + return true; + } + } + } + + public void ReleaseRequest() + { + if (!TryReleaseRequest()) ThrowReleased(); + + static void ThrowReleased() => + throw new InvalidOperationException("The request payload has already been released"); + } + + private bool TryReleaseRequest() // bool here means "it wasn't already zero"; it doesn't mean "it became zero" + { + while (true) + { + var oldCount = Volatile.Read(ref _requestRefCount); + if (oldCount == 0) return false; + if (Interlocked.CompareExchange(ref _requestRefCount, oldCount - 1, oldCount) == oldCount) + { + if (oldCount == 1) // we were the last one; recycle + { + BlockBufferSerializer.BlockBuffer.Release(in _request); + _request = default; + } + + return true; + } + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + protected void ThrowNotSent(short token) + { +#pragma warning disable CS0618 // can't access CheckTokenCore in base-class + CheckToken(token); // prefer a token explanation +#pragma warning restore CS0618 + throw new InvalidOperationException( + "This command has not yet been sent; waiting is not possible. If this is a transaction or batch, you must execute that first."); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + protected void SetNotSentAsync(short token) + { +#pragma warning disable CS0618 // can't access CheckTokenCore in base-class + CheckToken(token); +#pragma warning restore CS0618 + TrySetExceptionPrecheckedToken(new InvalidOperationException( + "This command has not yet been sent; awaiting is not possible. If this is a transaction or batch, you must execute that first.")); + } + + // spoof untyped on top of typed + void IValueTaskSource.GetResult(short token) => GetResultVoid(token); + + // ReSharper disable once UnusedMember.Local + private bool TrySetOutcomeKnown(short token, bool withSuccess) + => Token == token && TrySetOutcomeKnownPrecheckedToken(withSuccess); + + protected bool TrySetOutcomeKnownPrecheckedToken(bool withSuccess) + { + if (!SetFlag(StateFlags.OutcomeKnown)) return false; + UnregisterCancellation(); + TryReleaseRequest(); // we won't be needing this again + + // configure threading model; failure can be triggered from any thread - *always* + // dispatch to pool; in the success case, we're either on the IO thread + // (if inline-parsing is enabled) - in which case, yes: dispatch - or we've + // already jumped to a pool thread for the parse step. So: the only + // time we want to complete inline is success and not inline-parsing. + SetRunContinuationsAsynchronously(!withSuccess | AllowInlineParsing); + + return true; + } + + private protected abstract void SetRunContinuationsAsynchronously(bool value); + public abstract void GetResultVoid(short token); + public abstract void WaitVoid(short token, TimeSpan timeout); + + public bool TrySetCanceled(short token, CancellationToken cancellationToken = default) + { + if (!cancellationToken.IsCancellationRequested) + { + // use our own token if nothing more specific supplied + cancellationToken = _cancellationToken; + } + + return token == Token && TrySetCanceledPrecheckedToken(cancellationToken); + } + + // this is the path used by cancellation registration callbacks; always use our own + // cancellation token, and we must trust the version token + internal void TrySetCanceledTrustToken() => TrySetCanceledPrecheckedToken(_cancellationToken); + + private bool TrySetCanceledPrecheckedToken(CancellationToken cancellationToken) + { + if (!TrySetOutcomeKnownPrecheckedToken(false)) return false; + SetExceptionPreChecked(new OperationCanceledException(cancellationToken)); + SetFullyComplete(success: false); + return true; + } + + public bool TrySetException(short token, Exception exception) + => token == Token && TrySetExceptionPrecheckedToken(exception); + + private protected abstract void SetExceptionPreChecked(Exception exception); + + private bool TrySetExceptionPrecheckedToken(Exception exception) + { + if (!TrySetOutcomeKnownPrecheckedToken(false)) return false; // first winner only + SetExceptionPreChecked(exception); + SetFullyComplete(success: false); + return true; + } + + protected void SetFullyComplete(bool success) + { + var pulse = !HasFlag(StateFlags.NoPulse); + SetFlag(success + ? (StateFlags.Complete | StateFlags.NoPulse) + : (StateFlags.Complete | StateFlags.NoPulse | StateFlags.Doomed)); + + // for safety, always take the lock unless we know they've actively exited + if (pulse) + { + lock (this) + { + Monitor.PulseAll(this); + } + } + } + + protected bool TrySetTimeoutPrecheckedToken() + { + if (!TrySetOutcomeKnownPrecheckedToken(false)) return false; + + SetExceptionPreChecked(new TimeoutException()); + SetFullyComplete(success: false); + return true; + } + + public abstract void OnCompleted( + Action continuation, + object? state, + short token, + ValueTaskSourceOnCompletedFlags flags); + + public abstract void OnCompletedWithNotSentDetection( + Action continuation, + object? state, + short token, + ValueTaskSourceOnCompletedFlags flags); +} diff --git a/src/RESPite/Internal/RespMessageBaseT.cs b/src/RESPite/Internal/RespMessageBaseT.cs new file mode 100644 index 000000000..498440ca6 --- /dev/null +++ b/src/RESPite/Internal/RespMessageBaseT.cs @@ -0,0 +1,181 @@ +using System.Runtime.CompilerServices; +using System.Threading.Tasks.Sources; +using RESPite.Messages; + +namespace RESPite.Internal; + +internal abstract class RespMessageBase : RespMessageBase, IValueTaskSource +{ + private ManualResetValueTaskSourceCore _asyncCore; + + protected abstract TResponse Parse(ref RespReader reader); + + public override short Token => _asyncCore.Version; + + private protected override ValueTaskSourceStatus OwnStatus => _asyncCore.GetStatus(_asyncCore.Version); + + /* asking about the status too early is usually a very bad sign that they're doing + something like awaiting a message in a transaction that hasn't been sent */ + public override ValueTaskSourceStatus GetStatus(short token) + { + CheckTokenCore(token); + return _asyncCore.GetStatus(token); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void CheckTokenCore(short token) + { + if (token != _asyncCore.Version) // use cheap test + { + // note that _asyncCore just gives a default InvalidOperationException message; let's see if we can do better + ThrowInvalidToken(); + } + static void ThrowInvalidToken() => throw new InvalidOperationException( + $"The {nameof(RespOperation)} token is invalid; the most likely cause is awaiting an operation multiple times."); + } + + [Obsolete("Prefer de-virtualized version via " + nameof(CheckTokenCore))] + private protected override void CheckToken(short token) => CheckTokenCore(token); + + // this is used from Task/ValueTask; we can't avoid that - in theory + // we *coiuld* sort of make it work for ValueTask, but if anyone + // calls .AsTask() on it, it would fail + public override void OnCompleted( + Action continuation, + object? state, + short token, + ValueTaskSourceOnCompletedFlags flags) + { + CheckTokenCore(token); + SetFlag(StateFlags.NoPulse); // async doesn't need to be pulsed + _asyncCore.OnCompleted(continuation, state, token, flags); + } + + public override void OnCompletedWithNotSentDetection( + Action continuation, + object? state, + short token, + ValueTaskSourceOnCompletedFlags flags) + { + CheckTokenCore(token); + if (!HasFlag(StateFlags.IsSent)) SetNotSentAsync(token); + SetFlag(StateFlags.NoPulse); // async doesn't need to be pulsed + _asyncCore.OnCompleted(continuation, state, token, flags); + } + + private protected override void SetRunContinuationsAsynchronously(bool value) + => _asyncCore.RunContinuationsAsynchronously = value; + + public override void GetResultVoid(short token) => _ = GetResult(token); + public override void WaitVoid(short token, TimeSpan timeout) => _ = Wait(token, timeout); + + public TResponse Wait(short token, TimeSpan timeout) + { + switch (Flags & (StateFlags.Complete | StateFlags.IsSent)) + { + case StateFlags.IsSent: // this is the normal case + break; + case StateFlags.Complete | StateFlags.IsSent: // already complete + return GetResult(token); + default: + ThrowNotSent(token); // always throws + break; + } + + bool isTimeout = false; + CheckTokenCore(token); + lock (this) + { + switch (Flags & (StateFlags.Complete | StateFlags.NoPulse)) + { + case StateFlags.NoPulse | StateFlags.Complete: + case StateFlags.Complete: + break; // fine, we're complete + case 0: + // THIS IS OUR EXPECTED BRANCH; not complete, and will pulse + if (timeout == TimeSpan.Zero) + { + Monitor.Wait(this); + } + else if (!Monitor.Wait(this, timeout)) + { + isTimeout = true; + SetFlag(StateFlags.NoPulse); // no point in being woken, we're exiting + } + + break; + case StateFlags.NoPulse: + ThrowWillNotPulse(); + break; + } + } + + UnregisterCancellation(); + if (isTimeout) TrySetTimeoutPrecheckedToken(); + + return GetResult(token); + + static void ThrowWillNotPulse() => throw new InvalidOperationException( + "This operation cannot be waited because it entered async/await mode - most likely by calling AsTask()"); + } + + protected bool TrySetResultPrecheckedToken(TResponse response) + { + if (!TrySetOutcomeKnownPrecheckedToken(true)) return false; + + _asyncCore.SetResult(response); + SetFullyComplete(success: true); + return true; + } + + private TResponse ThrowFailureWithCleanup(short token) + { + var status = GetStatus(token); + try + { + if (status == ValueTaskSourceStatus.Pending) + { + if (!HasFlag(StateFlags.IsSent)) ThrowNotSent(_asyncCore.Version); + throw new InvalidOperationException( + "This operation has been sent but has not yet completed; the result is not available."); + } + return _asyncCore.GetResult(token); + } + finally + { + // we're not recycling; this is for GC reasons only + Reset(false); + } + } + + public TResponse GetResult(short token) + { + // failure uses some try/catch logic, let's put that to one side + // (it is very tempting to peek inside GetStatus with UnsafeAccessor...) + if (_asyncCore.Version != token || _asyncCore.GetStatus(token) != ValueTaskSourceStatus.Succeeded) + { + return ThrowFailureWithCleanup(token); + } + var result = _asyncCore.GetResult(token); + /* + If we get here, we're successful; increment "version"/"token" *immediately*. Technically + we could defer to when it is reused (after recycling), but then repeated calls will appear + to work for a while, which might lead to undetected problems in local builds (without much concurrency), + and we'd rather make people know that there's a problem immediately. This also means that any + continuation primitives (callback/state) are available for GC. + */ + Reset(true); + return result; + } + + private protected override void SetExceptionPreChecked(Exception exception) + => _asyncCore.SetException(exception); + + protected override bool TrySetResultPrecheckedToken(ref RespReader reader) => + TrySetResultPrecheckedToken(Parse(ref reader)); + + protected override bool TrySetDefaultResultPrecheckedToken() + => TrySetResultPrecheckedToken(default!); + + protected override void NextToken() => _asyncCore.Reset(); +} diff --git a/src/RESPite/Internal/RespMultiMessage.cs b/src/RESPite/Internal/RespMultiMessage.cs new file mode 100644 index 000000000..44d13da28 --- /dev/null +++ b/src/RESPite/Internal/RespMultiMessage.cs @@ -0,0 +1,74 @@ +using System.Diagnostics; +using System.Runtime.CompilerServices; +using RESPite.Connections.Internal; +using RESPite.Messages; + +namespace RESPite.Internal; + +internal sealed class RespMultiMessage : RespMessageBase +{ + private RespOperation[] _oversized; + private int _count; + + [ThreadStatic] + // used for object recycling of the async machinery + private static RespMultiMessage? _threadStaticSpare; + + private ReadOnlySpan Operations => new(_oversized, 0, _count); + + internal static RespMultiMessage Get(RespOperation[] oversized, int count) + { + RespMultiMessage obj = _threadStaticSpare ?? new(); + _threadStaticSpare = null; + obj._oversized = oversized; + obj._count = count; + return obj; + } + + public override bool TryGetSubMessages(short token, out ReadOnlySpan operations) + { + operations = token == Token ? Operations : default; + return true; // always return true; this means that flush gets called + } + + public override bool TrySetResultAfterUnloadingSubMessages(short token) + { + if (token == Token && TrySetResultPrecheckedToken(_count)) + { + // release the buffer immediately - it isn't needed any more + _count = 0; + BufferingBatchConnection.Return(ref _oversized); + return true; + } + + return false; + } + + protected override void Recycle() => _threadStaticSpare = this; + + private RespMultiMessage() => Unsafe.SkipInit(out _oversized); + + protected override int Parse(ref RespReader reader) + { + Debug.Fail("Not expecting to see results, since unrolled during write"); + return _count; + } + + protected override void OnSent() + { + base.OnSent(); + foreach (var op in Operations) + { + op.OnSent(); + } + } + + protected override void Reset(bool recycle) + { + _count = 0; + BufferingBatchConnection.Return(ref _oversized); + base.Reset(recycle); + } + + public override int MessageCount => _count; +} diff --git a/src/RESPite/Internal/RespOperationExtensions.cs b/src/RESPite/Internal/RespOperationExtensions.cs new file mode 100644 index 000000000..0aedccc69 --- /dev/null +++ b/src/RESPite/Internal/RespOperationExtensions.cs @@ -0,0 +1,58 @@ +using System.Buffers; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace RESPite.Internal; + +internal static class RespOperationExtensions +{ +#if PREVIEW_LANGVER + extension(in RespOperation operation) + { + // since this is valid... + public ref readonly RespOperation Self => ref operation; + + // so is this (the types are layout-identical) + public ref readonly RespOperation Untyped => ref Unsafe.As, RespOperation>( + ref Unsafe.AsRef(in operation)); + } +#endif + + // if we're recycling a buffer, we need to consider it trashable by other threads; for + // debug purposes, force this by overwriting with *****, aka the meaning of life + [Conditional("DEBUG")] + internal static void DebugScramble(this Span value) + => value.Fill(42); + + [Conditional("DEBUG")] + internal static void DebugScramble(this Memory value) + => value.Span.Fill(42); + + [Conditional("DEBUG")] + internal static void DebugScramble(this ReadOnlyMemory value) + => MemoryMarshal.AsMemory(value).Span.Fill(42); + + [Conditional("DEBUG")] + internal static void DebugScramble(this ReadOnlySequence value) + { + if (value.IsSingleSegment) + { + value.First.DebugScramble(); + } + else + { + foreach (var segment in value) + { + segment.DebugScramble(); + } + } + } + + [Conditional("DEBUG")] + internal static void DebugScramble(this byte[]? value) + { + if (value is not null) + value.AsSpan().Fill(42); + } +} diff --git a/src/RESPite/Internal/RespStatefulMessage.cs b/src/RESPite/Internal/RespStatefulMessage.cs new file mode 100644 index 000000000..6d818b664 --- /dev/null +++ b/src/RESPite/Internal/RespStatefulMessage.cs @@ -0,0 +1,35 @@ +using System.Runtime.CompilerServices; +using RESPite.Messages; + +namespace RESPite.Internal; + +internal sealed class RespStatefulMessage : RespMessageBase +{ + private TState _state; + private IRespParser? _parser; + [ThreadStatic] + // used for object recycling of the async machinery + private static RespStatefulMessage? _threadStaticSpare; + internal static RespStatefulMessage Get(in TState state, IRespParser? parser) + { + RespStatefulMessage obj = _threadStaticSpare ?? new(); + _threadStaticSpare = null; + obj._state = state; + obj._parser = parser; + obj.InitParser(parser); + return obj; + } + + protected override void Recycle() => _threadStaticSpare = this; + + private RespStatefulMessage() => Unsafe.SkipInit(out _state); + + protected override TResponse Parse(ref RespReader reader) => _parser!.Parse(in _state, ref reader); + + protected override void Reset(bool recycle) + { + _state = default!; + _parser = null!; + base.Reset(recycle); + } +} diff --git a/src/RESPite/Internal/RespStatelessMessage.cs b/src/RESPite/Internal/RespStatelessMessage.cs new file mode 100644 index 000000000..fd7e1768b --- /dev/null +++ b/src/RESPite/Internal/RespStatelessMessage.cs @@ -0,0 +1,32 @@ +using RESPite.Messages; + +namespace RESPite.Internal; + +internal sealed class RespStatelessMessage : RespMessageBase +{ + private IRespParser? _parser; + [ThreadStatic] + // used for object recycling of the async machinery + private static RespStatelessMessage? _threadStaticSpare; + + internal static RespStatelessMessage Get(IRespParser? parser) + { + RespStatelessMessage obj = _threadStaticSpare ?? new(); + _threadStaticSpare = null; + obj._parser = parser; + obj.InitParser(parser); + return obj; + } + + protected override void Recycle() => _threadStaticSpare = this; + + private RespStatelessMessage() { } + + protected override TResponse Parse(ref RespReader reader) => _parser!.Parse(ref reader); + + protected override void Reset(bool recycle) + { + _parser = null!; + base.Reset(recycle); + } +} diff --git a/src/RESPite/Internal/SynchronizedBlockBufferSerializer.cs b/src/RESPite/Internal/SynchronizedBlockBufferSerializer.cs new file mode 100644 index 000000000..cd5aabafb --- /dev/null +++ b/src/RESPite/Internal/SynchronizedBlockBufferSerializer.cs @@ -0,0 +1,121 @@ +using System.Buffers; +using RESPite.Messages; + +namespace RESPite.Internal; + +internal partial class BlockBufferSerializer +{ + internal static BlockBufferSerializer Create(bool retainChain = false) => + new SynchronizedBlockBufferSerializer(retainChain); + + /// + /// Used for things like . + /// + private sealed class SynchronizedBlockBufferSerializer(bool retainChain) : BlockBufferSerializer + { + private bool _discardDuringClear; + + private protected override BlockBuffer? Buffer { get; set; } // simple per-instance auto-prop + + // use lock-based synchronization + public override ReadOnlyMemory Serialize( + RespCommandMap? commandMap, + ReadOnlySpan command, + in TRequest request, + IRespFormatter formatter) + { + bool haveLock = false; + try // note that "lock" unrolls to something very similar; we're not adding anything unusual here + { + // in reality, we *expect* people to not attempt to use batches concurrently, *and* + // we expect serialization to be very fast, but: out of an abundance of caution, + // add a timeout - just to avoid surprises (since people can write their own formatters) + Monitor.TryEnter(this, LockTimeout, ref haveLock); + if (!haveLock) ThrowTimeout(); + return base.Serialize(commandMap, command, in request, formatter); + } + finally + { + if (haveLock) Monitor.Exit(this); + } + + static void ThrowTimeout() => throw new TimeoutException( + "It took a long time to get access to the serialization-buffer. This is very odd - please " + + "ask on GitHub, but *as a guess*, you have a custom RESP formatter that is really slow *and* " + + "you are using concurrent access to a RESP batch / transaction."); + } + + private static readonly TimeSpan LockTimeout = TimeSpan.FromSeconds(5); + + private Segment? _head, _tail; + + protected override bool ClaimSegment(ReadOnlyMemory segment) + { + if (retainChain & !_discardDuringClear) + { + if (_head is null) + { + _head = _tail = new Segment(segment); + } + else + { + _tail = new Segment(segment, _tail); + } + + // note we don't need to increment the ref-count; because of this "true" + return true; + } + + return false; + } + + internal override ReadOnlySequence Flush() + { + if (_head is null) + { + // at worst, single-segment - we can skip the alloc + return new(BlockBuffer.RetainCurrent(this)); + } + + // otherwise, flush everything *keeping the chain* + ClearWithDiscard(discard: false); + ReadOnlySequence seq = new(_head, 0, _tail!, _tail!.Length); + _head = _tail = null; + return seq; + } + + public override void Clear() + { + ClearWithDiscard(discard: true); + _head = _tail = null; + } + + private void ClearWithDiscard(bool discard) + { + try + { + _discardDuringClear = discard; + base.Clear(); + } + finally + { + _discardDuringClear = false; + } + } + + private sealed class Segment : ReadOnlySequenceSegment + { + public Segment(ReadOnlyMemory memory, Segment? previous = null) + { + Memory = memory; + if (previous is not null) + { + previous.Next = this; + RunningIndex = previous.RunningIndex + previous.Length; + } + } + + public int Length => Memory.Length; + } + } +} diff --git a/src/RESPite/Internal/ThreadLocalBlockBufferSerializer.cs b/src/RESPite/Internal/ThreadLocalBlockBufferSerializer.cs new file mode 100644 index 000000000..1c1895ff4 --- /dev/null +++ b/src/RESPite/Internal/ThreadLocalBlockBufferSerializer.cs @@ -0,0 +1,21 @@ +namespace RESPite.Internal; + +internal partial class BlockBufferSerializer +{ + internal static BlockBufferSerializer Shared => ThreadLocalBlockBufferSerializer.Instance; + private sealed class ThreadLocalBlockBufferSerializer : BlockBufferSerializer + { + private ThreadLocalBlockBufferSerializer() { } + public static readonly ThreadLocalBlockBufferSerializer Instance = new(); + + [ThreadStatic] + // side-step concurrency using per-thread semantics + private static BlockBuffer? _perTreadBuffer; + + private protected override BlockBuffer? Buffer + { + get => _perTreadBuffer; + set => _perTreadBuffer = value; + } + } +} diff --git a/src/RESPite/Internal/Utils.cs b/src/RESPite/Internal/Utils.cs new file mode 100644 index 000000000..c3ba8f3d2 --- /dev/null +++ b/src/RESPite/Internal/Utils.cs @@ -0,0 +1,32 @@ +namespace RESPite.Internal; + +internal static class Utils +{ + internal static void LogLocked(this TextWriter? writer, string message) + { + if (writer is null) return; + lock (writer) + { + writer.WriteLine(message); + } + } + +#if NET10_0_OR_GREATER + internal static void LogLocked( + this TextWriter? writer, + ref System.Runtime.CompilerServices.DefaultInterpolatedStringHandler message) + { + if (writer is null) + { + message.Clear(); + } + else + { + lock (writer) + { + writer.WriteLine(message.ToStringAndClear()); + } + } + } +#endif +} diff --git a/src/RESPite/Messages/IRespFormatterT.cs b/src/RESPite/Messages/IRespFormatterT.cs new file mode 100644 index 000000000..9857f3add --- /dev/null +++ b/src/RESPite/Messages/IRespFormatterT.cs @@ -0,0 +1,19 @@ +namespace RESPite.Messages; + +public interface IRespFormatter +#if NET9_0_OR_GREATER + where TRequest : allows ref struct +#endif +{ + void Format(scoped ReadOnlySpan command, ref RespWriter writer, in TRequest request); +} + +/* +public interface IRespSizeEstimator : IRespFormatter +#if NET9_0_OR_GREATER + where TRequest : allows ref struct +#endif +{ + int EstimateSize(scoped ReadOnlySpan command, in TRequest request); +} +*/ diff --git a/src/RESPite/Messages/IRespMetadataParser.cs b/src/RESPite/Messages/IRespMetadataParser.cs new file mode 100644 index 000000000..4d7943a4e --- /dev/null +++ b/src/RESPite/Messages/IRespMetadataParser.cs @@ -0,0 +1,10 @@ +namespace RESPite.Messages; + +/// +/// When implemented by a or , +/// indicates that the reader should not be pre-initialized to the first node - which would otherwise +/// consume attributes and errors. +/// +public interface IRespMetadataParser +{ +} diff --git a/src/RESPite/Messages/IRespParser_Typed.cs b/src/RESPite/Messages/IRespParser_Typed.cs new file mode 100644 index 000000000..25f96ec3b --- /dev/null +++ b/src/RESPite/Messages/IRespParser_Typed.cs @@ -0,0 +1,13 @@ +namespace RESPite.Messages; + +/// +/// Parses a RESP response into a typed value of type . +/// +/// The type of value being parsed. +public interface IRespParser +{ + /// + /// Parse into a . + /// + TResponse Parse(ref RespReader reader); +} diff --git a/src/RESPite/Messages/IRespParser_Typed_Stateful.cs b/src/RESPite/Messages/IRespParser_Typed_Stateful.cs new file mode 100644 index 000000000..3e549e92c --- /dev/null +++ b/src/RESPite/Messages/IRespParser_Typed_Stateful.cs @@ -0,0 +1,12 @@ +namespace RESPite.Messages; + +public interface IRespParser +{ + /// + /// Parse into a , + /// using the state from . + /// + /// The state to use when parsing. + /// The reader to parse. + TResponse Parse(in TState state, ref RespReader reader); +} diff --git a/src/RESPite/Messages/RespAttributeReader.cs b/src/RESPite/Messages/RespAttributeReader.cs new file mode 100644 index 000000000..699d70a49 --- /dev/null +++ b/src/RESPite/Messages/RespAttributeReader.cs @@ -0,0 +1,68 @@ +namespace RESPite.Messages; + +/// +/// Allows attribute data to be parsed conveniently. +/// +/// The type of data represented by this reader. +public abstract class RespAttributeReader +{ + /// + /// Parse a group of attributes. + /// + public virtual void Read(ref RespReader reader, ref T value) + { + reader.Demand(RespPrefix.Attribute); + _ = ReadKeyValuePairs(ref reader, ref value); + } + + /// + /// Parse an aggregate as a set of key/value pairs. + /// + /// The number of pairs successfully processed. + protected virtual int ReadKeyValuePairs(ref RespReader reader, ref T value) + { + var iterator = reader.AggregateChildren(); + + byte[] pooledBuffer = []; + Span localBuffer = stackalloc byte[128]; + int count = 0; + while (iterator.MoveNext() && iterator.Value.TryReadNext()) + { + if (iterator.Value.IsScalar) + { + var key = iterator.Value.Buffer(ref pooledBuffer, localBuffer); + + if (iterator.MoveNext() && iterator.Value.TryReadNext()) + { + if (ReadKeyValuePair(key, ref iterator.Value, ref value)) + { + count++; + } + } + else + { + break; // no matching value for this key + } + } + else + { + if (iterator.MoveNext() && iterator.Value.TryReadNext()) + { + // we won't try to handle aggregate keys; skip the value + } + else + { + break; // no matching value for this key + } + } + } + iterator.MovePast(out reader); + return count; + } + + /// + /// Parse an individual key/value pair. + /// + /// True if the pair was successfully processed. + public virtual bool ReadKeyValuePair(scoped ReadOnlySpan key, ref RespReader reader, ref T value) => false; +} diff --git a/src/RESPite/Messages/RespFrameScanner.cs b/src/RESPite/Messages/RespFrameScanner.cs new file mode 100644 index 000000000..322bfa5e9 --- /dev/null +++ b/src/RESPite/Messages/RespFrameScanner.cs @@ -0,0 +1,193 @@ +using System.Buffers; +using RESPite.Messages; +using static RESPite.Internal.RespConstants; +namespace RESPite.Internal; + +/// +/// Scans RESP frames. +/// . +public sealed class RespFrameScanner // : IFrameSacanner, IFrameValidator +{ + /// + /// Gets a frame scanner for RESP2 request/response connections, or RESP3 connections. + /// + public static RespFrameScanner Default { get; } = new(false); + + /// + /// Gets a frame scanner that identifies RESP2 pub/sub messages. + /// + public static RespFrameScanner Subscription { get; } = new(true); + private RespFrameScanner(bool pubsub) => _pubsub = pubsub; + private readonly bool _pubsub; + + private static readonly uint FastNull = UnsafeCpuUInt32("_\r\n\0"u8), + SingleCharScalarMask = CpuUInt32(0xFF00FFFF), + SingleDigitInteger = UnsafeCpuUInt32(":\0\r\n"u8), + EitherBoolean = UnsafeCpuUInt32("#\0\r\n"u8), + FirstThree = CpuUInt32(0xFFFFFF00); + private static readonly ulong OK = UnsafeCpuUInt64("+OK\r\n\0\0\0"u8), + PONG = UnsafeCpuUInt64("+PONG\r\n\0"u8), + DoubleCharScalarMask = CpuUInt64(0xFF0000FFFF000000), + DoubleDigitInteger = UnsafeCpuUInt64(":\0\0\r\n"u8), + FirstFive = CpuUInt64(0xFFFFFFFFFF000000), + FirstSeven = CpuUInt64(0xFFFFFFFFFFFFFF00); + + private const OperationStatus UseReader = (OperationStatus)(-1); + private static OperationStatus TryFastRead(ReadOnlySpan data, ref RespScanState info) + { + // use silly math to detect the most common short patterns without needing + // to access a reader, or use indexof etc; handles: + // +OK\r\n + // +PONG\r\n + // :N\r\n for any single-digit N (integer) + // :NN\r\n for any double-digit N (integer) + // #N\r\n for any single-digit N (boolean) + // _\r\n (null) + uint hi, lo; + switch (data.Length) + { + case 0: + case 1: + case 2: + return OperationStatus.NeedMoreData; + case 3: + hi = (((uint)UnsafeCpuUInt16(data)) << 16) | (((uint)UnsafeCpuByte(data, 2)) << 8); + break; + default: + hi = UnsafeCpuUInt32(data); + break; + } + if ((hi & FirstThree) == FastNull) + { + info.SetComplete(3, RespPrefix.Null); + return OperationStatus.Done; + } + + var masked = hi & SingleCharScalarMask; + if (masked == SingleDigitInteger) + { + info.SetComplete(4, RespPrefix.Integer); + return OperationStatus.Done; + } + else if (masked == EitherBoolean) + { + info.SetComplete(4, RespPrefix.Boolean); + return OperationStatus.Done; + } + + switch (data.Length) + { + case 3: + return OperationStatus.NeedMoreData; + case 4: + return UseReader; + case 5: + lo = ((uint)data[4]) << 24; + break; + case 6: + lo = ((uint)UnsafeCpuUInt16(data, 4)) << 16; + break; + case 7: + lo = ((uint)UnsafeCpuUInt16(data, 4)) << 16 | ((uint)UnsafeCpuByte(data, 6)) << 8; + break; + default: + lo = UnsafeCpuUInt32(data, 4); + break; + } + var u64 = BitConverter.IsLittleEndian ? ((((ulong)lo) << 32) | hi) : ((((ulong)hi) << 32) | lo); + if (((u64 & FirstFive) == OK) | ((u64 & DoubleCharScalarMask) == DoubleDigitInteger)) + { + info.SetComplete(5, RespPrefix.SimpleString); + return OperationStatus.Done; + } + if ((u64 & FirstSeven) == PONG) + { + info.SetComplete(7, RespPrefix.SimpleString); + return OperationStatus.Done; + } + return UseReader; + } + + /// + /// Attempt to read more data as part of the current frame. + /// + public OperationStatus TryRead(ref RespScanState state, in ReadOnlySequence data) + { + if (!_pubsub & state.TotalBytes == 0 & data.IsSingleSegment) + { +#if NETCOREAPP3_1_OR_GREATER + var status = TryFastRead(data.FirstSpan, ref state); +#else + var status = TryFastRead(data.First.Span, ref state); +#endif + if (status != UseReader) return status; + } + + return TryReadViaReader(ref state, in data); + + static OperationStatus TryReadViaReader(ref RespScanState state, in ReadOnlySequence data) + { + var reader = new RespReader(in data); + var complete = state.TryRead(ref reader, out var consumed); + if (complete) + { + return OperationStatus.Done; + } + return OperationStatus.NeedMoreData; + } + } + + /// + /// Attempt to read more data as part of the current frame. + /// + public OperationStatus TryRead(ref RespScanState state, ReadOnlySpan data) + { + if (!_pubsub & state.TotalBytes == 0) + { +#if NETCOREAPP3_1_OR_GREATER + var status = TryFastRead(data, ref state); +#else + var status = TryFastRead(data, ref state); +#endif + if (status != UseReader) return status; + } + + return TryReadViaReader(ref state, data); + + static OperationStatus TryReadViaReader(ref RespScanState state, ReadOnlySpan data) + { + var reader = new RespReader(data); + var complete = state.TryRead(ref reader, out var consumed); + if (complete) + { + return OperationStatus.Done; + } + return OperationStatus.NeedMoreData; + } + } + + /// + /// Validate that the supplied message is a valid RESP request, specifically: that it contains a single + /// top-level array payload with bulk-string elements, the first of which is non-empty (the command). + /// + public void ValidateRequest(in ReadOnlySequence message) + { + if (message.IsEmpty) Throw("Empty RESP frame"); + RespReader reader = new(in message); + reader.MoveNext(RespPrefix.Array); + reader.DemandNotNull(); + if (reader.IsStreaming) Throw("Streaming is not supported in this context"); + var count = reader.AggregateLength(); + for (int i = 0; i < count; i++) + { + reader.MoveNext(RespPrefix.BulkString); + reader.DemandNotNull(); + if (reader.IsStreaming) Throw("Streaming is not supported in this context"); + + if (i == 0 && reader.ScalarIsEmpty()) Throw("command must be non-empty"); + } + reader.DemandEnd(); + + static void Throw(string message) => throw new InvalidOperationException(message); + } +} diff --git a/src/RESPite/Messages/RespPrefix.cs b/src/RESPite/Messages/RespPrefix.cs new file mode 100644 index 000000000..09fa5e5d8 --- /dev/null +++ b/src/RESPite/Messages/RespPrefix.cs @@ -0,0 +1,97 @@ +namespace RESPite.Messages; + +/// +/// RESP protocol prefix. +/// +public enum RespPrefix : byte +{ + /// + /// Invalid. + /// + None = 0, + + /// + /// Simple strings: +OK\r\n. + /// + SimpleString = (byte)'+', + + /// + /// Simple errors: -ERR message\r\n. + /// + SimpleError = (byte)'-', + + /// + /// Integers: :123\r\n. + /// + Integer = (byte)':', + + /// + /// String with support for binary data: $7\r\nmessage\r\n. + /// + BulkString = (byte)'$', + + /// + /// Multiple inner messages: *1\r\n+message\r\n. + /// + Array = (byte)'*', + + /// + /// Null strings/arrays: _\r\n. + /// + Null = (byte)'_', + + /// + /// Boolean values: #T\r\n. + /// + Boolean = (byte)'#', + + /// + /// Floating-point number: ,123.45\r\n. + /// + Double = (byte)',', + + /// + /// Large integer number: (12...89\r\n. + /// + BigInteger = (byte)'(', + + /// + /// Error with support for binary data: !7\r\nmessage\r\n. + /// + BulkError = (byte)'!', + + /// + /// String that should be interpreted verbatim: =11\r\ntxt:message\r\n. + /// + VerbatimString = (byte)'=', + + /// + /// Multiple sub-items that represent a map. + /// + Map = (byte)'%', + + /// + /// Multiple sub-items that represent a set. + /// + Set = (byte)'~', + + /// + /// Out-of band messages. + /// + Push = (byte)'>', + + /// + /// Continuation of streaming scalar values. + /// + StreamContinuation = (byte)';', + + /// + /// End sentinel for streaming aggregate values. + /// + StreamTerminator = (byte)'.', + + /// + /// Metadata about the next element. + /// + Attribute = (byte)'|', +} diff --git a/src/RESPite/Messages/RespReader.AggregateEnumerator.cs b/src/RESPite/Messages/RespReader.AggregateEnumerator.cs new file mode 100644 index 000000000..1853d2ee6 --- /dev/null +++ b/src/RESPite/Messages/RespReader.AggregateEnumerator.cs @@ -0,0 +1,214 @@ +using System.Collections; +using System.ComponentModel; + +#pragma warning disable IDE0079 // Remove unnecessary suppression +#pragma warning disable CS0282 // There is no defined ordering between fields in multiple declarations of partial struct +#pragma warning restore IDE0079 // Remove unnecessary suppression + +namespace RESPite.Messages; + +public ref partial struct RespReader +{ + /// + /// Reads the sub-elements associated with an aggregate value. + /// + public readonly AggregateEnumerator AggregateChildren() => new(in this); + + /// + /// Reads the sub-elements associated with an aggregate value. + /// + public ref struct AggregateEnumerator + { + // Note that _reader is the overall reader that can see outside this aggregate, as opposed + // to Current which is the sub-tree of the current element *only* + private RespReader _reader; + private int _remaining; + + /// + /// Create a new enumerator for the specified . + /// + /// The reader containing the data for this operation. + public AggregateEnumerator(scoped in RespReader reader) + { + reader.DemandAggregate(); + _remaining = reader.IsStreaming ? -1 : reader._length; + _reader = reader; + Value = default; + } + + /// + public readonly AggregateEnumerator GetEnumerator() => this; + + /// + [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)] + public RespReader Current => Value; + + /// + /// Gets the current element associated with this reader. + /// + public RespReader Value; // intentionally a field, because of ref-semantics + + /// + /// Move to the next child if possible, and move the child element into the next node. + /// + public bool MoveNext(RespPrefix prefix) + { + bool result = MoveNext(); + if (result) + { + Value.MoveNext(prefix); + } + return result; + } + + /// + /// Move to the next child if possible, and move the child element into the next node. + /// + /// The type of data represented by this reader. + public bool MoveNext(RespPrefix prefix, RespAttributeReader respAttributeReader, ref T attributes) + { + bool result = MoveNext(respAttributeReader, ref attributes); + if (result) + { + Value.MoveNext(prefix); + } + return result; + } + + /// > + public bool MoveNext() + { + object? attributes = null; + return MoveNextCore(null, ref attributes); + } + + /// > + /// The type of data represented by this reader. + public bool MoveNext(RespAttributeReader respAttributeReader, ref T attributes) + => MoveNextCore(respAttributeReader, ref attributes); + + /// > + private bool MoveNextCore(RespAttributeReader? attributeReader, ref T attributes) + { + if (_remaining == 0) + { + Value = default; + return false; + } + + // in order to provide access to attributes etc, we want Current to be positioned + // *before* the next element; for that, we'll take a snapshot before we read + _reader.MovePastCurrent(); + var snapshot = _reader.Clone(); + + if (attributeReader is null) + { + _reader.MoveNext(); + } + else + { + _reader.MoveNext(attributeReader, ref attributes); + } + if (_remaining > 0) + { + // non-streaming, decrement + _remaining--; + } + else if (_reader.Prefix == RespPrefix.StreamTerminator) + { + // end of streaming aggregate + _remaining = 0; + Value = default; + return false; + } + + // move past that sub-tree and trim the "snapshot" state, giving + // us a scoped reader that is *just* that sub-tree + _reader.SkipChildren(); + snapshot.TrimToTotal(_reader.BytesConsumed); + + Value = snapshot; + return true; + } + + /// + /// Move to the end of this aggregate and export the state of the . + /// + /// The reader positioned at the end of the data; this is commonly + /// used to update a tree reader, to get to the next data after the aggregate. + public void MovePast(out RespReader reader) + { + while (MoveNext()) { } + reader = _reader; + } + + public void DemandNext() + { + if (!MoveNext()) ThrowEof(); + Value.MoveNext(); // skip any attributes etc + } + + public T ReadOne(Projection projection) + { + DemandNext(); + return projection(ref Value); + } + + public void FillAll(scoped Span target, Projection projection) + { + for (int i = 0; i < target.Length; i++) + { + if (!MoveNext()) ThrowEof(); + + Value.MoveNext(); // skip any attributes etc + target[i] = projection(ref Value); + } + } + + public void FillAll( + scoped Span target, + Projection first, + Projection second, + Func combine) + { + for (int i = 0; i < target.Length; i++) + { + if (!MoveNext()) ThrowEof(); + + Value.MoveNext(); // skip any attributes etc + var x = first(ref Value); + + if (!MoveNext()) ThrowEof(); + + Value.MoveNext(); // skip any attributes etc + var y = second(ref Value); + target[i] = combine(x, y); + } + } + } + + internal void TrimToTotal(long length) => TrimToRemaining(length - BytesConsumed); + + internal void TrimToRemaining(long bytes) + { + if (_prefix != RespPrefix.None || bytes < 0) Throw(); + + var current = CurrentAvailable; + if (bytes <= current) + { + UnsafeTrimCurrentBy(current - (int)bytes); + _remainingTailLength = 0; + return; + } + + bytes -= current; + if (bytes <= _remainingTailLength) + { + _remainingTailLength = bytes; + return; + } + + Throw(); + static void Throw() => throw new ArgumentOutOfRangeException(nameof(bytes)); + } +} diff --git a/src/RESPite/Messages/RespReader.Debug.cs b/src/RESPite/Messages/RespReader.Debug.cs new file mode 100644 index 000000000..3f471bbd1 --- /dev/null +++ b/src/RESPite/Messages/RespReader.Debug.cs @@ -0,0 +1,33 @@ +using System.Diagnostics; + +#pragma warning disable IDE0079 // Remove unnecessary suppression +#pragma warning disable CS0282 // There is no defined ordering between fields in multiple declarations of partial struct +#pragma warning restore IDE0079 // Remove unnecessary suppression + +namespace RESPite.Messages; + +[DebuggerDisplay($"{{{nameof(GetDebuggerDisplay)}(),nq}}")] +public ref partial struct RespReader +{ + internal bool DebugEquals(in RespReader other) + => _prefix == other._prefix + && _length == other._length + && _flags == other._flags + && _bufferIndex == other._bufferIndex + && _positionBase == other._positionBase + && _remainingTailLength == other._remainingTailLength; + + internal new string ToString() => $"{Prefix} ({_flags}); length {_length}, {TotalAvailable} remaining"; + + internal void DebugReset() + { + _bufferIndex = 0; + _length = 0; + _flags = 0; + _prefix = RespPrefix.None; + } + +#if DEBUG + internal bool VectorizeDisabled { get; set; } +#endif +} diff --git a/src/RESPite/Messages/RespReader.ScalarEnumerator.cs b/src/RESPite/Messages/RespReader.ScalarEnumerator.cs new file mode 100644 index 000000000..9e8ffbe70 --- /dev/null +++ b/src/RESPite/Messages/RespReader.ScalarEnumerator.cs @@ -0,0 +1,105 @@ +using System.Buffers; +using System.Collections; + +#pragma warning disable IDE0079 // Remove unnecessary suppression +#pragma warning disable CS0282 // There is no defined ordering between fields in multiple declarations of partial struct +#pragma warning restore IDE0079 // Remove unnecessary suppression + +namespace RESPite.Messages; + +public ref partial struct RespReader +{ + /// + /// Gets the chunks associated with a scalar value. + /// + public readonly ScalarEnumerator ScalarChunks() => new(in this); + + /// + /// Allows enumeration of chunks in a scalar value; this includes simple values + /// that span multiple segments, and streaming + /// scalar RESP values. + /// + public ref struct ScalarEnumerator + { + /// + public readonly ScalarEnumerator GetEnumerator() => this; + + private RespReader _reader; + + private ReadOnlySpan _current; + private ReadOnlySequenceSegment? _tail; + private int _offset, _remaining; + + /// + /// Create a new enumerator for the specified . + /// + /// The reader containing the data for this operation. + public ScalarEnumerator(scoped in RespReader reader) + { + reader.DemandScalar(); + _reader = reader; + InitSegment(); + } + + private void InitSegment() + { + _current = _reader.CurrentSpan(); + _tail = _reader._tail; + _offset = CurrentLength = 0; + _remaining = _reader._length; + if (_reader.TotalAvailable < _remaining) ThrowEof(); + } + + /// + public bool MoveNext() + { + while (true) // for each streaming element + { + _offset += CurrentLength; + while (_remaining > 0) // for each span in the current element + { + // look in the active span + var take = Math.Min(_remaining, _current.Length - _offset); + if (take > 0) // more in the current chunk + { + _remaining -= take; + CurrentLength = take; + return true; + } + + // otherwise, we expect more tail data + if (_tail is null) ThrowEof(); + + _current = _tail.Memory.Span; + _offset = 0; + _tail = _tail.Next; + } + + if (!_reader.MoveNextStreamingScalar()) break; + InitSegment(); + } + + CurrentLength = 0; + return false; + } + + /// + public readonly ReadOnlySpan Current => _current.Slice(_offset, CurrentLength); + + /// + /// Gets the or . + /// + public int CurrentLength { readonly get; private set; } + + /// + /// Move to the end of this aggregate and export the state of the . + /// + /// The reader positioned at the end of the data; this is commonly + /// used to update a tree reader, to get to the next data after the aggregate. + public void MovePast(out RespReader reader) + { + while (MoveNext()) { } + reader = _reader; + } + } +} diff --git a/src/RESPite/Messages/RespReader.Span.cs b/src/RESPite/Messages/RespReader.Span.cs new file mode 100644 index 000000000..fd3870ef3 --- /dev/null +++ b/src/RESPite/Messages/RespReader.Span.cs @@ -0,0 +1,84 @@ +#define USE_UNSAFE_SPAN + +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +#pragma warning disable IDE0079 // Remove unnecessary suppression +#pragma warning disable CS0282 // There is no defined ordering between fields in multiple declarations of partial struct +#pragma warning restore IDE0079 // Remove unnecessary suppression + +namespace RESPite.Messages; + +/* + How we actually implement the underlying buffer depends on the capabilities of the runtime. + */ + +#if NET7_0_OR_GREATER && USE_UNSAFE_SPAN + +public ref partial struct RespReader +{ + // intent: avoid lots of slicing by dealing with everything manually, and accepting the "don't get it wrong" rule + private ref byte _bufferRoot; + private int _bufferLength; + + private partial void UnsafeTrimCurrentBy(int count) + { + Debug.Assert(count >= 0 && count <= _bufferLength, "Unsafe trim length"); + _bufferLength -= count; + } + + private readonly partial ref byte UnsafeCurrent + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => ref Unsafe.Add(ref _bufferRoot, _bufferIndex); + } + + private readonly partial int CurrentLength + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => _bufferLength; + } + + private readonly partial ReadOnlySpan CurrentSpan() => MemoryMarshal.CreateReadOnlySpan( + ref UnsafeCurrent, CurrentAvailable); + + private readonly partial ReadOnlySpan UnsafePastPrefix() => MemoryMarshal.CreateReadOnlySpan( + ref Unsafe.Add(ref _bufferRoot, _bufferIndex + 1), + _bufferLength - (_bufferIndex + 1)); + + private partial void SetCurrent(ReadOnlySpan value) + { + _bufferRoot = ref MemoryMarshal.GetReference(value); + _bufferLength = value.Length; + } +} +#else +public ref partial struct RespReader // much more conservative - uses slices etc +{ + private ReadOnlySpan _buffer; + + private partial void UnsafeTrimCurrentBy(int count) + { + _buffer = _buffer.Slice(0, _buffer.Length - count); + } + + private readonly partial ref byte UnsafeCurrent + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => ref Unsafe.AsRef(in _buffer[_bufferIndex]); // hack around CS8333 + } + + private readonly partial int CurrentLength + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => _buffer.Length; + } + + private readonly partial ReadOnlySpan UnsafePastPrefix() => _buffer.Slice(_bufferIndex + 1); + + private readonly partial ReadOnlySpan CurrentSpan() => _buffer.Slice(_bufferIndex); + + private partial void SetCurrent(ReadOnlySpan value) => _buffer = value; +} +#endif diff --git a/src/RESPite/Messages/RespReader.Utils.cs b/src/RESPite/Messages/RespReader.Utils.cs new file mode 100644 index 000000000..da6b641d8 --- /dev/null +++ b/src/RESPite/Messages/RespReader.Utils.cs @@ -0,0 +1,317 @@ +using System.Buffers.Text; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using RESPite.Internal; + +#pragma warning disable IDE0079 // Remove unnecessary suppression +#pragma warning disable CS0282 // There is no defined ordering between fields in multiple declarations of partial struct +#pragma warning restore IDE0079 // Remove unnecessary suppression + +namespace RESPite.Messages; + +public ref partial struct RespReader +{ + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void UnsafeAssertClLf(int offset) => UnsafeAssertClLf(ref UnsafeCurrent, offset); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void UnsafeAssertClLf(scoped ref byte source, int offset) + { + if (Unsafe.ReadUnaligned(ref Unsafe.Add(ref source, offset)) != RespConstants.CrLfUInt16) + { + ThrowProtocolFailure("Expected CR/LF"); + } + } + + private enum LengthPrefixResult + { + NeedMoreData, + Length, + Null, + Streaming, + } + + /// + /// Asserts that the current element is a scalar type. + /// + public readonly void DemandScalar() + { + if (!IsScalar) Throw(Prefix); + static void Throw(RespPrefix prefix) => throw new InvalidOperationException($"This operation requires a scalar element, got {prefix}"); + } + + /// + /// Asserts that the current element is a scalar type. + /// + public readonly void DemandAggregate() + { + if (!IsAggregate) Throw(Prefix); + static void Throw(RespPrefix prefix) => throw new InvalidOperationException($"This operation requires an aggregate element, got {prefix}"); + } + + private static LengthPrefixResult TryReadLengthPrefix(ReadOnlySpan bytes, out int value, out int byteCount) + { + var end = bytes.IndexOf(RespConstants.CrlfBytes); + if (end < 0) + { + byteCount = value = 0; + if (bytes.Length >= RespConstants.MaxRawBytesInt32 + 2) + { + ThrowProtocolFailure("Unterminated or over-length integer"); // should have failed; report failure to prevent infinite loop + } + return LengthPrefixResult.NeedMoreData; + } + byteCount = end + 2; + switch (end) + { + case 0: + ThrowProtocolFailure("Length prefix expected"); + goto case default; // not reached, just satisfying definite assignment + case 1 when bytes[0] == (byte)'?': + value = 0; + return LengthPrefixResult.Streaming; + default: + if (end > RespConstants.MaxRawBytesInt32 || !(Utf8Parser.TryParse(bytes, out value, out var consumed) && consumed == end)) + { + ThrowProtocolFailure("Unable to parse integer"); + value = 0; + } + if (value < 0) + { + if (value == -1) + { + value = 0; + return LengthPrefixResult.Null; + } + ThrowProtocolFailure("Invalid negative length prefix"); + } + return LengthPrefixResult.Length; + } + } + + private readonly RespReader Clone() => this; // useful for performing streaming operations without moving the primary + + [MethodImpl(MethodImplOptions.NoInlining), DoesNotReturn] + private static void ThrowProtocolFailure(string message) + => throw new InvalidOperationException("RESP protocol failure: " + message); // protocol exception? + + [MethodImpl(MethodImplOptions.NoInlining), DoesNotReturn] + internal static void ThrowEof() => throw new EndOfStreamException(); + + [MethodImpl(MethodImplOptions.NoInlining), DoesNotReturn] + private static void ThrowFormatException() => throw new FormatException(); + + private int RawTryReadByte() + { + if (_bufferIndex < CurrentLength || TryMoveToNextSegment()) + { + var result = UnsafeCurrent; + _bufferIndex++; + return result; + } + return -1; + } + + private int RawPeekByte() + { + return (CurrentLength < _bufferIndex || TryMoveToNextSegment()) ? UnsafeCurrent : -1; + } + + private bool RawAssertCrLf() + { + if (CurrentAvailable >= 2) + { + UnsafeAssertClLf(0); + _bufferIndex += 2; + return true; + } + else + { + int next = RawTryReadByte(); + if (next < 0) return false; + if (next == '\r') + { + next = RawTryReadByte(); + if (next < 0) return false; + if (next == '\n') return true; + } + ThrowProtocolFailure("Expected CR/LF"); + return false; + } + } + + private LengthPrefixResult RawTryReadLengthPrefix() + { + _length = 0; + if (!RawTryFindCrLf(out int end)) + { + if (TotalAvailable >= RespConstants.MaxRawBytesInt32 + 2) + { + ThrowProtocolFailure("Unterminated or over-length integer"); // should have failed; report failure to prevent infinite loop + } + return LengthPrefixResult.NeedMoreData; + } + + switch (end) + { + case 0: + ThrowProtocolFailure("Length prefix expected"); + goto case default; // not reached, just satisfying definite assignment + case 1: + var b = (byte)RawTryReadByte(); + RawAssertCrLf(); + if (b == '?') + { + return LengthPrefixResult.Streaming; + } + else + { + _length = ParseSingleDigit(b); + return LengthPrefixResult.Length; + } + default: + if (end > RespConstants.MaxRawBytesInt32) + { + ThrowProtocolFailure("Unable to parse integer"); + } + Span bytes = stackalloc byte[end]; + RawFillBytes(bytes); + RawAssertCrLf(); + if (!(Utf8Parser.TryParse(bytes, out _length, out var consumed) && consumed == end)) + { + ThrowProtocolFailure("Unable to parse integer"); + } + + if (_length < 0) + { + if (_length == -1) + { + _length = 0; + return LengthPrefixResult.Null; + } + ThrowProtocolFailure("Invalid negative length prefix"); + } + + return LengthPrefixResult.Length; + } + } + + private void RawFillBytes(scoped Span target) + { + do + { + var current = CurrentSpan(); + if (current.Length >= target.Length) + { + // more than enough, need to trim + current.Slice(0, target.Length).CopyTo(target); + _bufferIndex += target.Length; + return; // we're done + } + else + { + // take what we can + current.CopyTo(target); + target = target.Slice(current.Length); + // we could move _bufferIndex here, but we're about to trash that in TryMoveToNextSegment + } + } + while (TryMoveToNextSegment()); + ThrowEof(); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static int ParseSingleDigit(byte value) + { + return value switch + { + (byte)'0' or (byte)'1' or (byte)'2' or (byte)'3' or (byte)'4' or (byte)'5' or (byte)'6' or (byte)'7' or (byte)'8' or (byte)'9' => value - (byte)'0', + _ => Invalid(value), + }; + + [MethodImpl(MethodImplOptions.NoInlining), DoesNotReturn] + static int Invalid(byte value) => throw new FormatException($"Unable to parse integer: '{(char)value}'"); + } + + private readonly bool RawTryAssertInlineScalarPayloadCrLf() + { + Debug.Assert(IsInlineScalar, "should be inline scalar"); + + var reader = Clone(); + var len = reader._length; + if (len == 0) return reader.RawAssertCrLf(); + + do + { + var current = reader.CurrentSpan(); + if (current.Length >= len) + { + reader._bufferIndex += len; + return reader.RawAssertCrLf(); // we're done + } + else + { + // take what we can + len -= current.Length; + // we could move _bufferIndex here, but we're about to trash that in TryMoveToNextSegment + } + } + while (reader.TryMoveToNextSegment()); + return false; // EOF + } + + private readonly bool RawTryFindCrLf(out int length) + { + length = 0; + RespReader reader = Clone(); + do + { + var span = reader.CurrentSpan(); + var index = span.IndexOf((byte)'\r'); + if (index >= 0) + { + checked + { + length += index; + } + // move past the CR and assert the LF + reader._bufferIndex += index + 1; + var next = reader.RawTryReadByte(); + if (next < 0) break; // we don't know + if (next != '\n') ThrowProtocolFailure("CR/LF expected"); + + return true; + } + checked + { + length += span.Length; + } + } + while (reader.TryMoveToNextSegment()); + length = 0; + return false; + } + + private string GetDebuggerDisplay() + { + return ToString(); + } + + internal readonly int GetInitialScanCount(out ushort streamingAggregateDepth) + { + // this is *similar* to GetDelta, but: without any discount for attributes + switch (_flags & (RespFlags.IsAggregate | RespFlags.IsStreaming)) + { + case RespFlags.IsAggregate: + streamingAggregateDepth = 0; + return _length - 1; + case RespFlags.IsAggregate | RespFlags.IsStreaming: + streamingAggregateDepth = 1; + return 0; + default: + streamingAggregateDepth = 0; + return -1; + } + } +} diff --git a/src/RESPite/Messages/RespReader.cs b/src/RESPite/Messages/RespReader.cs new file mode 100644 index 000000000..4c99187d3 --- /dev/null +++ b/src/RESPite/Messages/RespReader.cs @@ -0,0 +1,1772 @@ +using System.Buffers; +using System.Buffers.Text; +using System.ComponentModel; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; +using System.Runtime.CompilerServices; +using System.Text; +using RESPite.Internal; + +#if NETCOREAPP3_0_OR_GREATER +using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.X86; +#endif + +#pragma warning disable IDE0079 // Remove unnecessary suppression +#pragma warning disable CS0282 // There is no defined ordering between fields in multiple declarations of partial struct +#pragma warning restore IDE0079 // Remove unnecessary suppression + +namespace RESPite.Messages; + +/// +/// Provides low level RESP parsing functionality. +/// +public ref partial struct RespReader +{ + [Flags] + private enum RespFlags : byte + { + None = 0, + IsScalar = 1 << 0, // simple strings, bulk strings, etc + IsAggregate = 1 << 1, // arrays, maps, sets, etc + IsNull = 1 << 2, // explicit null RESP types, or bulk-strings/aggregates with length -1 + IsInlineScalar = 1 << 3, // a non-null scalar, i.e. with payload+CrLf + IsAttribute = 1 << 4, // is metadata for following elements + IsStreaming = 1 << 5, // unknown length + IsError = 1 << 6, // an explicit error reported inside the protocol + } + + // relates to the element we're currently reading + private RespFlags _flags; + private RespPrefix _prefix; + + private int _length; // for null: 0; for scalars: the length of the payload; for aggregates: the child count + + // the current buffer that we're observing + private int _bufferIndex; // after TryRead, this should be positioned immediately before the actual data + + // the position in a multi-segment payload + private long _positionBase; // total data we've already moved past in *previous* buffers + private ReadOnlySequenceSegment? _tail; // the next tail node + private long _remainingTailLength; // how much more can we consume from the tail? + + public long ProtocolBytesRemaining => TotalAvailable; + + private readonly int CurrentAvailable => CurrentLength - _bufferIndex; + + private readonly long TotalAvailable => CurrentAvailable + _remainingTailLength; + private partial void UnsafeTrimCurrentBy(int count); + private readonly partial ref byte UnsafeCurrent { get; } + private readonly partial int CurrentLength { get; } + private partial void SetCurrent(ReadOnlySpan value); + private RespPrefix UnsafePeekPrefix() => (RespPrefix)UnsafeCurrent; + private readonly partial ReadOnlySpan UnsafePastPrefix(); + private readonly partial ReadOnlySpan CurrentSpan(); + + /// + /// Get the scalar value as a single-segment span. + /// + /// True if this is a non-streaming scalar element that covers a single span only, otherwise False. + /// If a scalar reports False, can be used to iterate the entire payload. + /// When True, the contents of the scalar value. + public readonly bool TryGetSpan(out ReadOnlySpan value) + { + if (IsInlineScalar && CurrentAvailable >= _length) + { + value = CurrentSpan().Slice(0, _length); + return true; + } + + value = default; + return IsNullScalar; + } + + /// + /// Returns the position after the end of the current element. + /// + public readonly long BytesConsumed => _positionBase + _bufferIndex + TrailingLength; + + /// + /// Body length of scalar values, plus any terminating sentinels. + /// + private readonly int TrailingLength => (_flags & RespFlags.IsInlineScalar) == 0 ? 0 : (_length + 2); + + /// + /// Gets the RESP kind of the current element. + /// + public readonly RespPrefix Prefix => _prefix; + + /// + /// The payload length of this scalar element (includes combined length for streaming scalars). + /// + public readonly int ScalarLength() => + IsInlineScalar ? _length : IsNullScalar ? 0 : checked((int)ScalarLengthSlow()); + + /// + /// Indicates whether this scalar value is zero-length. + /// + public readonly bool ScalarIsEmpty() => + IsInlineScalar ? _length == 0 : (IsNullScalar || !ScalarChunks().MoveNext()); + + /// + /// The payload length of this scalar element (includes combined length for streaming scalars). + /// + public readonly long ScalarLongLength() => IsInlineScalar ? _length : IsNullScalar ? 0 : ScalarLengthSlow(); + + private readonly long ScalarLengthSlow() + { + DemandScalar(); + long length = 0; + var iterator = ScalarChunks(); + while (iterator.MoveNext()) + { + length += iterator.CurrentLength; + } + + return length; + } + + /// + /// The number of child elements associated with an aggregate. + /// + /// For + /// and aggregates, this is twice the value reported in the RESP protocol, + /// i.e. a map of the form %2\r\n... will report 4 as the length. + /// Note that if the data could be streaming (), it may be preferable to use + /// the API, using the API to update the outer reader. + public readonly int AggregateLength() => + (_flags & (RespFlags.IsAggregate | RespFlags.IsStreaming)) == RespFlags.IsAggregate + ? _length + : AggregateLengthSlow(); + + public delegate T Projection(ref RespReader value); + + public void FillAll(scoped Span target, Projection projection) + { + DemandNotNull(); + AggregateChildren().FillAll(target, projection); + } + + private readonly int AggregateLengthSlow() + { + switch (_flags & (RespFlags.IsAggregate | RespFlags.IsStreaming)) + { + case RespFlags.IsAggregate: + return _length; + case RespFlags.IsAggregate | RespFlags.IsStreaming: + break; + default: + DemandAggregate(); // we expect this to throw + break; + } + + int count = 0; + var reader = Clone(); + while (true) + { + if (!reader.TryMoveNext()) ThrowEof(); + if (reader.Prefix == RespPrefix.StreamTerminator) + { + return count; + } + + reader.SkipChildren(); + count++; + } + } + + /// + /// Indicates whether this is a scalar value, i.e. with a potential payload body. + /// + public readonly bool IsScalar => (_flags & RespFlags.IsScalar) != 0; + + internal readonly bool IsInlineScalar => (_flags & RespFlags.IsInlineScalar) != 0; + + internal readonly bool IsNullScalar => + (_flags & (RespFlags.IsScalar | RespFlags.IsNull)) == (RespFlags.IsScalar | RespFlags.IsNull); + + /// + /// Indicates whether this is an aggregate value, i.e. represents a collection of sub-values. + /// + public readonly bool IsAggregate => (_flags & RespFlags.IsAggregate) != 0; + + /// + /// Indicates whether this is a null value; this could be an explicit , + /// or a scalar or aggregate a negative reported length. + /// + public readonly bool IsNull => (_flags & RespFlags.IsNull) != 0; + + /// + /// Indicates whether this is an attribute value, i.e. metadata relating to later element data. + /// + public readonly bool IsAttribute => (_flags & RespFlags.IsAttribute) != 0; + + /// + /// Indicates whether this represents streaming content, where the or is not known in advance. + /// + public readonly bool IsStreaming => (_flags & RespFlags.IsStreaming) != 0; + + /// + /// Equivalent to both and . + /// + internal readonly bool IsStreamingScalar => (_flags & (RespFlags.IsScalar | RespFlags.IsStreaming)) == + (RespFlags.IsScalar | RespFlags.IsStreaming); + + /// + /// Indicates errors reported inside the protocol. + /// + public readonly bool IsError => (_flags & RespFlags.IsError) != 0; + + /// + /// Gets the effective change (in terms of how many RESP nodes we expect to see) from consuming this element. + /// For simple scalars, this is -1 because we have one less node to read; for simple aggregates, this is + /// AggregateLength-1 because we will have consumed one element, but now need to read the additional + /// child elements. Attributes report 0, since they supplement data + /// we still need to consume. The final terminator for streaming data reports a delta of -1, otherwise: 0. + /// + /// This does not account for being nested inside a streaming aggregate; the caller must deal with that manually. + internal int Delta() => + (_flags & (RespFlags.IsScalar | RespFlags.IsAggregate | RespFlags.IsStreaming | RespFlags.IsAttribute)) switch + { + RespFlags.IsScalar => -1, + RespFlags.IsAggregate => _length - 1, + RespFlags.IsAggregate | RespFlags.IsAttribute => _length, + _ => 0, + }; + + /// + /// Assert that this is the final element in the current payload. + /// + /// If additional elements are available. + public void DemandEnd() + { + while (IsStreamingScalar) + { + if (!TryReadNext()) ThrowEof(); + } + + if (TryReadNext()) + { + Throw(Prefix); + } + + static void Throw(RespPrefix prefix) => + throw new InvalidOperationException($"Expected end of payload, but found {prefix}"); + } + + private bool TryReadNextSkipAttributes() + { + while (TryReadNext()) + { + if (IsAttribute) + { + SkipChildren(); + } + else + { + return true; + } + } + + return false; + } + + private bool TryReadNextProcessAttributes(RespAttributeReader respAttributeReader, ref T attributes) + { + while (TryReadNext()) + { + if (IsAttribute) + { + respAttributeReader.Read(ref this, ref attributes); + } + else + { + return true; + } + } + + return false; + } + + /// + /// Move to the next content element; this skips attribute metadata, checking for RESP error messages by default. + /// + /// If the data is exhausted before a streaming scalar is exhausted. + /// If the data contains an explicit error element. + public bool TryMoveNext() + { + while (IsStreamingScalar) // close out the current streaming scalar + { + if (!TryReadNextSkipAttributes()) ThrowEof(); + } + + if (TryReadNextSkipAttributes()) + { + if (IsError) ThrowError(); + return true; + } + + return false; + } + + /// + /// Move to the next content element; this skips attribute metadata, checking for RESP error messages by default. + /// + /// Whether to check and throw for error messages. + /// If the data is exhausted before a streaming scalar is exhausted. + /// If the data contains an explicit error element. + public bool TryMoveNext(bool checkError) + { + while (IsStreamingScalar) // close out the current streaming scalar + { + if (!TryReadNextSkipAttributes()) ThrowEof(); + } + + if (TryReadNextSkipAttributes()) + { + if (checkError && IsError) ThrowError(); + return true; + } + + return false; + } + + /// + /// Move to the next content element; this skips attribute metadata, checking for RESP error messages by default. + /// + /// Parser for attribute data preceding the data. + /// The state for attributes encountered. + /// If the data is exhausted before a streaming scalar is exhausted. + /// If the data contains an explicit error element. + /// The type of data represented by this reader. + public bool TryMoveNext(RespAttributeReader respAttributeReader, ref T attributes) + { + while (IsStreamingScalar) // close out the current streaming scalar + { + if (!TryReadNextSkipAttributes()) ThrowEof(); + } + + if (TryReadNextProcessAttributes(respAttributeReader, ref attributes)) + { + if (IsError) ThrowError(); + return true; + } + + return false; + } + + /// + /// Move to the next content element, asserting that it is of the expected type; this skips attribute metadata, checking for RESP error messages by default. + /// + /// The expected data type. + /// If the data is exhausted before a streaming scalar is exhausted. + /// If the data contains an explicit error element. + /// If the data is not of the expected type. + public bool TryMoveNext(RespPrefix prefix) + { + bool result = TryMoveNext(); + if (result) Demand(prefix); + return result; + } + + /// + /// Move to the next content element; this skips attribute metadata, checking for RESP error messages by default. + /// + /// If the data is exhausted before content is found. + /// If the data contains an explicit error element. + public void MoveNext() + { + if (!TryMoveNext()) ThrowEof(); + } + + /// + /// Move to the next content element; this skips attribute metadata, checking for RESP error messages by default. + /// + /// Parser for attribute data preceding the data. + /// The state for attributes encountered. + /// If the data is exhausted before content is found. + /// If the data contains an explicit error element. + /// The type of data represented by this reader. + public void MoveNext(RespAttributeReader respAttributeReader, ref T attributes) + { + if (!TryMoveNext(respAttributeReader, ref attributes)) ThrowEof(); + } + + private bool MoveNextStreamingScalar() + { + if (IsStreamingScalar) + { + while (TryReadNext()) + { + if (IsAttribute) + { + SkipChildren(); + } + else + { + if (Prefix != RespPrefix.StreamContinuation) + ThrowProtocolFailure("Streaming continuation expected"); + return _length > 0; + } + } + + ThrowEof(); // we should have found something! + } + + return false; + } + + /// + /// Move to the next content element () and assert that it is a scalar (). + /// + /// If the data is exhausted before content is found. + /// If the data contains an explicit error element. + /// If the data is not a scalar type. + public void MoveNextScalar() + { + MoveNext(); + DemandScalar(); + } + + /// + /// Move to the next content element () and assert that it is an aggregate (). + /// + /// If the data is exhausted before content is found. + /// If the data contains an explicit error element. + /// If the data is not an aggregate type. + public void MoveNextAggregate() + { + MoveNext(); + DemandAggregate(); + } + + /// + /// Move to the next content element () and assert that it of type specified + /// in . + /// + /// The expected data type. + /// Parser for attribute data preceding the data. + /// The state for attributes encountered. + /// If the data is exhausted before content is found. + /// If the data contains an explicit error element. + /// If the data is not of the expected type. + /// The type of data represented by this reader. + public void MoveNext(RespPrefix prefix, RespAttributeReader respAttributeReader, ref T attributes) + { + MoveNext(respAttributeReader, ref attributes); + Demand(prefix); + } + + /// + /// Move to the next content element () and assert that it of type specified + /// in . + /// + /// The expected data type. + /// If the data is exhausted before content is found. + /// If the data contains an explicit error element. + /// If the data is not of the expected type. + public void MoveNext(RespPrefix prefix) + { + MoveNext(); + Demand(prefix); + } + + internal void Demand(RespPrefix prefix) + { + if (Prefix != prefix) Throw(prefix, Prefix); + + static void Throw(RespPrefix expected, RespPrefix actual) => + throw new InvalidOperationException($"Expected {expected} element, but found {actual}."); + } + + private readonly void ThrowError() => throw new RespException(ReadString()!); + + /// + /// Skip all sub elements of the current node; this includes both aggregate children and scalar streaming elements. + /// + public void SkipChildren() + { + // if this is a simple non-streaming scalar, then: there's nothing complex to do; otherwise, re-use the + // frame scanner logic to seek past the noise (this way, we avoid recursion etc) + switch (_flags & (RespFlags.IsScalar | RespFlags.IsAggregate | RespFlags.IsStreaming)) + { + case RespFlags.None: + // no current element + break; + case RespFlags.IsScalar: + // simple scalar + MovePastCurrent(); + break; + default: + // something more complex + RespScanState state = new(in this); + if (!state.TryRead(ref this, out _)) ThrowEof(); + break; + } + } + + /// + /// Reads the current element as a string value. + /// + public readonly string? ReadString() => ReadString(out _); + + /// + /// Reads the current element as a string value. + /// + public readonly string? ReadString(out string prefix) + { + byte[] pooled = []; + try + { + var span = Buffer(ref pooled, stackalloc byte[256]); + prefix = ""; + if (span.IsEmpty) + { + return IsNull ? null : ""; + } + + if (Prefix == RespPrefix.VerbatimString + && span.Length >= 4 && span[3] == ':') + { + // "the first three bytes provide information about the format of the following string, + // which can be txt for plain text, or mkd for markdown. The fourth byte is always :. + // Then the real string follows." + var prefixValue = RespConstants.UnsafeCpuUInt32(span); + if (prefixValue == PrefixTxt) + { + prefix = "txt"; + } + else if (prefixValue == PrefixMkd) + { + prefix = "mkd"; + } + else + { + prefix = RespConstants.UTF8.GetString(span.Slice(0, 3)); + } + + span = span.Slice(4); + } + + return RespConstants.UTF8.GetString(span); + } + finally + { + ArrayPool.Shared.Return(pooled); + } + } + + private static readonly uint + PrefixTxt = RespConstants.UnsafeCpuUInt32("txt:"u8), + PrefixMkd = RespConstants.UnsafeCpuUInt32("mkd:"u8); + + /// + /// Reads the current element as a string value. + /// + public readonly byte[]? ReadByteArray() + { + byte[] pooled = []; + try + { + var span = Buffer(ref pooled, stackalloc byte[256]); + if (span.IsEmpty) + { + return IsNull ? null : []; + } + + return span.ToArray(); + } + finally + { + ArrayPool.Shared.Return(pooled); + } + } + + /// + /// Reads the current element using a general purpose text parser. + /// + /// The type of data being parsed. + public readonly T ParseBytes(Parser parser) + { + byte[] pooled = []; + var span = Buffer(ref pooled, stackalloc byte[256]); + try + { + return parser(span); + } + finally + { + ArrayPool.Shared.Return(pooled); + } + } + + /// + /// Reads the current element using a general purpose text parser. + /// + /// The type of data being parsed. + /// State required by the parser. + public readonly T ParseBytes(Parser parser, TState? state) + { + byte[] pooled = []; + var span = Buffer(ref pooled, stackalloc byte[256]); + try + { + return parser(span, default); + } + finally + { + ArrayPool.Shared.Return(pooled); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal readonly ReadOnlySpan Buffer(Span target) + { + if (TryGetSpan(out var simple)) + { + return simple; + } + +#if NET6_0_OR_GREATER + return BufferSlow(ref Unsafe.NullRef(), target, usePool: false); +#else + byte[] pooled = []; + return BufferSlow(ref pooled, target, usePool: false); +#endif + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal readonly ReadOnlySpan Buffer(scoped ref byte[] pooled, Span target = default) + => TryGetSpan(out var simple) ? simple : BufferSlow(ref pooled, target, true); + + [MethodImpl(MethodImplOptions.NoInlining)] + private readonly ReadOnlySpan BufferSlow(scoped ref byte[] pooled, Span target, bool usePool) + { + DemandScalar(); + + if (IsInlineScalar && usePool) + { + // grow to the correct size in advance, if needed + var length = ScalarLength(); + if (length > target.Length) + { + var bigger = ArrayPool.Shared.Rent(length); + ArrayPool.Shared.Return(pooled); + target = pooled = bigger; + } + } + + var iterator = ScalarChunks(); + ReadOnlySpan current; + int offset = 0; + while (iterator.MoveNext()) + { + // will the current chunk fit? + current = iterator.Current; + if (current.TryCopyTo(target.Slice(offset))) + { + // fits into the current buffer + offset += current.Length; + } + else if (!usePool) + { + // rent disallowed; fill what we can + var available = target.Slice(offset); + current.Slice(0, available.Length).CopyTo(available); + return target; // we filled it + } + else + { + // rent a bigger buffer, copy and recycle + var bigger = ArrayPool.Shared.Rent(offset + current.Length); + if (offset != 0) + { + target.Slice(0, offset).CopyTo(bigger); + } + + ArrayPool.Shared.Return(pooled); + target = pooled = bigger; + current.CopyTo(target.Slice(offset)); + } + } + + return target.Slice(0, offset); + } + + /// + /// Reads the current element using a general purpose byte parser. + /// + /// The type of data being parsed. + public readonly T ParseChars(Parser parser) + { + byte[] bArr = []; + char[] cArr = []; + try + { + var bSpan = Buffer(ref bArr, stackalloc byte[128]); + var maxChars = RespConstants.UTF8.GetMaxCharCount(bSpan.Length); + Span cSpan = maxChars <= 128 ? stackalloc char[128] : (cArr = ArrayPool.Shared.Rent(maxChars)); + int chars = RespConstants.UTF8.GetChars(bSpan, cSpan); + return parser(cSpan.Slice(0, chars)); + } + finally + { + ArrayPool.Shared.Return(bArr); + ArrayPool.Shared.Return(cArr); + } + } + + /// + /// Reads the current element using a general purpose byte parser. + /// + /// The type of data being parsed. + /// State required by the parser. + public readonly T ParseChars(Parser parser, TState? state) + { + byte[] bArr = []; + char[] cArr = []; + try + { + var bSpan = Buffer(ref bArr, stackalloc byte[128]); + var maxChars = RespConstants.UTF8.GetMaxCharCount(bSpan.Length); + Span cSpan = maxChars <= 128 ? stackalloc char[128] : (cArr = ArrayPool.Shared.Rent(maxChars)); + int chars = RespConstants.UTF8.GetChars(bSpan, cSpan); + return parser(cSpan.Slice(0, chars), state); + } + finally + { + ArrayPool.Shared.Return(bArr); + ArrayPool.Shared.Return(cArr); + } + } + +#if NET7_0_OR_GREATER + /// + /// Reads the current element using . + /// + /// The type of data being parsed. +#pragma warning disable RS0016, RS0027 // back-compat overload + public readonly T ParseChars(IFormatProvider? formatProvider = null) where T : ISpanParsable +#pragma warning restore RS0016, RS0027 // back-compat overload + { + byte[] bArr = []; + char[] cArr = []; + try + { + var bSpan = Buffer(ref bArr, stackalloc byte[128]); + var maxChars = RespConstants.UTF8.GetMaxCharCount(bSpan.Length); + Span cSpan = maxChars <= 128 ? stackalloc char[128] : (cArr = ArrayPool.Shared.Rent(maxChars)); + int chars = RespConstants.UTF8.GetChars(bSpan, cSpan); + return T.Parse(cSpan.Slice(0, chars), formatProvider ?? CultureInfo.InvariantCulture); + } + finally + { + ArrayPool.Shared.Return(bArr); + ArrayPool.Shared.Return(cArr); + } + } +#endif + +#if NET8_0_OR_GREATER + /// + /// Reads the current element using . + /// + /// The type of data being parsed. +#pragma warning disable RS0016, RS0027 // back-compat overload + public readonly T ParseBytes(IFormatProvider? formatProvider = null) where T : IUtf8SpanParsable +#pragma warning restore RS0016, RS0027 // back-compat overload + { + byte[] bArr = []; + try + { + var bSpan = Buffer(ref bArr, stackalloc byte[128]); + return T.Parse(bSpan, formatProvider ?? CultureInfo.InvariantCulture); + } + finally + { + ArrayPool.Shared.Return(bArr); + } + } +#endif + + /// + /// General purpose parsing callback. + /// + /// The type of source data being parsed. + /// State required by the parser. + /// The output type of data being parsed. + public delegate TValue Parser(ReadOnlySpan value, TState? state); + + /// + /// General purpose parsing callback. + /// + /// The type of source data being parsed. + /// The output type of data being parsed. + public delegate TValue Parser(ReadOnlySpan value); + + /// + /// Initializes a new instance of the struct. + /// + /// The raw contents to parse with this instance. + public RespReader(ReadOnlySpan value) + { + _length = 0; + _flags = RespFlags.None; + _prefix = RespPrefix.None; + SetCurrent(value); + + _remainingTailLength = _positionBase = 0; + _tail = null; + } + + private void MovePastCurrent() + { + // skip past the trailing portion of a value, if any + var skip = TrailingLength; + if (_bufferIndex + skip <= CurrentLength) + { + _bufferIndex += skip; // available in the current buffer + } + else + { + AdvanceSlow(skip); + } + + // reset the current state + _length = 0; + _flags = 0; + _prefix = RespPrefix.None; + } + + /// + public RespReader(scoped in ReadOnlySequence value) +#if NETCOREAPP3_0_OR_GREATER + : this(value.FirstSpan) +#else + : this(value.First.Span) +#endif + { + if (!value.IsSingleSegment) + { + _remainingTailLength = value.Length - CurrentLength; + _tail = (value.Start.GetObject() as ReadOnlySequenceSegment)?.Next ?? MissingNext(); + } + + [MethodImpl(MethodImplOptions.NoInlining), DoesNotReturn] + static ReadOnlySequenceSegment MissingNext() => + throw new ArgumentException("Unable to extract tail segment", nameof(value)); + } + + /// + /// Attempt to move to the next RESP element. + /// + /// Unless you are intentionally handling errors, attributes and streaming data, should be preferred. + [EditorBrowsable(EditorBrowsableState.Never), Browsable(false)] + public unsafe bool TryReadNext() + { + MovePastCurrent(); + +#if NETCOREAPP3_0_OR_GREATER + // check what we have available; don't worry about zero/fetching the next segment; this is only + // for SIMD lookup, and zero would only apply when data ends exactly on segment boundaries, which + // is incredible niche + var available = CurrentAvailable; + + if (Avx2.IsSupported && Bmi1.IsSupported && available >= sizeof(uint)) + { + // read the first 4 bytes + ref byte origin = ref UnsafeCurrent; + var comparand = Unsafe.ReadUnaligned(ref origin); + + // broadcast those 4 bytes into a vector, mask to get just the first and last byte, and apply a SIMD equality test with our known cases + var eqs = + Avx2.CompareEqual(Avx2.And(Avx2.BroadcastScalarToVector256(&comparand), Raw.FirstLastMask), Raw.CommonRespPrefixes); + + // reinterpret that as floats, and pick out the sign bits (which will be 1 for "equal", 0 for "not equal"); since the + // test cases are mutually exclusive, we expect zero or one matches, so: lzcount tells us which matched + var index = + Bmi1.TrailingZeroCount((uint)Avx.MoveMask(Unsafe.As, Vector256>(ref eqs))); + int len; +#if DEBUG + if (VectorizeDisabled) index = uint.MaxValue; // just to break the switch +#endif + switch (index) + { + case Raw.CommonRespIndex_Success when available >= 5 && Unsafe.Add(ref origin, 4) == (byte)'\n': + _prefix = RespPrefix.SimpleString; + _length = 2; + _bufferIndex++; + _flags = RespFlags.IsScalar | RespFlags.IsInlineScalar; + return true; + case Raw.CommonRespIndex_SingleDigitInteger when Unsafe.Add(ref origin, 2) == (byte)'\r': + _prefix = RespPrefix.Integer; + _length = 1; + _bufferIndex++; + _flags = RespFlags.IsScalar | RespFlags.IsInlineScalar; + return true; + case Raw.CommonRespIndex_DoubleDigitInteger when available >= 5 && Unsafe.Add(ref origin, 4) == (byte)'\n': + _prefix = RespPrefix.Integer; + _length = 2; + _bufferIndex++; + _flags = RespFlags.IsScalar | RespFlags.IsInlineScalar; + return true; + case Raw.CommonRespIndex_SingleDigitString when Unsafe.Add(ref origin, 2) == (byte)'\r': + if (comparand == RespConstants.BulkStringStreaming) + { + _flags = RespFlags.IsScalar | RespFlags.IsStreaming; + } + else + { + len = ParseSingleDigit(Unsafe.Add(ref origin, 1)); + if (available < len + 6) break; // need more data + + UnsafeAssertClLf(4 + len); + _length = len; + _flags = RespFlags.IsScalar | RespFlags.IsInlineScalar; + } + _prefix = RespPrefix.BulkString; + _bufferIndex += 4; + return true; + case Raw.CommonRespIndex_DoubleDigitString when available >= 5 && Unsafe.Add(ref origin, 4) == (byte)'\n': + if (comparand == RespConstants.BulkStringNull) + { + _length = 0; + _flags = RespFlags.IsScalar | RespFlags.IsNull; + } + else + { + len = ParseDoubleDigitsNonNegative(ref Unsafe.Add(ref origin, 1)); + if (available < len + 7) break; // need more data + + UnsafeAssertClLf(5 + len); + _length = len; + _flags = RespFlags.IsScalar | RespFlags.IsInlineScalar; + } + _prefix = RespPrefix.BulkString; + _bufferIndex += 5; + return true; + case Raw.CommonRespIndex_SingleDigitArray when Unsafe.Add(ref origin, 2) == (byte)'\r': + if (comparand == RespConstants.ArrayStreaming) + { + _flags = RespFlags.IsAggregate | RespFlags.IsStreaming; + } + else + { + _flags = RespFlags.IsAggregate; + _length = ParseSingleDigit(Unsafe.Add(ref origin, 1)); + } + _prefix = RespPrefix.Array; + _bufferIndex += 4; + return true; + case Raw.CommonRespIndex_DoubleDigitArray when available >= 5 && Unsafe.Add(ref origin, 4) == (byte)'\n': + if (comparand == RespConstants.ArrayNull) + { + _flags = RespFlags.IsAggregate | RespFlags.IsNull; + } + else + { + _length = ParseDoubleDigitsNonNegative(ref Unsafe.Add(ref origin, 1)); + _flags = RespFlags.IsAggregate; + } + _prefix = RespPrefix.Array; + _bufferIndex += 5; + return true; + case Raw.CommonRespIndex_Error: + len = UnsafePastPrefix().IndexOf(RespConstants.CrlfBytes); + if (len < 0) break; // need more data + + _prefix = RespPrefix.SimpleError; + _flags = RespFlags.IsScalar | RespFlags.IsInlineScalar | RespFlags.IsError; + _length = len; + _bufferIndex++; + return true; + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static int ParseDoubleDigitsNonNegative(ref byte value) => (10 * ParseSingleDigit(value)) + ParseSingleDigit(Unsafe.Add(ref value, 1)); +#endif + + // no fancy vectorization, but: we can still try to find the payload the fast way in a single segment + if (_bufferIndex + 3 <= CurrentLength) // shortest possible RESP fragment is length 3 + { + var remaining = UnsafePastPrefix(); + switch (_prefix = UnsafePeekPrefix()) + { + case RespPrefix.SimpleString: + case RespPrefix.SimpleError: + case RespPrefix.Integer: + case RespPrefix.Boolean: + case RespPrefix.Double: + case RespPrefix.BigInteger: + // CRLF-terminated + _length = remaining.IndexOf(RespConstants.CrlfBytes); + if (_length < 0) break; // can't find, need more data + _bufferIndex++; // payload follows prefix directly + _flags = RespFlags.IsScalar | RespFlags.IsInlineScalar; + if (_prefix == RespPrefix.SimpleError) _flags |= RespFlags.IsError; + return true; + case RespPrefix.BulkError: + case RespPrefix.BulkString: + case RespPrefix.VerbatimString: + // length prefix with value payload; first, the length + switch (TryReadLengthPrefix(remaining, out _length, out int consumed)) + { + case LengthPrefixResult.Length: + // still need to valid terminating CRLF + if (remaining.Length < consumed + _length + 2) break; // need more data + UnsafeAssertClLf(1 + consumed + _length); + + _flags = RespFlags.IsScalar | RespFlags.IsInlineScalar; + break; + case LengthPrefixResult.Null: + _flags = RespFlags.IsScalar | RespFlags.IsNull; + break; + case LengthPrefixResult.Streaming: + _flags = RespFlags.IsScalar | RespFlags.IsStreaming; + break; + } + + if (_flags == 0) break; // will need more data to know + if (_prefix == RespPrefix.BulkError) _flags |= RespFlags.IsError; + _bufferIndex += 1 + consumed; + return true; + case RespPrefix.StreamContinuation: + // length prefix, possibly with value payload; first, the length + switch (TryReadLengthPrefix(remaining, out _length, out consumed)) + { + case LengthPrefixResult.Length when _length == 0: + // EOF, no payload + _flags = RespFlags + .IsScalar; // don't claim as streaming, we want this to count towards delta-decrement + break; + case LengthPrefixResult.Length: + // still need to valid terminating CRLF + if (remaining.Length < consumed + _length + 2) break; // need more data + UnsafeAssertClLf(1 + consumed + _length); + + _flags = RespFlags.IsScalar | RespFlags.IsInlineScalar | RespFlags.IsStreaming; + break; + case LengthPrefixResult.Null: + case LengthPrefixResult.Streaming: + ThrowProtocolFailure("Invalid streaming scalar length prefix"); + break; + } + + if (_flags == 0) break; // will need more data to know + _bufferIndex += 1 + consumed; + return true; + case RespPrefix.Array: + case RespPrefix.Set: + case RespPrefix.Map: + case RespPrefix.Push: + case RespPrefix.Attribute: + // length prefix without value payload (child values follow) + switch (TryReadLengthPrefix(remaining, out _length, out consumed)) + { + case LengthPrefixResult.Length: + _flags = RespFlags.IsAggregate; + if (AggregateLengthNeedsDoubling()) _length *= 2; + break; + case LengthPrefixResult.Null: + _flags = RespFlags.IsAggregate | RespFlags.IsNull; + break; + case LengthPrefixResult.Streaming: + _flags = RespFlags.IsAggregate | RespFlags.IsStreaming; + break; + } + + if (_flags == 0) break; // will need more data to know + if (_prefix is RespPrefix.Attribute) _flags |= RespFlags.IsAttribute; + _bufferIndex += consumed + 1; + return true; + case RespPrefix.Null: // null + // note we already checked we had 3 bytes + UnsafeAssertClLf(1); + _flags = RespFlags.IsScalar | RespFlags.IsNull; + _bufferIndex += 3; // skip prefix+terminator + return true; + case RespPrefix.StreamTerminator: + // note we already checked we had 3 bytes + UnsafeAssertClLf(1); + _flags = RespFlags.IsAggregate; // don't claim as streaming - this counts towards delta + _bufferIndex += 3; // skip prefix+terminator + return true; + default: + ThrowProtocolFailure("Unexpected protocol prefix: " + _prefix); + return false; + } + } + + return TryReadNextSlow(ref this); + } + + private static bool TryReadNextSlow(ref RespReader live) + { + // in the case of failure, we don't want to apply any changes, + // so we work against an isolated copy until we're happy + live.MovePastCurrent(); + RespReader isolated = live; + + int next = isolated.RawTryReadByte(); + if (next < 0) return false; + + switch (isolated._prefix = (RespPrefix)next) + { + case RespPrefix.SimpleString: + case RespPrefix.SimpleError: + case RespPrefix.Integer: + case RespPrefix.Boolean: + case RespPrefix.Double: + case RespPrefix.BigInteger: + // CRLF-terminated + if (!isolated.RawTryFindCrLf(out isolated._length)) return false; + isolated._flags = RespFlags.IsScalar | RespFlags.IsInlineScalar; + if (isolated._prefix == RespPrefix.SimpleError) isolated._flags |= RespFlags.IsError; + break; + case RespPrefix.BulkError: + case RespPrefix.BulkString: + case RespPrefix.VerbatimString: + // length prefix with value payload + switch (isolated.RawTryReadLengthPrefix()) + { + case LengthPrefixResult.Length: + // still need to valid terminating CRLF + isolated._flags = RespFlags.IsScalar | RespFlags.IsInlineScalar; + if (!isolated.RawTryAssertInlineScalarPayloadCrLf()) return false; + break; + case LengthPrefixResult.Null: + isolated._flags = RespFlags.IsScalar | RespFlags.IsNull; + break; + case LengthPrefixResult.Streaming: + isolated._flags = RespFlags.IsScalar | RespFlags.IsStreaming; + break; + case LengthPrefixResult.NeedMoreData: + return false; + default: + ThrowProtocolFailure("Unexpected length prefix"); + return false; + } + + if (isolated._prefix == RespPrefix.BulkError) isolated._flags |= RespFlags.IsError; + break; + case RespPrefix.Array: + case RespPrefix.Set: + case RespPrefix.Map: + case RespPrefix.Push: + case RespPrefix.Attribute: + // length prefix without value payload (child values follow) + switch (isolated.RawTryReadLengthPrefix()) + { + case LengthPrefixResult.Length: + isolated._flags = RespFlags.IsAggregate; + if (isolated.AggregateLengthNeedsDoubling()) isolated._length *= 2; + break; + case LengthPrefixResult.Null: + isolated._flags = RespFlags.IsAggregate | RespFlags.IsNull; + break; + case LengthPrefixResult.Streaming: + isolated._flags = RespFlags.IsAggregate | RespFlags.IsStreaming; + break; + case LengthPrefixResult.NeedMoreData: + return false; + default: + ThrowProtocolFailure("Unexpected length prefix"); + return false; + } + + if (isolated._prefix is RespPrefix.Attribute) isolated._flags |= RespFlags.IsAttribute; + break; + case RespPrefix.Null: // null + if (!isolated.RawAssertCrLf()) return false; + isolated._flags = RespFlags.IsScalar | RespFlags.IsNull; + break; + case RespPrefix.StreamTerminator: + if (!isolated.RawAssertCrLf()) return false; + isolated._flags = RespFlags.IsAggregate; // don't claim as streaming - this counts towards delta + break; + case RespPrefix.StreamContinuation: + // length prefix, possibly with value payload; first, the length + switch (isolated.RawTryReadLengthPrefix()) + { + case LengthPrefixResult.Length when isolated._length == 0: + // EOF, no payload + isolated._flags = + RespFlags + .IsScalar; // don't claim as streaming, we want this to count towards delta-decrement + break; + case LengthPrefixResult.Length: + // still need to valid terminating CRLF + isolated._flags = RespFlags.IsScalar | RespFlags.IsInlineScalar | RespFlags.IsStreaming; + if (!isolated.RawTryAssertInlineScalarPayloadCrLf()) return false; // need more data + break; + case LengthPrefixResult.Null: + case LengthPrefixResult.Streaming: + ThrowProtocolFailure("Invalid streaming scalar length prefix"); + break; + case LengthPrefixResult.NeedMoreData: + default: + return false; + } + + break; + default: + ThrowProtocolFailure("Unexpected protocol prefix: " + isolated._prefix); + return false; + } + + // commit the speculative changes back, and accept + live = isolated; + return true; + } + + private void AdvanceSlow(long bytes) + { + while (bytes > 0) + { + var available = CurrentLength - _bufferIndex; + if (bytes <= available) + { + _bufferIndex += (int)bytes; + return; + } + + bytes -= available; + + if (!TryMoveToNextSegment()) Throw(); + } + + [DoesNotReturn] + static void Throw() => throw new EndOfStreamException( + "Unexpected end of payload; this is unexpected because we already validated that it was available!"); + } + + private bool AggregateLengthNeedsDoubling() => _prefix is RespPrefix.Map or RespPrefix.Attribute; + + private bool TryMoveToNextSegment() + { + while (_tail is not null && _remainingTailLength > 0) + { + var memory = _tail.Memory; + _tail = _tail.Next; + if (!memory.IsEmpty) + { + var span = memory.Span; // check we can get this before mutating anything + _positionBase += CurrentLength; + if (span.Length > _remainingTailLength) + { + span = span.Slice(0, (int)_remainingTailLength); + _remainingTailLength = 0; + } + else + { + _remainingTailLength -= span.Length; + } + + SetCurrent(span); + _bufferIndex = 0; + return true; + } + } + + return false; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal readonly bool IsOK() // go mad with this, because it is used so often + { + if (TryGetSpan(out var span) && span.Length == 2) + { + var u16 = Unsafe.ReadUnaligned(ref UnsafeCurrent); + return u16 == RespConstants.OKUInt16 | u16 == RespConstants.OKUInt16_LC; + } + + return IsSlow(RespConstants.OKBytes, RespConstants.OKBytes_LC); + } + + /// + /// Indicates whether the current element is a scalar with a value that matches the provided . + /// + /// The payload value to verify. + public readonly bool Is(ReadOnlySpan value) + => TryGetSpan(out var span) ? span.SequenceEqual(value) : IsSlow(value); + + /// + /// Indicates whether the current element is a scalar with a value that matches the provided . + /// + /// The payload value to verify. + public readonly bool Is(ReadOnlySpan value) + { + var bytes = RespConstants.UTF8.GetMaxByteCount(value.Length); + byte[]? oversized = null; + Span buffer = bytes <= 128 ? stackalloc byte[128] : (oversized = ArrayPool.Shared.Rent(bytes)); + bytes = RespConstants.UTF8.GetBytes(value, buffer); + bool result = Is(buffer.Slice(0, bytes)); + if (oversized is not null) ArrayPool.Shared.Return(oversized); + return result; + } + + internal readonly bool IsInlneCpuUInt32(uint value) + { + if (IsInlineScalar && _length == sizeof(uint)) + { + return CurrentAvailable >= sizeof(uint) + ? Unsafe.ReadUnaligned(ref UnsafeCurrent) == value + : SlowIsInlneCpuUInt32(value); + } + + return false; + } + + private readonly bool SlowIsInlneCpuUInt32(uint value) + { + Debug.Assert(IsInlineScalar && _length == sizeof(uint), "should be inline scalar of length 4"); + Span buffer = stackalloc byte[sizeof(uint)]; + var copy = this; + copy.RawFillBytes(buffer); + return RespConstants.UnsafeCpuUInt32(buffer) == value; + } + + /// + /// Indicates whether the current element is a scalar with a value that matches the provided . + /// + /// The payload value to verify. + public readonly bool Is(byte value) + { + if (IsInlineScalar && _length == 1 && CurrentAvailable >= 1) + { + return UnsafeCurrent == value; + } + + ReadOnlySpan span = [value]; + return IsSlow(span); + } + + private readonly bool IsSlow(ReadOnlySpan testValue0, ReadOnlySpan testValue2) + => IsSlow(testValue0) || IsSlow(testValue2); + + private readonly bool IsSlow(ReadOnlySpan testValue) + { + DemandScalar(); + if (IsNull) return false; // nothing equals null + if (TotalAvailable < testValue.Length) return false; + + if (!IsStreaming && testValue.Length != ScalarLength()) return false; + + var iterator = ScalarChunks(); + while (true) + { + if (testValue.IsEmpty) + { + // nothing left to test; if also nothing left to read, great! + return !iterator.MoveNext(); + } + + if (!iterator.MoveNext()) + { + return false; // test is longer + } + + var current = iterator.Current; + if (testValue.Length < current.Length) return false; // payload is longer + + if (!current.SequenceEqual(testValue.Slice(0, current.Length))) return false; // payload is different + + testValue = testValue.Slice(current.Length); // validated; continue + } + } + + /// + /// Copy the current scalar value out into the supplied , or as much as can be copied. + /// + /// The destination for the copy operation. + /// The number of bytes successfully copied. + public readonly int CopyTo(Span target) + { + if (TryGetSpan(out var value)) + { + if (target.Length < value.Length) value = value.Slice(0, target.Length); + + value.CopyTo(target); + return value.Length; + } + + int totalBytes = 0; + var iterator = ScalarChunks(); + while (iterator.MoveNext()) + { + value = iterator.Current; + if (target.Length <= value.Length) + { + value.Slice(0, target.Length).CopyTo(target); + return totalBytes + target.Length; + } + + value.CopyTo(target); + target = target.Slice(value.Length); + totalBytes += value.Length; + } + + return totalBytes; + } + + /// + /// Copy the current scalar value out into the supplied , or as much as can be copied. + /// + /// The destination for the copy operation. + /// The number of bytes successfully copied. + public readonly int CopyTo(IBufferWriter target) + { + if (TryGetSpan(out var value)) + { + target.Write(value); + return value.Length; + } + + int totalBytes = 0; + var iterator = ScalarChunks(); + while (iterator.MoveNext()) + { + value = iterator.Current; + target.Write(value); + totalBytes += value.Length; + } + + return totalBytes; + } + + /// + /// Asserts that the current element is not null. + /// + public void DemandNotNull() + { + if (IsNull) Throw(); + static void Throw() => throw new InvalidOperationException("A non-null element was expected"); + } + + /// + /// Read the current element as a value. + /// + [SuppressMessage("Style", "IDE0018:Inline variable declaration", Justification = "No it can't - conditional")] + public readonly long ReadInt64() + { + var span = Buffer(stackalloc byte[RespConstants.MaxRawBytesInt64 + 1]); + long value; + if (!(span.Length <= RespConstants.MaxRawBytesInt64 + && Utf8Parser.TryParse(span, out value, out int bytes) + && bytes == span.Length)) + { + ThrowFormatException(); + value = 0; + } + + return value; + } + + /// + /// Try to read the current element as a value. + /// + public readonly bool TryReadInt64(out long value) + { + var span = Buffer(stackalloc byte[RespConstants.MaxRawBytesInt64 + 1]); + if (span.Length <= RespConstants.MaxRawBytesInt64) + { + return Utf8Parser.TryParse(span, out value, out int bytes) & bytes == span.Length; + } + + value = 0; + return false; + } + + /// + /// Read the current element as a value. + /// + [SuppressMessage("Style", "IDE0018:Inline variable declaration", Justification = "No it can't - conditional")] + public readonly int ReadInt32() + { + var span = Buffer(stackalloc byte[RespConstants.MaxRawBytesInt32 + 1]); + int value; + if (!(span.Length <= RespConstants.MaxRawBytesInt32 + && Utf8Parser.TryParse(span, out value, out int bytes) + && bytes == span.Length)) + { + ThrowFormatException(); + value = 0; + } + + return value; + } + + /// + /// Try to read the current element as a value. + /// + public readonly bool TryReadInt32(out int value) + { + var span = Buffer(stackalloc byte[RespConstants.MaxRawBytesInt32 + 1]); + if (span.Length <= RespConstants.MaxRawBytesInt32) + { + return Utf8Parser.TryParse(span, out value, out int bytes) & bytes == span.Length; + } + + value = 0; + return false; + } + + /// + /// Read the current element as a value. + /// + public readonly double ReadDouble() + { + var span = Buffer(stackalloc byte[RespConstants.MaxRawBytesNumber + 1]); + + if (span.Length <= RespConstants.MaxRawBytesNumber + && Utf8Parser.TryParse(span, out double value, out int bytes) + && bytes == span.Length) + { + return value; + } + + switch (span.Length) + { + case 3 when "inf"u8.SequenceEqual(span): + return double.PositiveInfinity; + case 3 when "nan"u8.SequenceEqual(span): + return double.NaN; + case 4 when "+inf"u8.SequenceEqual(span): // not actually mentioned in spec, but: we'll allow it + return double.PositiveInfinity; + case 4 when "-inf"u8.SequenceEqual(span): + return double.NegativeInfinity; + } + + ThrowFormatException(); + return 0; + } + + /// + /// Try to read the current element as a value. + /// + public bool TryReadDouble(out double value, bool allowTokens = true) + { + var span = Buffer(stackalloc byte[RespConstants.MaxRawBytesNumber + 1]); + + if (span.Length <= RespConstants.MaxRawBytesNumber + && Utf8Parser.TryParse(span, out value, out int bytes) + && bytes == span.Length) + { + return true; + } + + if (allowTokens) + { + switch (span.Length) + { + case 3 when "inf"u8.SequenceEqual(span): + value = double.PositiveInfinity; + return true; + case 3 when "nan"u8.SequenceEqual(span): + value = double.NaN; + return true; + case 4 when "+inf"u8.SequenceEqual(span): // not actually mentioned in spec, but: we'll allow it + value = double.PositiveInfinity; + return true; + case 4 when "-inf"u8.SequenceEqual(span): + value = double.NegativeInfinity; + return true; + } + } + + value = 0; + return false; + } + + /// + /// Note this uses a stackalloc buffer; requesting too much may overflow the stack. + /// + internal readonly bool UnsafeTryReadShortAscii(out string value, int maxLength = 127) + { + var span = Buffer(stackalloc byte[maxLength + 1]); + value = ""; + if (span.IsEmpty) return true; + + if (span.Length <= maxLength) + { + // check for anything that looks binary or unicode + foreach (var b in span) + { + // allow [SPACE]-thru-[DEL], plus CR/LF + if (!(b < 127 & (b >= 32 | (b is 12 or 13)))) + { + return false; + } + } + + value = Encoding.UTF8.GetString(span); + return true; + } + + return false; + } + + /// + /// Read the current element as a value. + /// + [SuppressMessage("Style", "IDE0018:Inline variable declaration", Justification = "No it can't - conditional")] + public readonly decimal ReadDecimal() + { + var span = Buffer(stackalloc byte[RespConstants.MaxRawBytesNumber + 1]); + decimal value; + if (!(span.Length <= RespConstants.MaxRawBytesNumber + && Utf8Parser.TryParse(span, out value, out int bytes) + && bytes == span.Length)) + { + ThrowFormatException(); + value = 0; + } + + return value; + } + + /// + /// Read the current element as a value. + /// + public readonly bool ReadBoolean() + { + var span = Buffer(stackalloc byte[2]); + switch (span.Length) + { + case 1: + switch (span[0]) + { + case (byte)'0' when Prefix == RespPrefix.Integer: return false; + case (byte)'1' when Prefix == RespPrefix.Integer: return true; + case (byte)'f' when Prefix == RespPrefix.Boolean: return false; + case (byte)'t' when Prefix == RespPrefix.Boolean: return true; + } + + break; + case 2 when Prefix == RespPrefix.SimpleString && IsOK(): return true; + } + + ThrowFormatException(); + return false; + } + + /// + /// Parse a scalar value as an enum of type . + /// + /// The value to report if the value is not recognized. + /// The type of enum being parsed. + public readonly T ReadEnum(T unknownValue = default) where T : struct, Enum + { +#if NET6_0_OR_GREATER + return ParseChars(static (chars, state) => Enum.TryParse(chars, true, out T value) ? value : state, unknownValue); +#else + return Enum.TryParse(ReadString(), true, out T value) ? value : unknownValue; +#endif + } + + public TResult[]? ReadArray(Projection projection, bool scalar = false) + { + DemandAggregate(); + if (IsNull) return null; + var len = AggregateLength(); + if (len == 0) return []; + var result = new TResult[len]; + if (scalar) + { + // if the data to be consumed is simple (scalar), we can use + // a simpler path that doesn't need to worry about RESP subtrees + for (int i = 0; i < result.Length; i++) + { + MoveNextScalar(); + result[i] = projection(ref this); + } + } + else + { + var agg = AggregateChildren(); + agg.FillAll(result, projection); + agg.MovePast(out this); + } + + return result; + } + + public TResult[]? ReadPairArray( + Projection first, + Projection second, + Func combine, + bool scalar = true) + { + DemandAggregate(); + if (IsNull) return null; + int sourceLength = AggregateLength(); + if (sourceLength is 0 or 1) return []; + var result = new TResult[sourceLength >> 1]; + if (scalar) + { + // if the data to be consumed is simple (scalar), we can use + // a simpler path that doesn't need to worry about RESP subtrees + for (int i = 0; i < result.Length; i++) + { + MoveNextScalar(); + var x = first(ref this); + MoveNextScalar(); + var y = second(ref this); + result[i] = combine(x, y); + } + // if we have an odd number of source elements, skip the last one + if ((sourceLength & 1) != 0) MoveNextScalar(); + } + else + { + var agg = AggregateChildren(); + agg.FillAll(result, first, second, combine); + agg.MovePast(out this); + } + return result; + } + internal TResult[]? ReadLeasedPairArray( + Projection first, + Projection second, + Func combine, + out int count, + bool scalar = true) + { + DemandAggregate(); + if (IsNull) + { + count = 0; + return null; + } + int sourceLength = AggregateLength(); + count = sourceLength >> 1; + if (count is 0) return []; + + var oversized = ArrayPool.Shared.Rent(count); + var result = oversized.AsSpan(0, count); + if (scalar) + { + // if the data to be consumed is simple (scalar), we can use + // a simpler path that doesn't need to worry about RESP subtrees + for (int i = 0; i < result.Length; i++) + { + MoveNextScalar(); + var x = first(ref this); + MoveNextScalar(); + var y = second(ref this); + result[i] = combine(x, y); + } + // if we have an odd number of source elements, skip the last one + if ((sourceLength & 1) != 0) MoveNextScalar(); + } + else + { + var agg = AggregateChildren(); + agg.FillAll(result, first, second, combine); + agg.MovePast(out this); + } + return oversized; + } +} diff --git a/src/RESPite/Messages/RespScanState.cs b/src/RESPite/Messages/RespScanState.cs new file mode 100644 index 000000000..f40d08b96 --- /dev/null +++ b/src/RESPite/Messages/RespScanState.cs @@ -0,0 +1,160 @@ +using System.Buffers; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; + +namespace RESPite.Messages; + +/// +/// Holds state used for RESP frame parsing, i.e. detecting the RESP for an entire top-level message. +/// +public struct RespScanState +{ + /* + The key point of ScanState is to skim over a RESP stream with minimal frame processing, to find the + end of a single top-level RESP message. We start by expecting 1 message, and then just read, with the + rules that the end of a message subtracts one, and aggregates add N. Streaming scalars apply zero offset + until the scalar stream terminator. Attributes also apply zero offset. + Note that streaming aggregates change the rules - when at least one streaming aggregate is in effect, + no offsets are applied until we get back out of the outermost streaming aggregate - we achieve this + by simply counting the streaming aggregate depth, which is usually zero. + Note that in reality streaming (scalar and aggregates) and attributes are non-existent; in addition + to being specific to RESP3, no known server currently implements these parts of the RESP3 specification, + so everything here is theoretical, but: works according to the spec. + */ + private int _delta; // when this becomes -1, we have fully read a top-level message; + private ushort _streamingAggregateDepth; + private RespPrefix _prefix; + + public RespPrefix Prefix => _prefix; + + private long _totalBytes; +#if DEBUG + private int _elementCount; + + /// + public override string ToString() => $"{_prefix}, consumed: {_totalBytes} bytes, {_elementCount} nodes, complete: {IsComplete}"; +#else + /// + public override string ToString() => _prefix.ToString(); +#endif + + /// + public override bool Equals([NotNullWhen(true)] object? obj) => throw new NotSupportedException(); + + /// + public override int GetHashCode() => throw new NotSupportedException(); + + /// + /// Gets whether an entire top-level RESP message has been consumed. + /// + public bool IsComplete => _delta == -1; + + /// + /// Gets the total length of the payload read (or read so far, if it is not yet complete); this combines payloads from multiple + /// TryRead operations. + /// + public long TotalBytes => _totalBytes; + + // used when spotting common replies - we entirely bypass the usual reader/delta mechanism + internal void SetComplete(int totalBytes, RespPrefix prefix) + { + _totalBytes = totalBytes; + _delta = -1; + _prefix = prefix; +#if DEBUG + _elementCount = 1; +#endif + } + + /// + /// The amount of data, in bytes, to read before attempting to read the next frame. + /// + public const int MinBytes = 3; // minimum legal RESP frame is: _\r\n + + /// + /// Create a new value that can parse the supplied node (and subtree). + /// + internal RespScanState(in RespReader reader) + { + Debug.Assert(reader.Prefix != RespPrefix.None, "missing RESP prefix"); + _totalBytes = 0; + _delta = reader.GetInitialScanCount(out _streamingAggregateDepth); + } + + /// + /// Scan as far as possible, stopping when an entire top-level RESP message has been consumed or the data is exhausted. + /// + /// True if a top-level RESP message has been consumed. + public bool TryRead(ref RespReader reader, out long bytesRead) + { + bytesRead = ReadCore(ref reader, reader.BytesConsumed); + return IsComplete; + } + + /// + /// Scan as far as possible, stopping when an entire top-level RESP message has been consumed or the data is exhausted. + /// + /// True if a top-level RESP message has been consumed. + public bool TryRead(ReadOnlySpan value, out int bytesRead) + { + var reader = new RespReader(value); + bytesRead = (int)ReadCore(ref reader); + return IsComplete; + } + + /// + /// Scan as far as possible, stopping when an entire top-level RESP message has been consumed or the data is exhausted. + /// + /// True if a top-level RESP message has been consumed. + public bool TryRead(in ReadOnlySequence value, out long bytesRead) + { + var reader = new RespReader(in value); + bytesRead = ReadCore(ref reader); + return IsComplete; + } + + /// + /// Scan as far as possible, stopping when an entire top-level RESP message has been consumed or the data is exhausted. + /// + /// The number of bytes consumed in this operation. + private long ReadCore(ref RespReader reader, long startOffset = 0) + { + while (_delta >= 0 && reader.TryReadNext()) + { +#if DEBUG + _elementCount++; +#endif + if (!reader.IsAttribute & _prefix == RespPrefix.None) + { + _prefix = reader.Prefix; + } + + if (reader.IsAggregate) ApplyAggregateRules(ref reader); + + if (_streamingAggregateDepth == 0) _delta += reader.Delta(); + } + + var bytesRead = reader.BytesConsumed - startOffset; + _totalBytes += bytesRead; + return bytesRead; + } + + private void ApplyAggregateRules(ref RespReader reader) + { + Debug.Assert(reader.IsAggregate, "RESP aggregate expected"); + if (reader.IsStreaming) + { + // entering an aggregate stream + if (_streamingAggregateDepth == ushort.MaxValue) ThrowTooDeep(); + _streamingAggregateDepth++; + } + else if (reader.Prefix == RespPrefix.StreamTerminator) + { + // exiting an aggregate stream + if (_streamingAggregateDepth == 0) ThrowUnexpectedTerminator(); + _streamingAggregateDepth--; + } + static void ThrowTooDeep() => throw new InvalidOperationException("Maximum streaming aggregate depth exceeded."); + static void ThrowUnexpectedTerminator() => throw new InvalidOperationException("Unexpected streaming aggregate terminator."); + } +} diff --git a/src/RESPite/Messages/RespWriter.cs b/src/RESPite/Messages/RespWriter.cs new file mode 100644 index 000000000..88b1caa1e --- /dev/null +++ b/src/RESPite/Messages/RespWriter.cs @@ -0,0 +1,994 @@ +using System; +using System.Buffers; +using System.Buffers.Text; +using System.ComponentModel; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Text; +using RESPite.Internal; + +namespace RESPite.Messages; + +/// +/// Provides low-level RESP formatting operations. +/// +public ref struct RespWriter +{ + private readonly IBufferWriter? _target; + + [SuppressMessage("Style", "IDE0032:Use auto property", Justification = "Clarity")] + private int _index; + + internal readonly int IndexInCurrentBuffer => _index; + +#if NET7_0_OR_GREATER + private ref byte StartOfBuffer; + private int BufferLength; + + private ref byte WriteHead + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => ref Unsafe.Add(ref StartOfBuffer, _index); + } + + private Span Tail + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => MemoryMarshal.CreateSpan(ref Unsafe.Add(ref StartOfBuffer, _index), BufferLength - _index); + } + + private void WriteRawUnsafe(byte value) => Unsafe.Add(ref StartOfBuffer, _index++) = value; + + private readonly ReadOnlySpan WrittenLocalBuffer => + MemoryMarshal.CreateReadOnlySpan(ref StartOfBuffer, _index); +#else + private Span _buffer; + private readonly int BufferLength => _buffer.Length; + + private readonly ref byte StartOfBuffer + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => ref MemoryMarshal.GetReference(_buffer); + } + + private readonly ref byte WriteHead + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => ref Unsafe.Add(ref MemoryMarshal.GetReference(_buffer), _index); + } + + private readonly Span Tail => _buffer.Slice(_index); + private void WriteRawUnsafe(byte value) => _buffer[_index++] = value; + + private readonly ReadOnlySpan WrittenLocalBuffer => _buffer.Slice(0, _index); +#endif + + internal readonly string DebugBuffer() => RespConstants.UTF8.GetString(WrittenLocalBuffer); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void WriteCrLfUnsafe() + { + Unsafe.WriteUnaligned(ref WriteHead, RespConstants.CrLfUInt16); + _index += 2; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void WriteCrLf() + { + if (Available >= 2) + { + Unsafe.WriteUnaligned(ref WriteHead, RespConstants.CrLfUInt16); + _index += 2; + } + else + { + WriteRaw(RespConstants.CrlfBytes); + } + } + + private readonly int Available + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => BufferLength - _index; + } + + /// + /// Create a new RESP writer over the provided target. + /// + public RespWriter(IBufferWriter target) + { + _target = target; + _index = 0; +#if NET7_0_OR_GREATER + StartOfBuffer = ref Unsafe.NullRef(); + BufferLength = 0; +#else + _buffer = default; +#endif + GetBuffer(); + } + + /// + /// Create a new RESP writer over the provided target. + /// + public RespWriter(Span target) + { + _index = 0; +#if NET7_0_OR_GREATER + BufferLength = target.Length; + StartOfBuffer = ref MemoryMarshal.GetReference(target); +#else + _buffer = target; +#endif + } + + /// + /// Commits any unwritten bytes to the output. + /// + public void Flush() + { + if (_index != 0 && _target is not null) + { + _target.Advance(_index); +#if NET7_0_OR_GREATER + _index = BufferLength = 0; + StartOfBuffer = ref Unsafe.NullRef(); +#else + _index = 0; + _buffer = default; +#endif + } + } + + private void FlushAndGetBuffer(int sizeHint) + { + Flush(); + GetBuffer(sizeHint); + } + + private void GetBuffer(int sizeHint = 128) + { + if (Available == 0) + { + if (_target is null) + { + ThrowFixedBufferExceeded(); + } + else + { + const int MIN_BUFFER = 1024; + _index = 0; +#if NET7_0_OR_GREATER + var span = _target.GetSpan(Math.Max(sizeHint, MIN_BUFFER)); + BufferLength = span.Length; + StartOfBuffer = ref MemoryMarshal.GetReference(span); +#else + _buffer = _target.GetSpan(Math.Max(sizeHint, MIN_BUFFER)); +#endif + ActivationHelper.DebugBreakIf(Available == 0); + } + } + } + + [DoesNotReturn, MethodImpl(MethodImplOptions.NoInlining)] + private static void ThrowFixedBufferExceeded() => + throw new InvalidOperationException("Fixed buffer cannot be expanded"); + + /// + /// Write raw RESP data to the output; no validation will occur. + /// + [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)] + public void WriteRaw(scoped ReadOnlySpan buffer) + { + const int MAX_TO_DOUBLE_BUFFER = 128; + if (buffer.Length <= MAX_TO_DOUBLE_BUFFER && buffer.Length <= Available) + { + buffer.CopyTo(Tail); + _index += buffer.Length; + } + else + { + // write directly to the output + Flush(); + if (_target is null) + { + ThrowFixedBufferExceeded(); + } + else + { + _target.Write(buffer); + } + } + } + + public RespCommandMap? CommandMap { get; set; } + + /// + /// Write a command header. + /// + /// The command name to write. + /// The number of arguments for the command (excluding the command itself). + public void WriteCommand(scoped ReadOnlySpan command, int args) + { + if (args < 0) Throw(); + WritePrefixInteger(RespPrefix.Array, args + 1); + if (command.IsEmpty) ThrowEmptyCommand(); + if (CommandMap is { } map) + { + var mapped = map.Map(command); + if (mapped.IsEmpty) ThrowCommandUnavailable(command); + command = mapped; + } + + WriteBulkString(command); + + static void Throw() => throw new ArgumentOutOfRangeException(nameof(args)); + + static void ThrowEmptyCommand() => + throw new ArgumentException(paramName: nameof(command), message: "Empty command specified."); + + static void ThrowCommandUnavailable(ReadOnlySpan command) + => throw new ArgumentException( + paramName: nameof(command), + message: $"The command {Encoding.UTF8.GetString(command)} is not available."); + } + + /// + /// Write a key as a bulk string. + /// + /// The key to write. + public void WriteKey(scoped ReadOnlySpan value) => WriteBulkString(value); + + /// + /// Write a key as a bulk string. + /// + /// The key to write. + public void WriteKey(ReadOnlyMemory value) => WriteBulkString(value.Span); + + /// + /// Write a key as a bulk string. + /// + /// The key to write. + public void WriteKey(scoped ReadOnlySpan value) => WriteBulkString(value); + + /// + /// Write a key as a bulk string. + /// + /// The key to write. + public void WriteKey(ReadOnlyMemory value) => WriteBulkString(value.Span); + + /// + /// Write a key as a bulk string. + /// + /// The key to write. + public void WriteKey(string value) => WriteBulkString(value); + + /// + /// Write a key as a bulk string. + /// + /// The key to write. + public void WriteKey(byte[] value) => WriteBulkString(value.AsSpan()); + + /// + /// Write a payload as a bulk string. + /// + /// The payload to write. + public void WriteBulkString(byte[] value) => WriteBulkString(value.AsSpan()); + + /// + /// Write a payload as a bulk string. + /// + /// The payload to write. + public void WriteBulkString(ReadOnlyMemory value) + => WriteBulkString(value.Span); + + /// + /// Write a payload as a bulk string. + /// + /// The payload to write. + public void WriteBulkString(scoped ReadOnlySpan value) + { + if (value.IsEmpty) + { + if (Available >= 6) + { + WriteRawPrechecked(Raw.BulkStringEmpty_6, 6); + } + else + { + WriteRaw("$0\r\n\r\n"u8); + } + } + else + { + WriteBulkStringHeader(value.Length); + if (Available >= value.Length + 2) + { + value.CopyTo(Tail); + _index += value.Length; + WriteCrLfUnsafe(); + } + else + { + // slow path + WriteRaw(value); + WriteCrLf(); + } + } + } + + /* + /// + /// Write a payload as a bulk string. + /// + /// The payload to write. + public void WriteBulkString(in SimpleString value) + { + if (value.IsEmpty) + { + WriteRaw("$0\r\n\r\n"u8); + } + else if (value.TryGetBytes(span: out var bytes)) + { + WriteBulkString(bytes); + } + else if (value.TryGetChars(span: out var chars)) + { + WriteBulkString(chars); + } + else if (value.TryGetBytes(sequence: out var bytesSeq)) + { + WriteBulkString(bytesSeq); + } + else if (value.TryGetChars(sequence: out var charsSeq)) + { + WriteBulkString(charsSeq); + } + else + { + Throw(); + } + + static void Throw() => throw new InvalidOperationException($"It was not possible to read the {nameof(SimpleString)} contents"); + } + */ + + /// + /// Write an integer as a bulk string. + /// + public void WriteBulkString(bool value) => WriteBulkString(value ? 1 : 0); + + /// + /// Write a floating point as a bulk string. + /// + public void WriteBulkString(double value) // implicitly: inclusive + { + if (value == 0.0 | double.IsNaN(value) | double.IsInfinity(value)) + { + WriteKnownDoubleInclusive(ref this, value); + + static void WriteKnownDoubleInclusive(ref RespWriter writer, double value) + { + if (value == 0.0) + { + writer.WriteRaw("$1\r\n0\r\n"u8); + } + else if (double.IsNaN(value)) + { + writer.WriteRaw("$3\r\nnan\r\n"u8); + } + else if (double.IsPositiveInfinity(value)) + { + writer.WriteRaw("$3\r\ninf\r\n"u8); + } + else if (double.IsNegativeInfinity(value)) + { + writer.WriteRaw("$4\r\n-inf\r\n"u8); + } + else + { + Throw(); + static void Throw() => throw new ArgumentOutOfRangeException(nameof(value)); + } + } + } + else + { + Debug.Assert((RespConstants.MaxProtocolBytesBytesNumber + 1) <= 32); + Span scratch = stackalloc byte[32]; + if (!Utf8Formatter.TryFormat(value, scratch, out int bytes, G17)) + ThrowFormatException(); + + WritePrefixInteger(RespPrefix.BulkString, bytes); + WriteRaw(scratch.Slice(0, bytes)); + WriteCrLf(); + } + } + + internal void WriteBulkStringExclusive(double value) + { + if (value == 0.0 | double.IsNaN(value) | double.IsInfinity(value)) + { + WriteKnownDoubleExclusive(ref this, value); + + static void WriteKnownDoubleExclusive(ref RespWriter writer, double value) + { + if (value == 0.0) + { + writer.WriteRaw("$2\r\n(0\r\n"u8); + } + else if (double.IsNaN(value)) + { + writer.WriteRaw("$4\r\n(nan\r\n"u8); + } + else if (double.IsPositiveInfinity(value)) + { + writer.WriteRaw("$4\r\n(inf\r\n"u8); + } + else if (double.IsNegativeInfinity(value)) + { + writer.WriteRaw("$5\r\n(-inf\r\n"u8); + } + else + { + Throw(); + static void Throw() => throw new ArgumentOutOfRangeException(nameof(value)); + } + } + } + else + { + Debug.Assert((RespConstants.MaxProtocolBytesBytesNumber + 1) <= 32); + Span scratch = stackalloc byte[32]; + scratch[0] = (byte)'('; + if (!Utf8Formatter.TryFormat(value, scratch.Slice(1), out int bytes, G17)) + ThrowFormatException(); + bytes++; + + WritePrefixInteger(RespPrefix.BulkString, bytes); + WriteRaw(scratch.Slice(0, bytes)); + WriteCrLf(); + } + } + + private static readonly StandardFormat G17 = new('G', 17); + + /// + /// Write an integer as a bulk string. + /// + public void WriteBulkString(long value) + { + if (value >= -1 & value <= 20) + { + WriteRaw(value switch + { + -1 => "$2\r\n-1\r\n"u8, + 0 => "$1\r\n0\r\n"u8, + 1 => "$1\r\n1\r\n"u8, + 2 => "$1\r\n2\r\n"u8, + 3 => "$1\r\n3\r\n"u8, + 4 => "$1\r\n4\r\n"u8, + 5 => "$1\r\n5\r\n"u8, + 6 => "$1\r\n6\r\n"u8, + 7 => "$1\r\n7\r\n"u8, + 8 => "$1\r\n8\r\n"u8, + 9 => "$1\r\n9\r\n"u8, + 10 => "$2\r\n10\r\n"u8, + 11 => "$2\r\n11\r\n"u8, + 12 => "$2\r\n12\r\n"u8, + 13 => "$2\r\n13\r\n"u8, + 14 => "$2\r\n14\r\n"u8, + 15 => "$2\r\n15\r\n"u8, + 16 => "$2\r\n16\r\n"u8, + 17 => "$2\r\n17\r\n"u8, + 18 => "$2\r\n18\r\n"u8, + 19 => "$2\r\n19\r\n"u8, + 20 => "$2\r\n20\r\n"u8, + _ => Throw(), + }); + + static ReadOnlySpan Throw() => throw new ArgumentOutOfRangeException(nameof(value)); + } + else if (Available >= RespConstants.MaxProtocolBytesBulkStringIntegerInt64) + { + var singleDigit = value >= -99_999_999 && value <= 999_999_999; + WriteRawUnsafe((byte)RespPrefix.BulkString); + + var target = Tail.Slice(singleDigit ? 3 : 4); // N\r\n or NN\r\n + if (!Utf8Formatter.TryFormat(value, target, out var valueBytes)) + ThrowFormatException(); + + Debug.Assert(valueBytes > 0 && singleDigit ? valueBytes < 10 : valueBytes is 10 or 11); + if (!Utf8Formatter.TryFormat(valueBytes, Tail, out var prefixBytes)) + ThrowFormatException(); + Debug.Assert(prefixBytes == (singleDigit ? 1 : 2)); + _index += prefixBytes; + WriteCrLfUnsafe(); + _index += valueBytes; + WriteCrLfUnsafe(); + } + else + { + Debug.Assert(RespConstants.MaxRawBytesInt64 <= 24); + Span scratch = stackalloc byte[24]; + if (!Utf8Formatter.TryFormat(value, scratch, out int bytes)) + ThrowFormatException(); + WritePrefixInteger(RespPrefix.BulkString, bytes); + WriteRaw(scratch.Slice(0, bytes)); + WriteCrLf(); + } + } + + /// + /// Write an unsigned integer as a bulk string. + /// + public void WriteBulkString(ulong value) + { + if (value <= (ulong)long.MaxValue) + { + // re-use existing code for most values + WriteBulkString((long)value); + } + else if (Available >= RespConstants.MaxProtocolBytesBulkStringIntegerInt64) + { + WriteRaw("$20\r\n"u8); + if (!Utf8Formatter.TryFormat(value, Tail, out var bytes) || bytes != 20) + ThrowFormatException(); + _index += 20; + WriteCrLfUnsafe(); + } + else + { + WriteRaw("$20\r\n"u8); + Span scratch = stackalloc byte[20]; + if (!Utf8Formatter.TryFormat(value, scratch, out int bytes) || bytes != 20) + ThrowFormatException(); + WriteRaw(scratch); + WriteCrLf(); + } + } + + private static void ThrowFormatException() => throw new FormatException(); + + private void WritePrefixInteger(RespPrefix prefix, int length) + { + if (Available >= RespConstants.MaxProtocolBytesIntegerInt32) + { + WriteRawUnsafe((byte)prefix); + if (length >= 0 & length <= 9) + { + WriteRawUnsafe((byte)(length + '0')); + } + else + { + if (!Utf8Formatter.TryFormat(length, Tail, out var bytesWritten)) + { + ThrowFormatException(); + } + + _index += bytesWritten; + } + + WriteCrLfUnsafe(); + } + else + { + WriteViaStack(ref this, prefix, length); + } + + static void WriteViaStack(ref RespWriter respWriter, RespPrefix prefix, int length) + { + Debug.Assert(RespConstants.MaxProtocolBytesIntegerInt32 <= 16); + Span buffer = stackalloc byte[16]; + buffer[0] = (byte)prefix; + int payloadLength; + if (length >= 0 & length <= 9) + { + buffer[1] = (byte)(length + '0'); + payloadLength = 1; + } + else if (!Utf8Formatter.TryFormat(length, buffer.Slice(1), out payloadLength)) + { + ThrowFormatException(); + } + + Unsafe.WriteUnaligned(ref buffer[payloadLength + 1], RespConstants.CrLfUInt16); + respWriter.WriteRaw(buffer.Slice(0, payloadLength + 3)); + } + + bool writeToStack = Available < RespConstants.MaxProtocolBytesIntegerInt32; + + Span target = writeToStack ? stackalloc byte[16] : Tail; + target[0] = (byte)prefix; + } + + /// + /// Write a payload as a bulk string. + /// + /// The payload to write. + public void WriteBulkString(string value) + { + // ReSharper disable once ConditionIsAlwaysTrueOrFalseAccordingToNullableAPIContract + if (value is null) ThrowNull(); + WriteBulkString(value.AsSpan()); + } + + [MethodImpl(MethodImplOptions.NoInlining), DoesNotReturn] + // ReSharper disable once NotResolvedInText + private static void ThrowNull() => + // ReSharper disable once NotResolvedInText + throw new ArgumentNullException("value", "Null values cannot be sent from client to server"); + + internal void WriteBulkStringUnoptimized(string? value) + { + if (value is null) ThrowNull(); + if (value.Length == 0) + { + WriteRaw("$0\r\n\r\n"u8); + } + else + { + var byteCount = RespConstants.UTF8.GetByteCount(value); + WritePrefixInteger(RespPrefix.BulkString, byteCount); + if (Available >= byteCount) + { + var actual = RespConstants.UTF8.GetBytes(value.AsSpan(), Tail); + Debug.Assert(actual == byteCount); + _index += actual; + } + else + { + WriteUtf8Slow(value.AsSpan(), byteCount); + } + + WriteCrLf(); + } + } + + /// + /// Write a payload as a bulk string. + /// + /// The payload to write. + public void WriteBulkString(ReadOnlyMemory value) => WriteBulkString(value.Span); + + /// + /// Write a payload as a bulk string. + /// + /// The payload to write. + public void WriteBulkString(scoped ReadOnlySpan value) + { + if (value.Length == 0) + { + if (Available >= 6) + { + WriteRawPrechecked(Raw.BulkStringEmpty_6, 6); + } + else + { + WriteRaw("$0\r\n\r\n"u8); + } + } + else + { + var byteCount = RespConstants.UTF8.GetByteCount(value); + WriteBulkStringHeader(byteCount); + if (Available >= 2 + byteCount) + { + var actual = RespConstants.UTF8.GetBytes(value, Tail); + Debug.Assert(actual == byteCount); + _index += actual; + WriteCrLfUnsafe(); + } + else + { + FlushAndGetBuffer(Math.Min(byteCount, MAX_BUFFER_HINT)); + if (Available >= byteCount + 2) + { + // that'll work + var actual = RespConstants.UTF8.GetBytes(value, Tail); + Debug.Assert(actual == byteCount); + _index += actual; + WriteCrLfUnsafe(); + } + else + { + WriteUtf8Slow(value, byteCount); + WriteCrLf(); + } + } + } + } + + private const int MAX_BUFFER_HINT = 64 * 1024; + + private void WriteUtf8Slow(scoped ReadOnlySpan value, int remaining) + { + var enc = _perThreadEncoder; + if (enc is null) + { + enc = _perThreadEncoder = RespConstants.UTF8.GetEncoder(); + } + else + { + enc.Reset(); + } + + bool completed; + int charsUsed, bytesUsed; + do + { + enc.Convert(value, Tail, false, out charsUsed, out bytesUsed, out completed); + value = value.Slice(charsUsed); + _index += bytesUsed; + remaining -= bytesUsed; + FlushAndGetBuffer(Math.Min(remaining, MAX_BUFFER_HINT)); + } + // until done... + while (!completed); + + if (remaining != 0) + { + // any trailing data? + FlushAndGetBuffer(Math.Min(remaining, MAX_BUFFER_HINT)); + enc.Convert(value, Tail, true, out charsUsed, out bytesUsed, out completed); + Debug.Assert(charsUsed == 0 && completed); + _index += bytesUsed; + // ReSharper disable once RedundantAssignment - it is in debug! + remaining -= bytesUsed; + } + + enc.Reset(); + Debug.Assert(remaining == 0); + } + + internal void WriteBulkString(in ReadOnlySequence value) + { + if (value.IsSingleSegment) + { +#if NETCOREAPP3_0_OR_GREATER + WriteBulkString(value.FirstSpan); +#else + WriteBulkString(value.First.Span); +#endif + } + else + { + // lazy for now + int len = checked((int)value.Length); + byte[] buffer = ArrayPool.Shared.Rent(len); + value.CopyTo(buffer); + WriteBulkString(new ReadOnlySpan(buffer, 0, len)); + ArrayPool.Shared.Return(buffer); + } + } + + internal void WriteBulkString(in ReadOnlySequence value) + { + if (value.IsSingleSegment) + { +#if NETCOREAPP3_0_OR_GREATER + WriteBulkString(value.FirstSpan); +#else + WriteBulkString(value.First.Span); +#endif + } + else + { + // lazy for now + int len = checked((int)value.Length); + char[] buffer = ArrayPool.Shared.Rent(len); + value.CopyTo(buffer); + WriteBulkString(new ReadOnlySpan(buffer, 0, len)); + ArrayPool.Shared.Return(buffer); + } + } + + /// + /// Experimental. + /// + public void WriteBulkString(int value) + { + if (Available >= sizeof(ulong)) + { + switch (value) + { + case -1: + WriteRawPrechecked(Raw.BulkStringInt32_M1_8, 8); + return; + case 0: + WriteRawPrechecked(Raw.BulkStringInt32_0_7, 7); + return; + case 1: + WriteRawPrechecked(Raw.BulkStringInt32_1_7, 7); + return; + case 2: + WriteRawPrechecked(Raw.BulkStringInt32_2_7, 7); + return; + case 3: + WriteRawPrechecked(Raw.BulkStringInt32_3_7, 7); + return; + case 4: + WriteRawPrechecked(Raw.BulkStringInt32_4_7, 7); + return; + case 5: + WriteRawPrechecked(Raw.BulkStringInt32_5_7, 7); + return; + case 6: + WriteRawPrechecked(Raw.BulkStringInt32_6_7, 7); + return; + case 7: + WriteRawPrechecked(Raw.BulkStringInt32_7_7, 7); + return; + case 8: + WriteRawPrechecked(Raw.BulkStringInt32_8_7, 7); + return; + case 9: + WriteRawPrechecked(Raw.BulkStringInt32_9_7, 7); + return; + case 10: + WriteRawPrechecked(Raw.BulkStringInt32_10_8, 8); + return; + } + } + + WriteBulkStringUnoptimized(value); + } + + internal void WriteBulkStringUnoptimized(int value) + { + if (Available >= RespConstants.MaxProtocolBytesBulkStringIntegerInt32) + { + var singleDigit = value >= -99_999_999 && value <= 999_999_999; + WriteRawUnsafe((byte)RespPrefix.BulkString); + + var target = Tail.Slice(singleDigit ? 3 : 4); // N\r\n or NN\r\n + if (!Utf8Formatter.TryFormat(value, target, out var valueBytes)) + ThrowFormatException(); + + Debug.Assert(valueBytes > 0 && singleDigit ? valueBytes < 10 : valueBytes is 10 or 11); + if (!Utf8Formatter.TryFormat(valueBytes, Tail, out var prefixBytes)) + ThrowFormatException(); + Debug.Assert(prefixBytes == (singleDigit ? 1 : 2)); + _index += prefixBytes; + WriteCrLfUnsafe(); + _index += valueBytes; + WriteCrLfUnsafe(); + } + else + { + Debug.Assert(RespConstants.MaxRawBytesInt32 <= 16); + Span scratch = stackalloc byte[16]; + if (!Utf8Formatter.TryFormat(value, scratch, out int bytes)) + ThrowFormatException(); + WritePrefixInteger(RespPrefix.BulkString, bytes); + WriteRaw(scratch.Slice(0, bytes)); + WriteCrLf(); + } + } + + /// + /// Write an array header. + /// + /// The number of elements in the array. + public void WriteArray(int count) + { + if (Available >= sizeof(uint)) + { + switch (count) + { + case 0: + WriteRawPrechecked(Raw.ArrayPrefix_0_4, 4); + return; + case 1: + WriteRawPrechecked(Raw.ArrayPrefix_1_4, 4); + return; + case 2: + WriteRawPrechecked(Raw.ArrayPrefix_2_4, 4); + return; + case 3: + WriteRawPrechecked(Raw.ArrayPrefix_3_4, 4); + return; + case 4: + WriteRawPrechecked(Raw.ArrayPrefix_4_4, 4); + return; + case 5: + WriteRawPrechecked(Raw.ArrayPrefix_5_4, 4); + return; + case 6: + WriteRawPrechecked(Raw.ArrayPrefix_6_4, 4); + return; + case 7: + WriteRawPrechecked(Raw.ArrayPrefix_7_4, 4); + return; + case 8: + WriteRawPrechecked(Raw.ArrayPrefix_8_4, 4); + return; + case 9: + WriteRawPrechecked(Raw.ArrayPrefix_9_4, 4); + return; + case 10 when Available >= sizeof(ulong): + WriteRawPrechecked(Raw.ArrayPrefix_10_5, 5); + return; + case -1: + WriteRawPrechecked(Raw.ArrayPrefix_M1_5, 5); + return; + } + } + + WritePrefixInteger(RespPrefix.Array, count); + } + + private void WriteBulkStringHeader(int count) + { + if (Available >= sizeof(uint)) + { + switch (count) + { + case 0: + WriteRawPrechecked(Raw.BulkStringPrefix_0_4, 4); + return; + case 1: + WriteRawPrechecked(Raw.BulkStringPrefix_1_4, 4); + return; + case 2: + WriteRawPrechecked(Raw.BulkStringPrefix_2_4, 4); + return; + case 3: + WriteRawPrechecked(Raw.BulkStringPrefix_3_4, 4); + return; + case 4: + WriteRawPrechecked(Raw.BulkStringPrefix_4_4, 4); + return; + case 5: + WriteRawPrechecked(Raw.BulkStringPrefix_5_4, 4); + return; + case 6: + WriteRawPrechecked(Raw.BulkStringPrefix_6_4, 4); + return; + case 7: + WriteRawPrechecked(Raw.BulkStringPrefix_7_4, 4); + return; + case 8: + WriteRawPrechecked(Raw.BulkStringPrefix_8_4, 4); + return; + case 9: + WriteRawPrechecked(Raw.BulkStringPrefix_9_4, 4); + return; + case 10 when Available >= sizeof(ulong): + WriteRawPrechecked(Raw.BulkStringPrefix_10_5, 5); + return; + case -1 when Available >= sizeof(ulong): + WriteRawPrechecked(Raw.BulkStringPrefix_M1_5, 5); + return; + } + } + + WritePrefixInteger(RespPrefix.BulkString, count); + } + + internal void WriteArrayUnpotimized(int count) => WritePrefixInteger(RespPrefix.Array, count); + + private void WriteRawPrechecked(ulong value, int count) + { + Debug.Assert(Available >= sizeof(ulong)); + Debug.Assert(count >= 0 && count <= sizeof(long)); + Unsafe.WriteUnaligned(ref WriteHead, value); + _index += count; + } + + private void WriteRawPrechecked(uint value, int count) + { + Debug.Assert(Available >= sizeof(uint)); + Debug.Assert(count >= 0 && count <= sizeof(uint)); + Unsafe.WriteUnaligned(ref WriteHead, value); + _index += count; + } + + internal void DebugResetIndex() => _index = 0; + + [ThreadStatic] + // used for multi-chunk encoding + private static Encoder? _perThreadEncoder; +} diff --git a/src/RESPite/RESPite.csproj b/src/RESPite/RESPite.csproj new file mode 100644 index 000000000..ed46defc7 --- /dev/null +++ b/src/RESPite/RESPite.csproj @@ -0,0 +1,91 @@ + + + + true + net461;netstandard2.0;net472;net6.0;net8.0;net9.0 + enable + enable + false + 2025 - $([System.DateTime]::Now.Year) Marc Gravell + readme.md + + + + + + + + + + + + + + + + + RespOperation.cs + + + RespReader.cs + + + RespReader.cs + + + RespReader.cs + + + RespReader.cs + + + RespReader.cs + + + BlockBufferSerializer.cs + + + BlockBufferSerializer.cs + + + BlockBufferSerializer.cs + + + RespMessageBase.cs + + + RespMessageBase.cs + + + RespMessageBase.cs + + + BufferingBatchConnection.cs + + + BufferingBatchConnection.cs + + + RespMessageBase.cs + + + DecoratorConnection.cs + + + DecoratorConnection.cs + + + + + + FrameworkShims.cs + + + NullableHacks.cs + + + SkipLocalsInit.cs + + + + diff --git a/src/RESPite/RespBatch.cs b/src/RESPite/RespBatch.cs new file mode 100644 index 000000000..1ca8c8183 --- /dev/null +++ b/src/RESPite/RespBatch.cs @@ -0,0 +1,34 @@ +namespace RESPite; + +public abstract class RespBatch : RespConnection +{ + // a batch doesn't act as a proxy to the tail, so we don't need to DecoratorConnection logic + protected readonly RespConnection Tail; + private protected RespBatch(in RespContext tail) : base(tail) + { + Tail = tail.Connection; + // ack: yes, I know we won't spot every recursive+decorated scenario + if (Tail is RespBatch) ThrowNestedBatch(); + + static void ThrowNestedBatch() => + throw new ArgumentException("Nested batches are not supported", nameof(tail)); + } + + public abstract Task FlushAsync(); + public abstract void Flush(); + + internal override void ThrowIfUnhealthy() + { + Tail.ThrowIfUnhealthy(); + base.ThrowIfUnhealthy(); + } + + internal override bool IsHealthy => base.IsHealthy & Tail.IsHealthy; + + /// + /// Suggests that the batch should ensure it has enough capacity for the given number of additional operations. + /// Note that this contrasts with , where the number provided + /// is the total number of elements. + /// + public virtual void EnsureCapacity(int additionalCount) { } +} diff --git a/src/RESPite/RespCommandAttribute.cs b/src/RESPite/RespCommandAttribute.cs new file mode 100644 index 000000000..e3f0c5239 --- /dev/null +++ b/src/RESPite/RespCommandAttribute.cs @@ -0,0 +1,34 @@ +using System.Diagnostics; + +namespace RESPite; + +[AttributeUsage(AttributeTargets.Method, AllowMultiple = false, Inherited = false)] +[Conditional("DEBUG")] +public sealed class RespCommandAttribute(string? command = null) : Attribute +{ + public string? Command => command; + public string? Formatter { get; set; } + public string? Parser { get; set; } + + public static class Parsers + { + private const string Prefix = "global::RESPite.RespParsers."; + + public const string Summary = Prefix + nameof(RespParsers.ResponseSummary) + + "." + nameof(RespParsers.ResponseSummary.Parser); + public const string ByteArray = Prefix + nameof(RespParsers.ByteArray); + public const string String = Prefix + nameof(RespParsers.String); + public const string Int32 = Prefix + nameof(RespParsers.Int32); + public const string Int64 = Prefix + nameof(RespParsers.Int64); + public const string NullableInt64 = Prefix + nameof(RespParsers.NullableInt64); + public const string NullableInt32 = Prefix + nameof(RespParsers.NullableInt32); + public const string NullableSingle = Prefix + nameof(RespParsers.NullableSingle); + public const string BufferWriter = Prefix + nameof(RespParsers.BufferWriter); + public const string ByteArrayArray = Prefix + nameof(RespParsers.ByteArrayArray); + public const string OK = Prefix + nameof(RespParsers.OK); + public const string Single = Prefix + nameof(RespParsers.Single); + public const string Double = Prefix + nameof(RespParsers.Double); + public const string Success = Prefix + nameof(RespParsers.Success); + public const string NullableDouble = Prefix + nameof(RespParsers.NullableDouble); + } +} diff --git a/src/RESPite/RespCommandMap.cs b/src/RESPite/RespCommandMap.cs new file mode 100644 index 000000000..3d5f28cb1 --- /dev/null +++ b/src/RESPite/RespCommandMap.cs @@ -0,0 +1,25 @@ +namespace RESPite; + +public abstract class RespCommandMap +{ + /// + /// Apply any remapping to the command. + /// + /// The command requested. + /// The remapped command; this can be the original command, a remapped command, or an empty instance if the command is not available. + public abstract ReadOnlySpan Map(ReadOnlySpan command); + + /// + /// Indicates whether the specified command is available. + /// + public virtual bool IsAvailable(ReadOnlySpan command) + => Map(command).Length != 0; + + public static RespCommandMap Default { get; } = new DefaultRespCommandMap(); + + private sealed class DefaultRespCommandMap : RespCommandMap + { + public override ReadOnlySpan Map(ReadOnlySpan command) => command; + public override bool IsAvailable(ReadOnlySpan command) => true; + } +} diff --git a/src/RESPite/RespConfiguration.cs b/src/RESPite/RespConfiguration.cs new file mode 100644 index 000000000..f0d079e9f --- /dev/null +++ b/src/RESPite/RespConfiguration.cs @@ -0,0 +1,95 @@ +using System.Text; + +namespace RESPite; + +/// +/// Over-arching configuration for a RESP system. +/// +public class RespConfiguration +{ + private static readonly TimeSpan DefaultSyncTimeout = TimeSpan.FromSeconds(10); + + public static RespConfiguration Default { get; } = new( + RespCommandMap.Default, [], DefaultSyncTimeout, NullServiceProvider.Instance, 0); + + public static Builder CreateBuilder() => default; // for discoverability + + public struct Builder // intentionally mutable + { + public TimeSpan? SyncTimeout { get; set; } + public IServiceProvider? ServiceProvider { get; set; } + public RespCommandMap? CommandMap { get; set; } + public int DefaultDatabase { get; set; } + public object? KeyPrefix { get; set; } // can be a string or byte[] + + public Builder(RespConfiguration? source) + { + if (source is not null) + { + CommandMap = source.CommandMap; + SyncTimeout = source.SyncTimeout; + KeyPrefix = source.KeyPrefix.ToArray(); + ServiceProvider = source.ServiceProvider; + DefaultDatabase = source.DefaultDatabase; + // undo defaults + if (ReferenceEquals(CommandMap, RespCommandMap.Default)) CommandMap = null; + if (ReferenceEquals(ServiceProvider, NullServiceProvider.Instance)) ServiceProvider = null; + } + } + + public RespConfiguration CreateConfiguration() + { + byte[] prefix = KeyPrefix switch + { + null => [], + string { Length: 0 } => [], + string s => Encoding.UTF8.GetBytes(s), + byte[] { Length: 0 } => [], + byte[] b => b.AsSpan().ToArray(), // create isolated copy for mutability reasons + _ => throw new ArgumentException($"{nameof(KeyPrefix)} must be a string or byte[]", nameof(KeyPrefix)), + }; + + if (prefix.Length == 0 & SyncTimeout is null & CommandMap is null & ServiceProvider is null) return Default; + + return new( + CommandMap ?? RespCommandMap.Default, + prefix, + SyncTimeout ?? DefaultSyncTimeout, + ServiceProvider ?? NullServiceProvider.Instance, + DefaultDatabase); + } + } + + private RespConfiguration( + RespCommandMap commandMap, + byte[] keyPrefix, + TimeSpan syncTimeout, + IServiceProvider serviceProvider, + int defaultDatabase) + { + CommandMap = commandMap; + SyncTimeout = syncTimeout; + _keyPrefix = (byte[])keyPrefix.Clone(); // create isolated copy + ServiceProvider = serviceProvider; + DefaultDatabase = defaultDatabase; + } + + private readonly byte[] _keyPrefix; + public IServiceProvider ServiceProvider { get; } + public RespCommandMap CommandMap { get; } + public TimeSpan SyncTimeout { get; } + public ReadOnlySpan KeyPrefix => _keyPrefix; + public int DefaultDatabase { get; } + + public Builder AsBuilder() => new(this); + + private sealed class NullServiceProvider : IServiceProvider + { + public static readonly NullServiceProvider Instance = new(); + private NullServiceProvider() { } + public object? GetService(Type serviceType) => null; + } + + internal T? GetService() where T : class + => ServiceProvider.GetService(typeof(T)) as T; +} diff --git a/src/RESPite/RespConnection.cs b/src/RESPite/RespConnection.cs new file mode 100644 index 000000000..d9f056fce --- /dev/null +++ b/src/RESPite/RespConnection.cs @@ -0,0 +1,204 @@ +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Net; +using System.Net.Sockets; +using System.Runtime.CompilerServices; +using RESPite.Connections; +using RESPite.Connections.Internal; +using RESPite.Internal; + +namespace RESPite; + +public abstract class RespConnection : IDisposable, IAsyncDisposable, IRespContextSource +{ + public sealed class RespConnectionErrorEventArgs(Exception exception, [CallerMemberName] string operation = "") + : EventArgs + { + public Exception Exception { get; } = exception; + public string Operation { get; } = operation; + } + + private bool _isDisposed; + internal bool IsDisposed => _isDisposed; + + private readonly RespContext _context; + public ref readonly RespContext Context => ref _context; + public RespConfiguration Configuration { get; } + public abstract event EventHandler? ConnectionError; + + private protected static void OnConnectionError( + EventHandler? handler, + Exception exception, + [CallerMemberName] string operation = "") + { + handler?.Invoke(null, new(exception, operation)); + } + + internal virtual bool IsHealthy => !_isDisposed; + + internal virtual BlockBufferSerializer Serializer => BlockBufferSerializer.Shared; + + internal abstract int OutstandingOperations { get; } + internal readonly RespCommandMap? NonDefaultCommandMap; // prevent checking this each write + public TimeSpan SyncTimeout { get; } + + public static RespConnection Create(Stream stream, RespConfiguration? configuration = null) + => new StreamConnection(configuration ?? RespConfiguration.Default, stream); + + // this is the usual usage, since we want context to be preserved + private protected RespConnection(in RespContext tail, RespConfiguration? configuration = null) + { + var conn = tail.Connection; + if (conn is not { IsHealthy: true }) + { + // ReSharper disable once ConditionIsAlwaysTrueOrFalseAccordingToNullableAPIContract + if (conn is null) ThrowNullTail(); // trust no-one + else conn.ThrowIfUnhealthy(); + } + + Configuration = configuration ?? conn.Configuration; + _context = tail.WithConnection(this); + + // hoist and pre-check the command map once per connection + var commandMap = Configuration.CommandMap; + NonDefaultCommandMap = ReferenceEquals(commandMap, RespCommandMap.Default) ? null : commandMap; + SyncTimeout = Configuration.SyncTimeout; // snapshot to reduce indirection + + [DoesNotReturn] + static void ThrowNullTail() => + throw new ArgumentException("No tail connection provided.", nameof(tail)); + } + + internal virtual void ThrowIfUnhealthy() + { + if (_isDisposed) ThrowDisposed(); + } + + // this is atypical - only for use when creating null connections + private protected RespConnection(RespConfiguration? configuration = null) + { + Configuration = configuration ?? RespConfiguration.Default; + _context = default; + _context = _context.WithConnection(this); + Debug.Assert(this is NullConnection); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + protected void ThrowIfDisposed() + { + if (_isDisposed) ThrowDisposed(); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private void ThrowDisposed() => throw CreateObjectDisposedException(); + + internal Exception CreateObjectDisposedException() => new ObjectDisposedException(GetType().Name); + + public void Dispose() + { + _isDisposed = this is not NullConnection; + OnDispose(true); + } + + protected virtual void OnDispose(bool disposing) + { + } + + public ValueTask DisposeAsync() + { + _isDisposed = this is not NullConnection; + return OnDisposeAsync(); + } + + protected virtual ValueTask OnDisposeAsync() + { + OnDispose(true); + return default; + } + + public abstract void Write(in RespOperation message); + + internal virtual void Write(ReadOnlySpan messages) + { + int i = 0; + try + { + for (i = 0; i < messages.Length; i++) + { + Write(messages[i]); + } + } + catch (Exception ex) + { + MarkFaulted(messages.Slice(i), ex); + throw; + } + } + + public virtual Task WriteAsync(in RespOperation message) + { + Write(message); + return Task.CompletedTask; + } + + internal virtual Task WriteAsync(ReadOnlyMemory messages) + { + switch (messages.Length) + { + case 0: return Task.CompletedTask; + case 1: return WriteAsync(messages.Span[0]); + } + + int i = 0; + try + { + for (; i < messages.Length; i++) + { + var pending = WriteAsync(messages.Span[i]); + if (!pending.IsCompleted) + return Awaited(this, pending, messages.Slice(i)); + pending.GetAwaiter().GetResult(); + } + } + catch (Exception ex) + { + MarkFaulted(messages.Span.Slice(i), ex); + throw; + } + + return Task.CompletedTask; + + static async Task Awaited(RespConnection connection, Task pending, ReadOnlyMemory messages) + { + int i = 0; + try + { + await pending.ConfigureAwait(false); + for (i = 1; i < messages.Length; i++) + { + await connection.WriteAsync(messages.Span[i]).ConfigureAwait(false); + } + } + catch (Exception ex) + { + MarkFaulted(messages.Span.Slice(i), ex); + throw; + } + } + } + + protected static void MarkFaulted(ReadOnlySpan messages, Exception fault) + { + foreach (var message in messages) + { + try + { + message.Message.TrySetException(message.Token, fault); + } + catch + { + // best efforts + } + } + } +} diff --git a/src/RESPite/RespContext.cs b/src/RESPite/RespContext.cs new file mode 100644 index 000000000..812cf4cd2 --- /dev/null +++ b/src/RESPite/RespContext.cs @@ -0,0 +1,245 @@ +#define MULTI_BATCH // use combining batches, rather than simple batches + +using System.Runtime.CompilerServices; +using RESPite.Connections.Internal; + +namespace RESPite; + +/// +/// Transient state for a RESP operation. +/// +public readonly struct RespContext +{ + public static ref readonly RespContext Null => ref NullConnection.Default.Context; + + private readonly RespConnection _connection; + public readonly CancellationToken CancellationToken; + private readonly int _database; + private readonly RespContextFlags _flags; + + public RespContextFlags Flags => _flags; + + [Flags] + public enum RespContextFlags + { + /// + /// No additional flags; this is the default. Operations will prefer primary nodes if available. + /// + None = 0, + + /// + /// The equivalent of with `false`. + /// + DisableCaptureContext = 1, + + // IMPORTANT: the following align with CommandFlags, to avoid needing any additional mapping. + + /// + /// The caller is not interested in the result; the caller will immediately receive a default-value + /// of the expected return type (this value is not indicative of anything at the server). + /// + FireAndForget = 2, + + /// + /// This operation should only be performed on the primary. + /// + DemandPrimary = 4, + + /// + /// This operation should be performed on the replica if it is available, but will be performed on + /// a primary if no replicas are available. Suitable for read operations only. + /// + PreferReplica = 8, // note: we're using a 2-bit set here, which [Flags] formatting hates + + /// + /// This operation should only be performed on a replica. Suitable for read operations only. + /// + DemandReplica = 12, // note: we're using a 2-bit set here, which [Flags] formatting hates + + /// + /// Indicates that this operation should not be forwarded to other servers as a result of an ASK or MOVED response. + /// + NoRedirect = 64, + + /// + /// Indicates that script-related operations should use EVAL, not SCRIPT LOAD + EVALSHA. + /// + NoScriptCache = 512, + } + + /// + public override string ToString() => _connection?.ToString() ?? "(null)"; + + public RespConnection Connection => _connection; + public int Database => _database; + + public RespCommandMap CommandMap => _connection.NonDefaultCommandMap ?? RespCommandMap.Default; + public TimeSpan SyncTimeout => _connection.SyncTimeout; + + /// + /// REPLACES the associated with this context. + /// + public RespContext WithCancellationToken(CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + RespContext clone = this; + Unsafe.AsRef(in clone.CancellationToken) = cancellationToken; + return clone; + } + + /// + /// COMBINES the associated with this context + /// with an additional cancellation. The returned + /// represents the lifetime of the combined operation, and should be + /// disposed when complete. + /// + public Lifetime WithCombineCancellationToken(CancellationToken cancellationToken) + { + if (!cancellationToken.CanBeCanceled + || cancellationToken == CancellationToken) + { + // would have no effect + CancellationToken.ThrowIfCancellationRequested(); + return new(in this, null); + } + + cancellationToken.ThrowIfCancellationRequested(); + if (!CancellationToken.CanBeCanceled) + { + // we don't currently have cancellation; no need for a link + return new(in this, null, cancellationToken); + } + + CancellationToken.ThrowIfCancellationRequested(); + var src = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, CancellationToken); + return new(in this, src, src.Token); + } + + public Lifetime WithCombine(IDisposable lifetime) + => new(in this, lifetime); + + public Lifetime WithCombineTimeout(TimeSpan timeout) + { + if (timeout <= TimeSpan.Zero) Throw(); + CancellationTokenSource src; + if (CancellationToken.CanBeCanceled) + { + CancellationToken.ThrowIfCancellationRequested(); + src = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken); + src.CancelAfter(timeout); + } + else + { + src = new CancellationTokenSource(timeout); + } + + static void Throw() => throw new ArgumentOutOfRangeException(nameof(timeout)); + + return new Lifetime(in this, src, src.Token); + } + + public readonly struct Lifetime : IDisposable + { + // Unusual public field; a ref-readonly would be preferable, but by-ref props have restrictions on structs. + // We would rather avoid the copy semantics associated with a regular property getter. + public readonly RespContext Context; + + private readonly IDisposable? _source; + + internal Lifetime(in RespContext context, IDisposable? source) + { + Context = context; + _source = source; + } + + internal Lifetime(in RespContext context, IDisposable? source, CancellationToken cancellationToken) + { + Context = context; // snapshot, we can now mutate this locally + _source = source; + Unsafe.AsRef(in Context.CancellationToken) = cancellationToken; + } + + public void Dispose() + { + var src = _source; + // best effort cleanup, noting that copies may exist + // (which is also why we can't risk TryReset+pool) + Unsafe.AsRef(in _source) = null; + Unsafe.AsRef(in Context.CancellationToken) = AlreadyCanceled; + src?.Dispose(); // don't cancel on EOL; want consistent behaviour with/without link + } + + private static readonly CancellationToken AlreadyCanceled = CreateCancelledToken(); + + private static CancellationToken CreateCancelledToken() + { + CancellationTokenSource cts = new(); + cts.Cancel(); + return cts.Token; + } + } + + public RespContext WithDatabase(int database) + { + RespContext clone = this; + Unsafe.AsRef(in clone._database) = database; + return clone; + } + + public RespContext WithConnection(RespConnection connection) + { + RespContext clone = this; + Unsafe.AsRef(in clone._connection) = connection; + return clone; + } + + public RespContext ConfigureAwait(bool continueOnCapturedContext) + { + RespContext clone = this; + Unsafe.AsRef(in clone._flags) = continueOnCapturedContext + ? _flags & ~RespContextFlags.DisableCaptureContext + : _flags | RespContextFlags.DisableCaptureContext; + return clone; + } + + /// + /// Replaces the associated with this context. + /// + public RespContext WithFlags(RespContextFlags flags) + { + RespContext clone = this; + Unsafe.AsRef(in clone._flags) = flags; + return clone; + } + + /// + /// Replaces the and associated with this context. + /// + public RespContext With(int database, RespContextFlags flags) + { + RespContext clone = this; + Unsafe.AsRef(in clone._database) = database; + Unsafe.AsRef(in clone._flags) = flags; + return clone; + } + + /// + /// Replaces the and associated with this context, + /// using a mask to determine which flags to replace. Passing + /// for will replace no flags. + /// + public RespContext With(int database, RespContextFlags flags, RespContextFlags mask) + { + RespContext clone = this; + Unsafe.AsRef(in clone._database) = database; + Unsafe.AsRef(in clone._flags) = (flags & ~mask) | (_flags & mask); + return clone; + } + + public RespBatch CreateBatch(int sizeHint = 0) +#if MULTI_BATCH + => new MergingBatchConnection(in this, sizeHint); +#else + => new BasicBatchConnection(in this, sizeHint); +#endif +} diff --git a/src/RESPite/RespContextExtensions.cs b/src/RESPite/RespContextExtensions.cs new file mode 100644 index 000000000..6272a39a1 --- /dev/null +++ b/src/RESPite/RespContextExtensions.cs @@ -0,0 +1,288 @@ +using System.Buffers; +using RESPite.Internal; +using RESPite.Messages; + +namespace RESPite; + +public static class RespContextExtensions +{ + public static RespOperationBuilder Command( + this in RespContext context, + ReadOnlySpan command, + TRequest request, + IRespFormatter formatter) +#if NET9_0_OR_GREATER + where TRequest : allows ref struct +#endif + => new(in context, command, request, formatter); + + /* not sure that default formatters (RespFormatters.Get) make sense + public static RespOperationBuilder Command( + this in RespContext context, + ReadOnlySpan command, + in TRequest value) +#if NET9_0_OR_GREATER + where TRequest : allows ref struct +#endif + => new(in context, command, value, RespFormatters.Get()); + */ + + public static RespOperationBuilder Command(this in RespContext context, ReadOnlySpan command) + => new(in context, command, false, RespFormatters.Empty); + + public static RespOperationBuilder Command( + this in RespContext context, + ReadOnlySpan command, + string value, + bool isKey) + => new(in context, command, value, RespFormatters.String(isKey)); + + public static RespOperationBuilder Command( + this in RespContext context, + ReadOnlySpan command, + byte[] value, + bool isKey) + => new(in context, command, value, RespFormatters.ByteArray(isKey)); + + /// + /// Creates an operation and synchronously writes it to the connection. + /// + /// The type of the request data being sent. + public static RespOperation Send( + this in RespContext context, + ReadOnlySpan command, + in TRequest request, + IRespFormatter formatter, + IRespParser parser) +#if NET9_0_OR_GREATER + where TRequest : allows ref struct +#endif + { + var op = CreateOperation(context, command, request, formatter, parser); + context.Connection.Write(op); + return op; + } + + /// + /// Creates an operation and synchronously writes it to the connection. + /// + /// The type of the request data being sent. + /// The type of the response data being received. + public static RespOperation Send( + this in RespContext context, + ReadOnlySpan command, + in TRequest request, + IRespFormatter formatter, + IRespParser parser) +#if NET9_0_OR_GREATER + where TRequest : allows ref struct +#endif + { + var op = CreateOperation(context, command, request, formatter, parser); + context.Connection.Write(op); + return op; + } + + /// + /// Creates an operation and synchronously writes it to the connection. + /// + /// The type of the response data being received. + public static RespOperation Send( + this in RespContext context, + ReadOnlySpan command, + IRespParser parser) + { + var op = CreateOperation(context, command, false, RespFormatters.Empty, parser); + context.Connection.Write(op); + return op; + } + + /// + /// Creates an operation and synchronously writes it to the connection. + /// + /// The type of the request data being sent. + /// The type of state data required by the parser. + /// The type of the response data being received. + public static RespOperation Send( + this in RespContext context, + ReadOnlySpan command, + in TRequest request, + IRespFormatter formatter, + in TState state, + IRespParser parser) +#if NET9_0_OR_GREATER + where TRequest : allows ref struct +#endif + { + var op = CreateOperation(context, command, request, formatter, in state, parser); + context.Connection.Write(op); + return op; + } + + /// + /// Creates an operation and synchronously writes it to the connection. + /// + /// The type of state data required by the parser. + /// The type of the response data being received. + public static RespOperation Send( + this in RespContext context, + ReadOnlySpan command, + in TState state, + IRespParser parser) + { + var op = CreateOperation(context, command, false, RespFormatters.Empty, in state, parser); + context.Connection.Write(op); + return op; + } + + /// + /// Creates an operation and synchronously writes it to the connection. + /// + /// The type of state data required by the parser. + /// The type of the response data being received. + /// The raw payload is the entire RESP fragment, only used if there is not a command-map. + internal static RespOperation Send( + this in RespContext context, + ReadOnlySpan command, + in TState state, + IRespParser parser, + byte[] rawPayload) + { + var op = CreateOperation(context, command, rawPayload, RespFormatters.Raw, in state, parser); + context.Connection.Write(op); + return op; + } + + /// + /// Creates an operation and asynchronously writes it to the connection, awaiting the completion of the underlying write. + /// + /// The type of the request data being sent. + public static ValueTask SendAsync( + this in RespContext context, + ReadOnlySpan command, + in TRequest request, + IRespFormatter formatter, + IRespParser parser) +#if NET9_0_OR_GREATER + where TRequest : allows ref struct +#endif + { + var op = CreateOperation(context, command, request, formatter, parser); + var write = context.Connection.WriteAsync(op); + if (!write.IsCompleted) return AwaitedVoid(op, write); + write.GetAwaiter().GetResult(); + return new(op); + + static async ValueTask AwaitedVoid(RespOperation op, Task write) + { + await write.ConfigureAwait(false); + return op; + } + } + + /// + /// Creates an operation and asynchronously writes it to the connection, awaiting the completion of the underlying write. + /// + /// The type of the request data being sent. + /// The type of the response data being received. + public static ValueTask> SendAsync( + this in RespContext context, + ReadOnlySpan command, + in TRequest request, + IRespFormatter formatter, + IRespParser parser) +#if NET9_0_OR_GREATER + where TRequest : allows ref struct +#endif + { + var op = CreateOperation(context, command, request, formatter, parser); + var write = context.Connection.WriteAsync(op); + if (!write.IsCompleted) return Awaited(op, write); + write.GetAwaiter().GetResult(); + return new(op); + } + + /// + /// Creates an operation and asynchronously writes it to the connection, awaiting the completion of the underlying write. + /// + /// The type of the request data being sent. + /// The type of state data required by the parser. + /// The type of the response data being received. + public static ValueTask> SendAsync( + this in RespContext context, + ReadOnlySpan command, + in TRequest request, + IRespFormatter formatter, + in TState state, + IRespParser parser) +#if NET9_0_OR_GREATER + where TRequest : allows ref struct +#endif + { + var op = CreateOperation(context, command, request, formatter, in state, parser); + var write = context.Connection.WriteAsync(op); + if (!write.IsCompleted) return Awaited(op, write); + write.GetAwaiter().GetResult(); + return new(op); + } + + private static async ValueTask> Awaited(RespOperation op, Task write) + { + await write.ConfigureAwait(false); + return op; + } + + public static RespOperation CreateOperation( + in RespContext context, // deliberately not "this" + ReadOnlySpan command, + in TRequest request, + IRespFormatter formatter, + IRespParser parser) +#if NET9_0_OR_GREATER + where TRequest : allows ref struct +#endif + { + var conn = context.Connection; + var memory = + conn.Serializer.Serialize(conn.NonDefaultCommandMap, command, request, formatter); + var msg = RespStatelessMessage.Get(parser); + msg.Init(memory, context.CancellationToken); + return new(msg); + } + + public static RespOperation CreateOperation( + in RespContext context, // deliberately not "this" + ReadOnlySpan command, + in TRequest request, + IRespFormatter formatter, + IRespParser parser) +#if NET9_0_OR_GREATER + where TRequest : allows ref struct +#endif + { + var conn = context.Connection; + var memory = + conn.Serializer.Serialize(conn.NonDefaultCommandMap, command, request, formatter); + var msg = RespStatelessMessage.Get(parser); + msg.Init(memory, context.CancellationToken); + return new(msg); + } + + public static RespOperation CreateOperation( + in RespContext context, // deliberately not "this" + ReadOnlySpan command, + in TRequest request, + IRespFormatter formatter, + in TState state, + IRespParser parser) +#if NET9_0_OR_GREATER + where TRequest : allows ref struct +#endif + { + var conn = context.Connection; + var memory = conn.Serializer.Serialize(conn.NonDefaultCommandMap, command, request, formatter); + var msg = RespStatefulMessage.Get(in state, parser); + msg.Init(memory, context.CancellationToken); + return new(msg); + } +} diff --git a/src/RESPite/RespException.cs b/src/RESPite/RespException.cs new file mode 100644 index 000000000..86a344577 --- /dev/null +++ b/src/RESPite/RespException.cs @@ -0,0 +1,8 @@ +namespace RESPite; + +/// +/// Represents a RESP error message. +/// +public sealed class RespException(string message) : Exception(message) +{ +} diff --git a/src/RESPite/RespFormatters.cs b/src/RESPite/RespFormatters.cs new file mode 100644 index 000000000..979891cfc --- /dev/null +++ b/src/RESPite/RespFormatters.cs @@ -0,0 +1,161 @@ +using RESPite.Messages; + +namespace RESPite; + +public static class RespFormatters +{ + public static IRespFormatter String(bool isKey) => isKey ? Key.String : Value.String; + public static IRespFormatter> Chars(bool isKey) => isKey ? Key.Chars : Value.Chars; + public static IRespFormatter ByteArray(bool isKey) => isKey ? Key.ByteArray : Value.ByteArray; + public static IRespFormatter> Bytes(bool isKey) => isKey ? Key.Bytes : Value.Bytes; + public static IRespFormatter Empty => EmptyFormatter.Instance; + public static IRespFormatter Int32 => Value.Formatter.Default; + public static IRespFormatter Int64 => Value.Formatter.Default; + public static IRespFormatter Single => Value.Formatter.Default; + public static IRespFormatter Double => Value.Formatter.Default; + internal static IRespFormatter Raw => RawFormatter.Instance; + + public static class Key + { + // ReSharper disable MemberHidesStaticFromOuterClass + public static IRespFormatter String => Formatter.Default; + public static IRespFormatter> Chars => Formatter.Default; + public static IRespFormatter ByteArray => Formatter.Default; + + public static IRespFormatter> Bytes => Formatter.Default; + // ReSharper restore MemberHidesStaticFromOuterClass + + // (just to fix an auto-format glitch) + internal sealed class Formatter : IRespFormatter, IRespFormatter, + IRespFormatter>, IRespFormatter> + { + private Formatter() { } + public static readonly Formatter Default = new(); + + public void Format(scoped ReadOnlySpan command, ref RespWriter writer, in string value) + { + writer.WriteCommand(command, 1); + writer.WriteKey(value); + } + + public void Format(scoped ReadOnlySpan command, ref RespWriter writer, in byte[] value) + { + writer.WriteCommand(command, 1); + writer.WriteKey(value); + } + + public void Format(scoped ReadOnlySpan command, ref RespWriter writer, in ReadOnlyMemory value) + { + writer.WriteCommand(command, 1); + writer.WriteKey(value); + } + + public void Format(scoped ReadOnlySpan command, ref RespWriter writer, in ReadOnlyMemory value) + { + writer.WriteCommand(command, 1); + writer.WriteKey(value); + } + } + } + + public static class Value + { + // ReSharper disable MemberHidesStaticFromOuterClass + public static IRespFormatter String => Formatter.Default; + public static IRespFormatter> Chars => Formatter.Default; + public static IRespFormatter ByteArray => Formatter.Default; + + public static IRespFormatter> Bytes => Formatter.Default; + // ReSharper restore MemberHidesStaticFromOuterClass + + // (just to fix an auto-format glitch) + internal sealed class Formatter : IRespFormatter, IRespFormatter, + IRespFormatter>, IRespFormatter>, + IRespFormatter, IRespFormatter, + IRespFormatter, IRespFormatter + { + private Formatter() { } + public static readonly Formatter Default = new(); + + public void Format(scoped ReadOnlySpan command, ref RespWriter writer, in string value) + { + writer.WriteCommand(command, 1); + writer.WriteBulkString(value); + } + + public void Format(scoped ReadOnlySpan command, ref RespWriter writer, in byte[] value) + { + writer.WriteCommand(command, 1); + writer.WriteBulkString(value); + } + + public void Format(scoped ReadOnlySpan command, ref RespWriter writer, in ReadOnlyMemory value) + { + writer.WriteCommand(command, 1); + writer.WriteBulkString(value); + } + + public void Format(scoped ReadOnlySpan command, ref RespWriter writer, in ReadOnlyMemory value) + { + writer.WriteCommand(command, 1); + writer.WriteBulkString(value); + } + + public void Format(scoped ReadOnlySpan command, ref RespWriter writer, in int value) + { + writer.WriteCommand(command, 1); + writer.WriteBulkString(value); + } + + public void Format(scoped ReadOnlySpan command, ref RespWriter writer, in long value) + { + writer.WriteCommand(command, 1); + writer.WriteBulkString(value); + } + + public void Format(scoped ReadOnlySpan command, ref RespWriter writer, in float value) + { + writer.WriteCommand(command, 1); + writer.WriteBulkString(value); + } + + public void Format(scoped ReadOnlySpan command, ref RespWriter writer, in double value) + { + writer.WriteCommand(command, 1); + writer.WriteBulkString(value); + } + } + } + + private sealed class EmptyFormatter : IRespFormatter + { + private EmptyFormatter() { } + public static readonly EmptyFormatter Instance = new(); + + public void Format(scoped ReadOnlySpan command, ref RespWriter writer, in bool value) + { + writer.WriteCommand(command, 0); + } + } + + private sealed class RawFormatter : IRespFormatter + { + private RawFormatter() { } + public static readonly RawFormatter Instance = new(); + + public void Format( + scoped ReadOnlySpan command, + ref RespWriter writer, + in byte[] value) + { + if (writer.CommandMap is null) + { + writer.WriteRaw(value); + } + else + { + writer.WriteCommand(command, 0); + } + } + } +} diff --git a/src/RESPite/RespIgnoreAttribute.cs b/src/RESPite/RespIgnoreAttribute.cs new file mode 100644 index 000000000..a79a96b61 --- /dev/null +++ b/src/RESPite/RespIgnoreAttribute.cs @@ -0,0 +1,13 @@ +using System.ComponentModel; +using System.Diagnostics; + +namespace RESPite; + +[AttributeUsage(AttributeTargets.Parameter)] +[Conditional("DEBUG"), ImmutableObject(true)] +public sealed class RespIgnoreAttribute(object? value = null) : Attribute +{ + // note; nulls are always ignored (taking NRTs into account); the purpose + // of an explicit null is for RedisValue - this prompts HasValue checks (i.e. non-trivial value). + public object? Value => value; +} diff --git a/src/RESPite/RespKeyAttribute.cs b/src/RESPite/RespKeyAttribute.cs new file mode 100644 index 000000000..5c8b87a71 --- /dev/null +++ b/src/RESPite/RespKeyAttribute.cs @@ -0,0 +1,10 @@ +using System.ComponentModel; +using System.Diagnostics; + +namespace RESPite; + +[AttributeUsage(AttributeTargets.Parameter)] +[Conditional("DEBUG"), ImmutableObject(true)] +public sealed class RespKeyAttribute() : Attribute +{ +} diff --git a/src/RESPite/RespOperation.cs b/src/RESPite/RespOperation.cs new file mode 100644 index 000000000..572ea17c3 --- /dev/null +++ b/src/RESPite/RespOperation.cs @@ -0,0 +1,239 @@ +using System.Buffers; +using System.ComponentModel; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Threading.Tasks.Sources; +using RESPite.Internal; +using RESPite.Messages; + +namespace RESPite; + +/// +/// Represents a RESP operation that does not return a value (other than to signal completion). +/// This works almost identically to when based on +/// , and the usage semantics are the same. In particular, +/// note that a value can only be consumed once. Unlike , the +/// value can be awaited synchronously if required. +/// +public readonly struct RespOperation : ICriticalNotifyCompletion +{ +#if DEBUG + [ThreadStatic] + // how many resp-operations have we chewed through? + private static int _debugPerThreadMessageAllocations; + + internal static int DebugPerThreadMessageAllocations => _debugPerThreadMessageAllocations; +#else + internal static int DebugPerThreadMessageAllocations => 0; +#endif + + [Conditional("DEBUG")] + internal static void DebugOnAllocateMessage() + { +#if DEBUG + _debugPerThreadMessageAllocations++; +#endif + } + + // it is important that this layout remains identical between RespOperation and RespOperation + private readonly RespMessageBase _message; + private readonly short _token; + private readonly bool _disableCaptureContext; // default is false, so: bypass + + internal RespOperation(RespMessageBase message, short token, bool disableCaptureContext) + { + _message = message; + _token = token; + _disableCaptureContext = disableCaptureContext; + } + + internal RespOperation(RespMessageBase message, bool disableCaptureContext = false) + { + _message = message; + _token = message.Token; + _disableCaptureContext = disableCaptureContext; + } + + public bool IsSent => Message.IsSent(_token); + internal RespMessageBase Message => _message ?? ThrowNoMessage(); + + internal static RespMessageBase ThrowNoMessage() + => throw new InvalidOperationException($"{nameof(RespOperation)} is not correctly initialized"); + + /// + /// Treats this operation as a . + /// + public static implicit operator ValueTask(in RespOperation operation) + => new(operation.Message, operation._token); + + /// + public Task AsTask() => new ValueTask(Message, _token).AsTask(); + + public ValueTask AsValueTask() => new(Message, _token); + + /// + public void Wait(TimeSpan timeout = default) + => Message.WaitVoid(_token, timeout); + + /// + public bool IsCompleted => Message.GetStatus(_token) != ValueTaskSourceStatus.Pending; + + /// + public bool IsCompletedSuccessfully => Message.GetStatus(_token) == ValueTaskSourceStatus.Succeeded; + + /// + public bool IsFaulted => Message.GetStatus(_token) == ValueTaskSourceStatus.Faulted; + + /// + public bool IsCanceled => Message.GetStatus(_token) == ValueTaskSourceStatus.Canceled; + + public ref readonly CancellationToken CancellationToken => ref Message.CancellationToken; + + internal short Token => _token; + internal int MessageCount => Message.MessageCount; + internal bool TrySetException(Exception exception) => Message.TrySetException(_token, exception); + + internal bool TrySetCancelled(CancellationToken cancellationToken = default) => + Message.TrySetCanceled(_token, cancellationToken); + + internal bool TryReserveRequest(out ReadOnlySequence payload, bool recordSent = true) => + Message.TryReserveRequest(_token, out payload, recordSent); + + internal void ReleaseRequest() => Message.ReleaseRequest(); + + internal static readonly Action InvokeState = static state => ((Action)state!).Invoke(); + + /// + /// + public void OnCompleted(Action continuation) + { + // UseSchedulingContext === continueOnCapturedContext, always add FlowExecutionContext + var flags = _disableCaptureContext + ? ValueTaskSourceOnCompletedFlags.FlowExecutionContext + : ValueTaskSourceOnCompletedFlags.FlowExecutionContext | + ValueTaskSourceOnCompletedFlags.UseSchedulingContext; + Message.OnCompletedWithNotSentDetection(InvokeState, continuation, _token, flags); + } + + /// + public void UnsafeOnCompleted(Action continuation) + { + // UseSchedulingContext === continueOnCapturedContext + var flags = _disableCaptureContext + ? ValueTaskSourceOnCompletedFlags.None + : ValueTaskSourceOnCompletedFlags.UseSchedulingContext; + Message.OnCompletedWithNotSentDetection(InvokeState, continuation, _token, flags); + } + + /// + public void GetResult() => Message.GetResultVoid(_token); + + /// + public RespOperation GetAwaiter() => this; + + /// + public RespOperation ConfigureAwait(bool continueOnCapturedContext) + { + var clone = this; + Unsafe.AsRef(in clone._disableCaptureContext) = !continueOnCapturedContext; + return clone; + } + + /// + /// Provides a mechanism to control the outcome of a ; this is mostly + /// intended for testing purposes. It is broadly comparable to . + /// + [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)] + public readonly struct Remote + { + private readonly RespMessageBase _message; + private readonly short _token; + + internal Remote(RespMessageBase message) + { + _message = message; + _token = message.Token; + } + + public bool IsTokenMatch => _token == _message.Token; + + /// + public bool TrySetCanceled(CancellationToken cancellationToken = default) + => _message.TrySetCanceled(_token); + + /// + public bool TrySetException(Exception exception) + => _message.TrySetException(_token, exception); + + /// + /// The parser provided during creation is used to process the result. + public bool TrySetResult(scoped ReadOnlySpan response) + => _message.TrySetResult(_token, response); + + /// + /// The parser provided during creation is used to process the result. + public bool TrySetResult(in ReadOnlySequence response) + => _message.TrySetResult(_token, response); + } + + /// + /// Create a disconnected without a RESP parser; this is only intended for testing purposes. + /// + [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)] + public static RespOperation Create( + out Remote remote, + bool sent = true, + CancellationToken cancellationToken = default) + { + var msg = RespStatelessMessage.Get(null); + msg.Init(sent, cancellationToken); + remote = new(msg); + return new RespOperation(msg); + } + + /// + /// Create a disconnected with a stateless RESP parser; this is only intended for testing purposes. + /// + /// The result of the operation. + [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)] + public static RespOperation Create( + IRespParser? parser, + out Remote remote, + bool sent = true, + CancellationToken cancellationToken = default) + { + var msg = RespStatelessMessage.Get(parser); + msg.Init(sent, cancellationToken); + remote = new(msg); + return new RespOperation(msg); + } + + /// + /// Create a disconnected with a stateful RESP parser; this is only intended for testing purposes. + /// + /// The state used by the parser. + /// The result of the operation. + [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)] + public static RespOperation Create( + in TState state, + IRespParser? parser, + out Remote remote, + bool sent = true, + CancellationToken cancellationToken = default) + { + var msg = RespStatefulMessage.Get(in state, parser); + msg.Init(sent, cancellationToken); + remote = new(msg); + return new RespOperation(msg); + } + + internal void OnSent() => Message.OnSent(Token); + + internal bool TryGetSubMessages(out ReadOnlySpan operations) + => Message.TryGetSubMessages(Token, out operations); + + internal bool TrySetResultAfterUnloadingSubMessages() => Message.TrySetResultAfterUnloadingSubMessages(Token); + + internal RespMessageBase.StateFlags Flags => Message.Flags; + internal int Slot => _message.Slot; +} diff --git a/src/RESPite/RespOperationBuilder.cs b/src/RESPite/RespOperationBuilder.cs new file mode 100644 index 000000000..e686cb724 --- /dev/null +++ b/src/RESPite/RespOperationBuilder.cs @@ -0,0 +1,44 @@ +using RESPite.Messages; + +namespace RESPite; + +public readonly ref struct RespOperationBuilder( + in RespContext context, + ReadOnlySpan command, + TRequest request, + IRespFormatter formatter) +#if NET9_0_OR_GREATER + where TRequest : allows ref struct +#endif +{ + private readonly RespContext _context = context; + private readonly ReadOnlySpan _command = command; + private readonly TRequest request = request; // cannot inline to .ctor because of "allows ref struct" + + public TResponse Wait() + => Send(RespParsers.Get()).Wait(_context.SyncTimeout); + + public TResponse Wait(IRespParser parser) + => Send(parser).Wait(_context.SyncTimeout); + + public TResponse Wait(in TState state) + => Send(in state, RespParsers.Get()).Wait(_context.SyncTimeout); + + public TResponse Wait(in TState state, IRespParser parser) + => Send(in state, parser).Wait(_context.SyncTimeout); + + public void Wait() => Send(RespParsers.Success).Wait(_context.SyncTimeout); + + public RespOperation Send() + => _context.Send(_command, request, formatter, RespParsers.Get()); + + public RespOperation Send(IRespParser parser) + => _context.Send(_command, request, formatter, parser); + + public RespOperation Send() => _context.Send(_command, request, formatter, RespParsers.Success); + public RespOperation Send(in TState state) + => _context.Send(_command, request, formatter, in state, RespParsers.Get()); + + public RespOperation Send(in TState state, IRespParser parser) + => _context.Send(_command, request, formatter, in state, parser); +} diff --git a/src/RESPite/RespOperationT.cs b/src/RESPite/RespOperationT.cs new file mode 100644 index 000000000..8de2c9542 --- /dev/null +++ b/src/RESPite/RespOperationT.cs @@ -0,0 +1,118 @@ +using System.Runtime.CompilerServices; +using System.Threading.Tasks.Sources; +using RESPite.Internal; + +namespace RESPite; + +/// +/// Represents a RESP operation that returns a value of type . +/// This works almost identically to when based on +/// , and the usage semantics are the same. In particular, +/// note that a value can only be consumed once. Unlike , the +/// value can be awaited synchronously if required. +/// +/// The type of value returned by the operation. +public readonly struct RespOperation : ICriticalNotifyCompletion +{ + // it is important that this layout remains identical between RespOperation and RespOperation + private readonly RespMessageBase _message; + private readonly short _token; + private readonly bool _disableCaptureContext; + + internal RespOperation(RespMessageBase message, short token, bool disableCaptureContext) + { + _message = message; + _token = token; + _disableCaptureContext = disableCaptureContext; + } + internal RespOperation(RespMessageBase message, bool disableCaptureContext = false) + { + _message = message; + _token = message.Token; + _disableCaptureContext = disableCaptureContext; + } + + public CancellationToken CancellationToken => Message.CancellationToken; + + private RespMessageBase Message => _message ?? (RespMessageBase)RespOperation.ThrowNoMessage(); + + /// + /// Treats this operation as an untyped . + /// + #if PREVIEW_LANGVER + [Obsolete($"When possible, prefer .Untyped")] + #endif + public static implicit operator RespOperation(in RespOperation operation) + => Unsafe.As, RespOperation>(ref Unsafe.AsRef(in operation)); + + /// + /// Treats this operation as an untyped . + /// + public static implicit operator ValueTask(in RespOperation operation) + => new(operation.Message, operation._token); + + /// + /// Treats this operation as a . + /// + public static implicit operator ValueTask(in RespOperation operation) + => new(operation.Message, operation._token); + + /// + public Task AsTask() => new ValueTask(Message, _token).AsTask(); + + public ValueTask AsValueTask() => new(Message, _token); + + /// + public T Wait(TimeSpan timeout = default) + => Message.Wait(_token, timeout); + + /// + public bool IsCompleted => Message.GetStatus(_token) != ValueTaskSourceStatus.Pending; + + /// + public bool IsCompletedSuccessfully => Message.GetStatus(_token) == ValueTaskSourceStatus.Succeeded; + + /// + public bool IsFaulted => Message.GetStatus(_token) == ValueTaskSourceStatus.Faulted; + + /// + public bool IsCanceled => Message.GetStatus(_token) == ValueTaskSourceStatus.Canceled; + + /// + /// + public void OnCompleted(Action continuation) + { + // UseSchedulingContext === continueOnCapturedContext, always add FlowExecutionContext + var flags = _disableCaptureContext + ? ValueTaskSourceOnCompletedFlags.FlowExecutionContext + : ValueTaskSourceOnCompletedFlags.FlowExecutionContext | + ValueTaskSourceOnCompletedFlags.UseSchedulingContext; + Message.OnCompletedWithNotSentDetection(RespOperation.InvokeState, continuation, _token, flags); + } + + public bool IsSent => Message.IsSent(_token); + + /// + public void UnsafeOnCompleted(Action continuation) + { + // UseSchedulingContext === continueOnCapturedContext + var flags = _disableCaptureContext + ? ValueTaskSourceOnCompletedFlags.None + : ValueTaskSourceOnCompletedFlags.UseSchedulingContext; + Message.OnCompletedWithNotSentDetection(RespOperation.InvokeState, continuation, _token, flags); + } + + /// + public T GetResult() => Message.GetResult(_token); + + /// + public RespOperation GetAwaiter() => this; + + /// + public RespOperation ConfigureAwait(bool continueOnCapturedContext) + { + var clone = this; + Unsafe.AsRef(in clone._disableCaptureContext) = !continueOnCapturedContext; + return clone; + } +} diff --git a/src/RESPite/RespParsers.cs b/src/RESPite/RespParsers.cs new file mode 100644 index 000000000..2ef47d537 --- /dev/null +++ b/src/RESPite/RespParsers.cs @@ -0,0 +1,191 @@ +using System.Buffers; +using System.Diagnostics.CodeAnalysis; +using RESPite.Internal; +using RESPite.Messages; + +namespace RESPite; + +public static class RespParsers +{ + public static IRespParser Success => InbuiltInlineParsers.Default; + public static IRespParser OK => OKParser.Default; + public static IRespParser String => InbuiltCopyOutParsers.Default; + public static IRespParser Int32 => InbuiltInlineParsers.Default; + public static IRespParser NullableInt32 => InbuiltInlineParsers.Default; + public static IRespParser Int64 => InbuiltInlineParsers.Default; + public static IRespParser NullableInt64 => InbuiltInlineParsers.Default; + public static IRespParser Single => InbuiltInlineParsers.Default; + public static IRespParser NullableSingle => InbuiltInlineParsers.Default; + public static IRespParser Double => InbuiltInlineParsers.Default; + public static IRespParser NullableDouble => InbuiltInlineParsers.Default; + public static IRespParser ByteArray => InbuiltCopyOutParsers.Default; + public static IRespParser ByteArrayArray => InbuiltCopyOutParsers.Default; + public static IRespParser, int> BufferWriter => InbuiltCopyOutParsers.Default; + + private static class StatelessCache + { + public static IRespParser? Instance = + (InbuiltCopyOutParsers.Default as IRespParser) ?? // regular (may allocate, etc) + (InbuiltInlineParsers.Default as IRespParser) ?? // inline + (ResponseSummary.Parser as IRespParser); // inline+metadata + } + + private static class StatefulCache + { + public static IRespParser? Instance = + InbuiltCopyOutParsers.Default as IRespParser; // ?? // regular (may allocate, etc) + // (InbuiltInlineParsers.Default as IRespParser) ?? // inline + // (ResponseSummary.Parser as IRespParser); // inline+metadata + } + + private static bool IsInbuilt(object? obj) => obj is InbuiltCopyOutParsers or InbuiltInlineParsers + or ResponseSummary.ResponseSummaryParser; + + public static IRespParser Get() + { + var obj = StatelessCache.Instance; + if (obj is null) ThrowNoParser(); + return obj; + } + + public static IRespParser Get() + { + var obj = StatefulCache.Instance; + if (obj is null) ThrowNoParser(); + return obj; + } + + public static void Set(IRespParser parser) + { + if (IsInbuilt(StatelessCache.Instance)) ThrowInbuiltParser(); + StatelessCache.Instance = parser; + } + + public static void Set(IRespParser parser) + { + if (IsInbuilt(StatefulCache.Instance)) ThrowInbuiltParser(); + StatefulCache.Instance = parser; + } + + [DoesNotReturn] + private static void ThrowNoParser() => throw new InvalidOperationException( + message: + $"No default parser registered for this type; a custom parser must be specified via {nameof(RespParsers)}.{nameof(Set)}(...)."); + + [DoesNotReturn] + private static void ThrowInbuiltParser() => throw new InvalidOperationException( + message: $"This type has inbuilt handling and cannot be changed."); + + private sealed class InbuiltInlineParsers : IRespInlineParser, + IRespParser, + IRespParser, IRespParser, + IRespParser, IRespParser, + IRespParser, IRespParser, + IRespParser, IRespParser + { + private InbuiltInlineParsers() { } + public static readonly InbuiltInlineParsers Default = new(); + + bool IRespParser.Parse(ref RespReader reader) => reader.ReadBoolean(); + int IRespParser.Parse(ref RespReader reader) => reader.ReadInt32(); + + int? IRespParser.Parse(ref RespReader reader) => reader.IsNull ? null : reader.ReadInt32(); + + long IRespParser.Parse(ref RespReader reader) => reader.ReadInt64(); + + long? IRespParser.Parse(ref RespReader reader) => reader.IsNull ? null : reader.ReadInt64(); + + float IRespParser.Parse(ref RespReader reader) => (float)reader.ReadDouble(); + + float? IRespParser.Parse(ref RespReader reader) => reader.IsNull ? null : (float)reader.ReadDouble(); + + double IRespParser.Parse(ref RespReader reader) => reader.ReadDouble(); + + double? IRespParser.Parse(ref RespReader reader) => reader.IsNull ? null : reader.ReadDouble(); + } + + private sealed class OKParser : IRespParser, IRespInlineParser + { + private OKParser() { } + public static readonly OKParser Default = new(); + + public bool Parse(ref RespReader reader) + { + if (!(reader.Prefix == RespPrefix.SimpleString && reader.IsOK())) + { + Throw(); + } + + return true; + static void Throw() => throw new InvalidOperationException("Expected +OK response or similar."); + } + } + + private sealed class InbuiltCopyOutParsers : IRespParser, + IRespParser, IRespParser, + IRespParser, int> + { + private InbuiltCopyOutParsers() { } + public static readonly InbuiltCopyOutParsers Default = new(); + + string? IRespParser.Parse(ref RespReader reader) => reader.ReadString(); + byte[]? IRespParser.Parse(ref RespReader reader) => reader.ReadByteArray(); + + byte[]?[]? IRespParser.Parse(ref RespReader reader) => + reader.ReadArray(static (ref RespReader reader) => reader.ReadByteArray()); + + int IRespParser, int>.Parse(in IBufferWriter state, ref RespReader reader) + { + reader.DemandScalar(); + if (reader.IsNull) return -1; + return reader.CopyTo(state); + } + } + + public readonly struct ResponseSummary(RespPrefix prefix, int length, long protocolBytes) + : IEquatable + { + public RespPrefix Prefix { get; } = prefix; + public int Length { get; } = length; + public long ProtocolBytes { get; } = protocolBytes; + + /// + public override string ToString() => $"{Prefix}, Length: {Length}, Protocol Bytes: {ProtocolBytes}"; + + /// + public bool Equals(ResponseSummary other) => EqualsCore(in other); + + private bool EqualsCore(in ResponseSummary other) => + Prefix == other.Prefix && Length == other.Length && ProtocolBytes == other.ProtocolBytes; + + bool IEquatable.Equals(ResponseSummary other) => EqualsCore(in other); + + /// + public override bool Equals(object? obj) => obj is ResponseSummary summary && EqualsCore(in summary); + + /// + public override int GetHashCode() => (int)Prefix ^ Length ^ ProtocolBytes.GetHashCode(); + + public static IRespParser Parser => ResponseSummaryParser.Default; + + internal sealed class ResponseSummaryParser : IRespParser, IRespInlineParser, + IRespMetadataParser + { + private ResponseSummaryParser() { } + public static readonly ResponseSummaryParser Default = new(); + + public ResponseSummary Parse(ref RespReader reader) + { + var protocolBytes = reader.ProtocolBytesRemaining; + int length = 0; + if (reader.TryMoveNext()) + { + if (reader.IsScalar) length = reader.ScalarLength(); + else if (reader.IsAggregate) length = reader.AggregateLength(); + } + + return new ResponseSummary(reader.Prefix, length, protocolBytes); + } + } + } +} diff --git a/src/RESPite/RespPrefixAttribute.cs b/src/RESPite/RespPrefixAttribute.cs new file mode 100644 index 000000000..e9d4fd438 --- /dev/null +++ b/src/RESPite/RespPrefixAttribute.cs @@ -0,0 +1,12 @@ +using System.ComponentModel; +using System.Diagnostics; + +namespace RESPite; + +// note: omitting the token means that a collection-count prefix will be written +[AttributeUsage(AttributeTargets.Parameter, AllowMultiple = true)] +[Conditional("DEBUG"), ImmutableObject(true)] +public sealed class RespPrefixAttribute(string token = "") : Attribute +{ + public string Token => token; +} diff --git a/src/RESPite/RespSuffixAttribute.cs b/src/RESPite/RespSuffixAttribute.cs new file mode 100644 index 000000000..5bd8e3515 --- /dev/null +++ b/src/RESPite/RespSuffixAttribute.cs @@ -0,0 +1,11 @@ +using System.ComponentModel; +using System.Diagnostics; + +namespace RESPite; + +[AttributeUsage(AttributeTargets.Parameter, AllowMultiple = true)] +[Conditional("DEBUG"), ImmutableObject(true)] +public sealed class RespSuffixAttribute(string token) : Attribute +{ + public string Token => token; +} diff --git a/src/RESPite/ValueTaskExtensions.cs b/src/RESPite/ValueTaskExtensions.cs new file mode 100644 index 000000000..56ec0cfde --- /dev/null +++ b/src/RESPite/ValueTaskExtensions.cs @@ -0,0 +1,137 @@ +#if NET8_0_OR_GREATER +using System.Runtime.CompilerServices; +#else +using System.Reflection; +#endif +using RESPite.Internal; + +namespace RESPite; + +/// +/// The results of asynchronous RESPite operations can be treated interchangeably as either or +/// (or their generic twins: and ). +/// is a more familiar, and is convenient in pre-existing code; +/// is more context-aware, and adds a few additional capabilities, such as: +/// - most notably: automatic detection if attempting to wait/await before a message has been sent. +/// - to check whether the message has been sent to a server. +/// - to access cancellation information about this message. +/// - to wait synchronously for the operation to complete. +/// - a can be implicitly converted to a (unlike to ). +/// +/// Neither representation is more efficient, and the semantics are identical - the result can only be waited/awaited once +/// (unless hoisted into a ). +/// +public static class ValueTaskExtensions +{ + public static bool TryGetRespOperation(this ValueTask value, out RespOperation operation) + { + if (FieldAccessor.Object(value) is not RespMessageBase msg) + { + operation = default; + return false; + } + + short token = FieldAccessor.Token(value); + bool continueOnCapturedContext = FieldAccessor.ContinueOnCapturedContext(value); + operation = new RespOperation(msg, token, continueOnCapturedContext); + return true; + } + + public static RespOperation AsRespOperation(this ValueTask value) + { + if (!TryGetRespOperation(value, out var operation)) Throw(typeof(T)); + return operation; + } + + public static bool TryGetRespOperation(this ValueTask value, out RespOperation operation) + { + if (FieldAccessor.Object(value) is not RespMessageBase msg) + { + operation = default; + return false; + } + + short token = FieldAccessor.Token(value); + bool continueOnCapturedContext = FieldAccessor.ContinueOnCapturedContext(value); + operation = new RespOperation(msg, token, continueOnCapturedContext); + return true; + } + + public static RespOperation AsRespOperation(this ValueTask value) + { + if (!TryGetRespOperation(value, out var operation)) Throw(); + return operation; + } + + private static void Throw(Type type) + => throw new ArgumentException( + $"The {nameof(ValueTask)}<{type.Name}> does not wrap does not wrap a {nameof(RespMessageBase)}<{type.Name}>"); + + private static void Throw() => + throw new ArgumentException($"The {nameof(ValueTask)} does not wrap a {nameof(RespMessageBase)}"); + + // from here on: evil reflection to peek inside ValueTask[] and extract the fields we need + private static class FieldAccessor + { +#if NET8_0_OR_GREATER + [UnsafeAccessor(UnsafeAccessorKind.Field, Name = "_obj")] + public static extern ref readonly object? Object(in ValueTask task); + + [UnsafeAccessor(UnsafeAccessorKind.Field, Name = "_token")] + public static extern ref readonly short Token(in ValueTask task); + + [UnsafeAccessor(UnsafeAccessorKind.Field, Name = "_continueOnCapturedContext")] + public static extern ref readonly bool ContinueOnCapturedContext(in ValueTask task); +#else + private static readonly FieldInfo _obj = + typeof(ValueTask).GetField(nameof(_obj), BindingFlags.NonPublic | BindingFlags.Instance)!; + + private static readonly FieldInfo _token = + typeof(ValueTask).GetField(nameof(_token), BindingFlags.NonPublic | BindingFlags.Instance)!; + + private static readonly FieldInfo? _continueOnCapturedContext = typeof(ValueTask).GetField( + nameof(_continueOnCapturedContext), + BindingFlags.NonPublic | BindingFlags.Instance); + + public static object? Object(ValueTask task) => _obj.GetValue(task); + + public static short Token(ValueTask task) => (short)_token.GetValue(task)!; + + public static bool ContinueOnCapturedContext(ValueTask task) + => _continueOnCapturedContext is not null + && (bool)_continueOnCapturedContext.GetValue(task)!; +#endif + } + + private static class FieldAccessor + { +#if NET8_0_OR_GREATER + [UnsafeAccessor(UnsafeAccessorKind.Field, Name = "_obj")] + public static extern ref readonly object? Object(in ValueTask task); + + [UnsafeAccessor(UnsafeAccessorKind.Field, Name = "_token")] + public static extern ref readonly short Token(in ValueTask task); + + [UnsafeAccessor(UnsafeAccessorKind.Field, Name = "_continueOnCapturedContext")] + public static extern ref readonly bool ContinueOnCapturedContext(in ValueTask task); +#else + private static readonly FieldInfo _obj = + typeof(ValueTask).GetField(nameof(_obj), BindingFlags.NonPublic | BindingFlags.Instance)!; + + private static readonly FieldInfo _token = + typeof(ValueTask).GetField(nameof(_token), BindingFlags.NonPublic | BindingFlags.Instance)!; + + private static readonly FieldInfo? _continueOnCapturedContext = typeof(ValueTask).GetField( + nameof(_continueOnCapturedContext), + BindingFlags.NonPublic | BindingFlags.Instance); + + public static object? Object(ValueTask task) => _obj.GetValue(task); + + public static short Token(ValueTask task) => (short)_token.GetValue(task)!; + + public static bool ContinueOnCapturedContext(ValueTask task) + => _continueOnCapturedContext is not null + && (bool)_continueOnCapturedContext.GetValue(task)!; +#endif + } +} diff --git a/src/RESPite/readme.md b/src/RESPite/readme.md new file mode 100644 index 000000000..541d89525 --- /dev/null +++ b/src/RESPite/readme.md @@ -0,0 +1,120 @@ +# RESPite + +RESPite is a high-performance low-level RESP (Redis, etc) library, used as the IO core for +StackExchange.Redis v3+. It is also available for direct use from other places! + +## Getting Started + +RESPite has two key primitives: + +- a *connection*, `RespConnection`. +- a *context*, `RespContext` - which is a connection plus other local ambient context such as database, cancellation, etc. + +The first thing we need, then, is to create a connection. There are many ways to do this, but to +create a connection to the local default Redis instance: + +``` c# +using var conn = RespConnection.Create(); +// ... +``` + +This gives us a single socket-based connection. Usually a *connection* is long-lived and used for +a great many RESP operations, with the `using` here closing socket eventually. + +Once we have a connection, we can start using it immediately, via the default *context*, from +`.Context`. Usually, it is the *context* that we should be passing around, not a connection: +the context *has* a connection plus local ambient configuration. So: + +``` c# +var ctx = conn.Context; +``` + +Once we have a *context*, we can use that to execute commands: + +``` c# +ctx.SomeOperation(...); +``` + +But: what is `SomeOperation(...)`? ***That's up to you.*** + +### Defining commands + +The RESPite libary only handles the RESP layer - it doesn't add the methods associated with Redis +(don't worry: RESPite.Redis does that - we're not animals!). However, in the general case where you +want to add your own RESP methods, we can do exactly that. The easiest way is by letting the tools do +the work for us: + +``` c# +static class MyCommands +{ + [RespOperation("incr")] // arg optional - it would assume "increment" if omitted + public partial static int Increment(this in RespContext ctx, string key); + + [RespOperation("incrby")] + public partial static int Increment(this in RespContext ctx, string key, int value); +} +``` + +Build-time tools will provide the implementation for us, including adding an `async` version. The code +for this isn't *difficult* - simply: it is *unnecessary*, since in most cases the intent can be clearly +understood. This avoids opportunities to fat-finger things (or get things wrong between the synchronous +and asynchronous versions). + +We can now use: + +``` c# +var x = ctx.Increment("mykey"); +var y = await ctx.IncrementAsync("mykey", 42); +``` + +That's *basically* it. If you need more control over how non-trivial commands are formatted and parsed, +APIs exist for that. But for most common scenarios: that's all we need. + +### Cancellation + +Unusually, our `IncrementAsync` method *does not* have a `CancellationToken cancellationToken = default` +parameter; instead, cancellation is conveyed *in the context*. This also means that cancellation works +for *both* the synchronous and asynchronous versions! We can supply our own cancellation: + +``` c# +var ctx = conn.Context.WithCancellationToken(request.CancellationToken); +// use ctx for commands +``` + +Now `ctx` is not just the *default* context - it has the cancellation token we supplied, and it is used +everywhere automatically! The `RespContext` type is cheap and allocation-free; it has no lifetime etc - it +is just a bundle of state required for RESP operations. We can freely `With...` them: + +``` c# +var db = conn.Context.WithDatabase(4).WithCancellationToken(request.CancellationToken); +// use db for commands +``` + +If you're thinking "Wait - if `RespContext` carries cancellation, does `WithCancellationToken(...)` *replace* +the cancellation, or *combine* the two cancellations?", then: have a cookie. The answer is "replace", but we can also +combine multiple cancellations, noting that now we need to scope that to a *lifetime*: + +``` c# +using var lifetime = db.WithDatabase(4).WithCombineCancellationToken(anotherCancellationToken); +// use lifetime.Context for commands +``` + +This will automatically do the most appropriate thing based on whether neither, one, or both tokens +are cancellable. We can do the same thing with a timeout: + +``` c# +using var lifetime = db.WithCombineTimeout(TimeSpan.FromSeconds(5)); +// use lifetime.Context for commands +``` + +Note that this timeout applies to the *lifetime*, not individual operations (i.e. if we loop forever +performing fast operations: it will still cancel after five seconds). From the name +`WithCombineTimeout`, you can probably guess that this works *in addition to* the +existing cancellation state. Help yourself to another cookie. + +## Summary + +With the combination of `RespConnection` for the long-lived connection, +`RespContext` for the transient local configuration (via various `With*` methods), +and our automatically generated `[RespCommand]` methods: we can easily and +efficiently talk to a range of RESP databases. diff --git a/src/StackExchange.Redis/ConfigurationOptions.cs b/src/StackExchange.Redis/ConfigurationOptions.cs index c0021f024..e9c1d1ac2 100644 --- a/src/StackExchange.Redis/ConfigurationOptions.cs +++ b/src/StackExchange.Redis/ConfigurationOptions.cs @@ -330,9 +330,9 @@ internal static LocalCertificateSelectionCallback CreatePemUserCertificateCallba { // PEM handshakes not universally supported and causes a runtime error about ephemeral certificates; to avoid, export as PFX using var pem = X509Certificate2.CreateFromPemFile(userCertificatePath, userKeyPath); -#pragma warning disable SYSLIB0057 // Type or member is obsolete +#pragma warning disable SYSLIB0057 // because of TFM support var pfx = new X509Certificate2(pem.Export(X509ContentType.Pfx)); -#pragma warning restore SYSLIB0057 // Type or member is obsolete +#pragma warning restore SYSLIB0057 return (sender, targetHost, localCertificates, remoteCertificate, acceptableIssuers) => pfx; } @@ -340,7 +340,9 @@ internal static LocalCertificateSelectionCallback CreatePemUserCertificateCallba internal static LocalCertificateSelectionCallback CreatePfxUserCertificateCallback(string userCertificatePath, string? password, X509KeyStorageFlags storageFlags = X509KeyStorageFlags.DefaultKeySet) { +#pragma warning disable SYSLIB0057 // because of TFM support var pfx = new X509Certificate2(userCertificatePath, password ?? "", storageFlags); +#pragma warning restore SYSLIB0057 return (sender, targetHost, localCertificates, remoteCertificate, acceptableIssuers) => pfx; } @@ -351,7 +353,9 @@ internal static LocalCertificateSelectionCallback CreatePfxUserCertificateCallba public void TrustIssuer(X509Certificate2 issuer) => CertificateValidationCallback = TrustIssuerCallback(issuer); internal static RemoteCertificateValidationCallback TrustIssuerCallback(string issuerCertificatePath) +#pragma warning disable SYSLIB0057 // because of TFM support => TrustIssuerCallback(new X509Certificate2(issuerCertificatePath)); +#pragma warning restore SYSLIB0057 private static RemoteCertificateValidationCallback TrustIssuerCallback(X509Certificate2 issuer) { if (issuer == null) throw new ArgumentNullException(nameof(issuer)); diff --git a/src/StackExchange.Redis/Enums/SortedSetWhen.cs b/src/StackExchange.Redis/Enums/SortedSetWhen.cs index 517aaeaa5..c7a038325 100644 --- a/src/StackExchange.Redis/Enums/SortedSetWhen.cs +++ b/src/StackExchange.Redis/Enums/SortedSetWhen.cs @@ -45,7 +45,7 @@ internal static uint CountBits(this SortedSetWhen when) return c; } - internal static SortedSetWhen Parse(When when) => when switch + internal static SortedSetWhen ToSortedSetWhen(this When when) => when switch { When.Always => SortedSetWhen.Always, When.Exists => SortedSetWhen.Exists, diff --git a/src/StackExchange.Redis/ExceptionFactory.cs b/src/StackExchange.Redis/ExceptionFactory.cs index 7e4eca49a..3cfb0268c 100644 --- a/src/StackExchange.Redis/ExceptionFactory.cs +++ b/src/StackExchange.Redis/ExceptionFactory.cs @@ -107,7 +107,7 @@ internal static Exception NoConnectionAvailable( serverSnapshot = new ServerEndPoint[] { server }; } - var innerException = PopulateInnerExceptions(serverSnapshot == default ? multiplexer.GetServerSnapshot() : serverSnapshot); + var innerException = PopulateInnerExceptions(serverSnapshot.IsEmpty ? multiplexer.GetServerSnapshot() : serverSnapshot); // Try to get a useful error message for the user. long attempts = multiplexer._connectAttemptCount, completions = multiplexer._connectCompletedCount; diff --git a/src/StackExchange.Redis/FrameworkShims.cs b/src/StackExchange.Redis/FrameworkShims.cs index 9472df9ae..0c264bf5c 100644 --- a/src/StackExchange.Redis/FrameworkShims.cs +++ b/src/StackExchange.Redis/FrameworkShims.cs @@ -5,8 +5,12 @@ [assembly: System.Runtime.CompilerServices.TypeForwardedTo(typeof(System.Runtime.CompilerServices.IsExternalInit))] #else // To support { get; init; } properties +using System.Buffers; using System.ComponentModel; +using System.Runtime.InteropServices; using System.Text; +using System.Threading; +using System.Threading.Tasks; namespace System.Runtime.CompilerServices { @@ -16,12 +20,112 @@ internal static class IsExternalInit { } #endif #if !(NETCOREAPP || NETSTANDARD2_1_OR_GREATER) +namespace System.IO +{ + internal static class StreamExtensions + { + public static void Write(this Stream stream, ReadOnlyMemory value) + { + if (MemoryMarshal.TryGetArray(value, out var segment)) + { + stream.Write(segment.Array!, segment.Offset, segment.Count); + } + else + { + var leased = ArrayPool.Shared.Rent(value.Length); + value.CopyTo(leased); + stream.Write(leased, 0, value.Length); + ArrayPool.Shared.Return(leased); // on success only + } + } + + public static int Read(this Stream stream, Memory value) + { + if (MemoryMarshal.TryGetArray(value, out var segment)) + { + return stream.Read(segment.Array!, segment.Offset, segment.Count); + } + else + { + var leased = ArrayPool.Shared.Rent(value.Length); + int bytes = stream.Read(leased, 0, value.Length); + if (bytes > 0) + { + leased.AsSpan(0, bytes).CopyTo(value.Span); + } + ArrayPool.Shared.Return(leased); // on success only + return bytes; + } + } + public static ValueTask ReadAsync(this Stream stream, Memory value, CancellationToken cancellationToken) + { + if (MemoryMarshal.TryGetArray(value, out var segment)) + { + return new(stream.ReadAsync(segment.Array!, segment.Offset, segment.Count, cancellationToken)); + } + else + { + var leased = ArrayPool.Shared.Rent(value.Length); + var pending = stream.ReadAsync(leased, 0, value.Length, cancellationToken); + if (!pending.IsCompleted) + { + return Awaited(pending, value, leased); + } + + var bytes = pending.GetAwaiter().GetResult(); + if (bytes > 0) + { + leased.AsSpan(0, bytes).CopyTo(value.Span); + } + ArrayPool.Shared.Return(leased); // on success only + return new(bytes); + + static async ValueTask Awaited(Task pending, Memory value, byte[] leased) + { + var bytes = await pending.ConfigureAwait(false); + if (bytes > 0) + { + leased.AsSpan(0, bytes).CopyTo(value.Span); + } + ArrayPool.Shared.Return(leased); // on success only + return bytes; + } + } + } + + public static ValueTask WriteAsync(this Stream stream, ReadOnlyMemory value, CancellationToken cancellationToken) + { + if (MemoryMarshal.TryGetArray(value, out var segment)) + { + return new(stream.WriteAsync(segment.Array!, segment.Offset, segment.Count, cancellationToken)); + } + else + { + var leased = ArrayPool.Shared.Rent(value.Length); + value.CopyTo(leased); + var pending = stream.WriteAsync(leased, 0, value.Length, cancellationToken); + if (!pending.IsCompleted) + { + return Awaited(pending, leased); + } + pending.GetAwaiter().GetResult(); + ArrayPool.Shared.Return(leased); // on success only + return default; + } + static async ValueTask Awaited(Task pending, byte[] leased) + { + await pending.ConfigureAwait(false); + ArrayPool.Shared.Return(leased); // on success only + } + } + } +} namespace System.Text { - internal static class EncodingExtensions + internal static unsafe class EncodingExtensions { - public static unsafe int GetBytes(this Encoding encoding, ReadOnlySpan source, Span destination) + public static int GetBytes(this Encoding encoding, ReadOnlySpan source, Span destination) { fixed (byte* bPtr = destination) { @@ -31,6 +135,42 @@ public static unsafe int GetBytes(this Encoding encoding, ReadOnlySpan sou } } } + public static string GetString(this Encoding encoding, ReadOnlySpan source) + { + fixed (byte* bPtr = source) + { + return encoding.GetString(bPtr, source.Length); + } + } + public static int GetChars(this Encoding encoding, ReadOnlySpan source, Span destination) + { + fixed (byte* bPtr = source) + { + fixed (char* cPtr = destination) + { + return encoding.GetChars(bPtr, source.Length, cPtr, destination.Length); + } + } + } + + public static int GetByteCount(this Encoding encoding, ReadOnlySpan source) + { + fixed (char* cPtr = source) + { + return encoding.GetByteCount(cPtr, source.Length); + } + } + + public static void Convert(this Encoder encoder, ReadOnlySpan source, Span destination, bool flush, out int charsUsed, out int bytesUsed, out bool completed) + { + fixed (char* cPtr = source) + { + fixed (byte* bPtr = destination) + { + encoder.Convert(cPtr, source.Length, bPtr, destination.Length, flush, out charsUsed, out bytesUsed, out completed); + } + } + } } } #endif diff --git a/src/StackExchange.Redis/PhysicalConnection.cs b/src/StackExchange.Redis/PhysicalConnection.cs index c21bc07fc..c36f2da72 100644 --- a/src/StackExchange.Redis/PhysicalConnection.cs +++ b/src/StackExchange.Redis/PhysicalConnection.cs @@ -914,7 +914,7 @@ internal void WriteHeader(RedisCommand command, int arguments, CommandBytes comm internal void RecordQuit() { // don't blame redis if we fired the first shot - Thread.VolatileWrite(ref clientSentQuit, 1); + Volatile.Write(ref clientSentQuit, 1); (_ioPipe as SocketConnection)?.TrySetProtocolShutdown(PipeShutdownKind.ProtocolExitClient); } @@ -1967,7 +1967,7 @@ private async Task ReadFromPipe() { _readStatus = ReadStatus.Faulted; // this CEX is just a hardcore "seriously, read the actual value" - there's no - // convenient "Thread.VolatileRead(ref T field) where T : class", and I don't + // convenient "Volatile.Read(ref T field) where T : class", and I don't // want to make the field volatile just for this one place that needs it if (isReading) { diff --git a/src/StackExchange.Redis/RedisBase.cs b/src/StackExchange.Redis/RedisBase.cs index 095835efd..1863cf2b1 100644 --- a/src/StackExchange.Redis/RedisBase.cs +++ b/src/StackExchange.Redis/RedisBase.cs @@ -1,6 +1,7 @@ using System; using System.Diagnostics.CodeAnalysis; using System.Threading.Tasks; +using StackExchange.Redis; namespace StackExchange.Redis { @@ -69,43 +70,6 @@ internal virtual RedisFeatures GetFeatures(in RedisKey key, CommandFlags flags, return new RedisFeatures(version); } - protected static void WhenAlwaysOrExists(When when) - { - switch (when) - { - case When.Always: - case When.Exists: - break; - default: - throw new ArgumentException(when + " is not valid in this context; the permitted values are: Always, Exists"); - } - } - - protected static void WhenAlwaysOrExistsOrNotExists(When when) - { - switch (when) - { - case When.Always: - case When.Exists: - case When.NotExists: - break; - default: - throw new ArgumentException(when + " is not valid in this context; the permitted values are: Always, Exists, NotExists"); - } - } - - protected static void WhenAlwaysOrNotExists(When when) - { - switch (when) - { - case When.Always: - case When.NotExists: - break; - default: - throw new ArgumentException(when + " is not valid in this context; the permitted values are: Always, NotExists"); - } - } - private ResultProcessor.TimingProcessor.TimerMessage GetTimerMessage(CommandFlags flags) { // do the best we can with available commands @@ -137,3 +101,56 @@ internal static bool IsNil(in RedisValue pattern) } } } + +internal static class WhenExtensions +{ + internal static void AlwaysOnly(this When when) + { + if (when != When.Always) Throw(when); + static void Throw(When when) => throw new ArgumentException(when + " is not valid in this context; the permitted values are: Always"); + } + + internal static void AlwaysOrExists(this When when) + { + switch (when) + { + case When.Always: + case When.Exists: + break; + default: + Throw(when); + break; + } + static void Throw(When when) => throw new ArgumentException(when + " is not valid in this context; the permitted values are: Always, Exists"); + } + + internal static void AlwaysOrExistsOrNotExists(this When when) + { + switch (when) + { + case When.Always: + case When.Exists: + case When.NotExists: + break; + default: + Throw(when); + break; + } + static void Throw(When when) + => throw new ArgumentException(when + " is not valid in this context; the permitted values are: Always, Exists, NotExists"); + } + + internal static void AlwaysOrNotExists(this When when) + { + switch (when) + { + case When.Always: + case When.NotExists: + break; + default: + Throw(when); + break; + } + static void Throw(When when) => throw new ArgumentException(when + " is not valid in this context; the permitted values are: Always, NotExists"); + } +} diff --git a/src/StackExchange.Redis/RedisDatabase.cs b/src/StackExchange.Redis/RedisDatabase.cs index 349864a1b..bc119a601 100644 --- a/src/StackExchange.Redis/RedisDatabase.cs +++ b/src/StackExchange.Redis/RedisDatabase.cs @@ -938,7 +938,7 @@ private CursorEnumerable HashScanNoValuesAsync(RedisKey key, RedisVa public bool HashSet(RedisKey key, RedisValue hashField, RedisValue value, When when = When.Always, CommandFlags flags = CommandFlags.None) { - WhenAlwaysOrNotExists(when); + when.AlwaysOrNotExists(); var msg = value.IsNull ? Message.Create(Database, flags, RedisCommand.HDEL, key, hashField) : Message.Create(Database, flags, when == When.Always ? RedisCommand.HSET : RedisCommand.HSETNX, key, hashField, value); @@ -960,7 +960,7 @@ public long HashStringLength(RedisKey key, RedisValue hashField, CommandFlags fl public Task HashSetAsync(RedisKey key, RedisValue hashField, RedisValue value, When when = When.Always, CommandFlags flags = CommandFlags.None) { - WhenAlwaysOrNotExists(when); + when.AlwaysOrNotExists(); var msg = value.IsNull ? Message.Create(Database, flags, RedisCommand.HDEL, key, hashField) : Message.Create(Database, flags, when == When.Always ? RedisCommand.HSET : RedisCommand.HSETNX, key, hashField, value); @@ -1398,14 +1398,14 @@ public Task KeyRandomAsync(CommandFlags flags = CommandFlags.None) public bool KeyRename(RedisKey key, RedisKey newKey, When when = When.Always, CommandFlags flags = CommandFlags.None) { - WhenAlwaysOrNotExists(when); + when.AlwaysOrNotExists(); var msg = Message.Create(Database, flags, when == When.Always ? RedisCommand.RENAME : RedisCommand.RENAMENX, key, newKey); return ExecuteSync(msg, ResultProcessor.Boolean); } public Task KeyRenameAsync(RedisKey key, RedisKey newKey, When when = When.Always, CommandFlags flags = CommandFlags.None) { - WhenAlwaysOrNotExists(when); + when.AlwaysOrNotExists(); var msg = Message.Create(Database, flags, when == When.Always ? RedisCommand.RENAME : RedisCommand.RENAMENX, key, newKey); return ExecuteAsync(msg, ResultProcessor.Boolean); } @@ -1558,14 +1558,14 @@ public Task ListPositionsAsync(RedisKey key, RedisValue element, long co public long ListLeftPush(RedisKey key, RedisValue value, When when = When.Always, CommandFlags flags = CommandFlags.None) { - WhenAlwaysOrExists(when); + when.AlwaysOrExists(); var msg = Message.Create(Database, flags, when == When.Always ? RedisCommand.LPUSH : RedisCommand.LPUSHX, key, value); return ExecuteSync(msg, ResultProcessor.Int64); } public long ListLeftPush(RedisKey key, RedisValue[] values, When when = When.Always, CommandFlags flags = CommandFlags.None) { - WhenAlwaysOrExists(when); + when.AlwaysOrExists(); if (values == null) throw new ArgumentNullException(nameof(values)); var command = when == When.Always ? RedisCommand.LPUSH : RedisCommand.LPUSHX; var msg = values.Length == 0 ? Message.Create(Database, flags, RedisCommand.LLEN, key) : Message.Create(Database, flags, command, key, values); @@ -1581,14 +1581,14 @@ public long ListLeftPush(RedisKey key, RedisValue[] values, CommandFlags flags = public Task ListLeftPushAsync(RedisKey key, RedisValue value, When when = When.Always, CommandFlags flags = CommandFlags.None) { - WhenAlwaysOrExists(when); + when.AlwaysOrExists(); var msg = Message.Create(Database, flags, when == When.Always ? RedisCommand.LPUSH : RedisCommand.LPUSHX, key, value); return ExecuteAsync(msg, ResultProcessor.Int64); } public Task ListLeftPushAsync(RedisKey key, RedisValue[] values, When when = When.Always, CommandFlags flags = CommandFlags.None) { - WhenAlwaysOrExists(when); + when.AlwaysOrExists(); if (values == null) throw new ArgumentNullException(nameof(values)); var command = when == When.Always ? RedisCommand.LPUSH : RedisCommand.LPUSHX; var msg = values.Length == 0 ? Message.Create(Database, flags, RedisCommand.LLEN, key) : Message.Create(Database, flags, command, key, values); @@ -1700,14 +1700,14 @@ public Task ListRightPopLeftPushAsync(RedisKey source, RedisKey dest public long ListRightPush(RedisKey key, RedisValue value, When when = When.Always, CommandFlags flags = CommandFlags.None) { - WhenAlwaysOrExists(when); + when.AlwaysOrExists(); var msg = Message.Create(Database, flags, when == When.Always ? RedisCommand.RPUSH : RedisCommand.RPUSHX, key, value); return ExecuteSync(msg, ResultProcessor.Int64); } public long ListRightPush(RedisKey key, RedisValue[] values, When when = When.Always, CommandFlags flags = CommandFlags.None) { - WhenAlwaysOrExists(when); + when.AlwaysOrExists(); if (values == null) throw new ArgumentNullException(nameof(values)); var command = when == When.Always ? RedisCommand.RPUSH : RedisCommand.RPUSHX; var msg = values.Length == 0 ? Message.Create(Database, flags, RedisCommand.LLEN, key) : Message.Create(Database, flags, command, key, values); @@ -1723,14 +1723,14 @@ public long ListRightPush(RedisKey key, RedisValue[] values, CommandFlags flags public Task ListRightPushAsync(RedisKey key, RedisValue value, When when = When.Always, CommandFlags flags = CommandFlags.None) { - WhenAlwaysOrExists(when); + when.AlwaysOrExists(); var msg = Message.Create(Database, flags, when == When.Always ? RedisCommand.RPUSH : RedisCommand.RPUSHX, key, value); return ExecuteAsync(msg, ResultProcessor.Int64); } public Task ListRightPushAsync(RedisKey key, RedisValue[] values, When when = When.Always, CommandFlags flags = CommandFlags.None) { - WhenAlwaysOrExists(when); + when.AlwaysOrExists(); if (values == null) throw new ArgumentNullException(nameof(values)); var command = when == When.Always ? RedisCommand.RPUSH : RedisCommand.RPUSHX; var msg = values.Length == 0 ? Message.Create(Database, flags, RedisCommand.LLEN, key) : Message.Create(Database, flags, command, key, values); @@ -2271,7 +2271,7 @@ public bool SortedSetAdd(RedisKey key, RedisValue member, double score, CommandF SortedSetAdd(key, member, score, SortedSetWhen.Always, flags); public bool SortedSetAdd(RedisKey key, RedisValue member, double score, When when = When.Always, CommandFlags flags = CommandFlags.None) => - SortedSetAdd(key, member, score, SortedSetWhenExtensions.Parse(when), flags); + SortedSetAdd(key, member, score, SortedSetWhenExtensions.ToSortedSetWhen(when), flags); public bool SortedSetAdd(RedisKey key, RedisValue member, double score, SortedSetWhen when = SortedSetWhen.Always, CommandFlags flags = CommandFlags.None) { @@ -2289,7 +2289,7 @@ public long SortedSetAdd(RedisKey key, SortedSetEntry[] values, CommandFlags fla SortedSetAdd(key, values, SortedSetWhen.Always, flags); public long SortedSetAdd(RedisKey key, SortedSetEntry[] values, When when = When.Always, CommandFlags flags = CommandFlags.None) => - SortedSetAdd(key, values, SortedSetWhenExtensions.Parse(when), flags); + SortedSetAdd(key, values, SortedSetWhenExtensions.ToSortedSetWhen(when), flags); public long SortedSetAdd(RedisKey key, SortedSetEntry[] values, SortedSetWhen when = SortedSetWhen.Always, CommandFlags flags = CommandFlags.None) { @@ -2307,7 +2307,7 @@ public Task SortedSetAddAsync(RedisKey key, RedisValue member, double scor SortedSetAddAsync(key, member, score, SortedSetWhen.Always, flags); public Task SortedSetAddAsync(RedisKey key, RedisValue member, double score, When when = When.Always, CommandFlags flags = CommandFlags.None) => - SortedSetAddAsync(key, member, score, SortedSetWhenExtensions.Parse(when), flags); + SortedSetAddAsync(key, member, score, SortedSetWhenExtensions.ToSortedSetWhen(when), flags); public Task SortedSetAddAsync(RedisKey key, RedisValue member, double score, SortedSetWhen when = SortedSetWhen.Always, CommandFlags flags = CommandFlags.None) { @@ -2325,7 +2325,7 @@ public Task SortedSetAddAsync(RedisKey key, SortedSetEntry[] values, Comma SortedSetAddAsync(key, values, SortedSetWhen.Always, flags); public Task SortedSetAddAsync(RedisKey key, SortedSetEntry[] values, When when = When.Always, CommandFlags flags = CommandFlags.None) => - SortedSetAddAsync(key, values, SortedSetWhenExtensions.Parse(when), flags); + SortedSetAddAsync(key, values, SortedSetWhenExtensions.ToSortedSetWhen(when), flags); public Task SortedSetAddAsync(RedisKey key, SortedSetEntry[] values, SortedSetWhen when = SortedSetWhen.Always, CommandFlags flags = CommandFlags.None) { @@ -3776,7 +3776,7 @@ public Task StringSetRangeAsync(RedisKey key, long offset, RedisValu return ExecuteAsync(msg, ResultProcessor.RedisValue); } - private static long GetUnixTimeMilliseconds(DateTime when) => when.Kind switch + internal static long GetUnixTimeMilliseconds(DateTime when) => when.Kind switch { DateTimeKind.Local or DateTimeKind.Utc => (when.ToUniversalTime() - RedisBase.UnixEpoch).Ticks / TimeSpan.TicksPerMillisecond, _ => throw new ArgumentException("Expiry time must be either Utc or Local", nameof(when)), @@ -4598,12 +4598,15 @@ private Message GetStreamAddMessage(RedisKey key, RedisValue entryId, long? maxL throw new ArgumentOutOfRangeException(nameof(maxLength), "maxLength must be greater than 0."); } + var includeMaxLen = maxLength.HasValue ? 2 : 0; + var includeApproxLen = maxLength.HasValue && useApproximateMaxLength ? 1 : 0; + var totalLength = (streamPairs.Length * 2) // Room for the name/value pairs - + 1 // The stream entry ID - + (maxLength.HasValue ? 2 : 0) // MAXLEN N - + (maxLength.HasValue && useApproximateMaxLength ? 1 : 0) // ~ - + (mode == StreamTrimMode.KeepReferences ? 0 : 1) // relevant trim-mode keyword - + (limit.HasValue ? 2 : 0); // LIMIT N + + 1 // The stream entry ID + + (maxLength.HasValue ? 2 : 0) // MAXLEN N + + (maxLength.HasValue && useApproximateMaxLength ? 1 : 0) // ~ + + (mode == StreamTrimMode.KeepReferences ? 0 : 1) // relevant trim-mode keyword + + (limit.HasValue ? 2 : 0); // LIMIT N var values = new RedisValue[totalLength]; @@ -5024,7 +5027,7 @@ private Message GetStringGetWithExpiryMessage(RedisKey key, CommandFlags flags, case 0: return null; case 1: return GetStringSetMessage(values[0].Key, values[0].Value, null, false, when, flags); default: - WhenAlwaysOrNotExists(when); + when.AlwaysOrNotExists(); int slot = ServerSelectionStrategy.NoSlot, offset = 0; var args = new RedisValue[values.Length * 2]; var serverSelectionStrategy = multiplexer.ServerSelectionStrategy; @@ -5046,7 +5049,7 @@ private Message GetStringSetMessage( When when = When.Always, CommandFlags flags = CommandFlags.None) { - WhenAlwaysOrExistsOrNotExists(when); + when.AlwaysOrExistsOrNotExists(); if (value.IsNull) return Message.Create(Database, flags, RedisCommand.DEL, key); if (expiry == null || expiry.Value == TimeSpan.MaxValue) @@ -5095,7 +5098,7 @@ private Message GetStringSetAndGetMessage( When when = When.Always, CommandFlags flags = CommandFlags.None) { - WhenAlwaysOrExistsOrNotExists(when); + when.AlwaysOrExistsOrNotExists(); if (value.IsNull) return Message.Create(Database, flags, RedisCommand.GETDEL, key); if (expiry == null || expiry.Value == TimeSpan.MaxValue) diff --git a/src/StackExchange.Redis/StackExchange.Redis.csproj b/src/StackExchange.Redis/StackExchange.Redis.csproj index b13a12423..761faa819 100644 --- a/src/StackExchange.Redis/StackExchange.Redis.csproj +++ b/src/StackExchange.Redis/StackExchange.Redis.csproj @@ -2,7 +2,7 @@ enable - net461;netstandard2.0;net472;netcoreapp3.1;net6.0;net8.0 + net461;netstandard2.0;net472;net6.0;net8.0;net9.0 High performance Redis client, incorporating both synchronous and asynchronous usage. StackExchange.Redis StackExchange.Redis @@ -46,6 +46,7 @@ + diff --git a/tests/BasicTest/BasicTest.csproj b/tests/BasicTest/BasicTest.csproj index 593d26619..97b916dd5 100644 --- a/tests/BasicTest/BasicTest.csproj +++ b/tests/BasicTest/BasicTest.csproj @@ -2,11 +2,10 @@ StackExchange.Redis.BasicTest .NET Core - net472;net8.0 + net472;net8.0;net9.0 BasicTest Exe BasicTest - @@ -15,6 +14,9 @@ + + + diff --git a/tests/BasicTest/CustomConfig.cs b/tests/BasicTest/CustomConfig.cs new file mode 100644 index 000000000..d062f0f1f --- /dev/null +++ b/tests/BasicTest/CustomConfig.cs @@ -0,0 +1,30 @@ +using System.Runtime.InteropServices; +using BenchmarkDotNet.Columns; +using BenchmarkDotNet.Configs; +using BenchmarkDotNet.Diagnosers; +using BenchmarkDotNet.Environments; +using BenchmarkDotNet.Jobs; +using BenchmarkDotNet.Validators; + +namespace BasicTest; + +internal class CustomConfig : ManualConfig +{ + protected virtual Job Configure(Job j) + => j.WithGcMode(new GcMode { Force = true }) + // .With(InProcessToolchain.Instance) + ; + + public CustomConfig() + { + AddDiagnoser(MemoryDiagnoser.Default); + AddColumn(StatisticColumn.OperationsPerSecond); + AddValidator(JitOptimizationsValidator.FailOnError); + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + AddJob(Configure(Job.Default.WithRuntime(ClrRuntime.Net472))); + } + + AddJob(Configure(Job.Default.WithRuntime(CoreRuntime.Core80))); + } +} diff --git a/tests/BasicTest/Issue898.cs b/tests/BasicTest/Issue898.cs new file mode 100644 index 000000000..00f449e16 --- /dev/null +++ b/tests/BasicTest/Issue898.cs @@ -0,0 +1,79 @@ +using System; +using System.Threading.Tasks; +using BenchmarkDotNet.Attributes; +using StackExchange.Redis; + +namespace BasicTest; + +[Config(typeof(SlowConfig))] +public class Issue898 : IDisposable +{ + private readonly ConnectionMultiplexer mux; + private readonly IDatabase db; + + public void Dispose() + { + mux?.Dispose(); + GC.SuppressFinalize(this); + } + + public Issue898() + { + mux = ConnectionMultiplexer.Connect("127.0.0.1:6379"); + db = mux.GetDatabase(); + } + + private const int Max = 100000; + + [Benchmark(OperationsPerInvoke = Max)] + public void Load() + { + for (int i = 0; i < Max; ++i) + { + db.StringSet(i.ToString(), i); + } + } + + [Benchmark(OperationsPerInvoke = Max)] + public async Task LoadAsync() + { + for (int i = 0; i < Max; ++i) + { + await db.StringSetAsync(i.ToString(), i).ConfigureAwait(false); + } + } + + [Benchmark(OperationsPerInvoke = Max)] + public void Sample() + { + var rnd = new Random(); + + for (int i = 0; i < Max; ++i) + { + var r = rnd.Next(0, Max - 1); + + var rv = db.StringGet(r.ToString()); + if (rv != r) + { + throw new Exception($"Unexpected {rv}, expected {r}"); + } + } + } + + [Benchmark(OperationsPerInvoke = Max)] + public async Task SampleAsync() + { + var rnd = new Random(); + + for (int i = 0; i < Max; ++i) + { + var r = rnd.Next(0, Max - 1); + + var rv = await db.StringGetAsync(r.ToString()).ConfigureAwait(false); + if (rv != r) + { + throw new Exception($"Unexpected {rv}, expected {r}"); + } + } + } +} diff --git a/tests/BasicTest/Program.cs b/tests/BasicTest/Program.cs index 2977c42c2..c32eef998 100644 --- a/tests/BasicTest/Program.cs +++ b/tests/BasicTest/Program.cs @@ -1,276 +1,4 @@ -using System; -using System.Reflection; -using System.Threading.Tasks; -using BenchmarkDotNet.Attributes; -using BenchmarkDotNet.Columns; -using BenchmarkDotNet.Configs; -using BenchmarkDotNet.Diagnosers; -using BenchmarkDotNet.Environments; -using BenchmarkDotNet.Jobs; +using System.Reflection; using BenchmarkDotNet.Running; -using BenchmarkDotNet.Validators; -using StackExchange.Redis; -namespace BasicTest -{ - internal static class Program - { - private static void Main(string[] args) => BenchmarkSwitcher.FromAssembly(typeof(Program).GetTypeInfo().Assembly).Run(args); - } - internal class CustomConfig : ManualConfig - { - protected virtual Job Configure(Job j) - => j.WithGcMode(new GcMode { Force = true }) - // .With(InProcessToolchain.Instance) - ; - - public CustomConfig() - { - AddDiagnoser(MemoryDiagnoser.Default); - AddColumn(StatisticColumn.OperationsPerSecond); - AddValidator(JitOptimizationsValidator.FailOnError); - - AddJob(Configure(Job.Default.WithRuntime(ClrRuntime.Net472))); - AddJob(Configure(Job.Default.WithRuntime(CoreRuntime.Core50))); - } - } - internal sealed class SlowConfig : CustomConfig - { - protected override Job Configure(Job j) - => j.WithLaunchCount(1) - .WithWarmupCount(1) - .WithIterationCount(5); - } - - [Config(typeof(CustomConfig))] - public class RedisBenchmarks : IDisposable - { - private SocketManager mgr; - private ConnectionMultiplexer connection; - private IDatabase db; - - [GlobalSetup] - public void Setup() - { - // Pipelines.Sockets.Unofficial.SocketConnection.AssertDependencies(); - var options = ConfigurationOptions.Parse("127.0.0.1:6379"); - connection = ConnectionMultiplexer.Connect(options); - db = connection.GetDatabase(3); - - db.KeyDelete(GeoKey); - db.GeoAdd(GeoKey, 13.361389, 38.115556, "Palermo "); - db.GeoAdd(GeoKey, 15.087269, 37.502669, "Catania"); - - db.KeyDelete(HashKey); - for (int i = 0; i < 1000; i++) - { - db.HashSet(HashKey, i, i); - } - } - - private static readonly RedisKey GeoKey = "GeoTest", IncrByKey = "counter", StringKey = "string", HashKey = "hash"; - void IDisposable.Dispose() - { - mgr?.Dispose(); - connection?.Dispose(); - mgr = null; - db = null; - connection = null; - GC.SuppressFinalize(this); - } - - private const int COUNT = 50; - - /// - /// Run INCRBY lots of times. - /// - // [Benchmark(Description = "INCRBY/s", OperationsPerInvoke = COUNT)] - public int ExecuteIncrBy() - { - var rand = new Random(12345); - - db.KeyDelete(IncrByKey, CommandFlags.FireAndForget); - int expected = 0; - for (int i = 0; i < COUNT; i++) - { - int x = rand.Next(50); - expected += x; - db.StringIncrement(IncrByKey, x, CommandFlags.FireAndForget); - } - int actual = (int)db.StringGet(IncrByKey); - if (actual != expected) throw new InvalidOperationException($"expected: {expected}, actual: {actual}"); - return actual; - } - - /// - /// Run INCRBY lots of times. - /// - // [Benchmark(Description = "INCRBY/a", OperationsPerInvoke = COUNT)] - public async Task ExecuteIncrByAsync() - { - var rand = new Random(12345); - - db.KeyDelete(IncrByKey, CommandFlags.FireAndForget); - int expected = 0; - for (int i = 0; i < COUNT; i++) - { - int x = rand.Next(50); - expected += x; - await db.StringIncrementAsync(IncrByKey, x, CommandFlags.FireAndForget).ConfigureAwait(false); - } - int actual = (int)await db.StringGetAsync(IncrByKey).ConfigureAwait(false); - if (actual != expected) throw new InvalidOperationException($"expected: {expected}, actual: {actual}"); - return actual; - } - - /// - /// Run GEORADIUS lots of times. - /// - // [Benchmark(Description = "GEORADIUS/s", OperationsPerInvoke = COUNT)] - public int ExecuteGeoRadius() - { - int total = 0; - for (int i = 0; i < COUNT; i++) - { - var results = db.GeoRadius(GeoKey, 15, 37, 200, GeoUnit.Kilometers, options: GeoRadiusOptions.WithCoordinates | GeoRadiusOptions.WithDistance | GeoRadiusOptions.WithGeoHash); - total += results.Length; - } - return total; - } - - /// - /// Run GEORADIUS lots of times. - /// - // [Benchmark(Description = "GEORADIUS/a", OperationsPerInvoke = COUNT)] - public async Task ExecuteGeoRadiusAsync() - { - int total = 0; - for (int i = 0; i < COUNT; i++) - { - var results = await db.GeoRadiusAsync(GeoKey, 15, 37, 200, GeoUnit.Kilometers, options: GeoRadiusOptions.WithCoordinates | GeoRadiusOptions.WithDistance | GeoRadiusOptions.WithGeoHash).ConfigureAwait(false); - total += results.Length; - } - return total; - } - - /// - /// Run StringSet lots of times. - /// - [Benchmark(Description = "StringSet/s", OperationsPerInvoke = COUNT)] - public void StringSet() - { - for (int i = 0; i < COUNT; i++) - { - db.StringSet(StringKey, "hey"); - } - } - - /// - /// Run StringGet lots of times. - /// - [Benchmark(Description = "StringGet/s", OperationsPerInvoke = COUNT)] - public void StringGet() - { - for (int i = 0; i < COUNT; i++) - { - db.StringGet(StringKey); - } - } - - /// - /// Run HashGetAll lots of times. - /// - [Benchmark(Description = "HashGetAll F+F/s", OperationsPerInvoke = COUNT)] - public void HashGetAll_FAF() - { - for (int i = 0; i < COUNT; i++) - { - db.HashGetAll(HashKey, CommandFlags.FireAndForget); - db.Ping(); // to wait for response - } - } - - /// - /// Run HashGetAll lots of times. - /// - [Benchmark(Description = "HashGetAll F+F/a", OperationsPerInvoke = COUNT)] - - public async Task HashGetAllAsync_FAF() - { - for (int i = 0; i < COUNT; i++) - { - await db.HashGetAllAsync(HashKey, CommandFlags.FireAndForget); - await db.PingAsync(); // to wait for response - } - } - } - - [Config(typeof(SlowConfig))] - public class Issue898 : IDisposable - { - private readonly ConnectionMultiplexer mux; - private readonly IDatabase db; - - public void Dispose() - { - mux?.Dispose(); - GC.SuppressFinalize(this); - } - public Issue898() - { - mux = ConnectionMultiplexer.Connect("127.0.0.1:6379"); - db = mux.GetDatabase(); - } - - private const int Max = 100000; - [Benchmark(OperationsPerInvoke = Max)] - public void Load() - { - for (int i = 0; i < Max; ++i) - { - db.StringSet(i.ToString(), i); - } - } - [Benchmark(OperationsPerInvoke = Max)] - public async Task LoadAsync() - { - for (int i = 0; i < Max; ++i) - { - await db.StringSetAsync(i.ToString(), i).ConfigureAwait(false); - } - } - [Benchmark(OperationsPerInvoke = Max)] - public void Sample() - { - var rnd = new Random(); - - for (int i = 0; i < Max; ++i) - { - var r = rnd.Next(0, Max - 1); - - var rv = db.StringGet(r.ToString()); - if (rv != r) - { - throw new Exception($"Unexpected {rv}, expected {r}"); - } - } - } - - [Benchmark(OperationsPerInvoke = Max)] - public async Task SampleAsync() - { - var rnd = new Random(); - - for (int i = 0; i < Max; ++i) - { - var r = rnd.Next(0, Max - 1); - - var rv = await db.StringGetAsync(r.ToString()).ConfigureAwait(false); - if (rv != r) - { - throw new Exception($"Unexpected {rv}, expected {r}"); - } - } - } - } -} +BenchmarkSwitcher.FromAssembly(typeof(Program).GetTypeInfo().Assembly).Run(args); diff --git a/tests/BasicTest/RedisBenchmarks.cs b/tests/BasicTest/RedisBenchmarks.cs new file mode 100644 index 000000000..240a3f471 --- /dev/null +++ b/tests/BasicTest/RedisBenchmarks.cs @@ -0,0 +1,462 @@ +using System; +using System.Threading.Tasks; +using BenchmarkDotNet.Attributes; +#if !TEST_BASELINE +using RESPite; +using RESPite.Connections; +using RESPite.Redis; +#if !PREVIEW_LANGVER +using RESPite.Redis.Alt; // needed for AsStrings() etc +#endif +#endif +using StackExchange.Redis; + +namespace BasicTest; + +[Config(typeof(CustomConfig))] +public class RedisBenchmarks : IDisposable +{ + private SocketManager mgr; + private ConnectionMultiplexer connection; + private IDatabase db; +#if !TEST_BASELINE + private RespConnectionPool pool, customPool; +#endif + + [GlobalSetup] + public void Setup() + { + // Pipelines.Sockets.Unofficial.SocketConnection.AssertDependencies(); +#if !TEST_BASELINE + pool = new(); +#pragma warning disable CS0618 // Type or member is obsolete + customPool = new() { UseCustomNetworkStream = true }; +#pragma warning restore CS0618 // Type or member is obsolete +#endif + // var options = ConfigurationOptions.Parse("127.0.0.1:6379"); + // connection = ConnectionMultiplexer.Connect(options); + // db = connection.GetDatabase(3); + // + // db.KeyDelete(GeoKey); + // db.KeyDelete(StringKey_K); + // db.StringSet(StringKey_K, StringValue_S); + // db.GeoAdd(GeoKey, 13.361389, 38.115556, "Palermo "); + // db.GeoAdd(GeoKey, 15.087269, 37.502669, "Catania"); + // + // db.KeyDelete(HashKey); + // for (int i = 0; i < 1000; i++) + // { + // db.HashSet(HashKey, i, i); + // } + } + + public const string StringKey_S = "string", StringValue_S = "some suitably non-trivial value"; + + public static readonly RedisKey GeoKey = "GeoTest", + IncrByKey = "counter", + StringKey_K = StringKey_S, + HashKey = "hash"; + + public static readonly RedisValue StringValue_V = StringValue_S; + + void IDisposable.Dispose() + { +#if !TEST_BASELINE + pool?.Dispose(); + customPool?.Dispose(); +#endif + mgr?.Dispose(); + connection?.Dispose(); + mgr = null; + db = null; + connection = null; + GC.SuppressFinalize(this); + } + + public const int OperationsPerInvoke = 128; + + /// + /// Run INCRBY lots of times. + /// + // [Benchmark(Description = "INCRBY/s", OperationsPerInvoke = COUNT)] + public int ExecuteIncrBy() + { + var rand = new Random(12345); + + db.KeyDelete(IncrByKey, CommandFlags.FireAndForget); + int expected = 0; + for (int i = 0; i < OperationsPerInvoke; i++) + { + int x = rand.Next(50); + expected += x; + db.StringIncrement(IncrByKey, x, CommandFlags.FireAndForget); + } + + int actual = (int)db.StringGet(IncrByKey); + if (actual != expected) throw new InvalidOperationException($"expected: {expected}, actual: {actual}"); + return actual; + } + + /// + /// Run INCRBY lots of times. + /// + // [Benchmark(Description = "INCRBY/a", OperationsPerInvoke = COUNT)] + public async Task ExecuteIncrByAsync() + { + var rand = new Random(12345); + + db.KeyDelete(IncrByKey, CommandFlags.FireAndForget); + int expected = 0; + for (int i = 0; i < OperationsPerInvoke; i++) + { + int x = rand.Next(50); + expected += x; + await db.StringIncrementAsync(IncrByKey, x, CommandFlags.FireAndForget).ConfigureAwait(false); + } + + int actual = (int)await db.StringGetAsync(IncrByKey).ConfigureAwait(false); + if (actual != expected) throw new InvalidOperationException($"expected: {expected}, actual: {actual}"); + return actual; + } + + /// + /// Run GEORADIUS lots of times. + /// + // [Benchmark(Description = "GEORADIUS/s", OperationsPerInvoke = COUNT)] + public int ExecuteGeoRadius() + { + int total = 0; + const GeoRadiusOptions options = GeoRadiusOptions.WithCoordinates | GeoRadiusOptions.WithDistance | + GeoRadiusOptions.WithGeoHash; + for (int i = 0; i < OperationsPerInvoke; i++) + { + var results = db.GeoRadius( + GeoKey, + 15, + 37, + 200, + GeoUnit.Kilometers, + options: options); + total += results.Length; + } + + return total; + } + + /// + /// Run GEORADIUS lots of times. + /// + // [Benchmark(Description = "GEORADIUS/a", OperationsPerInvoke = COUNT)] + public async Task ExecuteGeoRadiusAsync() + { + var options = GeoRadiusOptions.WithCoordinates | GeoRadiusOptions.WithDistance | + GeoRadiusOptions.WithGeoHash; + int total = 0; + for (int i = 0; i < OperationsPerInvoke; i++) + { + var results = await db.GeoRadiusAsync( + GeoKey, + 15, + 37, + 200, + GeoUnit.Kilometers, + options: options) + .ConfigureAwait(false); + total += results.Length; + } + + return total; + } + + /// + /// Run StringSet lots of times. + /// + // [Benchmark(Description = "StringSet/s", OperationsPerInvoke = COUNT)] + public void StringSet() + { + for (int i = 0; i < OperationsPerInvoke; i++) + { + db.StringSet(StringKey_K, StringValue_V); + } + } + + /// + /// Run StringGet lots of times. + /// + // [Benchmark(Description = "StringGet/s", OperationsPerInvoke = COUNT)] + public void StringGet() + { + for (int i = 0; i < OperationsPerInvoke; i++) + { + db.StringGet(StringKey_K); + } + } + +#if !TEST_BASELINE + /// + /// Run StringSet lots of times. + /// + // [Benchmark(Description = "C StringSet/s", OperationsPerInvoke = COUNT)] + public void StringSet_Core() + { + using var conn = pool.GetConnection(); +#if PREVIEW_LANGVER + ref readonly RedisStrings s = ref conn.Context.Strings; +#else + var s = conn.Context.AsStrings(); +#endif + for (int i = 0; i < OperationsPerInvoke; i++) + { + s.Set(StringKey_S, StringValue_S); + } + } + + /// + /// Run StringGet lots of times. + /// + // [Benchmark(Description = "C StringGet/s", OperationsPerInvoke = COUNT)] + public void StringGet_Core() + { + using var conn = pool.GetConnection(); +#if PREVIEW_LANGVER + ref readonly RedisStrings s = ref conn.Context.Strings; +#else + var s = conn.Context.AsStrings(); +#endif + for (int i = 0; i < OperationsPerInvoke; i++) + { + s.Get(StringKey_S); + } + } + + /// + /// Run StringSet lots of times. + /// + // [Benchmark(Description = "PC StringSet/s", OperationsPerInvoke = COUNT)] + public void StringSet_Pipelined_Core() + { + using var conn = pool.GetConnection().Synchronized(); +#if PREVIEW_LANGVER + ref readonly RedisStrings s = ref conn.Context.Strings; +#else + var s = conn.Context.AsStrings(); +#endif + for (int i = 0; i < OperationsPerInvoke; i++) + { + s.Set(StringKey_S, StringValue_S); + } + } + + /// + /// Run StringSet lots of times. + /// + // [Benchmark(Description = "PCA StringSet/s", OperationsPerInvoke = COUNT)] + public async Task StringSet_Pipelined_Core_Async() + { + using var conn = pool.GetConnection().Synchronized(); + var ctx = conn.Context; + for (int i = 0; i < OperationsPerInvoke; i++) + { +#if PREVIEW_LANGVER + await ctx.Strings.SetAsync(StringKey_S, StringValue_S); +#else + await ctx.AsStrings().SetAsync(StringKey_S, StringValue_S); +#endif + } + } + + /// + /// Run StringGet lots of times. + /// + // [Benchmark(Description = "PC StringGet/s", OperationsPerInvoke = COUNT)] + public void StringGet_Pipelined_Core() + { + using var conn = pool.GetConnection().Synchronized(); +#if PREVIEW_LANGVER + ref readonly RedisStrings s = ref conn.Context.Strings; +#else + var s = conn.Context.AsStrings(); +#endif + for (int i = 0; i < OperationsPerInvoke; i++) + { + s.Get(StringKey_S); + } + } + + /// + /// Run StringGet lots of times. + /// + // [Benchmark(Description = "PCA StringGet/s", OperationsPerInvoke = COUNT)] + public async Task StringGet_Pipelined_Core_Async() + { + using var conn = pool.GetConnection().Synchronized(); + var ctx = conn.Context; + for (int i = 0; i < OperationsPerInvoke; i++) + { +#if PREVIEW_LANGVER + await ctx.Strings.GetAsync(StringKey_S); +#else + await ctx.AsStrings().GetAsync(StringKey_S); +#endif + } + } +#endif + + /// + /// Run HashGetAll lots of times. + /// + // [Benchmark(Description = "HashGetAll F+F/s", OperationsPerInvoke = COUNT)] + public void HashGetAll_FAF() + { + for (int i = 0; i < OperationsPerInvoke; i++) + { + db.HashGetAll(HashKey, CommandFlags.FireAndForget); + db.Ping(); // to wait for response + } + } + + /// + /// Run HashGetAll lots of times. + /// + // [Benchmark(Description = "HashGetAll F+F/a", OperationsPerInvoke = COUNT)] + public async Task HashGetAllAsync_FAF() + { + for (int i = 0; i < OperationsPerInvoke; i++) + { + await db.HashGetAllAsync(HashKey, CommandFlags.FireAndForget); + await db.PingAsync(); // to wait for response + } + } + + /// + /// Run incr lots of times. + /// + // [Benchmark(Description = "old incr", OperationsPerInvoke = OperationsPerInvoke)] + public int IncrBy_Old() + { + RedisValue value = 0; + db.StringSet(StringKey_K, value); + for (int i = 0; i < OperationsPerInvoke; i++) + { + value = db.StringIncrement(StringKey_K); + } + + return (int)value; + } + +#if !TEST_BASELINE + /// + /// Run incr lots of times. + /// + [Benchmark(Description = "new incr /p", OperationsPerInvoke = OperationsPerInvoke)] + public int IncrBy_New_Pipelined() + { + using var conn = pool.GetConnection().Synchronized(); +#if PREVIEW_LANGVER + ref readonly RedisStrings s = ref conn.Context.Strings; +#else + var s = conn.Context.AsStrings(); +#endif + int value = 0; + s.Set(StringKey_S, value); + for (int i = 0; i < OperationsPerInvoke; i++) + { + value = s.Incr(StringKey_K); + } + + return value; + } + + /// + /// Run incr lots of times. + /// + [Benchmark(Description = "new incr /p/a", OperationsPerInvoke = OperationsPerInvoke)] + public async Task IncrBy_New_Pipelined_Async() + { + using var conn = pool.GetConnection().Synchronized(); + var ctx = conn.Context; + int value = 0; +#if PREVIEW_LANGVER + await ctx.Strings.SetAsync(StringKey_S, value); +#else + await ctx.AsStrings().SetAsync(StringKey_S, value); +#endif + for (int i = 0; i < OperationsPerInvoke; i++) + { +#if PREVIEW_LANGVER + value = await ctx.Strings.IncrAsync(StringKey_K); +#else + value = await ctx.AsStrings().IncrAsync(StringKey_K); +#endif + } + + return value; + } + + /// + /// Run incr lots of times. + /// + [Benchmark(Description = "new incr", OperationsPerInvoke = OperationsPerInvoke)] + public int IncrBy_New() + { + using var conn = pool.GetConnection(); +#if PREVIEW_LANGVER + ref readonly RedisStrings s = ref conn.Context.Strings; +#else + var s = conn.Context.AsStrings(); +#endif + int value = 0; + s.Set(StringKey_S, value); + for (int i = 0; i < OperationsPerInvoke; i++) + { + value = s.Incr(StringKey_K); + } + + return value; + } + + /// + /// Run incr lots of times. + /// + // [Benchmark(Description = "new incr /pc", OperationsPerInvoke = OperationsPerInvoke)] + public int IncrBy_New_Pipelined_Custom() + { + using var conn = customPool.GetConnection().Synchronized(); +#if PREVIEW_LANGVER + ref readonly RedisStrings s = ref conn.Context.Strings; +#else + var s = conn.Context.AsStrings(); +#endif + int value = 0; + s.Set(StringKey_S, value); + for (int i = 0; i < OperationsPerInvoke; i++) + { + value = s.Incr(StringKey_K); + } + + return value; + } + + /// + /// Run incr lots of times. + /// + // [Benchmark(Description = "new incr /c", OperationsPerInvoke = OperationsPerInvoke)] + public int IncrBy_New_Custom() + { + using var conn = customPool.GetConnection(); +#if PREVIEW_LANGVER + ref readonly RedisStrings s = ref conn.Context.Strings; +#else + var s = conn.Context.AsStrings(); +#endif + int value = 0; + s.Set(StringKey_S, value); + for (int i = 0; i < OperationsPerInvoke; i++) + { + value = s.Incr(StringKey_K); + } + + return value; + } +#endif +} diff --git a/tests/BasicTest/SlowConfig.cs b/tests/BasicTest/SlowConfig.cs new file mode 100644 index 000000000..cc5aa4537 --- /dev/null +++ b/tests/BasicTest/SlowConfig.cs @@ -0,0 +1,11 @@ +using BenchmarkDotNet.Jobs; + +namespace BasicTest; + +internal sealed class SlowConfig : CustomConfig +{ + protected override Job Configure(Job j) + => j.WithLaunchCount(1) + .WithWarmupCount(1) + .WithIterationCount(5); +} diff --git a/tests/BasicTestBaseline/BasicTestBaseline.csproj b/tests/BasicTestBaseline/BasicTestBaseline.csproj index a9f75e441..7bc6d3697 100644 --- a/tests/BasicTestBaseline/BasicTestBaseline.csproj +++ b/tests/BasicTestBaseline/BasicTestBaseline.csproj @@ -6,12 +6,11 @@ BasicTestBaseline Exe BasicTestBaseline - $(DefineConstants);TEST_BASELINE - + diff --git a/tests/RESPite.Tests/BasicIntegrationTests.cs b/tests/RESPite.Tests/BasicIntegrationTests.cs new file mode 100644 index 000000000..312699b54 --- /dev/null +++ b/tests/RESPite.Tests/BasicIntegrationTests.cs @@ -0,0 +1,99 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using RESPite.Connections; +using RESPite.Messages; +using RESPite.Redis.Alt; // needed for AsStrings() etc +using Xunit; + +namespace RESPite.Tests; + +public class BasicIntegrationTests(ConnectionFixture fixture, ITestOutputHelper log) : IntegrationTestBase(fixture, log) +{ + [Fact] + public void Format() + { + Span buffer = stackalloc byte[128]; + var writer = new RespWriter(buffer); + RespFormatters.Value.String.Format("get"u8, ref writer, "abc"); + writer.Flush(); + Assert.Equal("*2\r\n$3\r\nget\r\n$3\r\nabc\r\n", writer.DebugBuffer()); + } + + [Fact] + public void Parse() + { + ReadOnlySpan buffer = "$3\r\nabc\r\n"u8; + var reader = new RespReader(buffer); + reader.MoveNext(); + var value = RespParsers.String.Parse(ref reader); + reader.DemandEnd(); + Assert.Equal("abc", value); + } + + [Theory] + [InlineData(1)] + [InlineData(5)] + [InlineData(100)] + [InlineData(1000)] + public void Ping(int count) + { + using var conn = GetConnection(); + var ctx = conn.Context; + for (int i = 0; i < count; i++) + { + var key = $"{Me()}{i}"; + ctx.AsStrings().Set(key, $"def{i}"); + var val = ctx.AsStrings().Get(key); + Assert.Equal($"def{i}", val); + } + } + + [Theory] + [InlineData(1)] + [InlineData(5)] + [InlineData(100)] + [InlineData(1000)] + public async Task PingAsync(int count) + { + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(10)); + await using var conn = GetConnection(); + for (int i = 0; i < count; i++) + { + var ctx = conn.Context.WithCancellationToken(cts.Token); + var key = $"{Me()}{i}"; + await ctx.AsStrings().SetAsync(key, $"def{i}"); + var val = await ctx.AsStrings().GetAsync(key); + Assert.Equal($"def{i}", val); + } + } + + [Theory] + [InlineData(1, false)] + [InlineData(5, false)] + [InlineData(100, false)] + [InlineData(1, true)] + [InlineData(5, true)] + [InlineData(100, true)] + public async Task PingPipelinedAsync(int count, bool forPipeline) + { + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(10)); + + await using var conn = forPipeline ? GetConnection().Synchronized() : GetConnection(); + + ValueTask[] tasks = new ValueTask[count]; + for (int i = 0; i < count; i++) + { + RespContext ctx = conn.Context.WithCancellationToken(cts.Token); + var key = $"{Me()}{i}"; + _ = ctx.AsStrings().SetAsync(key, $"def{i}"); + tasks[i] = ctx.AsStrings().GetAsync(key); + } + + for (int i = 0; i < count; i++) + { + var val = await tasks[i]; + Assert.Equal($"def{i}", val); + } + } +} diff --git a/tests/RESPite.Tests/BatchTests.cs b/tests/RESPite.Tests/BatchTests.cs new file mode 100644 index 000000000..17e5f515f --- /dev/null +++ b/tests/RESPite.Tests/BatchTests.cs @@ -0,0 +1,97 @@ +using System; +using System.Threading.Tasks; +using Xunit; + +namespace RESPite.Tests; + +public partial class BatchTests +{ + [Fact] + public async Task TestInfrastructure() + { + await TestServer.Execute(ctx => FooAsync(ctx), "*1\r\n$3\r\nfoo\r\n"u8, ":42\r\n"u8, 42); + await TestServer.Execute(ctx => FooAsync(ctx), "*1\r\n$3\r\nfoo\r\n", ":42\r\n", 42); + await TestServer.Execute(ctx => BarAsync(ctx), "*1\r\n$3\r\nbar\r\n"u8, "+ok\r\n"u8); + await TestServer.Execute(ctx => BarAsync(ctx), "*1\r\n$3\r\nbar\r\n", "+OK\r\n"); + } + + [Fact(Timeout = 500)] // this should be very fast unless something is very wrong + public async Task SimpleBatching() + { + // server setup + using var server = new TestServer(); + var cancellationToken = server.Context.CancellationToken; + Assert.Equal(TestContext.Current.CancellationToken, cancellationToken); // check server has CT + Assert.True(cancellationToken.CanBeCanceled); + + // prepare a batch + ValueTask a, b, c, d, e, f; + using (var batch = server.Context.CreateBatch()) + { + Assert.Equal(cancellationToken, batch.Context.CancellationToken); // check the batch inherited CT + + b = TestAsync(batch.Context, 1); + Assert.Equal(cancellationToken, b.AsRespOperation().CancellationToken); // check batch ops inherit CT + c = TestAsync(batch.Context, 2); + d = TestAsync(batch.Context, 3); + + // we want to sandwich the batch between two regular operations + a = TestAsync(server.Context, 0); // uses SERVER + Assert.Equal(cancellationToken, a.AsRespOperation().CancellationToken); // check server ops inherit CT + Assert.True(a.AsRespOperation().IsSent); + Assert.False(d.AsRespOperation().IsSent); + await batch.FlushAsync(); // uses BATCH + + // await something not flushed, inside the scope of the batch + f = TestAsync(batch.Context, 10); + + // Because of https://github.com/dotnet/runtime/issues/119232, we can't detect unsent operations + // in ValueTask/Task (technically we could for ValueTask[T], but it would break .AsTask()), but + // we can check the unwrapped handling. + var ex = await Assert.ThrowsAsync(async () => await f.AsRespOperation()); + Assert.StartsWith("This command has not yet been sent", ex.Message); + + // and try one that escapes the batch (should get disposed) + f = TestAsync(batch.Context, 10); // never flushed, intentionally + } + // we *can* safely await if the batch is disposed + await Assert.ThrowsAsync(async () => await f); + + // check what was sent + server.AssertSent("*2\r\n$4\r\ntest\r\n$1\r\n0\r\n"u8); + server.AssertSent("*2\r\n$4\r\ntest\r\n$1\r\n1\r\n"u8); + server.AssertSent("*2\r\n$4\r\ntest\r\n$1\r\n2\r\n"u8); + server.AssertSent("*2\r\n$4\r\ntest\r\n$1\r\n3\r\n"u8); + + server.AssertAllSent(); // that's everything + + Assert.True(d.AsRespOperation().IsSent, "batch ops should report as sent"); + e = TestAsync(server.Context, 4); // uses SERVER again + server.AssertSent("*2\r\n$4\r\ntest\r\n$1\r\n4\r\n"u8); + server.AssertAllSent(); // that's everything + + // check what is received (all in one chunk) + server.Respond(":5\r\n:6\r\n:7\r\n:8\r\n:9\r\n"u8); + Assert.Equal(5, await a); + Assert.Equal(6, await b); + Assert.Equal(7, await c); + Assert.Equal(8, await d); + Assert.Equal(9, await e); + + // but can only be awaited once + await Assert.ThrowsAsync(async () => await a); + await Assert.ThrowsAsync(async () => await b); + await Assert.ThrowsAsync(async () => await c); + await Assert.ThrowsAsync(async () => await d); + await Assert.ThrowsAsync(async () => await e); + } + + [RespCommand] + private static partial int Test(in RespContext ctx, int value); + + [RespCommand] + private static partial int Foo(in RespContext ctx); + + [RespCommand] + private static partial void Bar(in RespContext ctx); +} diff --git a/tests/RESPite.Tests/BlockBufferTests.cs b/tests/RESPite.Tests/BlockBufferTests.cs new file mode 100644 index 000000000..923a4dc15 --- /dev/null +++ b/tests/RESPite.Tests/BlockBufferTests.cs @@ -0,0 +1,160 @@ +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using System.Text; +using RESPite.Internal; +using Xunit; + +namespace RESPite.Tests; + +public class BlockBufferTests(ITestOutputHelper log) +{ + private void Log(ReadOnlySpan span) + { +#if NET + log.WriteLine(Encoding.UTF8.GetString(span)); +#else + unsafe + { + fixed (byte* p = span) + { + log.WriteLine(Encoding.UTF8.GetString(p, span.Length)); + } + } +#endif + } + + [Fact] + public void CanCreateAndWriteSimpleBuffer() + { + var buffer = BlockBufferSerializer.Create(); + var a = buffer.Serialize(null, "get"u8, "abc", RespFormatters.Key.String); + var b = buffer.Serialize(null, "get"u8, "def", RespFormatters.Key.String); + var c = buffer.Serialize(null, "get"u8, "ghi", RespFormatters.Key.String); + buffer.Clear(); +#if DEBUG + Assert.Equal(1, buffer.CountAdded); + Assert.Equal(3, buffer.CountMessages); + Assert.Equal(66, buffer.CountMessageBytes); // contents shown/verified below + Assert.Equal(0, buffer.CountRecycled); + Assert.Equal(0, buffer.CountLeaked); +#endif + // check the payloads + Log(a.Span); + Assert.True(a.Span.SequenceEqual("*2\r\n$3\r\nget\r\n$3\r\nabc\r\n"u8)); + Log(a.Span); + Assert.True(b.Span.SequenceEqual("*2\r\n$3\r\nget\r\n$3\r\ndef\r\n"u8)); + Log(c.Span); + Assert.True(c.Span.SequenceEqual("*2\r\n$3\r\nget\r\n$3\r\nghi\r\n"u8)); + AssertRelease(a); + AssertRelease(b); +#if DEBUG + Assert.Equal(0, buffer.CountRecycled); + Assert.Equal(0, buffer.CountLeaked); +#endif + AssertRelease(c); +#if DEBUG + Assert.Equal(1, buffer.CountRecycled); + Assert.Equal(0, buffer.CountLeaked); +#endif + } + + private static void AssertRelease(ReadOnlyMemory buffer) + { + Assert.True(MemoryMarshal.TryGetMemoryManager(buffer, out var manager)); + manager.Release(); + } + + [Fact] + public void CanWriteLotsOfBuffers_WithCheapReset() // when messages are consumed before more are added + { + var buffer = BlockBufferSerializer.Create(); +#if DEBUG + Assert.Equal(0, buffer.CountAdded); + Assert.Equal(0, buffer.CountRecycled); + Assert.Equal(0, buffer.CountLeaked); + Assert.Equal(0, buffer.CountMessages); +#endif + for (int i = 0; i < 5000; i++) + { + var a = buffer.Serialize(null, "get"u8, "abc", RespFormatters.Key.String); + var b = buffer.Serialize(null, "get"u8, "def", RespFormatters.Key.String); + var c = buffer.Serialize(null, "get"u8, "ghi", RespFormatters.Key.String); + Assert.True(MemoryMarshal.TryGetArray(a, out var aSegment)); + Assert.True(MemoryMarshal.TryGetArray(b, out var bSegment)); + Assert.True(MemoryMarshal.TryGetArray(c, out var cSegment)); + Assert.Equal(0, aSegment.Offset); + Assert.Equal(22, aSegment.Count); + Assert.Equal(22, bSegment.Offset); + Assert.Equal(22, bSegment.Count); + Assert.Equal(44, cSegment.Offset); + Assert.Equal(22, cSegment.Count); + Assert.Same(aSegment.Array, bSegment.Array); + Assert.Same(aSegment.Array, cSegment.Array); + AssertRelease(a); + AssertRelease(b); + AssertRelease(c); + } +#if DEBUG + Assert.Equal(1, buffer.CountAdded); + Assert.Equal(0, buffer.CountRecycled); + Assert.Equal(0, buffer.CountLeaked); + Assert.Equal(15_000, buffer.CountMessages); +#endif + buffer.Clear(); +#if DEBUG + Assert.Equal(1, buffer.CountAdded); + Assert.Equal(1, buffer.CountRecycled); + Assert.Equal(0, buffer.CountLeaked); + Assert.Equal(15_000, buffer.CountMessages); +#endif + } + + [Fact] + public void CanWriteLotsOfBuffers() + { + var buffer = BlockBufferSerializer.Create(); + List> blocks = new(15_000); +#if DEBUG + Assert.Equal(0, buffer.CountAdded); + Assert.Equal(0, buffer.CountRecycled); + Assert.Equal(0, buffer.CountLeaked); + Assert.Equal(0, buffer.CountMessages); +#endif + for (int i = 0; i < 5000; i++) + { + var block = buffer.Serialize(null, "get"u8, "abc", RespFormatters.Key.String); + blocks.Add(block); + block = buffer.Serialize(null, "get"u8, "def", RespFormatters.Key.String); + blocks.Add(block); + block = buffer.Serialize(null, "get"u8, "ghi", RespFormatters.Key.String); + blocks.Add(block); + } + + // Each buffer is 2048 by default, so: 93 per buffer; at least 162 buffers (looking at CountAdded). + // In reality, we apply some round-ups and minimum buffer sizes, which pushes it a little higher, but: not much. + // However, the runtime can also choose to issue bigger leases than we expect, pushing it down! What matters + // isn't the specific number, but: that it isn't huge. +#if DEBUG + Assert.Equal(15_000, buffer.CountMessages); + Assert.True(buffer.CountAdded < 200, "too many buffers used"); + Assert.Equal(0, buffer.CountRecycled); + Assert.Equal(0, buffer.CountLeaked); +#endif + buffer.Clear(); +#if DEBUG + Assert.Equal(15_000, buffer.CountMessages); + Assert.True(buffer.CountAdded < 200, "too many buffers used"); + Assert.Equal(0, buffer.CountRecycled); + Assert.Equal(0, buffer.CountLeaked); +#endif + + foreach (var block in blocks) AssertRelease(block); +#if DEBUG + Assert.Equal(15_000, buffer.CountMessages); + Assert.True(buffer.CountAdded < 200, "too many buffers used"); + Assert.Equal(buffer.CountAdded, buffer.CountRecycled); + Assert.Equal(0, buffer.CountLeaked); +#endif + } +} diff --git a/tests/RESPite.Tests/ConnectionFixture.cs b/tests/RESPite.Tests/ConnectionFixture.cs new file mode 100644 index 000000000..c80ddc022 --- /dev/null +++ b/tests/RESPite.Tests/ConnectionFixture.cs @@ -0,0 +1,198 @@ +using System; +using System.IO; +using System.Net; +using System.Threading.Tasks; +using Microsoft.Testing.Platform.Extensions.Messages; +using RESPite.Connections; +using RESPite.StackExchange.Redis; +using StackExchange.Redis; +using StackExchange.Redis.Maintenance; +using StackExchange.Redis.Profiling; +using Xunit; + +[assembly: AssemblyFixture(typeof(RESPite.Tests.ConnectionFixture))] + +namespace RESPite.Tests; + +public class ConnectionFixture : IDisposable +{ + private readonly IConnectionMultiplexer _muxer; + private readonly RespConnectionPool _pool = new(); + + public ConnectionFixture() + { + _muxer = new DummyMultiplexer(this); + } + + public void Dispose() => _pool.Dispose(); + + public RespConnection GetConnection() + { + var template = _pool.Template.WithCancellationToken(TestContext.Current.CancellationToken); + return _pool.GetConnection(template); + } + + public IConnectionMultiplexer Multiplexer => _muxer; +} + +internal sealed class DummyMultiplexer(ConnectionFixture fixture) : IConnectionMultiplexer +{ + public override string ToString() => nameof(DummyMultiplexer); + private readonly ConnectionFixture _fixture = fixture; + private readonly string clientName = ""; + private readonly string configuration = ""; +#pragma warning disable CS0649 // Field is never assigned to, and will always have its default value + private int timeoutMilliseconds; + private long operationCount; + private bool preserveAsyncOrder; + private bool isConnected; + private bool isConnecting; + private bool includeDetailInExceptions; + private int stormLogThreshold; +#pragma warning restore CS0649 // Field is never assigned to, and will always have its default value + + void IDisposable.Dispose() { } + + ValueTask IAsyncDisposable.DisposeAsync() => default; + + string IConnectionMultiplexer.ClientName => clientName; + + string IConnectionMultiplexer.Configuration => configuration; + + int IConnectionMultiplexer.TimeoutMilliseconds => timeoutMilliseconds; + + long IConnectionMultiplexer.OperationCount => operationCount; + + bool IConnectionMultiplexer.PreserveAsyncOrder + { + get => preserveAsyncOrder; + set => preserveAsyncOrder = value; + } + + bool IConnectionMultiplexer.IsConnected => isConnected; + + bool IConnectionMultiplexer.IsConnecting => isConnecting; + + bool IConnectionMultiplexer.IncludeDetailInExceptions + { + get => includeDetailInExceptions; + set => includeDetailInExceptions = value; + } + + int IConnectionMultiplexer.StormLogThreshold + { + get => stormLogThreshold; + set => stormLogThreshold = value; + } + + void IConnectionMultiplexer.RegisterProfiler(Func profilingSessionProvider) => + throw new NotImplementedException(); + + ServerCounters IConnectionMultiplexer.GetCounters() => throw new NotImplementedException(); + + event EventHandler? IConnectionMultiplexer.ErrorMessage + { + add => throw new NotImplementedException(); + remove => throw new NotImplementedException(); + } + + event EventHandler? IConnectionMultiplexer.ConnectionFailed + { + add => throw new NotImplementedException(); + remove => throw new NotImplementedException(); + } + + event EventHandler? IConnectionMultiplexer.InternalError + { + add => throw new NotImplementedException(); + remove => throw new NotImplementedException(); + } + + event EventHandler? IConnectionMultiplexer.ConnectionRestored + { + add => throw new NotImplementedException(); + remove => throw new NotImplementedException(); + } + + event EventHandler? IConnectionMultiplexer.ConfigurationChanged + { + add => throw new NotImplementedException(); + remove => throw new NotImplementedException(); + } + + event EventHandler? IConnectionMultiplexer.ConfigurationChangedBroadcast + { + add => throw new NotImplementedException(); + remove => throw new NotImplementedException(); + } + + event EventHandler? IConnectionMultiplexer.ServerMaintenanceEvent + { + add => throw new NotImplementedException(); + remove => throw new NotImplementedException(); + } + + EndPoint[] IConnectionMultiplexer.GetEndPoints(bool configuredOnly) => throw new NotImplementedException(); + + void IConnectionMultiplexer.Wait(Task task) => throw new NotImplementedException(); + + T IConnectionMultiplexer.Wait(Task task) => throw new NotImplementedException(); + + void IConnectionMultiplexer.WaitAll(params Task[] tasks) => throw new NotImplementedException(); + + event EventHandler? IConnectionMultiplexer.HashSlotMoved + { + add => throw new NotImplementedException(); + remove => throw new NotImplementedException(); + } + + int IConnectionMultiplexer.HashSlot(RedisKey key) => throw new NotImplementedException(); + + ISubscriber IConnectionMultiplexer.GetSubscriber(object? asyncState) => throw new NotImplementedException(); + + IDatabase IConnectionMultiplexer.GetDatabase(int db, object? asyncState) => throw new NotImplementedException(); + + IServer IConnectionMultiplexer.GetServer(string host, int port, object? asyncState) => + throw new NotImplementedException(); + + IServer IConnectionMultiplexer.GetServer(string hostAndPort, object? asyncState) => + throw new NotImplementedException(); + + IServer IConnectionMultiplexer.GetServer(IPAddress host, int port) => throw new NotImplementedException(); + + IServer IConnectionMultiplexer.GetServer(EndPoint endpoint, object? asyncState) => + throw new NotImplementedException(); + + public IServer GetServer(RedisKey key, object? asyncState = null, CommandFlags flags = CommandFlags.None) + => throw new NotImplementedException(); + + IServer[] IConnectionMultiplexer.GetServers() => throw new NotImplementedException(); + + Task IConnectionMultiplexer.ConfigureAsync(TextWriter? log) => throw new NotImplementedException(); + + bool IConnectionMultiplexer.Configure(TextWriter? log) => throw new NotImplementedException(); + + string IConnectionMultiplexer.GetStatus() => throw new NotImplementedException(); + + void IConnectionMultiplexer.GetStatus(TextWriter log) => throw new NotImplementedException(); + + void IConnectionMultiplexer.Close(bool allowCommandsToComplete) => throw new NotImplementedException(); + + Task IConnectionMultiplexer.CloseAsync(bool allowCommandsToComplete) => throw new NotImplementedException(); + + string? IConnectionMultiplexer.GetStormLog() => throw new NotImplementedException(); + + void IConnectionMultiplexer.ResetStormLog() => throw new NotImplementedException(); + + long IConnectionMultiplexer.PublishReconfigure(CommandFlags flags) => throw new NotImplementedException(); + + Task IConnectionMultiplexer.PublishReconfigureAsync(CommandFlags flags) => + throw new NotImplementedException(); + + int IConnectionMultiplexer.GetHashSlot(RedisKey key) => throw new NotImplementedException(); + + void IConnectionMultiplexer.ExportConfiguration(Stream destination, ExportOptions options) => + throw new NotImplementedException(); + + void IConnectionMultiplexer.AddLibraryNameSuffix(string suffix) => throw new NotImplementedException(); +} diff --git a/tests/RESPite.Tests/CycleBufferTests.cs b/tests/RESPite.Tests/CycleBufferTests.cs new file mode 100644 index 000000000..bbffdb51b --- /dev/null +++ b/tests/RESPite.Tests/CycleBufferTests.cs @@ -0,0 +1,182 @@ +using System; +using System.Buffers; +using System.Diagnostics; +using System.Threading; +using RESPite.Internal; +using RESPite.Messages; +using Xunit; + +namespace RESPite.Tests; + +public class CycleBufferTests +{ + [Fact] + public void SimpleUsage() + { + CycleBuffer buffer = CycleBuffer.Create(); + Assert.True(buffer.CommittedIsEmpty); + Assert.Equal(0, buffer.GetCommittedLength()); + Assert.False(buffer.TryGetFirstCommittedSpan(0, out _)); + + buffer.Write("hello world"u8); + Assert.False(buffer.CommittedIsEmpty, "should be empty"); + Assert.Equal(11, buffer.GetCommittedLength()); + + Assert.False(buffer.TryGetFirstCommittedSpan(-1, out _), "should have rejected full"); + Assert.True(buffer.TryGetFirstCommittedSpan(0, out var committed), "should have accepted partial"); + Assert.True(committed.SequenceEqual("hello world"u8)); + buffer.DiscardCommitted(11); + Assert.True(buffer.CommittedIsEmpty); + Assert.Equal(0, buffer.GetCommittedLength()); + Assert.False(buffer.TryGetFirstCommittedSpan(0, out _)); + + // now partial consume + buffer.Write("partial consume"u8); + Assert.False(buffer.CommittedIsEmpty); + Assert.Equal(15, buffer.GetCommittedLength()); + + Assert.False(buffer.TryGetFirstCommittedSpan(-1, out _)); + Assert.True(buffer.TryGetFirstCommittedSpan(0, out committed)); + Assert.True(committed.SequenceEqual("partial consume"u8)); + buffer.DiscardCommitted(8); + Assert.False(buffer.CommittedIsEmpty); + Assert.Equal(7, buffer.GetCommittedLength()); + Assert.True(buffer.TryGetFirstCommittedSpan(0, out committed)); + Assert.True(committed.SequenceEqual("consume"u8)); + buffer.DiscardCommitted(7); + Assert.True(buffer.CommittedIsEmpty); + Assert.Equal(0, buffer.GetCommittedLength()); + Assert.False(buffer.TryGetFirstCommittedSpan(0, out _)); + buffer.Release(); + } + + private sealed class CountingMemoryPool(MemoryPool? tail = null) : MemoryPool + { + private readonly MemoryPool _tail = tail ?? MemoryPool.Shared; + private int count; + + public int Count => Volatile.Read(ref count); + public override IMemoryOwner Rent(int minBufferSize = -1) => new Wrapper(this, _tail.Rent(minBufferSize)); + + protected override void Dispose(bool disposing) => throw new NotImplementedException(); + + private void Decrement() => Interlocked.Decrement(ref count); + + private CountingMemoryPool Increment() + { + Interlocked.Increment(ref count); + return this; + } + + public override int MaxBufferSize => _tail.MaxBufferSize; + + private sealed class Wrapper(CountingMemoryPool parent, IMemoryOwner tail) : IMemoryOwner + { + private int _disposed; + private readonly CountingMemoryPool _parent = parent.Increment(); + + public void Dispose() + { + if (Interlocked.CompareExchange(ref _disposed, 1, 0) == 0) + { + _parent.Decrement(); + tail.Dispose(); + } + else + { + ThrowDisposed(); + } + } + + private void ThrowDisposed() => throw new ObjectDisposedException(nameof(MemoryPool)); + + public Memory Memory + { + get + { + if (Volatile.Read(ref _disposed) != 0) ThrowDisposed(); + return tail.Memory; + } + } + } + } + + [Fact] + public void SkipAggregate() + { + var reader = new RespReader("*1\r\n$3\r\nabc\r\n"u8); // ["abc"] + reader.MoveNext(); + reader.SkipChildren(); + Assert.False(reader.TryMoveNext()); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void MultiSegmentUsage(bool multiSegmentRead) + { + byte[] garbage = new byte[1024 * 1024]; + var rand = new Random(Seed: 134521); + rand.NextBytes(garbage); + + int offset = 0; + var mgr = new CountingMemoryPool(); + CycleBuffer buffer = CycleBuffer.Create(mgr); + Assert.Equal(0, mgr.Count); + while (offset < garbage.Length) + { + var size = rand.Next(1, garbage.Length - offset + 1); + Debug.Assert(size > 0); + buffer.Write(new ReadOnlySpan(garbage, offset, size)); + offset += size; + Assert.Equal(offset, buffer.GetCommittedLength()); + } + + Assert.True(mgr.Count >= 50); // some non-trivial count + int total = 0; + if (multiSegmentRead) + { + while (!buffer.CommittedIsEmpty) + { + var seq = buffer.GetAllCommitted(); + var take = rand.Next((int)Math.Min(seq.Length, 4 * buffer.PageSize)) + 1; + var slice = seq.Slice(0, take); + Assert.True(SequenceEqual(slice, new(garbage, total, take)), "data integrity check"); + buffer.DiscardCommitted(take); + total += take; + } + } + else + { + while (buffer.TryGetFirstCommittedSpan(0, out var span)) + { + var take = rand.Next(span.Length) + 1; + var slice = span.Slice(0, take); + Assert.True(slice.SequenceEqual(new(garbage, total, take)), "data integrity check"); + buffer.DiscardCommitted(take); + total += take; + } + } + + Assert.Equal(garbage.Length, total); + Assert.Equal(3, mgr.Count); + buffer.Release(); + + Assert.Equal(0, mgr.Count); + + static bool SequenceEqual(ReadOnlySequence seq1, ReadOnlySpan seq2) + { + if (seq1.IsSingleSegment) + { + return seq1.First.Span.SequenceEqual(seq2); + } + + if (seq1.Length != seq2.Length) return false; + var arr = ArrayPool.Shared.Rent(seq2.Length); + seq1.CopyTo(arr); + var result = arr.AsSpan(0, seq2.Length).SequenceEqual(seq2); + ArrayPool.Shared.Return(arr); + return result; + } + } +} diff --git a/tests/RESPite.Tests/IntegrationTestBase.cs b/tests/RESPite.Tests/IntegrationTestBase.cs new file mode 100644 index 000000000..d1ef3c7e8 --- /dev/null +++ b/tests/RESPite.Tests/IntegrationTestBase.cs @@ -0,0 +1,33 @@ +using System.Runtime.CompilerServices; +using System.Threading.Tasks; +using RESPite.Redis.Alt; +using RESPite.StackExchange.Redis; +using StackExchange.Redis; +using Xunit; + +namespace RESPite.Tests; + +public abstract class IntegrationTestBase(ConnectionFixture fixture, ITestOutputHelper log) +{ + public RespConnection GetConnection([CallerMemberName] string caller = "") + { + var conn = fixture.GetConnection(); // includes cancellation from the test + // most of the time, they'll be using a key from Me(), so: pre-emptively nuke it + conn.Context.AsKeys().Del(caller); + return conn; + } + + public async ValueTask GetConnectionAsync([CallerMemberName] string caller = "") + { + var conn = fixture.GetConnection(); // includes cancellation from the test + // most of the time, they'll be using a key from Me(), so: pre-emptively nuke it + await conn.Context.AsKeys().DelAsync(caller).ConfigureAwait(false); + return conn; + } + + public IDatabase AsDatabase(RespConnection conn, int db = 0) => new RespContextDatabase(fixture.Multiplexer, conn, db); + + public void Log(string message) => log?.WriteLine(message); + + protected string Me([CallerMemberName] string caller = "") => caller; +} diff --git a/tests/RESPite.Tests/LogWriter.cs b/tests/RESPite.Tests/LogWriter.cs new file mode 100644 index 000000000..45ffce8b6 --- /dev/null +++ b/tests/RESPite.Tests/LogWriter.cs @@ -0,0 +1,11 @@ +using System.IO; +using System.Text; +using Xunit; + +namespace RESPite.Tests; + +internal sealed class LogWriter(ITestOutputHelper? log) : TextWriter +{ + public override Encoding Encoding => Encoding.Unicode; + public override void WriteLine(string? value) => log?.WriteLine(value ?? ""); +} diff --git a/tests/RESPite.Tests/OperationUnitTests.cs b/tests/RESPite.Tests/OperationUnitTests.cs new file mode 100644 index 000000000..27c877628 --- /dev/null +++ b/tests/RESPite.Tests/OperationUnitTests.cs @@ -0,0 +1,331 @@ +using System; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Threading; +using System.Threading.Tasks; +using System.Threading.Tasks.Sources; +using Xunit; + +namespace RESPite.Tests; + +[SuppressMessage( + "Usage", + "xUnit1031:Do not use blocking task operations in test method", + Justification = "This isn't actually async; we're testing an awaitable.")] +public class OperationUnitTests +{ + private static CancellationToken CancellationToken => TestContext.Current.CancellationToken; + + [Theory] + [InlineData(true, false)] + [InlineData(false, false)] + [InlineData(true, true)] + [InlineData(false, true)] + public void ManuallyImplementedAsync_NotSent_Untyped(bool sent, bool @unsafe) + { + var op = RespOperation.Create(out var remote, sent, cancellationToken: CancellationToken); + Assert.Equal(sent, op.IsSent); + var awaiter = op.GetAwaiter(); + Assert.False(awaiter.IsCompleted, "not completed first IsCompleted check"); + + if (@unsafe) + { + op.UnsafeOnCompleted(() => { }); + } + else + { + op.OnCompleted(() => { }); + } + + if (sent) + { + Assert.False(awaiter.IsCompleted, "incomplete after OnCompleted"); + Assert.True(remote.TrySetResult(default)); + awaiter.GetResult(); + } + else + { + Assert.True(awaiter.IsFaulted, "faulted after OnCompleted"); + Assert.False(remote.TrySetResult(default)); + var ex = Assert.Throws(() => awaiter.GetResult()); + Assert.Contains("This command has not yet been sent", ex.Message); + } + + Assert.Throws(() => awaiter.GetResult()); + } + + [Theory] + [InlineData(true, false)] + [InlineData(false, false)] + [InlineData(true, true)] + [InlineData(false, true)] + public void ManuallyImplementedAsync_NotSent_Typed(bool sent, bool @unsafe) + { + var op = RespOperation.Create(null, out var remote, sent, cancellationToken: CancellationToken); + Assert.Equal(sent, op.IsSent); + var awaiter = op.GetAwaiter(); + Assert.False(awaiter.IsCompleted, "not completed first IsCompleted check"); + + if (@unsafe) + { + op.UnsafeOnCompleted(() => { }); + } + else + { + op.OnCompleted(() => { }); + } + + if (sent) + { + Assert.False(awaiter.IsCompleted, "incomplete after OnCompleted"); + Assert.True(remote.TrySetResult(default)); + awaiter.GetResult(); + } + else + { + Assert.True(awaiter.IsFaulted, "faulted after OnCompleted"); + Assert.False(remote.TrySetResult(default)); + var ex = Assert.Throws(() => awaiter.GetResult()); + Assert.Contains("This command has not yet been sent", ex.Message); + } + + Assert.Throws(() => awaiter.GetResult()); + } + + [Theory] + [InlineData(true, false)] + [InlineData(false, false)] + [InlineData(true, true)] + [InlineData(false, true)] + public void ManuallyImplementedAsync_NotSent_Stateful(bool sent, bool @unsafe) + { + var op = RespOperation.Create("abc", null, out var remote, sent, CancellationToken); + Assert.Equal(sent, op.IsSent); + var awaiter = op.GetAwaiter(); + Assert.False(awaiter.IsCompleted, "not completed first IsCompleted check"); + + if (@unsafe) + { + op.UnsafeOnCompleted(() => { }); + } + else + { + op.OnCompleted(() => { }); + } + + if (sent) + { + Assert.False(awaiter.IsCompleted, "incomplete after OnCompleted"); + Assert.True(remote.TrySetResult(default)); + awaiter.GetResult(); + } + else + { + Assert.True(awaiter.IsFaulted, "faulted after OnCompleted"); + Assert.False(remote.TrySetResult(default)); + var ex = Assert.Throws(() => awaiter.GetResult()); + Assert.Contains("This command has not yet been sent", ex.Message); + } + + Assert.Throws(() => awaiter.GetResult()); + } + + [Fact(Timeout = 1000)] + public void UnsentDetectedSync() + { + var op = RespOperation.Create(out var remote, false, CancellationToken); + var ex = Assert.Throws(() => op.Wait()); + Assert.Contains("This command has not yet been sent", ex.Message); + } + + [Fact(Timeout = 1000)] + public async Task UnsentDetected_Operation_Async() + { + var op = RespOperation.Create(out var remote, false, CancellationToken); + Assert.False(op.IsCompleted); + var ex = await Assert.ThrowsAsync(async () => await op); + Assert.Contains("This command has not yet been sent", ex.Message); + } + + [Fact(Timeout = 1000)] + public async Task UnsentNotDetected_ValueTask_Async() + { + using var cts = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken); + cts.CancelAfter(100); + var op = RespOperation.Create(out var remote, false, cts.Token); + var ex = await Assert.ThrowsAsync(async () => await op.AsValueTask()); + AssertCT(ex.CancellationToken, cts.Token); + } + + [Fact] + public void CoreValueTaskToTaskSupportsCancellation() + { + // The purpose of this test is to show that there are some inherent limitations in netfx + // regarding IVTS:AsTask (compared with modern .NET), specifically: + // - it manifests as TaskCanceledException instead of OperationCanceledException + // - the token is not propagated correctly - it comes back as .None + var cts = new CancellationTokenSource(); + cts.Cancel(); + var ta = new TestAwaitable(); + var task = ta.AsValueTask().AsTask(); + Assert.Equal(TaskStatus.WaitingForActivation, task.Status); + ta.Cancel(cts.Token); + Assert.Equal(TaskStatus.Canceled, task.Status); + // ReSharper disable once MethodSupportsCancellation - this task is not incomplete +#pragma warning disable xUnit1051 + // use awaiter to unroll aggregate exception +#if NETFRAMEWORK + var ex = Assert.Throws(() => task.GetAwaiter().GetResult()); +#else + var ex = Assert.Throws(() => task.GetAwaiter().GetResult()); +#endif +#pragma warning restore xUnit1051 + var summary = SummarizeCT(ex.CancellationToken, cts.Token); + +#if NETFRAMEWORK // I *wish* this wasn't the case, but: wishes are free + Assert.Equal( + CancellationProblems.DefaultToken | CancellationProblems.NotCanceled + | CancellationProblems.CannotBeCanceled | CancellationProblems.NotExpectedToken, + summary); +#else + Assert.Equal(CancellationProblems.None, summary); +#endif + } + + private sealed class TestAwaitable : IValueTaskSource + { + private ManualResetValueTaskSourceCore _core; + public ValueTask AsValueTask() => new(this, _core.Version); + public void GetResult(short token) => _core.GetResult(token); + public void Cancel(CancellationToken token) => _core.SetException(new OperationCanceledException(token)); + public ValueTaskSourceStatus GetStatus(short token) => _core.GetStatus(token); + + public void OnCompleted( + Action continuation, + object? state, + short token, + ValueTaskSourceOnCompletedFlags flags) + => _core.OnCompleted(continuation, state, token, flags); + } + + [Fact(Timeout = 1000)] + public async Task UnsentNotDetected_Task_Async() + { + using var cts = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken); + cts.CancelAfter(100); + var op = RespOperation.Create(out var remote, false, cts.Token); + var ex = await Assert.ThrowsAnyAsync(async () => await op.AsTask()); + + #if NETFRAMEWORK // see CoreValueTaskToTaskSupportsCancellation for more context + Assert.Equal(CancellationToken.None, ex.CancellationToken); + #else + AssertCT(ex.CancellationToken, cts.Token); + #endif + } + + [Flags] + private enum CancellationProblems + { + None = 0, + DefaultToken = 1 << 0, + NotCanceled = 1 << 1, + CannotBeCanceled = 1 << 2, + TestInfrastuctureToken = 1 << 3, + NotExpectedToken = 1 << 4, + } + + private static CancellationProblems SummarizeCT(CancellationToken actual, CancellationToken expected) + { + CancellationProblems problems = 0; + if (actual == CancellationToken.None) problems |= CancellationProblems.DefaultToken; + if (!actual.IsCancellationRequested) problems |= CancellationProblems.NotCanceled; + if (!actual.CanBeCanceled) problems |= CancellationProblems.CannotBeCanceled; + if (actual == CancellationToken) problems |= CancellationProblems.TestInfrastuctureToken; + if (actual != expected) problems |= CancellationProblems.NotExpectedToken; + return problems; + } + + private static void AssertCT(CancellationToken actual, CancellationToken expected) + => Assert.Equal(CancellationProblems.None, SummarizeCT(actual, expected)); + + [Fact(Timeout = 1000)] + public void CanCreateAndCompleteOperation() + { + var op = RespOperation.Create(out var remote, cancellationToken: CancellationToken); + + // initial state + Assert.False(op.IsCanceled); + Assert.False(op.IsCompleted); + Assert.False(op.IsCompletedSuccessfully); + Assert.False(op.IsFaulted); + + // complete first time + Assert.True(remote.TrySetResult(default)); + Assert.False(op.IsCanceled); + Assert.True(op.IsCompleted); + Assert.True(op.IsCompletedSuccessfully); + Assert.False(op.IsFaulted); + + // additional completions fail + Assert.False(remote.TrySetResult(default)); +#pragma warning disable xUnit1051 + Assert.False(remote.TrySetCanceled()); +#pragma warning restore xUnit1051 + Assert.False(remote.TrySetException(null!)); + + // can get result + Assert.True(remote.IsTokenMatch, "should match before GetResult"); + op.GetResult(); + Assert.False(remote.IsTokenMatch, "should have reset token"); + + // but only once, after that: bad things + Assert.Throws(() => op.GetResult()); + Assert.Throws(() => op.IsCanceled); + Assert.Throws(() => op.IsCompleted); + Assert.Throws(() => op.IsCompletedSuccessfully); + Assert.Throws(() => op.IsFaulted); + + // additional completions continue to fail + Assert.False(remote.TrySetResult(default), "TrySetResult"); + Assert.False(remote.TrySetCanceled(CancellationToken), "TrySetCanceled"); + Assert.False(remote.TrySetException(null!), "TrySetException"); + } + + [Fact(Timeout = 1000)] + public void CanCreateAndCompleteWithoutLeaking() + { + int before = RespOperation.DebugPerThreadMessageAllocations; + for (int i = 0; i < 100; i++) + { + var op = RespOperation.Create(out var remote, cancellationToken: CancellationToken); + remote.TrySetResult(default); + Assert.True(op.IsCompleted); + op.Wait(); + } + + int after = RespOperation.DebugPerThreadMessageAllocations; + var allocs = after - before; + Debug.Assert(allocs < 2, $"allocations: {allocs}"); + } + + [Fact(Timeout = 1000)] + public async Task CanCreateAndCompleteWithoutLeaking_Async() + { + var threadId = Environment.CurrentManagedThreadId; + int before = RespOperation.DebugPerThreadMessageAllocations; + for (int i = 0; i < 100; i++) + { + var op = RespOperation.Create(out var remote, cancellationToken: CancellationToken); + remote.TrySetResult(default); + Assert.True(op.IsCompleted); + await op; + } + + int after = RespOperation.DebugPerThreadMessageAllocations; + var allocs = after - before; + Debug.Assert(allocs < 2, $"allocations: {allocs}"); + + // do not expect thread switch + Assert.Equal(threadId, Environment.CurrentManagedThreadId); + } +} diff --git a/tests/RESPite.Tests/RESPite.Tests.csproj b/tests/RESPite.Tests/RESPite.Tests.csproj new file mode 100644 index 000000000..90f681c9b --- /dev/null +++ b/tests/RESPite.Tests/RESPite.Tests.csproj @@ -0,0 +1,25 @@ + + + + net481;net8.0 + enable + false + true + Exe + + + + + + + + + + + + + + + + + diff --git a/tests/RESPite.Tests/RedisDatabaseTests.cs b/tests/RESPite.Tests/RedisDatabaseTests.cs new file mode 100644 index 000000000..a2df1ade1 --- /dev/null +++ b/tests/RESPite.Tests/RedisDatabaseTests.cs @@ -0,0 +1,41 @@ +using System.Threading.Tasks; +using StackExchange.Redis; +using Xunit; + +namespace RESPite.Tests; + +public class RedisDatabaseTests(ConnectionFixture fixture, ITestOutputHelper log) + : IntegrationTestBase(fixture, log) +{ + [Fact] + public void HashSetGetAll() + { + var key = Me(); + + using var conn = GetConnection(); + var db = AsDatabase(conn); + db.HashSet(key, "abc", "xyz"); + db.HashSet(key, "def", "uvw"); + + var all = db.HashGetAll(key); + Assert.Equal(2, all.Length); + Assert.Contains(new HashEntry("abc", "xyz"), all); + Assert.Contains(new HashEntry("def", "uvw"), all); + } + + [Fact] + public async Task HashSetGetAllAsync() + { + var key = Me(); + + await using var conn = await GetConnectionAsync(); + var db = AsDatabase(conn); + await db.HashSetAsync(key, "abc", "xyz"); + await db.HashSetAsync(key, "def", "uvw"); + + var all = await db.HashGetAllAsync(key); + Assert.Equal(2, all.Length); + Assert.Contains(new HashEntry("abc", "xyz"), all); + Assert.Contains(new HashEntry("def", "uvw"), all); + } +} diff --git a/tests/RESPite.Tests/RedisStringsIntegrationTests.cs b/tests/RESPite.Tests/RedisStringsIntegrationTests.cs new file mode 100644 index 000000000..e4ad8b8d9 --- /dev/null +++ b/tests/RESPite.Tests/RedisStringsIntegrationTests.cs @@ -0,0 +1,40 @@ +using System.Threading.Tasks; +using RESPite.Redis.Alt; // needed for AsStrings() etc +using Xunit; +using FactAttribute = StackExchange.Redis.Tests.FactAttribute; + +namespace RESPite.Tests; + +public class RedisStringsIntegrationTests(ConnectionFixture fixture, ITestOutputHelper log) + : IntegrationTestBase(fixture, log) +{ + [Fact] + public void Incr() + { + var key = Me(); + + using var conn = GetConnection(); + var ctx = conn.Context; + for (int i = 0; i < 5; i++) + { + ctx.AsStrings().Incr(key); + } + var result = ctx.AsStrings().GetInt32(key); + Assert.Equal(5, result); + } + + [Fact] + public async Task IncrAsync() + { + var key = Me(); + + await using var conn = GetConnection(); + var ctx = conn.Context; + for (int i = 0; i < 5; i++) + { + await ctx.AsStrings().IncrAsync(key); + } + var result = await ctx.AsStrings().GetInt32Async(key); + Assert.Equal(5, result); + } +} diff --git a/tests/RESPite.Tests/RespMultiplexerTests.cs b/tests/RESPite.Tests/RespMultiplexerTests.cs new file mode 100644 index 000000000..b11e69b0d --- /dev/null +++ b/tests/RESPite.Tests/RespMultiplexerTests.cs @@ -0,0 +1,35 @@ +using System.Linq; +using System.Threading.Tasks; +using RESPite.StackExchange.Redis; +using StackExchange.Redis; +using Xunit; + +namespace RESPite.Tests; + +public class RespMultiplexerTests(ITestOutputHelper log) +{ + private readonly LogWriter logWriter = new(log); + + [Fact] + public async Task CanConnect() + { + await using var muxer = new RespMultiplexer(); + await muxer.ConnectAsync("localhost:6379", log: logWriter); + Assert.True(muxer.IsConnected); + + var server = muxer.GetServer(default(RedisKey)); + Assert.IsType(server); // we expect this to *not* use routing + server.Ping(); + await server.PingAsync(); + + var db = muxer.GetDatabase(); + var proxied = Assert.IsType(db); + // since this is a single-node instance, we expect the proxied database to use the interactive connection + db.Ping(); + await db.PingAsync(); + + // ReSharper disable once MethodHasAsyncOverload + proxied.Ping(); + await proxied.PingAsync(); + } +} diff --git a/tests/RESPite.Tests/RespReaderTests.cs b/tests/RESPite.Tests/RespReaderTests.cs new file mode 100644 index 000000000..1cc04bba2 --- /dev/null +++ b/tests/RESPite.Tests/RespReaderTests.cs @@ -0,0 +1,863 @@ +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using System.Numerics; +using System.Reflection; +using System.Text; +using System.Threading.Tasks; +using RESPite.Internal; +using RESPite.Messages; +using Xunit; +using Xunit.Sdk; +using Xunit.v3; + +namespace RESPite.Tests; + +public class RespReaderTests(ITestOutputHelper logger) +{ + public readonly struct RespPayload(string label, ReadOnlySequence payload, byte[] expected, bool? outOfBand, int count) + { + public override string ToString() => Label; + public string Label { get; } = label; + public ReadOnlySequence PayloadRaw { get; } = payload; + public int Length { get; } = CheckPayload(payload, expected, outOfBand, count); + private static int CheckPayload(scoped in ReadOnlySequence actual, byte[] expected, bool? outOfBand, int count) + { + Assert.Equal(expected.LongLength, actual.Length); + var pool = ArrayPool.Shared.Rent(expected.Length); + actual.CopyTo(pool); + bool isSame = pool.AsSpan(0, expected.Length).SequenceEqual(expected); + ArrayPool.Shared.Return(pool); + Assert.True(isSame, "Data mismatch"); + + // verify that the data exactly passes frame-scanning + long totalBytes = 0; + RespReader reader = new(actual); + while (count > 0) + { + RespScanState state = default; + Assert.True(state.TryRead(ref reader, out long bytesRead)); + totalBytes += bytesRead; + Assert.True(state.IsComplete, nameof(state.IsComplete)); + if (outOfBand.HasValue) + { + if (outOfBand.Value) + { + Assert.Equal(RespPrefix.Push, state.Prefix); + } + else + { + Assert.NotEqual(RespPrefix.Push, state.Prefix); + } + } + count--; + } + Assert.Equal(expected.Length, totalBytes); + reader.DemandEnd(); + return expected.Length; + } + + public RespReader Reader() => new(PayloadRaw); + } + + public sealed class RespAttribute : DataAttribute + { + public override bool SupportsDiscoveryEnumeration() => true; + + private readonly object _value; + public bool OutOfBand { get; init; } = false; + + private bool? EffectiveOutOfBand => Count == 1 ? OutOfBand : default(bool?); + public int Count { get; init; } = 1; + + public RespAttribute(string value) => _value = value; + public RespAttribute(params string[] values) => _value = values; + + public override ValueTask> GetData(MethodInfo testMethod, DisposalTracker disposalTracker) + => new(GetData(testMethod).ToArray()); + + public IEnumerable GetData(MethodInfo testMethod) + { + switch (_value) + { + case string s: + foreach (var item in GetVariants(s, EffectiveOutOfBand, Count)) + { + yield return new TheoryDataRow(item); + } + break; + case string[] arr: + foreach (string s in arr) + { + foreach (var item in GetVariants(s, EffectiveOutOfBand, Count)) + { + yield return new TheoryDataRow(item); + } + } + break; + } + } + + private static IEnumerable GetVariants(string value, bool? outOfBand, int count) + { + var bytes = Encoding.UTF8.GetBytes(value); + + // all in one + yield return new("Right-sized", new(bytes), bytes, outOfBand, count); + + var bigger = new byte[bytes.Length + 4]; + bytes.CopyTo(bigger.AsSpan(2, bytes.Length)); + bigger.AsSpan(0, 2).Fill(0xFF); + bigger.AsSpan(bytes.Length + 2, 2).Fill(0xFF); + + // all in one, oversized + yield return new("Oversized", new(bigger, 2, bytes.Length), bytes, outOfBand, count); + + // two-chunks + for (int i = 0; i <= bytes.Length; i++) + { + int offset = 2 + i; + var left = new Segment(new ReadOnlyMemory(bigger, 0, offset), null); + var right = new Segment(new ReadOnlyMemory(bigger, offset, bigger.Length - offset), left); + yield return new($"Split:{i}", new ReadOnlySequence(left, 2, right, right.Length - 2), bytes, outOfBand, count); + } + + // N-chunks + Segment head = new(new(bytes, 0, 1), null), tail = head; + for (int i = 1; i < bytes.Length; i++) + { + tail = new(new(bytes, i, 1), tail); + } + yield return new("Chunk-per-byte", new(head, 0, tail, 1), bytes, outOfBand, count); + } + } + + [Theory, Resp("$3\r\n128\r\n")] + public void HandleSplitTokens(RespPayload payload) + { + RespReader reader = payload.Reader(); + RespScanState scan = default; + bool readResult = scan.TryRead(ref reader, out _); + logger.WriteLine(scan.ToString()); + Assert.Equal(payload.Length, reader.BytesConsumed); + Assert.True(readResult); + } + + // the examples from https://github.com/redis/redis-specifications/blob/master/protocol/RESP3.md + [Theory, Resp("$11\r\nhello world\r\n", "$?\r\n;6\r\nhello \r\n;5\r\nworld\r\n;0\r\n")] + public void BlobString(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.BulkString); + Assert.True(reader.Is("hello world"u8)); + Assert.Equal("hello world", reader.ReadString()); + Assert.Equal("hello world", reader.ReadString(out var prefix)); + Assert.Equal("", prefix); +#if NET7_0_OR_GREATER + Assert.Equal("hello world", reader.ParseChars()); +#endif + /* interestingly, string does not implement IUtf8SpanParsable +#if NET8_0_OR_GREATER + Assert.Equal("hello world", reader.ParseBytes()); +#endif + */ + reader.DemandEnd(); + } + + [Theory, Resp("$0\r\n\r\n", "$?\r\n;0\r\n")] + public void EmptyBlobString(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.BulkString); + Assert.True(reader.Is(""u8)); + Assert.Equal("", reader.ReadString()); + reader.DemandEnd(); + } + + [Theory, Resp("+hello world\r\n")] + public void SimpleString(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.SimpleString); + Assert.True(reader.Is("hello world"u8)); + Assert.Equal("hello world", reader.ReadString()); + Assert.Equal("hello world", reader.ReadString(out var prefix)); + Assert.Equal("", prefix); + reader.DemandEnd(); + } + + [Theory, Resp("-ERR this is the error description\r\n")] + public void SimpleError_ImplicitErrors(RespPayload payload) + { + var ex = Assert.Throws(() => + { + var reader = payload.Reader(); + reader.MoveNext(); + }); + Assert.Equal("ERR this is the error description", ex.Message); + } + + [Theory, Resp("-ERR this is the error description\r\n")] + public void SimpleError_Careful(RespPayload payload) + { + var reader = payload.Reader(); + Assert.True(reader.TryReadNext()); + Assert.Equal(RespPrefix.SimpleError, reader.Prefix); + Assert.True(reader.Is("ERR this is the error description"u8)); + Assert.Equal("ERR this is the error description", reader.ReadString()); + reader.DemandEnd(); + } + + [Theory, Resp(":1234\r\n")] + public void Number(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Integer); + Assert.True(reader.Is("1234"u8)); + Assert.Equal("1234", reader.ReadString()); + Assert.Equal(1234, reader.ReadInt32()); + Assert.Equal(1234D, reader.ReadDouble()); + Assert.Equal(1234M, reader.ReadDecimal()); +#if NET7_0_OR_GREATER + Assert.Equal(1234, reader.ParseChars()); + Assert.Equal(1234D, reader.ParseChars()); + Assert.Equal(1234M, reader.ParseChars()); +#endif +#if NET8_0_OR_GREATER + Assert.Equal(1234, reader.ParseBytes()); + Assert.Equal(1234D, reader.ParseBytes()); + Assert.Equal(1234M, reader.ParseBytes()); +#endif + reader.DemandEnd(); + } + + [Theory, Resp("_\r\n")] + public void Null(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Null); + Assert.True(reader.Is(""u8)); + Assert.Null(reader.ReadString()); + reader.DemandEnd(); + } + + [Theory, Resp("$-1\r\n")] + public void NullString(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.BulkString); + Assert.True(reader.IsNull); + Assert.Null(reader.ReadString()); + Assert.Equal(0, reader.ScalarLength()); + Assert.True(reader.Is(""u8)); + Assert.True(reader.ScalarIsEmpty()); + + var iterator = reader.ScalarChunks(); + Assert.False(iterator.MoveNext()); + iterator.MovePast(out reader); + reader.DemandEnd(); + } + + [Theory, Resp(",1.23\r\n")] + public void Double(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Double); + Assert.True(reader.Is("1.23"u8)); + Assert.Equal("1.23", reader.ReadString()); + Assert.Equal(1.23D, reader.ReadDouble()); + Assert.Equal(1.23M, reader.ReadDecimal()); + reader.DemandEnd(); + } + + [Theory, Resp(":10\r\n")] + public void Integer_Simple(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Integer); + Assert.True(reader.Is("10"u8)); + Assert.Equal("10", reader.ReadString()); + Assert.Equal(10, reader.ReadInt32()); + Assert.Equal(10D, reader.ReadDouble()); + Assert.Equal(10M, reader.ReadDecimal()); + reader.DemandEnd(); + } + + [Theory, Resp(",10\r\n")] + public void Double_Simple(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Double); + Assert.True(reader.Is("10"u8)); + Assert.Equal("10", reader.ReadString()); + Assert.Equal(10, reader.ReadInt32()); + Assert.Equal(10D, reader.ReadDouble()); + Assert.Equal(10M, reader.ReadDecimal()); + reader.DemandEnd(); + } + + [Theory, Resp(",inf\r\n")] + public void Double_Infinity(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Double); + Assert.True(reader.Is("inf"u8)); + Assert.Equal("inf", reader.ReadString()); + var val = reader.ReadDouble(); + Assert.True(double.IsInfinity(val)); + Assert.True(double.IsPositiveInfinity(val)); + reader.DemandEnd(); + } + + [Theory, Resp(",+inf\r\n")] + public void Double_PosInfinity(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Double); + Assert.True(reader.Is("+inf"u8)); + Assert.Equal("+inf", reader.ReadString()); + var val = reader.ReadDouble(); + Assert.True(double.IsInfinity(val)); + Assert.True(double.IsPositiveInfinity(val)); + reader.DemandEnd(); + } + + [Theory, Resp(",-inf\r\n")] + public void Double_NegInfinity(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Double); + Assert.True(reader.Is("-inf"u8)); + Assert.Equal("-inf", reader.ReadString()); + var val = reader.ReadDouble(); + Assert.True(double.IsInfinity(val)); + Assert.True(double.IsNegativeInfinity(val)); + reader.DemandEnd(); + } + + [Theory, Resp(",nan\r\n")] + public void Double_NaN(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Double); + Assert.True(reader.Is("nan"u8)); + Assert.Equal("nan", reader.ReadString()); + var val = reader.ReadDouble(); + Assert.True(double.IsNaN(val)); + reader.DemandEnd(); + } + + [Theory, Resp("#t\r\n")] + public void Boolean_T(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Boolean); + Assert.True(reader.ReadBoolean()); + reader.DemandEnd(); + } + + [Theory, Resp("#f\r\n")] + public void Boolean_F(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Boolean); + Assert.False(reader.ReadBoolean()); + reader.DemandEnd(); + } + + [Theory, Resp(":1\r\n")] + public void Boolean_1(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Integer); + Assert.True(reader.ReadBoolean()); + reader.DemandEnd(); + } + + [Theory, Resp(":0\r\n")] + public void Boolean_0(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Integer); + Assert.False(reader.ReadBoolean()); + reader.DemandEnd(); + } + + [Theory, Resp("!21\r\nSYNTAX invalid syntax\r\n", "!?\r\n;6\r\nSYNTAX\r\n;15\r\n invalid syntax\r\n;0\r\n")] + public void BlobError_ImplicitErrors(RespPayload payload) + { + var ex = Assert.Throws(() => + { + var reader = payload.Reader(); + reader.MoveNext(); + }); + Assert.Equal("SYNTAX invalid syntax", ex.Message); + } + + [Theory, Resp("!21\r\nSYNTAX invalid syntax\r\n", "!?\r\n;6\r\nSYNTAX\r\n;15\r\n invalid syntax\r\n;0\r\n")] + public void BlobError_Careful(RespPayload payload) + { + var reader = payload.Reader(); + Assert.True(reader.TryReadNext()); + Assert.Equal(RespPrefix.BulkError, reader.Prefix); + Assert.True(reader.Is("SYNTAX invalid syntax"u8)); + Assert.Equal("SYNTAX invalid syntax", reader.ReadString()); + reader.DemandEnd(); + } + + [Theory, Resp("=15\r\ntxt:Some string\r\n", "=?\r\n;4\r\ntxt:\r\n;11\r\nSome string\r\n;0\r\n")] + public void VerbatimString(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.VerbatimString); + Assert.Equal("Some string", reader.ReadString()); + Assert.Equal("Some string", reader.ReadString(out var prefix)); + Assert.Equal("txt", prefix); + + Assert.Equal("Some string", reader.ReadString(out var prefix2)); + Assert.Same(prefix, prefix2); // check prefix recognized and reuse literal + reader.DemandEnd(); + } + + [Theory, Resp("(3492890328409238509324850943850943825024385\r\n")] + public void BigIntegers(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.BigInteger); + Assert.Equal("3492890328409238509324850943850943825024385", reader.ReadString()); +#if NET8_0_OR_GREATER + var actual = reader.ParseChars(chars => BigInteger.Parse(chars, CultureInfo.InvariantCulture)); + + var expected = BigInteger.Parse("3492890328409238509324850943850943825024385"); + Assert.Equal(expected, actual); +#endif + } + + [Theory, Resp("*3\r\n:1\r\n:2\r\n:3\r\n", "*?\r\n:1\r\n:2\r\n:3\r\n.\r\n")] + public void Array(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Array); + Assert.Equal(3, reader.AggregateLength()); + var iterator = reader.AggregateChildren(); + Assert.True(iterator.MoveNext(RespPrefix.Integer)); + Assert.Equal(1, iterator.Value.ReadInt32()); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.Integer)); + Assert.Equal(2, iterator.Value.ReadInt32()); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.Integer)); + Assert.Equal(3, iterator.Value.ReadInt32()); + iterator.Value.DemandEnd(); + + Assert.False(iterator.MoveNext(RespPrefix.Integer)); + iterator.MovePast(out reader); + reader.DemandEnd(); + + reader = payload.Reader(); + reader.MoveNext(RespPrefix.Array); + int[] arr = new int[reader.AggregateLength()]; + int i = 0; + foreach (var sub in reader.AggregateChildren()) + { + sub.MoveNext(RespPrefix.Integer); + arr[i++] = sub.ReadInt32(); + sub.DemandEnd(); + } + iterator.MovePast(out reader); + reader.DemandEnd(); + + Assert.Equal([1, 2, 3], arr); + } + + [Theory, Resp("*-1\r\n")] + public void NullArray(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Array); + Assert.True(reader.IsNull); + Assert.Equal(0, reader.AggregateLength()); + var iterator = reader.AggregateChildren(); + Assert.False(iterator.MoveNext()); + iterator.MovePast(out reader); + reader.DemandEnd(); + } + + [Theory, Resp("*2\r\n*3\r\n:1\r\n$5\r\nhello\r\n:2\r\n#f\r\n", "*?\r\n*?\r\n:1\r\n$5\r\nhello\r\n:2\r\n.\r\n#f\r\n.\r\n")] + public void NestedArray(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Array); + + Assert.Equal(2, reader.AggregateLength()); + + var iterator = reader.AggregateChildren(); + Assert.True(iterator.MoveNext(RespPrefix.Array)); + + Assert.Equal(3, iterator.Value.AggregateLength()); + var subIterator = iterator.Value.AggregateChildren(); + Assert.True(subIterator.MoveNext(RespPrefix.Integer)); + Assert.Equal(1, subIterator.Value.ReadInt64()); + subIterator.Value.DemandEnd(); + + Assert.True(subIterator.MoveNext(RespPrefix.BulkString)); + Assert.True(subIterator.Value.Is("hello"u8)); + subIterator.Value.DemandEnd(); + + Assert.True(subIterator.MoveNext(RespPrefix.Integer)); + Assert.Equal(2, subIterator.Value.ReadInt64()); + subIterator.Value.DemandEnd(); + + Assert.False(subIterator.MoveNext()); + + Assert.True(iterator.MoveNext(RespPrefix.Boolean)); + Assert.False(iterator.Value.ReadBoolean()); + iterator.Value.DemandEnd(); + + Assert.False(iterator.MoveNext()); + iterator.MovePast(out reader); + + reader.DemandEnd(); + } + + [Theory, Resp("%2\r\n+first\r\n:1\r\n+second\r\n:2\r\n", "%?\r\n+first\r\n:1\r\n+second\r\n:2\r\n.\r\n")] + public void Map(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Map); + + Assert.Equal(4, reader.AggregateLength()); + + var iterator = reader.AggregateChildren(); + + Assert.True(iterator.MoveNext(RespPrefix.SimpleString)); + Assert.True(iterator.Value.Is("first".AsSpan())); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.Integer)); + Assert.Equal(1, iterator.Value.ReadInt32()); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.SimpleString)); + Assert.True(iterator.Value.Is("second"u8)); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.Integer)); + Assert.Equal(2, iterator.Value.ReadInt32()); + iterator.Value.DemandEnd(); + + Assert.False(iterator.MoveNext()); + + iterator.MovePast(out reader); + reader.DemandEnd(); + } + + [Theory, Resp("~5\r\n+orange\r\n+apple\r\n#t\r\n:100\r\n:999\r\n", "~?\r\n+orange\r\n+apple\r\n#t\r\n:100\r\n:999\r\n.\r\n")] + public void Set(RespPayload payload) + { + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Set); + + Assert.Equal(5, reader.AggregateLength()); + + var iterator = reader.AggregateChildren(); + + Assert.True(iterator.MoveNext(RespPrefix.SimpleString)); + Assert.True(iterator.Value.Is("orange".AsSpan())); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.SimpleString)); + Assert.True(iterator.Value.Is("apple"u8)); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.Boolean)); + Assert.True(iterator.Value.ReadBoolean()); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.Integer)); + Assert.Equal(100, iterator.Value.ReadInt32()); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.Integer)); + Assert.Equal(999, iterator.Value.ReadInt32()); + iterator.Value.DemandEnd(); + + Assert.False(iterator.MoveNext()); + + iterator.MovePast(out reader); + reader.DemandEnd(); + } + + private sealed class TestAttributeReader : RespAttributeReader<(int Count, int Ttl, decimal A, decimal B)> + { + public override void Read(ref RespReader reader, ref (int Count, int Ttl, decimal A, decimal B) value) + { + value.Count += ReadKeyValuePairs(ref reader, ref value); + } + private TestAttributeReader() { } + public static readonly TestAttributeReader Instance = new(); + public static (int Count, int Ttl, decimal A, decimal B) Zero = (0, 0, 0, 0); + public override bool ReadKeyValuePair(scoped ReadOnlySpan key, ref RespReader reader, ref (int Count, int Ttl, decimal A, decimal B) value) + { + if (key.SequenceEqual("ttl"u8) && reader.IsScalar) + { + value.Ttl = reader.ReadInt32(); + } + else if (key.SequenceEqual("key-popularity"u8) && reader.IsAggregate) + { + ReadKeyValuePairs(ref reader, ref value); // recurse to process a/b below + } + else if (key.SequenceEqual("a"u8) && reader.IsScalar) + { + value.A = reader.ReadDecimal(); + } + else if (key.SequenceEqual("b"u8) && reader.IsScalar) + { + value.B = reader.ReadDecimal(); + } + else + { + return false; // not recognized + } + return true; // recognized + } + } + + [Theory, Resp( + "|1\r\n+key-popularity\r\n%2\r\n$1\r\na\r\n,0.1923\r\n$1\r\nb\r\n,0.0012\r\n*2\r\n:2039123\r\n:9543892\r\n", + "|1\r\n+key-popularity\r\n%2\r\n$1\r\na\r\n,0.1923\r\n$1\r\nb\r\n,0.0012\r\n*?\r\n:2039123\r\n:9543892\r\n.\r\n")] + public void AttributeRoot(RespPayload payload) + { + // ignore the attribute data + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Array); + Assert.Equal(2, reader.AggregateLength()); + var iterator = reader.AggregateChildren(); + + Assert.True(iterator.MoveNext(RespPrefix.Integer)); + Assert.Equal(2039123, iterator.Value.ReadInt32()); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.Integer)); + Assert.Equal(9543892, iterator.Value.ReadInt32()); + iterator.Value.DemandEnd(); + + Assert.False(iterator.MoveNext()); + iterator.MovePast(out reader); + reader.DemandEnd(); + + // process the attribute data + var state = TestAttributeReader.Zero; + reader = payload.Reader(); + reader.MoveNext(RespPrefix.Array, TestAttributeReader.Instance, ref state); + Assert.Equal(1, state.Count); + Assert.Equal(0.1923M, state.A); + Assert.Equal(0.0012M, state.B); + state = TestAttributeReader.Zero; + + Assert.Equal(2, reader.AggregateLength()); + iterator = reader.AggregateChildren(); + + Assert.True(iterator.MoveNext(RespPrefix.Integer, TestAttributeReader.Instance, ref state)); + Assert.Equal(2039123, iterator.Value.ReadInt32()); + Assert.Equal(0, state.Count); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.Integer, TestAttributeReader.Instance, ref state)); + Assert.Equal(9543892, iterator.Value.ReadInt32()); + Assert.Equal(0, state.Count); + iterator.Value.DemandEnd(); + + Assert.False(iterator.MoveNext()); + iterator.MovePast(out reader); + reader.DemandEnd(); + } + + [Theory, Resp("*3\r\n:1\r\n:2\r\n|1\r\n+ttl\r\n:3600\r\n:3\r\n", "*?\r\n:1\r\n:2\r\n|1\r\n+ttl\r\n:3600\r\n:3\r\n.\r\n")] + public void AttributeInner(RespPayload payload) + { + // ignore the attribute data + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Array); + Assert.Equal(3, reader.AggregateLength()); + var iterator = reader.AggregateChildren(); + + Assert.True(iterator.MoveNext(RespPrefix.Integer)); + Assert.Equal(1, iterator.Value.ReadInt32()); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.Integer)); + Assert.Equal(2, iterator.Value.ReadInt32()); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.Integer)); + Assert.Equal(3, iterator.Value.ReadInt32()); + iterator.Value.DemandEnd(); + + Assert.False(iterator.MoveNext()); + iterator.MovePast(out reader); + reader.DemandEnd(); + + // process the attribute data + var state = TestAttributeReader.Zero; + reader = payload.Reader(); + reader.MoveNext(RespPrefix.Array, TestAttributeReader.Instance, ref state); + Assert.Equal(0, state.Count); + Assert.Equal(3, reader.AggregateLength()); + iterator = reader.AggregateChildren(); + + Assert.True(iterator.MoveNext(RespPrefix.Integer, TestAttributeReader.Instance, ref state)); + Assert.Equal(0, state.Count); + Assert.Equal(1, iterator.Value.ReadInt32()); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.Integer, TestAttributeReader.Instance, ref state)); + Assert.Equal(0, state.Count); + Assert.Equal(2, iterator.Value.ReadInt32()); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.Integer, TestAttributeReader.Instance, ref state)); + Assert.Equal(1, state.Count); + Assert.Equal(3600, state.Ttl); + state = TestAttributeReader.Zero; // reset + Assert.Equal(3, iterator.Value.ReadInt32()); + iterator.Value.DemandEnd(); + + Assert.False(iterator.MoveNext(TestAttributeReader.Instance, ref state)); + Assert.Equal(0, state.Count); + iterator.MovePast(out reader); + reader.DemandEnd(); + } + + [Theory, Resp(">3\r\n+message\r\n+somechannel\r\n+this is the message\r\n", OutOfBand = true)] + public void Push(RespPayload payload) + { + // ignore the attribute data + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Push); + Assert.Equal(3, reader.AggregateLength()); + var iterator = reader.AggregateChildren(); + + Assert.True(iterator.MoveNext(RespPrefix.SimpleString)); + Assert.True(iterator.Value.Is("message"u8)); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.SimpleString)); + Assert.True(iterator.Value.Is("somechannel"u8)); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.SimpleString)); + Assert.True(iterator.Value.Is("this is the message"u8)); + iterator.Value.DemandEnd(); + + Assert.False(iterator.MoveNext()); + iterator.MovePast(out reader); + reader.DemandEnd(); + } + + [Theory, Resp(">3\r\n+message\r\n+somechannel\r\n+this is the message\r\n$9\r\nGet-Reply\r\n", Count = 2)] + public void PushThenGetReply(RespPayload payload) + { + // ignore the attribute data + var reader = payload.Reader(); + reader.MoveNext(RespPrefix.Push); + Assert.Equal(3, reader.AggregateLength()); + var iterator = reader.AggregateChildren(); + + Assert.True(iterator.MoveNext(RespPrefix.SimpleString)); + Assert.True(iterator.Value.Is("message"u8)); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.SimpleString)); + Assert.True(iterator.Value.Is("somechannel"u8)); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.SimpleString)); + Assert.True(iterator.Value.Is("this is the message"u8)); + iterator.Value.DemandEnd(); + + Assert.False(iterator.MoveNext()); + iterator.MovePast(out reader); + + reader.MoveNext(RespPrefix.BulkString); + Assert.True(reader.Is("Get-Reply"u8)); + reader.DemandEnd(); + } + + [Theory, Resp("$9\r\nGet-Reply\r\n>3\r\n+message\r\n+somechannel\r\n+this is the message\r\n", Count = 2)] + public void GetReplyThenPush(RespPayload payload) + { + // ignore the attribute data + var reader = payload.Reader(); + + reader.MoveNext(RespPrefix.BulkString); + Assert.True(reader.Is("Get-Reply"u8)); + + reader.MoveNext(RespPrefix.Push); + Assert.Equal(3, reader.AggregateLength()); + var iterator = reader.AggregateChildren(); + + Assert.True(iterator.MoveNext(RespPrefix.SimpleString)); + Assert.True(iterator.Value.Is("message"u8)); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.SimpleString)); + Assert.True(iterator.Value.Is("somechannel"u8)); + iterator.Value.DemandEnd(); + + Assert.True(iterator.MoveNext(RespPrefix.SimpleString)); + Assert.True(iterator.Value.Is("this is the message"u8)); + iterator.Value.DemandEnd(); + + Assert.False(iterator.MoveNext()); + iterator.MovePast(out reader); + + reader.DemandEnd(); + } + + [Theory, Resp("*0\r\n$4\r\npass\r\n", "*1\r\n+ok\r\n$4\r\npass\r\n", "*-1\r\n$4\r\npass\r\n", "*?\r\n.\r\n$4\r\npass\r\n", Count = 2)] + public void ArrayThenString(RespPayload payload) + { + var reader = payload.Reader(); + Assert.True(reader.TryMoveNext(RespPrefix.Array)); + reader.SkipChildren(); + + Assert.True(reader.TryMoveNext(RespPrefix.BulkString)); + Assert.True(reader.Is("pass"u8)); + + reader.DemandEnd(); + + // and the same using child iterator + reader = payload.Reader(); + Assert.True(reader.TryMoveNext(RespPrefix.Array)); + var iterator = reader.AggregateChildren(); + iterator.MovePast(out reader); + + Assert.True(reader.TryMoveNext(RespPrefix.BulkString)); + Assert.True(reader.Is("pass"u8)); + + reader.DemandEnd(); + } + + private sealed class Segment : ReadOnlySequenceSegment + { + public override string ToString() => RespConstants.UTF8.GetString(Memory.Span) + .Replace("\r", "\\r").Replace("\n", "\\n"); + + public Segment(ReadOnlyMemory value, Segment? head) + { + Memory = value; + if (head is not null) + { + RunningIndex = head.RunningIndex + head.Memory.Length; + head.Next = this; + } + } + public bool IsEmpty => Memory.IsEmpty; + public int Length => Memory.Length; + } +} diff --git a/tests/RESPite.Tests/RespWriterTests.cs b/tests/RESPite.Tests/RespWriterTests.cs new file mode 100644 index 000000000..6462ee991 --- /dev/null +++ b/tests/RESPite.Tests/RespWriterTests.cs @@ -0,0 +1,42 @@ +using RESPite.Messages; +using Xunit; + +namespace RESPite.Tests; + +public class RespWriterTests +{ + [Theory] + [InlineData(0, "$1\r\n0\r\n")] + [InlineData(-1, "$2\r\n-1\r\n")] + [InlineData(-12, "$3\r\n-12\r\n")] + [InlineData(-123, "$4\r\n-123\r\n")] + [InlineData(-1234, "$5\r\n-1234\r\n")] + [InlineData(-12345, "$6\r\n-12345\r\n")] + [InlineData(-123456, "$7\r\n-123456\r\n")] + [InlineData(-1234567, "$8\r\n-1234567\r\n")] + [InlineData(-12345678, "$9\r\n-12345678\r\n")] + [InlineData(-123456789, "$10\r\n-123456789\r\n")] + [InlineData(-1234567890, "$11\r\n-1234567890\r\n")] + [InlineData(int.MinValue, "$11\r\n-2147483648\r\n")] + [InlineData(1, "$1\r\n1\r\n")] + [InlineData(12, "$2\r\n12\r\n")] + [InlineData(123, "$3\r\n123\r\n")] + [InlineData(1234, "$4\r\n1234\r\n")] + [InlineData(12345, "$5\r\n12345\r\n")] + [InlineData(123456, "$6\r\n123456\r\n")] + [InlineData(1234567, "$7\r\n1234567\r\n")] + [InlineData(12345678, "$8\r\n12345678\r\n")] + [InlineData(123456789, "$9\r\n123456789\r\n")] + [InlineData(1234567890, "$10\r\n1234567890\r\n")] + [InlineData(int.MaxValue, "$10\r\n2147483647\r\n")] + + public void BulkStringInteger(int value, string expected) + { + using var aw = new TestBufferWriter(); + var writer = new RespWriter(aw); + writer.WriteBulkString(value); + writer.Flush(); + var actual = aw.ToString(); + Assert.Equal(expected, actual); + } +} diff --git a/tests/RESPite.Tests/TestBufferWriter.cs b/tests/RESPite.Tests/TestBufferWriter.cs new file mode 100644 index 000000000..c258498dc --- /dev/null +++ b/tests/RESPite.Tests/TestBufferWriter.cs @@ -0,0 +1,52 @@ +using System; +using System.Buffers; +using System.Text; + +namespace RESPite.Tests; + +// note that ArrayBufferWriter{T} is not available on all target platforms +public sealed class TestBufferWriter : IBufferWriter, IDisposable +{ + private byte[] _buffer = []; + private int _committed; + + public override string ToString() => Encoding.UTF8.GetString(_buffer, 0, _committed); + public ReadOnlySpan Committed => _buffer.AsSpan(0, _committed); + + public void Advance(int count) + { + if (count < 0 | count + _committed > _buffer.Length) throw new ArgumentOutOfRangeException(nameof(count)); + _committed += count; + } + + private void Ensure(int sizeHint) + { + sizeHint = Math.Max(sizeHint, 128); + if (_buffer.Length < _committed + sizeHint) + { + var newBuffer = ArrayPool.Shared.Rent(Math.Max(_buffer.Length * 2, _committed + sizeHint)); + Committed.CopyTo(newBuffer); + ArrayPool.Shared.Return(_buffer); + _buffer = newBuffer; + } + } + + public Memory GetMemory(int sizeHint = 0) + { + Ensure(sizeHint); + return _buffer.AsMemory(_committed); + } + + public Span GetSpan(int sizeHint = 0) + { + Ensure(sizeHint); + return _buffer.AsSpan(_committed); + } + + public void Dispose() + { + _committed = 0; + ArrayPool.Shared.Return(_buffer); + _buffer = []; + } +} diff --git a/tests/RESPite.Tests/TestServer.cs b/tests/RESPite.Tests/TestServer.cs new file mode 100644 index 000000000..5b3a3f1ce --- /dev/null +++ b/tests/RESPite.Tests/TestServer.cs @@ -0,0 +1,358 @@ +using System; +using System.Buffers; +using System.IO; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using RESPite.Connections.Internal; +using RESPite.Internal; +using Xunit; + +namespace RESPite.Tests; + +internal sealed class TestServer : IDisposable +{ + private readonly TestRespServerStream _stream = new(); + public RespConnection Connection { get; } + + public TestServer(RespConfiguration? configuration = null) + { + Connection = new StreamConnection( + RespContext.Null.WithCancellationToken(TestContext.Current.CancellationToken), + configuration ?? RespConfiguration.Default, + _stream); + } + + public void Dispose() + { + // ReSharper disable once ConditionalAccessQualifierIsNonNullableAccordingToAPIContract + _stream?.Dispose(); + // ReSharper disable once ConditionalAccessQualifierIsNonNullableAccordingToAPIContract + Connection?.Dispose(); + } + + public static ValueTask Execute( + Func> operation, + ReadOnlySpan request, + ReadOnlySpan response) + => ExecuteCore(operation, request, response); + + // intended for use with [InlineData("...")] scenarios + public static ValueTask Execute( + Func> operation, + string request, + string response) + { + var lease = Encode(request, response, out var reqSpan, out var respSpan); + return ExecuteCore(operation, reqSpan, respSpan, lease); + } + + private static byte[] Encode( + string request, + string response, + out ReadOnlySpan requestSpan, + out ReadOnlySpan responseSpan) + { + var byteCount = Encoding.UTF8.GetByteCount(request) + Encoding.UTF8.GetByteCount(response); + var lease = ArrayPool.Shared.Rent(byteCount); + var reqLen = Encoding.UTF8.GetBytes(request.AsSpan(), lease.AsSpan()); + var respLen = Encoding.UTF8.GetBytes(response.AsSpan(), lease.AsSpan(reqLen)); + requestSpan = lease.AsSpan(0, reqLen); + responseSpan = lease.AsSpan(reqLen, respLen); + return lease; + } + + private static ValueTask ExecuteCore( + Func> operation, + ReadOnlySpan request, + ReadOnlySpan response, + byte[]? lease = null) + { + bool disposeServer = true; + TestServer? server = null; + try + { + server = new TestServer(); + var pending = operation(server.Context); + server.AssertSent(request, final: true); + Assert.False(pending.IsCompleted); + server.Respond(response); + disposeServer = false; + return AwaitAndDispose(server, pending); + } + finally + { + if (disposeServer) server?.Dispose(); + if (lease is not null) ArrayPool.Shared.Return(lease); + } + + static async ValueTask AwaitAndDispose(TestServer server, ValueTask pending) + { + using (server) + { + return await pending.ConfigureAwait(false); + } + } + } + + public static ValueTask Execute( + Func> operation, + ReadOnlySpan request, + ReadOnlySpan response, + T expected) + => AwaitAndValidate(Execute(operation, request, response), expected); + + // intended for use with [InlineData("...")] scenarios + public static ValueTask Execute( + Func> operation, + string request, + string response, + T expected) + => AwaitAndValidate(Execute(operation, request, response), expected); + + public static ValueTask Execute( + Func operation, + ReadOnlySpan request, + ReadOnlySpan response) + => ExecuteCore(operation, request, response); + + // intended for use with [InlineData("...")] scenarios + public static ValueTask Execute( + Func operation, + string request, + string response) + { + var lease = Encode(request, response, out var reqSpan, out var respSpan); + return ExecuteCore(operation, reqSpan, respSpan, lease); + } + + private static ValueTask ExecuteCore( + Func operation, + ReadOnlySpan request, + ReadOnlySpan response, + byte[]? lease = null) + { + bool disposeServer = true; + TestServer? server = null; + try + { + server = new TestServer(); + var pending = operation(server.Context); + server.AssertSent(request, final: true); + Assert.False(pending.IsCompleted); + server.Respond(response); + disposeServer = false; + return AwaitAndDispose(server, pending); + } + finally + { + if (disposeServer) server?.Dispose(); + if (lease is not null) ArrayPool.Shared.Return(lease); + } + + static async ValueTask AwaitAndDispose(TestServer server, ValueTask pending) + { + using (server) + { + await pending.ConfigureAwait(false); + } + } + } + + private static async ValueTask AwaitAndValidate(ValueTask pending, T expected) + { + var actual = await pending.ConfigureAwait(false); + Assert.Equal(expected, actual); + } + + public ref readonly RespContext Context => ref Connection.Context; + + public void Respond(ReadOnlySpan serverToClient) => _stream.Respond(serverToClient); + + public void AssertSent(ReadOnlySpan clientToServer, bool final = false) + { + _stream.AssertSent(clientToServer); + if (final) _stream.AssertAllSent(); + } + + public void AssertAllSent() => _stream.AssertAllSent(); + + private sealed class TestRespServerStream : Stream + { + private bool _disposed, _closed; + + public override void Close() + { + _closed = true; + lock (inboundLock) + { + Monitor.PulseAll(inboundLock); + } + } + + protected override void Dispose(bool disposing) + { + _disposed = true; + if (disposing) + { + lock (inboundLock) + { + Monitor.PulseAll(inboundLock); + } + } + } + + public override void Flush() { } + + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + + public override void SetLength(long value) => throw new NotSupportedException(); + + private void ThrowIfDisposed() + { + if (_disposed) throw new ObjectDisposedException(GetType().Name); + } + + public override int Read(byte[] buffer, int offset, int count) + => ReadCore(buffer.AsSpan(offset, count)); + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + var read = ReadCore(buffer.AsSpan(offset, count)); + return Task.FromResult(read); + } + + public void Respond(ReadOnlySpan serverToClient) + { + lock (inboundLock) + { + if (!(_disposed | _disposed)) + { + _inbound.Write(serverToClient); + } + + Monitor.PulseAll(inboundLock); + } + } + + private int ReadCore(Span destination) + { + ThrowIfDisposed(); + lock (inboundLock) + { + while (_inbound.CommittedIsEmpty) + { + if (_closed) return 0; + Monitor.Wait(inboundLock); + ThrowIfDisposed(); + } + + if (destination.IsEmpty) return 0; // zero-length read + Assert.True(_inbound.TryGetFirstCommittedSpan(1, out var span)); + Assert.False(span.IsEmpty); + if (span.Length > destination.Length) span = span.Slice(0, destination.Length); + span.CopyTo(destination); + return span.Length; + } + } + +#if NET + public override int Read(Span buffer) => ReadCore(buffer); + public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + var read = ReadCore(buffer.Span); + return new(read); + } +#endif + + // ReSharper disable once ChangeFieldTypeToSystemThreadingLock - TFM dependent + private readonly object outboundLock = new(), inboundLock = new(); + + private CycleBuffer _outbound = CycleBuffer.Create(MemoryPool.Shared), + _inbound = CycleBuffer.Create(MemoryPool.Shared); + + private void WriteCore(ReadOnlySpan source) + { + lock (outboundLock) + { + _outbound.Write(source); + } + } + + public override void Write(byte[] buffer, int offset, int count) + => WriteCore(buffer.AsSpan(offset, count)); + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + WriteCore(buffer.AsSpan(offset, count)); + return Task.CompletedTask; + } + +#if NET + public override void Write(ReadOnlySpan buffer) => WriteCore(buffer); + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + WriteCore(buffer.Span); + return default; + } +#endif + + // verifies that there is no more request data unaccounted for + public void AssertAllSent() + { + bool empty; + lock (outboundLock) + { + empty = _outbound.CommittedIsEmpty; + } + + Assert.True(empty); + } + + /// + /// Verifies and discards outbound data. + /// + public void AssertSent(ReadOnlySpan clientToServer) + { + lock (outboundLock) + { + var available = _outbound.GetCommittedLength(); + Assert.True( + available >= clientToServer.Length, + $"expected {clientToServer.Length} bytes, {available} available"); + while (!clientToServer.IsEmpty) + { + Assert.True(_outbound.TryGetFirstCommittedSpan(1, out var received), "should have data available"); + var take = Math.Min(received.Length, clientToServer.Length); + Assert.True(take > 0, "should have some data to compare"); + var xBytes = clientToServer.Slice(0, take); + var yBytes = received.Slice(0, take); + if (!xBytes.SequenceEqual(yBytes)) + { + var xText = Encoding.UTF8.GetString(xBytes).Replace("\r\n", "\\r\\n"); + var yText = Encoding.UTF8.GetString(yBytes).Replace("\r\n", "\\r\\n"); + Assert.Equal(xText, yText); + } + + _outbound.DiscardCommitted(take); + clientToServer = clientToServer.Slice(take); + } + } + } + + public override bool CanRead => true; + public override bool CanSeek => false; + public override bool CanWrite => true; + public override long Length => throw new NotSupportedException(); + + public override long Position + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } + } +} diff --git a/tests/StackExchange.Redis.Tests/SSLTests.cs b/tests/StackExchange.Redis.Tests/SSLTests.cs index 0dafe3f9b..5c01bd817 100644 --- a/tests/StackExchange.Redis.Tests/SSLTests.cs +++ b/tests/StackExchange.Redis.Tests/SSLTests.cs @@ -240,7 +240,9 @@ public async Task RedisLabsSSL() Skip.IfNoConfig(nameof(TestConfig.Config.RedisLabsSslServer), TestConfig.Current.RedisLabsSslServer); Skip.IfNoConfig(nameof(TestConfig.Config.RedisLabsPfxPath), TestConfig.Current.RedisLabsPfxPath); +#pragma warning disable SYSLIB0057 // because of TFM support var cert = new X509Certificate2(TestConfig.Current.RedisLabsPfxPath, ""); +#pragma warning restore SYSLIB0057 Assert.NotNull(cert); Log("Thumbprint: " + cert.Thumbprint); diff --git a/tests/StackExchange.Redis.Tests/SyncContextTests.cs b/tests/StackExchange.Redis.Tests/SyncContextTests.cs index b98caefeb..5feb37e3d 100644 --- a/tests/StackExchange.Redis.Tests/SyncContextTests.cs +++ b/tests/StackExchange.Redis.Tests/SyncContextTests.cs @@ -122,7 +122,7 @@ public MySyncContext(TextWriter log) private int _opCount; private void Incr() => Interlocked.Increment(ref _opCount); - public void Reset() => Thread.VolatileWrite(ref _opCount, 0); + public void Reset() => Volatile.Write(ref _opCount, 0); public override string ToString() => $"Sync context ({(IsCurrent ? "active" : "inactive")}): {OpCount}";