From 3efadeb5962fc80ea502d13bf050a7168447a7b7 Mon Sep 17 00:00:00 2001 From: Aron Tsang Date: Fri, 4 Apr 2025 15:57:29 +0800 Subject: [PATCH 1/5] Implement a fast lookup for wildcard certificates --- ...verCertificateSelector.CertificateStore.cs | 139 ++++++++++++++++++ .../Certificates/ServerCertificateSelector.cs | 40 ++++- ...ReverseProxyServiceCollectionExtensions.cs | 4 +- 3 files changed, 177 insertions(+), 6 deletions(-) create mode 100644 src/Kubernetes.Controller/Certificates/ServerCertificateSelector.CertificateStore.cs diff --git a/src/Kubernetes.Controller/Certificates/ServerCertificateSelector.CertificateStore.cs b/src/Kubernetes.Controller/Certificates/ServerCertificateSelector.CertificateStore.cs new file mode 100644 index 000000000..1da95167b --- /dev/null +++ b/src/Kubernetes.Controller/Certificates/ServerCertificateSelector.CertificateStore.cs @@ -0,0 +1,139 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Formats.Asn1; +using System.Globalization; +using System.Linq; +using System.Security.Cryptography.X509Certificates; + +namespace Yarp.Kubernetes.Controller.Certificates; + +internal partial class ServerCertificateSelector +{ + private class CertificateStore + { + private readonly List _wildCardDomains = new(); + private readonly Dictionary _certificates = new(StringComparer.OrdinalIgnoreCase); + + public CertificateStore(IEnumerable certificates) + { + + foreach (var certificate in certificates) + { + foreach (var domain in GetDomains(certificate)) + { + if (domain.StartsWith("*.")) + { + _wildCardDomains.Add(new (domain[2..], certificate)); + } + else + { + _certificates[domain] = certificate; + } + } + } + + _wildCardDomains.Sort(DomainNameComparer.Instance); + } + + + public X509Certificate2 GetCertificate(string domain) + { + // First search for exact match for certificate. + if (_certificates.TryGetValue(domain, out var cert)) + { + return cert; + } + + + // By using a binary search, we can achieve O(log n) suffix search whilst avoiding a complex + // tree/trie structure in the heap. + if (_wildCardDomains.BinarySearch(new WildCardDomain(domain, null!), DomainNameComparer.Instance) is { } index and < -1) + { + var candidate = _wildCardDomains[~index]; + if (domain.EndsWith(candidate.Domain, true, CultureInfo.InvariantCulture)) + { + return candidate.Certificate; + } + } + + return _wildCardDomains.FirstOrDefault()?.Certificate + ?? _certificates.Values.FirstOrDefault(); + } + } + + private record WildCardDomain(string Domain, X509Certificate2 Certificate); + + private static IEnumerable GetDomains(X509Certificate2 certificate) + { + if (certificate.GetNameInfo(X509NameType.DnsName, false) is { } dnsName) + { + yield return dnsName; + } + + const string SAN_OID = "2.5.29.17"; + var extension = certificate.Extensions[SAN_OID]; + if (extension is null) + { + yield break; + } + + var dnsNameTag = new Asn1Tag(TagClass.ContextSpecific, tagValue: 2, isConstructed: false); + + var asnReader = new AsnReader(extension.RawData, AsnEncodingRules.BER); + var sequenceReader = asnReader.ReadSequence(Asn1Tag.Sequence); + while (sequenceReader.HasData) + { + var tag = sequenceReader.PeekTag(); + if (tag != dnsNameTag) + { + sequenceReader.ReadEncodedValue(); + continue; + } + + var alternativeName = sequenceReader.ReadCharacterString(UniversalTagNumber.IA5String, dnsNameTag); + yield return alternativeName; + } + + } + + + /// + /// Sorts domain names right to left. + /// This allows us to use a Binary Search to achieve a suffix + /// search. + /// + private class DomainNameComparer : IComparer + { + public static readonly DomainNameComparer Instance = new(); + + public int Compare(WildCardDomain x, WildCardDomain y) + { + return Compare(x!.Domain.AsSpan(), y!.Domain.AsSpan()); + } + + private static int Compare(ReadOnlySpan x, ReadOnlySpan y) + { + + var length = Math.Min(x.Length, y.Length); + + for (var i = 1; i <= length; i++) + { + var charA = x[^i] & 0x3F; + var charB = y[^i] & 0x3F; + + if (charA == charB) + { + continue; + } + + return charB - charA; + } + + return x.Length - y.Length; + } + + } +} diff --git a/src/Kubernetes.Controller/Certificates/ServerCertificateSelector.cs b/src/Kubernetes.Controller/Certificates/ServerCertificateSelector.cs index 0c2bfdd10..ac8999484 100644 --- a/src/Kubernetes.Controller/Certificates/ServerCertificateSelector.cs +++ b/src/Kubernetes.Controller/Certificates/ServerCertificateSelector.cs @@ -1,27 +1,57 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; using System.Security.Cryptography.X509Certificates; +using System.Threading; +using System.Threading.Tasks; +using System.Timers; using Microsoft.AspNetCore.Connections; +using Microsoft.Extensions.Hosting; namespace Yarp.Kubernetes.Controller.Certificates; -internal class ServerCertificateSelector : IServerCertificateSelector +internal partial class ServerCertificateSelector + : BackgroundService + , IServerCertificateSelector { - private X509Certificate2 _defaultCertificate; + private readonly ConcurrentDictionary _certificates = new(); + private bool _hasBeenUpdated; + + private CertificateStore _certificateStore = new(Array.Empty()); public void AddCertificate(NamespacedName certificateName, X509Certificate2 certificate) { - _defaultCertificate = certificate; + _certificates[certificateName] = certificate; + _hasBeenUpdated = true; } public X509Certificate2 GetCertificate(ConnectionContext connectionContext, string domainName) { - return _defaultCertificate; + return _certificateStore.GetCertificate(domainName); } public void RemoveCertificate(NamespacedName certificateName) { - _defaultCertificate = null; + _ = _certificates.TryRemove(certificateName, out _); + _hasBeenUpdated = true; + } + + protected override async Task ExecuteAsync(CancellationToken stoppingToken) + { + // Poll every 10 seconds for updates to + while (!stoppingToken.IsCancellationRequested) + { + await Task.Delay(TimeSpan.FromSeconds(10), stoppingToken); + if (_hasBeenUpdated) + { + _hasBeenUpdated = false; + _certificateStore = new CertificateStore(_certificates.Values); + } + } } } + + diff --git a/src/Kubernetes.Controller/Management/KubernetesReverseProxyServiceCollectionExtensions.cs b/src/Kubernetes.Controller/Management/KubernetesReverseProxyServiceCollectionExtensions.cs index 663fa5921..cd4570053 100644 --- a/src/Kubernetes.Controller/Management/KubernetesReverseProxyServiceCollectionExtensions.cs +++ b/src/Kubernetes.Controller/Management/KubernetesReverseProxyServiceCollectionExtensions.cs @@ -92,7 +92,9 @@ public static IServiceCollection AddKubernetesControllerRuntime(this IServiceCol services.RegisterResourceInformer("type=kubernetes.io/tls"); // Add the Ingress/Secret to certificate management - services.AddSingleton(); + services.AddSingleton(); + services.AddHostedService(x => x.GetRequiredService()); + services.AddSingleton(x => x.GetRequiredService()); services.AddSingleton(); // ingress status updater From 6b3455bb0b5d5ade2134aac80e314ae602f41094 Mon Sep 17 00:00:00 2001 From: Aron Tsang Date: Sat, 5 Apr 2025 17:36:15 +0800 Subject: [PATCH 2/5] Refactor for Unit Tests --- .../Certificates/ImmutableCertificateCache.cs | 110 ++++++++++++++ ...verCertificateSelector.CertificateCache.cs | 50 +++++++ ...verCertificateSelector.CertificateStore.cs | 139 ------------------ .../Certificates/ServerCertificateSelector.cs | 4 +- .../Certificates/CertificateCacheTests.cs | 9 ++ 5 files changed, 171 insertions(+), 141 deletions(-) create mode 100644 src/Kubernetes.Controller/Certificates/ImmutableCertificateCache.cs create mode 100644 src/Kubernetes.Controller/Certificates/ServerCertificateSelector.CertificateCache.cs delete mode 100644 src/Kubernetes.Controller/Certificates/ServerCertificateSelector.CertificateStore.cs create mode 100644 test/Kubernetes.Tests/Certificates/CertificateCacheTests.cs diff --git a/src/Kubernetes.Controller/Certificates/ImmutableCertificateCache.cs b/src/Kubernetes.Controller/Certificates/ImmutableCertificateCache.cs new file mode 100644 index 000000000..d2261d4e8 --- /dev/null +++ b/src/Kubernetes.Controller/Certificates/ImmutableCertificateCache.cs @@ -0,0 +1,110 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; +using System.Linq; +#nullable enable +namespace Yarp.Kubernetes.Controller.Certificates; + +public abstract class ImmutableCertificateCache where TCert : class +{ + private readonly List _wildCardDomains = new(); + private readonly Dictionary _certificates = new(StringComparer.OrdinalIgnoreCase); + + public ImmutableCertificateCache(IEnumerable certificates, Func> getDomains) + { + foreach (var certificate in certificates) + { + foreach (var domain in getDomains(certificate)) + { + if (domain.StartsWith("*.")) + { + _wildCardDomains.Add(new (domain[1..], certificate)); + } + else + { + _certificates[domain] = certificate; + } + } + } + + _wildCardDomains.Sort(DomainNameComparer.Instance); + } + + public bool TryGetCertificateExact(string domain, [NotNullWhen(true)] out TCert? certificate) => + _certificates.TryGetValue(domain, out certificate); + + public bool TryGetWildcardCertificate(string domain, [NotNullWhen(true)] out TCert? certificate) + { + if (_wildCardDomains.BinarySearch(new WildCardDomain(domain, null!), DomainNameComparer.Instance) is { } index and < -1) + { + var candidate = _wildCardDomains[~index]; + if (domain.EndsWith(candidate.Domain, true, CultureInfo.InvariantCulture)) + { + certificate = candidate.Certificate; + return true; + } + } + + certificate = null; + return false; + } + + public TCert? GetDefaultCertificate() => _wildCardDomains.FirstOrDefault()?.Certificate + ?? _certificates.Values.FirstOrDefault(); + + public TCert? GetCertificate(string domain) + { + if (TryGetCertificateExact(domain, out var certificate)) + { + return certificate; + } + if (TryGetWildcardCertificate(domain, out certificate)) + { + return certificate; + } + + return GetDefaultCertificate(); + } + + private record WildCardDomain(string Domain, TCert Certificate); + + /// + /// Sorts domain names right to left. + /// This allows us to use a Binary Search to achieve a suffix + /// search. + /// + private class DomainNameComparer : IComparer + { + public static readonly DomainNameComparer Instance = new(); + + public int Compare(WildCardDomain? x, WildCardDomain? y) + { + return Compare(x!.Domain.AsSpan(), y!.Domain.AsSpan()); + } + + private static int Compare(ReadOnlySpan x, ReadOnlySpan y) + { + + var length = Math.Min(x.Length, y.Length); + + for (var i = 1; i <= length; i++) + { + var charA = x[^i] & 0x3F; + var charB = y[^i] & 0x3F; + + if (charA == charB) + { + continue; + } + + return charB - charA; + } + + return x.Length - y.Length; + } + } +} diff --git a/src/Kubernetes.Controller/Certificates/ServerCertificateSelector.CertificateCache.cs b/src/Kubernetes.Controller/Certificates/ServerCertificateSelector.CertificateCache.cs new file mode 100644 index 000000000..243d81bdb --- /dev/null +++ b/src/Kubernetes.Controller/Certificates/ServerCertificateSelector.CertificateCache.cs @@ -0,0 +1,50 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +#nullable enable +using System.Collections.Generic; +using System.Formats.Asn1; +using System.Security.Cryptography.X509Certificates; + +namespace Yarp.Kubernetes.Controller.Certificates; + +internal partial class ServerCertificateSelector +{ + private class ImmutableX509CertificateCache(IEnumerable certificates) + : ImmutableCertificateCache(certificates, GetDomains); + + private static IEnumerable GetDomains(X509Certificate2 certificate) + { + if (certificate.GetNameInfo(X509NameType.DnsName, false) is { } dnsName) + { + yield return dnsName; + } + + const string SAN_OID = "2.5.29.17"; + var extension = certificate.Extensions[SAN_OID]; + if (extension is null) + { + yield break; + } + + var dnsNameTag = new Asn1Tag(TagClass.ContextSpecific, tagValue: 2, isConstructed: false); + + var asnReader = new AsnReader(extension.RawData, AsnEncodingRules.BER); + var sequenceReader = asnReader.ReadSequence(Asn1Tag.Sequence); + while (sequenceReader.HasData) + { + var tag = sequenceReader.PeekTag(); + if (tag != dnsNameTag) + { + sequenceReader.ReadEncodedValue(); + continue; + } + + var alternativeName = sequenceReader.ReadCharacterString(UniversalTagNumber.IA5String, dnsNameTag); + yield return alternativeName; + } + + } + + + +} diff --git a/src/Kubernetes.Controller/Certificates/ServerCertificateSelector.CertificateStore.cs b/src/Kubernetes.Controller/Certificates/ServerCertificateSelector.CertificateStore.cs deleted file mode 100644 index 1da95167b..000000000 --- a/src/Kubernetes.Controller/Certificates/ServerCertificateSelector.CertificateStore.cs +++ /dev/null @@ -1,139 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System; -using System.Collections.Generic; -using System.Formats.Asn1; -using System.Globalization; -using System.Linq; -using System.Security.Cryptography.X509Certificates; - -namespace Yarp.Kubernetes.Controller.Certificates; - -internal partial class ServerCertificateSelector -{ - private class CertificateStore - { - private readonly List _wildCardDomains = new(); - private readonly Dictionary _certificates = new(StringComparer.OrdinalIgnoreCase); - - public CertificateStore(IEnumerable certificates) - { - - foreach (var certificate in certificates) - { - foreach (var domain in GetDomains(certificate)) - { - if (domain.StartsWith("*.")) - { - _wildCardDomains.Add(new (domain[2..], certificate)); - } - else - { - _certificates[domain] = certificate; - } - } - } - - _wildCardDomains.Sort(DomainNameComparer.Instance); - } - - - public X509Certificate2 GetCertificate(string domain) - { - // First search for exact match for certificate. - if (_certificates.TryGetValue(domain, out var cert)) - { - return cert; - } - - - // By using a binary search, we can achieve O(log n) suffix search whilst avoiding a complex - // tree/trie structure in the heap. - if (_wildCardDomains.BinarySearch(new WildCardDomain(domain, null!), DomainNameComparer.Instance) is { } index and < -1) - { - var candidate = _wildCardDomains[~index]; - if (domain.EndsWith(candidate.Domain, true, CultureInfo.InvariantCulture)) - { - return candidate.Certificate; - } - } - - return _wildCardDomains.FirstOrDefault()?.Certificate - ?? _certificates.Values.FirstOrDefault(); - } - } - - private record WildCardDomain(string Domain, X509Certificate2 Certificate); - - private static IEnumerable GetDomains(X509Certificate2 certificate) - { - if (certificate.GetNameInfo(X509NameType.DnsName, false) is { } dnsName) - { - yield return dnsName; - } - - const string SAN_OID = "2.5.29.17"; - var extension = certificate.Extensions[SAN_OID]; - if (extension is null) - { - yield break; - } - - var dnsNameTag = new Asn1Tag(TagClass.ContextSpecific, tagValue: 2, isConstructed: false); - - var asnReader = new AsnReader(extension.RawData, AsnEncodingRules.BER); - var sequenceReader = asnReader.ReadSequence(Asn1Tag.Sequence); - while (sequenceReader.HasData) - { - var tag = sequenceReader.PeekTag(); - if (tag != dnsNameTag) - { - sequenceReader.ReadEncodedValue(); - continue; - } - - var alternativeName = sequenceReader.ReadCharacterString(UniversalTagNumber.IA5String, dnsNameTag); - yield return alternativeName; - } - - } - - - /// - /// Sorts domain names right to left. - /// This allows us to use a Binary Search to achieve a suffix - /// search. - /// - private class DomainNameComparer : IComparer - { - public static readonly DomainNameComparer Instance = new(); - - public int Compare(WildCardDomain x, WildCardDomain y) - { - return Compare(x!.Domain.AsSpan(), y!.Domain.AsSpan()); - } - - private static int Compare(ReadOnlySpan x, ReadOnlySpan y) - { - - var length = Math.Min(x.Length, y.Length); - - for (var i = 1; i <= length; i++) - { - var charA = x[^i] & 0x3F; - var charB = y[^i] & 0x3F; - - if (charA == charB) - { - continue; - } - - return charB - charA; - } - - return x.Length - y.Length; - } - - } -} diff --git a/src/Kubernetes.Controller/Certificates/ServerCertificateSelector.cs b/src/Kubernetes.Controller/Certificates/ServerCertificateSelector.cs index ac8999484..7dfc40b17 100644 --- a/src/Kubernetes.Controller/Certificates/ServerCertificateSelector.cs +++ b/src/Kubernetes.Controller/Certificates/ServerCertificateSelector.cs @@ -20,7 +20,7 @@ internal partial class ServerCertificateSelector private readonly ConcurrentDictionary _certificates = new(); private bool _hasBeenUpdated; - private CertificateStore _certificateStore = new(Array.Empty()); + private ImmutableX509CertificateCache _certificateStore = new(Array.Empty()); public void AddCertificate(NamespacedName certificateName, X509Certificate2 certificate) { @@ -48,7 +48,7 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) if (_hasBeenUpdated) { _hasBeenUpdated = false; - _certificateStore = new CertificateStore(_certificates.Values); + _certificateStore = new ImmutableX509CertificateCache(_certificates.Values); } } } diff --git a/test/Kubernetes.Tests/Certificates/CertificateCacheTests.cs b/test/Kubernetes.Tests/Certificates/CertificateCacheTests.cs new file mode 100644 index 000000000..c1855df26 --- /dev/null +++ b/test/Kubernetes.Tests/Certificates/CertificateCacheTests.cs @@ -0,0 +1,9 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Yarp.Kubernetes.Controller.Certificates.Tests; + +public class CertificateCacheTests +{ + +} From d9790581460882a32abe237f2bea279884ef559b Mon Sep 17 00:00:00 2001 From: Aron Tsang Date: Sat, 5 Apr 2025 18:15:44 +0800 Subject: [PATCH 3/5] Fix code and add Unit Tests for CertificateCache --- Directory.Build.props | 2 +- .../Certificates/ImmutableCertificateCache.cs | 72 ++++++++++++------- ...verCertificateSelector.CertificateCache.cs | 13 +++- .../Certificates/CertificateCacheTests.cs | 49 ++++++++++++- 4 files changed, 109 insertions(+), 27 deletions(-) diff --git a/Directory.Build.props b/Directory.Build.props index bf2077319..c09be010e 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -7,7 +7,7 @@ © Microsoft Corporation. All rights reserved. - 12.0 + 13.0 MIT Microsoft true diff --git a/src/Kubernetes.Controller/Certificates/ImmutableCertificateCache.cs b/src/Kubernetes.Controller/Certificates/ImmutableCertificateCache.cs index d2261d4e8..52999c894 100644 --- a/src/Kubernetes.Controller/Certificates/ImmutableCertificateCache.cs +++ b/src/Kubernetes.Controller/Certificates/ImmutableCertificateCache.cs @@ -34,27 +34,9 @@ public ImmutableCertificateCache(IEnumerable certificates, Func - _certificates.TryGetValue(domain, out certificate); - - public bool TryGetWildcardCertificate(string domain, [NotNullWhen(true)] out TCert? certificate) - { - if (_wildCardDomains.BinarySearch(new WildCardDomain(domain, null!), DomainNameComparer.Instance) is { } index and < -1) - { - var candidate = _wildCardDomains[~index]; - if (domain.EndsWith(candidate.Domain, true, CultureInfo.InvariantCulture)) - { - certificate = candidate.Certificate; - return true; - } - } - certificate = null; - return false; - } - public TCert? GetDefaultCertificate() => _wildCardDomains.FirstOrDefault()?.Certificate - ?? _certificates.Values.FirstOrDefault(); + protected abstract TCert? GetDefaultCertificate(); public TCert? GetCertificate(string domain) { @@ -70,7 +52,35 @@ public bool TryGetWildcardCertificate(string domain, [NotNullWhen(true)] out TCe return GetDefaultCertificate(); } - private record WildCardDomain(string Domain, TCert Certificate); + protected IReadOnlyList WildcardCertificates => _wildCardDomains; + + protected IReadOnlyDictionary Certificates => _certificates; + + protected record WildCardDomain(string Domain, TCert? Certificate); + + private bool TryGetCertificateExact(string domain, [NotNullWhen(true)] out TCert? certificate) => + _certificates.TryGetValue(domain, out certificate); + + private bool TryGetWildcardCertificate(string domain, [NotNullWhen(true)] out TCert? certificate) + { + if (_wildCardDomains.BinarySearch(new WildCardDomain(domain, null!), DomainNameComparer.Instance) is { } index) + { + if (index > -1) + { + certificate = _wildCardDomains[index].Certificate!; + return true; + } + // var candidate = _wildCardDomains[~index]; + // if (domain.EndsWith(candidate.Domain, true, CultureInfo.InvariantCulture)) + // { + // certificate = candidate.Certificate!; + // return true; + // } + } + + certificate = null; + return false; + } /// /// Sorts domain names right to left. @@ -83,7 +93,21 @@ private class DomainNameComparer : IComparer public int Compare(WildCardDomain? x, WildCardDomain? y) { - return Compare(x!.Domain.AsSpan(), y!.Domain.AsSpan()); + var ret = Compare(x!.Domain.AsSpan(), y!.Domain.AsSpan()); + if (ret != 0) + { + return ret; + } + + switch (x!.Certificate, y!.Certificate) + { + case (null, {}) when x.Domain.Length > y.Domain.Length: + return 0; + case ({}, null) when x.Domain.Length < y.Domain.Length: + return 0; + default: + return x.Domain.Length - y.Domain.Length; + } } private static int Compare(ReadOnlySpan x, ReadOnlySpan y) @@ -93,8 +117,8 @@ private static int Compare(ReadOnlySpan x, ReadOnlySpan y) for (var i = 1; i <= length; i++) { - var charA = x[^i] & 0x3F; - var charB = y[^i] & 0x3F; + var charA = x[^i] & 0x5F; + var charB = y[^i] & 0x5F; if (charA == charB) { @@ -104,7 +128,7 @@ private static int Compare(ReadOnlySpan x, ReadOnlySpan y) return charB - charA; } - return x.Length - y.Length; + return 0; } } } diff --git a/src/Kubernetes.Controller/Certificates/ServerCertificateSelector.CertificateCache.cs b/src/Kubernetes.Controller/Certificates/ServerCertificateSelector.CertificateCache.cs index 243d81bdb..fe2a25bcb 100644 --- a/src/Kubernetes.Controller/Certificates/ServerCertificateSelector.CertificateCache.cs +++ b/src/Kubernetes.Controller/Certificates/ServerCertificateSelector.CertificateCache.cs @@ -3,6 +3,7 @@ #nullable enable using System.Collections.Generic; using System.Formats.Asn1; +using System.Linq; using System.Security.Cryptography.X509Certificates; namespace Yarp.Kubernetes.Controller.Certificates; @@ -10,7 +11,17 @@ namespace Yarp.Kubernetes.Controller.Certificates; internal partial class ServerCertificateSelector { private class ImmutableX509CertificateCache(IEnumerable certificates) - : ImmutableCertificateCache(certificates, GetDomains); + : ImmutableCertificateCache(certificates, GetDomains) + { + protected override X509Certificate2? GetDefaultCertificate() + { + if (WildcardCertificates.Count != 0) + { + return WildcardCertificates[0].Certificate; + } + return Certificates.Values.FirstOrDefault(); + } + } private static IEnumerable GetDomains(X509Certificate2 certificate) { diff --git a/test/Kubernetes.Tests/Certificates/CertificateCacheTests.cs b/test/Kubernetes.Tests/Certificates/CertificateCacheTests.cs index c1855df26..125fbc5e2 100644 --- a/test/Kubernetes.Tests/Certificates/CertificateCacheTests.cs +++ b/test/Kubernetes.Tests/Certificates/CertificateCacheTests.cs @@ -1,9 +1,56 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Collections.Generic; +using Xunit; +#nullable enable namespace Yarp.Kubernetes.Controller.Certificates.Tests; public class CertificateCacheTests { - + + private static readonly FakeCertificateCache Cache = new( + new FakeCertificate("Acme", "mail.acme.com", "www.acme.com"), + new FakeCertificate("Initech", "*.initech.com", "initech.com"), + new FakeCertificate("Northwind", "*.northwind.com") + ); + + [Theory] + [InlineData("www.acme.com", "Acme")] + [InlineData("www.ACME.com", "Acme")] + [InlineData("mail.acme.com", "Acme")] + [InlineData("acme.com", null)] + [InlineData("store.acme.com", null)] + [InlineData("www.northwind.com", "Northwind")] + [InlineData("mail.northwind.com", "Northwind")] + [InlineData("northwind.com", null)] + [InlineData("initech.com", "Initech")] + [InlineData("www.initech.com", "Initech")] + [InlineData("www.IniTech.coM", "Initech")] + public void CertificateConversionFromPem(string requestedDomain, string? expectedCompanyName) + { + var certificate = Cache.GetCertificate(requestedDomain); + if (expectedCompanyName != null) + { + Assert.Equal(expectedCompanyName, certificate?.Name); + } + else + { + Assert.Null(certificate?.Name); + } + } + + private record FakeCertificate(string Name, params string[] Domains); + + private class FakeCertificateCache(params IEnumerable certificates) + : ImmutableCertificateCache(certificates, static cert => cert.Domains) + { + protected override FakeCertificate? GetDefaultCertificate() + { + return null; + } + } } + + + From 3448610fd796f69b5b36ba09820b773565184895 Mon Sep 17 00:00:00 2001 From: Aron Tsang Date: Sat, 5 Apr 2025 21:41:43 +0800 Subject: [PATCH 4/5] Fix test --- test/ReverseProxy.Tests/Forwarder/HttpForwarderTests.cs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/ReverseProxy.Tests/Forwarder/HttpForwarderTests.cs b/test/ReverseProxy.Tests/Forwarder/HttpForwarderTests.cs index 82bb01a1b..110526d6b 100644 --- a/test/ReverseProxy.Tests/Forwarder/HttpForwarderTests.cs +++ b/test/ReverseProxy.Tests/Forwarder/HttpForwarderTests.cs @@ -1064,7 +1064,7 @@ public async Task RequestWithCookieHeaders(params string[] cookies) { var events = TestEventListener.Collect(); - + var httpContext = new DefaultHttpContext(); httpContext.Request.Method = "GET"; httpContext.Request.Headers[HeaderNames.Cookie] = cookies; @@ -1260,7 +1260,7 @@ public static IEnumerable ResponseMultiHeadersData() { foreach (var header in ResponseMultiHeaderNames()) { - foreach (var version in new[] { "1.1", "2.0" }) + foreach (var version in new[] { "1.1", "2.0" }) { foreach (var value in MultiValues()) { @@ -2567,8 +2567,9 @@ public async Task Response_RemoveProhibitedHeaders(string protocol, string prohi await sut.SendAsync(httpContext, destinationPrefix, client, new ForwarderRequestConfig { Version = Version.Parse(protocol) }); + string[] headers = httpContext.Response.Headers[PreservedHeaderName]; Assert.Equal((int)HttpStatusCode.OK, httpContext.Response.StatusCode); - Assert.Equal(PreservedHeaderValue, string.Join(", ", httpContext.Response.Headers[PreservedHeaderName])); + Assert.Equal(PreservedHeaderValue, string.Join(", ", headers)); foreach (var (name, _) in prohibitedHeaders) { From e56abc3e05b0efb8f0aa410655cd24c8a59f26a2 Mon Sep 17 00:00:00 2001 From: Aron Tsang Date: Sun, 6 Apr 2025 03:25:30 +0800 Subject: [PATCH 5/5] Switch to using record struct for better cache locality --- .../Certificates/ImmutableCertificateCache.cs | 37 ++++++------------- 1 file changed, 12 insertions(+), 25 deletions(-) diff --git a/src/Kubernetes.Controller/Certificates/ImmutableCertificateCache.cs b/src/Kubernetes.Controller/Certificates/ImmutableCertificateCache.cs index 52999c894..1b3fdd3d4 100644 --- a/src/Kubernetes.Controller/Certificates/ImmutableCertificateCache.cs +++ b/src/Kubernetes.Controller/Certificates/ImmutableCertificateCache.cs @@ -4,8 +4,7 @@ using System; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; -using System.Globalization; -using System.Linq; + #nullable enable namespace Yarp.Kubernetes.Controller.Certificates; @@ -56,26 +55,17 @@ public ImmutableCertificateCache(IEnumerable certificates, Func Certificates => _certificates; - protected record WildCardDomain(string Domain, TCert? Certificate); + protected record struct WildCardDomain(string Domain, TCert? Certificate); private bool TryGetCertificateExact(string domain, [NotNullWhen(true)] out TCert? certificate) => _certificates.TryGetValue(domain, out certificate); private bool TryGetWildcardCertificate(string domain, [NotNullWhen(true)] out TCert? certificate) { - if (_wildCardDomains.BinarySearch(new WildCardDomain(domain, null!), DomainNameComparer.Instance) is { } index) + if (_wildCardDomains.BinarySearch(new WildCardDomain(domain, null!), DomainNameComparer.Instance) is { } index and > -1) { - if (index > -1) - { - certificate = _wildCardDomains[index].Certificate!; - return true; - } - // var candidate = _wildCardDomains[~index]; - // if (domain.EndsWith(candidate.Domain, true, CultureInfo.InvariantCulture)) - // { - // certificate = candidate.Certificate!; - // return true; - // } + certificate = _wildCardDomains[index].Certificate!; + return true; } certificate = null; @@ -91,23 +81,20 @@ private class DomainNameComparer : IComparer { public static readonly DomainNameComparer Instance = new(); - public int Compare(WildCardDomain? x, WildCardDomain? y) + public int Compare(WildCardDomain x, WildCardDomain y) { - var ret = Compare(x!.Domain.AsSpan(), y!.Domain.AsSpan()); + var ret = Compare(x.Domain.AsSpan(), y.Domain.AsSpan()); if (ret != 0) { return ret; } - switch (x!.Certificate, y!.Certificate) + return (x.Certificate, y.Certificate) switch { - case (null, {}) when x.Domain.Length > y.Domain.Length: - return 0; - case ({}, null) when x.Domain.Length < y.Domain.Length: - return 0; - default: - return x.Domain.Length - y.Domain.Length; - } + (null, not null) when x.Domain.Length > y.Domain.Length => 0, + (not null, null) when x.Domain.Length < y.Domain.Length => 0, + _ => x.Domain.Length - y.Domain.Length + }; } private static int Compare(ReadOnlySpan x, ReadOnlySpan y)